Merge branch 'dev_gzf_deepspeed' of http://gitlab.alibaba-inc.com/zhifu.gzf/FunASR into dev_gzf_deepspeed

This commit is contained in:
木守 2024-09-12 19:03:07 +08:00
commit 1b9300a4c3

View File

@ -2297,11 +2297,11 @@ class LLMASRXvecSlotTTS(nn.Module):
# e2e tts model related
from funasr.models.llm_asr.tts_models.e2e_model import UCTDXvecSlotModel
self.tts_model = UCTDXvecSlotModel(**kwargs.get("tts_model_conf", {}))
self.tts_model = UCTDXvecSlotModel(**kwargs.get("tts_model_conf", {})).to(torch.float32)
# vocoder related
vocoder_name = kwargs.get("vocoder", None)
vocoder_conf = kwargs.get("vocoder_conf", None)
self.vocoder = self.build_vocoder(name=vocoder_name, conf=vocoder_conf)
self.vocoder = self.build_vocoder(name=vocoder_name, conf=vocoder_conf).to(torch.float32)
import os
@ -2935,6 +2935,38 @@ class LLMASRXvecSlotTTS(nn.Module):
return results, meta_data
def prepare_k_v_cache(self, inputs_embeds, contents, batch, source_ids, meta_data, **kwargs):
# tts related inference, require the kv cache of llm last layer for only the current inputs
# TODO: select kv cache of the current turn inputs
attention_mask = batch.get("attention_mask", None)
model_outputs = self.llm(
inputs_embeds=inputs_embeds,
attention_mask=None,
labels=None,
)
hidden_states = model_outputs.hidden_states[-1].float()
llm_cur_kv_cache, llm_cur_kv_cache_len = None, None
input_mask_beg = batch.get("input_mask_beg")[-1][None, :]
input_mask_beg[input_mask_beg < 0] = 0
input_mask = batch.get("input_mask")[-1][None, :]
input_mask[input_mask < 0] = 0
for turn_id_cum in range(input_mask.shape[0]):
beg = input_mask_beg[turn_id_cum].sum(-1)
end = input_mask[turn_id_cum].sum(-1)
llm_cur_kv_cache = hidden_states[:, beg:end, :]
llm_cur_kv_cache_len = torch.tensor(
[
end - beg,
],
dtype=torch.int32,
).to(inputs_embeds.device)
return llm_cur_kv_cache, llm_cur_kv_cache_len
def generate_speech(self, text, llm_cur_kv_cache, llm_cur_kv_cache_len, llm_dtype):
# self.tts_text_tokenizer = self.tts_text_tokenizer
self.vocoder.to(llm_dtype)
@ -3058,12 +3090,12 @@ class LLMASRXvecSlotTTS(nn.Module):
def convert_wav_to_mp3(self, wav: torch.Tensor):
wav = wav.detach().cpu().numpy()
wav = (wav * (2**15-1) * 0.8).astype(np.int16)
wav = (wav * (2**15 - 1) * 0.8).astype(np.int16)
mp3 = AudioSegment(
wav.tobytes(),
sample_width=16 // 8, # Sample width in bytes
frame_rate=22050,
channels=1
channels=1,
)
mp3_buffer = BytesIO()
mp3.export(mp3_buffer, format="mp3", bitrate="48k")
@ -3124,9 +3156,61 @@ class LLMASRXvecSlotTTS(nn.Module):
speech_tokens = torch.cat(token_list, dim=1)
mel_feats = torch.cat(feat_list, dim=2)
wav = torch.cat(wav_list, dim=1)
mp3 = b''.join(mp3_list)
mp3 = b"".join(mp3_list)
return speech_tokens, mel_feats, wav, mp3
def reset_generate_states(self, states={}):
if states is None:
states = {}
states["new_text"] = ""
states["last_t_size"] = 0
states["prompt_token"] = [None, None]
states["prompt_audio"] = [None, None]
states["chunk_idx"] = 0
def streaming_generate_speech(
self,
preds,
states,
llm_cur_kv_cache,
llm_cur_kv_cache_len,
is_last=False,
text_chunk_size=8,
format="mp3",
):
new_text, last_t_size, prompt_token, prompt_audio, chunk_idx = (
states["new_text"],
states["last_t_size"],
states["prompt_token"],
states["prompt_audio"],
states["chunk_idx"],
)
new_text = new_text + preds
rt_value, states_ret = self.generate_speech_one_step(
new_text,
last_t_size,
llm_cur_kv_cache,
llm_cur_kv_cache_len,
prompt_token,
prompt_audio,
text_chunk_size,
chunk_idx,
is_last,
)
cur_token, feat, wav = rt_value
new_text, last_t_size, prompt_token, prompt_audio, chunk_idx = states_ret
states["new_text"] = new_text
states["last_t_size"] = last_t_size
states["prompt_token"] = prompt_token
states["prompt_audio"] = prompt_audio
states["chunk_idx"] = chunk_idx
if format == "mp3":
if cur_token is not None:
wav = self.convert_wav_to_mp3(wav)
return cur_token, feat, wav
def write_mel_wav(self, output_dir, feat, wav, mp3, key):
out_dir = os.path.join(output_dir, "1best_recog", "mels")
os.makedirs(out_dir, exist_ok=True)