diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py index c722ebc3a..e12dbb544 100644 --- a/funasr/bin/asr_infer.py +++ b/funasr/bin/asr_infer.py @@ -377,6 +377,7 @@ class Speech2TextParaformer: self.asr_train_args = asr_train_args self.converter = converter self.tokenizer = tokenizer + self.cmvn_file = cmvn_file # 6. [Optional] Build hotword list from str, local file or url self.hotword_list = None @@ -519,6 +520,44 @@ class Speech2TextParaformer: return results def generate_hotwords_list(self, hotword_list_or_file): + def load_seg_dict(seg_dict_file): + seg_dict = {} + assert isinstance(seg_dict_file, str) + with open(seg_dict_file, "r", encoding="utf8") as f: + lines = f.readlines() + for line in lines: + s = line.strip().split() + key = s[0] + value = s[1:] + seg_dict[key] = " ".join(value) + return seg_dict + + def seg_tokenize(txt, seg_dict): + pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$') + out_txt = "" + for word in txt: + word = word.lower() + if word in seg_dict: + out_txt += seg_dict[word] + " " + else: + if pattern.match(word): + for char in word: + if char in seg_dict: + out_txt += seg_dict[char] + " " + else: + out_txt += "" + " " + else: + out_txt += "" + " " + return out_txt.strip().split() + + seg_dict = None + if self.cmvn_file is not None: + model_dir = os.path.dirname(self.cmvn_file) + seg_dict_file = os.path.join(model_dir, 'seg_dict') + if os.path.exists(seg_dict_file): + seg_dict = load_seg_dict(seg_dict_file) + else: + seg_dict = None # for None if hotword_list_or_file is None: hotword_list = None @@ -530,8 +569,11 @@ class Speech2TextParaformer: with codecs.open(hotword_list_or_file, 'r') as fin: for line in fin.readlines(): hw = line.strip() + hw_list = hw.split() + if seg_dict is not None: + hw_list = seg_tokenize(hw_list, seg_dict) hotword_str_list.append(hw) - hotword_list.append(self.converter.tokens2ids([i for i in hw])) + hotword_list.append(self.converter.tokens2ids(hw_list)) hotword_list.append([self.asr_model.sos]) hotword_str_list.append('') logging.info("Initialized hotword list from file: {}, hotword list: {}." @@ -551,8 +593,11 @@ class Speech2TextParaformer: with codecs.open(hotword_list_or_file, 'r') as fin: for line in fin.readlines(): hw = line.strip() + hw_list = hw.split() + if seg_dict is not None: + hw_list = seg_tokenize(hw_list, seg_dict) hotword_str_list.append(hw) - hotword_list.append(self.converter.tokens2ids([i for i in hw])) + hotword_list.append(self.converter.tokens2ids(hw_list)) hotword_list.append([self.asr_model.sos]) hotword_str_list.append('') logging.info("Initialized hotword list from file: {}, hotword list: {}." @@ -564,7 +609,10 @@ class Speech2TextParaformer: hotword_str_list = [] for hw in hotword_list_or_file.strip().split(): hotword_str_list.append(hw) - hotword_list.append(self.converter.tokens2ids([i for i in hw])) + hw_list = hw + if seg_dict is not None: + hw_list = seg_tokenize(hw_list, seg_dict) + hotword_list.append(self.converter.tokens2ids(hw_list)) hotword_list.append([self.asr_model.sos]) hotword_str_list.append('') logging.info("Hotword list: {}.".format(hotword_str_list))