From 686eed1231227b09f5e128cebd188a91c06f530d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BF=97=E6=B5=A9?= Date: Fri, 12 Jul 2024 11:31:15 +0800 Subject: [PATCH] set random seed --- funasr/models/llm_asr/model.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index 3382302e2..65abd27ce 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -23,6 +23,7 @@ from funasr.utils.datadir_writer import DatadirWriter from funasr.register import tables from funasr.train_utils.device_funcs import to_device from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list +from funasr.train_utils.set_all_random_seed import set_all_random_seed import traceback dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32} @@ -2203,6 +2204,8 @@ class LLMASR5(nn.Module): self.llm = self.llm.to(dtype_map[llm_dtype]) inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype]) + # set random seed for reproduce + set_all_random_seed(0) generated_ids = self.llm.generate( inputs_embeds=inputs_embeds, max_new_tokens=kwargs.get("max_length", 512), @@ -2241,6 +2244,8 @@ class LLMASR5(nn.Module): outs = self.audio_decoder_in_proj(torch.cat([outs, emb], dim=-1)) hidden_states_select = self.fusion_act(self.fusion_norm(outs)) + # set random seed for reproduce + set_all_random_seed(0) speech_tokens = self.audio_decode(hidden_states_select, hidden_states_out_len)[ :, :, 0 ] # 1xlx1: 2,10,1023