This commit is contained in:
游雁 2024-08-20 13:56:52 +08:00
parent c75040f1be
commit ed9fd49d46
2 changed files with 18 additions and 14 deletions

View File

@ -1332,13 +1332,13 @@ class OpenAIDatasetMultiTurnCodecMel2(torch.utils.data.Dataset):
fbank.append(speech[0, :, :])
fbank_lens.append(speech_lengths)
# filter
# if i == multiturn_num - 1:
ratio = codec_i_len / len(token_num_tts)
if ratio < 1 or ratio > 7:
badcase_flag = True
if codec_i_len + len(token_num_tts) > 1500:
badcase_flag = True
# # filter
# # if i == multiturn_num - 1:
# ratio = codec_i_len / len(token_num_tts)
# if ratio < 1 or ratio > 7:
# badcase_flag = True
# if codec_i_len + len(token_num_tts) > 1500:
# badcase_flag = True
if badcase_flag:
continue

View File

@ -1597,6 +1597,8 @@ class LLMASR4_extract_kv(nn.Module):
import os
os.makedirs(self.kv_cache_outdir, exist_ok=True)
os.makedirs(f"{self.kv_cache_outdir}/mat", exist_ok=True)
os.makedirs(f"{self.kv_cache_outdir}/txt", exist_ok=True)
# adaptor
adaptor_class = tables.adaptor_classes.get(audio_adaptor)
@ -1624,6 +1626,7 @@ class LLMASR4_extract_kv(nn.Module):
rank = int(os.environ.get("RANK", 0))
logging.info(f"rank: {rank}, model is builded.")
self.fo = open(f"{self.kv_cache_outdir}/txt/{rank}.txt", "w")
def forward(
self,
@ -1727,14 +1730,15 @@ class LLMASR4_extract_kv(nn.Module):
hidden_states = model_outputs.hidden_states[-1].float()
key = kwargs.get("key")[0]
kv_cache_outdir = self.kv_cache_outdir
mat_file = f"{kv_cache_outdir}/{key}.mat"
mat_file = f"{kv_cache_outdir}/mat/{key}.mat"
savemat(mat_file, {"kv_cache": hidden_states[0].cpu()})
with open(f"{kv_cache_outdir}/{key}.txt", "w") as f:
for turn_id_cum in range(input_mask.shape[0]):
end = input_mask[turn_id_cum].sum(-1)
line = f"{key}.assistent.{turn_id_cum} {mat_file} {end}\n"
f.write(line)
f.flush()
for turn_id_cum in range(input_mask.shape[0]):
end = input_mask[turn_id_cum].sum(-1)
uttid = f"{key}_assistant_{turn_id_cum:02d}"
line = f"{uttid} {mat_file} {end}\n"
self.fo.write(line)
self.fo.flush()
with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1)