mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge branch 'main' of github.com:alibaba-damo-academy/FunASR
add
This commit is contained in:
commit
0923d835cd
@ -1,57 +1,37 @@
|
||||
import os
|
||||
import logging
|
||||
import torch
|
||||
import torchaudio
|
||||
import soundfile
|
||||
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
import logging
|
||||
|
||||
logger = get_logger(log_level=logging.CRITICAL)
|
||||
logger.setLevel(logging.CRITICAL)
|
||||
|
||||
os.environ["MODELSCOPE_CACHE"] = "./"
|
||||
inference_pipeline = pipeline(
|
||||
task=Tasks.auto_speech_recognition,
|
||||
model='damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online',
|
||||
model_revision='v1.0.2')
|
||||
|
||||
waveform, sample_rate = torchaudio.load("waihu.wav")
|
||||
speech_length = waveform.shape[1]
|
||||
speech = waveform[0]
|
||||
model_dir = os.path.join(os.environ["MODELSCOPE_CACHE"], "damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online")
|
||||
speech, sample_rate = soundfile.read(os.path.join(model_dir, "example/asr_example.wav"))
|
||||
speech_length = speech.shape[0]
|
||||
|
||||
cache_en = {"start_idx": 0, "pad_left": 0, "stride": 10, "pad_right": 5, "cif_hidden": None, "cif_alphas": None}
|
||||
cache_de = {"decode_fsmn": None}
|
||||
cache = {"encoder": cache_en, "decoder": cache_de}
|
||||
param_dict = {}
|
||||
param_dict["cache"] = cache
|
||||
|
||||
first_chunk = True
|
||||
speech_buffer = speech
|
||||
speech_cache = []
|
||||
sample_offset = 0
|
||||
step = 4800 #300ms
|
||||
param_dict = {"cache": dict(), "is_final": False}
|
||||
final_result = ""
|
||||
|
||||
while len(speech_buffer) >= 960:
|
||||
if first_chunk:
|
||||
if len(speech_buffer) >= 14400:
|
||||
rec_result = inference_pipeline(audio_in=speech_buffer[0:14400], param_dict=param_dict)
|
||||
speech_buffer = speech_buffer[4800:]
|
||||
else:
|
||||
cache_en["stride"] = len(speech_buffer) // 960
|
||||
cache_en["pad_right"] = 0
|
||||
rec_result = inference_pipeline(audio_in=speech_buffer, param_dict=param_dict)
|
||||
speech_buffer = []
|
||||
cache_en["start_idx"] = -5
|
||||
first_chunk = False
|
||||
else:
|
||||
cache_en["start_idx"] += 10
|
||||
if len(speech_buffer) >= 4800:
|
||||
cache_en["pad_left"] = 5
|
||||
rec_result = inference_pipeline(audio_in=speech_buffer[:19200], param_dict=param_dict)
|
||||
speech_buffer = speech_buffer[9600:]
|
||||
else:
|
||||
cache_en["stride"] = len(speech_buffer) // 960
|
||||
cache_en["pad_right"] = 0
|
||||
rec_result = inference_pipeline(audio_in=speech_buffer, param_dict=param_dict)
|
||||
speech_buffer = []
|
||||
if len(rec_result) !=0 and rec_result['text'] != "sil":
|
||||
for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)):
|
||||
if sample_offset + step >= speech_length - 1:
|
||||
step = speech_length - sample_offset
|
||||
param_dict["is_final"] = True
|
||||
rec_result = inference_pipeline(audio_in=speech[sample_offset: sample_offset + step],
|
||||
param_dict=param_dict)
|
||||
if len(rec_result) != 0 and rec_result['text'] != "sil" and rec_result['text'] != "waiting_for_more_voice":
|
||||
final_result += rec_result['text']
|
||||
print(rec_result)
|
||||
print(final_result)
|
||||
|
||||
@ -544,11 +544,6 @@ def inference_modelscope(
|
||||
)
|
||||
|
||||
export_mode = False
|
||||
if param_dict is not None:
|
||||
hotword_list_or_file = param_dict.get('hotword')
|
||||
export_mode = param_dict.get("export_mode", False)
|
||||
else:
|
||||
hotword_list_or_file = None
|
||||
|
||||
if ngpu >= 1 and torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
@ -578,7 +573,6 @@ def inference_modelscope(
|
||||
ngram_weight=ngram_weight,
|
||||
penalty=penalty,
|
||||
nbest=nbest,
|
||||
hotword_list_or_file=hotword_list_or_file,
|
||||
)
|
||||
if export_mode:
|
||||
speech2text = Speech2TextExport(**speech2text_kwargs)
|
||||
@ -594,123 +588,92 @@ def inference_modelscope(
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
hotword_list_or_file = None
|
||||
if param_dict is not None:
|
||||
hotword_list_or_file = param_dict.get('hotword')
|
||||
if 'hotword' in kwargs:
|
||||
hotword_list_or_file = kwargs['hotword']
|
||||
if hotword_list_or_file is not None or 'hotword' in kwargs:
|
||||
speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
|
||||
|
||||
# 3. Build data-iterator
|
||||
if data_path_and_name_and_type is None and raw_inputs is not None:
|
||||
if isinstance(raw_inputs, torch.Tensor):
|
||||
raw_inputs = raw_inputs.numpy()
|
||||
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
|
||||
loader = ASRTask.build_streaming_iterator(
|
||||
data_path_and_name_and_type,
|
||||
dtype=dtype,
|
||||
fs=fs,
|
||||
batch_size=batch_size,
|
||||
key_file=key_file,
|
||||
num_workers=num_workers,
|
||||
preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
|
||||
collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
|
||||
allow_variable_data_keys=allow_variable_data_keys,
|
||||
inference=True,
|
||||
)
|
||||
if isinstance(raw_inputs, np.ndarray):
|
||||
raw_inputs = torch.tensor(raw_inputs)
|
||||
|
||||
if param_dict is not None:
|
||||
use_timestamp = param_dict.get('use_timestamp', True)
|
||||
else:
|
||||
use_timestamp = True
|
||||
|
||||
forward_time_total = 0.0
|
||||
length_total = 0.0
|
||||
finish_count = 0
|
||||
file_count = 1
|
||||
cache = None
|
||||
is_final = False
|
||||
if param_dict is not None and "cache" in param_dict:
|
||||
cache = param_dict["cache"]
|
||||
if param_dict is not None and "is_final" in param_dict:
|
||||
is_final = param_dict["is_final"]
|
||||
# 7 .Start for-loop
|
||||
# FIXME(kamo): The output format should be discussed about
|
||||
asr_result_list = []
|
||||
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
|
||||
if output_path is not None:
|
||||
writer = DatadirWriter(output_path)
|
||||
results = []
|
||||
asr_result = ""
|
||||
wait = True
|
||||
if len(cache) == 0:
|
||||
cache["encoder"] = {"start_idx": 0, "pad_left": 0, "stride": 10, "pad_right": 5, "cif_hidden": None, "cif_alphas": None}
|
||||
cache_de = {"decode_fsmn": None}
|
||||
cache["decoder"] = cache_de
|
||||
cache["first_chunk"] = True
|
||||
cache["speech"] = []
|
||||
cache["chunk_index"] = 0
|
||||
cache["speech_chunk"] = []
|
||||
|
||||
if raw_inputs is not None:
|
||||
if len(cache["speech"]) == 0:
|
||||
cache["speech"] = raw_inputs
|
||||
else:
|
||||
writer = None
|
||||
if param_dict is not None and "cache" in param_dict:
|
||||
cache = param_dict["cache"]
|
||||
for keys, batch in loader:
|
||||
assert isinstance(batch, dict), type(batch)
|
||||
assert all(isinstance(s, str) for s in keys), keys
|
||||
_bs = len(next(iter(batch.values())))
|
||||
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
|
||||
# batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")}
|
||||
logging.info("decoding, utt_id: {}".format(keys))
|
||||
# N-best list of (text, token, token_int, hyp_object)
|
||||
|
||||
time_beg = time.time()
|
||||
results = speech2text(cache=cache, **batch)
|
||||
if len(results) < 1:
|
||||
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
|
||||
results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest
|
||||
time_end = time.time()
|
||||
forward_time = time_end - time_beg
|
||||
lfr_factor = results[0][-1]
|
||||
length = results[0][-2]
|
||||
forward_time_total += forward_time
|
||||
length_total += length
|
||||
rtf_cur = "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".format(length, forward_time,
|
||||
100 * forward_time / (
|
||||
length * lfr_factor))
|
||||
logging.info(rtf_cur)
|
||||
|
||||
for batch_id in range(_bs):
|
||||
result = [results[batch_id][:-2]]
|
||||
|
||||
key = keys[batch_id]
|
||||
for n, result in zip(range(1, nbest + 1), result):
|
||||
text, token, token_int, hyp = result[0], result[1], result[2], result[3]
|
||||
time_stamp = None if len(result) < 5 else result[4]
|
||||
# Create a directory: outdir/{n}best_recog
|
||||
if writer is not None:
|
||||
ibest_writer = writer[f"{n}best_recog"]
|
||||
|
||||
# Write the result to each file
|
||||
ibest_writer["token"][key] = " ".join(token)
|
||||
# ibest_writer["token_int"][key] = " ".join(map(str, token_int))
|
||||
ibest_writer["score"][key] = str(hyp.score)
|
||||
ibest_writer["rtf"][key] = rtf_cur
|
||||
|
||||
if text is not None:
|
||||
if use_timestamp and time_stamp is not None:
|
||||
postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
|
||||
cache["speech"] = torch.cat([cache["speech"], raw_inputs], dim=0)
|
||||
if len(cache["speech_chunk"]) == 0:
|
||||
cache["speech_chunk"] = raw_inputs
|
||||
else:
|
||||
postprocessed_result = postprocess_utils.sentence_postprocess(token)
|
||||
time_stamp_postprocessed = ""
|
||||
if len(postprocessed_result) == 3:
|
||||
text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \
|
||||
postprocessed_result[1], \
|
||||
postprocessed_result[2]
|
||||
cache["speech_chunk"] = torch.cat([cache["speech_chunk"], raw_inputs], dim=0)
|
||||
while len(cache["speech_chunk"]) >= 960:
|
||||
if cache["first_chunk"]:
|
||||
if len(cache["speech_chunk"]) >= 14400:
|
||||
speech = torch.unsqueeze(cache["speech_chunk"][0:14400], axis=0)
|
||||
speech_length = torch.tensor([14400])
|
||||
results = speech2text(cache, speech, speech_length)
|
||||
cache["speech_chunk"]= cache["speech_chunk"][4800:]
|
||||
cache["first_chunk"] = False
|
||||
cache["encoder"]["start_idx"] = -5
|
||||
wait = False
|
||||
else:
|
||||
text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
|
||||
item = {'key': key, 'value': text_postprocessed}
|
||||
if time_stamp_postprocessed != "":
|
||||
item['time_stamp'] = time_stamp_postprocessed
|
||||
if is_final:
|
||||
cache["encoder"]["stride"] = len(cache["speech_chunk"]) // 960
|
||||
cache["encoder"]["pad_right"] = 0
|
||||
speech = torch.unsqueeze(cache["speech_chunk"], axis=0)
|
||||
speech_length = torch.tensor([len(cache["speech_chunk"])])
|
||||
results = speech2text(cache, speech, speech_length)
|
||||
cache["speech_chunk"] = []
|
||||
wait = False
|
||||
else:
|
||||
break
|
||||
else:
|
||||
if len(cache["speech_chunk"]) >= 19200:
|
||||
cache["encoder"]["start_idx"] += 10
|
||||
cache["encoder"]["pad_left"] = 5
|
||||
speech = torch.unsqueeze(cache["speech_chunk"][:19200], axis=0)
|
||||
speech_length = torch.tensor([19200])
|
||||
results = speech2text(cache, speech, speech_length)
|
||||
cache["speech_chunk"] = cache["speech_chunk"][9600:]
|
||||
wait = False
|
||||
else:
|
||||
if is_final:
|
||||
cache["encoder"]["stride"] = len(cache["speech_chunk"]) // 960
|
||||
cache["encoder"]["pad_right"] = 0
|
||||
speech = torch.unsqueeze(cache["speech_chunk"], axis=0)
|
||||
speech_length = torch.tensor([len(cache["speech_chunk"])])
|
||||
results = speech2text(cache, speech, speech_length)
|
||||
cache["speech_chunk"] = []
|
||||
wait = False
|
||||
else:
|
||||
break
|
||||
|
||||
if len(results) >= 1:
|
||||
asr_result += results[0][0]
|
||||
if asr_result == "":
|
||||
asr_result = "sil"
|
||||
if wait:
|
||||
asr_result = "waiting_for_more_voice"
|
||||
item = {'key': "utt", 'value': asr_result}
|
||||
asr_result_list.append(item)
|
||||
finish_count += 1
|
||||
# asr_utils.print_progress(finish_count / file_count)
|
||||
if writer is not None:
|
||||
ibest_writer["text"][key] = text_postprocessed
|
||||
|
||||
logging.info("decoding, utt: {}, predictions: {}".format(key, text))
|
||||
rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total,
|
||||
forward_time_total,
|
||||
100 * forward_time_total / (
|
||||
length_total * lfr_factor))
|
||||
logging.info(rtf_avg)
|
||||
if writer is not None:
|
||||
ibest_writer["rtf"]["rtf_avf"] = rtf_avg
|
||||
else:
|
||||
return []
|
||||
return asr_result_list
|
||||
|
||||
return _forward
|
||||
@ -905,3 +868,4 @@ if __name__ == "__main__":
|
||||
# rec_result = inference_16k_pipline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
|
||||
# print(rec_result)
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user