forked from Fadi987/StyleTTS2
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathgenerate_TTS2_lists.py
More file actions
156 lines (126 loc) · 7.13 KB
/
generate_TTS2_lists.py
File metadata and controls
156 lines (126 loc) · 7.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#!/usr/bin/env python3
"""
generate_TTS2_lists.py
This script reads a metadata CSV file (produced by hfData2WavFiles.py) and generates two list files for StyleTTS2:
- A training list file (e.g., Data/train_list.txt)
- A validation list file (e.g., Data/val_list.txt)
Each output file contains one line per sample in the format:
file_name|phonetic_text|speakerID
For the training set only, if a target duration (in seconds) is specified, the script selects audio
until it reaches that duration **for each gender separately**.
Sorting is determined by --duration_order:
- "random" (default) shuffles the files,
- "min" sorts in ascending order (shorter files first),
- "max" sorts in descending order (longer files first).
If the available data is less than `target_duration`, it selects **all files for that gender**.
The validation (test) set remains unchanged.
Usage:
python generate_TTS2_lists.py --metadata_csv dataset_metadata.csv \
--train_split train --val_split test \
--train_list Data/train_list.txt --val_list Data/val_list.txt \
--text_field phonetic_text --target_duration 3600 --duration_order random
"""
import os
import argparse
import pandas as pd
def write_list_file(data: pd.DataFrame, output_file: str, text_field: str):
"""Write list file in format: file_name|phonetic_text|speakerID."""
out_dir = os.path.dirname(output_file)
if out_dir and not os.path.exists(out_dir):
os.makedirs(out_dir, exist_ok=True)
with open(output_file, "w", encoding="utf-8") as f:
for _, row in data.iterrows():
file_name = row.get("file_name", "").strip()
phonetic_text = str(row.get(text_field, "")).strip()
gender = row.get("gender", "").strip().lower()
# Assign speakerID based on gender (0 for female, 1 for male)
speaker_id = "0" if gender == "female" else "1"
f.write(f"{file_name}|{phonetic_text}|{speaker_id}\n")
print(f"Wrote {len(data)} entries to {output_file}.")
def select_per_gender(df: pd.DataFrame, target_duration: float, order: str) -> pd.DataFrame:
"""
Select rows **per gender** given a target cumulative duration.
- Each gender ('male' and 'female') gets at least `target_duration` seconds.
- If the total available duration for a gender is **less than `target_duration`**, all files are selected.
- Sorting behavior: 'random', 'min' (ascending), or 'max' (descending).
"""
selected_rows = []
for gender in ['male', 'female']:
df_gender = df[df['gender'].str.lower() == gender].copy()
if df_gender.empty:
print(f"Warning: No data found for gender '{gender}'.")
continue
total_duration = df_gender["duration"].sum()
if total_duration < target_duration:
print(f"Total available duration for {gender} ({total_duration:.2f} sec) is less than {target_duration:.2f} sec. Selecting full data.")
selected_rows.append(df_gender)
continue
# Apply sorting behavior
if order == 'random':
df_gender = df_gender.sample(frac=1, random_state=42)
elif order == 'min':
df_gender = df_gender.sort_values(by='duration', ascending=True)
elif order == 'max':
df_gender = df_gender.sort_values(by='duration', ascending=False)
else:
df_gender = df_gender.sample(frac=1, random_state=42)
cum_duration = 0.0
selected_indices = []
for idx, row in df_gender.iterrows():
try:
dur = float(row.get("duration", 0.0))
except Exception:
dur = 0.0
selected_indices.append(idx)
cum_duration += dur
if cum_duration >= target_duration:
break
print(f"Selected {len(selected_indices)} entries for gender '{gender}' with cumulative duration {cum_duration:.2f} sec.")
selected_rows.append(df_gender.loc[selected_indices])
return pd.concat(selected_rows, ignore_index=True) if selected_rows else df
def main():
parser = argparse.ArgumentParser(description="Generate list files for StyleTTS2 from a metadata CSV file.")
parser.add_argument("--metadata_csv", type=str, required=True, help="Path to metadata CSV file.")
parser.add_argument("--train_split", type=str, default="train", help="Metadata 'split' value for training.")
parser.add_argument("--val_split", type=str, default="test", help="Metadata 'split' value for validation.")
parser.add_argument("--train_list", type=str, default="Data/train_list.txt", help="Output file for train list.")
parser.add_argument("--val_list", type=str, default="Data/val_list.txt", help="Output file for validation list.")
parser.add_argument("--text_field", type=str, default="phonetic_text", help="Column in metadata containing phonetic text.")
parser.add_argument("--target_duration", type=float, default=None, help="Target duration **per gender** in seconds.")
parser.add_argument("--duration_order", type=str, default="random", choices=["random", "min", "max"], help="Ordering method: 'random' (default), 'min' (ascending), 'max' (descending).")
parser.add_argument("--max_duration", type=float, default=None, help="Skip files with duration greater than this value (in seconds).")
args = parser.parse_args()
if not os.path.exists(args.metadata_csv):
print(f"Error: Metadata CSV file '{args.metadata_csv}' not found.")
exit(1)
df = pd.read_csv(args.metadata_csv)
# import pdb; pdb.set_trace()
# Validate required columns exist
for col in ["file_name", "split", args.text_field, "gender", "duration"]:
if col not in df.columns:
print(f"Error: Required column '{col}' not found in metadata.")
exit(1)
# Filter out files with duration greater than max_duration
if args.max_duration is not None:
original_count = len(df)
df = df[df["duration"] <= args.max_duration]
filtered_count = original_count - len(df)
print(f"Filtered out {filtered_count} files with duration > {args.max_duration} sec.")
df_train = df[df["split"] == args.train_split]
df_val = df[df["split"] == args.val_split]
print(f"Total training entries: {len(df_train)}")
print(f"Total validation entries: {len(df_val)}")
# Apply gender-balanced selection **only for training set**, if requested
if args.target_duration is not None:
print(f"Applying gender-balanced selection for training set: {args.target_duration} sec per gender, sorted by '{args.duration_order}'.")
df_train = select_per_gender(df_train, args.target_duration, args.duration_order)
print(f"After filtering, training entries: {len(df_train)}")
# Save filtered metadata for verification
filtered_train_csv = "filtered_train_metadata.csv"
df_train.to_csv(filtered_train_csv, index=False)
print(f"Saved filtered training metadata to {filtered_train_csv}")
# Write the training and validation list files
write_list_file(df_train, args.train_list, args.text_field)
write_list_file(df_val, args.val_list, args.text_field)
if __name__ == "__main__":
main()