-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata.py
More file actions
452 lines (385 loc) · 18.8 KB
/
Copy pathdata.py
File metadata and controls
452 lines (385 loc) · 18.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
"""
data.py — TranscriptionUnit and Transcript data structures.
Pipeline coverage:
Step 2 — TranscriptionUnit.__post_init__ (preprocess / normalize)
Step 3 — Transcript.sort
Step 4 — Transcript.find_overlaps
Step 5 — TranscriptionUnit.tokenize (delegates to tokens.tokenize_tu)
Step 6 — Transcript.check_overlaps
Step 7 — TranscriptionUnit.add_token_features
"""
from __future__ import annotations
import collections
import logging
from dataclasses import dataclass, field
from typing import Optional
import networkx as nx
import regex as re
import dataflags as df
from normalize import validate_and_normalize
from tokens import Token, tokenize_tu
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# TranscriptionUnit (steps 2, 5, 7)
# ---------------------------------------------------------------------------
@dataclass
class TranscriptionUnit:
tu_id: int
speaker: str
start: float
end: float
duration: float
annotation: str
parent_tu_id: Optional[int] = None
# Full pipeline config dict; used for normalization and tokenization.
cfg: dict = field(default_factory=dict, repr=False)
# Computed in __post_init__
orig_annotation: str = field(init=False, default="")
include: bool = field(init=False, default=True)
non_ita: df.languagevariation = field(init=False, default=df.languagevariation.none)
# Span positions (char offsets into the normalized annotation), step 2f.
overlapping_spans: list[tuple[int, int]] = field(init=False, default_factory=list)
slow_pace_spans: list[tuple[int, int]] = field(init=False, default_factory=list)
fast_pace_spans: list[tuple[int, int]] = field(init=False, default_factory=list)
low_volume_spans: list[tuple[int, int]] = field(init=False, default_factory=list)
high_volume_spans: list[tuple[int, int]] = field(init=False, default_factory=list)
guessing_spans: list[tuple[int, int]] = field(init=False, default_factory=list)
# Populated by Transcript.check_overlaps (step 6).
overlapping_times: dict = field(init=False, default_factory=dict)
overlapping_matches: dict = field(init=False, default_factory=dict)
overlap_duration: dict = field(init=False, default_factory=dict)
warnings: dict[str, int] = field(init=False, default_factory=lambda: collections.defaultdict(int))
errors: dict[str, bool] = field(init=False, default_factory=lambda: collections.defaultdict(bool))
tokens: list[Token] = field(init=False, default_factory=list)
# ------------------------------------------------------------------
# Step 2 — Preprocess
# ------------------------------------------------------------------
def __post_init__(self):
self.orig_annotation = self.annotation
# 2a. Empty / exclude check.
if not self.annotation or not self.annotation.strip():
logger.info("TU %s: empty annotation, excluding", self.tu_id)
self.include = False
return
self.annotation = self.annotation.strip()
# 2b. TU-level language variation markers.
if self.annotation.startswith("#_"):
self.non_ita = df.languagevariation.all
self.annotation = self.annotation[2:].strip()
# Skip normalization for entirely non-Italian TUs.
return
if self.annotation.startswith("# "):
self.non_ita = df.languagevariation.some
self.annotation = self.annotation[1:].strip()
# 2c–2g. Normalize, error-check, conditional fixes, symbol corrections.
# validate_and_normalize covers: warning rules (SYMBOL_NOT_ALLOWED, META_TAGS,
# UNEVEN_SPACES, TRIM_PAUSES, TRIM_PROSODICLINKS, OVERLAP_PROLONGATION,
# MULTIPLE_SPACES, ACCENTS, NUMBERS, check_spaces_dots, check_spaces_angular,
# SWITCHES, remove_empty_spans, flag_empty_unit) and error rules
# (UNBALANCED_DOTS, UNBALANCED_PACE, UNBALANCED_GUESS, UNBALANCED_OVERLAP).
norm_cfg = self.cfg.get("normalization", {})
normalized, warnings, errors = validate_and_normalize(self.annotation, norm_cfg)
for key, count in warnings.items():
self.warnings[key] += count
for key, has_error in errors.items():
if has_error:
self.errors[key] = True
# 2h. All-symbol exclusion: flag_empty_unit returns "" when nothing remains.
if not normalized:
logger.info("TU %s: only symbols after normalization, excluding", self.tu_id)
self.include = False
return
self.annotation = normalized
# 2f. Span position extraction on the normalized annotation.
if "<" in self.annotation and not self.errors.get("UNBALANCED_PACE"):
self.slow_pace_spans = [
(m.start(), m.end())
for m in re.finditer(r"<[^<>]*>", self.annotation)
]
self.fast_pace_spans = [
(m.start(), m.end())
for m in re.finditer(r">[^<>]*<", self.annotation)
]
if "°" in self.annotation and not self.errors.get("UNBALANCED_DOTS"):
self.low_volume_spans = [
(m.start(), m.end())
for m in re.finditer(r"°[^°]+°", self.annotation)
]
hv_matches = list(re.finditer(
r"\b[A-ZÀÈÉÌÒÓÙ]+(?:\s+[A-ZÀÈÉÌÒÓÙ]+)*\b", self.annotation
))
if hv_matches:
self.high_volume_spans = [(m.start(), m.end()) for m in hv_matches]
if "[" in self.annotation and not self.errors.get("UNBALANCED_OVERLAP"):
self.overlapping_spans = [
(m.start(), m.end())
for m in re.finditer(r"\[[^\]]+\]", self.annotation)
]
if "(" in self.annotation and not self.errors.get("UNBALANCED_GUESS"):
self.guessing_spans = [
(m.start(), m.end())
for m in re.finditer(r"\([^)]+\)", self.annotation)
]
# ------------------------------------------------------------------
# Step 5 — Tokenize
# ------------------------------------------------------------------
def tokenize(self, cfg: dict | None = None):
"""Tokenize the normalized annotation."""
if not self.include:
return
if cfg is None:
cfg = self.cfg
self.tokens = tokenize_tu(
self.annotation,
tu_id=self.tu_id,
variation_context=self.non_ita,
cfg_variation=cfg.get("variation_markers", {}),
)
# Update TU-level non_ita based on token-level flags.
has_non_ita = any(t.non_ita for t in self.tokens)
all_non_ita = bool(self.tokens) and all(t.non_ita for t in self.tokens)
if all_non_ita:
self.non_ita = df.languagevariation.all
elif has_non_ita:
self.non_ita = df.languagevariation.some
# ------------------------------------------------------------------
# Step 7 — Map span features to tokens
# ------------------------------------------------------------------
def add_token_features(self):
"""Map TU-level span positions to individual tokens."""
if not self.tokens:
return
# Build a character-level index over all token orig_text values.
# For annotation char position i:
# token_at[i] = list index of the owning token
# (-1 = colon/punctuation, -2 = bracket marker, -3 = inter-token)
# form_idx[i] = form-char index within that token (-1 for non-form chars)
token_at: list[int] = []
form_idx: list[int] = []
for tok_i, tok in enumerate(self.tokens):
fi = 0
for ch in tok.orig_text:
if ch in ":.,?":
token_at.append(-1)
form_idx.append(-1)
elif ch in "[]()<>°":
token_at.append(-2)
form_idx.append(-2)
else:
token_at.append(tok_i)
form_idx.append(fi)
fi += 1
# Sentinel between tokens.
token_at.append(-3)
form_idx.append(-3)
def _apply(feature: str, spans: list[tuple[int, int]], use_match_id: bool = False):
for span_id, span in enumerate(spans):
a, b = span
pairs = list(zip(token_at[a:b], form_idx[a:b]))
covered = {ti for ti, _ in pairs if ti >= 0}
char_ranges: dict[int, list[int]] = {ti: [] for ti in covered}
for ti, fi in pairs:
if ti in char_ranges:
char_ranges[ti].append(fi)
for ti, positions in char_ranges.items():
cs = min(positions)
ce = max(positions) + 1
tok = self.tokens[ti]
if feature == "slow_pace":
tok.slow_pace[span_id] = (cs, ce)
elif feature == "fast_pace":
tok.fast_pace[span_id] = (cs, ce)
elif feature == "low_volume":
tok.low_volume[span_id] = (cs, ce)
tok.volume = df.volume.low
elif feature == "guesses":
tok.guesses[span_id] = (cs, ce)
_apply("slow_pace", self.slow_pace_spans)
_apply("fast_pace", self.fast_pace_spans)
_apply("low_volume", self.low_volume_spans)
# high_volume is detected per-token in Token._classify; no dict on Token.
_apply("guesses", self.guessing_spans)
# Overlaps use match_id (clique id) as the key, not span index.
if self.overlapping_matches:
for span, match_id in self.overlapping_matches.items():
a, b = span
pairs = list(zip(token_at[a:b], form_idx[a:b]))
covered = {ti for ti, _ in pairs if ti >= 0}
char_ranges = {ti: [] for ti in covered}
for ti, fi in pairs:
if ti in char_ranges:
char_ranges[ti].append(fi)
for ti, positions in char_ranges.items():
self.tokens[ti].overlaps[match_id] = (min(positions), max(positions) + 1)
# Position flags: first and last token of TU.
self.tokens[0].set_position(df.position.start)
self.tokens[-1].set_position(df.position.end)
# ---------------------------------------------------------------------------
# Transcript (steps 1, 3, 4, 6)
# ---------------------------------------------------------------------------
@dataclass
class Transcript:
tr_id: str
speakers: dict[str, int] = field(default_factory=dict)
_tu_by_id: dict[int, TranscriptionUnit] = field(default_factory=dict, repr=False)
transcription_units: list[TranscriptionUnit] = field(default_factory=list)
tot_length: float = 0.0
time_based_overlaps: nx.Graph = field(default_factory=nx.Graph)
overlap_events: dict[int, tuple[float, float]] = field(default_factory=dict)
# ------------------------------------------------------------------
# Step 1 — Add TUs
# ------------------------------------------------------------------
def add(self, tu: TranscriptionUnit):
if tu.speaker not in self.speakers:
self.speakers[tu.speaker] = 0
if tu.include:
self.speakers[tu.speaker] += 1
self._tu_by_id[tu.tu_id] = tu
# ------------------------------------------------------------------
# Step 3 — Sort
# ------------------------------------------------------------------
def sort(self):
self.transcription_units = sorted(
self._tu_by_id.values(), key=lambda tu: tu.start
)
if self.transcription_units:
self.tot_length = self.transcription_units[-1].end
# ------------------------------------------------------------------
# Step 4 — Find time-based overlaps
# ------------------------------------------------------------------
def find_overlaps(self, duration_threshold: float = 0.0):
G = nx.Graph()
tus = [tu for tu in self.transcription_units if tu.include]
for i, tu1 in enumerate(tus):
for tu2 in tus[i + 1:]:
if tu1.end > tu2.start and tu2.end > tu1.start:
if tu1.tu_id not in G:
G.add_node(tu1.tu_id, speaker=tu1.speaker,
overlaps=tu1.overlapping_spans)
if tu2.tu_id not in G:
G.add_node(tu2.tu_id, speaker=tu2.speaker,
overlaps=tu2.overlapping_spans)
start = max(tu1.start, tu2.start)
end = min(tu1.end, tu2.end)
G.add_edge(tu1.tu_id, tu2.tu_id,
start=start, end=end, duration=end - start)
self.time_based_overlaps = G
# ------------------------------------------------------------------
# Step 6 — Resolve overlaps
# ------------------------------------------------------------------
def check_overlaps(
self,
duration_threshold: float,
relations_to_ignore: list[tuple] | None = None,
nvb_participates: bool = False,
):
if relations_to_ignore is None:
relations_to_ignore = []
# 6a. Remove NVB-only edges (unless nvb_participates is True).
if not nvb_participates:
to_remove = [
(u, v)
for u, v in self.time_based_overlaps.edges()
if (all(df.tokentype.nonverbalbehavior in t.token_type
for t in self._tu_by_id[u].tokens) or
all(df.tokentype.nonverbalbehavior in t.token_type
for t in self._tu_by_id[v].tokens))
]
for u, v in to_remove:
logger.warning("Removing NVB edge %s-%s", u, v)
self.time_based_overlaps.remove_edge(u, v)
# 6b. Remove manually ignored pairs.
for u, v in relations_to_ignore:
if self.time_based_overlaps.has_edge(u, v):
logger.warning("Removing ignored edge %s-%s", u, v)
self.time_based_overlaps.remove_edge(u, v)
# 6c. Remove short unannotated overlaps; nudge TU boundaries.
to_remove = []
for u, v in list(self.time_based_overlaps.edges()):
edge = self.time_based_overlaps[u][v]
tu_u = self._tu_by_id[u]
tu_v = self._tu_by_id[v]
if (edge["duration"] < duration_threshold and
not tu_u.overlapping_spans and not tu_v.overlapping_spans):
half = edge["duration"] / 2
min_tu, max_tu = sorted([tu_u, tu_v], key=lambda t: t.tu_id)
min_tu.end -= half
max_tu.start += half
min_tu.warnings["MOVED_BOUNDARIES"] += 1
max_tu.warnings["MOVED_BOUNDARIES"] += 1
to_remove.append((u, v))
for u, v in to_remove:
logger.warning("Removing short unannotated overlap %s-%s", u, v)
self.time_based_overlaps.remove_edge(u, v)
# 6d. Cliques → overlap events.
cliques = sorted(
(c for c in nx.find_cliques(self.time_based_overlaps) if len(c) > 1),
key=len,
)
self.overlap_events = {}
for clique_id, clique in enumerate(cliques):
starts = [self._tu_by_id[n].start for n in clique]
ends = [self._tu_by_id[n].end for n in clique]
nvb_in_clique = any(
any(df.tokentype.nonverbalbehavior in t.token_type
for t in self._tu_by_id[n].tokens)
for n in clique
)
overlap_start = max(starts)
overlap_end = min(ends)
self.overlap_events[clique_id] = (overlap_start, overlap_end)
for node in clique:
partners = tuple(n for n in clique if n != node)
self._tu_by_id[node].overlapping_times[partners] = (
overlap_start, overlap_end, clique_id, nvb_in_clique
)
# 6e. Match annotated spans to overlap events.
for tu in self._tu_by_id.values():
spans = tu.overlapping_spans
times = tu.overlapping_times
n_spans = len(spans)
n_times = len(times)
if n_spans == n_times:
sorted_times = sorted(times.items(), key=lambda kv: kv[1][0])
tu.overlapping_matches = dict(
zip(spans, (kv[1][2] for kv in sorted_times))
)
elif n_spans == 0:
# Record durations and check which events are removable.
removable_ids: set[int] = set()
for el, (os, oe, cid, nvb) in times.items():
tu.overlap_duration["+".join(str(x) for x in el)] = oe - os
if (nvb and not nvb_participates) or (oe - os < duration_threshold):
removable_ids.add(cid)
all_clique_ids = {v[2] for v in times.values()}
if removable_ids >= all_clique_ids:
tu.warnings["MISMATCHING_OVERLAPS"] = True
else:
tu.errors["OVERLAPS:MISSING_ANNOTATION"] = True
elif n_times == 0:
tu.errors["OVERLAPS:MISSING_TIME"] = True
tu.overlapping_matches = {span: "?" for span in spans}
elif n_times > n_spans:
diff = n_times - n_spans
removable_ids = set()
for el, (os, oe, cid, nvb) in times.items():
if (oe - os < duration_threshold) or (nvb and not nvb_participates):
removable_ids.add(cid)
if len(removable_ids) == diff:
sorted_times = sorted(times.items(), key=lambda kv: kv[1][0])
keep_ids = [kv[1][2] for kv in sorted_times
if kv[1][2] not in removable_ids]
tu.overlapping_matches = dict(zip(spans, keep_ids))
tu.warnings["MISMATCHING_OVERLAPS"] = True
else:
tu.errors["MISMATCHING_OVERLAPS"] = True
tu.overlapping_matches = {span: "?" for span in spans}
for el, (os, oe, _, _) in times.items():
tu.overlap_duration["+".join(str(x) for x in el)] = oe - os
else:
tu.errors["MISMATCHING_OVERLAPS"] = True
tu.overlapping_matches = {span: "?" for span in spans}
for el, (os, oe, _, _) in times.items():
tu.overlap_duration["+".join(str(x) for x in el)] = oe - os
def __iter__(self):
return iter(self.transcription_units)