-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathparse_generated_data.py
More file actions
106 lines (77 loc) · 3.07 KB
/
parse_generated_data.py
File metadata and controls
106 lines (77 loc) · 3.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import json
from opencc import OpenCC
from tqdm import tqdm
from curriculum_training.constants import MODEL_DISTAL_FROM
from news_with_rationale import NewsWithRationale
from utils import (
load_udn_news,
get_response_filename,
get_news_with_rationale_filename,
)
MODELNAME = MODEL_DISTAL_FROM
def load_response(
filepath: str | None = None, model_name: str = MODELNAME
) -> list[dict]:
if filepath is None:
if model_name == "":
raise ValueError("Either filepath or model_name must be provided.")
filepath = get_response_filename(model_name)
print(f"Loading from {filepath}")
with open(filepath, "r", encoding="utf-8") as f:
raw_data = []
for line in f:
raw_data.append(json.loads(line))
return raw_data
def parse_response(responses: list[dict]) -> list[NewsWithRationale]:
cc = OpenCC('s2twp')
corrupted_response_ids: set[int] = set()
data: list[NewsWithRationale] = []
news: list[str] = load_udn_news()
for i, d in tqdm(
enumerate(responses),
total=len(responses), desc="Parsing responses",
):
if i in corrupted_response_ids:
continue
# remove leading \r\n in d["news"]
d["news"] = d["news"].strip()
# s: 核心要素:\n.....\n三元組:\n.....\n生成摘要:\n.....
s = d["response"]
if "核心要素:" not in s or "三元組:" not in s or "生成摘要:" not in s:
# print(f"Corrupted response at index {i}")
corrupted_response_ids.add(i)
continue
idx1, idx2, idx3 = s.find("核心要素:"), s.find("三元組:"), s.find("生成摘要:")
critical_elements_str = s[idx1 + 5:idx2].strip()
triples_str = s[idx2 + 4:idx3].strip()
summary = s[idx3 + 5:].strip()
critical_elements_str = cc.convert(critical_elements_str)
triples_str = cc.convert(triples_str)
summary = cc.convert(summary)
# print(f'\n\n核心要素:\n{critical_elements_str}\n')
# print(f'\n\n三元組:\n{triples_str}\n')
# print(f'\n\n生成摘要:\n{summary}\n')
essential_aspects: list[str] = critical_elements_str.split('\n')
triples: list[str] = triples_str.split('\n')
# print(f'{essential_aspects=}')
# print(f'{triples=}')
id = news.index(d['news'])
data.append(NewsWithRationale(
article=d["news"],
summary=summary,
id=id,
label=[-1],
essential_aspects=essential_aspects,
triples=triples,
rationale_summary=summary,
))
print(f"Parsed {len(data)} responses")
print(f"Corrupted response ids: {sorted(corrupted_response_ids)}")
return data
if __name__ == "__main__":
data = parse_response(load_response(model_name=MODELNAME))
filename = get_news_with_rationale_filename(MODELNAME)
with open(filename, "w", encoding="utf-8") as f:
for d in data:
f.write(json.dumps(d.__dict__, ensure_ascii=False))
f.write('\n')