update readme, fix seaco bug

This commit is contained in:
shixian.shi 2024-01-15 20:10:39 +08:00
parent 97d648c255
commit 55c09aeaa2
5 changed files with 22 additions and 19 deletions

View File

@ -90,12 +90,15 @@ Notes: Support recognition of single audio file, as well as file list in Kaldi-s
### Speech Recognition (Non-streaming) ### Speech Recognition (Non-streaming)
```python ```python
from funasr import AutoModel from funasr import AutoModel
# paraformer-zh is a multi-functional asr model
model = AutoModel(model="paraformer-zh") # use vad, punc, spk or not as you need
# for the long duration wav, you could add vad model model = AutoModel(model="paraformer-zh", model_revision="v2.0.2", \
# model = AutoModel(model="paraformer-zh", vad_model="fsmn-vad", punc_model="ct-punc") vad_model="fsmn-vad", vad_model_revision="v2.0.2", \
punc_model="ct-punc-c", punc_model_revision="v2.0.2", \
res = model(input="asr_example_zh.wav", batch_size=64) spk_model="cam++", spk_model_revision="v2.0.2")
res = model(input=f"{model.model_path}/example/asr_example.wav",
batch_size=16,
hotword='魔搭')
print(res) print(res)
``` ```
Note: `model_hub`: represents the model repository, `ms` stands for selecting ModelScope download, `hf` stands for selecting Huggingface download. Note: `model_hub`: represents the model repository, `ms` stands for selecting ModelScope download, `hf` stands for selecting Huggingface download.
@ -108,7 +111,7 @@ chunk_size = [0, 10, 5] #[0, 10, 5] 600ms, [0, 8, 4] 480ms
encoder_chunk_look_back = 4 #number of chunks to lookback for encoder self-attention encoder_chunk_look_back = 4 #number of chunks to lookback for encoder self-attention
decoder_chunk_look_back = 1 #number of encoder chunks to lookback for decoder cross-attention decoder_chunk_look_back = 1 #number of encoder chunks to lookback for decoder cross-attention
model = AutoModel(model="paraformer-zh-streaming", model_revision="v2.0.0") model = AutoModel(model="paraformer-zh-streaming", model_revision="v2.0.2")
import soundfile import soundfile
import os import os
@ -163,7 +166,7 @@ for i in range(total_chunk_num):
```python ```python
from funasr import AutoModel from funasr import AutoModel
model = AutoModel(model="ct-punc", model_revision="v2.0.1") model = AutoModel(model="ct-punc", model_revision="v2.0.2")
res = model(input="那今天的会就到这里吧 happy new year 明年见") res = model(input="那今天的会就到这里吧 happy new year 明年见")
print(res) print(res)
@ -172,7 +175,7 @@ print(res)
```python ```python
from funasr import AutoModel from funasr import AutoModel
model = AutoModel(model="fa-zh", model_revision="v2.0.0") model = AutoModel(model="fa-zh", model_revision="v2.0.2")
wav_file = f"{model.model_path}/example/asr_example.wav" wav_file = f"{model.model_path}/example/asr_example.wav"
text_file = f"{model.model_path}/example/asr_example.wav" text_file = f"{model.model_path}/example/asr_example.wav"

View File

@ -11,8 +11,10 @@ model = AutoModel(model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-co
vad_model_revision="v2.0.2", vad_model_revision="v2.0.2",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
punc_model_revision="v2.0.2", punc_model_revision="v2.0.2",
spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
spk_model="v2.0.2",
) )
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", res = model(input=f"{model.model_path}/example/asr_example.wav",
hotword='达摩院 ') hotword='达摩院 ')
print(res) print(res)

View File

@ -1,14 +1,13 @@
name_maps_ms = { name_maps_ms = {
"paraformer-zh": "damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch", "paraformer-zh": "damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
"paraformer-zh-spk": "damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn",
"paraformer-en": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020", "paraformer-en": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
"paraformer-en-spk": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020", "paraformer-en-spk": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
"paraformer-zh-streaming": "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", "paraformer-zh-streaming": "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
"fsmn-vad": "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", "fsmn-vad": "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
"ct-punc": "damo/punc_ct-transformer_cn-en-common-vocab471067-large", "ct-punc": "damo/punc_ct-transformer_cn-en-common-vocab471067-large",
"ct-punc-c": "damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
"fa-zh": "damo/speech_timestamp_prediction-v1-16k-offline", "fa-zh": "damo/speech_timestamp_prediction-v1-16k-offline",
"cam++": "damo/speech_campplus_sv_zh-cn_16k-common",
} }
name_maps_hf = { name_maps_hf = {

View File

@ -344,7 +344,6 @@ class CTTransformer(torch.nn.Module):
punc_array = punctuations punc_array = punctuations
else: else:
punc_array = torch.cat([punc_array, punctuations], dim=0) punc_array = torch.cat([punc_array, punctuations], dim=0)
result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array} result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array}
results.append(result_i) results.append(result_i)

View File

@ -212,7 +212,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
ys_pad_lens, ys_pad_lens,
hw_list, hw_list,
nfilter=50, nfilter=50,
seaco_weight=1.0): seaco_weight=1.0):
# decoder forward # decoder forward
decoder_out, decoder_hidden, _ = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, return_hidden=True, return_both=True) decoder_out, decoder_hidden, _ = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, return_hidden=True, return_both=True)
decoder_pred = torch.log_softmax(decoder_out, dim=-1) decoder_pred = torch.log_softmax(decoder_out, dim=-1)
@ -254,10 +254,9 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
dha_output = self.hotword_output_layer(merged) # remove the last token in loss calculation dha_output = self.hotword_output_layer(merged) # remove the last token in loss calculation
dha_pred = torch.log_softmax(dha_output, dim=-1) dha_pred = torch.log_softmax(dha_output, dim=-1)
# import pdb; pdb.set_trace()
def _merge_res(dec_output, dha_output): def _merge_res(dec_output, dha_output):
lmbd = torch.Tensor([seaco_weight] * dha_output.shape[0]) lmbd = torch.Tensor([seaco_weight] * dha_output.shape[0])
dha_ids = dha_output.max(-1)[-1][0] dha_ids = dha_output.max(-1)[-1]# [0]
dha_mask = (dha_ids == 8377).int().unsqueeze(-1) dha_mask = (dha_ids == 8377).int().unsqueeze(-1)
a = (1 - lmbd) / lmbd a = (1 - lmbd) / lmbd
b = 1 / lmbd b = 1 / lmbd
@ -267,6 +266,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
logits = dec_output * dha_mask + dha_output[:,:,:] * (1-dha_mask) logits = dec_output * dha_mask + dha_output[:,:,:] * (1-dha_mask)
return logits return logits
merged_pred = _merge_res(decoder_pred, dha_pred) merged_pred = _merge_res(decoder_pred, dha_pred)
# import pdb; pdb.set_trace()
return merged_pred return merged_pred
else: else:
return decoder_pred return decoder_pred