mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
sensevoice
This commit is contained in:
parent
0033151b62
commit
ada76b6312
25
examples/industrial_data_pretraining/sense_voice/demo_ctc.py
Normal file
25
examples/industrial_data_pretraining/sense_voice/demo_ctc.py
Normal file
@ -0,0 +1,25 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- encoding: utf-8 -*-
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
import sys
|
||||
from funasr import AutoModel
|
||||
|
||||
model_dir = "/nfs/beinian.lzr/workspace/models/funasr_results/asr/sense_voice/sensevoice_sanm_ctc"
|
||||
input_file = (
|
||||
"https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
|
||||
)
|
||||
|
||||
model = AutoModel(
|
||||
model=model_dir,
|
||||
)
|
||||
|
||||
res = model.generate(
|
||||
input=input_file,
|
||||
cache={},
|
||||
language="zh",
|
||||
text_norm="wotextnorm",
|
||||
)
|
||||
|
||||
print(res)
|
||||
@ -1454,6 +1454,10 @@ class SenseVoiceSANMCTC(nn.Module):
|
||||
self.length_normalized_loss = length_normalized_loss
|
||||
self.encoder_output_size = encoder_output_size
|
||||
|
||||
self.lid_dict = {"zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13}
|
||||
self.textnorm_dict = {"withtextnorm": 14, "wotextnorm": 15}
|
||||
self.embed = torch.nn.Embedding(8 + len(self.lid_dict) + len(self.textnorm_dict), 560)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
@ -1587,6 +1591,22 @@ class SenseVoiceSANMCTC(nn.Module):
|
||||
|
||||
speech = speech.to(device=kwargs["device"])
|
||||
speech_lengths = speech_lengths.to(device=kwargs["device"])
|
||||
|
||||
language = kwargs.get("language", None)
|
||||
if language is not None:
|
||||
language_query = self.embed(torch.LongTensor([[self.lid_dict[language] if language in self.lid_dict else 0]]).to(speech.device)).repeat(speech.size(0), 1, 1)
|
||||
else:
|
||||
language_query = self.embed(torch.LongTensor([[0]]).to(speech.device)).repeat(speech.size(0), 1, 1)
|
||||
textnorm = kwargs.get("text_norm", "wotextnorm")
|
||||
textnorm_query = self.embed(torch.LongTensor([[self.textnorm_dict[textnorm]]]).to(speech.device)).repeat(speech.size(0), 1, 1)
|
||||
speech = torch.cat((textnorm_query, speech), dim=1)
|
||||
speech_lengths += 1
|
||||
|
||||
event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(speech.size(0), 1, 1)
|
||||
input_query = torch.cat((language_query, event_emo_query), dim=1)
|
||||
speech = torch.cat((input_query, speech), dim=1)
|
||||
speech_lengths += 3
|
||||
|
||||
# Encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
if isinstance(encoder_out, tuple):
|
||||
@ -1630,11 +1650,9 @@ class SenseVoiceSANMCTC(nn.Module):
|
||||
)
|
||||
|
||||
# Change integer-ids to tokens
|
||||
token = tokenizer.ids2tokens(token_int)
|
||||
text = tokenizer.tokens2text(token)
|
||||
text = tokenizer.decode(token_int)
|
||||
|
||||
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
|
||||
result_i = {"key": key[i], "token": token, "text": text_postprocessed}
|
||||
result_i = {"key": key[i], "text": text}
|
||||
results.append(result_i)
|
||||
|
||||
if ibest_writer is not None:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user