diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index 9272f62bf..7db1eb606 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -330,6 +330,8 @@ class AutoModel: if run_mode == "extract_token" and hasattr(model, "writer"): model.writer.close() + if hasattr(model, "len_writer"): + model.len_writer.close() if pbar: # pbar.update(1) diff --git a/funasr/models/sense_voice/model_small.py b/funasr/models/sense_voice/model_small.py index 0d27b4d25..91d4a754f 100644 --- a/funasr/models/sense_voice/model_small.py +++ b/funasr/models/sense_voice/model_small.py @@ -2050,16 +2050,19 @@ class SenseVoiceL(nn.Module): results.append(result_i) - ark_writer = None + ark_writer, len_writer = None, None if kwargs.get("output_dir") is not None: out_dir = kwargs.get("output_dir") os.makedirs(out_dir, exist_ok=True) if not hasattr(self, "writer"): out_path = os.path.join(out_dir, f"enc_token") self.writer = kaldiio.WriteHelper(f"ark,scp,f:{out_path}.ark,{out_path}.scp") + self.len_writer = open(out_path+"_len.txt", "wt") ark_writer = self.writer + len_writer = self.len_writer if ark_writer is not None: for k, v, l in zip(key, tokens.detach().cpu().numpy(), out_lens): ark_writer(k, v[:l]) + len_writer.write(f"{k}\t{l}\n") return results, meta_data