-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcreate_dummy_models.py
More file actions
44 lines (37 loc) · 1.52 KB
/
create_dummy_models.py
File metadata and controls
44 lines (37 loc) · 1.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import os
import pickle
import torch
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from transformers import AutoTokenizer, AutoModelForSequenceClassification
def create_dummy_tfidf():
print("Creating dummy TF-IDF models...")
texts = ["I love this", "This is bad", "Happy day", "Sad moment"]
labels = [1, 0, 1, 0]
vectorizer = TfidfVectorizer()
X = vectorizer.fit_transform(texts)
model = LogisticRegression()
model.fit(X, labels)
os.makedirs("models", exist_ok=True)
with open("models/tfidf_vectorizer.pkl", "wb") as f:
pickle.dump(vectorizer, f)
with open("models/logistic_regression.pkl", "wb") as f:
pickle.dump(model, f)
print("TF-IDF models saved.")
def create_dummy_bert():
print("Downloading/Saving base BERT model as dummy...")
model_name = "distilbert-base-uncased-finetuned-sst-2-english"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
save_path = "models/bert_sentiment"
os.makedirs(save_path, exist_ok=True)
tokenizer.save_pretrained(save_path)
model.save_pretrained(save_path)
print("BERT model saved.")
if __name__ == "__main__":
create_dummy_tfidf()
# BERT might take a while to download, so I'll leave it as an option or just do it
try:
create_dummy_bert()
except Exception as e:
print(f"Skipping BERT dummy creation (requires internet/time): {e}")