mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
kv
This commit is contained in:
parent
c75040f1be
commit
ed9fd49d46
@ -1332,13 +1332,13 @@ class OpenAIDatasetMultiTurnCodecMel2(torch.utils.data.Dataset):
|
||||
fbank.append(speech[0, :, :])
|
||||
fbank_lens.append(speech_lengths)
|
||||
|
||||
# filter
|
||||
# if i == multiturn_num - 1:
|
||||
ratio = codec_i_len / len(token_num_tts)
|
||||
if ratio < 1 or ratio > 7:
|
||||
badcase_flag = True
|
||||
if codec_i_len + len(token_num_tts) > 1500:
|
||||
badcase_flag = True
|
||||
# # filter
|
||||
# # if i == multiturn_num - 1:
|
||||
# ratio = codec_i_len / len(token_num_tts)
|
||||
# if ratio < 1 or ratio > 7:
|
||||
# badcase_flag = True
|
||||
# if codec_i_len + len(token_num_tts) > 1500:
|
||||
# badcase_flag = True
|
||||
|
||||
if badcase_flag:
|
||||
continue
|
||||
|
||||
@ -1597,6 +1597,8 @@ class LLMASR4_extract_kv(nn.Module):
|
||||
import os
|
||||
|
||||
os.makedirs(self.kv_cache_outdir, exist_ok=True)
|
||||
os.makedirs(f"{self.kv_cache_outdir}/mat", exist_ok=True)
|
||||
os.makedirs(f"{self.kv_cache_outdir}/txt", exist_ok=True)
|
||||
|
||||
# adaptor
|
||||
adaptor_class = tables.adaptor_classes.get(audio_adaptor)
|
||||
@ -1624,6 +1626,7 @@ class LLMASR4_extract_kv(nn.Module):
|
||||
|
||||
rank = int(os.environ.get("RANK", 0))
|
||||
logging.info(f"rank: {rank}, model is builded.")
|
||||
self.fo = open(f"{self.kv_cache_outdir}/txt/{rank}.txt", "w")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -1727,14 +1730,15 @@ class LLMASR4_extract_kv(nn.Module):
|
||||
hidden_states = model_outputs.hidden_states[-1].float()
|
||||
key = kwargs.get("key")[0]
|
||||
kv_cache_outdir = self.kv_cache_outdir
|
||||
mat_file = f"{kv_cache_outdir}/{key}.mat"
|
||||
mat_file = f"{kv_cache_outdir}/mat/{key}.mat"
|
||||
savemat(mat_file, {"kv_cache": hidden_states[0].cpu()})
|
||||
with open(f"{kv_cache_outdir}/{key}.txt", "w") as f:
|
||||
for turn_id_cum in range(input_mask.shape[0]):
|
||||
end = input_mask[turn_id_cum].sum(-1)
|
||||
line = f"{key}.assistent.{turn_id_cum} {mat_file} {end}\n"
|
||||
f.write(line)
|
||||
f.flush()
|
||||
|
||||
for turn_id_cum in range(input_mask.shape[0]):
|
||||
end = input_mask[turn_id_cum].sum(-1)
|
||||
uttid = f"{key}_assistant_{turn_id_cum:02d}"
|
||||
line = f"{uttid} {mat_file} {end}\n"
|
||||
self.fo.write(line)
|
||||
self.fo.flush()
|
||||
|
||||
with torch.no_grad():
|
||||
preds = torch.argmax(model_outputs.logits, -1)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user