set random seed

This commit is contained in:
志浩 2024-07-12 11:31:15 +08:00
parent 158c22ca4c
commit 686eed1231

View File

@ -23,6 +23,7 @@ from funasr.utils.datadir_writer import DatadirWriter
from funasr.register import tables from funasr.register import tables
from funasr.train_utils.device_funcs import to_device from funasr.train_utils.device_funcs import to_device
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.train_utils.set_all_random_seed import set_all_random_seed
import traceback import traceback
dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32} dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
@ -2203,6 +2204,8 @@ class LLMASR5(nn.Module):
self.llm = self.llm.to(dtype_map[llm_dtype]) self.llm = self.llm.to(dtype_map[llm_dtype])
inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype]) inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype])
# set random seed for reproduce
set_all_random_seed(0)
generated_ids = self.llm.generate( generated_ids = self.llm.generate(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
max_new_tokens=kwargs.get("max_length", 512), max_new_tokens=kwargs.get("max_length", 512),
@ -2241,6 +2244,8 @@ class LLMASR5(nn.Module):
outs = self.audio_decoder_in_proj(torch.cat([outs, emb], dim=-1)) outs = self.audio_decoder_in_proj(torch.cat([outs, emb], dim=-1))
hidden_states_select = self.fusion_act(self.fusion_norm(outs)) hidden_states_select = self.fusion_act(self.fusion_norm(outs))
# set random seed for reproduce
set_all_random_seed(0)
speech_tokens = self.audio_decode(hidden_states_select, hidden_states_out_len)[ speech_tokens = self.audio_decode(hidden_states_select, hidden_states_out_len)[
:, :, 0 :, :, 0
] # 1xlx1: 2,10,1023 ] # 1xlx1: 2,10,1023