-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot_loss.py
More file actions
76 lines (58 loc) · 2.24 KB
/
Copy pathplot_loss.py
File metadata and controls
76 lines (58 loc) · 2.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""
Download trainer_log.jsonl from the Modal volume and plot the SFT loss curve.
Usage:
modal run plot_loss.py # downloads logs to ./training_logs/
python plot_loss.py --plot # plots from local downloaded logs
"""
import argparse
import modal
app = modal.App("plot-loss")
volume = modal.Volume.from_name("qwen-sft-output")
@app.function(volumes={"/output": volume})
def fetch_logs():
"""Read trainer_state.json from the final checkpoint and return its log_history."""
import json, os, glob
# Find the highest-numbered checkpoint
checkpoints = sorted(glob.glob("/output/qwen-alpaca-sft/checkpoint-*"),
key=lambda p: int(p.split("-")[-1]))
if not checkpoints:
raise FileNotFoundError("No checkpoints found in volume")
final = checkpoints[-1]
print(f"Using {final}")
with open(f"{final}/trainer_state.json") as f:
state = json.load(f)
return state["log_history"]
@app.local_entrypoint()
def main():
import json, os
history = fetch_logs.remote()
os.makedirs("training_logs", exist_ok=True)
out = "training_logs/sft_log_history.json"
with open(out, "w") as f:
json.dump(history, f, indent=2)
print(f"Saved {len(history)} log entries to {out}")
print("Now run: python plot_loss.py --plot")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--plot", action="store_true")
args = parser.parse_args()
if args.plot:
import json
import matplotlib.pyplot as plt
with open("training_logs/sft_log_history.json") as f:
history = json.load(f)
steps = [h["step"] for h in history if "loss" in h]
losses = [h["loss"] for h in history if "loss" in h]
fig, ax = plt.subplots(figsize=(9, 5))
ax.plot(steps, losses, color="#2563eb", linewidth=1.6)
ax.set_xlabel("Step")
ax.set_ylabel("Training loss")
ax.set_title("SFT — Qwen2.5-0.5B on Alpaca")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.grid(True, alpha=0.25)
plt.tight_layout()
out = "docs/sft_loss_curve.png"
plt.savefig(out, dpi=150)
print(f"Saved plot to {out}")
plt.show()