From 292d34f2366d05a249088db844fc177e61708281 Mon Sep 17 00:00:00 2001 From: "shixian.shi" Date: Thu, 25 Jan 2024 17:23:54 +0800 Subject: [PATCH] Bug fix for res combine --- .../seaco_paraformer/demo.py | 6 ++++-- funasr/auto/auto_model.py | 20 +++++++++---------- funasr/models/contextual_paraformer/model.py | 2 -- funasr/models/seaco_paraformer/model.py | 2 -- 4 files changed, 14 insertions(+), 16 deletions(-) diff --git a/examples/industrial_data_pretraining/seaco_paraformer/demo.py b/examples/industrial_data_pretraining/seaco_paraformer/demo.py index 8385ccc67..bba5268ba 100644 --- a/examples/industrial_data_pretraining/seaco_paraformer/demo.py +++ b/examples/industrial_data_pretraining/seaco_paraformer/demo.py @@ -15,6 +15,8 @@ model = AutoModel(model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-com # spk_model_revision="v2.0.2", ) -res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", - hotword='达摩院 魔搭') +res = model.generate(input="/Users/shixian/Downloads/output_16000.wav", + hotword='达摩院 魔搭', + # sentence_timestamp=True, + ) print(res) \ No newline at end of file diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index e774a8ffc..4d0f3021c 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -123,7 +123,6 @@ class AutoModel: self.preset_spk_num = kwargs.get("preset_spk_num", None) if self.preset_spk_num: logging.warning("Using preset speaker number: {}".format(self.preset_spk_num)) - logging.warning("Many to print when using speaker model...") self.kwargs = kwargs self.model = model @@ -329,8 +328,6 @@ class AutoModel: speech_j, speech_lengths_j = slice_padding_audio_samples(speech, speech_lengths, sorted_data[beg_idx:end_idx]) results = self.inference(speech_j, input_len=None, model=model, kwargs=kwargs, disable_pbar=True, **cfg) if self.spk_model is not None: - - # compose vad segments: [[start_time_sec, end_time_sec, speech], [...]] for _b in range(len(speech_j)): vad_segments = [[sorted_data[beg_idx:end_idx][_b][0][0]/1000.0, @@ -345,8 +342,6 @@ class AutoModel: if len(results) < 1: continue results_sorted.extend(results) - - end_asr_total = time.time() time_escape_total_per_sample = end_asr_total - beg_asr_total @@ -355,7 +350,6 @@ class AutoModel: f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, " f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}") - restored_data = [0] * n for j in range(n): index = sorted_data[j][1] @@ -378,7 +372,7 @@ class AutoModel: result[k] = restored_data[j][k] else: result[k] = torch.cat([result[k], restored_data[j][k]], dim=0) - elif k == 'raw_text': + elif 'text' in k: if k not in result: result[k] = restored_data[j][k] else: @@ -393,8 +387,9 @@ class AutoModel: if self.punc_model is not None: self.punc_kwargs.update(cfg) punc_res = self.inference(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg) + import copy; raw_text = copy.copy(result["text"]) result["text"] = punc_res[0]["text"] - + # speaker embedding cluster after resorted if self.spk_model is not None: all_segments = sorted(all_segments, key=lambda x: x[0]) @@ -402,19 +397,24 @@ class AutoModel: labels = self.cb_model(spk_embedding.cpu(), oracle_num=self.preset_spk_num) del result['spk_embedding'] sv_output = postprocess(all_segments, None, labels, spk_embedding.cpu()) - if self.spk_mode == 'vad_segment': + if self.spk_mode == 'vad_segment': # recover sentence_list sentence_list = [] for res, vadsegment in zip(restored_data, vadsegments): sentence_list.append({"start": vadsegment[0],\ "end": vadsegment[1], "sentence": res['raw_text'], "timestamp": res['timestamp']}) - else: # punc_segment + elif self.spk_mode == 'punc_segment': sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \ result['timestamp'], \ result['raw_text']) distribute_spk(sentence_list, sv_output) result['sentence_info'] = sentence_list + elif kwargs.get("sentence_timestamp", False): + sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \ + result['timestamp'], \ + result['raw_text']) + result['sentence_info'] = sentence_list result["key"] = key results_ret_list.append(result) diff --git a/funasr/models/contextual_paraformer/model.py b/funasr/models/contextual_paraformer/model.py index abbac8c4f..3f79eedf2 100644 --- a/funasr/models/contextual_paraformer/model.py +++ b/funasr/models/contextual_paraformer/model.py @@ -65,11 +65,9 @@ class ContextualParaformer(Paraformer): if bias_encoder_type == 'lstm': - logging.warning("enable bias encoder sampling and contextual training") self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=bias_encoder_dropout_rate) self.bias_embed = torch.nn.Embedding(self.vocab_size, inner_dim) elif bias_encoder_type == 'mean': - logging.warning("enable bias encoder sampling and contextual training") self.bias_embed = torch.nn.Embedding(self.vocab_size, inner_dim) else: logging.error("Unsupport bias encoder type: {}".format(bias_encoder_type)) diff --git a/funasr/models/seaco_paraformer/model.py b/funasr/models/seaco_paraformer/model.py index 04410bafa..8b8e97e53 100644 --- a/funasr/models/seaco_paraformer/model.py +++ b/funasr/models/seaco_paraformer/model.py @@ -66,7 +66,6 @@ class SeacoParaformer(BiCifParaformer, Paraformer): # bias encoder if self.bias_encoder_type == 'lstm': - logging.warning("enable bias encoder sampling and contextual training") self.bias_encoder = torch.nn.LSTM(self.inner_dim, self.inner_dim, 2, @@ -79,7 +78,6 @@ class SeacoParaformer(BiCifParaformer, Paraformer): self.lstm_proj = None self.bias_embed = torch.nn.Embedding(self.vocab_size, self.inner_dim) elif self.bias_encoder_type == 'mean': - logging.warning("enable bias encoder sampling and contextual training") self.bias_embed = torch.nn.Embedding(self.vocab_size, self.inner_dim) else: logging.error("Unsupport bias encoder type: {}".format(self.bias_encoder_type))