mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge branch 'dev_gzf_deepspeed' of github.com:alibaba-damo-academy/FunASR into dev_gzf_deepspeed
merge
This commit is contained in:
commit
4272292ffd
@ -2175,6 +2175,7 @@ class LLMASR5(nn.Module):
|
|||||||
inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare(
|
inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare(
|
||||||
data_in, data_lengths, key, tokenizer, frontend, **kwargs
|
data_in, data_lengths, key, tokenizer, frontend, **kwargs
|
||||||
)
|
)
|
||||||
|
rand_seed = kwargs.get("rand_seed", 0)
|
||||||
|
|
||||||
llm_dtype = kwargs.get("llm_dtype", "fp32")
|
llm_dtype = kwargs.get("llm_dtype", "fp32")
|
||||||
if llm_dtype == "fp32":
|
if llm_dtype == "fp32":
|
||||||
@ -2189,7 +2190,7 @@ class LLMASR5(nn.Module):
|
|||||||
inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype])
|
inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype])
|
||||||
|
|
||||||
# set random seed for reproduce
|
# set random seed for reproduce
|
||||||
set_all_random_seed(0)
|
set_all_random_seed(rand_seed)
|
||||||
generated_ids = self.llm.generate(
|
generated_ids = self.llm.generate(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
max_new_tokens=kwargs.get("max_length", 512),
|
max_new_tokens=kwargs.get("max_length", 512),
|
||||||
@ -2229,7 +2230,7 @@ class LLMASR5(nn.Module):
|
|||||||
hidden_states_select = self.fusion_act(self.fusion_norm(outs))
|
hidden_states_select = self.fusion_act(self.fusion_norm(outs))
|
||||||
|
|
||||||
# set random seed for reproduce
|
# set random seed for reproduce
|
||||||
set_all_random_seed(0)
|
set_all_random_seed(rand_seed)
|
||||||
speech_tokens = self.audio_decode(hidden_states_select, hidden_states_out_len)[
|
speech_tokens = self.audio_decode(hidden_states_select, hidden_states_out_len)[
|
||||||
:, :, 0
|
:, :, 0
|
||||||
] # 1xlx1: 2,10,1023
|
] # 1xlx1: 2,10,1023
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user