From 0c4fbea66b7c4eddeec5734d4ff43ad85e32d5fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Thu, 15 Jun 2023 15:39:22 +0800 Subject: [PATCH] update repo --- funasr/bin/lm_inference_launch.py | 127 ++++++++---------- .../build_utils/build_streaming_iterator.py | 1 + 2 files changed, 56 insertions(+), 72 deletions(-) diff --git a/funasr/bin/lm_inference_launch.py b/funasr/bin/lm_inference_launch.py index 1d99fcec5..c8482b8cc 100644 --- a/funasr/bin/lm_inference_launch.py +++ b/funasr/bin/lm_inference_launch.py @@ -1,5 +1,5 @@ -# -*- encoding: utf-8 -*- #!/usr/bin/env python3 +# -*- encoding: utf-8 -*- # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. # MIT License (https://opensource.org/licenses/MIT) @@ -7,40 +7,25 @@ import argparse import logging import os import sys -from typing import Union, Dict, Any - -from funasr.utils import config_argparse -from funasr.utils.cli_utils import get_commandline_args -from funasr.utils.types import str2bool -from funasr.utils.types import str2triple_str -from funasr.utils.types import str_or_none -from funasr.utils.types import float_or_none -import argparse -import logging -from pathlib import Path -import sys -import os -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import Union -from typing import Dict from typing import Any from typing import List +from typing import Optional +from typing import Union import numpy as np import torch from torch.nn.parallel import data_parallel from typeguard import check_argument_types -from funasr.tasks.lm import LMTask +from funasr.build_utils.build_model_from_file import build_model_from_file +from funasr.build_utils.build_streaming_iterator import build_streaming_iterator from funasr.datasets.preprocessor import LMPreprocessor -from funasr.utils.cli_utils import get_commandline_args from funasr.fileio.datadir_writer import DatadirWriter from funasr.torch_utils.device_funcs import to_device from funasr.torch_utils.forward_adaptor import ForwardAdaptor from funasr.torch_utils.set_all_random_seed import set_all_random_seed from funasr.utils import config_argparse +from funasr.utils.cli_utils import get_commandline_args from funasr.utils.types import float_or_none from funasr.utils.types import str2bool from funasr.utils.types import str2triple_str @@ -48,42 +33,42 @@ from funasr.utils.types import str_or_none def inference_lm( - batch_size: int, - dtype: str, - ngpu: int, - seed: int, - num_workers: int, - log_level: Union[int, str], - key_file: Optional[str], - train_config: Optional[str], - model_file: Optional[str], - log_base: Optional[float] = 10, - allow_variable_data_keys: bool = False, - split_with_space: Optional[bool] = False, - seg_dict_file: Optional[str] = None, - output_dir: Optional[str] = None, - param_dict: dict = None, - **kwargs, + batch_size: int, + dtype: str, + ngpu: int, + seed: int, + num_workers: int, + log_level: Union[int, str], + key_file: Optional[str], + train_config: Optional[str], + model_file: Optional[str], + log_base: Optional[float] = 10, + allow_variable_data_keys: bool = False, + split_with_space: Optional[bool] = False, + seg_dict_file: Optional[str] = None, + output_dir: Optional[str] = None, + param_dict: dict = None, + **kwargs, ): assert check_argument_types() ncpu = kwargs.get("ncpu", 1) torch.set_num_threads(ncpu) - + if ngpu >= 1 and torch.cuda.is_available(): device = "cuda" else: device = "cpu" - + # 1. Set random-seed set_all_random_seed(seed) - + # 2. Build Model - model, train_args = LMTask.build_model_from_file( - train_config, model_file, device) + model, train_args = build_model_from_file( + train_config, model_file, None, device, "lm") wrapped_model = ForwardAdaptor(model, "nll") wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval() logging.info(f"Model:\n{model}") - + preprocessor = LMPreprocessor( train=False, token_type=train_args.token_type, @@ -96,12 +81,12 @@ def inference_lm( split_with_space=split_with_space, seg_dict_file=seg_dict_file ) - + def _forward( - data_path_and_name_and_type, - raw_inputs: Union[List[Any], bytes, str] = None, - output_dir_v2: Optional[str] = None, - param_dict: dict = None, + data_path_and_name_and_type, + raw_inputs: Union[List[Any], bytes, str] = None, + output_dir_v2: Optional[str] = None, + param_dict: dict = None, ): results = [] output_path = output_dir_v2 if output_dir_v2 is not None else output_dir @@ -109,7 +94,7 @@ def inference_lm( writer = DatadirWriter(output_path) else: writer = None - + if raw_inputs != None: line = raw_inputs.strip() key = "lm demo" @@ -121,7 +106,7 @@ def inference_lm( batch['text'] = line if preprocessor != None: batch = preprocessor(key, batch) - + # Force data-precision for name in batch: value = batch[name] @@ -138,11 +123,11 @@ def inference_lm( else: raise NotImplementedError(f"Not supported dtype: {value.dtype}") batch[name] = value - + batch["text_lengths"] = torch.from_numpy( np.array([len(batch["text"])], dtype='int32')) batch["text"] = np.expand_dims(batch["text"], axis=0) - + with torch.no_grad(): batch = to_device(batch, device) if ngpu <= 1: @@ -173,7 +158,7 @@ def inference_lm( word_nll=round(word_nll.item(), 8) ) pre_word = cur_word - + sent_nll_mean = sent_nll.mean().cpu().numpy() sent_nll_sum = sent_nll.sum().cpu().numpy() if log_base is None: @@ -189,22 +174,20 @@ def inference_lm( if writer is not None: writer["ppl"][key + ":\n"] = ppl_out results.append(item) - + return results - + # 3. Build data-iterator - loader = LMTask.build_streaming_iterator( - data_path_and_name_and_type, + loader = build_streaming_iterator( + task_name="lm", + preprocess_args=train_args, + data_path_and_name_and_type=data_path_and_name_and_type, dtype=dtype, batch_size=batch_size, key_file=key_file, num_workers=num_workers, - preprocess_fn=preprocessor, - collate_fn=LMTask.build_collate_fn(train_args, False), - allow_variable_data_keys=allow_variable_data_keys, - inference=True, ) - + # 4. Start for-loop total_nll = 0.0 total_ntokens = 0 @@ -214,7 +197,7 @@ def inference_lm( assert all(isinstance(s, str) for s in keys), keys _bs = len(next(iter(batch.values()))) assert len(keys) == _bs, f"{len(keys)} != {_bs}" - + ppl_out_batch = "" with torch.no_grad(): batch = to_device(batch, device) @@ -247,7 +230,7 @@ def inference_lm( word_nll=round(word_nll.item(), 8) ) pre_word = cur_word - + sent_nll_mean = sent_nll.mean().cpu().numpy() sent_nll_sum = sent_nll.sum().cpu().numpy() if log_base is None: @@ -265,9 +248,9 @@ def inference_lm( writer["ppl"][key + ":\n"] = ppl_out writer["utt2nll"][key] = str(utt2nll) results.append(item) - + ppl_out_all += ppl_out_batch - + assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths)) # nll: (B, L) -> (B,) nll = nll.detach().cpu().numpy().sum(1) @@ -275,12 +258,12 @@ def inference_lm( lengths = lengths.detach().cpu().numpy() total_nll += nll.sum() total_ntokens += lengths.sum() - + if log_base is None: ppl = np.exp(total_nll / total_ntokens) else: ppl = log_base ** (total_nll / total_ntokens / np.log(log_base)) - + avg_ppl = 'logprob= {total_nll} ppl= {total_ppl}\n'.format( total_nll=round(-total_nll.item(), 4), total_ppl=round(ppl.item(), 4) @@ -290,9 +273,9 @@ def inference_lm( if writer is not None: writer["ppl"]["AVG PPL : "] = avg_ppl results.append(item) - + return results - + return _forward @@ -302,7 +285,8 @@ def inference_launch(mode, **kwargs): else: logging.info("Unknown decoding mode: {}".format(mode)) return None - + + def get_parser(): parser = config_argparse.ArgumentParser( description="Calc perplexity", @@ -407,4 +391,3 @@ def main(cmd=None): if __name__ == "__main__": main() - diff --git a/funasr/build_utils/build_streaming_iterator.py b/funasr/build_utils/build_streaming_iterator.py index da42929f1..ad36b4e61 100644 --- a/funasr/build_utils/build_streaming_iterator.py +++ b/funasr/build_utils/build_streaming_iterator.py @@ -26,6 +26,7 @@ def build_streaming_iterator( # preprocess if preprocess_args is not None: + preprocess_args.task_name = task_name preprocess_fn = build_preprocess(preprocess_args, train) else: preprocess_fn = None