-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
137 lines (106 loc) · 3.54 KB
/
main.py
File metadata and controls
137 lines (106 loc) · 3.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import sys
import logging
import os
from pathlib import Path
import subprocess
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
# Determine data directory - use /data if writable, otherwise use local data/
def get_data_dir():
data_dir = Path("/data")
if not os.access(data_dir, os.W_OK):
data_dir = Path("data")
return data_dir
def run_command(cmd, cwd=None, description=None, env=None):
if description:
logger.info(f"Starting: {description}")
try:
run_env = os.environ.copy()
if env:
run_env.update(env)
subprocess.run(cmd, cwd=cwd, check=True, capture_output=False, env=run_env)
if description:
logger.info(f"Completed: {description}")
return True
except subprocess.CalledProcessError as e:
logger.error(f"Failed to run command: {' '.join(cmd)}")
logger.error(f"Return code: {e.returncode}")
return False
def download_data():
data_dir = get_data_dir()
corpus_path = data_dir / "corpus.txt.gz"
if corpus_path.exists():
logger.info(f"Data already exists at {corpus_path}")
return True
if not run_command(
[sys.executable, "data/gatherData.py"],
cwd=None,
description="Download Wikipedia dataset",
env={"DATA_DIR": "/data"}
):
return False
if not corpus_path.exists():
logger.error(f"Data download failed - {corpus_path} not found")
return False
return True
def preprocess_data():
data_dir = get_data_dir()
corpus_clean_path = data_dir / "corpus_clean.txt.gz"
if corpus_clean_path.exists():
logger.info(f"Preprocessed data already exists at {corpus_clean_path}")
return True
if not run_command(
[sys.executable, "data/preprocess.py"],
cwd=None,
description="Preprocess and clean data",
env={"DATA_DIR": "/data"}
):
return False
if not corpus_clean_path.exists():
logger.error(f"Data preprocessing failed - {corpus_clean_path} not found")
return False
return True
def train_tokenizer():
model_dir = Path("model")
tokenizer_path = model_dir / "tokenizer.json"
if tokenizer_path.exists():
logger.info(f"Tokenizer already exists at {tokenizer_path}")
return True
if not run_command(
[sys.executable, "model/tokenizer.py"],
cwd=None,
description="Train BPE tokenizer"
):
return False
if not tokenizer_path.exists():
logger.error(f"Tokenizer training failed - {tokenizer_path} not found")
return False
return True
def train_model():
if not run_command(
[sys.executable, "model/train.py"],
cwd=None,
description="Train MLM model"
):
return False
return True
if not Path("data").exists() or not Path("model").exists():
logger.error("Error: data/ and model/ directories not found")
logger.error("Please run this script from the CDiffusion root directory")
steps = [
("Download Data", download_data),
("Preprocess Data", preprocess_data),
("Train Tokenizer", train_tokenizer),
("Train Model", train_model),
]
for step_name, step_func in steps:
try:
if not step_func():
logger.error(f"Pipeline failed at: {step_name}")
exit(1)
except Exception as e:
logger.error(f"Unexpected error during {step_name}: {e}")
exit(1)