mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
kv cache
This commit is contained in:
parent
3e0bd69f83
commit
f33138ab2e
@ -168,14 +168,13 @@ def main(**kwargs):
|
||||
for epoch in range(trainer.start_epoch, trainer.max_epoch):
|
||||
time1 = time.perf_counter()
|
||||
|
||||
if trainer.kwargs.get("do_train", True):
|
||||
for data_split_i in range(trainer.start_data_split_i, dataloader.data_split_num):
|
||||
time_slice_i = time.perf_counter()
|
||||
|
||||
dataloader_tr, dataloader_val = dataloader.build_iter(
|
||||
epoch, data_split_i=data_split_i, start_step=trainer.start_step
|
||||
)
|
||||
for data_split_i in range(trainer.start_data_split_i, dataloader.data_split_num):
|
||||
time_slice_i = time.perf_counter()
|
||||
|
||||
dataloader_tr, dataloader_val = dataloader.build_iter(
|
||||
epoch, data_split_i=data_split_i, start_step=trainer.start_step
|
||||
)
|
||||
if trainer.kwargs.get("do_train", True):
|
||||
trainer.train_epoch(
|
||||
model=model,
|
||||
optim=optim,
|
||||
|
||||
@ -1592,6 +1592,10 @@ class LLMASR4_extract_kv(nn.Module):
|
||||
self.llm = model.to(dtype_map[self.llm_dtype])
|
||||
llm_dim = model.get_input_embeddings().weight.shape[-1]
|
||||
self.kv_cache_outdir = llm_conf.get("kv_cache_outdir", None)
|
||||
if self.kv_cache_outdir is not None:
|
||||
import os
|
||||
|
||||
os.makedirs(self.kv_cache_outdir, exist_ok=True)
|
||||
|
||||
# adaptor
|
||||
adaptor_class = tables.adaptor_classes.get(audio_adaptor)
|
||||
@ -1714,14 +1718,16 @@ class LLMASR4_extract_kv(nn.Module):
|
||||
input_mask[input_mask < 0] = 0
|
||||
|
||||
hidden_states = model_outputs.hidden_states[-1].float()
|
||||
key = kwargs.get("key")
|
||||
key = kwargs.get("key")[0]
|
||||
kv_cache_outdir = self.kv_cache_outdir
|
||||
savemat(f"{kv_cache_outdir}/{key}.mat", {"kv_cache": hidden_states[0]})
|
||||
mat_file = f"{kv_cache_outdir}/{key}.mat"
|
||||
savemat(mat_file, {"kv_cache": hidden_states[0].cpu()})
|
||||
with open(f"{kv_cache_outdir}/{key}.txt", "w") as f:
|
||||
f.write(f"{kv_cache_outdir}/{key}.mat")
|
||||
for turn_id_cum in range(input_mask.shape[1]):
|
||||
for turn_id_cum in range(input_mask.shape[0]):
|
||||
end = input_mask[turn_id_cum].sum(-1)
|
||||
f.write(f"\t{end}")
|
||||
line = f"{key}.assistent.{turn_id_cum} {mat_file} {end}"
|
||||
f.write(line)
|
||||
f.flush()
|
||||
|
||||
stats = {}
|
||||
with torch.no_grad():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user