-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerator-adv.py
More file actions
49 lines (33 loc) · 1.57 KB
/
generator-adv.py
File metadata and controls
49 lines (33 loc) · 1.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import pickle
import fire
import jsonlines
from progressbar import progressbar
from AdvDecoder import decode
from BatchTextGenerationPipeline import BatchTextGenerationPipeline
from IsFakePipeline import IsFakePipelineHF, IsFakePipelineSklearn
from transformers import AutoTokenizer, GPT2Tokenizer, GPT2LMHeadModel, RobertaForSequenceClassification
def main(file, lines, step=16, sequences_per_step=12, sequence_length=64):
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
detector_tokenizer = AutoTokenizer.from_pretrained("roberta-base-openai-detector")
detector_model = RobertaForSequenceClassification.from_pretrained("roberta-base-openai-detector")
model.to(0)
detector_model.to(0)
classifier = IsFakePipelineHF(model=detector_model, tokenizer=detector_tokenizer, device=0)
# classifier = IsFakePipelineSklearn(
# model=pickle.load(open('./model/sim.tfidf_model_65536_feat.bin', 'rb')),
# vectorizer=pickle.load(open('./model/sim.tfidf_vect_65536_feat.bin', 'rb'))
# )
generator = BatchTextGenerationPipeline(model=model, tokenizer=tokenizer, device=0)
with jsonlines.open(file, mode='a') as writer:
for _ in progressbar(range(lines)):
writer.write({'text': decode(
prompt="",
step=step,
sequences_per_step=sequences_per_step,
generate_length=sequence_length,
generator=generator,
classifier=classifier
)})
if __name__ == '__main__':
fire.Fire(main)