Merge branch 'main' of github.com:alibaba-damo-academy/FunASR

add
This commit is contained in:
游雁 2023-03-24 11:39:40 +08:00
commit 0923d835cd
2 changed files with 99 additions and 155 deletions

View File

@ -1,57 +1,37 @@
import os
import logging
import torch
import torchaudio
import soundfile
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
import logging
logger = get_logger(log_level=logging.CRITICAL)
logger.setLevel(logging.CRITICAL)
os.environ["MODELSCOPE_CACHE"] = "./"
inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model='damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online',
model_revision='v1.0.2')
waveform, sample_rate = torchaudio.load("waihu.wav")
speech_length = waveform.shape[1]
speech = waveform[0]
model_dir = os.path.join(os.environ["MODELSCOPE_CACHE"], "damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online")
speech, sample_rate = soundfile.read(os.path.join(model_dir, "example/asr_example.wav"))
speech_length = speech.shape[0]
cache_en = {"start_idx": 0, "pad_left": 0, "stride": 10, "pad_right": 5, "cif_hidden": None, "cif_alphas": None}
cache_de = {"decode_fsmn": None}
cache = {"encoder": cache_en, "decoder": cache_de}
param_dict = {}
param_dict["cache"] = cache
first_chunk = True
speech_buffer = speech
speech_cache = []
sample_offset = 0
step = 4800 #300ms
param_dict = {"cache": dict(), "is_final": False}
final_result = ""
while len(speech_buffer) >= 960:
if first_chunk:
if len(speech_buffer) >= 14400:
rec_result = inference_pipeline(audio_in=speech_buffer[0:14400], param_dict=param_dict)
speech_buffer = speech_buffer[4800:]
else:
cache_en["stride"] = len(speech_buffer) // 960
cache_en["pad_right"] = 0
rec_result = inference_pipeline(audio_in=speech_buffer, param_dict=param_dict)
speech_buffer = []
cache_en["start_idx"] = -5
first_chunk = False
else:
cache_en["start_idx"] += 10
if len(speech_buffer) >= 4800:
cache_en["pad_left"] = 5
rec_result = inference_pipeline(audio_in=speech_buffer[:19200], param_dict=param_dict)
speech_buffer = speech_buffer[9600:]
else:
cache_en["stride"] = len(speech_buffer) // 960
cache_en["pad_right"] = 0
rec_result = inference_pipeline(audio_in=speech_buffer, param_dict=param_dict)
speech_buffer = []
if len(rec_result) !=0 and rec_result['text'] != "sil":
for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)):
if sample_offset + step >= speech_length - 1:
step = speech_length - sample_offset
param_dict["is_final"] = True
rec_result = inference_pipeline(audio_in=speech[sample_offset: sample_offset + step],
param_dict=param_dict)
if len(rec_result) != 0 and rec_result['text'] != "sil" and rec_result['text'] != "waiting_for_more_voice":
final_result += rec_result['text']
print(rec_result)
print(final_result)

View File

