mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge branch 'dev_gzf_deepspeed' of github.com:alibaba-damo-academy/FunASR into dev_gzf_deepspeed
merge
This commit is contained in:
commit
4272292ffd
@ -2175,6 +2175,7 @@ class LLMASR5(nn.Module):
|
||||
inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare(
|
||||
data_in, data_lengths, key, tokenizer, frontend, **kwargs
|
||||
)
|
||||
rand_seed = kwargs.get("rand_seed", 0)
|
||||
|
||||
llm_dtype = kwargs.get("llm_dtype", "fp32")
|
||||
if llm_dtype == "fp32":
|
||||
@ -2189,7 +2190,7 @@ class LLMASR5(nn.Module):
|
||||
inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype])
|
||||
|
||||
# set random seed for reproduce
|
||||
set_all_random_seed(0)
|
||||
set_all_random_seed(rand_seed)
|
||||
generated_ids = self.llm.generate(
|
||||
inputs_embeds=inputs_embeds,
|
||||
max_new_tokens=kwargs.get("max_length", 512),
|
||||
@ -2229,7 +2230,7 @@ class LLMASR5(nn.Module):
|
||||
hidden_states_select = self.fusion_act(self.fusion_norm(outs))
|
||||
|
||||
# set random seed for reproduce
|
||||
set_all_random_seed(0)
|
||||
set_all_random_seed(rand_seed)
|
||||
speech_tokens = self.audio_decode(hidden_states_select, hidden_states_out_len)[
|
||||
:, :, 0
|
||||
] # 1xlx1: 2,10,1023
|
||||
|
||||
Loading…
Reference in New Issue
Block a user