mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
support hotword parameter passing in the pipeline forward
This commit is contained in:
parent
ea377e7e7b
commit
f13cfbc18e
@ -395,6 +395,7 @@ def inference_modelscope(
|
||||
output_dir_v2: Optional[str] = None,
|
||||
fs: dict = None,
|
||||
param_dict: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
# 3. Build data-iterator
|
||||
if data_path_and_name_and_type is None and raw_inputs is not None:
|
||||
|
||||
@ -523,6 +523,7 @@ def inference_modelscope(
|
||||
output_dir_v2: Optional[str] = None,
|
||||
fs: dict = None,
|
||||
param_dict: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
# 3. Build data-iterator
|
||||
if data_path_and_name_and_type is None and raw_inputs is not None:
|
||||
|
||||
@ -169,56 +169,8 @@ class Speech2Text:
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
# 6. [Optional] Build hotword list from str, local file or url
|
||||
# for None
|
||||
if hotword_list_or_file is None:
|
||||
self.hotword_list = None
|
||||
# for text str input
|
||||
elif not os.path.exists(hotword_list_or_file) and not hotword_list_or_file.startswith('http'):
|
||||
logging.info("Attempting to parse hotwords as str...")
|
||||
self.hotword_list = []
|
||||
hotword_str_list = []
|
||||
for hw in hotword_list_or_file.strip().split():
|
||||
hotword_str_list.append(hw)
|
||||
self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
|
||||
self.hotword_list.append([self.asr_model.sos])
|
||||
hotword_str_list.append('<s>')
|
||||
logging.info("Hotword list: {}.".format(hotword_str_list))
|
||||
# for local txt inputs
|
||||
elif os.path.exists(hotword_list_or_file):
|
||||
logging.info("Attempting to parse hotwords from local txt...")
|
||||
self.hotword_list = []
|
||||
hotword_str_list = []
|
||||
with codecs.open(hotword_list_or_file, 'r') as fin:
|
||||
for line in fin.readlines():
|
||||
hw = line.strip()
|
||||
hotword_str_list.append(hw)
|
||||
self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
|
||||
self.hotword_list.append([self.asr_model.sos])
|
||||
hotword_str_list.append('<s>')
|
||||
logging.info("Initialized hotword list from file: {}, hotword list: {}."
|
||||
.format(hotword_list_or_file, hotword_str_list))
|
||||
# for url, download and generate txt
|
||||
else:
|
||||
logging.info("Attempting to parse hotwords from url...")
|
||||
work_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(work_dir):
|
||||
os.makedirs(work_dir)
|
||||
text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
|
||||
local_file = requests.get(hotword_list_or_file)
|
||||
open(text_file_path, "wb").write(local_file.content)
|
||||
hotword_list_or_file = text_file_path
|
||||
self.hotword_list = []
|
||||
hotword_str_list = []
|
||||
with codecs.open(hotword_list_or_file, 'r') as fin:
|
||||
for line in fin.readlines():
|
||||
hw = line.strip()
|
||||
hotword_str_list.append(hw)
|
||||
self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
|
||||
self.hotword_list.append([self.asr_model.sos])
|
||||
hotword_str_list.append('<s>')
|
||||
logging.info("Initialized hotword list from file: {}, hotword list: {}."
|
||||
.format(hotword_list_or_file, hotword_str_list))
|
||||
|
||||
self.hotword_list = None
|
||||
self.hotword_list = self.generate_hotwords_list(hotword_list_or_file)
|
||||
|
||||
is_use_lm = lm_weight != 0.0 and lm_file is not None
|
||||
if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
|
||||
@ -337,6 +289,60 @@ class Speech2Text:
|
||||
# assert check_return_type(results)
|
||||
return results
|
||||
|
||||
def generate_hotwords_list(self, hotword_list_or_file):
|
||||
# for None
|
||||
if hotword_list_or_file is None:
|
||||
hotword_list = None
|
||||
# for local txt inputs
|
||||
elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'):
|
||||
logging.info("Attempting to parse hotwords from local txt...")
|
||||
hotword_list = []
|
||||
hotword_str_list = []
|
||||
with codecs.open(hotword_list_or_file, 'r') as fin:
|
||||
for line in fin.readlines():
|
||||
hw = line.strip()
|
||||
hotword_str_list.append(hw)
|
||||
hotword_list.append(self.converter.tokens2ids([i for i in hw]))
|
||||
hotword_list.append([self.asr_model.sos])
|
||||
hotword_str_list.append('<s>')
|
||||
logging.info("Initialized hotword list from file: {}, hotword list: {}."
|
||||
.format(hotword_list_or_file, hotword_str_list))
|
||||
# for url, download and generate txt
|
||||
elif hotword_list_or_file.startswith('http'):
|
||||
logging.info("Attempting to parse hotwords from url...")
|
||||
work_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(work_dir):
|
||||
os.makedirs(work_dir)
|
||||
text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
|
||||
local_file = requests.get(hotword_list_or_file)
|
||||
open(text_file_path, "wb").write(local_file.content)
|
||||
hotword_list_or_file = text_file_path
|
||||
hotword_list = []
|
||||
hotword_str_list = []
|
||||
with codecs.open(hotword_list_or_file, 'r') as fin:
|
||||
for line in fin.readlines():
|
||||
hw = line.strip()
|
||||
hotword_str_list.append(hw)
|
||||
hotword_list.append(self.converter.tokens2ids([i for i in hw]))
|
||||
hotword_list.append([self.asr_model.sos])
|
||||
hotword_str_list.append('<s>')
|
||||
logging.info("Initialized hotword list from file: {}, hotword list: {}."
|
||||
.format(hotword_list_or_file, hotword_str_list))
|
||||
# for text str input
|
||||
elif not hotword_list_or_file.endswith('.txt'):
|
||||
logging.info("Attempting to parse hotwords as str...")
|
||||
hotword_list = []
|
||||
hotword_str_list = []
|
||||
for hw in hotword_list_or_file.strip().split():
|
||||
hotword_str_list.append(hw)
|
||||
hotword_list.append(self.converter.tokens2ids([i for i in hw]))
|
||||
hotword_list.append([self.asr_model.sos])
|
||||
hotword_str_list.append('<s>')
|
||||
logging.info("Hotword list: {}.".format(hotword_str_list))
|
||||
else:
|
||||
hotword_list = None
|
||||
return hotword_list
|
||||
|
||||
class Speech2TextExport:
|
||||
"""Speech2TextExport class
|
||||
|
||||
@ -648,7 +654,19 @@ def inference_modelscope(
|
||||
output_dir_v2: Optional[str] = None,
|
||||
fs: dict = None,
|
||||
param_dict: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
hotword_list_or_file = None
|
||||
if param_dict is not None:
|
||||
hotword_list_or_file = param_dict.get('hotword')
|
||||
|
||||
if 'hotword' in kwargs:
|
||||
hotword_list_or_file = kwargs['hotword']
|
||||
|
||||
if speech2text.hotword_list is None:
|
||||
speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
|
||||
|
||||
# 3. Build data-iterator
|
||||
if data_path_and_name_and_type is None and raw_inputs is not None:
|
||||
if isinstance(raw_inputs, torch.Tensor):
|
||||
|
||||
@ -228,7 +228,19 @@ def inference_modelscope(
|
||||
output_dir_v2: Optional[str] = None,
|
||||
fs: dict = None,
|
||||
param_dict: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
hotword_list_or_file = None
|
||||
if param_dict is not None:
|
||||
hotword_list_or_file = param_dict.get('hotword')
|
||||
|
||||
if 'hotword' in kwargs:
|
||||
hotword_list_or_file = kwargs['hotword']
|
||||
|
||||
if speech2text.hotword_list is None:
|
||||
speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
|
||||
|
||||
# 3. Build data-iterator
|
||||
if data_path_and_name_and_type is None and raw_inputs is not None:
|
||||
if isinstance(raw_inputs, torch.Tensor):
|
||||
|
||||
@ -176,55 +176,8 @@ class Speech2Text:
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
# 6. [Optional] Build hotword list from str, local file or url
|
||||
# for None
|
||||
if hotword_list_or_file is None:
|
||||
self.hotword_list = None
|
||||
# for text str input
|
||||
elif not os.path.exists(hotword_list_or_file) and not hotword_list_or_file.startswith('http'):
|
||||
logging.info("Attempting to parse hotwords as str...")
|
||||
self.hotword_list = []
|
||||
hotword_str_list = []
|
||||
for hw in hotword_list_or_file.strip().split():
|
||||
hotword_str_list.append(hw)
|
||||
self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
|
||||
self.hotword_list.append([self.asr_model.sos])
|
||||
hotword_str_list.append('<s>')
|
||||
logging.info("Hotword list: {}.".format(hotword_str_list))
|
||||
# for local txt inputs
|
||||
elif os.path.exists(hotword_list_or_file):
|
||||
logging.info("Attempting to parse hotwords from local txt...")
|
||||
self.hotword_list = []
|
||||
hotword_str_list = []
|
||||
with codecs.open(hotword_list_or_file, 'r') as fin:
|
||||
for line in fin.readlines():
|
||||
hw = line.strip()
|
||||
hotword_str_list.append(hw)
|
||||
self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
|
||||
self.hotword_list.append([self.asr_model.sos])
|
||||
hotword_str_list.append('<s>')
|
||||
logging.info("Initialized hotword list from file: {}, hotword list: {}."
|
||||
.format(hotword_list_or_file, hotword_str_list))
|
||||
# for url, download and generate txt
|
||||
else:
|
||||
logging.info("Attempting to parse hotwords from url...")
|
||||
work_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(work_dir):
|
||||
os.makedirs(work_dir)
|
||||
text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
|
||||
local_file = requests.get(hotword_list_or_file)
|
||||
open(text_file_path, "wb").write(local_file.content)
|
||||
hotword_list_or_file = text_file_path
|
||||
self.hotword_list = []
|
||||
hotword_str_list = []
|
||||
with codecs.open(hotword_list_or_file, 'r') as fin:
|
||||
for line in fin.readlines():
|
||||
hw = line.strip()
|
||||
hotword_str_list.append(hw)
|
||||
self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
|
||||
self.hotword_list.append([self.asr_model.sos])
|
||||
hotword_str_list.append('<s>')
|
||||
logging.info("Initialized hotword list from file: {}, hotword list: {}."
|
||||
.format(hotword_list_or_file, hotword_str_list))
|
||||
self.hotword_list = None
|
||||
self.hotword_list = self.generate_hotwords_list(hotword_list_or_file)
|
||||
|
||||
is_use_lm = lm_weight != 0.0 and lm_file is not None
|
||||
if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
|
||||
@ -355,6 +308,59 @@ class Speech2Text:
|
||||
# assert check_return_type(results)
|
||||
return results
|
||||
|
||||
def generate_hotwords_list(self, hotword_list_or_file):
|
||||
# for None
|
||||
if hotword_list_or_file is None:
|
||||
hotword_list = None
|
||||
# for local txt inputs
|
||||
elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'):
|
||||
logging.info("Attempting to parse hotwords from local txt...")
|
||||
hotword_list = []
|
||||
hotword_str_list = []
|
||||
with codecs.open(hotword_list_or_file, 'r') as fin:
|
||||
for line in fin.readlines():
|
||||
hw = line.strip()
|
||||
hotword_str_list.append(hw)
|
||||
hotword_list.append(self.converter.tokens2ids([i for i in hw]))
|
||||
hotword_list.append([self.asr_model.sos])
|
||||
hotword_str_list.append('<s>')
|
||||
logging.info("Initialized hotword list from file: {}, hotword list: {}."
|
||||
.format(hotword_list_or_file, hotword_str_list))
|
||||
# for url, download and generate txt
|
||||
elif hotword_list_or_file.startswith('http'):
|
||||
logging.info("Attempting to parse hotwords from url...")
|
||||
work_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(work_dir):
|
||||
os.makedirs(work_dir)
|
||||
text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
|
||||
local_file = requests.get(hotword_list_or_file)
|
||||
open(text_file_path, "wb").write(local_file.content)
|
||||
hotword_list_or_file = text_file_path
|
||||
hotword_list = []
|
||||
hotword_str_list = []
|
||||
with codecs.open(hotword_list_or_file, 'r') as fin:
|
||||
for line in fin.readlines():
|
||||
hw = line.strip()
|
||||
hotword_str_list.append(hw)
|
||||
hotword_list.append(self.converter.tokens2ids([i for i in hw]))
|
||||
hotword_list.append([self.asr_model.sos])
|
||||
hotword_str_list.append('<s>')
|
||||
logging.info("Initialized hotword list from file: {}, hotword list: {}."
|
||||
.format(hotword_list_or_file, hotword_str_list))
|
||||
# for text str input
|
||||
elif not hotword_list_or_file.endswith('.txt'):
|
||||
logging.info("Attempting to parse hotwords as str...")
|
||||
hotword_list = []
|
||||
hotword_str_list = []
|
||||
for hw in hotword_list_or_file.strip().split():
|
||||
hotword_str_list.append(hw)
|
||||
hotword_list.append(self.converter.tokens2ids([i for i in hw]))
|
||||
hotword_list.append([self.asr_model.sos])
|
||||
hotword_str_list.append('<s>')
|
||||
logging.info("Hotword list: {}.".format(hotword_str_list))
|
||||
else:
|
||||
hotword_list = None
|
||||
return hotword_list
|
||||
|
||||
class Speech2VadSegment:
|
||||
"""Speech2VadSegment class
|
||||
@ -637,7 +643,19 @@ def inference_modelscope(
|
||||
output_dir_v2: Optional[str] = None,
|
||||
fs: dict = None,
|
||||
param_dict: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
hotword_list_or_file = None
|
||||
if param_dict is not None:
|
||||
hotword_list_or_file = param_dict.get('hotword')
|
||||
|
||||
if 'hotword' in kwargs:
|
||||
hotword_list_or_file = kwargs['hotword']
|
||||
|
||||
if speech2text.hotword_list is None:
|
||||
speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
|
||||
|
||||
# 3. Build data-iterator
|
||||
if data_path_and_name_and_type is None and raw_inputs is not None:
|
||||
if isinstance(raw_inputs, torch.Tensor):
|
||||
|
||||
@ -433,6 +433,7 @@ def inference_modelscope(
|
||||
output_dir_v2: Optional[str] = None,
|
||||
fs: dict = None,
|
||||
param_dict: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
# 3. Build data-iterator
|
||||
if data_path_and_name_and_type is None and raw_inputs is not None:
|
||||
|
||||
@ -433,6 +433,7 @@ def inference_modelscope(
|
||||
output_dir_v2: Optional[str] = None,
|
||||
fs: dict = None,
|
||||
param_dict: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
# 3. Build data-iterator
|
||||
if data_path_and_name_and_type is None and raw_inputs is not None:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user