Merge branch 'dev_gzf_deepspeed' of gitlab.alibaba-inc.com:zhifu.gzf/FunASR into dev_gzf_deepspeed

merge
This commit is contained in:
游雁 2024-09-13 14:36:24 +08:00
commit 89c1dd5f08
4 changed files with 68 additions and 38 deletions

View File

@ -5,6 +5,7 @@ from funasr.models.llm_asr.diffusion_models.matcha_decoder import (Decoder, Cond
import logging
from funasr.utils.hinter import hint_once
import time
from funasr.train_utils.set_all_random_seed import set_all_random_seed
class BASECFM(torch.nn.Module, ABC):
@ -47,6 +48,7 @@ class BASECFM(torch.nn.Module, ABC):
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
set_all_random_seed(0)
z = torch.randn_like(mu) * temperature
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
if self.t_scheduler == 'cosine':

View File

@ -3005,9 +3005,10 @@ class LLMASRXvecSlotTTS(nn.Module):
text_token = text_token[:-1]
return text_token
@torch.no_grad()
def generate_speech_one_step(
self,
text: str,
text: str, preds: str,
last_t_size,
llm_cur_kv_cache,
llm_cur_kv_cache_len,
@ -3016,24 +3017,40 @@ class LLMASRXvecSlotTTS(nn.Module):
tts_text_chunk_size,
chunk_idx,
is_last,
para_len=30,
para_phone_len=200,
):
device = llm_cur_kv_cache.device
pounc = ["", "", "", "", "", ".", "?", "!", ";", "\n"]
# remove duplicated pounctuations
normed_text = []
for i, c in enumerate(text):
if i > 0 and text[i - 1] in pounc and text[i] in pounc:
normed_preds = []
for i, c in enumerate(preds):
if i > 0 and preds[i - 1] in pounc and preds[i] in pounc:
continue
normed_text.append(c)
text = "".join(normed_text)
normed_preds.append(c)
normed_preds = "".join(normed_preds)
idx = -1
for p in pounc:
str_idx = normed_preds.find(p)
if str_idx > -1:
preds = self.split_characters_and_words(normed_preds[:str_idx])
idx = len(preds)
preds.append(normed_preds[str_idx])
preds.extend(self.split_characters_and_words(normed_preds[str_idx+1:]))
break
_text = f"<|endofprompt|><|sil|>{text+normed_preds}" + ("<|sil|>" if is_last else "")
para_end = False
if idx > -1 and not is_last:
pre_part = "".join(preds[:idx+1])
if len(self.tts_tokenizer_warpper(text+pre_part)) >= para_phone_len:
_text = f"<|endofprompt|><|sil|>{text+pre_part}<|sil|>"
para_end = True
cur_token, feat, wav = None, None, None
_text = f"<|endofprompt|><|sil|>{text}" + ("<|sil|>" if is_last else "")
text_token = self.tts_tokenizer_warpper(_text)
t_size = len(text_token)
if (t_size - last_t_size) >= tts_text_chunk_size or is_last:
if (t_size - last_t_size) >= tts_text_chunk_size or is_last or para_end:
text_token = torch.tensor([text_token], dtype=torch.long, device=device)
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.long, device=device)
cur_token, feat = self.tts_model.streaming_one_step(
@ -3072,19 +3089,17 @@ class LLMASRXvecSlotTTS(nn.Module):
cur_token, feat, wav = None, None, None
# post process
if not para_end:
last_t_size = t_size
# restart a new paragraph
# char_words = self.split_characters_and_words(text)
# if len(char_words) > para_len:
# # find the last pounc to split paragraph
# idx = -1
# for i in range(len(char_words)-1, -1, -1):
# if char_words[i] in pounc:
# idx = i
# break
# if idx > 0:
# text = text[idx+1:]
# last_t_size = len(self.tts_tokenizer_warpper(text))
if para_end:
text = "".join(preds[idx + 1:])
last_t_size = 0
prompt_token, prompt_audio = [None, None], [None, None]
wav = torch.cat([wav, torch.zeros([1, 2205]).to(wav)], dim=1)
chunk_idx = 0
else:
text = text + normed_preds
return ((cur_token, feat, wav), (text, last_t_size, prompt_token, prompt_audio, chunk_idx))
@ -3098,7 +3113,7 @@ class LLMASRXvecSlotTTS(nn.Module):
channels=1,
)
mp3_buffer = BytesIO()
mp3.export(mp3_buffer, format="mp3", bitrate="48k")
mp3.export(mp3_buffer, format="mp3", bitrate="192k")
# we should return this to web page.
mp3_bytes_data = mp3_buffer.getvalue()
@ -3127,9 +3142,9 @@ class LLMASRXvecSlotTTS(nn.Module):
)[0]
is_last = st + chunk_size >= preds.shape[1]
new_text = new_text + _resp
# new_text = new_text + _resp
rt_value, states = self.generate_speech_one_step(
new_text,
new_text, _resp,
last_t_size,
llm_cur_kv_cache,
llm_cur_kv_cache_len,
@ -3187,9 +3202,10 @@ class LLMASRXvecSlotTTS(nn.Module):
states["prompt_audio"],
states["chunk_idx"],
)
new_text = new_text + preds
# new_text = new_text + preds
with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32):
rt_value, states_ret = self.generate_speech_one_step(
new_text,
new_text, preds,
last_t_size,
llm_cur_kv_cache,
llm_cur_kv_cache_len,

View File

@ -10,6 +10,7 @@ from funasr.utils.hinter import hint_once
from funasr.models.transformer.utils.nets_utils import pad_list
import numpy as np
import random
from funasr.train_utils.set_all_random_seed import set_all_random_seed
def norm_and_sample_xvec(xvec, xvec_lengths):
@ -335,6 +336,7 @@ class UpsampleCtcTokenDiffModel(nn.Module):
)
# forward FM model
set_all_random_seed(0)
feat = self.fm_model.inference(
aligned_token_emb, aligned_token_lens,
prompt=dict(
@ -1001,6 +1003,7 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel):
cur_token_len = cur_token_len - token_hop_len
# forward FM model
set_all_random_seed(0)
feat = self.fm_model.inference(
cur_token, cur_token_len,
xvec, xvec_lengths,
@ -1011,7 +1014,7 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel):
**kwargs,
)
feat = self.rms_rescale_feat(feat)
print_token = tokens.cpu().squeeze().tolist()
print_token = tokens.cpu().squeeze(0).squeeze(-1).tolist()
logging.info(f"valid_tokens: {print_token[:len(print_token) - token_hop_len]}, "
f"pad_tokens: {print_token[len(print_token) - token_hop_len:]}.")

View File

@ -612,6 +612,15 @@ async def ws_serve(websocket, path):
websocket.streaming_state["previous_asr_text"] = ""
websocket.streaming_state["previous_s2tt_text"] = ""
if not websocket.is_speaking:
if is_alpha_ending(websocket.streaming_state["onscreen_asr_res"]):
websocket.streaming_state["onscreen_asr_res"] += "."
elif is_chinese_ending(websocket.streaming_state["onscreen_asr_res"]):
websocket.streaming_state["onscreen_asr_res"] += ""
if is_alpha_ending(websocket.streaming_state["onscreen_s2tt_res"]):
websocket.streaming_state["onscreen_s2tt_res"] += "."
elif is_chinese_ending(websocket.streaming_state["onscreen_s2tt_res"]):
websocket.streaming_state["onscreen_s2tt_res"] += ""
message = json.dumps(
{
"mode": "online",