diff --git a/funasr/bin/lm_calc_perplexity.py b/funasr/bin/lm_calc_perplexity.py
index 27a8a71fc..198d5784c 100755
--- a/funasr/bin/lm_calc_perplexity.py
+++ b/funasr/bin/lm_calc_perplexity.py
@@ -56,7 +56,7 @@ def calc_perplexity(
set_all_random_seed(seed)
# 2. Build LM
- model, train_args = LMTask.build_model_from_file(train_config, model_file, device)
+ model, train_args = LMTask.build_model_from_file(config_file=train_config, model_file=model_file, device=device)
# Wrape model to make model.nll() data-parallel
wrapped_model = ForwardAdaptor(model, "nll")
wrapped_model.to(dtype=getattr(torch, dtype)).eval()
@@ -111,6 +111,7 @@ def calc_perplexity(
utt_ppl = log_base ** (_nll / ntoken / np.log(log_base))
# Write PPL of each utts for debugging or analysis
+ writer["utt2nll"][key] = str(-_nll)
writer["utt2ppl"][key] = str(utt_ppl)
writer["utt2ntokens"][key] = str(ntoken)
diff --git a/funasr/bin/lm_inference.py b/funasr/bin/lm_inference.py
new file mode 100644
index 000000000..909cb02da
--- /dev/null
+++ b/funasr/bin/lm_inference.py
@@ -0,0 +1,406 @@
+#!/usr/bin/env python3
+import argparse
+import logging
+from pathlib import Path
+import sys
+import os
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+from typing import Dict
+from typing import Any
+from typing import List
+
+import numpy as np
+import torch
+from torch.nn.parallel import data_parallel
+from typeguard import check_argument_types
+
+from funasr.tasks.lm import LMTask
+from funasr.datasets.preprocessor import LMPreprocessor
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.fileio.datadir_writer import DatadirWriter
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.forward_adaptor import ForwardAdaptor
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.types import float_or_none
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+
+def inference(
+ output_dir: str,
+ batch_size: int,
+ dtype: str,
+ ngpu: int,
+ seed: int,
+ num_workers: int,
+ log_level: Union[int, str],
+ train_config: Optional[str],
+ model_file: Optional[str],
+ log_base: Optional[float],
+ key_file: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ split_with_space: Optional[bool] = False,
+ seg_dict_file: Optional[str] = None,
+ data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
+ raw_inputs: Union[List[Any], bytes, str] = None,
+ **kwargs,
+):
+ inference_pipeline = inference_modelscope(
+ output_dir=output_dir,
+ raw_inputs=raw_inputs,
+ batch_size=batch_size,
+ dtype=dtype,
+ ngpu=ngpu,
+ seed=seed,
+ num_workers=num_workers,
+ log_level=log_level,
+ key_file=key_file,
+ train_config=train_config,
+ model_file=model_file,
+ log_base = log_base,
+ allow_variable_data_keys = allow_variable_data_keys,
+ split_with_space=split_with_space,
+ seg_dict_file=seg_dict_file,
+ **kwargs,
+ )
+ return inference_pipeline(data_path_and_name_and_type, raw_inputs)
+
+
+def inference_modelscope(
+ batch_size: int,
+ dtype: str,
+ ngpu: int,
+ seed: int,
+ num_workers: int,
+ log_level: Union[int, str],
+ key_file: Optional[str],
+ train_config: Optional[str],
+ model_file: Optional[str],
+ log_base: Optional[float] = 10,
+ allow_variable_data_keys: bool = False,
+ split_with_space: Optional[bool] = False,
+ seg_dict_file: Optional[str] = None,
+ output_dir: Optional[str] = None,
+ param_dict: dict = None,
+ **kwargs,
+):
+ assert check_argument_types()
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build Model
+ model, train_args = LMTask.build_model_from_file(
+ train_config, model_file, device)
+ wrapped_model = ForwardAdaptor(model, "nll")
+ wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
+ logging.info(f"Model:\n{model}")
+
+ preprocessor = LMPreprocessor(
+ train=False,
+ token_type=train_args.token_type,
+ token_list=train_args.token_list,
+ bpemodel=train_args.bpemodel,
+ text_cleaner=train_args.cleaner,
+ g2p_type=train_args.g2p,
+ text_name="text",
+ non_linguistic_symbols=train_args.non_linguistic_symbols,
+ split_with_space=split_with_space,
+ seg_dict_file=seg_dict_file
+ )
+
+ def _forward(
+ data_path_and_name_and_type,
+ raw_inputs: Union[List[Any], bytes, str] = None,
+ output_dir_v2: Optional[str] = None,
+ param_dict: dict = None,
+ ):
+ results = []
+ if output_dir_v2 is not None:
+ writer = DatadirWriter(output_dir_v2)
+ else:
+ writer = None
+
+ if raw_inputs != None:
+ line = raw_inputs.strip()
+ key = "lm demo"
+ if line=="":
+ item = {'key': key, 'value': ""}
+ results.append(item)
+ return results
+ batch = {}
+ batch['text'] = line
+ if preprocessor != None:
+ batch = preprocessor(key, batch)
+
+ # Force data-precision
+ for name in batch:
+ value = batch[name]
+ if not isinstance(value, np.ndarray):
+ raise RuntimeError(
+ f"All values must be converted to np.ndarray object "
+ f'by preprocessing, but "{name}" is still {type(value)}.'
+ )
+ # Cast to desired type
+ if value.dtype.kind == "f":
+ value = value.astype("float32")
+ elif value.dtype.kind == "i":
+ value = value.astype("long")
+ else:
+ raise NotImplementedError(f"Not supported dtype: {value.dtype}")
+ batch[name] = value
+
+ batch["text_lengths"] = torch.from_numpy(
+ np.array([len(batch["text"])], dtype='int32'))
+ batch["text"] = np.expand_dims(batch["text"], axis=0)
+
+ with torch.no_grad():
+ batch = to_device(batch, device)
+ if ngpu <= 1:
+ nll, lengths = wrapped_model(**batch)
+ else:
+ nll, lengths = data_parallel(
+ wrapped_model, (), range(ngpu), module_kwargs=batch
+ )
+ ## compute ppl
+ ppl_out_batch = ""
+ ids2tokens = preprocessor.token_id_converter.ids2tokens
+ for sent_ids, sent_nll in zip(batch['text'], nll):
+ pre_word = ""
+ cur_word = None
+ sent_lst = ids2tokens(sent_ids) + ['']
+ ppl_out = " ".join(sent_lst) + "\n"
+ for word, word_nll in zip(sent_lst, sent_nll):
+ cur_word = word
+ word_nll = -word_nll.cpu()
+ if log_base is None:
+ word_prob = np.exp(word_nll)
+ else:
+ word_prob = log_base ** (word_nll / np.log(log_base))
+ ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
+ cur=cur_word,
+ pre=pre_word,
+ prob=round(word_prob.item(), 8),
+ word_nll=round(word_nll.item(), 8)
+ )
+ pre_word = cur_word
+
+ sent_nll_mean = sent_nll.mean().cpu().numpy()
+ sent_nll_sum = sent_nll.sum().cpu().numpy()
+ if log_base is None:
+ sent_ppl = np.exp(sent_nll_mean)
+ else:
+ sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
+ ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
+ sent_nll=round(-sent_nll_sum.item(), 4),
+ sent_ppl=round(sent_ppl.item(), 4)
+ )
+ ppl_out_batch += ppl_out
+ item = {'key': key, 'value': ppl_out}
+ if writer is not None:
+ writer["ppl"][key+":\n"] = ppl_out
+ results.append(item)
+
+ return results
+
+ # 3. Build data-iterator
+ loader = LMTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ batch_size=batch_size,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=preprocessor,
+ collate_fn=LMTask.build_collate_fn(train_args, False),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
+
+ # 4. Start for-loop
+ total_nll = 0.0
+ total_ntokens = 0
+ ppl_out_all = ""
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+
+ ppl_out_batch = ""
+ with torch.no_grad():
+ batch = to_device(batch, device)
+ if ngpu <= 1:
+ # NOTE(kamo): data_parallel also should work with ngpu=1,
+ # but for debuggability it's better to keep this block.
+ nll, lengths = wrapped_model(**batch)
+ else:
+ nll, lengths = data_parallel(
+ wrapped_model, (), range(ngpu), module_kwargs=batch
+ )
+ ## print ppl
+ ids2tokens = preprocessor.token_id_converter.ids2tokens
+ for key, sent_ids, sent_nll in zip(keys, batch['text'], nll):
+ pre_word = ""
+ cur_word = None
+ sent_lst = ids2tokens(sent_ids) + ['']
+ ppl_out = " ".join(sent_lst) + "\n"
+ for word, word_nll in zip(sent_lst, sent_nll):
+ cur_word = word
+ word_nll = -word_nll.cpu()
+ if log_base is None:
+ word_prob = np.exp(word_nll)
+ else:
+ word_prob = log_base ** (word_nll / np.log(log_base))
+ ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
+ cur=cur_word,
+ pre=pre_word,
+ prob=round(word_prob.item(), 8),
+ word_nll=round(word_nll.item(), 8)
+ )
+ pre_word = cur_word
+
+ sent_nll_mean = sent_nll.mean().cpu().numpy()
+ sent_nll_sum = sent_nll.sum().cpu().numpy()
+ if log_base is None:
+ sent_ppl = np.exp(sent_nll_mean)
+ else:
+ sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
+ ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
+ sent_nll=round(-sent_nll_sum.item(), 4),
+ sent_ppl=round(sent_ppl.item(), 4)
+ )
+ ppl_out_batch += ppl_out
+ utt2nll = round(-sent_nll_sum.item(), 5)
+ item = {'key': key, 'value': ppl_out}
+ if writer is not None:
+ writer["ppl"][key+":\n"] = ppl_out
+ writer["utt2nll"][key] = str(utt2nll)
+ results.append(item)
+
+ ppl_out_all += ppl_out_batch
+
+ assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths))
+ # nll: (B, L) -> (B,)
+ nll = nll.detach().cpu().numpy().sum(1)
+ # lengths: (B,)
+ lengths = lengths.detach().cpu().numpy()
+ total_nll += nll.sum()
+ total_ntokens += lengths.sum()
+
+ if log_base is None:
+ ppl = np.exp(total_nll / total_ntokens)
+ else:
+ ppl = log_base ** (total_nll / total_ntokens / np.log(log_base))
+
+ avg_ppl = 'logprob= {total_nll} ppl= {total_ppl}\n'.format(
+ total_nll=round(-total_nll.item(), 4),
+ total_ppl=round(ppl.item(), 4)
+ )
+ item = {'key': 'AVG PPL', 'value': avg_ppl}
+ ppl_out_all += avg_ppl
+ if writer is not None:
+ writer["ppl"]["AVG PPL : "] = avg_ppl
+ results.append(item)
+
+ return results
+
+ return _forward
+
+
+def get_parser():
+ parser = config_argparse.ArgumentParser(
+ description="Calc perplexity",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+
+ parser.add_argument(
+ "--log_level",
+ type=lambda x: x.upper(),
+ default="INFO",
+ choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+ help="The verbose level of logging",
+ )
+
+ parser.add_argument("--output_dir", type=str, required=False)
+ parser.add_argument(
+ "--ngpu",
+ type=int,
+ default=0,
+ help="The number of gpus. 0 indicates CPU mode",
+ )
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
+ parser.add_argument(
+ "--dtype",
+ default="float32",
+ choices=["float16", "float32", "float64"],
+ help="Data type",
+ )
+ parser.add_argument(
+ "--num_workers",
+ type=int,
+ default=1,
+ help="The number of workers used for DataLoader",
+ )
+ parser.add_argument(
+ "--batch_size",
+ type=int,
+ default=1,
+ help="The batch size for inference",
+ )
+ parser.add_argument(
+ "--log_base",
+ type=float_or_none,
+ default=10,
+ help="The base of logarithm for Perplexity. "
+ "If None, napier's constant is used.",
+ required=False
+ )
+
+ group = parser.add_argument_group("Input data related")
+ group.add_argument(
+ "--data_path_and_name_and_type",
+ type=str2triple_str,
+ action="append",
+ required=False
+ )
+ group.add_argument(
+ "--raw_inputs",
+ type=str,
+ required=False
+ )
+ group.add_argument("--key_file", type=str_or_none)
+ group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
+
+ group.add_argument("--split_with_space", type=str2bool, default=False)
+ group.add_argument("--seg_dict_file", type=str_or_none)
+
+ group = parser.add_argument_group("The model configuration related")
+ group.add_argument("--train_config", type=str)
+ group.add_argument("--model_file", type=str)
+
+ return parser
+
+
+def main(cmd=None):
+ print(get_commandline_args(), file=sys.stderr)
+ parser = get_parser()
+ args = parser.parse_args(cmd)
+ kwargs = vars(args)
+ inference(**kwargs)
+
+if __name__ == "__main__":
+ main()
+
diff --git a/funasr/bin/lm_inference_launch.py b/funasr/bin/lm_inference_launch.py
new file mode 100644
index 000000000..492ebab5a
--- /dev/null
+++ b/funasr/bin/lm_inference_launch.py
@@ -0,0 +1,130 @@
+#!/usr/bin/env python3
+# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+import argparse
+import logging
+import os
+import sys
+from typing import Union, Dict, Any
+
+from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from funasr.utils.types import float_or_none
+
+
+def get_parser():
+ parser = config_argparse.ArgumentParser(
+ description="Calc perplexity",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+
+ parser.add_argument(
+ "--log_level",
+ type=lambda x: x.upper(),
+ default="INFO",
+ choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+ help="The verbose level of logging",
+ )
+ parser.add_argument("--output_dir", type=str, required=True)
+ parser.add_argument("--gpuid_list", type=str, required=True)
+ parser.add_argument(
+ "--ngpu",
+ type=int,
+ default=0,
+ help="The number of gpus. 0 indicates CPU mode",
+ )
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
+ parser.add_argument("--njob", type=int, default=1, help="Random seed")
+ parser.add_argument(
+ "--dtype",
+ default="float32",
+ choices=["float16", "float32", "float64"],
+ help="Data type",
+ )
+ parser.add_argument(
+ "--num_workers",
+ type=int,
+ default=1,
+ help="The number of workers used for DataLoader",
+ )
+ parser.add_argument(
+ "--batch_size",
+ type=int,
+ default=1,
+ help="The batch size for inference",
+ )
+ parser.add_argument(
+ "--log_base",
+ type=float_or_none,
+ default=10,
+ help="The base of logarithm for Perplexity. "
+ "If None, napier's constant is used.",
+ required=False
+ )
+
+ group = parser.add_argument_group("Input data related")
+ group.add_argument(
+ "--data_path_and_name_and_type",
+ type=str2triple_str,
+ action="append",
+ required=False
+ )
+ group.add_argument(
+ "--raw_inputs",
+ type=str,
+ required=False
+ )
+ group.add_argument("--key_file", type=str_or_none)
+ group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
+
+ group.add_argument("--split_with_space", type=str2bool, default=False)
+ group.add_argument("--seg_dict_file", type=str_or_none)
+
+ group = parser.add_argument_group("The model configuration related")
+ group.add_argument("--train_config", type=str)
+ group.add_argument("--model_file", type=str)
+ group.add_argument("--mode", type=str, default="lm")
+ return parser
+
+def inference_launch(mode, **kwargs):
+ if mode == "transformer":
+ from funasr.bin.lm_inference import inference_modelscope
+ return inference_modelscope(**kwargs)
+ else:
+ logging.info("Unknown decoding mode: {}".format(mode))
+ return None
+
+
+def main(cmd=None):
+ print(get_commandline_args(), file=sys.stderr)
+ parser = get_parser()
+ args = parser.parse_args(cmd)
+ kwargs = vars(args)
+ kwargs.pop("config", None)
+
+ # set logging messages
+ logging.basicConfig(
+ level=args.log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+ logging.info("Decoding args: {}".format(kwargs))
+
+ # gpu setting
+ if args.ngpu > 0:
+ jobid = int(args.output_dir.split(".")[-1])
+ gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+ os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
+
+ kwargs.pop("gpuid_list", None)
+ kwargs.pop("njob", None)
+ results = inference_launch(**kwargs)
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/funasr/bin/lm_train.py b/funasr/bin/lm_train.py
index faa7a4596..8641465eb 100755
--- a/funasr/bin/lm_train.py
+++ b/funasr/bin/lm_train.py
@@ -1,22 +1,46 @@
#!/usr/bin/env python3
+
+import os
+
from funasr.tasks.lm import LMTask
-def get_parser():
+# for LM Training
+def parse_args():
parser = LMTask.get_parser()
- return parser
+ parser.add_argument(
+ "--gpu_id",
+ type=int,
+ default=0,
+ help="local gpu id.",
+ )
+ args = parser.parse_args()
+ return args
-def main(cmd=None):
- """LM training.
-
- Example:
-
- % python lm_train.py asr --print_config --optim adadelta
- % python lm_train.py --config conf/train_asr.yaml
- """
- LMTask.main(cmd=cmd)
+def main(args=None, cmd=None):
+ # for LM Training
+ LMTask.main(args=args, cmd=cmd)
-if __name__ == "__main__":
- main()
+if __name__ == '__main__':
+ args = parse_args()
+
+ # setup local gpu_id
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
+
+ # DDP settings
+ if args.ngpu > 1:
+ args.distributed = True
+ else:
+ args.distributed = False
+ assert args.num_worker_count == 1
+
+ # re-compute batch size: when dataset type is small
+ if args.dataset_type == "small" and args.ngpu != 0:
+ if args.batch_size is not None:
+ args.batch_size = args.batch_size * args.ngpu
+ if args.batch_bins is not None:
+ args.batch_bins = args.batch_bins * args.ngpu
+
+ main(args=args)
diff --git a/funasr/bin/tokenize_text.py b/funasr/bin/tokenize_text.py
new file mode 100755
index 000000000..dc565d03c
--- /dev/null
+++ b/funasr/bin/tokenize_text.py
@@ -0,0 +1,283 @@
+#!/usr/bin/env python3
+import argparse
+from collections import Counter
+import logging
+from pathlib import Path
+import sys
+from typing import List
+from typing import Optional
+
+from typeguard import check_argument_types
+
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.text.build_tokenizer import build_tokenizer
+from funasr.text.cleaner import TextCleaner
+from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.utils.types import str2bool
+from funasr.utils.types import str_or_none
+
+
+def field2slice(field: Optional[str]) -> slice:
+ """Convert field string to slice
+
+ Note that field string accepts 1-based integer.
+
+ Examples:
+ >>> field2slice("1-")
+ slice(0, None, None)
+ >>> field2slice("1-3")
+ slice(0, 3, None)
+ >>> field2slice("-3")
+ slice(None, 3, None)
+ """
+ field = field.strip()
+ try:
+ if "-" in field:
+ # e.g. "2-" or "2-5" or "-7"
+ s1, s2 = field.split("-", maxsplit=1)
+ if s1.strip() == "":
+ s1 = None
+ else:
+ s1 = int(s1)
+ if s1 == 0:
+ raise ValueError("1-based string")
+ if s2.strip() == "":
+ s2 = None
+ else:
+ s2 = int(s2)
+ else:
+ # e.g. "2"
+ s1 = int(field)
+ s2 = s1 + 1
+ if s1 == 0:
+ raise ValueError("must be 1 or more value")
+ except ValueError:
+ raise RuntimeError(f"Format error: e.g. '2-', '2-5', or '-5': {field}")
+
+ if s1 is None:
+ slic = slice(None, s2)
+ else:
+ # -1 because of 1-based integer following "cut" command
+ # e.g "1-3" -> slice(0, 3)
+ slic = slice(s1 - 1, s2)
+ return slic
+
+
+def tokenize(
+ input: str,
+ output: str,
+ field: Optional[str],
+ delimiter: Optional[str],
+ token_type: str,
+ space_symbol: str,
+ non_linguistic_symbols: Optional[str],
+ bpemodel: Optional[str],
+ log_level: str,
+ write_vocabulary: bool,
+ vocabulary_size: int,
+ remove_non_linguistic_symbols: bool,
+ cutoff: int,
+ add_symbol: List[str],
+ cleaner: Optional[str],
+ g2p: Optional[str],
+):
+ assert check_argument_types()
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+ if input == "-":
+ fin = sys.stdin
+ else:
+ fin = Path(input).open("r", encoding="utf-8")
+ if output == "-":
+ fout = sys.stdout
+ else:
+ p = Path(output)
+ p.parent.mkdir(parents=True, exist_ok=True)
+ fout = p.open("w", encoding="utf-8")
+
+ cleaner = TextCleaner(cleaner)
+ tokenizer = build_tokenizer(
+ token_type=token_type,
+ bpemodel=bpemodel,
+ delimiter=delimiter,
+ space_symbol=space_symbol,
+ non_linguistic_symbols=non_linguistic_symbols,
+ remove_non_linguistic_symbols=remove_non_linguistic_symbols,
+ g2p_type=g2p,
+ )
+
+ counter = Counter()
+ if field is not None:
+ field = field2slice(field)
+
+ for line in fin:
+ line = line.rstrip()
+ if field is not None:
+ # e.g. field="2-"
+ # uttidA hello world!! -> hello world!!
+ tokens = line.split(delimiter)
+ tokens = tokens[field]
+ if delimiter is None:
+ line = " ".join(tokens)
+ else:
+ line = delimiter.join(tokens)
+
+ line = cleaner(line)
+ tokens = tokenizer.text2tokens(line)
+ if not write_vocabulary:
+ fout.write(" ".join(tokens) + "\n")
+ else:
+ for t in tokens:
+ counter[t] += 1
+
+ if not write_vocabulary:
+ return
+
+ ## FIXME
+ ## del duplicate add_symbols in counter
+ for symbol_and_id in add_symbol:
+ # e.g symbol=":0"
+ try:
+ symbol, idx = symbol_and_id.split(":")
+ except ValueError:
+ raise RuntimeError(f"Format error: e.g. ':0': {symbol_and_id}")
+ symbol = symbol.strip()
+ if symbol in counter:
+ del counter[symbol]
+
+ # ======= write_vocabulary mode from here =======
+ # Sort by the number of occurrences in descending order
+ # and filter lower frequency words than cutoff value
+ words_and_counts = list(
+ filter(lambda x: x[1] > cutoff, sorted(counter.items(), key=lambda x: -x[1]))
+ )
+ # Restrict the vocabulary size
+ if vocabulary_size > 0:
+ if vocabulary_size < len(add_symbol):
+ raise RuntimeError(f"vocabulary_size is too small: {vocabulary_size}")
+ words_and_counts = words_and_counts[: vocabulary_size - len(add_symbol)]
+
+ # Parse the values of --add_symbol
+ for symbol_and_id in add_symbol:
+ # e.g symbol=":0"
+ try:
+ symbol, idx = symbol_and_id.split(":")
+ idx = int(idx)
+ except ValueError:
+ raise RuntimeError(f"Format error: e.g. ':0': {symbol_and_id}")
+ symbol = symbol.strip()
+
+ # e.g. idx=0 -> append as the first symbol
+ # e.g. idx=-1 -> append as the last symbol
+ if idx < 0:
+ idx = len(words_and_counts) + 1 + idx
+ words_and_counts.insert(idx, (symbol, None))
+
+ # Write words
+ for w, c in words_and_counts:
+ fout.write(w + "\n")
+
+ # Logging
+ total_count = sum(counter.values())
+ invocab_count = sum(c for w, c in words_and_counts if c is not None)
+ logging.info(f"OOV rate = {(total_count - invocab_count) / total_count * 100} %")
+
+
+def get_parser() -> argparse.ArgumentParser:
+ parser = argparse.ArgumentParser(
+ description="Tokenize texts",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--log_level",
+ type=lambda x: x.upper(),
+ default="INFO",
+ choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+ help="The verbose level of logging",
+ )
+
+ parser.add_argument(
+ "--input", "-i", required=True, help="Input text. - indicates sys.stdin"
+ )
+ parser.add_argument(
+ "--output", "-o", required=True, help="Output text. - indicates sys.stdout"
+ )
+ parser.add_argument(
+ "--field",
+ "-f",
+ help="The target columns of the input text as 1-based integer. e.g 2-",
+ )
+ parser.add_argument(
+ "--token_type",
+ "-t",
+ default="char",
+ choices=["char", "bpe", "word", "phn"],
+ help="Token type",
+ )
+ parser.add_argument("--delimiter", "-d", default=None, help="The delimiter")
+ parser.add_argument("--space_symbol", default="", help="The space symbol")
+ parser.add_argument("--bpemodel", default=None, help="The bpemodel file path")
+ parser.add_argument(
+ "--non_linguistic_symbols",
+ type=str_or_none,
+ help="non_linguistic_symbols file path",
+ )
+ parser.add_argument(
+ "--remove_non_linguistic_symbols",
+ type=str2bool,
+ default=False,
+ help="Remove non-language-symbols from tokens",
+ )
+ parser.add_argument(
+ "--cleaner",
+ type=str_or_none,
+ choices=[None, "tacotron", "jaconv", "vietnamese", "korean_cleaner"],
+ default=None,
+ help="Apply text cleaning",
+ )
+ parser.add_argument(
+ "--g2p",
+ type=str_or_none,
+ choices=g2p_choices,
+ default=None,
+ help="Specify g2p method if --token_type=phn",
+ )
+
+ group = parser.add_argument_group("write_vocabulary mode related")
+ group.add_argument(
+ "--write_vocabulary",
+ type=str2bool,
+ default=False,
+ help="Write tokens list instead of tokenized text per line",
+ )
+ group.add_argument("--vocabulary_size", type=int, default=0, help="Vocabulary size")
+ group.add_argument(
+ "--cutoff",
+ default=0,
+ type=int,
+ help="cut-off frequency used for write-vocabulary mode",
+ )
+ group.add_argument(
+ "--add_symbol",
+ type=str,
+ default=[],
+ action="append",
+ help="Append symbol e.g. --add_symbol ':0' --add_symbol ':1'",
+ )
+
+ return parser
+
+
+def main(cmd=None):
+ print(get_commandline_args(), file=sys.stderr)
+ parser = get_parser()
+ args = parser.parse_args(cmd)
+ kwargs = vars(args)
+ tokenize(**kwargs)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/funasr/datasets/preprocessor.py b/funasr/datasets/preprocessor.py
index 10fbccba7..79540c1fa 100644
--- a/funasr/datasets/preprocessor.py
+++ b/funasr/datasets/preprocessor.py
@@ -58,6 +58,15 @@ def seg_tokenize(txt, seg_dict):
continue
return out_txt.strip().split()
+def seg_tokenize_wo_pattern(txt, seg_dict):
+ out_txt = ""
+ for word in txt:
+ if word in seg_dict:
+ out_txt += seg_dict[word] + " "
+ else:
+ out_txt += "" + " "
+ return out_txt.strip().split()
+
def framing(
x,
@@ -372,6 +381,70 @@ class CommonPreprocessor(AbsPreprocessor):
data = self._text_process(data)
return data
+## FIXME
+class LMPreprocessor(CommonPreprocessor):
+ def __init__(
+ self,
+ train: bool,
+ token_type: str = None,
+ token_list: Union[Path, str, Iterable[str]] = None,
+ bpemodel: Union[Path, str, Iterable[str]] = None,
+ text_cleaner: Collection[str] = None,
+ g2p_type: str = None,
+ unk_symbol: str = "",
+ space_symbol: str = "",
+ non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+ delimiter: str = None,
+ rir_scp: str = None,
+ rir_apply_prob: float = 1.0,
+ noise_scp: str = None,
+ noise_apply_prob: float = 1.0,
+ noise_db_range: str = "3_10",
+ speech_volume_normalize: float = None,
+ speech_name: str = "speech",
+ text_name: str = "text",
+ split_with_space: bool = False,
+ seg_dict_file: str = None,
+ ):
+ super().__init__(train,
+ token_type,
+ token_list,
+ bpemodel,
+ text_cleaner,
+ g2p_type,
+ unk_symbol,
+ space_symbol,
+ non_linguistic_symbols,
+ delimiter,
+ rir_scp,
+ rir_apply_prob,
+ noise_scp,
+ noise_apply_prob,
+ noise_db_range,
+ speech_volume_normalize,
+ speech_name,
+ text_name,
+ split_with_space,
+ seg_dict_file,
+ )
+
+ def _text_process(
+ self, data: Dict[str, Union[str, np.ndarray]]
+ ) -> Dict[str, np.ndarray]:
+ if self.text_name in data and self.tokenizer is not None:
+ text = data[self.text_name]
+ text = self.text_cleaner(text)
+ if self.split_with_space:
+ tokens = text.strip().split(" ")
+ if self.seg_dict is not None:
+ tokens = seg_tokenize_wo_pattern(tokens, self.seg_dict)
+ else:
+ tokens = self.tokenizer.text2tokens(text)
+ text_ints = self.token_id_converter.tokens2ids(tokens)
+ data[self.text_name] = np.array(text_ints, dtype=np.int64)
+ assert check_return_type(data)
+ return data
+
class CommonPreprocessor_multi(AbsPreprocessor):
def __init__(
diff --git a/funasr/lm/espnet_model.py b/funasr/lm/espnet_model.py
index 4fc3b49c8..db11b6741 100644
--- a/funasr/lm/espnet_model.py
+++ b/funasr/lm/espnet_model.py
@@ -46,10 +46,10 @@ class ESPnetLanguageModel(AbsESPnetModel):
# 1. Create a sentence pair like ' w1 w2 w3' and 'w1 w2 w3 '
# text: (Batch, Length) -> x, y: (Batch, Length + 1)
- x = F.pad(text, [1, 0], "constant", self.eos)
+ x = F.pad(text, [1, 0], "constant", self.sos)
t = F.pad(text, [0, 1], "constant", self.ignore_id)
for i, l in enumerate(text_lengths):
- t[i, l] = self.sos
+ t[i, l] = self.eos
x_lengths = text_lengths + 1
# 2. Forward Language model
diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 789940019..02311fda9 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -43,6 +43,7 @@ from funasr.iterators.abs_iter_factory import AbsIterFactory
from funasr.iterators.chunk_iter_factory import ChunkIterFactory
from funasr.iterators.multiple_iter_factory import MultipleIterFactory
from funasr.iterators.sequence_iter_factory import SequenceIterFactory
+from funasr.main_funcs.collect_stats import collect_stats
from funasr.optimizers.sgd import SGD
from funasr.optimizers.fairseq_adam import FairseqAdam
from funasr.samplers.build_batch_sampler import BATCH_TYPES
@@ -1272,6 +1273,52 @@ class AbsTask(ABC):
if args.dry_run:
pass
+ elif args.collect_stats:
+ # Perform on collect_stats mode. This mode has two roles
+ # - Derive the length and dimension of all input data
+ # - Accumulate feats, square values, and the length for whitening
+
+ if args.valid_batch_size is None:
+ args.valid_batch_size = args.batch_size
+
+ if len(args.train_shape_file) != 0:
+ train_key_file = args.train_shape_file[0]
+ else:
+ train_key_file = None
+ if len(args.valid_shape_file) != 0:
+ valid_key_file = args.valid_shape_file[0]
+ else:
+ valid_key_file = None
+
+ collect_stats(
+ model=model,
+ train_iter=cls.build_streaming_iterator(
+ data_path_and_name_and_type=args.train_data_path_and_name_and_type,
+ key_file=train_key_file,
+ batch_size=args.batch_size,
+ dtype=args.train_dtype,
+ num_workers=args.num_workers,
+ allow_variable_data_keys=args.allow_variable_data_keys,
+ ngpu=args.ngpu,
+ preprocess_fn=cls.build_preprocess_fn(args, train=False),
+ collate_fn=cls.build_collate_fn(args, train=False),
+ ),
+ valid_iter=cls.build_streaming_iterator(
+ data_path_and_name_and_type=args.valid_data_path_and_name_and_type,
+ key_file=valid_key_file,
+ batch_size=args.valid_batch_size,
+ dtype=args.train_dtype,
+ num_workers=args.num_workers,
+ allow_variable_data_keys=args.allow_variable_data_keys,
+ ngpu=args.ngpu,
+ preprocess_fn=cls.build_preprocess_fn(args, train=False),
+ collate_fn=cls.build_collate_fn(args, train=False),
+ ),
+ output_dir=output_dir,
+ ngpu=args.ngpu,
+ log_interval=args.log_interval,
+ write_collected_feats=args.write_collected_feats,
+ )
else:
logging.info("Training args: {}".format(args))
# 6. Loads pre-trained model
diff --git a/funasr/tasks/lm.py b/funasr/tasks/lm.py
index 46b9fe089..608c1d3eb 100644
--- a/funasr/tasks/lm.py
+++ b/funasr/tasks/lm.py
@@ -58,7 +58,7 @@ class LMTask(AbsTask):
# NOTE(kamo): add_arguments(..., required=True) can't be used
# to provide --print_config mode. Instead of it, do as
required = parser.get_default("required")
- required += ["token_list"]
+ # required += ["token_list"]
group.add_argument(
"--token_list",