mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
set random seed
This commit is contained in:
parent
158c22ca4c
commit
686eed1231
@ -23,6 +23,7 @@ from funasr.utils.datadir_writer import DatadirWriter
|
|||||||
from funasr.register import tables
|
from funasr.register import tables
|
||||||
from funasr.train_utils.device_funcs import to_device
|
from funasr.train_utils.device_funcs import to_device
|
||||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
|
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
|
||||||
|
from funasr.train_utils.set_all_random_seed import set_all_random_seed
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
|
dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
|
||||||
@ -2203,6 +2204,8 @@ class LLMASR5(nn.Module):
|
|||||||
self.llm = self.llm.to(dtype_map[llm_dtype])
|
self.llm = self.llm.to(dtype_map[llm_dtype])
|
||||||
inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype])
|
inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype])
|
||||||
|
|
||||||
|
# set random seed for reproduce
|
||||||
|
set_all_random_seed(0)
|
||||||
generated_ids = self.llm.generate(
|
generated_ids = self.llm.generate(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
max_new_tokens=kwargs.get("max_length", 512),
|
max_new_tokens=kwargs.get("max_length", 512),
|
||||||
@ -2241,6 +2244,8 @@ class LLMASR5(nn.Module):
|
|||||||
outs = self.audio_decoder_in_proj(torch.cat([outs, emb], dim=-1))
|
outs = self.audio_decoder_in_proj(torch.cat([outs, emb], dim=-1))
|
||||||
hidden_states_select = self.fusion_act(self.fusion_norm(outs))
|
hidden_states_select = self.fusion_act(self.fusion_norm(outs))
|
||||||
|
|
||||||
|
# set random seed for reproduce
|
||||||
|
set_all_random_seed(0)
|
||||||
speech_tokens = self.audio_decode(hidden_states_select, hidden_states_out_len)[
|
speech_tokens = self.audio_decode(hidden_states_select, hidden_states_out_len)[
|
||||||
:, :, 0
|
:, :, 0
|
||||||
] # 1xlx1: 2,10,1023
|
] # 1xlx1: 2,10,1023
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user