-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerator-simple.py
More file actions
38 lines (29 loc) · 1.29 KB
/
generator-simple.py
File metadata and controls
38 lines (29 loc) · 1.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import fire
import jsonlines
from progressbar import progressbar
from AdvDecoder import decode
from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2Tokenizer, GPT2LMHeadModel, RobertaForSequenceClassification
from BatchTextGenerationPipeline import BatchTextGenerationPipeline
from IsFakePipeline import IsFakePipelineHF
def main(file, lines, sequences_per_step=12, sequence_length=64):
# tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
# model = GPT2LMHeadModel.from_pretrained("gpt2")
# model.to(0)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M")
model.to(0)
generator = BatchTextGenerationPipeline(model=model, tokenizer=tokenizer, device=0)
with jsonlines.open(file, mode='a') as writer:
for _ in progressbar(range(lines // sequences_per_step)):
sequences = generator.generate(
prompt='',
generate_length=sequence_length,
num_return_sequences=sequences_per_step,
do_sample=True,
top_p=0.99,
no_repeat_ngram_size=3
)
for text in sequences:
writer.write({'text': text})
if __name__ == '__main__':
fire.Fire(main)