diff --git a/funasr/models/sense_voice/model_small.py b/funasr/models/sense_voice/model_small.py index 42a5c0338..d2a19993b 100644 --- a/funasr/models/sense_voice/model_small.py +++ b/funasr/models/sense_voice/model_small.py @@ -1631,6 +1631,7 @@ class SenseVoiceQuantizedEncoder(SenseVoiceEncoder): name = vq_config.pop("name", "costume_quantizer") if name == "costume_quantizer": from funasr.models.sense_voice.quantizer.costume_quantizer import CostumeQuantizer + quantizer = CostumeQuantizer( input_size=self.linear_units, **vq_config, @@ -1639,6 +1640,7 @@ class SenseVoiceQuantizedEncoder(SenseVoiceEncoder): return quantizer elif name == "lookup_free_quantizer": from funasr.models.sense_voice.quantizer.lookup_free_quantizer import LFQ + quantizer = LFQ( input_size=self.linear_units, **vq_config, @@ -1647,6 +1649,7 @@ class SenseVoiceQuantizedEncoder(SenseVoiceEncoder): return quantizer elif name == "finite_scalar_quantizer": from funasr.models.sense_voice.quantizer.finite_scalar_quantizer import FSQ + quantizer = FSQ( input_size=self.linear_units, **vq_config, @@ -1716,8 +1719,11 @@ class SenseVoiceQuantizedEncoder(SenseVoiceEncoder): x = block(x, mask=padding_mask, position_ids=position_ids) if self.quantize_layer_idx is not None and self.quantizer is not None: if layer == self.quantize_layer_idx: - hint_once(f"Quantization at layer {layer} wit {self.quantizer}", - "normalize_quant_enc_out", rank=0) + hint_once( + f"Quantization at layer {layer} wit {self.quantizer}", + "normalize_quant_enc_out", + rank=0, + ) x, ret_dict = self.quantize_enc_outs(x) if only_extract_tokens: return (x, ret_dict), olens @@ -2026,8 +2032,10 @@ class SenseVoiceL(nn.Module): if data_lengths is None: data_lengths = [x.shape[0] for x in audio_sample_list] speech, speech_lengths = extract_fbank( - audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend, - data_len=data_lengths + audio_sample_list, + data_type=kwargs.get("data_type", "sound"), + frontend=frontend, + data_len=data_lengths, ) time3 = time.perf_counter() meta_data["extract_feat"] = f"{time3 - time2:0.3f}" @@ -2039,9 +2047,10 @@ class SenseVoiceL(nn.Module): speech_lengths = speech_lengths.to(device=kwargs["device"]) (outs, ret_dict), out_lens = self.model.encoder( - speech, speech_lengths, - only_extract_tokens=True + speech, speech_lengths, only_extract_tokens=True ) + time4 = time.perf_counter() + meta_data["extract_tokens"] = f"{time4 - time3:0.3f}" tokens = ret_dict["indices"] text = "extract_token" @@ -2057,12 +2066,15 @@ class SenseVoiceL(nn.Module): if not hasattr(self, "writer"): out_path = os.path.join(out_dir, f"enc_token") self.writer = kaldiio.WriteHelper(f"ark,scp,f:{out_path}.ark,{out_path}.scp") - self.len_writer = open(out_path+"_len.txt", "wt") + self.len_writer = open(out_path + "_len.txt", "wt") ark_writer = self.writer len_writer = self.len_writer + if ark_writer is not None: for k, v, l in zip(key, tokens.detach().cpu().numpy(), out_lens): ark_writer(k, v[:l]) len_writer.write(f"{k}\t{l}\n") + time5 = time.perf_counter() + meta_data["write_tokens"] = f"{time5 - time4:0.3f}" return results, meta_data