diff --git a/examples/industrial_data_pretraining/sense_voice/demo_ctc.py b/examples/industrial_data_pretraining/sense_voice/demo_ctc.py new file mode 100644 index 000000000..b079cb379 --- /dev/null +++ b/examples/industrial_data_pretraining/sense_voice/demo_ctc.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +import sys +from funasr import AutoModel + +model_dir = "/nfs/beinian.lzr/workspace/models/funasr_results/asr/sense_voice/sensevoice_sanm_ctc" +input_file = ( + "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" +) + +model = AutoModel( + model=model_dir, +) + +res = model.generate( + input=input_file, + cache={}, + language="zh", + text_norm="wotextnorm", +) + +print(res) diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py index c77930d39..697f50c00 100644 --- a/funasr/models/sense_voice/model.py +++ b/funasr/models/sense_voice/model.py @@ -1454,6 +1454,10 @@ class SenseVoiceSANMCTC(nn.Module): self.length_normalized_loss = length_normalized_loss self.encoder_output_size = encoder_output_size + self.lid_dict = {"zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13} + self.textnorm_dict = {"withtextnorm": 14, "wotextnorm": 15} + self.embed = torch.nn.Embedding(8 + len(self.lid_dict) + len(self.textnorm_dict), 560) + def forward( self, speech: torch.Tensor, @@ -1587,6 +1591,22 @@ class SenseVoiceSANMCTC(nn.Module): speech = speech.to(device=kwargs["device"]) speech_lengths = speech_lengths.to(device=kwargs["device"]) + + language = kwargs.get("language", None) + if language is not None: + language_query = self.embed(torch.LongTensor([[self.lid_dict[language] if language in self.lid_dict else 0]]).to(speech.device)).repeat(speech.size(0), 1, 1) + else: + language_query = self.embed(torch.LongTensor([[0]]).to(speech.device)).repeat(speech.size(0), 1, 1) + textnorm = kwargs.get("text_norm", "wotextnorm") + textnorm_query = self.embed(torch.LongTensor([[self.textnorm_dict[textnorm]]]).to(speech.device)).repeat(speech.size(0), 1, 1) + speech = torch.cat((textnorm_query, speech), dim=1) + speech_lengths += 1 + + event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(speech.size(0), 1, 1) + input_query = torch.cat((language_query, event_emo_query), dim=1) + speech = torch.cat((input_query, speech), dim=1) + speech_lengths += 3 + # Encoder encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) if isinstance(encoder_out, tuple): @@ -1630,11 +1650,9 @@ class SenseVoiceSANMCTC(nn.Module): ) # Change integer-ids to tokens - token = tokenizer.ids2tokens(token_int) - text = tokenizer.tokens2text(token) + text = tokenizer.decode(token_int) - text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) - result_i = {"key": key[i], "token": token, "text": text_postprocessed} + result_i = {"key": key[i], "text": text} results.append(result_i) if ibest_writer is not None: