mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge branch 'dev_gzf_deepspeed' of http://gitlab.alibaba-inc.com/zhifu.gzf/FunASR into dev_gzf_deepspeed
This commit is contained in:
commit
1b9300a4c3
@ -2297,11 +2297,11 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
# e2e tts model related
|
||||
from funasr.models.llm_asr.tts_models.e2e_model import UCTDXvecSlotModel
|
||||
|
||||
self.tts_model = UCTDXvecSlotModel(**kwargs.get("tts_model_conf", {}))
|
||||
self.tts_model = UCTDXvecSlotModel(**kwargs.get("tts_model_conf", {})).to(torch.float32)
|
||||
# vocoder related
|
||||
vocoder_name = kwargs.get("vocoder", None)
|
||||
vocoder_conf = kwargs.get("vocoder_conf", None)
|
||||
self.vocoder = self.build_vocoder(name=vocoder_name, conf=vocoder_conf)
|
||||
self.vocoder = self.build_vocoder(name=vocoder_name, conf=vocoder_conf).to(torch.float32)
|
||||
|
||||
import os
|
||||
|
||||
@ -2935,6 +2935,38 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
|
||||
return results, meta_data
|
||||
|
||||
def prepare_k_v_cache(self, inputs_embeds, contents, batch, source_ids, meta_data, **kwargs):
|
||||
|
||||
# tts related inference, require the kv cache of llm last layer for only the current inputs
|
||||
# TODO: select kv cache of the current turn inputs
|
||||
attention_mask = batch.get("attention_mask", None)
|
||||
model_outputs = self.llm(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=None,
|
||||
labels=None,
|
||||
)
|
||||
hidden_states = model_outputs.hidden_states[-1].float()
|
||||
|
||||
llm_cur_kv_cache, llm_cur_kv_cache_len = None, None
|
||||
|
||||
input_mask_beg = batch.get("input_mask_beg")[-1][None, :]
|
||||
input_mask_beg[input_mask_beg < 0] = 0
|
||||
input_mask = batch.get("input_mask")[-1][None, :]
|
||||
input_mask[input_mask < 0] = 0
|
||||
|
||||
for turn_id_cum in range(input_mask.shape[0]):
|
||||
beg = input_mask_beg[turn_id_cum].sum(-1)
|
||||
end = input_mask[turn_id_cum].sum(-1)
|
||||
llm_cur_kv_cache = hidden_states[:, beg:end, :]
|
||||
llm_cur_kv_cache_len = torch.tensor(
|
||||
[
|
||||
end - beg,
|
||||
],
|
||||
dtype=torch.int32,
|
||||
).to(inputs_embeds.device)
|
||||
|
||||
return llm_cur_kv_cache, llm_cur_kv_cache_len
|
||||
|
||||
def generate_speech(self, text, llm_cur_kv_cache, llm_cur_kv_cache_len, llm_dtype):
|
||||
# self.tts_text_tokenizer = self.tts_text_tokenizer
|
||||
self.vocoder.to(llm_dtype)
|
||||
@ -3058,12 +3090,12 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
|
||||
def convert_wav_to_mp3(self, wav: torch.Tensor):
|
||||
wav = wav.detach().cpu().numpy()
|
||||
wav = (wav * (2**15-1) * 0.8).astype(np.int16)
|
||||
wav = (wav * (2**15 - 1) * 0.8).astype(np.int16)
|
||||
mp3 = AudioSegment(
|
||||
wav.tobytes(),
|
||||
sample_width=16 // 8, # Sample width in bytes
|
||||
frame_rate=22050,
|
||||
channels=1
|
||||
channels=1,
|
||||
)
|
||||
mp3_buffer = BytesIO()
|
||||
mp3.export(mp3_buffer, format="mp3", bitrate="48k")
|
||||
@ -3124,9 +3156,61 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
speech_tokens = torch.cat(token_list, dim=1)
|
||||
mel_feats = torch.cat(feat_list, dim=2)
|
||||
wav = torch.cat(wav_list, dim=1)
|
||||
mp3 = b''.join(mp3_list)
|
||||
mp3 = b"".join(mp3_list)
|
||||
return speech_tokens, mel_feats, wav, mp3
|
||||
|
||||
def reset_generate_states(self, states={}):
|
||||
|
||||
if states is None:
|
||||
states = {}
|
||||
states["new_text"] = ""
|
||||
states["last_t_size"] = 0
|
||||
states["prompt_token"] = [None, None]
|
||||
states["prompt_audio"] = [None, None]
|
||||
states["chunk_idx"] = 0
|
||||
|
||||
def streaming_generate_speech(
|
||||
self,
|
||||
preds,
|
||||
states,
|
||||
llm_cur_kv_cache,
|
||||
llm_cur_kv_cache_len,
|
||||
is_last=False,
|
||||
text_chunk_size=8,
|
||||
format="mp3",
|
||||
):
|
||||
|
||||
new_text, last_t_size, prompt_token, prompt_audio, chunk_idx = (
|
||||
states["new_text"],
|
||||
states["last_t_size"],
|
||||
states["prompt_token"],
|
||||
states["prompt_audio"],
|
||||
states["chunk_idx"],
|
||||
)
|
||||
new_text = new_text + preds
|
||||
rt_value, states_ret = self.generate_speech_one_step(
|
||||
new_text,
|
||||
last_t_size,
|
||||
llm_cur_kv_cache,
|
||||
llm_cur_kv_cache_len,
|
||||
prompt_token,
|
||||
prompt_audio,
|
||||
text_chunk_size,
|
||||
chunk_idx,
|
||||
is_last,
|
||||
)
|
||||
cur_token, feat, wav = rt_value
|
||||
new_text, last_t_size, prompt_token, prompt_audio, chunk_idx = states_ret
|
||||
states["new_text"] = new_text
|
||||
states["last_t_size"] = last_t_size
|
||||
states["prompt_token"] = prompt_token
|
||||
states["prompt_audio"] = prompt_audio
|
||||
states["chunk_idx"] = chunk_idx
|
||||
if format == "mp3":
|
||||
if cur_token is not None:
|
||||
wav = self.convert_wav_to_mp3(wav)
|
||||
return cur_token, feat, wav
|
||||
|
||||
def write_mel_wav(self, output_dir, feat, wav, mp3, key):
|
||||
out_dir = os.path.join(output_dir, "1best_recog", "mels")
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user