mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
decoding
This commit is contained in:
parent
0033151b62
commit
b01c9f1c25
@ -300,9 +300,9 @@ class OpenAIDatasetMultiTurn(torch.utils.data.Dataset):
|
||||
return len(self.index_ds)
|
||||
|
||||
def __getitem__(self, index):
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
# import pdb
|
||||
#
|
||||
# pdb.set_trace()
|
||||
|
||||
output = None
|
||||
|
||||
@ -397,6 +397,7 @@ class OpenAIDatasetMultiTurn(torch.utils.data.Dataset):
|
||||
labels += source_mask + target_ids
|
||||
fbank.append(speech[0, :, :])
|
||||
fbank_mask += fbank_mask_i
|
||||
fbank_lens.append(speech_lengths)
|
||||
|
||||
if len(input_ids) > self.max_token_length:
|
||||
logging.info(
|
||||
@ -410,7 +411,7 @@ class OpenAIDatasetMultiTurn(torch.utils.data.Dataset):
|
||||
labels = torch.tensor(labels, dtype=torch.int64) # [: self.max_token_length]
|
||||
|
||||
# fbank = speech[0, :, :]
|
||||
fbank_lens = speech_lengths
|
||||
# fbank_lens = torch.tensor(fbank_lens, dtype=torch.int32)
|
||||
fbank_mask = torch.tensor(fbank_mask, dtype=torch.float32)
|
||||
fbank_beg = torch.tensor(fbank_beg, dtype=torch.int32)
|
||||
fake_token_len = torch.tensor(fake_token_len, dtype=torch.int32)
|
||||
|
||||
@ -990,12 +990,14 @@ class LLMASR4(nn.Module):
|
||||
text: (Batch, Length)
|
||||
text_lengths: (Batch,)
|
||||
"""
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
if len(speech_lengths.size()) > 1:
|
||||
speech_lengths = speech_lengths[:, 0]
|
||||
|
||||
batch_size, frames, _ = speech.shape
|
||||
batch_size_speech, frames, _ = speech.shape
|
||||
batch_size, token_num = input_ids.shape
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
# audio encoder
|
||||
@ -1008,38 +1010,34 @@ class LLMASR4(nn.Module):
|
||||
inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)
|
||||
|
||||
batch_size, token_num, dims = inputs_embeds.shape
|
||||
fbank_mask[fbank_mask < 0] = 0
|
||||
fbank_fake_lens = fbank_mask.sum(-1).to(torch.int32)
|
||||
# _, l, _ = encoder_out.shape
|
||||
fake_token_len = kwargs.get("fake_token_len")
|
||||
fake_token_len[fake_token_len < 0] = 0
|
||||
fbank_beg[fbank_beg < 0] = 0
|
||||
speech_idx = 0
|
||||
for batch_idx in range(batch_size):
|
||||
|
||||
for turn_id in range(fbank_beg.shape[1]):
|
||||
fbank_beg_idx = fbank_beg[batch_idx, turn_id].item()
|
||||
if fbank_beg[turn_id] > 0:
|
||||
if fbank_beg[batch_idx, turn_id] > 0:
|
||||
speech_token_len = fake_token_len[batch_idx, turn_id]
|
||||
speech_token = encoder_out[batch_idx + turn_id, turn_id, :speech_token_len, :]
|
||||
speech_token = encoder_out[speech_idx, :speech_token_len, :]
|
||||
|
||||
fbank_fake_len = fbank_fake_lens[batch_idx].item()
|
||||
fbank_beg_idx = fbank_beg[batch_idx, 0].item()
|
||||
min_len = min(fbank_fake_len, inputs_embeds.shape[1] - fbank_beg_idx)
|
||||
try:
|
||||
inputs_embeds[
|
||||
batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
|
||||
] = speech_token
|
||||
except Exception as e:
|
||||
logging.error(f"{str(e)}, {traceback.format_exc()}")
|
||||
logging.info(
|
||||
f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens[speech_idx].item()}"
|
||||
)
|
||||
speech_token_len = encoder_out_lens[speech_idx].item()
|
||||
speech_token = encoder_out[speech_idx, turn_id, :speech_token_len, :]
|
||||
inputs_embeds[
|
||||
batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
|
||||
] = speech_token
|
||||
|
||||
try:
|
||||
inputs_embeds[batch_idx, fbank_beg_idx : fbank_beg_idx + min_len, :] = encoder_out[
|
||||
batch_idx, :min_len, :
|
||||
]
|
||||
except Exception as e:
|
||||
logging.error(f"{str(e)}, {traceback.format_exc()}")
|
||||
logging.info(
|
||||
f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, min_len: {min_len}, fbank_fake_len: {fbank_fake_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens[batch_idx].item()}"
|
||||
)
|
||||
fbank_fake_len = encoder_out_lens[batch_idx].item()
|
||||
min_len = min(fbank_fake_len, min_len)
|
||||
inputs_embeds[batch_idx, fbank_beg_idx : fbank_beg_idx + min_len, :] = encoder_out[
|
||||
batch_idx, :min_len, :
|
||||
]
|
||||
speech_idx += 1
|
||||
|
||||
with torch.cuda.amp.autocast(
|
||||
enabled=True if self.llm_dtype != "fp32" else False, dtype=dtype_map[self.llm_dtype]
|
||||
@ -1061,7 +1059,8 @@ class LLMASR4(nn.Module):
|
||||
|
||||
stats["loss"] = torch.clone(loss.detach())
|
||||
stats["batch_size"] = batch_size
|
||||
stats["batch_size_x_frames"] = frames * batch_size
|
||||
stats["batch_size_speech"] = batch_size_speech
|
||||
stats["batch_size_x_frames"] = frames * batch_size_speech
|
||||
stats["batch_size_real_frames"] = speech_lengths.sum().item()
|
||||
stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
|
||||
stats["batch_size_x_tokens"] = token_num * batch_size
|
||||
|
||||
Loading…
Reference in New Issue
Block a user