From 03875965c8a650571903283810c822005b01d22b Mon Sep 17 00:00:00 2001 From: lzr265946 Date: Thu, 9 Feb 2023 15:13:14 +0800 Subject: [PATCH 1/3] remove global vars --- funasr/bin/asr_inference.py | 10 ---------- funasr/utils/asr_utils.py | 16 ++++------------ 2 files changed, 4 insertions(+), 22 deletions(-) diff --git a/funasr/bin/asr_inference.py b/funasr/bin/asr_inference.py index d419018ec..16fa3e51c 100644 --- a/funasr/bin/asr_inference.py +++ b/funasr/bin/asr_inference.py @@ -464,16 +464,6 @@ def inference_modelscope( return _forward -def set_parameters(language: str = None, - sample_rate: Union[int, Dict[Any, int]] = None): - if language is not None: - global global_asr_language - global_asr_language = language - if sample_rate is not None: - global global_sample_rate - global_sample_rate = sample_rate - - def get_parser(): parser = config_argparse.ArgumentParser( description="ASR Decoding", diff --git a/funasr/utils/asr_utils.py b/funasr/utils/asr_utils.py index 76b20a51a..a3ff3e331 100644 --- a/funasr/utils/asr_utils.py +++ b/funasr/utils/asr_utils.py @@ -186,23 +186,12 @@ def recursion_dir_all_wav(wav_list, dir_path: str) -> List[str]: return wav_list - -def set_parameters(language: str = None): - if language is not None: - global global_asr_language - global_asr_language = language - - def compute_wer(hyp_list: List[Any], ref_list: List[Any], lang: str = None) -> Dict[str, Any]: assert len(hyp_list) > 0, 'hyp list is empty' assert len(ref_list) > 0, 'ref list is empty' - if lang is not None: - global global_asr_language - global_asr_language = lang - rst = { 'Wrd': 0, 'Corr': 0, @@ -216,12 +205,15 @@ def compute_wer(hyp_list: List[Any], 'wrong_sentences': 0 } + if lang is None: + lang = global_asr_language + for h_item in hyp_list: for r_item in ref_list: if h_item['key'] == r_item['key']: out_item = compute_wer_by_line(h_item['value'], r_item['value'], - global_asr_language) + lang) rst['Wrd'] += out_item['nwords'] rst['Corr'] += out_item['cor'] rst['wrong_words'] += out_item['wrong'] From 983bb9382ebcbd29c8c59251b4bcbd328bdf3bf1 Mon Sep 17 00:00:00 2001 From: lzr265946 Date: Thu, 9 Feb 2023 15:14:20 +0800 Subject: [PATCH 2/3] fix bug in predictor tail_process_fn --- funasr/models/predictor/cif.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py index 00c5a3e92..c28146062 100644 --- a/funasr/models/predictor/cif.py +++ b/funasr/models/predictor/cif.py @@ -68,7 +68,8 @@ class CifPredictor(nn.Module): mask_2 = torch.cat([ones_t, mask], dim=1) mask = mask_2 - mask_1 tail_threshold = mask * tail_threshold - alphas = torch.cat([alphas, tail_threshold], dim=1) + alphas = torch.cat([alphas, zeros_t], dim=1) + alphas = torch.add(alphas, tail_threshold) else: tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device) tail_threshold = torch.reshape(tail_threshold, (1, 1)) From 44fe2e811fed3d1d2341e54fa72a06931b011da1 Mon Sep 17 00:00:00 2001 From: lzr265946 Date: Thu, 9 Feb 2023 15:23:40 +0800 Subject: [PATCH 3/3] fix bug in predictor tail_process_fn --- funasr/models/predictor/cif.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py index c28146062..c34759d0d 100644 --- a/funasr/models/predictor/cif.py +++ b/funasr/models/predictor/cif.py @@ -598,7 +598,8 @@ class CifPredictorV3(nn.Module): mask_2 = torch.cat([ones_t, mask], dim=1) mask = mask_2 - mask_1 tail_threshold = mask * tail_threshold - alphas = torch.cat([alphas, tail_threshold], dim=1) + alphas = torch.cat([alphas, zeros_t], dim=1) + alphas = torch.add(alphas, tail_threshold) else: tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device) tail_threshold = torch.reshape(tail_threshold, (1, 1))