add batch support for token extraction

This commit is contained in:
志浩 2024-09-24 17:59:02 +08:00
parent 752abbb3ca
commit 4b840fd668
2 changed files with 6 additions and 1 deletions

View File

@ -330,6 +330,8 @@ class AutoModel:
if run_mode == "extract_token" and hasattr(model, "writer"):
model.writer.close()
if hasattr(model, "len_writer"):
model.len_writer.close()
if pbar:
# pbar.update(1)

View File

@ -2050,16 +2050,19 @@ class SenseVoiceL(nn.Module):
results.append(result_i)
ark_writer = None
ark_writer, len_writer = None, None
if kwargs.get("output_dir") is not None:
out_dir = kwargs.get("output_dir")
os.makedirs(out_dir, exist_ok=True)
if not hasattr(self, "writer"):
out_path = os.path.join(out_dir, f"enc_token")
self.writer = kaldiio.WriteHelper(f"ark,scp,f:{out_path}.ark,{out_path}.scp")
self.len_writer = open(out_path+"_len.txt", "wt")
ark_writer = self.writer
len_writer = self.len_writer
if ark_writer is not None:
for k, v, l in zip(key, tokens.detach().cpu().numpy(), out_lens):
ark_writer(k, v[:l])
len_writer.write(f"{k}\t{l}\n")
return results, meta_data