diff --git a/egs/librispeech/asr/simple_v1/.gitignore b/egs/librispeech/asr/simple_v1/.gitignore index 2211df63..abc9727c 100644 --- a/egs/librispeech/asr/simple_v1/.gitignore +++ b/egs/librispeech/asr/simple_v1/.gitignore @@ -1 +1,7 @@ *.txt +exp/ +exp +exp-* +*.model +*.vocab +data diff --git a/egs/librispeech/asr/simple_v1/bpe_mmi_att_transformer_decode.py b/egs/librispeech/asr/simple_v1/bpe_mmi_att_transformer_decode.py new file mode 100755 index 00000000..968e1a3d --- /dev/null +++ b/egs/librispeech/asr/simple_v1/bpe_mmi_att_transformer_decode.py @@ -0,0 +1,608 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2020 Xiaomi Corporation (authors: Daniel Povey +# Haowen Qiu +# Fangjun Kuang) +# 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# Apache 2.0 + +# Usage of this script: +''' + +# Without LM rescoring + +## Use n-best decoding +./bpe_mmi_att_transformer_decode.py \ + --use-lm-rescoring=0 \ + --num-paths=100 \ + --max-duration=300 + +## Use 1-best decoding +./bpe_mmi_att_transformer_decode.py \ + --use-lm-rescoring=0 \ + --num-paths=1 \ + --max-duration=300 + +# With LM rescoring + +## Use whole lattice +./bpe_mmi_att_transformer_decode.py \ + --use-lm-rescoring=1 \ + --num-paths=-1 \ + --max-duration=300 + +## Use n-best list +./bpe_mmi_att_transformer_decode.py \ + --use-lm-rescoring=1 \ + --num-paths=100 \ + --max-duration=300 +''' + +import argparse +import k2 +import logging +import numpy as np +import os +import torch +from k2 import Fsa, SymbolTable +from collections import defaultdict +from pathlib import Path +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Union + + +from snowfall.common import average_checkpoint, store_transcripts +from snowfall.common import find_first_disambig_symbol +from snowfall.common import get_texts +from snowfall.common import write_error_stats +from snowfall.common import load_checkpoint +from snowfall.common import setup_logger +from snowfall.common import str2bool +from snowfall.data import LibriSpeechAsrDataModule +from snowfall.decoding.graph import compile_HLG +from snowfall.decoding.lm_rescore import rescore_with_n_best_list +from snowfall.decoding.lm_rescore import rescore_with_whole_lattice +from snowfall.models import AcousticModel +from snowfall.models.transformer import Transformer +from snowfall.models.conformer import Conformer +from snowfall.models.contextnet import ContextNet +from snowfall.training.ctc_graph import build_ctc_topo +from snowfall.training.mmi_graph import create_bigram_phone_lm +from snowfall.training.mmi_graph import get_phone_symbols + +def nbest_decoding(lats: k2.Fsa, num_paths: int): + ''' + (Ideas of this function are from Dan) + + It implements something like CTC prefix beam search using n-best lists + + The basic idea is to first extra n-best paths from the given lattice, + build a word seqs from these paths, and compute the total scores + of these sequences in the log-semiring. The one with the max score + is used as the decoding output. + ''' + + # First, extract `num_paths` paths for each sequence. + # paths is a k2.RaggedInt with axes [seq][path][arc_pos] + paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True) + + # word_seqs is a k2.RaggedInt sharing the same shape as `paths` + # but it contains word IDs. Note that it also contains 0s and -1s. + # The last entry in each sublist is -1. + + word_seqs = k2.index(lats.aux_labels, paths) + # Note: the above operation supports also the case when + # lats.aux_labels is a ragged tensor. In that case, + # `remove_axis=True` is used inside the pybind11 binding code, + # so the resulting `word_seqs` still has 3 axes, like `paths`. + # The 3 axes are [seq][path][word] + + # Remove epsilons and -1 from word_seqs + word_seqs = k2.ragged.remove_values_leq(word_seqs, 0) + + # Remove repeated sequences to avoid redundant computation later. + # + # Since k2.ragged.unique_sequences will reorder paths within a seq, + # `new2old` is a 1-D torch.Tensor mapping from the output path index + # to the input path index. + # new2old.numel() == unique_word_seqs.num_elements() + unique_word_seqs, _, new2old = k2.ragged.unique_sequences( + word_seqs, need_num_repeats=False, need_new2old_indexes=True) + # Note: unique_word_seqs still has the same axes as word_seqs + + seq_to_path_shape = k2.ragged.get_layer(unique_word_seqs.shape(), 0) + + # path_to_seq_map is a 1-D torch.Tensor. + # path_to_seq_map[i] is the seq to which the i-th path + # belongs. + path_to_seq_map = seq_to_path_shape.row_ids(1) + + # Remove the seq axis. + # Now unique_word_seqs has only two axes [path][word] + unique_word_seqs = k2.ragged.remove_axis(unique_word_seqs, 0) + + # word_fsas is an FsaVec with axes [path][state][arc] + word_fsas = k2.linear_fsa(unique_word_seqs) + + word_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsas) + + # lats has phone IDs as labels and word IDs as aux_labels. + # inv_lats has word IDs as labels and phone IDs as aux_labels + inv_lats = k2.invert(lats) + inv_lats = k2.arc_sort(inv_lats) # no-op if inv_lats is already arc-sorted + + path_lats = k2.intersect_device(inv_lats, + word_fsas_with_epsilon_loops, + b_to_a_map=path_to_seq_map, + sorted_match_a=True) + # path_lats has word IDs as labels and phone IDs as aux_labels + + path_lats = k2.top_sort(k2.connect(path_lats.to('cpu')).to(lats.device)) + + tot_scores = path_lats.get_tot_scores(True, True) + # RaggedFloat currently supports float32 only. + # We may bind Ragged as RaggedDouble if needed. + ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape, + tot_scores.to(torch.float32)) + + argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores) + + # Since we invoked `k2.ragged.unique_sequences`, which reorders + # the index from `paths`, we use `new2old` + # here to convert argmax_indexes to the indexes into `paths`. + # + # Use k2.index here since argmax_indexes' dtype is torch.int32 + best_path_indexes = k2.index(new2old, argmax_indexes) + + paths_2axes = k2.ragged.remove_axis(paths, 0) + + # best_paths is a k2.RaggedInt with 2 axes [path][arc_pos] + best_paths = k2.index(paths_2axes, best_path_indexes) + + # labels is a k2.RaggedInt with 2 axes [path][phone_id] + # Note that it contains -1s. + labels = k2.index(lats.labels.contiguous(), best_paths) + + labels = k2.ragged.remove_values_eq(labels, -1) + + # lats.aux_labels is a k2.RaggedInt tensor with 2 axes, so + # aux_labels is also a k2.RaggedInt with 2 axes + aux_labels = k2.index(lats.aux_labels, best_paths.values()) + + best_path_fsas = k2.linear_fsa(labels) + best_path_fsas.aux_labels = aux_labels + + return best_path_fsas + + +def decode_one_batch(batch: Dict[str, Any], + model: AcousticModel, + HLG: k2.Fsa, + output_beam_size: float, + num_paths: int, + use_whole_lattice: bool, + G: Optional[k2.Fsa] = None)->Dict[str, List[List[int]]]: + ''' + Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + + Args: + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + model: + The neural network model. + HLG: + The decoding graph. + output_beam_size: + Size of the beam for pruning. + use_whole_lattice: + If True, `G` must not be None and it will use whole lattice for + LM rescoring. + If False and if `G` is not None, then `num_paths` must be positive + and it will use n-best list for LM rescoring. + num_paths: + It specifies the size of `n` in n-best list decoding. + G: + The LM. If it is None, no rescoring is used. + Otherwise, LM rescoring is used. + It supports two types of LM rescoring: n-best list rescoring + and whole lattice rescoring. + `use_whole_lattice` specifies which type to use. + + Returns: + Return the decoding result. See above description for the format of + the returned dict. + ''' + device = HLG.device + feature = batch['inputs'] + assert feature.ndim == 3 + feature = feature.to(device) + + # at entry, feature is [N, T, C] + feature = feature.permute(0, 2, 1) # now feature is [N, C, T] + + supervisions = batch['supervisions'] + + nnet_output, _, _ = model(feature, supervisions) + # nnet_output is [N, C, T] + + nnet_output = nnet_output.permute(0, 2, 1) + # now nnet_output is [N, T, C] + + supervision_segments = torch.stack( + (supervisions['sequence_idx'], + (((supervisions['start_frame'] - 1) // 2 - 1) // 2), + (((supervisions['num_frames'] - 1) // 2 - 1) // 2)), + 1).to(torch.int32) + + supervision_segments = torch.clamp(supervision_segments, min=0) + indices = torch.argsort(supervision_segments[:, 2], descending=True) + supervision_segments = supervision_segments[indices] + + dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) + + lattices = k2.intersect_dense_pruned(HLG, dense_fsa_vec, 20.0, output_beam_size, 30, 10000) + + if G is None: + if num_paths > 1: + best_paths = nbest_decoding(lattices, num_paths) + key=f'no_rescore-{num_paths}' + else: + key = 'no_rescore' + best_paths = k2.shortest_path(lattices, use_double_scores=True) + hyps = get_texts(best_paths, indices) + return {key: hyps} + + lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + + if use_whole_lattice: + best_paths_dict = rescore_with_whole_lattice(lattices, G, + lm_scale_list) + else: + best_paths_dict = rescore_with_n_best_list(lattices, G, num_paths, + lm_scale_list) + # best_paths_dict is a dict + # - key: lm_scale_xxx, where xxx is the value of lm_scale. An example + # key is lm_scale_1.2 + # - value: it is the best path obtained using the corresponding lm scale + # from the dict key. + + ans = dict() + for lm_scale_str, best_paths in best_paths_dict.items(): + hyps = get_texts(best_paths, indices) + ans[lm_scale_str] = hyps + return ans + + +@torch.no_grad() +def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel, + HLG: Fsa, symbols: SymbolTable, + num_paths: int, G: k2.Fsa, use_whole_lattice: bool, output_beam_size: float): + tot_num_cuts = len(dataloader.dataset.cuts) + num_cuts = 0 + results = defaultdict(list) + # results is a dict whose keys and values are: + # - key: It indicates the lm_scale, e.g., lm_scale_1.2. + # If no rescoring is used, the key is the literal string: no_rescore + # + # - value: It is a list of tuples (ref_words, hyp_words) + + for batch_idx, batch in enumerate(dataloader): + texts = batch['supervisions']['text'] + + hyps_dict = decode_one_batch(batch=batch, + model=model, + HLG=HLG, + output_beam_size=output_beam_size, + num_paths=num_paths, + use_whole_lattice=use_whole_lattice, + G=G) + + for lm_scale, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + + for i in range(len(texts)): + hyp_words = [symbols.get(x) for x in hyps[i]] + ref_words = texts[i].split(' ') + this_batch.append((ref_words, hyp_words)) + + results[lm_scale].extend(this_batch) + + if batch_idx % 10 == 0: + logging.info( + 'batch {}, cuts processed until now is {}/{} ({:.6f}%)'.format( + batch_idx, num_cuts, tot_num_cuts, + float(num_cuts) / tot_num_cuts * 100)) + + num_cuts += len(texts) + + return results + + +def get_parser(): + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + '--model-type', + type=str, + default="conformer", + choices=["transformer", "conformer", "contextnet"], + help="Model type.") + parser.add_argument( + '--epoch', + type=int, + default=10, + help="Decoding epoch.") + parser.add_argument( + '--avg', + type=int, + default=5, + help="Number of checkpionts to average. Automatically select " + "consecutive checkpoints before checkpoint specified by'--epoch'. ") + parser.add_argument( + '--att-rate', + type=float, + default=0.0, + help="Attention loss rate.") + parser.add_argument( + '--nhead', + type=int, + default=4, + help="Number of attention heads in transformer.") + parser.add_argument( + '--attention-dim', + type=int, + default=256, + help="Number of units in transformer attention layers.") + parser.add_argument( + '--output-beam-size', + type=float, + default=8, + help='Output beam size. Used in k2.intersect_dense_pruned.'\ + 'Choose a large value (e.g., 20), for 1-best decoding '\ + 'and n-best rescoring. Choose a small value (e.g., 8) for ' \ + 'rescoring with the whole lattice') + parser.add_argument( + '--use-lm-rescoring', + type=str2bool, + default=True, + help='When enabled, it uses LM for rescoring') + parser.add_argument( + '--num-paths', + type=int, + default=-1, + help='Number of paths for rescoring using n-best list.' \ + 'If it is negative, then rescore with the whole lattice.'\ + 'CAUTION: You have to reduce max_duration in case of CUDA OOM' + ) + parser.add_argument( + '--is-espnet-structure', + type=str2bool, + default=True, + help='When enabled, the conformer will have the ' \ + 'same structure like espnet') + parser.add_argument( + '--vgg-frontend', + type=str2bool, + default=True, + help='When enabled, it uses vgg style network for subsampling') + return parser + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + model_type = args.model_type + epoch = args.epoch + avg = args.avg + att_rate = args.att_rate + num_paths = args.num_paths + use_lm_rescoring = args.use_lm_rescoring + use_whole_lattice = False + if use_lm_rescoring and num_paths < 1: + # It doesn't make sense to use n-best list for rescoring + # when n is less than 1 + use_whole_lattice = True + + output_beam_size = args.output_beam_size + + exp_dir = Path('exp-bpe-' + model_type + '-mmi-att-sa-vgg-normlayer') + setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug') + + logging.info(f'output_beam_size: {output_beam_size}') + + # load L, G, symbol_table + lang_dir = Path('data/lang_bpe2') + symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt') + phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') + + phone_ids = get_phone_symbols(phone_symbol_table) + + phone_ids_with_blank = [0] + phone_ids + ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank)) + + logging.debug("About to load model") + # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N + # device = torch.device('cuda', 1) + device = torch.device('cuda') + + if att_rate != 0.0: + num_decoder_layers = 6 + else: + num_decoder_layers = 0 + + if model_type == "transformer": + model = Transformer( + num_features=80, + nhead=args.nhead, + d_model=args.attention_dim, + num_classes=len(phone_ids) + 1, # +1 for the blank symbol + subsampling_factor=4, + num_decoder_layers=num_decoder_layers, + vgg_frontend=args.vgg_fronted) + elif model_type == "conformer": + model = Conformer( + num_features=80, + nhead=args.nhead, + d_model=args.attention_dim, + num_classes=len(phone_ids) + 1, # +1 for the blank symbol + subsampling_factor=4, + num_decoder_layers=num_decoder_layers, + vgg_frontend=args.vgg_frontend, + is_espnet_structure=args.is_espnet_structure) + elif model_type == "contextnet": + model = ContextNet( + num_features=80, + num_classes=len(phone_ids) + 1) # +1 for the blank symbol + else: + raise NotImplementedError("Model of type " + str(model_type) + " is not implemented") + + if avg == 1: + checkpoint = os.path.join(exp_dir, 'epoch-' + str(epoch - 1) + '.pt') + load_checkpoint(checkpoint, model) + else: + checkpoints = [os.path.join(exp_dir, 'epoch-' + str(avg_epoch) + '.pt') for avg_epoch in + range(epoch - avg, epoch)] + average_checkpoint(checkpoints, model) + + model.to(device) + model.eval() + + if not os.path.exists(lang_dir / 'HLG.pt'): + logging.debug("Loading L_disambig.fst.txt") + with open(lang_dir / 'L_disambig.fst.txt') as f: + L = k2.Fsa.from_openfst(f.read(), acceptor=False) + logging.debug("Loading G.fst.txt") + with open(lang_dir / 'G.fst.txt') as f: + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + first_phone_disambig_id = find_first_disambig_symbol(phone_symbol_table) + first_word_disambig_id = find_first_disambig_symbol(symbol_table) + HLG = compile_HLG(L=L, + G=G, + H=ctc_topo, + labels_disambig_id_start=first_phone_disambig_id, + aux_labels_disambig_id_start=first_word_disambig_id) + torch.save(HLG.as_dict(), lang_dir / 'HLG.pt') + else: + logging.debug("Loading pre-compiled HLG") + d = torch.load(lang_dir / 'HLG.pt') + HLG = k2.Fsa.from_dict(d) + + if use_lm_rescoring: + if use_whole_lattice: + logging.info('Rescoring with the whole lattice') + else: + logging.info(f'Rescoring with n-best list, n is {num_paths}') + first_word_disambig_id = find_first_disambig_symbol(symbol_table) + if not os.path.exists(lang_dir / 'G_4_gram.pt'): + logging.debug('Loading G_4_gram.fst.txt') + with open(lang_dir / 'G_4_gram.fst.txt') as f: + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION(fangjun): The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + G = k2.create_fsa_vec([G]).to(device) + G = k2.arc_sort(G) + torch.save(G.as_dict(), lang_dir / 'G_4_gram.pt') + else: + logging.debug('Loading pre-compiled G_4_gram.pt') + d = torch.load(lang_dir / 'G_4_gram.pt') + G = k2.Fsa.from_dict(d).to(device) + + if use_whole_lattice: + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + logging.debug('Decoding without LM rescoring') + G = None + if num_paths > 1: + logging.debug(f'Use n-best list decoding, n is {num_paths}') + else: + logging.debug('Use 1-best decoding') + + logging.debug("convert HLG to device") + HLG = HLG.to(device) + HLG.aux_labels = k2.ragged.remove_values_eq(HLG.aux_labels, 0) + HLG.requires_grad_(False) + + if not hasattr(HLG, 'lm_scores'): + HLG.lm_scores = HLG.scores.clone() + + + # load dataset + librispeech = LibriSpeechAsrDataModule(args) + test_sets = ['test-clean', 'test-other'] + for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()): + logging.info(f'* DECODING: {test_set}') + + test_set_wers = dict() + results_dict = decode(dataloader=test_dl, + model=model, + HLG=HLG, + symbols=symbol_table, + num_paths=num_paths, + G=G, + use_whole_lattice=use_whole_lattice, + output_beam_size=output_beam_size) + + for key, results in results_dict.items(): + recog_path = exp_dir / f'recogs-{test_set}-{key}.txt' + store_transcripts(path=recog_path, texts=results) + logging.info(f'The transcripts are stored in {recog_path}') + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = exp_dir / f'errs-{test_set}-{key}.txt' + with open(errs_filename, 'w') as f: + wer = write_error_stats(f, f'{test_set}-{key}', results) + test_set_wers[key] = wer + + logging.info('Wrote detailed error stats to {}'.format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = exp_dir / f'wer-summary-{test_set}.txt' + with open(errs_info, 'w') as f: + print('settings\tWER', file=f) + for key, val in test_set_wers: + print('{}\t{}'.format(key, val), file=f) + + s = '\nFor {}, WER of different settings are:\n'.format(test_set) + note = '\tbest for {}'.format(test_set) + for key, val in test_set_wers: + s += '{}\t{}{}\n'.format(key, val, note) + note='' + logging.info(s) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == '__main__': + main() diff --git a/egs/librispeech/asr/simple_v1/bpe_mmi_att_transformer_train.py b/egs/librispeech/asr/simple_v1/bpe_mmi_att_transformer_train.py new file mode 100755 index 00000000..acb4d674 --- /dev/null +++ b/egs/librispeech/asr/simple_v1/bpe_mmi_att_transformer_train.py @@ -0,0 +1,698 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2020 Xiaomi Corporation (authors: Daniel Povey +# Haowen Qiu +# Fangjun Kuang) +# 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# Apache 2.0 + +import argparse +import logging +import os +import sys +from datetime import datetime +from pathlib import Path +from typing import Dict, Optional + +import k2 +import numpy as np +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch import nn +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_value_ +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter + +from lhotse.utils import fix_random_seed, nullcontext +from snowfall.common import describe, str2bool +from snowfall.common import load_checkpoint, save_checkpoint +from snowfall.common import save_training_info +from snowfall.common import setup_logger +from snowfall.data.librispeech import LibriSpeechAsrDataModule +from snowfall.dist import cleanup_dist +from snowfall.dist import setup_dist +from snowfall.lexicon import Lexicon +from snowfall.models import AcousticModel +from snowfall.models.conformer import Conformer +from snowfall.models.contextnet import ContextNet +from snowfall.models.tdnn_lstm import TdnnLstm1b # alignment model +from snowfall.models.transformer import Noam, Transformer +from snowfall.objectives import LFMMILoss, encode_supervisions +from snowfall.training.diagnostics import measure_gradient_norms, optim_step_and_measure_param_change +from snowfall.training.mmi_graph import MmiTrainingGraphCompiler + + +def get_objf(batch: Dict, + model: AcousticModel, + ali_model: Optional[AcousticModel], + device: torch.device, + graph_compiler: MmiTrainingGraphCompiler, + use_pruned_intersect: bool, + is_training: bool, + is_update: bool, + accum_grad: int = 1, + den_scale: float = 1.0, + att_rate: float = 0.0, + tb_writer: Optional[SummaryWriter] = None, + global_batch_idx_train: Optional[int] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scaler: GradScaler = None + ): + feature = batch['inputs'] + # at entry, feature is [N, T, C] + feature = feature.permute(0, 2, 1) # now feature is [N, C, T] + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch['supervisions'] + supervision_segments, texts = encode_supervisions(supervisions) + + loss_fn = LFMMILoss( + graph_compiler=graph_compiler, + den_scale=den_scale, + use_pruned_intersect=use_pruned_intersect + ) + + grad_context = nullcontext if is_training else torch.no_grad + + with autocast(enabled=scaler.is_enabled()), grad_context(): + + if att_rate == 0: + # Note: Make TorchScript happy by making the supervision dict strictly + # conform to type Dict[str, Tensor] + # Using the attention decoder with TorchScript is currently unsupported, + # we'll need to separate out the 'text' field from 'supervisions' first. + del supervisions['text'] + + nnet_output, encoder_memory, memory_mask = model(feature, supervisions) + if att_rate != 0.0: + att_loss = model.module.decoder_forward(encoder_memory, memory_mask, supervisions, graph_compiler) + + if (ali_model is not None and global_batch_idx_train is not None and + global_batch_idx_train // accum_grad < 4000): + with torch.no_grad(): + ali_model_output = ali_model(feature) + # subsampling is done slightly differently, may be small length + # differences. + min_len = min(ali_model_output.shape[2], nnet_output.shape[2]) + # scale less than one so it will be encouraged + # to mimic ali_model's output + ali_model_scale = 500.0 / (global_batch_idx_train // accum_grad + 500) + nnet_output = nnet_output.clone() # or log-softmax backprop will fail. + nnet_output[:, :,:min_len] += ali_model_scale * ali_model_output[:, :,:min_len] + + # nnet_output is [N, C, T] + nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] + + mmi_loss, tot_frames, all_frames = loss_fn(nnet_output, texts, supervision_segments) + + if is_training: + def maybe_log_gradients(tag: str): + if tb_writer is not None and global_batch_idx_train is not None and global_batch_idx_train % 200 == 0: + tb_writer.add_scalars( + tag, + measure_gradient_norms(model, norm='l1'), + global_step=global_batch_idx_train + ) + + if att_rate != 0.0: + loss = (- (1.0 - att_rate) * mmi_loss + att_rate * att_loss) / (len(texts) * accum_grad) + else: + loss = (-mmi_loss) / (len(texts) * accum_grad) + scaler.scale(loss).backward() + if is_update: + maybe_log_gradients('train/grad_norms') + scaler.unscale_(optimizer) + clip_grad_value_(model.parameters(), 5.0) + maybe_log_gradients('train/clipped_grad_norms') + if tb_writer is not None and (global_batch_idx_train // accum_grad) % 200 == 0: + # Once in a time we will perform a more costly diagnostic + # to check the relative parameter change per minibatch. + deltas = optim_step_and_measure_param_change(model, optimizer, scaler) + tb_writer.add_scalars( + 'train/relative_param_change_per_minibatch', + deltas, + global_step=global_batch_idx_train + ) + else: + scaler.step(optimizer) + optimizer.zero_grad() + scaler.update() + + ans = -mmi_loss.detach().cpu().item(), tot_frames.cpu().item( + ), all_frames.cpu().item() + return ans + + +def get_validation_objf(dataloader: torch.utils.data.DataLoader, + model: AcousticModel, + ali_model: Optional[AcousticModel], + device: torch.device, + graph_compiler: MmiTrainingGraphCompiler, + use_pruned_intersect: bool, + scaler: GradScaler, + den_scale: float = 1, + ): + total_objf = 0. + total_frames = 0. # for display only + total_all_frames = 0. # all frames including those seqs that failed. + + model.eval() + + from torchaudio.datasets.utils import bg_iterator + for batch_idx, batch in enumerate(bg_iterator(dataloader, 2)): + objf, frames, all_frames = get_objf( + batch=batch, + model=model, + ali_model=ali_model, + device=device, + graph_compiler=graph_compiler, + use_pruned_intersect=use_pruned_intersect, + is_training=False, + is_update=False, + den_scale=den_scale, + scaler=scaler + ) + total_objf += objf + total_frames += frames + total_all_frames += all_frames + + return total_objf, total_frames, total_all_frames + + +def train_one_epoch(dataloader: torch.utils.data.DataLoader, + valid_dataloader: torch.utils.data.DataLoader, + model: AcousticModel, + ali_model: Optional[AcousticModel], + device: torch.device, + graph_compiler: MmiTrainingGraphCompiler, + use_pruned_intersect: bool, + optimizer: torch.optim.Optimizer, + accum_grad: int, + den_scale: float, + att_rate: float, + current_epoch: int, + tb_writer: SummaryWriter, + num_epochs: int, + global_batch_idx_train: int, + world_size: int, + scaler: GradScaler + ): + """One epoch training and validation. + + Args: + dataloader: Training dataloader + valid_dataloader: Validation dataloader + model: Acoustic model to be trained + P: An FSA representing the bigram phone LM + device: Training device, torch.device("cpu") or torch.device("cuda", device_id) + graph_compiler: MMI training graph compiler + optimizer: Training optimizer + accum_grad: Number of gradient accumulation + den_scale: Denominator scale in mmi loss + att_rate: Attention loss rate, final loss is att_rate * att_loss + (1-att_rate) * other_loss + current_epoch: current training epoch, for logging only + tb_writer: tensorboard SummaryWriter + num_epochs: total number of training epochs, for logging only + global_batch_idx_train: global training batch index before this epoch, for logging only + + Returns: + A tuple of 3 scalar: (total_objf / total_frames, valid_average_objf, global_batch_idx_train) + - `total_objf / total_frames` is the average training loss + - `valid_average_objf` is the average validation loss + - `global_batch_idx_train` is the global training batch index after this epoch + """ + total_objf, total_frames, total_all_frames = 0., 0., 0. + valid_average_objf = float('inf') + time_waiting_for_batch = 0 + forward_count = 0 + prev_timestamp = datetime.now() + + model.train() + for batch_idx, batch in enumerate(dataloader): + forward_count += 1 + if forward_count == accum_grad: + is_update = True + forward_count = 0 + else: + is_update = False + + global_batch_idx_train += 1 + timestamp = datetime.now() + time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds() + + curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf( + batch=batch, + model=model, + ali_model=ali_model, + device=device, + graph_compiler=graph_compiler, + use_pruned_intersect=use_pruned_intersect, + is_training=True, + is_update=is_update, + accum_grad=accum_grad, + den_scale=den_scale, + att_rate=att_rate, + tb_writer=tb_writer, + global_batch_idx_train=global_batch_idx_train, + optimizer=optimizer, + scaler=scaler + ) + + total_objf += curr_batch_objf + total_frames += curr_batch_frames + total_all_frames += curr_batch_all_frames + + if batch_idx % 10 == 0: + logging.info( + 'batch {}, epoch {}/{} ' + 'global average objf: {:.6f} over {} ' + 'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) ' + 'avg time waiting for batch {:.3f}s'.format( + batch_idx, current_epoch, num_epochs, + total_objf / total_frames, total_frames, + 100.0 * total_frames / total_all_frames, + curr_batch_objf / (curr_batch_frames + 0.001), + curr_batch_frames, + 100.0 * curr_batch_frames / curr_batch_all_frames, + time_waiting_for_batch / max(1, batch_idx))) + + if tb_writer is not None: + tb_writer.add_scalar('train/global_average_objf', + total_objf / total_frames, global_batch_idx_train) + + tb_writer.add_scalar('train/current_batch_average_objf', + curr_batch_objf / (curr_batch_frames + 0.001), + global_batch_idx_train) + # if batch_idx >= 10: + # print("Exiting early to get profile info") + # sys.exit(0) + + if batch_idx > 0 and batch_idx % 200 == 0: + total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf( + dataloader=valid_dataloader, + model=model, + ali_model=ali_model, + device=device, + graph_compiler=graph_compiler, + use_pruned_intersect=use_pruned_intersect, + scaler=scaler) + if world_size > 1: + s = torch.tensor([ + total_valid_objf, total_valid_frames, + total_valid_all_frames + ]).to(device) + + dist.all_reduce(s, op=dist.ReduceOp.SUM) + total_valid_objf, total_valid_frames, total_valid_all_frames = s.cpu().tolist() + + valid_average_objf = total_valid_objf / total_valid_frames + model.train() + logging.info( + 'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)' + .format(valid_average_objf, + total_valid_frames, + 100.0 * total_valid_frames / total_valid_all_frames)) + + if tb_writer is not None: + tb_writer.add_scalar('train/global_valid_average_objf', + valid_average_objf, + global_batch_idx_train) + model.module.write_tensorboard_diagnostics(tb_writer, global_step=global_batch_idx_train) + prev_timestamp = datetime.now() + return total_objf / total_frames, valid_average_objf, global_batch_idx_train + + +def get_parser(): + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + '--world-size', + type=int, + default=1, + help='Number of GPUs for DDP training.') + parser.add_argument( + '--master-port', + type=int, + default=12354, + help='Master port to use for DDP training.') + parser.add_argument( + '--model-type', + type=str, + default="conformer", + choices=["transformer", "conformer", "contextnet"], + help="Model type.") + parser.add_argument( + '--num-epochs', + type=int, + default=10, + help="Number of training epochs.") + parser.add_argument( + '--start-epoch', + type=int, + default=0, + help="Number of start epoch.") + parser.add_argument( + '--warm-step', + type=int, + default=5000, + help='The number of warm-up steps for Noam optimizer.' + ) + parser.add_argument( + '--lr-factor', + type=float, + default=1.0, + help='Learning rate factor for Noam optimizer.' + ) + parser.add_argument( + '--weight-decay', + type=float, + default=0.0, + help='weight decay (L2 penalty) for Noam optimizer.' + ) + parser.add_argument( + '--accum-grad', + type=int, + default=1, + help="Number of gradient accumulation.") + parser.add_argument( + '--den-scale', + type=float, + default=1.0, + help="denominator scale in mmi loss.") + parser.add_argument( + '--att-rate', + type=float, + default=0.0, + help="Attention loss rate.") + parser.add_argument( + '--nhead', + type=int, + default=4, + help="Number of attention heads in transformer.") + parser.add_argument( + '--attention-dim', + type=int, + default=256, + help="Number of units in transformer attention layers.") + parser.add_argument( + '--tensorboard', + type=str2bool, + default=True, + help='Should various information be logged in tensorboard.' + ) + parser.add_argument( + '--amp', + type=str2bool, + default=True, + help='Should we use automatic mixed precision (AMP) training.' + ) + parser.add_argument( + '--use-ali-model', + type=str2bool, + default=True, + help='If true, we assume that you have run ./ctc_train.py ' + 'and you have some checkpoints inside the directory ' + 'exp-lstm-adam-ctc-musan/ .' + 'It will use exp-lstm-adam-ctc-musan/epoch-{ali-model-epoch}.pt ' + 'as the pre-trained alignment model' + ) + parser.add_argument( + '--ali-model-epoch', + type=int, + default=7, + help='If --use-ali-model is True, load ' + 'exp-lstm-adam-ctc-musan/epoch-{ali-model-epoch}.pt as the alignment model.' + 'Used only if --use-ali-model is True.' + ) + parser.add_argument( + '--use-pruned-intersect', + type=str2bool, + default=False, + help='True to use pruned intersect to compute the denominator lattice. ' \ + 'You probably want to set it to True if you have a very large LM. ' \ + 'In that case, you will get an OOM if it is False. ') + # See https://github.com/k2-fsa/k2/issues/739 for more details + parser.add_argument( + '--torchscript', + type=str2bool, + default=False, + help='Should we convert the model to TorchScript before starting training.' + ) + parser.add_argument( + '--torchscript-epoch', + type=int, + default=-1, + help='After which epoch should we start storing models with TorchScript,' + 'so that they can be simply loaded with torch.jit.load(). ' + '-1 disables this option.' + ) + return parser + + +def run(rank, world_size, args): + ''' + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + ''' + model_type = args.model_type + start_epoch = args.start_epoch + num_epochs = args.num_epochs + accum_grad = args.accum_grad + den_scale = args.den_scale + att_rate = args.att_rate + use_pruned_intersect = args.use_pruned_intersect + + fix_random_seed(42) + setup_dist(rank, world_size, args.master_port) + + exp_dir = Path('exp-bpe-' + model_type + '-mmi-att-sa-vgg-normlayer') + setup_logger(f'{exp_dir}/log/log-train-{rank}') + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') + else: + tb_writer = None + # tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') if args.tensorboard and rank == 0 else None + + logging.info("Loading lexicon and symbol tables") + lang_dir = Path('data/lang_bpe2') + lexicon = Lexicon(lang_dir) + + device_id = rank + device = torch.device('cuda', device_id) + + graph_compiler = MmiTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + phone_ids = lexicon.phone_symbols() + + librispeech = LibriSpeechAsrDataModule(args) + train_dl = librispeech.train_dataloaders() + valid_dl = librispeech.valid_dataloaders() + + if not torch.cuda.is_available(): + logging.error('No GPU detected!') + sys.exit(-1) + + if use_pruned_intersect: + logging.info('Use pruned intersect for den_lats') + else: + logging.info("Don't use pruned intersect for den_lats") + + logging.info("About to create model") + + if att_rate != 0.0: + num_decoder_layers = 6 + else: + num_decoder_layers = 0 + + if model_type == "transformer": + model = Transformer( + num_features=80, + nhead=args.nhead, + d_model=args.attention_dim, + num_classes=len(phone_ids) + 1, # +1 for the blank symbol + subsampling_factor=4, + num_decoder_layers=num_decoder_layers, + vgg_frontend=True) + elif model_type == "conformer": + model = Conformer( + num_features=80, + nhead=args.nhead, + d_model=args.attention_dim, + num_classes=len(phone_ids) + 1, # +1 for the blank symbol + subsampling_factor=4, + num_decoder_layers=num_decoder_layers, + vgg_frontend=True, + is_espnet_structure=True) + elif model_type == "contextnet": + model = ContextNet( + num_features=80, + num_classes=len(phone_ids) + 1) # +1 for the blank symbol + else: + raise NotImplementedError("Model of type " + str(model_type) + " is not implemented") + + if args.torchscript: + logging.info('Applying TorchScript to model...') + model = torch.jit.script(model) + + model.to(device) + describe(model) + + model = DDP(model, device_ids=[rank]) + + # Now for the alignment model, if any + if args.use_ali_model: + ali_model = TdnnLstm1b( + num_features=80, + num_classes=len(phone_ids) + 1, # +1 for the blank symbol + subsampling_factor=4) + + ali_model_fname = Path(f'exp-lstm-adam-ctc-musan/epoch-{args.ali_model_epoch}.pt') + assert ali_model_fname.is_file(), \ + f'ali model filename {ali_model_fname} does not exist!' + ali_model.load_state_dict(torch.load(ali_model_fname, map_location='cpu')['state_dict']) + ali_model.to(device) + + ali_model.eval() + ali_model.requires_grad_(False) + logging.info(f'Use ali_model: {ali_model_fname}') + else: + ali_model = None + logging.info('No ali_model') + + optimizer = Noam(model.parameters(), + model_size=args.attention_dim, + factor=args.lr_factor, + warm_step=args.warm_step, + weight_decay=args.weight_decay) + + scaler = GradScaler(enabled=args.amp) + + best_objf = np.inf + best_valid_objf = np.inf + best_epoch = start_epoch + best_model_path = os.path.join(exp_dir, 'best_model.pt') + best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info') + global_batch_idx_train = 0 # for logging only + + if start_epoch > 0: + model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(start_epoch - 1)) + ckpt = load_checkpoint(filename=model_path, model=model, optimizer=optimizer, scaler=scaler) + best_objf = ckpt['objf'] + best_valid_objf = ckpt['valid_objf'] + global_batch_idx_train = ckpt['global_batch_idx_train'] + logging.info(f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}") + + for epoch in range(start_epoch, num_epochs): + train_dl.sampler.set_epoch(epoch) + curr_learning_rate = optimizer._rate + if tb_writer is not None: + tb_writer.add_scalar('train/learning_rate', curr_learning_rate, global_batch_idx_train) + tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train) + + logging.info('epoch {}, learning rate {}'.format(epoch, curr_learning_rate)) + objf, valid_objf, global_batch_idx_train = train_one_epoch( + dataloader=train_dl, + valid_dataloader=valid_dl, + model=model, + ali_model=ali_model, + device=device, + graph_compiler=graph_compiler, + use_pruned_intersect=use_pruned_intersect, + optimizer=optimizer, + accum_grad=accum_grad, + den_scale=den_scale, + att_rate=att_rate, + current_epoch=epoch, + tb_writer=tb_writer, + num_epochs=num_epochs, + global_batch_idx_train=global_batch_idx_train, + world_size=world_size, + scaler=scaler + ) + # the lower, the better + if valid_objf < best_valid_objf: + best_valid_objf = valid_objf + best_objf = objf + best_epoch = epoch + save_checkpoint(filename=best_model_path, + optimizer=None, + scheduler=None, + scaler=None, + model=model, + epoch=epoch, + learning_rate=curr_learning_rate, + objf=objf, + valid_objf=valid_objf, + global_batch_idx_train=global_batch_idx_train, + local_rank=rank, + torchscript=args.torchscript_epoch != -1 and epoch >= args.torchscript_epoch + ) + save_training_info(filename=best_epoch_info_filename, + model_path=best_model_path, + current_epoch=epoch, + learning_rate=curr_learning_rate, + objf=objf, + best_objf=best_objf, + valid_objf=valid_objf, + best_valid_objf=best_valid_objf, + best_epoch=best_epoch, + local_rank=rank) + + # we always save the model for every epoch + model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch)) + save_checkpoint(filename=model_path, + optimizer=optimizer, + scheduler=None, + scaler=scaler, + model=model, + epoch=epoch, + learning_rate=curr_learning_rate, + objf=objf, + valid_objf=valid_objf, + global_batch_idx_train=global_batch_idx_train, + local_rank=rank, + torchscript=args.torchscript_epoch != -1 and epoch >= args.torchscript_epoch + ) + epoch_info_filename = os.path.join(exp_dir, 'epoch-{}-info'.format(epoch)) + save_training_info(filename=epoch_info_filename, + model_path=model_path, + current_epoch=epoch, + learning_rate=curr_learning_rate, + objf=objf, + best_objf=best_objf, + valid_objf=valid_objf, + best_valid_objf=best_valid_objf, + best_epoch=best_epoch, + local_rank=rank) + + logging.warning('Done') + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + world_size = args.world_size + assert world_size >= 1 + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == '__main__': + main() diff --git a/egs/librispeech/asr/simple_v1/generate_bpe_lexicon.py b/egs/librispeech/asr/simple_v1/generate_bpe_lexicon.py new file mode 100755 index 00000000..36584fe3 --- /dev/null +++ b/egs/librispeech/asr/simple_v1/generate_bpe_lexicon.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +from pathlib import Path +from typing import List + +import argparse +import sentencepiece as spm + + +def read_words(words_txt: str, excluded=['', '']) -> List[str]: + '''Read words_txt and return a list of words. + + The file words_txt has the following format: + + + + That is, every line has two fields. This function + extracts the first field. + + Args: + words_txt: + Filename of words.txt. + excluded: + words in this list are not returned. + Returns: + Return a list of words. + ''' + ans = [] + with open(words_txt, 'r', encoding='latin-1') as f: + for line in f: + word, id = line.strip().split() + if word not in excluded: + ans.append(word) + return ans + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument('--model-file', + type=str, + help='Pre-trained BPE model file') + + parser.add_argument('--words-file', type=str, help='Path to words.txt') + + args = parser.parse_args() + model_file = args.model_file + words_txt = args.words_file + assert Path(model_file).is_file(), f'{model_file} does not exist' + assert Path(words_txt).is_file(), f'{words_txt} does not exist' + + words = read_words(words_txt) + + sp = spm.SentencePieceProcessor(model_file=model_file) + words_pieces = sp.encode(words, out_type=str) + + for word, pieces in zip(words, words_pieces): + print(word, ' '.join(pieces)) + + print('', '') + + +if __name__ == '__main__': + main() diff --git a/egs/librispeech/asr/simple_v1/generate_bpe_tokens.py b/egs/librispeech/asr/simple_v1/generate_bpe_tokens.py new file mode 100755 index 00000000..00188c33 --- /dev/null +++ b/egs/librispeech/asr/simple_v1/generate_bpe_tokens.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +''' +Example usage of this script: + +python3 ./generate_bpe_tokens.py \ + --model-file ./data/lang_bpe/bpe_unigram_500.model > data/lang_bpe/bpe_unigram_500.tokens +''' + +from pathlib import Path + +import argparse +import sentencepiece as spm + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument('--model-file', + type=str, + help='Pre-trained BPE model file') + args = parser.parse_args() + model_file = args.model_file + assert Path(model_file).is_file(), f'{model_file} does not exist' + + sp = spm.SentencePieceProcessor(model_file=model_file) + vocab_size = sp.vocab_size() + for i in range(vocab_size): + print(sp.id_to_piece(i), i) + + +if __name__ == '__main__': + main() diff --git a/egs/librispeech/asr/simple_v1/run.sh b/egs/librispeech/asr/simple_v1/run.sh index 1aa18626..9ae09a9e 100755 --- a/egs/librispeech/asr/simple_v1/run.sh +++ b/egs/librispeech/asr/simple_v1/run.sh @@ -3,11 +3,36 @@ # Copyright 2020 Xiaomi Corporation (Author: Junbo Zhang) # Apache 2.0 -# Example of how to build L and G FST for K2. Most scripts of this example are copied from Kaldi. - set -eou pipefail -stage=0 +libri_dirs=( +/root/fangjun/data/librispeech/LibriSpeech +/export/corpora5/LibriSpeech +/home/storage04/zhuangweiji/data/open-source-data/librispeech/LibriSpeech +/export/common/data/corpora/ASR/openslr/SLR12/LibriSpeech +) + +libri_dir= +for d in ${libri_dirs[@]}; do + if [ -d $d ]; then + libri_dir=$d + break + fi +done + +if [ ! -d $libri_dir/train-clean-100 ]; then + echo "Please set LibriSpeech dataset path before running this script" + exit 1 +fi + +echo "LibriSpeech dataset dir: $libri_dir" + +stage=5 + +# settings for BPE training -- start +vocab_size=5000 +model_type=unigram # valid values: unigram, bpe, word, char +# settings for BPE training -- end if [ $stage -le 1 ]; then local/download_lm.sh "openslr.org/resources/11" data/local/lm @@ -70,6 +95,116 @@ if [ $stage -le 4 ]; then fi if [ $stage -le 5 ]; then + # TODO(fangjun): Move this stage to a separate script + echo "Preparing BPE training" + dir=data/lang_bpe2 + mkdir -p $dir + if [ ! -f $dir/transcript.txt ]; then + echo "Generating $dir/transcript.txt" + files=$( + find "$libri_dir/train-clean-100" -name "*.trans.txt" + find "$libri_dir/train-clean-360" -name "*.trans.txt" + find "$libri_dir/train-other-500" -name "*.trans.txt" + ) + for f in ${files[@]}; do + cat $f | cut -d " " -f 2- + done > $dir/transcript.txt + fi + + model_file=$dir/bpe_${model_type}_${vocab_size}.model + if [ ! -f $model_file ]; then + echo "Generating $model_file" + python3 ./train_bpe_model.py \ + --transcript $dir/transcript.txt \ + --model-type $model_type \ + --vocab-size $vocab_size \ + --output-dir $dir + else + echo "$model_file exists, skip BPE training" + fi + + if [ ! -f $dir/tokens.txt ]; then + python3 ./generate_bpe_tokens.py \ + --model-file $model_file > $dir/tokens.txt + fi + # Copy tokens.txt to phones.txt since the existing code + # expects a fixed name "phones.txt" + ln -fv $dir/tokens.txt $dir/phones.txt + + if [ ! -f $dir/words.txt ]; then + echo " 0" > $dir/words.txt + echo " 1" >> $dir/words.txt + cat $dir/transcript.txt | tr -s " " "\n" | sort | uniq | + awk '{print $0 " " NR+1}' >> $dir/words.txt + fi + + if [ ! -f $dir/lexicon.txt ]; then + python3 ./generate_bpe_lexicon.py \ + --model-file $model_file \ + --words-file $dir/words.txt > $dir/lexicon.txt + fi + + if [ ! -f $dir/lexiconp.txt ]; then + echo "**Creating $dir/lexiconp.txt from $dir/lexicon.txt" + perl -ape 's/(\S+\s+)(.+)/${1}1.0\t$2/;' < $dir/lexicon.txt > $dir/lexiconp.txt || exit 1 + fi + + ndisambig=$(local/add_lex_disambig.pl --pron-probs $dir/lexiconp.txt $dir/lexiconp_disambig.txt) + if ! grep "#0" $dir/words.txt > /dev/null 2>&1; then + max_word_id=$(tail -1 $dir/words.txt | awk '{print $2}') + for i in $(seq 0 $ndisambig); do + echo "#$i $((i+max_word_id+1))" + done >> $dir/words.txt + fi + + if ! grep "#0" $dir/phones.txt > /dev/null 2>&1 ; then + max_phone_id=$(tail -1 $dir/phones.txt | awk '{print $2}') + for i in $(seq 0 $ndisambig); do + echo "#$i $((i+max_phone_id+1))" + done >> $dir/phones.txt + fi + + if [ ! -f $dir/L.fst.txt ]; then + # NOTE: 1 is in `--map-oov 1`. + local/make_lexicon_fst.py $dir/lexiconp.txt | \ + local/sym2int.pl --map-oov 1 -f 3 $dir/tokens.txt | \ + local/sym2int.pl -f 4 $dir/words.txt > $dir/L.fst.txt || exit 1 + fi + + if [ ! -f $dir/L_disambig.fst.txt ]; then + wdisambig_phone=$(echo "#0" | local/sym2int.pl $dir/phones.txt) + wdisambig_word=$(echo "#0" | local/sym2int.pl $dir/words.txt) + + local/make_lexicon_fst.py \ + $dir/lexiconp_disambig.txt | \ + local/sym2int.pl --map-oov 1 -f 3 $dir/phones.txt | \ + local/sym2int.pl -f 4 $dir/words.txt | \ + local/fstaddselfloops.pl $wdisambig_phone $wdisambig_word > $dir/L_disambig.fst.txt || exit 1 + fi + + if [ ! -f $dir/G.fst.txt ]; then + python3 -m kaldilm \ + --read-symbol-table="$dir/words.txt" \ + --disambig-symbol='#0' \ + --max-order=3 \ + data/local/lm/lm_tgmed.arpa > $dir/G.fst.txt + else + echo "Skip generating $dir/G.fst.txt" + fi + + if [ ! -f $dir/G_4_gram.fst.txt ]; then + python3 -m kaldilm \ + --read-symbol-table="$dir/words.txt" \ + --disambig-symbol='#0' \ + --max-order=4 \ + data/local/lm/lm_fglarge.arpa > $dir/G_4_gram.fst.txt + else + echo "Skip generating $dir/G_4_gram.fst.txt" + fi +fi +exit 0 + +if [ $stage -le 6 ]; then python3 ./prepare.py fi @@ -79,7 +214,7 @@ fi # # exit 0 -if [ $stage -le 6 ]; then +if [ $stage -le 7 ]; then # python3 ./train.py # ctc training # python3 ./mmi_bigram_train.py # ctc training + bigram phone LM # python3 ./mmi_mbr_train.py @@ -99,7 +234,7 @@ if [ $stage -le 6 ]; then # python3 -m torch.distributed.launch --nproc_per_node=$ngpus ./mmi_bigram_train.py --world_size $ngpus fi -if [ $stage -le 7 ]; then +if [ $stage -le 8 ]; then # python3 ./decode.py # ctc decoding # python3 ./mmi_bigram_decode.py --epoch 9 # python3 ./mmi_mbr_decode.py diff --git a/egs/librispeech/asr/simple_v1/train_bpe_model.py b/egs/librispeech/asr/simple_v1/train_bpe_model.py new file mode 100755 index 00000000..0786b027 --- /dev/null +++ b/egs/librispeech/asr/simple_v1/train_bpe_model.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +''' +Example usage of this script: + +python3 ./train_bpe_model.py + --transcript data/lang_bpe/transcript.txt \ + --output-dir data/lang_bpe \ + --model-type unigram \ + --vocab-size 500 + +It will generate two files: +(1) data/lang_bpe/bpe_unigram_500.model +(2) data/lang_bpe/bpe_unigram_500.vocab + +We only use the first file "bpe_unigram_500.model". +''' + +from pathlib import Path +from typing import Iterable +from typing import List +from typing import Optional +from typing import Union + +import argparse +import sys + +try: + import sentencepiece as spm +except ImportError: + print('Please run:\n\n\t' + 'pip install sentencepiece\n\n' + 'before running this script.') + sys.exit(1) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument('--transcript', + type=str, + help='Path to the transcript.') + + parser.add_argument('--vocab-size', + type=int, + default=500, + help='vocab size for BPE training') + + parser.add_argument('--model-type', + type=str, + default='unigram', + choices=['unigram', 'bpe', 'word', 'char'], + help='model algorithm for BPE training') + + parser.add_argument('--output-dir', + type=str, + required=True, + help='Output directory') + return parser + + +def main(): + args = get_parser().parse_args() + assert Path(args.transcript).is_file(), f'{args.transcript} does not exist' + assert args.model_type in ('unigram', 'bpe', 'word', 'char') + assert args.vocab_size > 0 + + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + minloglevel = 0 # change it to 1 to disable INFO logs + transcript = args.transcript + vocab_size = args.vocab_size + model_type = args.model_type + model_prefix = f'{args.output_dir}/bpe_{model_type}_{vocab_size}' + input_sentence_size = 100000000 + user_defined_symbols = [''] + + # By default, unk_id is 0, but we want to map to 0, so + # We define unk_id to 1. You can choose any value (less than vocab_size), + # for unk_id except 0, which is occupied by the above `` symbol + unk_id = 1 + + # `` is guaranteed to be mapped to 0 + + spm.SentencePieceTrainer.train(input=transcript, + model_prefix=model_prefix, + vocab_size=vocab_size, + model_type=model_type, + user_defined_symbols=user_defined_symbols, + unk_id=unk_id, + bos_id=-1, + eos_id=-1, + input_sentence_size=input_sentence_size, + minloglevel=minloglevel) + + print(f'Generated "{model_prefix}.model" and {model_prefix}.vocab') + + +if __name__ == '__main__': + main() diff --git a/requirements.txt b/requirements.txt index 72c7d824..34b85a49 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ tensorboard torch>=1.6.0 torchaudio click>=7.1 +sentencepiece diff --git a/snowfall/objectives/mmi.py b/snowfall/objectives/mmi.py index 88cd55af..abefe643 100644 --- a/snowfall/objectives/mmi.py +++ b/snowfall/objectives/mmi.py @@ -14,7 +14,6 @@ def _compute_mmi_loss_exact_optimized( texts: List[str], supervision_segments: torch.Tensor, graph_compiler: MmiTrainingGraphCompiler, - P: k2.Fsa, den_scale: float = 1.0 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ''' @@ -36,13 +35,10 @@ def _compute_mmi_loss_exact_optimized( A 2-D tensor that will be passed to :func:`k2.DenseFsaVec`. graph_compiler: Used to build num_graphs and den_graphs - P: - Represents a bigram Fsa. den_scale: The scale applied to the denominator tot_scores. ''' num_graphs, den_graphs = graph_compiler.compile(texts, - P, replicate_den=False) dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) @@ -111,7 +107,6 @@ def _compute_mmi_loss_exact_non_optimized( texts: List[str], supervision_segments: torch.Tensor, graph_compiler: MmiTrainingGraphCompiler, - P: k2.Fsa, den_scale: float = 1.0 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ''' @@ -124,7 +119,6 @@ def _compute_mmi_loss_exact_non_optimized( It uses less memory at the cost of speed. It is slower. ''' num_graphs, den_graphs = graph_compiler.compile(texts, - P, replicate_den=True) dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) @@ -149,7 +143,6 @@ def _compute_mmi_loss_pruned( texts: List[str], supervision_segments: torch.Tensor, graph_compiler: MmiTrainingGraphCompiler, - P: k2.Fsa, den_scale: float = 1.0 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ''' @@ -163,7 +156,6 @@ def _compute_mmi_loss_pruned( to pruning. ''' num_graphs, den_graphs = graph_compiler.compile(texts, - P, replicate_den=False) dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) @@ -200,13 +192,11 @@ class LFMMILoss(nn.Module): def __init__( self, graph_compiler: MmiTrainingGraphCompiler, - P: k2.Fsa, use_pruned_intersect: bool = False, den_scale: float = 1.0, ): super().__init__() self.graph_compiler = graph_compiler - self.P = P self.den_scale = den_scale self.use_pruned_intersect = use_pruned_intersect @@ -223,5 +213,4 @@ def forward(self, nnet_output: torch.Tensor, texts: List[str], texts=texts, supervision_segments=supervision_segments, graph_compiler=self.graph_compiler, - P=self.P, den_scale=self.den_scale) diff --git a/snowfall/training/ctc_graph.py b/snowfall/training/ctc_graph.py index ea4b198a..25ec0e10 100644 --- a/snowfall/training/ctc_graph.py +++ b/snowfall/training/ctc_graph.py @@ -10,7 +10,7 @@ from snowfall.common import get_phone_symbols -def build_ctc_topo(tokens: List[int]) -> k2.Fsa: +def build_ctc_topo2(tokens: List[int]) -> k2.Fsa: '''Build CTC topology. A token which appears once on the right side (i.e. olabels) may appear multiple times on the left side (ilabels), possibly with @@ -42,7 +42,7 @@ def build_ctc_topo(tokens: List[int]) -> k2.Fsa: return ans -def build_ctc_topo2(phones: List[int]): +def build_ctc_topo(phones: List[int]): # See https://github.com/k2-fsa/k2/issues/746#issuecomment-856421616 assert 0 in phones, 'We assume 0 is the ID of the blank symbol' phones = phones.copy() diff --git a/snowfall/training/mmi_graph.py b/snowfall/training/mmi_graph.py index 758830f7..f611d9ca 100644 --- a/snowfall/training/mmi_graph.py +++ b/snowfall/training/mmi_graph.py @@ -81,11 +81,10 @@ def __init__( ctc_topo = build_ctc_topo(phone_symbols_with_blank).to(device) assert ctc_topo.requires_grad is False - self.ctc_topo_inv = k2.arc_sort(ctc_topo.invert_()) + self.ctc_topo = k2.arc_sort(ctc_topo) def compile(self, texts: Iterable[str], - P: k2.Fsa, replicate_den: bool = True) -> Tuple[k2.Fsa, k2.Fsa]: '''Create numerator and denominator graphs from transcripts and the bigram phone LM. @@ -94,8 +93,6 @@ def compile(self, texts: A list of transcripts. Within a transcript, words are separated by spaces. - P: - The bigram phone LM created by :func:`create_bigram_phone_lm`. replicate_den: If True, the returned den_graph is replicated to match the number of FSAs in the returned num_graph; if False, the returned den_graph @@ -111,19 +108,6 @@ def compile(self, is an FsaVec containing only a single FSA. ''' self_device = str(self.device) - if self_device == 'cuda': - # the compilers graph device does not specify GPU ID, just check that both tensors are on GPU - assert str(P.device).startswith( - 'cuda'), f'Assertion failed: GraphCompiler uses on "cuda", but P is on "{P.device}"' - else: - assert str(P.device) == str(self.device), f'Assertion failed: "{P.device} == {self.device}"' - P_with_self_loops = k2.add_epsilon_self_loops(P) - - ctc_topo_P = k2.intersect(self.ctc_topo_inv, - P_with_self_loops, - treat_epsilons_specially=False).invert() - - ctc_topo_P = k2.arc_sort(ctc_topo_P) num_graphs = self.build_num_graphs(texts) num_graphs_with_self_loops = k2.remove_epsilon_and_add_self_loops( @@ -131,19 +115,19 @@ def compile(self, num_graphs_with_self_loops = k2.arc_sort(num_graphs_with_self_loops) - num = k2.compose(ctc_topo_P, + num = k2.compose(self.ctc_topo, num_graphs_with_self_loops, treat_epsilons_specially=False) num = k2.arc_sort(num) - ctc_topo_P_vec = k2.create_fsa_vec([ctc_topo_P.detach()]) + ctc_topo_vec = k2.create_fsa_vec([self.ctc_topo]) if replicate_den: indexes = torch.zeros(len(texts), dtype=torch.int32, device=self.device) - den = k2.index_fsa(ctc_topo_P_vec, indexes) + den = k2.index_fsa(ctc_topo_vec, indexes) else: - den = ctc_topo_P_vec + den = ctc_topo_vec return num, den