mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
token extract
This commit is contained in:
parent
09bb6d8d03
commit
6d2434f257
@ -304,10 +304,7 @@ class AutoModel:
|
||||
|
||||
time1 = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
if run_mode == "extract_token":
|
||||
res = model.extract_token(**batch, **kwargs)
|
||||
else:
|
||||
res = model.inference(**batch, **kwargs)
|
||||
res = model.inference(**batch, **kwargs)
|
||||
if isinstance(res, (list, tuple)):
|
||||
results = res[0] if len(res) > 0 else [{"text": ""}]
|
||||
meta_data = res[1] if len(res) > 1 else {}
|
||||
@ -329,7 +326,9 @@ class AutoModel:
|
||||
pbar.set_description(description)
|
||||
else:
|
||||
if log_interval is not None and count % log_interval == 0:
|
||||
logging.info(f"processed {count*batch_size}/{num_samples} samples: {key_batch[0]}")
|
||||
logging.info(
|
||||
f"processed {count*batch_size}/{num_samples} samples: {key_batch[0]}"
|
||||
)
|
||||
time_speech_total += batch_data_time
|
||||
time_escape_total += time_escape
|
||||
count += 1
|
||||
|
||||
0
funasr/models/extract_tokens/__init__.py
Normal file
0
funasr/models/extract_tokens/__init__.py
Normal file
1601
funasr/models/extract_tokens/model_small.py
Normal file
1601
funasr/models/extract_tokens/model_small.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1606,136 +1606,6 @@ class SenseVoiceEncoder(nn.Module):
|
||||
return x, olens
|
||||
|
||||
|
||||
@tables.register("encoder_classes", "SenseVoiceQuantizedEncoder")
|
||||
class SenseVoiceQuantizedEncoder(SenseVoiceEncoder):
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
linear_units: int,
|
||||
attention_heads: int,
|
||||
num_blocks: int,
|
||||
quantize_layer_idx: int,
|
||||
normalized_quant_input: bool,
|
||||
quantizer_config: dict,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(input_size, linear_units, attention_heads, num_blocks, **kwargs)
|
||||
self.linear_units = linear_units
|
||||
self.quantize_layer_idx = quantize_layer_idx
|
||||
self.normalized_quant_input = normalized_quant_input
|
||||
self.quantizer = self.build_quantizer(quantizer_config)
|
||||
|
||||
def build_quantizer(self, vq_config):
|
||||
if vq_config is None:
|
||||
return None
|
||||
name = vq_config.pop("name", "costume_quantizer")
|
||||
if name == "costume_quantizer":
|
||||
from funasr.models.sense_voice.quantizer.costume_quantizer import CostumeQuantizer
|
||||
|
||||
quantizer = CostumeQuantizer(
|
||||
input_size=self.linear_units,
|
||||
**vq_config,
|
||||
)
|
||||
vq_config["name"] = "costume_quantizer"
|
||||
return quantizer
|
||||
elif name == "lookup_free_quantizer":
|
||||
from funasr.models.sense_voice.quantizer.lookup_free_quantizer import LFQ
|
||||
|
||||
quantizer = LFQ(
|
||||
input_size=self.linear_units,
|
||||
**vq_config,
|
||||
)
|
||||
vq_config["name"] = "lookup_free_quantizer"
|
||||
return quantizer
|
||||
elif name == "finite_scalar_quantizer":
|
||||
from funasr.models.sense_voice.quantizer.finite_scalar_quantizer import FSQ
|
||||
|
||||
quantizer = FSQ(
|
||||
input_size=self.linear_units,
|
||||
**vq_config,
|
||||
)
|
||||
vq_config["name"] = "finite_scalar_quantizer"
|
||||
return quantizer
|
||||
else:
|
||||
raise NotImplemented("quantizer {} not implemented".format(name))
|
||||
|
||||
def quantize_enc_outs(self, x):
|
||||
ret_dict = {}
|
||||
|
||||
if self.normalized_quant_input:
|
||||
x = F.normalize(x, dim=-1)
|
||||
ret_dict["quant_in"] = x
|
||||
x, indices, commit_loss, sub_quants = self.quantizer(x)
|
||||
ret_dict["quant_out"] = x
|
||||
ret_dict["indices"] = indices
|
||||
ret_dict["quant_loss"] = commit_loss
|
||||
|
||||
return x, ret_dict
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
ilens: torch.Tensor = None,
|
||||
**kwargs,
|
||||
):
|
||||
use_padmask = self.use_padmask
|
||||
x = F.gelu(self.conv1(x))
|
||||
x = F.gelu(self.conv2(x))
|
||||
x = x.permute(0, 2, 1)
|
||||
only_extract_tokens = kwargs.get("only_extract_tokens", False)
|
||||
|
||||
n_frames = x.size(1)
|
||||
max_pos = n_frames
|
||||
|
||||
if ilens is not None:
|
||||
if self.downsample_rate == 4:
|
||||
olens = (
|
||||
1
|
||||
+ (ilens - self.conv1.kernel_size[0] + 2 * self.conv1.padding[0])
|
||||
// self.conv1.stride[0]
|
||||
)
|
||||
else:
|
||||
olens = ilens
|
||||
olens = (
|
||||
1
|
||||
+ (olens - self.conv2.kernel_size[0] + 2 * self.conv2.padding[0])
|
||||
// self.conv2.stride[0]
|
||||
)
|
||||
olens = torch.clamp(olens, max=max_pos)
|
||||
else:
|
||||
olens = None
|
||||
|
||||
if use_padmask and olens is not None:
|
||||
padding_mask = (~make_pad_mask(olens)[:, None, :]).to(torch.bool).to(x.device)
|
||||
else:
|
||||
padding_mask = None
|
||||
|
||||
device = x.device
|
||||
seq_length = x.shape[1]
|
||||
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
|
||||
for layer, block in enumerate(self.blocks):
|
||||
x = block(x, mask=padding_mask, position_ids=position_ids)
|
||||
if self.quantize_layer_idx is not None and self.quantizer is not None:
|
||||
if layer == self.quantize_layer_idx:
|
||||
hint_once(
|
||||
f"Quantization at layer {layer} wit {self.quantizer}",
|
||||
"normalize_quant_enc_out",
|
||||
rank=0,
|
||||
)
|
||||
x, ret_dict = self.quantize_enc_outs(x)
|
||||
if only_extract_tokens:
|
||||
return (x, ret_dict), olens
|
||||
|
||||
x = self.ln_post(x)
|
||||
|
||||
if ilens is None:
|
||||
return x
|
||||
else:
|
||||
return x, olens
|
||||
|
||||
|
||||
import types
|
||||
import time
|
||||
import numpy as np
|
||||
@ -1989,92 +1859,93 @@ class SenseVoiceL(nn.Module):
|
||||
|
||||
return results, meta_data
|
||||
|
||||
def extract_token(
|
||||
self,
|
||||
data_in,
|
||||
data_lengths=None,
|
||||
key: list = None,
|
||||
tokenizer=None,
|
||||
frontend=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if frontend is None and not hasattr(self, "frontend"):
|
||||
frontend_class = tables.frontend_classes.get("WhisperFrontend")
|
||||
frontend = frontend_class(
|
||||
n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True)
|
||||
)
|
||||
self.frontend = frontend
|
||||
else:
|
||||
frontend = frontend if frontend is not None else self.frontend
|
||||
|
||||
meta_data = {}
|
||||
if (
|
||||
isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
|
||||
): # fbank
|
||||
speech, speech_lengths = data_in, data_lengths
|
||||
if len(speech.shape) < 3:
|
||||
speech = speech[None, :, :]
|
||||
if speech_lengths is None:
|
||||
speech_lengths = speech.shape[1]
|
||||
else:
|
||||
# extract fbank feats
|
||||
time1 = time.perf_counter()
|
||||
audio_sample_list = load_audio_text_image_video(
|
||||
data_in,
|
||||
fs=frontend.fs if hasattr(frontend, "fs") else 16000,
|
||||
audio_fs=kwargs.get("fs", 16000),
|
||||
data_type=kwargs.get("data_type", "sound"),
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
time2 = time.perf_counter()
|
||||
meta_data["load_data"] = f"{time2 - time1:0.3f}"
|
||||
if data_lengths is None:
|
||||
data_lengths = [x.shape[0] for x in audio_sample_list]
|
||||
speech, speech_lengths = extract_fbank(
|
||||
audio_sample_list,
|
||||
data_type=kwargs.get("data_type", "sound"),
|
||||
frontend=frontend,
|
||||
data_len=data_lengths,
|
||||
)
|
||||
time3 = time.perf_counter()
|
||||
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
|
||||
frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10
|
||||
lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1
|
||||
meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000
|
||||
|
||||
speech = speech.to(device=kwargs["device"])
|
||||
speech_lengths = speech_lengths.to(device=kwargs["device"])
|
||||
|
||||
(outs, ret_dict), out_lens = self.model.encoder(
|
||||
speech, speech_lengths, only_extract_tokens=True
|
||||
)
|
||||
time4 = time.perf_counter()
|
||||
meta_data["extract_tokens"] = f"{time4 - time3:0.3f}"
|
||||
tokens = ret_dict["indices"]
|
||||
|
||||
text = "extract_token"
|
||||
results = []
|
||||
result_i = {"key": key[0], "text": text}
|
||||
|
||||
# results.append(result_i)
|
||||
|
||||
ark_writer, len_writer = None, None
|
||||
if kwargs.get("output_dir") is not None:
|
||||
out_dir = kwargs.get("output_dir")
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
if not hasattr(self, "writer"):
|
||||
out_path = os.path.join(out_dir, f"enc_token")
|
||||
self.writer = kaldiio.WriteHelper(f"ark,scp,f:{out_path}.ark,{out_path}.scp")
|
||||
self.len_writer = open(out_path + "_len.txt", "wt")
|
||||
ark_writer = self.writer
|
||||
len_writer = self.len_writer
|
||||
|
||||
if ark_writer is not None:
|
||||
for k, v, l in zip(key, tokens.detach().cpu().numpy(), out_lens):
|
||||
ark_writer(k, v[:l])
|
||||
len_writer.write(f"{k}\t{l}\n")
|
||||
time5 = time.perf_counter()
|
||||
meta_data["write_tokens"] = f"{time5 - time4:0.3f}"
|
||||
|
||||
return results, meta_data
|
||||
# def extract_token(
|
||||
# self,
|
||||
# data_in,
|
||||
# data_lengths=None,
|
||||
# key: list = None,
|
||||
# tokenizer=None,
|
||||
# frontend=None,
|
||||
# **kwargs,
|
||||
# ):
|
||||
#
|
||||
# if frontend is None and not hasattr(self, "frontend"):
|
||||
# frontend_class = tables.frontend_classes.get("WhisperFrontend")
|
||||
# frontend = frontend_class(
|
||||
# n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True)
|
||||
# )
|
||||
# self.frontend = frontend
|
||||
# else:
|
||||
# frontend = frontend if frontend is not None else self.frontend
|
||||
#
|
||||
# meta_data = {}
|
||||
# if (
|
||||
# isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
|
||||
# ): # fbank
|
||||
# speech, speech_lengths = data_in, data_lengths
|
||||
# if len(speech.shape) < 3:
|
||||
# speech = speech[None, :, :]
|
||||
# if speech_lengths is None:
|
||||
# speech_lengths = speech.shape[1]
|
||||
# else:
|
||||
# # extract fbank feats
|
||||
# time1 = time.perf_counter()
|
||||
# audio_sample_list = load_audio_text_image_video(
|
||||
# data_in,
|
||||
# fs=frontend.fs if hasattr(frontend, "fs") else 16000,
|
||||
# audio_fs=kwargs.get("fs", 16000),
|
||||
# data_type=kwargs.get("data_type", "sound"),
|
||||
# tokenizer=tokenizer,
|
||||
# )
|
||||
# time2 = time.perf_counter()
|
||||
# meta_data["load_data"] = f"{time2 - time1:0.3f}"
|
||||
# if data_lengths is None:
|
||||
# data_lengths = [x.shape[0] for x in audio_sample_list]
|
||||
# speech, speech_lengths = extract_fbank(
|
||||
# audio_sample_list,
|
||||
# data_type=kwargs.get("data_type", "sound"),
|
||||
# frontend=frontend,
|
||||
# data_len=data_lengths,
|
||||
# )
|
||||
# time3 = time.perf_counter()
|
||||
# meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
|
||||
# frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10
|
||||
# lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1
|
||||
# meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000
|
||||
#
|
||||
# speech = speech.to(device=kwargs["device"])
|
||||
# speech_lengths = speech_lengths.to(device=kwargs["device"])
|
||||
#
|
||||
# (outs, ret_dict), out_lens = self.model.encoder(
|
||||
# speech, speech_lengths, only_extract_tokens=True
|
||||
# )
|
||||
# time4 = time.perf_counter()
|
||||
# meta_data["extract_tokens"] = f"{time4 - time3:0.3f}"
|
||||
# print(f'extract_tokens: {meta_data["extract_tokens"]}')
|
||||
# tokens = ret_dict["indices"]
|
||||
#
|
||||
# text = "extract_token"
|
||||
# results = []
|
||||
# result_i = {"key": key[0], "text": text}
|
||||
#
|
||||
# # results.append(result_i)
|
||||
#
|
||||
# ark_writer, len_writer = None, None
|
||||
# if kwargs.get("output_dir") is not None:
|
||||
# out_dir = kwargs.get("output_dir")
|
||||
# os.makedirs(out_dir, exist_ok=True)
|
||||
# if not hasattr(self, "writer"):
|
||||
# out_path = os.path.join(out_dir, f"enc_token")
|
||||
# self.writer = kaldiio.WriteHelper(f"ark,scp,f:{out_path}.ark,{out_path}.scp")
|
||||
# self.len_writer = open(out_path + "_len.txt", "wt")
|
||||
# ark_writer = self.writer
|
||||
# len_writer = self.len_writer
|
||||
#
|
||||
# if ark_writer is not None:
|
||||
# for k, v, l in zip(key, tokens.detach().cpu().numpy(), out_lens):
|
||||
# ark_writer(k, v[:l])
|
||||
# len_writer.write(f"{k}\t{l}\n")
|
||||
# time5 = time.perf_counter()
|
||||
# meta_data["write_tokens"] = f"{time5 - time4:0.3f}"
|
||||
#
|
||||
# return results, meta_data
|
||||
|
||||
Loading…
Reference in New Issue
Block a user