@ -544,11 +544,6 @@ def inference_modelscope(
)
export_mode = False
if param_dict is not None:
hotword_list_or_file = param_dict.get('hotword')
export_mode = param_dict.get("export_mode", False)
else:
hotword_list_or_file = None
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
@ -578,7 +573,6 @@ def inference_modelscope(
ngram_weight=ngram_weight,
penalty=penalty,
nbest=nbest,
hotword_list_or_file=hotword_list_or_file,
)
if export_mode:
speech2text = Speech2TextExport(**speech2text_kwargs)
@ -594,123 +588,92 @@ def inference_modelscope(
**kwargs,
):
hotword_list_or_file = None
if param_dict is not None:
hotword_list_or_file = param_dict.get('hotword')
if 'hotword' in kwargs:
hotword_list_or_file = kwargs['hotword']
if hotword_list_or_file is not None or 'hotword' in kwargs:
speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
loader = ASRTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,
fs=fs,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
if isinstance(raw_inputs, np.ndarray):
raw_inputs = torch.tensor(raw_inputs)
if param_dict is not None:
use_timestamp = param_dict.get('use_timestamp', True)
else:
use_timestamp = True
forward_time_total = 0.0
length_total = 0.0
finish_count = 0
file_count = 1
cache = None
is_final = False
if param_dict is not None and "cache" in param_dict:
cache = param_dict["cache"]
if param_dict is not None and "is_final" in param_dict:
is_final = param_dict["is_final"]
# 7 .Start for-loop
# FIXME(kamo): The output format should be discussed about
asr_result_list = []
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
if output_path is not None:
writer = DatadirWriter(output_path)
results = []
asr_result = ""
wait = True
if len(cache) == 0:
cache["encoder"] = {"start_idx": 0, "pad_left": 0, "stride": 10, "pad_right": 5, "cif_hidden": None, "cif_alphas": None}
cache_de = {"decode_fsmn": None}
cache["decoder"] = cache_de
cache["first_chunk"] = True
cache["speech"] = []
cache["chunk_index"] = 0
cache["speech_chunk"] = []
if raw_inputs is not None:
if len(cache["speech"]) == 0:
cache["speech"] = raw_inputs
else:
writer = None
if param_dict is not None and "cache" in param_dict:
cache = param_dict["cache"]
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
# batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")}
logging.info("decoding, utt_id: {}".format(keys))
# N-best list of (text, token, token_int, hyp_object)
time_beg = time.time()
results = speech2text(cache=cache, **batch)
if len(results) < 1:
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest
time_end = time.time()
forward_time = time_end - time_beg
lfr_factor = results[0][-1]
length = results[0][-2]
forward_time_total += forward_time
length_total += length
rtf_cur = "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".format(length, forward_time,
100 * forward_time / (
length * lfr_factor))
logging.info(rtf_cur)
for batch_id in range(_bs):
result = [results[batch_id][:-2]]
key = keys[batch_id]
for n, result in zip(range(1, nbest + 1), result):
text, token, token_int, hyp = result[0], result[1], result[2], result[3]
time_stamp = None if len(result) < 5 else result[4]
# Create a directory: outdir/{n}best_recog
if writer is not None:
ibest_writer = writer[f"{n}best_recog"]
# Write the result to each file
ibest_writer["token"][key] = " ".join(token)
# ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["score"][key] = str(hyp.score)
ibest_writer["rtf"][key] = rtf_cur
if text is not None:
if use_timestamp and time_stamp is not None:
postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
cache["speech"] = torch.cat([cache["speech"], raw_inputs], dim=0)
if len(cache["speech_chunk"]) == 0:
cache["speech_chunk"] = raw_inputs
else:
postprocessed_result = postprocess_utils.sentence_postprocess(token)
time_stamp_postprocessed = ""
if len(postprocessed_result) == 3:
text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \
postprocessed_result[1], \
postprocessed_result[2]
cache["speech_chunk"] = torch.cat([cache["speech_chunk"], raw_inputs], dim=0)
while len(cache["speech_chunk"]) >= 960:
if cache["first_chunk"]:
if len(cache["speech_chunk"]) >= 14400:
speech = torch.unsqueeze(cache["speech_chunk"][0:14400], axis=0)
speech_length = torch.tensor([14400])
results = speech2text(cache, speech, speech_length)
cache["speech_chunk"]= cache["speech_chunk"][4800:]
cache["first_chunk"] = False
cache["encoder"]["start_idx"] = -5
wait = False
else:
text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
item = {'key': key, 'value': text_postprocessed}
if time_stamp_postprocessed != "":
item['time_stamp'] = time_stamp_postprocessed
if is_final:
cache["encoder"]["stride"] = len(cache["speech_chunk"]) // 960
cache["encoder"]["pad_right"] = 0
speech = torch.unsqueeze(cache["speech_chunk"], axis=0)
speech_length = torch.tensor([len(cache["speech_chunk"])])
results = speech2text(cache, speech, speech_length)
cache["speech_chunk"] = []
wait = False
else:
break
else:
if len(cache["speech_chunk"]) >= 19200:
cache["encoder"]["start_idx"] += 10
cache["encoder"]["pad_left"] = 5
speech = torch.unsqueeze(cache["speech_chunk"][:19200], axis=0)
speech_length = torch.tensor([19200])
results = speech2text(cache, speech, speech_length)
cache["speech_chunk"] = cache["speech_chunk"][9600:]
wait = False
else:
if is_final:
cache["encoder"]["stride"] = len(cache["speech_chunk"]) // 960
cache["encoder"]["pad_right"] = 0
speech = torch.unsqueeze(cache["speech_chunk"], axis=0)
speech_length = torch.tensor([len(cache["speech_chunk"])])
results = speech2text(cache, speech, speech_length)
cache["speech_chunk"] = []
wait = False
else:
break
if len(results) >= 1:
asr_result += results[0][0]
if asr_result == "":
asr_result = "sil"
if wait:
asr_result = "waiting_for_more_voice"
item = {'key': "utt", 'value': asr_result}
asr_result_list.append(item)
finish_count += 1
# asr_utils.print_progress(finish_count / file_count)
if writer is not None:
ibest_writer["text"][key] = text_postprocessed
logging.info("decoding, utt: {}, predictions: {}".format(key, text))
rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total,
forward_time_total,
100 * forward_time_total / (
length_total * lfr_factor))
logging.info(rtf_avg)
if writer is not None:
ibest_writer["rtf"]["rtf_avf"] = rtf_avg
else:
return []
return asr_result_list
return _forward
@ -905,3 +868,4 @@ if __name__ == "__main__":
# rec_result = inference_16k_pipline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
# print(rec_result)