mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
token extract
This commit is contained in:
parent
851474632d
commit
09bb6d8d03
@ -1631,6 +1631,7 @@ class SenseVoiceQuantizedEncoder(SenseVoiceEncoder):
|
||||
name = vq_config.pop("name", "costume_quantizer")
|
||||
if name == "costume_quantizer":
|
||||
from funasr.models.sense_voice.quantizer.costume_quantizer import CostumeQuantizer
|
||||
|
||||
quantizer = CostumeQuantizer(
|
||||
input_size=self.linear_units,
|
||||
**vq_config,
|
||||
@ -1639,6 +1640,7 @@ class SenseVoiceQuantizedEncoder(SenseVoiceEncoder):
|
||||
return quantizer
|
||||
elif name == "lookup_free_quantizer":
|
||||
from funasr.models.sense_voice.quantizer.lookup_free_quantizer import LFQ
|
||||
|
||||
quantizer = LFQ(
|
||||
input_size=self.linear_units,
|
||||
**vq_config,
|
||||
@ -1647,6 +1649,7 @@ class SenseVoiceQuantizedEncoder(SenseVoiceEncoder):
|
||||
return quantizer
|
||||
elif name == "finite_scalar_quantizer":
|
||||
from funasr.models.sense_voice.quantizer.finite_scalar_quantizer import FSQ
|
||||
|
||||
quantizer = FSQ(
|
||||
input_size=self.linear_units,
|
||||
**vq_config,
|
||||
@ -1716,8 +1719,11 @@ class SenseVoiceQuantizedEncoder(SenseVoiceEncoder):
|
||||
x = block(x, mask=padding_mask, position_ids=position_ids)
|
||||
if self.quantize_layer_idx is not None and self.quantizer is not None:
|
||||
if layer == self.quantize_layer_idx:
|
||||
hint_once(f"Quantization at layer {layer} wit {self.quantizer}",
|
||||
"normalize_quant_enc_out", rank=0)
|
||||
hint_once(
|
||||
f"Quantization at layer {layer} wit {self.quantizer}",
|
||||
"normalize_quant_enc_out",
|
||||
rank=0,
|
||||
)
|
||||
x, ret_dict = self.quantize_enc_outs(x)
|
||||
if only_extract_tokens:
|
||||
return (x, ret_dict), olens
|
||||
@ -2026,8 +2032,10 @@ class SenseVoiceL(nn.Module):
|
||||
if data_lengths is None:
|
||||
data_lengths = [x.shape[0] for x in audio_sample_list]
|
||||
speech, speech_lengths = extract_fbank(
|
||||
audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend,
|
||||
data_len=data_lengths
|
||||
audio_sample_list,
|
||||
data_type=kwargs.get("data_type", "sound"),
|
||||
frontend=frontend,
|
||||
data_len=data_lengths,
|
||||
)
|
||||
time3 = time.perf_counter()
|
||||
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
|
||||
@ -2039,9 +2047,10 @@ class SenseVoiceL(nn.Module):
|
||||
speech_lengths = speech_lengths.to(device=kwargs["device"])
|
||||
|
||||
(outs, ret_dict), out_lens = self.model.encoder(
|
||||
speech, speech_lengths,
|
||||
only_extract_tokens=True
|
||||
speech, speech_lengths, only_extract_tokens=True
|
||||
)
|
||||
time4 = time.perf_counter()
|
||||
meta_data["extract_tokens"] = f"{time4 - time3:0.3f}"
|
||||
tokens = ret_dict["indices"]
|
||||
|
||||
text = "extract_token"
|
||||
@ -2057,12 +2066,15 @@ class SenseVoiceL(nn.Module):
|
||||
if not hasattr(self, "writer"):
|
||||
out_path = os.path.join(out_dir, f"enc_token")
|
||||
self.writer = kaldiio.WriteHelper(f"ark,scp,f:{out_path}.ark,{out_path}.scp")
|
||||
self.len_writer = open(out_path+"_len.txt", "wt")
|
||||
self.len_writer = open(out_path + "_len.txt", "wt")
|
||||
ark_writer = self.writer
|
||||
len_writer = self.len_writer
|
||||
|
||||
if ark_writer is not None:
|
||||
for k, v, l in zip(key, tokens.detach().cpu().numpy(), out_lens):
|
||||
ark_writer(k, v[:l])
|
||||
len_writer.write(f"{k}\t{l}\n")
|
||||
time5 = time.perf_counter()
|
||||
meta_data["write_tokens"] = f"{time5 - time4:0.3f}"
|
||||
|
||||
return results, meta_data
|
||||
|
||||
Loading…
Reference in New Issue
Block a user