mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
test
This commit is contained in:
parent
0a7384a1ec
commit
62178770dc
@ -209,14 +209,12 @@ class AutoModel:
|
||||
kwargs.update(cfg)
|
||||
model = self.model if model is None else model
|
||||
model.eval()
|
||||
pdb.set_trace()
|
||||
|
||||
batch_size = kwargs.get("batch_size", 1)
|
||||
# if kwargs.get("device", "cpu") == "cpu":
|
||||
# batch_size = 1
|
||||
|
||||
key_list, data_list = prepare_data_iterator(input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key)
|
||||
pdb.set_trace()
|
||||
|
||||
speed_stats = {}
|
||||
asr_result_list = []
|
||||
@ -225,14 +223,12 @@ class AutoModel:
|
||||
pbar = tqdm(colour="blue", total=num_samples, dynamic_ncols=True) if not disable_pbar else None
|
||||
time_speech_total = 0.0
|
||||
time_escape_total = 0.0
|
||||
pdb.set_trace()
|
||||
for beg_idx in range(0, num_samples, batch_size):
|
||||
pdb.set_trace()
|
||||
end_idx = min(num_samples, beg_idx + batch_size)
|
||||
data_batch = data_list[beg_idx:end_idx]
|
||||
key_batch = key_list[beg_idx:end_idx]
|
||||
batch = {"data_in": data_batch, "key": key_batch}
|
||||
pdb.set_trace()
|
||||
|
||||
if (end_idx - beg_idx) == 1 and kwargs.get("data_type", None) == "fbank": # fbank
|
||||
batch["data_in"] = data_batch[0]
|
||||
batch["data_lengths"] = input_len
|
||||
|
||||
@ -102,17 +102,16 @@ class ContextualParaformer(Paraformer):
|
||||
text_lengths = text_lengths[:, 0]
|
||||
if len(speech_lengths.size()) > 1:
|
||||
speech_lengths = speech_lengths[:, 0]
|
||||
pdb.set_trace()
|
||||
|
||||
batch_size = speech.shape[0]
|
||||
|
||||
hotword_pad = kwargs.get("hotword_pad")
|
||||
hotword_lengths = kwargs.get("hotword_lengths")
|
||||
dha_pad = kwargs.get("dha_pad")
|
||||
pdb.set_trace()
|
||||
|
||||
# 1. Encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
|
||||
pdb.set_trace()
|
||||
loss_ctc, cer_ctc = None, None
|
||||
|
||||
stats = dict()
|
||||
@ -127,12 +126,11 @@ class ContextualParaformer(Paraformer):
|
||||
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
|
||||
stats["cer_ctc"] = cer_ctc
|
||||
|
||||
pdb.set_trace()
|
||||
# 2b. Attention decoder branch
|
||||
loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal = self._calc_att_clas_loss(
|
||||
encoder_out, encoder_out_lens, text, text_lengths, hotword_pad, hotword_lengths
|
||||
)
|
||||
pdb.set_trace()
|
||||
|
||||
# 3. CTC-Att loss definition
|
||||
if self.ctc_weight == 0.0:
|
||||
loss = loss_att + loss_pre * self.predictor_weight
|
||||
@ -170,26 +168,24 @@ class ContextualParaformer(Paraformer):
|
||||
):
|
||||
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
|
||||
encoder_out.device)
|
||||
pdb.set_trace()
|
||||
|
||||
if self.predictor_bias == 1:
|
||||
_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
|
||||
ys_pad_lens = ys_pad_lens + self.predictor_bias
|
||||
pdb.set_trace()
|
||||
|
||||
pre_acoustic_embeds, pre_token_length, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
|
||||
ignore_id=self.ignore_id)
|
||||
pdb.set_trace()
|
||||
# -1. bias encoder
|
||||
if self.use_decoder_embedding:
|
||||
hw_embed = self.decoder.embed(hotword_pad)
|
||||
else:
|
||||
hw_embed = self.bias_embed(hotword_pad)
|
||||
pdb.set_trace()
|
||||
|
||||
hw_embed, (_, _) = self.bias_encoder(hw_embed)
|
||||
pdb.set_trace()
|
||||
_ind = np.arange(0, hotword_pad.shape[0]).tolist()
|
||||
selected = hw_embed[_ind, [i - 1 for i in hotword_lengths.detach().cpu().tolist()]]
|
||||
contextual_info = selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device)
|
||||
pdb.set_trace()
|
||||
|
||||
# 0. sampler
|
||||
decoder_out_1st = None
|
||||
if self.sampling_ratio > 0.0:
|
||||
@ -201,7 +197,7 @@ class ContextualParaformer(Paraformer):
|
||||
if self.step_cur < 2:
|
||||
logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
|
||||
sematic_embeds = pre_acoustic_embeds
|
||||
pdb.set_trace()
|
||||
|
||||
# 1. Forward decoder
|
||||
decoder_outs = self.decoder(
|
||||
encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
|
||||
@ -217,7 +213,7 @@ class ContextualParaformer(Paraformer):
|
||||
loss_ideal = None
|
||||
'''
|
||||
loss_ideal = None
|
||||
pdb.set_trace()
|
||||
|
||||
if decoder_out_1st is None:
|
||||
decoder_out_1st = decoder_out
|
||||
# 2. Compute attention loss
|
||||
@ -294,11 +290,11 @@ class ContextualParaformer(Paraformer):
|
||||
enforce_sorted=False)
|
||||
_, (h_n, _) = self.bias_encoder(hw_embed)
|
||||
hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
|
||||
pdb.set_trace()
|
||||
|
||||
decoder_outs = self.decoder(
|
||||
encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale
|
||||
)
|
||||
pdb.set_trace()
|
||||
|
||||
decoder_out = decoder_outs[0]
|
||||
decoder_out = torch.log_softmax(decoder_out, dim=-1)
|
||||
return decoder_out, ys_pad_lens
|
||||
@ -363,14 +359,11 @@ class ContextualParaformer(Paraformer):
|
||||
clas_scale=kwargs.get("clas_scale", 1.0))
|
||||
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
|
||||
|
||||
pdb.set_trace()
|
||||
results = []
|
||||
b, n, d = decoder_out.size()
|
||||
pdb.set_trace()
|
||||
for i in range(b):
|
||||
x = encoder_out[i, :encoder_out_lens[i], :]
|
||||
am_scores = decoder_out[i, :pre_token_length[i], :]
|
||||
pdb.set_trace()
|
||||
if self.beam_search is not None:
|
||||
nbest_hyps = self.beam_search(
|
||||
x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
|
||||
|
||||
@ -32,7 +32,7 @@ from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
|
||||
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
|
||||
|
||||
|
||||
import pdb
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
from torch.cuda.amp import autocast
|
||||
else:
|
||||
@ -130,7 +130,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
||||
hotword_pad = kwargs.get("hotword_pad")
|
||||
hotword_lengths = kwargs.get("hotword_lengths")
|
||||
dha_pad = kwargs.get("dha_pad")
|
||||
|
||||
|
||||
batch_size = speech.shape[0]
|
||||
self.step_cur += 1
|
||||
# for data-parallel
|
||||
@ -212,58 +212,87 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
||||
nfilter=50,
|
||||
seaco_weight=1.0):
|
||||
# decoder forward
|
||||
pdb.set_trace()
|
||||
decoder_out, decoder_hidden, _ = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, return_hidden=True, return_both=True)
|
||||
pdb.set_trace()
|
||||
decoder_pred = torch.log_softmax(decoder_out, dim=-1)
|
||||
if hw_list is not None:
|
||||
pdb.set_trace()
|
||||
hw_lengths = [len(i) for i in hw_list]
|
||||
hw_list_ = [torch.Tensor(i).long() for i in hw_list]
|
||||
hw_list_pad = pad_list(hw_list_, 0).to(encoder_out.device)
|
||||
pdb.set_trace()
|
||||
selected = self._hotword_representation(hw_list_pad, torch.Tensor(hw_lengths).int().to(encoder_out.device))
|
||||
pdb.set_trace()
|
||||
contextual_info = selected.squeeze(0).repeat(encoder_out.shape[0], 1, 1).to(encoder_out.device)
|
||||
pdb.set_trace()
|
||||
num_hot_word = contextual_info.shape[1]
|
||||
_contextual_length = torch.Tensor([num_hot_word]).int().repeat(encoder_out.shape[0]).to(encoder_out.device)
|
||||
|
||||
pdb.set_trace()
|
||||
# ASF Core
|
||||
if nfilter > 0 and nfilter < num_hot_word:
|
||||
for dec in self.seaco_decoder.decoders:
|
||||
dec.reserve_attn = True
|
||||
pdb.set_trace()
|
||||
# cif_attended, _ = self.decoder2(contextual_info, _contextual_length, sematic_embeds, ys_pad_lens)
|
||||
dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens)
|
||||
# cif_filter = torch.topk(self.decoder2.decoders[-1].attn_mat[0][0].sum(0).sum(0)[:-1], min(nfilter, num_hot_word-1))[1].tolist()
|
||||
pdb.set_trace()
|
||||
hotword_scores = self.seaco_decoder.decoders[-1].attn_mat[0][0].sum(0).sum(0)[:-1]
|
||||
# hotword_scores /= torch.sqrt(torch.tensor(hw_lengths)[:-1].float()).to(hotword_scores.device)
|
||||
pdb.set_trace()
|
||||
dec_filter = torch.topk(hotword_scores, min(nfilter, num_hot_word-1))[1].tolist()
|
||||
pdb.set_trace()
|
||||
add_filter = dec_filter
|
||||
pdb.set_trace()
|
||||
add_filter.append(len(hw_list_pad)-1)
|
||||
# filter hotword embedding
|
||||
pdb.set_trace()
|
||||
selected = selected[add_filter]
|
||||
# again
|
||||
pdb.set_trace()
|
||||
contextual_info = selected.squeeze(0).repeat(encoder_out.shape[0], 1, 1).to(encoder_out.device)
|
||||
pdb.set_trace()
|
||||
num_hot_word = contextual_info.shape[1]
|
||||
_contextual_length = torch.Tensor([num_hot_word]).int().repeat(encoder_out.shape[0]).to(encoder_out.device)
|
||||
pdb.set_trace()
|
||||
for dec in self.seaco_decoder.decoders:
|
||||
dec.attn_mat = []
|
||||
dec.reserve_attn = False
|
||||
|
||||
pdb.set_trace()
|
||||
# SeACo Core
|
||||
cif_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, sematic_embeds, ys_pad_lens)
|
||||
pdb.set_trace()
|
||||
dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens)
|
||||
pdb.set_trace()
|
||||
merged = self._merge(cif_attended, dec_attended)
|
||||
|
||||
pdb.set_trace()
|
||||
|
||||
dha_output = self.hotword_output_layer(merged) # remove the last token in loss calculation
|
||||
pdb.set_trace()
|
||||
dha_pred = torch.log_softmax(dha_output, dim=-1)
|
||||
pdb.set_trace()
|
||||
def _merge_res(dec_output, dha_output):
|
||||
pdb.set_trace()
|
||||
lmbd = torch.Tensor([seaco_weight] * dha_output.shape[0])
|
||||
pdb.set_trace()
|
||||
dha_ids = dha_output.max(-1)[-1]# [0]
|
||||
pdb.set_trace()
|
||||
dha_mask = (dha_ids == 8377).int().unsqueeze(-1)
|
||||
pdb.set_trace()
|
||||
a = (1 - lmbd) / lmbd
|
||||
b = 1 / lmbd
|
||||
pdb.set_trace()
|
||||
a, b = a.to(dec_output.device), b.to(dec_output.device)
|
||||
pdb.set_trace()
|
||||
dha_mask = (dha_mask + a.reshape(-1, 1, 1)) / b.reshape(-1, 1, 1)
|
||||
# logits = dec_output * dha_mask + dha_output[:,:,:-1] * (1-dha_mask)
|
||||
pdb.set_trace()
|
||||
logits = dec_output * dha_mask + dha_output[:,:,:] * (1-dha_mask)
|
||||
return logits
|
||||
|
||||
merged_pred = _merge_res(decoder_pred, dha_pred)
|
||||
pdb.set_trace()
|
||||
# import pdb; pdb.set_trace()
|
||||
return merged_pred
|
||||
else:
|
||||
@ -318,7 +347,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
||||
logging.info("enable beam_search")
|
||||
self.init_beam_search(**kwargs)
|
||||
self.nbest = kwargs.get("nbest", 1)
|
||||
|
||||
pdb.set_trace()
|
||||
meta_data = {}
|
||||
|
||||
# extract fbank feats
|
||||
@ -326,6 +355,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
||||
audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
|
||||
time2 = time.perf_counter()
|
||||
meta_data["load_data"] = f"{time2 - time1:0.3f}"
|
||||
pdb.set_trace()
|
||||
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
|
||||
frontend=frontend)
|
||||
time3 = time.perf_counter()
|
||||
@ -336,14 +366,18 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
||||
speech = speech.to(device=kwargs["device"])
|
||||
speech_lengths = speech_lengths.to(device=kwargs["device"])
|
||||
|
||||
pdb.set_trace()
|
||||
# hotword
|
||||
self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend)
|
||||
|
||||
pdb.set_trace()
|
||||
# Encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
if isinstance(encoder_out, tuple):
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
|
||||
pdb.set_trace()
|
||||
# predictor
|
||||
predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
|
||||
pre_acoustic_embeds, pre_token_length, _, _ = predictor_outs[0], predictor_outs[1], \
|
||||
@ -352,15 +386,16 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
||||
if torch.max(pre_token_length) < 1:
|
||||
return []
|
||||
|
||||
|
||||
pdb.set_trace()
|
||||
decoder_out = self._seaco_decode_with_ASF(encoder_out, encoder_out_lens,
|
||||
pre_acoustic_embeds,
|
||||
pre_token_length,
|
||||
hw_list=self.hotword_list)
|
||||
pdb.set_trace()
|
||||
# decoder_out, _ = decoder_outs[0], decoder_outs[1]
|
||||
_, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens,
|
||||
pre_token_length)
|
||||
|
||||
pdb.set_trace()
|
||||
results = []
|
||||
b, n, d = decoder_out.size()
|
||||
for i in range(b):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user