mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
8f6d2787f0
commit
ef5ea9b05f
@ -610,6 +610,8 @@ class OpenAIDatasetMultiTurnCodec(torch.utils.data.Dataset):
|
||||
fake_token_len_i = 0
|
||||
fbank_beg_i = -1
|
||||
fbank_lens_i = []
|
||||
speech = []
|
||||
speech_lengths = []
|
||||
for k, sub_str in enumerate(splits):
|
||||
if not sub_str.startswith("<|startofspeech|>"):
|
||||
sub_token = self.tokenizer.encode(sub_str)
|
||||
@ -688,9 +690,11 @@ class OpenAIDatasetMultiTurnCodec(torch.utils.data.Dataset):
|
||||
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_mask + target_ids
|
||||
fbank.append(speech[0, :, :])
|
||||
|
||||
fbank_mask += fbank_mask_i
|
||||
fbank_lens.append(speech_lengths)
|
||||
if len(speech) > 0:
|
||||
fbank.append(speech[0, :, :])
|
||||
fbank_lens.append(speech_lengths)
|
||||
|
||||
if badcase_flag:
|
||||
continue
|
||||
@ -706,8 +710,6 @@ class OpenAIDatasetMultiTurnCodec(torch.utils.data.Dataset):
|
||||
fake_token_len = torch.tensor(fake_token_len, dtype=torch.int32)
|
||||
|
||||
output = {
|
||||
"speech": fbank,
|
||||
"speech_lengths": fbank_lens,
|
||||
"fbank_mask": fbank_mask,
|
||||
"fbank_beg": fbank_beg,
|
||||
"fake_token_len": fake_token_len,
|
||||
@ -719,6 +721,10 @@ class OpenAIDatasetMultiTurnCodec(torch.utils.data.Dataset):
|
||||
codec_len = torch.tensor(codec_len, dtype=torch.int32)
|
||||
output["codec"] = codec
|
||||
output["codec_len"] = codec_len
|
||||
if len(fbank) > 0:
|
||||
output["speech"] = fbank
|
||||
output["speech_lengths"] = fbank_lens
|
||||
|
||||
break
|
||||
|
||||
return output
|
||||
|
||||
@ -974,13 +974,13 @@ class LLMASR4(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
labels_ids: torch.Tensor,
|
||||
fbank_beg: torch.Tensor,
|
||||
fbank_mask: torch.Tensor,
|
||||
speech: torch.Tensor = None,
|
||||
speech_lengths: torch.Tensor = None,
|
||||
input_ids: torch.Tensor = None,
|
||||
attention_mask: torch.Tensor = None,
|
||||
labels_ids: torch.Tensor = None,
|
||||
fbank_beg: torch.Tensor = None,
|
||||
fbank_mask: torch.Tensor = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Encoder + Decoder + Calc loss
|
||||
@ -993,55 +993,55 @@ class LLMASR4(nn.Module):
|
||||
# import pdb
|
||||
#
|
||||
# pdb.set_trace()
|
||||
if len(speech_lengths.size()) > 1:
|
||||
speech_lengths = speech_lengths[:, 0]
|
||||
|
||||
batch_size_speech, frames, _ = speech.shape
|
||||
batch_size, token_num = input_ids.shape
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
# audio encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
|
||||
# audio_adaptor
|
||||
encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
|
||||
|
||||
input_ids[input_ids < 0] = 0
|
||||
inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)
|
||||
if speech is not None:
|
||||
if len(speech_lengths.size()) > 1:
|
||||
speech_lengths = speech_lengths[:, 0]
|
||||
|
||||
batch_size, token_num, dims = inputs_embeds.shape
|
||||
fake_token_len = kwargs.get("fake_token_len")
|
||||
fake_token_len[fake_token_len < 0] = 0
|
||||
fbank_beg[fbank_beg < 0] = 0
|
||||
batch_size_speech, frames, _ = speech.shape
|
||||
batch_size, token_num = input_ids.shape
|
||||
|
||||
speech_idx = 0
|
||||
for batch_idx in range(batch_size):
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
# audio encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
|
||||
for turn_id in range(fbank_beg.shape[1]):
|
||||
fbank_beg_idx = fbank_beg[batch_idx, turn_id].item()
|
||||
if fbank_beg_idx > 0:
|
||||
speech_token_len = fake_token_len[batch_idx, turn_id]
|
||||
speech_token = encoder_out[speech_idx, :speech_token_len, :]
|
||||
# audio_adaptor
|
||||
encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
|
||||
|
||||
try:
|
||||
inputs_embeds[
|
||||
batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
|
||||
] = speech_token
|
||||
except Exception as e:
|
||||
#
|
||||
logging.error(f"{str(e)}, {traceback.format_exc()}")
|
||||
logging.info(
|
||||
f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}"
|
||||
)
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
speech_token_len = encoder_out_lens[speech_idx].item()
|
||||
batch_size, token_num, dims = inputs_embeds.shape
|
||||
fake_token_len = kwargs.get("fake_token_len")
|
||||
fake_token_len[fake_token_len < 0] = 0
|
||||
fbank_beg[fbank_beg < 0] = 0
|
||||
|
||||
speech_idx = 0
|
||||
for batch_idx in range(batch_size):
|
||||
|
||||
for turn_id in range(fbank_beg.shape[1]):
|
||||
fbank_beg_idx = fbank_beg[batch_idx, turn_id].item()
|
||||
if fbank_beg_idx > 0:
|
||||
speech_token_len = fake_token_len[batch_idx, turn_id]
|
||||
speech_token = encoder_out[speech_idx, :speech_token_len, :]
|
||||
inputs_embeds[
|
||||
batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
|
||||
] = speech_token
|
||||
|
||||
speech_idx += 1
|
||||
try:
|
||||
inputs_embeds[
|
||||
batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
|
||||
] = speech_token
|
||||
except Exception as e:
|
||||
#
|
||||
logging.error(f"{str(e)}, {traceback.format_exc()}")
|
||||
logging.info(
|
||||
f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}"
|
||||
)
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
speech_token_len = encoder_out_lens[speech_idx].item()
|
||||
speech_token = encoder_out[speech_idx, :speech_token_len, :]
|
||||
inputs_embeds[
|
||||
batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
|
||||
] = speech_token
|
||||
|
||||
speech_idx += 1
|
||||
|
||||
with torch.cuda.amp.autocast(
|
||||
enabled=True if self.llm_dtype != "fp32" else False, dtype=dtype_map[self.llm_dtype]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user