mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
add batch support for token extraction
This commit is contained in:
parent
752abbb3ca
commit
4b840fd668
@ -330,6 +330,8 @@ class AutoModel:
|
||||
|
||||
if run_mode == "extract_token" and hasattr(model, "writer"):
|
||||
model.writer.close()
|
||||
if hasattr(model, "len_writer"):
|
||||
model.len_writer.close()
|
||||
|
||||
if pbar:
|
||||
# pbar.update(1)
|
||||
|
||||
@ -2050,16 +2050,19 @@ class SenseVoiceL(nn.Module):
|
||||
|
||||
results.append(result_i)
|
||||
|
||||
ark_writer = None
|
||||
ark_writer, len_writer = None, None
|
||||
if kwargs.get("output_dir") is not None:
|
||||
out_dir = kwargs.get("output_dir")
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
if not hasattr(self, "writer"):
|
||||
out_path = os.path.join(out_dir, f"enc_token")
|
||||
self.writer = kaldiio.WriteHelper(f"ark,scp,f:{out_path}.ark,{out_path}.scp")
|
||||
self.len_writer = open(out_path+"_len.txt", "wt")
|
||||
ark_writer = self.writer
|
||||
len_writer = self.len_writer
|
||||
if ark_writer is not None:
|
||||
for k, v, l in zip(key, tokens.detach().cpu().numpy(), out_lens):
|
||||
ark_writer(k, v[:l])
|
||||
len_writer.write(f"{k}\t{l}\n")
|
||||
|
||||
return results, meta_data
|
||||
|
||||
Loading…
Reference in New Issue
Block a user