token extract

This commit is contained in:
游雁 2024-09-25 10:52:49 +08:00
parent 851474632d
commit 09bb6d8d03

View File

@ -1631,6 +1631,7 @@ class SenseVoiceQuantizedEncoder(SenseVoiceEncoder):
name = vq_config.pop("name", "costume_quantizer")
if name == "costume_quantizer":
from funasr.models.sense_voice.quantizer.costume_quantizer import CostumeQuantizer
quantizer = CostumeQuantizer(
input_size=self.linear_units,
**vq_config,
@ -1639,6 +1640,7 @@ class SenseVoiceQuantizedEncoder(SenseVoiceEncoder):
return quantizer
elif name == "lookup_free_quantizer":
from funasr.models.sense_voice.quantizer.lookup_free_quantizer import LFQ
quantizer = LFQ(
input_size=self.linear_units,
**vq_config,
@ -1647,6 +1649,7 @@ class SenseVoiceQuantizedEncoder(SenseVoiceEncoder):
return quantizer
elif name == "finite_scalar_quantizer":
from funasr.models.sense_voice.quantizer.finite_scalar_quantizer import FSQ
quantizer = FSQ(
input_size=self.linear_units,
**vq_config,
@ -1716,8 +1719,11 @@ class SenseVoiceQuantizedEncoder(SenseVoiceEncoder):
x = block(x, mask=padding_mask, position_ids=position_ids)
if self.quantize_layer_idx is not None and self.quantizer is not None:
if layer == self.quantize_layer_idx:
hint_once(f"Quantization at layer {layer} wit {self.quantizer}",
"normalize_quant_enc_out", rank=0)
hint_once(
f"Quantization at layer {layer} wit {self.quantizer}",
"normalize_quant_enc_out",
rank=0,
)
x, ret_dict = self.quantize_enc_outs(x)
if only_extract_tokens:
return (x, ret_dict), olens
@ -2026,8 +2032,10 @@ class SenseVoiceL(nn.Module):
if data_lengths is None:
data_lengths = [x.shape[0] for x in audio_sample_list]
speech, speech_lengths = extract_fbank(
audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend,
data_len=data_lengths
audio_sample_list,
data_type=kwargs.get("data_type", "sound"),
frontend=frontend,
data_len=data_lengths,
)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
@ -2039,9 +2047,10 @@ class SenseVoiceL(nn.Module):
speech_lengths = speech_lengths.to(device=kwargs["device"])
(outs, ret_dict), out_lens = self.model.encoder(
speech, speech_lengths,
only_extract_tokens=True
speech, speech_lengths, only_extract_tokens=True
)
time4 = time.perf_counter()
meta_data["extract_tokens"] = f"{time4 - time3:0.3f}"
tokens = ret_dict["indices"]
text = "extract_token"
@ -2057,12 +2066,15 @@ class SenseVoiceL(nn.Module):
if not hasattr(self, "writer"):
out_path = os.path.join(out_dir, f"enc_token")
self.writer = kaldiio.WriteHelper(f"ark,scp,f:{out_path}.ark,{out_path}.scp")
self.len_writer = open(out_path+"_len.txt", "wt")
self.len_writer = open(out_path + "_len.txt", "wt")
ark_writer = self.writer
len_writer = self.len_writer
if ark_writer is not None:
for k, v, l in zip(key, tokens.detach().cpu().numpy(), out_lens):
ark_writer(k, v[:l])
len_writer.write(f"{k}\t{l}\n")
time5 = time.perf_counter()
meta_data["write_tokens"] = f"{time5 - time4:0.3f}"
return results, meta_data