This commit is contained in:
游雁 2024-08-19 14:57:25 +08:00
parent 3e0bd69f83
commit f33138ab2e
2 changed files with 17 additions and 12 deletions

View File

@ -168,14 +168,13 @@ def main(**kwargs):
for epoch in range(trainer.start_epoch, trainer.max_epoch):
time1 = time.perf_counter()
if trainer.kwargs.get("do_train", True):
for data_split_i in range(trainer.start_data_split_i, dataloader.data_split_num):
time_slice_i = time.perf_counter()
dataloader_tr, dataloader_val = dataloader.build_iter(
epoch, data_split_i=data_split_i, start_step=trainer.start_step
)
for data_split_i in range(trainer.start_data_split_i, dataloader.data_split_num):
time_slice_i = time.perf_counter()
dataloader_tr, dataloader_val = dataloader.build_iter(
epoch, data_split_i=data_split_i, start_step=trainer.start_step
)
if trainer.kwargs.get("do_train", True):
trainer.train_epoch(
model=model,
optim=optim,

View File

@ -1592,6 +1592,10 @@ class LLMASR4_extract_kv(nn.Module):
self.llm = model.to(dtype_map[self.llm_dtype])
llm_dim = model.get_input_embeddings().weight.shape[-1]
self.kv_cache_outdir = llm_conf.get("kv_cache_outdir", None)
if self.kv_cache_outdir is not None:
import os
os.makedirs(self.kv_cache_outdir, exist_ok=True)
# adaptor
adaptor_class = tables.adaptor_classes.get(audio_adaptor)
@ -1714,14 +1718,16 @@ class LLMASR4_extract_kv(nn.Module):
input_mask[input_mask < 0] = 0
hidden_states = model_outputs.hidden_states[-1].float()
key = kwargs.get("key")
key = kwargs.get("key")[0]
kv_cache_outdir = self.kv_cache_outdir
savemat(f"{kv_cache_outdir}/{key}.mat", {"kv_cache": hidden_states[0]})
mat_file = f"{kv_cache_outdir}/{key}.mat"
savemat(mat_file, {"kv_cache": hidden_states[0].cpu()})
with open(f"{kv_cache_outdir}/{key}.txt", "w") as f:
f.write(f"{kv_cache_outdir}/{key}.mat")
for turn_id_cum in range(input_mask.shape[1]):
for turn_id_cum in range(input_mask.shape[0]):
end = input_mask[turn_id_cum].sum(-1)
f.write(f"\t{end}")
line = f"{key}.assistent.{turn_id_cum} {mat_file} {end}"
f.write(line)
f.flush()
stats = {}
with torch.no_grad():