diff --git a/webui_all.py b/webui_all.py new file mode 100644 index 0000000..e28de17 --- /dev/null +++ b/webui_all.py @@ -0,0 +1,950 @@ +import os +import json +from tqdm import tqdm +from dataclasses import dataclass, asdict + +import torch +import torchaudio + +from typing import List,Tuple + +from subprocess import Popen +import platform +import psutil +import signal +import shutil +import sys +import requests +import torch,gc +import wave + +from glob import glob + +import gradio as gr +from pathlib import Path +#get function for repo +from utils.audio import LogMelSpectrogram, load_and_resample_audio +from preprocess import g2p_mapping,load_filelist +from api import StableTTSAPI +from config import MelConfig + +# webui +import re +import numpy as np +import matplotlib.pyplot as plt +import random + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +system = platform.system() +python_executable = sys.executable or "python" + +supported_languages = list(g2p_mapping.keys()) + +os.makedirs('./runs', exist_ok=True) +os.makedirs('./stableTTS_datasets', exist_ok=True) +os.makedirs("./checkpoints/pretrain", exist_ok=True) + +model_stable_tts = None +training_process = None +tenserboard_process = None + +#settings +tts_model_path = './checkpoints/checkpoint_0.pt' +vocoder_model_path = './vocoders/pretrained/firefly-gan-base-generator.ckpt' +vocoder_type = 'ffgan' +current_model = "" + +#models +pretrained_model_path = os.path.join("./checkpoints/pretrain", "checkpoint_0.pt") +url_model="https://huggingface.co/KdaiP/StableTTS1.1/resolve/main/StableTTS/checkpoint_0.pt" + +vocos_model_path = './vocoders/pretrained/vocos.pt' +url_vocos="https://huggingface.co/KdaiP/StableTTS1.1/resolve/main/vocoders/vocos.pt" + +firefly_model_path = './vocoders/pretrained/firefly-gan-base-generator.ckpt' +url_firefly='https://github.com/fishaudio/vocoder/releases/download/1.0.0/firefly-gan-base-generator.ckpt' + +#check model and models +def download_model_file(url: str, local_filename: str): + response = requests.get(url, stream=True) + + if response.status_code == 200: + print(f"Downloading: {local_filename}") + + total_size = int(response.headers.get('content-length', 0)) + + with open(local_filename, 'wb') as file, tqdm( + total=total_size, unit='B', unit_scale=True, unit_divisor=1024 + ) as progress_bar: + for chunk in response.iter_content(chunk_size=8192): + file.write(chunk) + progress_bar.update(len(chunk)) + print(f"File downloaded and saved as {local_filename}") + else: + print(f"Failed to download file. Status code: {response.status_code}") + +if not os.path.isfile(pretrained_model_path): + download_model_file(url_model, pretrained_model_path) + +if not os.path.isfile(vocos_model_path): + download_model_file(url_vocos, vocos_model_path) + +if not os.path.isfile(firefly_model_path): + download_model_file(url_firefly, firefly_model_path) + +#terminal +def terminate_process_tree(pid, including_parent=True): + try: + parent = psutil.Process(pid) + except psutil.NoSuchProcess: + # Process already terminated + return + + children = parent.children(recursive=True) + for child in children: + try: + os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL + except OSError: + pass + if including_parent: + try: + os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL + except OSError: + pass + +def terminate_process(pid): + if system == "Windows": + cmd = f"taskkill /t /f /pid {pid}" + os.system(cmd) + else: + terminate_process_tree(pid) + +#preprocess +def copy_speakers(input_file_list, destination_dir, num_files=3): + + if not os.path.exists(destination_dir): + os.makedirs(destination_dir) + else: + return + + if os.path.isfile(input_file_list)==False:return + + with open(input_file_list, 'r',encoding='utf-8') as file: + data = [json.loads(line) for line in file] + + wav_files=[] + for item in data: + wav_files.append(item["audio_path"]) + + if not wav_files: + return + + file_durations = [] + + for wav_file in wav_files: + with wave.open(wav_file, 'r') as wav_obj: + frames = wav_obj.getnframes() + rate = wav_obj.getframerate() + duration = frames / float(rate) + file_durations.append((wav_file, duration)) + + file_durations.sort(key=lambda x: x[1], reverse=True) + + longest_files = [file for file, _ in file_durations[:num_files]] + + for wav_file in longest_files: + shutil.copyfile(wav_file, os.path.join(destination_dir, os.path.basename(wav_file))) + +def create_project(project_name: str): + os.makedirs(f'./filelists/{project_name}', exist_ok=True) + input_file_list = f'./filelists/{project_name}/filelist.txt' + output_file_list = f'./filelists/{project_name}/filelist.json' + output_feature_dir = f'./stableTTS_datasets/{project_name}' + return [input_file_list, output_file_list, output_feature_dir] + +def get_project_files(project_name: str): + output_file_list = f'./filelists/{project_name}/filelist.json' + run_dir = f'./runs/{project_name}' + checkpoint_dir = f'./checkpoints/{project_name}' + return [output_file_list, run_dir, checkpoint_dir] + +def get_projects(folder_path=r'./filelists') -> List[str]: + json_files = [] + for folder in os.listdir(folder_path): + file_json = os.path.join(folder_path, folder, "filelist.json") + if os.path.isfile(file_json): + json_files.append(folder) + return json_files + +def refresh_projects() -> Tuple[List[str], str]: + projects = get_projects() + first_project = projects[0] if projects else None + return projects, first_project + +def export_language_value(file_path,language_value): + with open(file_path, 'w', encoding='utf-8') as file: + json.dump({"language":language_value}, file, ensure_ascii=False, indent=4) + +def import_language_value(file_path): + if os.path.isfile(file_path)==False:return "chinese" + with open(file_path, 'r', encoding='utf-8') as file: + data = json.load(file) + return data["language"] + +def preprocess_audio_files(input_file_list: str, output_feature_dir: str, output_file_list: str, language: str, copy_speaker:bool, progress=gr.Progress()): + + if not os.path.isfile(input_file_list): + return f"No such file or directory: '{input_file_list}'" + + mel_config = MelConfig() + + mel_extractor = LogMelSpectrogram(**asdict(mel_config)).to(device) + text_to_phoneme = g2p_mapping.get(language) + + output_mel_dir = os.path.join(output_feature_dir, 'mels') + os.makedirs(output_mel_dir, exist_ok=True) + os.makedirs(os.path.dirname(output_file_list), exist_ok=True) + + @torch.inference_mode() + def process_audio_file(line) -> str: + idx, audio_path, text = line + audio = load_and_resample_audio(audio_path, mel_config.sample_rate, device=device) + if audio is not None: + audio_name, _ = os.path.splitext(os.path.basename(audio_path)) + + try: + phonemes = text_to_phoneme(text) + if phonemes: + mel = mel_extractor(audio.to(device)).cpu().squeeze(0) + output_mel_path = os.path.join(output_mel_dir, f'{idx}_{audio_name}.pt') + torch.save(mel, output_mel_path) + + return json.dumps({'mel_path': output_mel_path, 'phone': phonemes, 'audio_path': audio_path, 'text': text, 'mel_length': mel.size(-1)}, ensure_ascii=False, allow_nan=False) + except Exception as e: + print(f'Error processing {audio_path}: {str(e)}') + + input_file_list_data = load_filelist(input_file_list) + processed_files = [] + + for i, line in enumerate(progress.tqdm(input_file_list_data, desc="Processing files")): + result = process_audio_file(line) + if result: + processed_files.append(f'{result}\n') + + with open(output_file_list, 'w', encoding='utf-8') as f: + f.writelines(processed_files) + + + path = Path(input_file_list) + directory_name = path.parent.name + directory_name = os.path.join("checkpoints",directory_name,"speakers") + os.makedirs(directory_name,exist_ok=True) + + if copy_speaker: + copy_speakers(output_file_list, directory_name, num_files=3) + + export_language_value(os.path.join(directory_name,"lag.json"),language) + + return f"File list has been saved to {output_file_list}" + +# train + +def save_config(config, file_path): + with open(file_path, 'w') as file: + json.dump(config, file, indent=4) + print(f"Configuration saved to {file_path}") + +def load_config(file_path): + with open(file_path, 'r') as file: + config = json.load(file) + return config + +def get_config_data(project_name): + if project_name is None:return 16,0.0001,200,16,1,200 + config_dir = os.path.join(r'./filelists',project_name) + config_file_json = os.path.join(config_dir, "config.json") + if os.path.isfile(config_file_json)==False:return 16,0.0001,200,16,1,200 + data=load_config(config_file_json) + + return data["batch_size"],data["learning_rate"],data["num_epochs"]-1,data["log_interval"],data["save_interval"],data["warmup_steps"] + +def create_training_config(config_file,config_file_json, + train_dataset_path: str = 'filelists/filelist.json', + test_dataset_path: str = 'filelists/filelist.json', + batch_size: int = 32, + learning_rate: float = 1e-4, + num_epochs: int = 10000, + model_save_path: str = './checkpoints', + log_dir: str = './runs', + log_interval: int = 16, + save_interval: int = 1, + warmup_steps: int = 200, + ): + + config_content = f""" +from dataclasses import dataclass + +@dataclass +class MelConfig: + sample_rate: int = 44100 + n_fft: int = 2048 + win_length: int = 2048 + hop_length: int = 512 + f_min: float = 0.0 + f_max: float = None + pad: int = 0 + n_mels: int = 128 + center: bool = False + pad_mode: str = "reflect" + mel_scale: str = "slaney" + + def __post_init__(self): + if self.pad == 0: + self.pad = (self.n_fft - self.hop_length) // 2 + +@dataclass +class ModelConfig: + hidden_channels: int = 256 + filter_channels: int = 1024 + n_heads: int = 4 + n_enc_layers: int = 3 + n_dec_layers: int = 6 + kernel_size: int = 3 + p_dropout: int = 0.1 + gin_channels: int = 256 + +@dataclass +class TrainConfig: + train_dataset_path: str = "{train_dataset_path}" + test_dataset_path: str = "{test_dataset_path}" + batch_size: int = {batch_size} + learning_rate: float = {learning_rate} + num_epochs: int = {num_epochs} + model_save_path: str = "{model_save_path}" + log_dir: str = "{log_dir}" + log_interval: int = {log_interval} + save_interval: int = {save_interval} + warmup_steps: int = {warmup_steps} + +@dataclass +class VocosConfig: + input_channels: int = 128 + dim: int = 512 + intermediate_dim: int = 1536 + num_layers: int = 8 +""" + with open(config_file, "w") as f: + f.write(config_content) + + config = { + "train_dataset_path": train_dataset_path, + "test_dataset_path": test_dataset_path, + "batch_size": batch_size, + "learning_rate": learning_rate, + "num_epochs": num_epochs, + "model_save_path": model_save_path, + "log_dir": log_dir, + "log_interval": log_interval, + "save_interval": save_interval, + "warmup_steps": warmup_steps + } + + save_config(config,config_file_json) + +def train_model(train_dataset_path: str, batch_size:int, learning_rate:float, num_epochs:int, model_save_path:str, log_dir:str, log_interval:str, save_interval:int, warmup_steps:int, use_finetune:bool, pretrain_model:str = r"./checkpoints/pretrain/checkpoint_0.pt"): + + config_dir = os.path.dirname(train_dataset_path) + config_file = os.path.join(config_dir, "config.py") + config_file_json = os.path.join(config_dir, "config.json") + + create_training_config(config_file,config_file_json, train_dataset_path, "", batch_size, learning_rate, num_epochs + 1, model_save_path, log_dir, log_interval, save_interval, warmup_steps) + + if os.path.isfile(config_file): + shutil.copy(config_file, "config.py") + + if use_finetune: + os.makedirs(model_save_path, exist_ok=True) + finetune_model = os.path.join(model_save_path, "checkpoint_0.pt") + if not os.path.isfile(finetune_model): + shutil.copy(pretrain_model, finetune_model) + + yield "Training started !",gr.update(interactive=False),gr.update(interactive=True) + + clear_model() + start_training() + yield "Training finish !",gr.update(interactive=True),gr.update(interactive=False) + +def clear_model(): + global model_stable_tts + if model_stable_tts is not None: + del model_stable_tts + gc.collect() + torch.cuda.empty_cache() + model_stable_tts=None + +def start_training(): + global training_process + if training_process is not None:return f"Train run already!" + + cmd = f"{python_executable} train.py" + + training_process = Popen(cmd, shell=True) + training_process.wait() + +def stop_training(): + global training_process + if training_process is None:return f"Train not run !" + terminate_process_tree(training_process.pid) + training_process=None + return "Training cancel !",gr.update(interactive=True),gr.update(interactive=False) + +def refresh_dropdown_train(): + names,select=refresh_projects() + if select=="":select=None + return gr.Dropdown(choices=names,value=select, label="Project") + +def refresh_train_stage(project_name): + bt_train = button_disable() if project_name is None else button_enable() + if project_name is not None: + config_dir = os.path.join(r'./filelists',project_name) + config_file_json = os.path.join(config_dir, "config.json") + value = button_disable() if os.path.isfile(config_file_json) else button_enable() + else: + value = button_enable() + return bt_train,value,value + +# tensorboard +def start_tensorboard(log_dir: str, port: int = 6006): + global tenserboard_process + if tenserboard_process is not None:return f"Tensorboard run on port {port}",gr.update(interactive=False),gr.update(interactive=True),gr.update(interactive=True) + + try: + cmd = f"tensorboard --logdir {log_dir} --port {port}" + tenserboard_process = Popen(cmd, shell=True) + yield f"TensorBoard started. Open http://localhost:{port} to view.",gr.update(interactive=False),gr.update(interactive=True),gr.update(interactive=True) + tenserboard_process.wait() + + except Exception as e: + return f"Failed to start TensorBoard: {str(e)}",gr.update(interactive=False),gr.update(interactive=True),gr.update(interactive=False) + +def stop_tensorboard(): + global tenserboard_process + if tenserboard_process is None:return f"Tensorboard not run !",gr.update(interactive=True),gr.update(interactive=False),gr.update(interactive=False) + + try: + terminate_process_tree(tenserboard_process.pid) + yield "Tensorboard stopped",gr.update(interactive=True),gr.update(interactive=False),gr.update(interactive=False) + tenserboard_process=None + except Exception as e: + return f"Failed to stop TensorBoard: {str(e)}",gr.update(interactive=True),gr.update(interactive=False),gr.update(interactive=False) + +def get_tensorboard_projects(folder_path=r'./runs') -> List[str]: + return os.listdir(folder_path) + +def refresh_tensorboard_projects() -> Tuple[List[str], str]: + projects = get_tensorboard_projects() + first_project = projects[0] if projects else "" + return projects, first_project + +def refresh_tensorboard_stage(): + projects = get_tensorboard_projects() + if projects==[]:return button_disable(),button_disable(),button_disable() + return button_enable(),button_disable(),button_disable() + +def get_tensorboard_log_dir(project_name="", folder_path=r'./runs') -> str: + if project_name is None:project_name="" + return f"{folder_path}/{project_name}" + +def refresh_dropdown_tensorboard(name): + names,select=refresh_tensorboard_projects() + if select=="":select=None + return gr.Dropdown(choices=names,value=select, label="Project") + +# interface +def update_model(tts_model_path:str, vocoder_model_path:str, vocoder_type:str): + global model_stable_tts + global current_model + if current_model != tts_model_path: + model_stable_tts = StableTTSAPI(tts_model_path, vocoder_model_path, vocoder_type).to(device) + +# Function to extract the number from the file name +def get_checkpoint_number(filename): + match = re.search(r'(\d+)', filename) + return int(match.group()) if match else -1 + +def get_checkpoints(folder: str) -> Tuple[List[str], str]: + + if not folder: + return [], "" + + checkpoints_path = os.path.join('checkpoints', folder, "*.pt") + checkpoints = glob(checkpoints_path) + checkpoint_names = [os.path.basename(item) for item in checkpoints if 'checkpoint' in os.path.basename(item)] + checkpoint_names = sorted(checkpoint_names, key=get_checkpoint_number) + + if checkpoint_names: + return checkpoint_names, checkpoint_names[0] + + return [], "" + +def get_speakers(folder: str) -> Tuple[List[str], str]: + if not folder:return [], "" + speakes_path = os.path.join('checkpoints', folder,"speakers") + if os.path.isdir(speakes_path)==False:return [],"" + speakers = glob(os.path.join(speakes_path, "*.wav")) + speaker_names = [os.path.basename(item) for item in speakers] + + if speaker_names: + return speaker_names, speaker_names[0] + + return [], "" + +def refresh_dropdown_speakers(folder): + names,select=get_speakers(folder) + if select=="":select=None + return gr.Dropdown(choices=names,value=select, label="Speaker"),select + +def select_speaker(folder,filename): + if folder is None:return None + speakes_path = os.path.join('checkpoints', folder,"speakers") + if os.path.isdir(speakes_path)==False:return None + filewave=os.path.join(speakes_path,filename) + if os.path.isfile(filewave)==False:return None + return filewave + +def refresh_dropdown_checkpoints(folder): + lag="chinese" + if folder is None:return gr.Dropdown(choices=[],value=None, label="Checkpoint"),lag + file_lag=os.path.join('./checkpoints',folder,"speakers",'lag.json') + lag=import_language_value(file_lag) + names,select=get_checkpoints(folder) + return gr.Dropdown(choices=names,value=select, label="Checkpoint"),lag + +def refresh_dropdown(): + names,select=refresh_projects_interface() + return gr.Dropdown(choices=names,value=select, label="Project") + +def get_projects_interface(folder_path=r'./checkpoints') -> List[str]: + folders = [] + for folder in os.listdir(folder_path): + names,select=get_checkpoints(folder) + if names!=[]:folders.append(folder) + return folders + +def refresh_projects_interface() -> Tuple[List[str], str]: + projects = get_projects_interface() + first_project = projects[0] if projects else None + return projects, first_project + +@torch.inference_mode() +def generate_speech(text, ref_audio, language, step, temperature, length_scale, solver, cfg): + text = remove_newlines_after_punctuation(text) + + if language == 'chinese': + text = text.replace(' ', '') + + audio, mel = model_stable_tts.inference(text, ref_audio, language, step, temperature, length_scale, solver, cfg) + + max_val = torch.max(torch.abs(audio)) + if max_val > 1: + audio = audio / max_val + + audio_output = (model_stable_tts.mel_config.sample_rate, (audio.cpu().squeeze(0).numpy() * 32767).astype(np.int16)) + mel_output = plot_mel_spectrogram(mel.cpu().squeeze(0).numpy()) + + return audio_output, mel_output + +def plot_mel_spectrogram(mel_spectrogram): + plt.close() # prevent memory leak + fig, ax = plt.subplots(figsize=(20, 8)) + ax.imshow(mel_spectrogram, aspect='auto', origin='lower') + plt.axis('off') + fig.subplots_adjust(left=0, right=1, top=1, bottom=0) # remove white edges + return fig + +def remove_newlines_after_punctuation(text): + pattern = r'([,。!?、""''《》【】;:,.!?\'\"<>()\[\]{}])\n' + return re.sub(pattern, r'\1', text) + +def set_seed(seed): + seed = int(seed) + seed = seed if seed != -1 else random.randrange(1 << 32) + os.environ['PYTHONHASHSEED'] = str(seed) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + try: + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + except: + pass + return seed + +@torch.inference_mode() +def generate_tts(folder, checkpoint, text, ref_audio, language, step, temperature, length_scale, solver, cfg,seed=0,random=True): + if folder is None or checkpoint is None or ref_audio is None: return None, None, seed + + if random:seed=-1 + seed = set_seed(seed) + + update_model(os.path.join("checkpoints", folder, checkpoint), vocoder_model_path, vocoder_type) + + text = remove_newlines_after_punctuation(text) + + if language == 'chinese': + text = text.replace(' ', '') + + audio, mel = model_stable_tts.inference(text, ref_audio, language, step, temperature, length_scale, solver, cfg) + + max_val = torch.max(torch.abs(audio)) + if max_val > 1: + audio = audio / max_val + + audio_output = (model_stable_tts.mel_config.sample_rate, (audio.cpu().squeeze(0).numpy() * 32767).astype(np.int16)) + mel_output = plot_mel_spectrogram(mel.cpu().squeeze(0).numpy()) + + return audio_output, mel_output , seed + +def check_and_download(pretrained_model_path, vocos_model_path, firefly_model_path): + + if not os.path.isfile(pretrained_model_path): + footer_message = "Downloading Pretrained Model... " + print(footer_message) + yield gr.update(value=f"") + download_model_file(url_model, pretrained_model_path) + + if not os.path.isfile(vocos_model_path): + footer_message = "Downloading Vocos Model... " + print(footer_message) + yield gr.update(value=f"") + download_model_file(url_vocos, vocos_model_path) + + if not os.path.isfile(firefly_model_path): + footer_message = "Downloading Firefly Model... " + print(footer_message) + yield gr.update(value=f"") + download_model_file(url_firefly, firefly_model_path) + + footer_message += "All models downloaded!" + yield gr.update(value=f"") + +def get_file(): + file_path = gr.File.file_select() + return file_path + +def launch_tensorboard(port): + import webbrowser + url = f"http://localhost:{port}" + webbrowser.open(url, new=2) + return f"TensorBoard launched at {url}" + +def button_enable(): + return gr.update(interactive=True) + +def button_disable(): + return gr.update(interactive=False) + +def random_sample(nameproject): + if nameproject is None: return "",None + filelist_file = os.path.join('filelists',nameproject, "filelist.json") + + if os.path.isfile(filelist_file)==False:return "",None + + with open(filelist_file, 'r',encoding='utf-8') as file: + data = [json.loads(line) for line in file] + + entry = random.choice(data) + + return entry["text"],entry["audio_path"] + +def enable_button_setting(value): + return gr.update(interactive=value) + + +def create_interface(): + + with gr.Blocks() as app: + gui_title = 'StableTTS ALL IN ONE' + gui_description = """Next-generation TTS model using flow-matching and DiT, inspired by Stable Diffusion 3.""" + example_text = """你指尖跳动的电光,是我永恒不变的信仰。唯我超电磁炮永世长存!""" + + with gr.Row(): + with gr.Column(): + gr.Markdown(f"# {gui_title}") + gr.Markdown(gui_description) + + + with gr.Tabs(): + + with gr.TabItem("Train"): + + with gr.TabItem("Preprocess Data"): + with gr.Row(): + project_name = gr.Textbox(label="Project Name", value='test') + create_project_btn = gr.Button("Create") + + with gr.Row(): + input_file_list = gr.Textbox(label="Input File List Path", value='./filelists/filelist.txt',interactive=False) + output_file_list = gr.Textbox(label="Output File List Path", value='./filelists/filelist.json',interactive=False) + output_feature_dir = gr.Textbox(label="Output Feature Directory", value='./stableTTS_datasets',interactive=False) + + language = gr.Dropdown(label="Language", choices=supported_languages, value="chinese") + + copy_speaker = gr.Checkbox(label="Copy 3 audio files for use as references.",value=True) + preprocess_output = gr.Textbox(label="Preprocess Output", lines=4) + + preprocess_btn = gr.Button("Preprocess Data",interactive=False) + + preprocess_btn.click( + fn=preprocess_audio_files, + inputs=[input_file_list, output_feature_dir, output_file_list, language,copy_speaker], + outputs=preprocess_output + ) + + create_project_btn.click(fn=create_project, inputs=[project_name], outputs=[input_file_list, output_file_list, output_feature_dir]) + create_project_btn.click(fn=button_enable, outputs=[preprocess_btn]) + + with gr.TabItem("Train Model"): + initial_projects = get_projects() + initial_project = initial_projects[0] if initial_projects else None + + if initial_project is not None: + train_value, log_value, model_value = get_project_files(initial_project) + else: + train_value, log_value, model_value = './filelists/filelist.json', './runs', './checkpoints' + + with gr.Row(): + project_dropdown = gr.Dropdown(choices=initial_projects, value=initial_project, label="Project", interactive=True,allow_custom_value=True) + refresh_projects_btn = gr.Button("Refresh Projects") + + with gr.Row(): + train_dataset_path = gr.Textbox(label="Train Dataset Path", value=train_value,interactive=False) + log_dir = gr.Textbox(label="Log Directory", value=log_value,interactive=False) + model_save_path = gr.Textbox(label="Model Save Path", value=model_value,interactive=False) + + with gr.Row(): + batch_size = gr.Slider(label="Batch Size", minimum=1, maximum=128, step=1, value=16) + log_interval = gr.Slider(label="Log Interval", minimum=1, maximum=100, step=1, value=16) + warmup_steps = gr.Slider(label="Warmup Steps", minimum=1, maximum=10000, step=1, value=200) + + with gr.Row(): + num_epochs = gr.Slider(label="Number of Epochs", minimum=1, maximum=10000, step=1, value=200) + save_interval = gr.Slider(label="Save Interval", minimum=1, maximum=100, step=1, value=1) + learning_rate = gr.Number(label="Learning Rate", value=1e-4) + + with gr.Row(): + use_finetune = gr.Checkbox(label="Use Finetune", value=True) + pretrain_model_path = gr.Textbox(label="Pretrain Model Path", value=pretrained_model_path,interactive=False) + use_finetune.change(enable_button_setting,inputs=[use_finetune],outputs=[pretrain_model_path]) + + + train_output = gr.Textbox(label="Training Output", lines=4) + + with gr.Row(): + + train_start_model_btn = gr.Button("Start Train",interactive=False) + train_stop_model_btn = gr.Button("Stop Train",interactive=False) + + refresh_projects_btn.click(fn=get_config_data,inputs=project_dropdown,outputs=[ batch_size, learning_rate,num_epochs,log_interval, save_interval, warmup_steps]) + refresh_projects_btn.click(fn=refresh_projects, outputs=[project_dropdown, project_dropdown]) + refresh_projects_btn.click(fn=refresh_dropdown_train, outputs=[project_dropdown]) + + project_dropdown.change(fn=get_config_data, inputs=project_dropdown, outputs=[ batch_size, learning_rate,num_epochs,log_interval, save_interval, warmup_steps]) + project_dropdown.change(fn=get_project_files, inputs=project_dropdown, outputs=[train_dataset_path, log_dir, model_save_path]) + + + refresh_projects_btn.click(fn=refresh_train_stage,inputs=[project_dropdown], outputs=[train_start_model_btn,pretrain_model_path,use_finetune]) + + train_stop_model_btn.click(fn=stop_training,outputs=[train_output,train_start_model_btn,train_stop_model_btn]) + train_start_model_btn.click( + fn=train_model, + inputs=[train_dataset_path, batch_size, learning_rate,num_epochs, model_save_path, log_dir, log_interval, save_interval, warmup_steps, use_finetune, pretrain_model_path], + outputs=[train_output,train_start_model_btn,train_stop_model_btn]) + + with gr.TabItem("Tensorboard"): + with gr.Row(): + initial_tensorboard_projects = get_tensorboard_projects() + initial_tensorboard_project = initial_tensorboard_projects[0] if initial_tensorboard_projects else None + tensorboard_project_dropdown = gr.Dropdown(choices=initial_tensorboard_projects, value=initial_tensorboard_project, label="Project", interactive=True,allow_custom_value=True) + refresh_tensorboard_btn = gr.Button("Refresh Tensorboard") + + if initial_tensorboard_project is not None: + log_value = get_tensorboard_log_dir(initial_tensorboard_project) + else: + log_value = './runs' + + tensorboard_log_path = gr.Textbox(label="Tensorboard Log Directory", value=log_value,interactive=False) + + + with gr.Row(): + port_tensorboard = gr.Number(label="Port",value=6006) + start_tensorboard_btn = gr.Button("Start TensorBoard",interactive=False) + stop_tensorboard_btn = gr.Button("Stop TensorBoard",interactive=False) + open_tensorboard_btn = gr.Button("Open TensorBoard",interactive=False) + + + tensorboard_output = gr.Textbox(label="TensorBoard Output", lines=2) + + start_tensorboard_btn.click( + fn=start_tensorboard, + inputs=[tensorboard_log_path,port_tensorboard], + outputs=[tensorboard_output,start_tensorboard_btn,stop_tensorboard_btn,open_tensorboard_btn], + ) + + stop_tensorboard_btn.click( + fn=stop_tensorboard, + outputs=[tensorboard_output,start_tensorboard_btn,stop_tensorboard_btn,open_tensorboard_btn] + ) + + open_tensorboard_btn.click( + fn=launch_tensorboard, + inputs=[port_tensorboard], + ) + + refresh_tensorboard_btn.click(fn=refresh_tensorboard_projects, outputs=[tensorboard_project_dropdown, tensorboard_project_dropdown]) + refresh_tensorboard_btn.click(fn=refresh_tensorboard_stage, outputs=[start_tensorboard_btn,stop_tensorboard_btn,open_tensorboard_btn]) + tensorboard_project_dropdown.change(fn=refresh_dropdown_tensorboard , outputs=[tensorboard_project_dropdown]) + tensorboard_project_dropdown.change(fn=get_tensorboard_log_dir, inputs=tensorboard_project_dropdown, outputs=[tensorboard_log_path]) + + with gr.TabItem("Interface"): + with gr.Blocks(theme=gr.themes.Base()) as demo: + demo.load(None, None, js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', 'light');window.location.search = params.toString();}}") + + + + with gr.Row(): + with gr.Column(): + initial_projects_interface, initial_project_interface = refresh_projects_interface() + initial_checkpoints, initial_checkpoint = get_checkpoints(initial_project_interface) + + + with gr.Row(): + model_project_dropdown = gr.Dropdown(choices=initial_projects_interface, value=initial_project_interface, label="Project", interactive=True,allow_custom_value=True) + model_checkpoint_dropdown = gr.Dropdown(choices=initial_checkpoints, value=initial_checkpoint, label="Checkpoint", interactive=True,allow_custom_value=True) + refresh_model_btn = gr.Button("Refresh Projects") + + initial_speakers, initial_speaker= get_speakers(initial_project) + with gr.Row(): + random_model_btn = gr.Button("Random Sample") + speaker_dropdown = gr.Dropdown(choices=initial_speakers, value=initial_speaker, label="Speaker", interactive=True,allow_custom_value=True) + + + input_text = gr.Textbox( + label="Input Text", + info="Enter your text here", + value=example_text + ) + + reference_audio = gr.Audio( + label="Reference Audio", + type="filepath" + ) + + + generated_audio = gr.Audio(label="Generated Audio", autoplay=True) + + + + with gr.Column(): + + with gr.Row(): + + language_dropdown = gr.Dropdown( + label='Language', + choices=supported_languages, + value='chinese' + ) + + solver_dropdown = gr.Dropdown( + label='ODE Solver', + choices=['euler', 'midpoint', 'dopri5', 'rk4', 'implicit_adams', 'bosh3', 'fehlberg2', 'adaptive_heun'], + value='dopri5' + ) + + + with gr.Row(): + step_slider = gr.Slider( + label='Step', + minimum=1, + maximum=100, + value=25, + step=1 + ) + + temperature_slider = gr.Slider( + label='Temperature', + minimum=0, + maximum=2, + value=1, + ) + + with gr.Row(): + length_scale_slider = gr.Slider( + label='Length Scale', + minimum=0, + maximum=5, + value=1, + ) + + cfg_slider = gr.Slider( + label='CFG', + minimum=0, + maximum=10, + value=3, + ) + + + with gr.Row(): + + seed_bool = gr.Checkbox( + label='Random', + value=True + ) + + seed_value = gr.Slider( + label='Seeds', + step=1, + minimum=0, + maximum=100000000, + value=0, + interactive=True + ) + + + generate_btn = gr.Button("Generate", elem_id="send-btn", visible=True, variant="primary") + mel_plot = gr.Plot(label="Mel Spectrogram Visualization") + + refresh_model_btn.click(fn=refresh_projects_interface, outputs=[model_project_dropdown, model_project_dropdown]) + refresh_model_btn.click(fn=refresh_dropdown_speakers, inputs=[model_project_dropdown], outputs=[speaker_dropdown,speaker_dropdown]) + refresh_model_btn.click(fn=get_checkpoints, inputs=[model_project_dropdown], outputs=[model_checkpoint_dropdown, model_checkpoint_dropdown]) + refresh_model_btn.click(fn=refresh_dropdown_checkpoints, inputs=[model_project_dropdown], outputs=[model_checkpoint_dropdown,language_dropdown]) + refresh_model_btn.click(fn=refresh_dropdown, outputs=[model_project_dropdown]) + + speaker_dropdown.change(fn=select_speaker,inputs=[model_project_dropdown,speaker_dropdown],outputs=[reference_audio]) + + model_project_dropdown.change(fn=refresh_dropdown_checkpoints, inputs=[model_project_dropdown], outputs=[model_checkpoint_dropdown,language_dropdown]) + + random_model_btn.click(fn=random_sample,inputs=[model_project_dropdown],outputs=[input_text,reference_audio]) + + + generate_btn.click(generate_tts, [model_project_dropdown, model_checkpoint_dropdown, input_text, reference_audio, language_dropdown, step_slider, temperature_slider, length_scale_slider, solver_dropdown, cfg_slider,seed_value,seed_bool], outputs=[generated_audio, mel_plot,seed_value]) + + seed_bool.change(fn=enable_button_setting,inputs=seed_bool,outputs=seed_value) + + + + + footer = gr.HTML(f""" + + """) + + return app + +if __name__ == "__main__": + create_interface().launch(debug=True) \ No newline at end of file