token extract

This commit is contained in:
游雁 2024-09-25 11:21:13 +08:00
parent 09bb6d8d03
commit 6d2434f257
4 changed files with 1695 additions and 224 deletions

View File

@ -304,10 +304,7 @@ class AutoModel:
time1 = time.perf_counter()
with torch.no_grad():
if run_mode == "extract_token":
res = model.extract_token(**batch, **kwargs)
else:
res = model.inference(**batch, **kwargs)
res = model.inference(**batch, **kwargs)
if isinstance(res, (list, tuple)):
results = res[0] if len(res) > 0 else [{"text": ""}]
meta_data = res[1] if len(res) > 1 else {}
@ -329,7 +326,9 @@ class AutoModel:
pbar.set_description(description)
else:
if log_interval is not None and count % log_interval == 0:
logging.info(f"processed {count*batch_size}/{num_samples} samples: {key_batch[0]}")
logging.info(
f"processed {count*batch_size}/{num_samples} samples: {key_batch[0]}"
)
time_speech_total += batch_data_time
time_escape_total += time_escape
count += 1

View File

File diff suppressed because it is too large Load Diff

View File

@ -1606,136 +1606,6 @@ class SenseVoiceEncoder(nn.Module):
return x, olens
@tables.register("encoder_classes", "SenseVoiceQuantizedEncoder")
class SenseVoiceQuantizedEncoder(SenseVoiceEncoder):
def __init__(
self,
input_size,
linear_units: int,
attention_heads: int,
num_blocks: int,
quantize_layer_idx: int,
normalized_quant_input: bool,
quantizer_config: dict,
**kwargs,
):
super().__init__(input_size, linear_units, attention_heads, num_blocks, **kwargs)
self.linear_units = linear_units
self.quantize_layer_idx = quantize_layer_idx
self.normalized_quant_input = normalized_quant_input
self.quantizer = self.build_quantizer(quantizer_config)
def build_quantizer(self, vq_config):
if vq_config is None:
return None
name = vq_config.pop("name", "costume_quantizer")
if name == "costume_quantizer":
from funasr.models.sense_voice.quantizer.costume_quantizer import CostumeQuantizer
quantizer = CostumeQuantizer(
input_size=self.linear_units,
**vq_config,
)
vq_config["name"] = "costume_quantizer"
return quantizer
elif name == "lookup_free_quantizer":
from funasr.models.sense_voice.quantizer.lookup_free_quantizer import LFQ
quantizer = LFQ(
input_size=self.linear_units,
**vq_config,
)
vq_config["name"] = "lookup_free_quantizer"
return quantizer
elif name == "finite_scalar_quantizer":
from funasr.models.sense_voice.quantizer.finite_scalar_quantizer import FSQ
quantizer = FSQ(
input_size=self.linear_units,
**vq_config,
)
vq_config["name"] = "finite_scalar_quantizer"
return quantizer
else:
raise NotImplemented("quantizer {} not implemented".format(name))
def quantize_enc_outs(self, x):
ret_dict = {}
if self.normalized_quant_input:
x = F.normalize(x, dim=-1)
ret_dict["quant_in"] = x
x, indices, commit_loss, sub_quants = self.quantizer(x)
ret_dict["quant_out"] = x
ret_dict["indices"] = indices
ret_dict["quant_loss"] = commit_loss
return x, ret_dict
def forward(
self,
x: torch.Tensor,
ilens: torch.Tensor = None,
**kwargs,
):
use_padmask = self.use_padmask
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1)
only_extract_tokens = kwargs.get("only_extract_tokens", False)
n_frames = x.size(1)
max_pos = n_frames
if ilens is not None:
if self.downsample_rate == 4:
olens = (
1
+ (ilens - self.conv1.kernel_size[0] + 2 * self.conv1.padding[0])
// self.conv1.stride[0]
)
else:
olens = ilens
olens = (
1
+ (olens - self.conv2.kernel_size[0] + 2 * self.conv2.padding[0])
// self.conv2.stride[0]
)
olens = torch.clamp(olens, max=max_pos)
else:
olens = None
if use_padmask and olens is not None:
padding_mask = (~make_pad_mask(olens)[:, None, :]).to(torch.bool).to(x.device)
else:
padding_mask = None
device = x.device
seq_length = x.shape[1]
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
for layer, block in enumerate(self.blocks):
x = block(x, mask=padding_mask, position_ids=position_ids)
if self.quantize_layer_idx is not None and self.quantizer is not None:
if layer == self.quantize_layer_idx:
hint_once(
f"Quantization at layer {layer} wit {self.quantizer}",
"normalize_quant_enc_out",
rank=0,
)
x, ret_dict = self.quantize_enc_outs(x)
if only_extract_tokens:
return (x, ret_dict), olens
x = self.ln_post(x)
if ilens is None:
return x
else:
return x, olens
import types
import time
import numpy as np
@ -1989,92 +1859,93 @@ class SenseVoiceL(nn.Module):
return results, meta_data
def extract_token(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
if frontend is None and not hasattr(self, "frontend"):
frontend_class = tables.frontend_classes.get("WhisperFrontend")
frontend = frontend_class(
n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True)
)
self.frontend = frontend
else:
frontend = frontend if frontend is not None else self.frontend
meta_data = {}
if (
isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
): # fbank
speech, speech_lengths = data_in, data_lengths
if len(speech.shape) < 3:
speech = speech[None, :, :]
if speech_lengths is None:
speech_lengths = speech.shape[1]
else:
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio_text_image_video(
data_in,
fs=frontend.fs if hasattr(frontend, "fs") else 16000,
audio_fs=kwargs.get("fs", 16000),
data_type=kwargs.get("data_type", "sound"),
tokenizer=tokenizer,
)
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
if data_lengths is None:
data_lengths = [x.shape[0] for x in audio_sample_list]
speech, speech_lengths = extract_fbank(
audio_sample_list,
data_type=kwargs.get("data_type", "sound"),
frontend=frontend,
data_len=data_lengths,
)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10
lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1
meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000
speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])
(outs, ret_dict), out_lens = self.model.encoder(
speech, speech_lengths, only_extract_tokens=True
)
time4 = time.perf_counter()
meta_data["extract_tokens"] = f"{time4 - time3:0.3f}"
tokens = ret_dict["indices"]
text = "extract_token"
results = []
result_i = {"key": key[0], "text": text}
# results.append(result_i)
ark_writer, len_writer = None, None
if kwargs.get("output_dir") is not None:
out_dir = kwargs.get("output_dir")
os.makedirs(out_dir, exist_ok=True)
if not hasattr(self, "writer"):
out_path = os.path.join(out_dir, f"enc_token")
self.writer = kaldiio.WriteHelper(f"ark,scp,f:{out_path}.ark,{out_path}.scp")
self.len_writer = open(out_path + "_len.txt", "wt")
ark_writer = self.writer
len_writer = self.len_writer
if ark_writer is not None:
for k, v, l in zip(key, tokens.detach().cpu().numpy(), out_lens):
ark_writer(k, v[:l])
len_writer.write(f"{k}\t{l}\n")
time5 = time.perf_counter()
meta_data["write_tokens"] = f"{time5 - time4:0.3f}"
return results, meta_data
# def extract_token(
# self,
# data_in,
# data_lengths=None,
# key: list = None,
# tokenizer=None,
# frontend=None,
# **kwargs,
# ):
#
# if frontend is None and not hasattr(self, "frontend"):
# frontend_class = tables.frontend_classes.get("WhisperFrontend")
# frontend = frontend_class(
# n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True)
# )
# self.frontend = frontend
# else:
# frontend = frontend if frontend is not None else self.frontend
#
# meta_data = {}
# if (
# isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
# ): # fbank
# speech, speech_lengths = data_in, data_lengths
# if len(speech.shape) < 3:
# speech = speech[None, :, :]
# if speech_lengths is None:
# speech_lengths = speech.shape[1]
# else:
# # extract fbank feats
# time1 = time.perf_counter()
# audio_sample_list = load_audio_text_image_video(
# data_in,
# fs=frontend.fs if hasattr(frontend, "fs") else 16000,
# audio_fs=kwargs.get("fs", 16000),
# data_type=kwargs.get("data_type", "sound"),
# tokenizer=tokenizer,
# )
# time2 = time.perf_counter()
# meta_data["load_data"] = f"{time2 - time1:0.3f}"
# if data_lengths is None:
# data_lengths = [x.shape[0] for x in audio_sample_list]
# speech, speech_lengths = extract_fbank(
# audio_sample_list,
# data_type=kwargs.get("data_type", "sound"),
# frontend=frontend,
# data_len=data_lengths,
# )
# time3 = time.perf_counter()
# meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
# frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10
# lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1
# meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000
#
# speech = speech.to(device=kwargs["device"])
# speech_lengths = speech_lengths.to(device=kwargs["device"])
#
# (outs, ret_dict), out_lens = self.model.encoder(
# speech, speech_lengths, only_extract_tokens=True
# )
# time4 = time.perf_counter()
# meta_data["extract_tokens"] = f"{time4 - time3:0.3f}"
# print(f'extract_tokens: {meta_data["extract_tokens"]}')
# tokens = ret_dict["indices"]
#
# text = "extract_token"
# results = []
# result_i = {"key": key[0], "text": text}
#
# # results.append(result_i)
#
# ark_writer, len_writer = None, None
# if kwargs.get("output_dir") is not None:
# out_dir = kwargs.get("output_dir")
# os.makedirs(out_dir, exist_ok=True)
# if not hasattr(self, "writer"):
# out_path = os.path.join(out_dir, f"enc_token")
# self.writer = kaldiio.WriteHelper(f"ark,scp,f:{out_path}.ark,{out_path}.scp")
# self.len_writer = open(out_path + "_len.txt", "wt")
# ark_writer = self.writer
# len_writer = self.len_writer
#
# if ark_writer is not None:
# for k, v, l in zip(key, tokens.detach().cpu().numpy(), out_lens):
# ark_writer(k, v[:l])
# len_writer.write(f"{k}\t{l}\n")
# time5 = time.perf_counter()
# meta_data["write_tokens"] = f"{time5 - time4:0.3f}"
#
# return results, meta_data