mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge pull request #660 from alibaba-damo-academy/dev_lhn
fix english hotwords bug
This commit is contained in:
commit
786ed53467
@ -377,6 +377,7 @@ class Speech2TextParaformer:
|
||||
self.asr_train_args = asr_train_args
|
||||
self.converter = converter
|
||||
self.tokenizer = tokenizer
|
||||
self.cmvn_file = cmvn_file
|
||||
|
||||
# 6. [Optional] Build hotword list from str, local file or url
|
||||
self.hotword_list = None
|
||||
@ -519,6 +520,44 @@ class Speech2TextParaformer:
|
||||
return results
|
||||
|
||||
def generate_hotwords_list(self, hotword_list_or_file):
|
||||
def load_seg_dict(seg_dict_file):
|
||||
seg_dict = {}
|
||||
assert isinstance(seg_dict_file, str)
|
||||
with open(seg_dict_file, "r", encoding="utf8") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
s = line.strip().split()
|
||||
key = s[0]
|
||||
value = s[1:]
|
||||
seg_dict[key] = " ".join(value)
|
||||
return seg_dict
|
||||
|
||||
def seg_tokenize(txt, seg_dict):
|
||||
pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
|
||||
out_txt = ""
|
||||
for word in txt:
|
||||
word = word.lower()
|
||||
if word in seg_dict:
|
||||
out_txt += seg_dict[word] + " "
|
||||
else:
|
||||
if pattern.match(word):
|
||||
for char in word:
|
||||
if char in seg_dict:
|
||||
out_txt += seg_dict[char] + " "
|
||||
else:
|
||||
out_txt += "<unk>" + " "
|
||||
else:
|
||||
out_txt += "<unk>" + " "
|
||||
return out_txt.strip().split()
|
||||
|
||||
seg_dict = None
|
||||
if self.cmvn_file is not None:
|
||||
model_dir = os.path.dirname(self.cmvn_file)
|
||||
seg_dict_file = os.path.join(model_dir, 'seg_dict')
|
||||
if os.path.exists(seg_dict_file):
|
||||
seg_dict = load_seg_dict(seg_dict_file)
|
||||
else:
|
||||
seg_dict = None
|
||||
# for None
|
||||
if hotword_list_or_file is None:
|
||||
hotword_list = None
|
||||
@ -530,8 +569,11 @@ class Speech2TextParaformer:
|
||||
with codecs.open(hotword_list_or_file, 'r') as fin:
|
||||
for line in fin.readlines():
|
||||
hw = line.strip()
|
||||
hw_list = hw.split()
|
||||
if seg_dict is not None:
|
||||
hw_list = seg_tokenize(hw_list, seg_dict)
|
||||
hotword_str_list.append(hw)
|
||||
hotword_list.append(self.converter.tokens2ids([i for i in hw]))
|
||||
hotword_list.append(self.converter.tokens2ids(hw_list))
|
||||
hotword_list.append([self.asr_model.sos])
|
||||
hotword_str_list.append('<s>')
|
||||
logging.info("Initialized hotword list from file: {}, hotword list: {}."
|
||||
@ -551,8 +593,11 @@ class Speech2TextParaformer:
|
||||
with codecs.open(hotword_list_or_file, 'r') as fin:
|
||||
for line in fin.readlines():
|
||||
hw = line.strip()
|
||||
hw_list = hw.split()
|
||||
if seg_dict is not None:
|
||||
hw_list = seg_tokenize(hw_list, seg_dict)
|
||||
hotword_str_list.append(hw)
|
||||
hotword_list.append(self.converter.tokens2ids([i for i in hw]))
|
||||
hotword_list.append(self.converter.tokens2ids(hw_list))
|
||||
hotword_list.append([self.asr_model.sos])
|
||||
hotword_str_list.append('<s>')
|
||||
logging.info("Initialized hotword list from file: {}, hotword list: {}."
|
||||
@ -564,7 +609,10 @@ class Speech2TextParaformer:
|
||||
hotword_str_list = []
|
||||
for hw in hotword_list_or_file.strip().split():
|
||||
hotword_str_list.append(hw)
|
||||
hotword_list.append(self.converter.tokens2ids([i for i in hw]))
|
||||
hw_list = hw
|
||||
if seg_dict is not None:
|
||||
hw_list = seg_tokenize(hw_list, seg_dict)
|
||||
hotword_list.append(self.converter.tokens2ids(hw_list))
|
||||
hotword_list.append([self.asr_model.sos])
|
||||
hotword_str_list.append('<s>')
|
||||
logging.info("Hotword list: {}.".format(hotword_str_list))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user