diff --git a/funasr/bin/lm_calc_perplexity.py b/funasr/bin/lm_calc_perplexity.py index 27a8a71fc..198d5784c 100755 --- a/funasr/bin/lm_calc_perplexity.py +++ b/funasr/bin/lm_calc_perplexity.py @@ -56,7 +56,7 @@ def calc_perplexity( set_all_random_seed(seed) # 2. Build LM - model, train_args = LMTask.build_model_from_file(train_config, model_file, device) + model, train_args = LMTask.build_model_from_file(config_file=train_config, model_file=model_file, device=device) # Wrape model to make model.nll() data-parallel wrapped_model = ForwardAdaptor(model, "nll") wrapped_model.to(dtype=getattr(torch, dtype)).eval() @@ -111,6 +111,7 @@ def calc_perplexity( utt_ppl = log_base ** (_nll / ntoken / np.log(log_base)) # Write PPL of each utts for debugging or analysis + writer["utt2nll"][key] = str(-_nll) writer["utt2ppl"][key] = str(utt_ppl) writer["utt2ntokens"][key] = str(ntoken) diff --git a/funasr/bin/lm_inference.py b/funasr/bin/lm_inference.py new file mode 100644 index 000000000..909cb02da --- /dev/null +++ b/funasr/bin/lm_inference.py @@ -0,0 +1,406 @@ +#!/usr/bin/env python3 +import argparse +import logging +from pathlib import Path +import sys +import os +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union +from typing import Dict +from typing import Any +from typing import List + +import numpy as np +import torch +from torch.nn.parallel import data_parallel +from typeguard import check_argument_types + +from funasr.tasks.lm import LMTask +from funasr.datasets.preprocessor import LMPreprocessor +from funasr.utils.cli_utils import get_commandline_args +from funasr.fileio.datadir_writer import DatadirWriter +from funasr.torch_utils.device_funcs import to_device +from funasr.torch_utils.forward_adaptor import ForwardAdaptor +from funasr.torch_utils.set_all_random_seed import set_all_random_seed +from funasr.utils import config_argparse +from funasr.utils.types import float_or_none +from funasr.utils.types import str2bool +from funasr.utils.types import str2triple_str +from funasr.utils.types import str_or_none + +def inference( + output_dir: str, + batch_size: int, + dtype: str, + ngpu: int, + seed: int, + num_workers: int, + log_level: Union[int, str], + train_config: Optional[str], + model_file: Optional[str], + log_base: Optional[float], + key_file: Optional[str] = None, + allow_variable_data_keys: bool = False, + split_with_space: Optional[bool] = False, + seg_dict_file: Optional[str] = None, + data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None, + raw_inputs: Union[List[Any], bytes, str] = None, + **kwargs, +): + inference_pipeline = inference_modelscope( + output_dir=output_dir, + raw_inputs=raw_inputs, + batch_size=batch_size, + dtype=dtype, + ngpu=ngpu, + seed=seed, + num_workers=num_workers, + log_level=log_level, + key_file=key_file, + train_config=train_config, + model_file=model_file, + log_base = log_base, + allow_variable_data_keys = allow_variable_data_keys, + split_with_space=split_with_space, + seg_dict_file=seg_dict_file, + **kwargs, + ) + return inference_pipeline(data_path_and_name_and_type, raw_inputs) + + +def inference_modelscope( + batch_size: int, + dtype: str, + ngpu: int, + seed: int, + num_workers: int, + log_level: Union[int, str], + key_file: Optional[str], + train_config: Optional[str], + model_file: Optional[str], + log_base: Optional[float] = 10, + allow_variable_data_keys: bool = False, + split_with_space: Optional[bool] = False, + seg_dict_file: Optional[str] = None, + output_dir: Optional[str] = None, + param_dict: dict = None, + **kwargs, +): + assert check_argument_types() + logging.basicConfig( + level=log_level, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + + if ngpu >= 1 and torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + + # 1. Set random-seed + set_all_random_seed(seed) + + # 2. Build Model + model, train_args = LMTask.build_model_from_file( + train_config, model_file, device) + wrapped_model = ForwardAdaptor(model, "nll") + wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval() + logging.info(f"Model:\n{model}") + + preprocessor = LMPreprocessor( + train=False, + token_type=train_args.token_type, + token_list=train_args.token_list, + bpemodel=train_args.bpemodel, + text_cleaner=train_args.cleaner, + g2p_type=train_args.g2p, + text_name="text", + non_linguistic_symbols=train_args.non_linguistic_symbols, + split_with_space=split_with_space, + seg_dict_file=seg_dict_file + ) + + def _forward( + data_path_and_name_and_type, + raw_inputs: Union[List[Any], bytes, str] = None, + output_dir_v2: Optional[str] = None, + param_dict: dict = None, + ): + results = [] + if output_dir_v2 is not None: + writer = DatadirWriter(output_dir_v2) + else: + writer = None + + if raw_inputs != None: + line = raw_inputs.strip() + key = "lm demo" + if line=="": + item = {'key': key, 'value': ""} + results.append(item) + return results + batch = {} + batch['text'] = line + if preprocessor != None: + batch = preprocessor(key, batch) + + # Force data-precision + for name in batch: + value = batch[name] + if not isinstance(value, np.ndarray): + raise RuntimeError( + f"All values must be converted to np.ndarray object " + f'by preprocessing, but "{name}" is still {type(value)}.' + ) + # Cast to desired type + if value.dtype.kind == "f": + value = value.astype("float32") + elif value.dtype.kind == "i": + value = value.astype("long") + else: + raise NotImplementedError(f"Not supported dtype: {value.dtype}") + batch[name] = value + + batch["text_lengths"] = torch.from_numpy( + np.array([len(batch["text"])], dtype='int32')) + batch["text"] = np.expand_dims(batch["text"], axis=0) + + with torch.no_grad(): + batch = to_device(batch, device) + if ngpu <= 1: + nll, lengths = wrapped_model(**batch) + else: + nll, lengths = data_parallel( + wrapped_model, (), range(ngpu), module_kwargs=batch + ) + ## compute ppl + ppl_out_batch = "" + ids2tokens = preprocessor.token_id_converter.ids2tokens + for sent_ids, sent_nll in zip(batch['text'], nll): + pre_word = "" + cur_word = None + sent_lst = ids2tokens(sent_ids) + [''] + ppl_out = " ".join(sent_lst) + "\n" + for word, word_nll in zip(sent_lst, sent_nll): + cur_word = word + word_nll = -word_nll.cpu() + if log_base is None: + word_prob = np.exp(word_nll) + else: + word_prob = log_base ** (word_nll / np.log(log_base)) + ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format( + cur=cur_word, + pre=pre_word, + prob=round(word_prob.item(), 8), + word_nll=round(word_nll.item(), 8) + ) + pre_word = cur_word + + sent_nll_mean = sent_nll.mean().cpu().numpy() + sent_nll_sum = sent_nll.sum().cpu().numpy() + if log_base is None: + sent_ppl = np.exp(sent_nll_mean) + else: + sent_ppl = log_base ** (sent_nll_mean / np.log(log_base)) + ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format( + sent_nll=round(-sent_nll_sum.item(), 4), + sent_ppl=round(sent_ppl.item(), 4) + ) + ppl_out_batch += ppl_out + item = {'key': key, 'value': ppl_out} + if writer is not None: + writer["ppl"][key+":\n"] = ppl_out + results.append(item) + + return results + + # 3. Build data-iterator + loader = LMTask.build_streaming_iterator( + data_path_and_name_and_type, + dtype=dtype, + batch_size=batch_size, + key_file=key_file, + num_workers=num_workers, + preprocess_fn=preprocessor, + collate_fn=LMTask.build_collate_fn(train_args, False), + allow_variable_data_keys=allow_variable_data_keys, + inference=True, + ) + + # 4. Start for-loop + total_nll = 0.0 + total_ntokens = 0 + ppl_out_all = "" + for keys, batch in loader: + assert isinstance(batch, dict), type(batch) + assert all(isinstance(s, str) for s in keys), keys + _bs = len(next(iter(batch.values()))) + assert len(keys) == _bs, f"{len(keys)} != {_bs}" + + ppl_out_batch = "" + with torch.no_grad(): + batch = to_device(batch, device) + if ngpu <= 1: + # NOTE(kamo): data_parallel also should work with ngpu=1, + # but for debuggability it's better to keep this block. + nll, lengths = wrapped_model(**batch) + else: + nll, lengths = data_parallel( + wrapped_model, (), range(ngpu), module_kwargs=batch + ) + ## print ppl + ids2tokens = preprocessor.token_id_converter.ids2tokens + for key, sent_ids, sent_nll in zip(keys, batch['text'], nll): + pre_word = "" + cur_word = None + sent_lst = ids2tokens(sent_ids) + [''] + ppl_out = " ".join(sent_lst) + "\n" + for word, word_nll in zip(sent_lst, sent_nll): + cur_word = word + word_nll = -word_nll.cpu() + if log_base is None: + word_prob = np.exp(word_nll) + else: + word_prob = log_base ** (word_nll / np.log(log_base)) + ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format( + cur=cur_word, + pre=pre_word, + prob=round(word_prob.item(), 8), + word_nll=round(word_nll.item(), 8) + ) + pre_word = cur_word + + sent_nll_mean = sent_nll.mean().cpu().numpy() + sent_nll_sum = sent_nll.sum().cpu().numpy() + if log_base is None: + sent_ppl = np.exp(sent_nll_mean) + else: + sent_ppl = log_base ** (sent_nll_mean / np.log(log_base)) + ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format( + sent_nll=round(-sent_nll_sum.item(), 4), + sent_ppl=round(sent_ppl.item(), 4) + ) + ppl_out_batch += ppl_out + utt2nll = round(-sent_nll_sum.item(), 5) + item = {'key': key, 'value': ppl_out} + if writer is not None: + writer["ppl"][key+":\n"] = ppl_out + writer["utt2nll"][key] = str(utt2nll) + results.append(item) + + ppl_out_all += ppl_out_batch + + assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths)) + # nll: (B, L) -> (B,) + nll = nll.detach().cpu().numpy().sum(1) + # lengths: (B,) + lengths = lengths.detach().cpu().numpy() + total_nll += nll.sum() + total_ntokens += lengths.sum() + + if log_base is None: + ppl = np.exp(total_nll / total_ntokens) + else: + ppl = log_base ** (total_nll / total_ntokens / np.log(log_base)) + + avg_ppl = 'logprob= {total_nll} ppl= {total_ppl}\n'.format( + total_nll=round(-total_nll.item(), 4), + total_ppl=round(ppl.item(), 4) + ) + item = {'key': 'AVG PPL', 'value': avg_ppl} + ppl_out_all += avg_ppl + if writer is not None: + writer["ppl"]["AVG PPL : "] = avg_ppl + results.append(item) + + return results + + return _forward + + +def get_parser(): + parser = config_argparse.ArgumentParser( + description="Calc perplexity", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "--log_level", + type=lambda x: x.upper(), + default="INFO", + choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), + help="The verbose level of logging", + ) + + parser.add_argument("--output_dir", type=str, required=False) + parser.add_argument( + "--ngpu", + type=int, + default=0, + help="The number of gpus. 0 indicates CPU mode", + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed") + parser.add_argument( + "--dtype", + default="float32", + choices=["float16", "float32", "float64"], + help="Data type", + ) + parser.add_argument( + "--num_workers", + type=int, + default=1, + help="The number of workers used for DataLoader", + ) + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="The batch size for inference", + ) + parser.add_argument( + "--log_base", + type=float_or_none, + default=10, + help="The base of logarithm for Perplexity. " + "If None, napier's constant is used.", + required=False + ) + + group = parser.add_argument_group("Input data related") + group.add_argument( + "--data_path_and_name_and_type", + type=str2triple_str, + action="append", + required=False + ) + group.add_argument( + "--raw_inputs", + type=str, + required=False + ) + group.add_argument("--key_file", type=str_or_none) + group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) + + group.add_argument("--split_with_space", type=str2bool, default=False) + group.add_argument("--seg_dict_file", type=str_or_none) + + group = parser.add_argument_group("The model configuration related") + group.add_argument("--train_config", type=str) + group.add_argument("--model_file", type=str) + + return parser + + +def main(cmd=None): + print(get_commandline_args(), file=sys.stderr) + parser = get_parser() + args = parser.parse_args(cmd) + kwargs = vars(args) + inference(**kwargs) + +if __name__ == "__main__": + main() + diff --git a/funasr/bin/lm_inference_launch.py b/funasr/bin/lm_inference_launch.py new file mode 100644 index 000000000..492ebab5a --- /dev/null +++ b/funasr/bin/lm_inference_launch.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import argparse +import logging +import os +import sys +from typing import Union, Dict, Any + +from funasr.utils import config_argparse +from funasr.utils.cli_utils import get_commandline_args +from funasr.utils.types import str2bool +from funasr.utils.types import str2triple_str +from funasr.utils.types import str_or_none +from funasr.utils.types import float_or_none + + +def get_parser(): + parser = config_argparse.ArgumentParser( + description="Calc perplexity", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "--log_level", + type=lambda x: x.upper(), + default="INFO", + choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), + help="The verbose level of logging", + ) + parser.add_argument("--output_dir", type=str, required=True) + parser.add_argument("--gpuid_list", type=str, required=True) + parser.add_argument( + "--ngpu", + type=int, + default=0, + help="The number of gpus. 0 indicates CPU mode", + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed") + parser.add_argument("--njob", type=int, default=1, help="Random seed") + parser.add_argument( + "--dtype", + default="float32", + choices=["float16", "float32", "float64"], + help="Data type", + ) + parser.add_argument( + "--num_workers", + type=int, + default=1, + help="The number of workers used for DataLoader", + ) + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="The batch size for inference", + ) + parser.add_argument( + "--log_base", + type=float_or_none, + default=10, + help="The base of logarithm for Perplexity. " + "If None, napier's constant is used.", + required=False + ) + + group = parser.add_argument_group("Input data related") + group.add_argument( + "--data_path_and_name_and_type", + type=str2triple_str, + action="append", + required=False + ) + group.add_argument( + "--raw_inputs", + type=str, + required=False + ) + group.add_argument("--key_file", type=str_or_none) + group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) + + group.add_argument("--split_with_space", type=str2bool, default=False) + group.add_argument("--seg_dict_file", type=str_or_none) + + group = parser.add_argument_group("The model configuration related") + group.add_argument("--train_config", type=str) + group.add_argument("--model_file", type=str) + group.add_argument("--mode", type=str, default="lm") + return parser + +def inference_launch(mode, **kwargs): + if mode == "transformer": + from funasr.bin.lm_inference import inference_modelscope + return inference_modelscope(**kwargs) + else: + logging.info("Unknown decoding mode: {}".format(mode)) + return None + + +def main(cmd=None): + print(get_commandline_args(), file=sys.stderr) + parser = get_parser() + args = parser.parse_args(cmd) + kwargs = vars(args) + kwargs.pop("config", None) + + # set logging messages + logging.basicConfig( + level=args.log_level, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.info("Decoding args: {}".format(kwargs)) + + # gpu setting + if args.ngpu > 0: + jobid = int(args.output_dir.split(".")[-1]) + gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob] + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = gpuid + + kwargs.pop("gpuid_list", None) + kwargs.pop("njob", None) + results = inference_launch(**kwargs) + + +if __name__ == "__main__": + main() + diff --git a/funasr/bin/lm_train.py b/funasr/bin/lm_train.py index faa7a4596..8641465eb 100755 --- a/funasr/bin/lm_train.py +++ b/funasr/bin/lm_train.py @@ -1,22 +1,46 @@ #!/usr/bin/env python3 + +import os + from funasr.tasks.lm import LMTask -def get_parser(): +# for LM Training +def parse_args(): parser = LMTask.get_parser() - return parser + parser.add_argument( + "--gpu_id", + type=int, + default=0, + help="local gpu id.", + ) + args = parser.parse_args() + return args -def main(cmd=None): - """LM training. - - Example: - - % python lm_train.py asr --print_config --optim adadelta - % python lm_train.py --config conf/train_asr.yaml - """ - LMTask.main(cmd=cmd) +def main(args=None, cmd=None): + # for LM Training + LMTask.main(args=args, cmd=cmd) -if __name__ == "__main__": - main() +if __name__ == '__main__': + args = parse_args() + + # setup local gpu_id + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) + + # DDP settings + if args.ngpu > 1: + args.distributed = True + else: + args.distributed = False + assert args.num_worker_count == 1 + + # re-compute batch size: when dataset type is small + if args.dataset_type == "small" and args.ngpu != 0: + if args.batch_size is not None: + args.batch_size = args.batch_size * args.ngpu + if args.batch_bins is not None: + args.batch_bins = args.batch_bins * args.ngpu + + main(args=args) diff --git a/funasr/bin/tokenize_text.py b/funasr/bin/tokenize_text.py new file mode 100755 index 000000000..dc565d03c --- /dev/null +++ b/funasr/bin/tokenize_text.py @@ -0,0 +1,283 @@ +#!/usr/bin/env python3 +import argparse +from collections import Counter +import logging +from pathlib import Path +import sys +from typing import List +from typing import Optional + +from typeguard import check_argument_types + +from funasr.utils.cli_utils import get_commandline_args +from funasr.text.build_tokenizer import build_tokenizer +from funasr.text.cleaner import TextCleaner +from funasr.text.phoneme_tokenizer import g2p_choices +from funasr.utils.types import str2bool +from funasr.utils.types import str_or_none + + +def field2slice(field: Optional[str]) -> slice: + """Convert field string to slice + + Note that field string accepts 1-based integer. + + Examples: + >>> field2slice("1-") + slice(0, None, None) + >>> field2slice("1-3") + slice(0, 3, None) + >>> field2slice("-3") + slice(None, 3, None) + """ + field = field.strip() + try: + if "-" in field: + # e.g. "2-" or "2-5" or "-7" + s1, s2 = field.split("-", maxsplit=1) + if s1.strip() == "": + s1 = None + else: + s1 = int(s1) + if s1 == 0: + raise ValueError("1-based string") + if s2.strip() == "": + s2 = None + else: + s2 = int(s2) + else: + # e.g. "2" + s1 = int(field) + s2 = s1 + 1 + if s1 == 0: + raise ValueError("must be 1 or more value") + except ValueError: + raise RuntimeError(f"Format error: e.g. '2-', '2-5', or '-5': {field}") + + if s1 is None: + slic = slice(None, s2) + else: + # -1 because of 1-based integer following "cut" command + # e.g "1-3" -> slice(0, 3) + slic = slice(s1 - 1, s2) + return slic + + +def tokenize( + input: str, + output: str, + field: Optional[str], + delimiter: Optional[str], + token_type: str, + space_symbol: str, + non_linguistic_symbols: Optional[str], + bpemodel: Optional[str], + log_level: str, + write_vocabulary: bool, + vocabulary_size: int, + remove_non_linguistic_symbols: bool, + cutoff: int, + add_symbol: List[str], + cleaner: Optional[str], + g2p: Optional[str], +): + assert check_argument_types() + + logging.basicConfig( + level=log_level, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + if input == "-": + fin = sys.stdin + else: + fin = Path(input).open("r", encoding="utf-8") + if output == "-": + fout = sys.stdout + else: + p = Path(output) + p.parent.mkdir(parents=True, exist_ok=True) + fout = p.open("w", encoding="utf-8") + + cleaner = TextCleaner(cleaner) + tokenizer = build_tokenizer( + token_type=token_type, + bpemodel=bpemodel, + delimiter=delimiter, + space_symbol=space_symbol, + non_linguistic_symbols=non_linguistic_symbols, + remove_non_linguistic_symbols=remove_non_linguistic_symbols, + g2p_type=g2p, + ) + + counter = Counter() + if field is not None: + field = field2slice(field) + + for line in fin: + line = line.rstrip() + if field is not None: + # e.g. field="2-" + # uttidA hello world!! -> hello world!! + tokens = line.split(delimiter) + tokens = tokens[field] + if delimiter is None: + line = " ".join(tokens) + else: + line = delimiter.join(tokens) + + line = cleaner(line) + tokens = tokenizer.text2tokens(line) + if not write_vocabulary: + fout.write(" ".join(tokens) + "\n") + else: + for t in tokens: + counter[t] += 1 + + if not write_vocabulary: + return + + ## FIXME + ## del duplicate add_symbols in counter + for symbol_and_id in add_symbol: + # e.g symbol=":0" + try: + symbol, idx = symbol_and_id.split(":") + except ValueError: + raise RuntimeError(f"Format error: e.g. ':0': {symbol_and_id}") + symbol = symbol.strip() + if symbol in counter: + del counter[symbol] + + # ======= write_vocabulary mode from here ======= + # Sort by the number of occurrences in descending order + # and filter lower frequency words than cutoff value + words_and_counts = list( + filter(lambda x: x[1] > cutoff, sorted(counter.items(), key=lambda x: -x[1])) + ) + # Restrict the vocabulary size + if vocabulary_size > 0: + if vocabulary_size < len(add_symbol): + raise RuntimeError(f"vocabulary_size is too small: {vocabulary_size}") + words_and_counts = words_and_counts[: vocabulary_size - len(add_symbol)] + + # Parse the values of --add_symbol + for symbol_and_id in add_symbol: + # e.g symbol=":0" + try: + symbol, idx = symbol_and_id.split(":") + idx = int(idx) + except ValueError: + raise RuntimeError(f"Format error: e.g. ':0': {symbol_and_id}") + symbol = symbol.strip() + + # e.g. idx=0 -> append as the first symbol + # e.g. idx=-1 -> append as the last symbol + if idx < 0: + idx = len(words_and_counts) + 1 + idx + words_and_counts.insert(idx, (symbol, None)) + + # Write words + for w, c in words_and_counts: + fout.write(w + "\n") + + # Logging + total_count = sum(counter.values()) + invocab_count = sum(c for w, c in words_and_counts if c is not None) + logging.info(f"OOV rate = {(total_count - invocab_count) / total_count * 100} %") + + +def get_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Tokenize texts", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--log_level", + type=lambda x: x.upper(), + default="INFO", + choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), + help="The verbose level of logging", + ) + + parser.add_argument( + "--input", "-i", required=True, help="Input text. - indicates sys.stdin" + ) + parser.add_argument( + "--output", "-o", required=True, help="Output text. - indicates sys.stdout" + ) + parser.add_argument( + "--field", + "-f", + help="The target columns of the input text as 1-based integer. e.g 2-", + ) + parser.add_argument( + "--token_type", + "-t", + default="char", + choices=["char", "bpe", "word", "phn"], + help="Token type", + ) + parser.add_argument("--delimiter", "-d", default=None, help="The delimiter") + parser.add_argument("--space_symbol", default="", help="The space symbol") + parser.add_argument("--bpemodel", default=None, help="The bpemodel file path") + parser.add_argument( + "--non_linguistic_symbols", + type=str_or_none, + help="non_linguistic_symbols file path", + ) + parser.add_argument( + "--remove_non_linguistic_symbols", + type=str2bool, + default=False, + help="Remove non-language-symbols from tokens", + ) + parser.add_argument( + "--cleaner", + type=str_or_none, + choices=[None, "tacotron", "jaconv", "vietnamese", "korean_cleaner"], + default=None, + help="Apply text cleaning", + ) + parser.add_argument( + "--g2p", + type=str_or_none, + choices=g2p_choices, + default=None, + help="Specify g2p method if --token_type=phn", + ) + + group = parser.add_argument_group("write_vocabulary mode related") + group.add_argument( + "--write_vocabulary", + type=str2bool, + default=False, + help="Write tokens list instead of tokenized text per line", + ) + group.add_argument("--vocabulary_size", type=int, default=0, help="Vocabulary size") + group.add_argument( + "--cutoff", + default=0, + type=int, + help="cut-off frequency used for write-vocabulary mode", + ) + group.add_argument( + "--add_symbol", + type=str, + default=[], + action="append", + help="Append symbol e.g. --add_symbol ':0' --add_symbol ':1'", + ) + + return parser + + +def main(cmd=None): + print(get_commandline_args(), file=sys.stderr) + parser = get_parser() + args = parser.parse_args(cmd) + kwargs = vars(args) + tokenize(**kwargs) + + +if __name__ == "__main__": + main() diff --git a/funasr/datasets/preprocessor.py b/funasr/datasets/preprocessor.py index 10fbccba7..79540c1fa 100644 --- a/funasr/datasets/preprocessor.py +++ b/funasr/datasets/preprocessor.py @@ -58,6 +58,15 @@ def seg_tokenize(txt, seg_dict): continue return out_txt.strip().split() +def seg_tokenize_wo_pattern(txt, seg_dict): + out_txt = "" + for word in txt: + if word in seg_dict: + out_txt += seg_dict[word] + " " + else: + out_txt += "" + " " + return out_txt.strip().split() + def framing( x, @@ -372,6 +381,70 @@ class CommonPreprocessor(AbsPreprocessor): data = self._text_process(data) return data +## FIXME +class LMPreprocessor(CommonPreprocessor): + def __init__( + self, + train: bool, + token_type: str = None, + token_list: Union[Path, str, Iterable[str]] = None, + bpemodel: Union[Path, str, Iterable[str]] = None, + text_cleaner: Collection[str] = None, + g2p_type: str = None, + unk_symbol: str = "", + space_symbol: str = "", + non_linguistic_symbols: Union[Path, str, Iterable[str]] = None, + delimiter: str = None, + rir_scp: str = None, + rir_apply_prob: float = 1.0, + noise_scp: str = None, + noise_apply_prob: float = 1.0, + noise_db_range: str = "3_10", + speech_volume_normalize: float = None, + speech_name: str = "speech", + text_name: str = "text", + split_with_space: bool = False, + seg_dict_file: str = None, + ): + super().__init__(train, + token_type, + token_list, + bpemodel, + text_cleaner, + g2p_type, + unk_symbol, + space_symbol, + non_linguistic_symbols, + delimiter, + rir_scp, + rir_apply_prob, + noise_scp, + noise_apply_prob, + noise_db_range, + speech_volume_normalize, + speech_name, + text_name, + split_with_space, + seg_dict_file, + ) + + def _text_process( + self, data: Dict[str, Union[str, np.ndarray]] + ) -> Dict[str, np.ndarray]: + if self.text_name in data and self.tokenizer is not None: + text = data[self.text_name] + text = self.text_cleaner(text) + if self.split_with_space: + tokens = text.strip().split(" ") + if self.seg_dict is not None: + tokens = seg_tokenize_wo_pattern(tokens, self.seg_dict) + else: + tokens = self.tokenizer.text2tokens(text) + text_ints = self.token_id_converter.tokens2ids(tokens) + data[self.text_name] = np.array(text_ints, dtype=np.int64) + assert check_return_type(data) + return data + class CommonPreprocessor_multi(AbsPreprocessor): def __init__( diff --git a/funasr/lm/espnet_model.py b/funasr/lm/espnet_model.py index 4fc3b49c8..db11b6741 100644 --- a/funasr/lm/espnet_model.py +++ b/funasr/lm/espnet_model.py @@ -46,10 +46,10 @@ class ESPnetLanguageModel(AbsESPnetModel): # 1. Create a sentence pair like ' w1 w2 w3' and 'w1 w2 w3 ' # text: (Batch, Length) -> x, y: (Batch, Length + 1) - x = F.pad(text, [1, 0], "constant", self.eos) + x = F.pad(text, [1, 0], "constant", self.sos) t = F.pad(text, [0, 1], "constant", self.ignore_id) for i, l in enumerate(text_lengths): - t[i, l] = self.sos + t[i, l] = self.eos x_lengths = text_lengths + 1 # 2. Forward Language model diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py index 789940019..02311fda9 100644 --- a/funasr/tasks/abs_task.py +++ b/funasr/tasks/abs_task.py @@ -43,6 +43,7 @@ from funasr.iterators.abs_iter_factory import AbsIterFactory from funasr.iterators.chunk_iter_factory import ChunkIterFactory from funasr.iterators.multiple_iter_factory import MultipleIterFactory from funasr.iterators.sequence_iter_factory import SequenceIterFactory +from funasr.main_funcs.collect_stats import collect_stats from funasr.optimizers.sgd import SGD from funasr.optimizers.fairseq_adam import FairseqAdam from funasr.samplers.build_batch_sampler import BATCH_TYPES @@ -1272,6 +1273,52 @@ class AbsTask(ABC): if args.dry_run: pass + elif args.collect_stats: + # Perform on collect_stats mode. This mode has two roles + # - Derive the length and dimension of all input data + # - Accumulate feats, square values, and the length for whitening + + if args.valid_batch_size is None: + args.valid_batch_size = args.batch_size + + if len(args.train_shape_file) != 0: + train_key_file = args.train_shape_file[0] + else: + train_key_file = None + if len(args.valid_shape_file) != 0: + valid_key_file = args.valid_shape_file[0] + else: + valid_key_file = None + + collect_stats( + model=model, + train_iter=cls.build_streaming_iterator( + data_path_and_name_and_type=args.train_data_path_and_name_and_type, + key_file=train_key_file, + batch_size=args.batch_size, + dtype=args.train_dtype, + num_workers=args.num_workers, + allow_variable_data_keys=args.allow_variable_data_keys, + ngpu=args.ngpu, + preprocess_fn=cls.build_preprocess_fn(args, train=False), + collate_fn=cls.build_collate_fn(args, train=False), + ), + valid_iter=cls.build_streaming_iterator( + data_path_and_name_and_type=args.valid_data_path_and_name_and_type, + key_file=valid_key_file, + batch_size=args.valid_batch_size, + dtype=args.train_dtype, + num_workers=args.num_workers, + allow_variable_data_keys=args.allow_variable_data_keys, + ngpu=args.ngpu, + preprocess_fn=cls.build_preprocess_fn(args, train=False), + collate_fn=cls.build_collate_fn(args, train=False), + ), + output_dir=output_dir, + ngpu=args.ngpu, + log_interval=args.log_interval, + write_collected_feats=args.write_collected_feats, + ) else: logging.info("Training args: {}".format(args)) # 6. Loads pre-trained model diff --git a/funasr/tasks/lm.py b/funasr/tasks/lm.py index 46b9fe089..608c1d3eb 100644 --- a/funasr/tasks/lm.py +++ b/funasr/tasks/lm.py @@ -58,7 +58,7 @@ class LMTask(AbsTask): # NOTE(kamo): add_arguments(..., required=True) can't be used # to provide --print_config mode. Instead of it, do as required = parser.get_default("required") - required += ["token_list"] + # required += ["token_list"] group.add_argument( "--token_list",