mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge branch 'dev_gzf_deepspeed' of gitlab.alibaba-inc.com:zhifu.gzf/FunASR into dev_gzf_deepspeed
merge
This commit is contained in:
commit
89c1dd5f08
@ -5,6 +5,7 @@ from funasr.models.llm_asr.diffusion_models.matcha_decoder import (Decoder, Cond
|
||||
import logging
|
||||
from funasr.utils.hinter import hint_once
|
||||
import time
|
||||
from funasr.train_utils.set_all_random_seed import set_all_random_seed
|
||||
|
||||
|
||||
class BASECFM(torch.nn.Module, ABC):
|
||||
@ -47,6 +48,7 @@ class BASECFM(torch.nn.Module, ABC):
|
||||
sample: generated mel-spectrogram
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
set_all_random_seed(0)
|
||||
z = torch.randn_like(mu) * temperature
|
||||
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
||||
if self.t_scheduler == 'cosine':
|
||||
|
||||
@ -3005,9 +3005,10 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
text_token = text_token[:-1]
|
||||
return text_token
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_speech_one_step(
|
||||
self,
|
||||
text: str,
|
||||
text: str, preds: str,
|
||||
last_t_size,
|
||||
llm_cur_kv_cache,
|
||||
llm_cur_kv_cache_len,
|
||||
@ -3016,24 +3017,40 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
tts_text_chunk_size,
|
||||
chunk_idx,
|
||||
is_last,
|
||||
para_len=30,
|
||||
para_phone_len=200,
|
||||
):
|
||||
device = llm_cur_kv_cache.device
|
||||
pounc = ["。", "?", "!", ";", ":", ".", "?", "!", ";", "\n"]
|
||||
|
||||
# remove duplicated pounctuations
|
||||
normed_text = []
|
||||
for i, c in enumerate(text):
|
||||
if i > 0 and text[i - 1] in pounc and text[i] in pounc:
|
||||
normed_preds = []
|
||||
for i, c in enumerate(preds):
|
||||
if i > 0 and preds[i - 1] in pounc and preds[i] in pounc:
|
||||
continue
|
||||
normed_text.append(c)
|
||||
text = "".join(normed_text)
|
||||
normed_preds.append(c)
|
||||
normed_preds = "".join(normed_preds)
|
||||
idx = -1
|
||||
for p in pounc:
|
||||
str_idx = normed_preds.find(p)
|
||||
if str_idx > -1:
|
||||
preds = self.split_characters_and_words(normed_preds[:str_idx])
|
||||
idx = len(preds)
|
||||
preds.append(normed_preds[str_idx])
|
||||
preds.extend(self.split_characters_and_words(normed_preds[str_idx+1:]))
|
||||
break
|
||||
|
||||
_text = f"<|endofprompt|><|sil|>{text+normed_preds}" + ("<|sil|>" if is_last else "")
|
||||
para_end = False
|
||||
if idx > -1 and not is_last:
|
||||
pre_part = "".join(preds[:idx+1])
|
||||
if len(self.tts_tokenizer_warpper(text+pre_part)) >= para_phone_len:
|
||||
_text = f"<|endofprompt|><|sil|>{text+pre_part}<|sil|>"
|
||||
para_end = True
|
||||
|
||||
cur_token, feat, wav = None, None, None
|
||||
_text = f"<|endofprompt|><|sil|>{text}" + ("<|sil|>" if is_last else "")
|
||||
text_token = self.tts_tokenizer_warpper(_text)
|
||||
t_size = len(text_token)
|
||||
if (t_size - last_t_size) >= tts_text_chunk_size or is_last:
|
||||
if (t_size - last_t_size) >= tts_text_chunk_size or is_last or para_end:
|
||||
text_token = torch.tensor([text_token], dtype=torch.long, device=device)
|
||||
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.long, device=device)
|
||||
cur_token, feat = self.tts_model.streaming_one_step(
|
||||
@ -3072,19 +3089,17 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
cur_token, feat, wav = None, None, None
|
||||
|
||||
# post process
|
||||
if not para_end:
|
||||
last_t_size = t_size
|
||||
# restart a new paragraph
|
||||
# char_words = self.split_characters_and_words(text)
|
||||
# if len(char_words) > para_len:
|
||||
# # find the last pounc to split paragraph
|
||||
# idx = -1
|
||||
# for i in range(len(char_words)-1, -1, -1):
|
||||
# if char_words[i] in pounc:
|
||||
# idx = i
|
||||
# break
|
||||
# if idx > 0:
|
||||
# text = text[idx+1:]
|
||||
# last_t_size = len(self.tts_tokenizer_warpper(text))
|
||||
|
||||
if para_end:
|
||||
text = "".join(preds[idx + 1:])
|
||||
last_t_size = 0
|
||||
prompt_token, prompt_audio = [None, None], [None, None]
|
||||
wav = torch.cat([wav, torch.zeros([1, 2205]).to(wav)], dim=1)
|
||||
chunk_idx = 0
|
||||
else:
|
||||
text = text + normed_preds
|
||||
|
||||
return ((cur_token, feat, wav), (text, last_t_size, prompt_token, prompt_audio, chunk_idx))
|
||||
|
||||
@ -3098,7 +3113,7 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
channels=1,
|
||||
)
|
||||
mp3_buffer = BytesIO()
|
||||
mp3.export(mp3_buffer, format="mp3", bitrate="48k")
|
||||
mp3.export(mp3_buffer, format="mp3", bitrate="192k")
|
||||
# we should return this to web page.
|
||||
mp3_bytes_data = mp3_buffer.getvalue()
|
||||
|
||||
@ -3127,9 +3142,9 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
)[0]
|
||||
is_last = st + chunk_size >= preds.shape[1]
|
||||
|
||||
new_text = new_text + _resp
|
||||
# new_text = new_text + _resp
|
||||
rt_value, states = self.generate_speech_one_step(
|
||||
new_text,
|
||||
new_text, _resp,
|
||||
last_t_size,
|
||||
llm_cur_kv_cache,
|
||||
llm_cur_kv_cache_len,
|
||||
@ -3187,9 +3202,10 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
states["prompt_audio"],
|
||||
states["chunk_idx"],
|
||||
)
|
||||
new_text = new_text + preds
|
||||
# new_text = new_text + preds
|
||||
with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32):
|
||||
rt_value, states_ret = self.generate_speech_one_step(
|
||||
new_text,
|
||||
new_text, preds,
|
||||
last_t_size,
|
||||
llm_cur_kv_cache,
|
||||
llm_cur_kv_cache_len,
|
||||
|
||||
@ -10,6 +10,7 @@ from funasr.utils.hinter import hint_once
|
||||
from funasr.models.transformer.utils.nets_utils import pad_list
|
||||
import numpy as np
|
||||
import random
|
||||
from funasr.train_utils.set_all_random_seed import set_all_random_seed
|
||||
|
||||
|
||||
def norm_and_sample_xvec(xvec, xvec_lengths):
|
||||
@ -335,6 +336,7 @@ class UpsampleCtcTokenDiffModel(nn.Module):
|
||||
)
|
||||
|
||||
# forward FM model
|
||||
set_all_random_seed(0)
|
||||
feat = self.fm_model.inference(
|
||||
aligned_token_emb, aligned_token_lens,
|
||||
prompt=dict(
|
||||
@ -1001,6 +1003,7 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel):
|
||||
cur_token_len = cur_token_len - token_hop_len
|
||||
|
||||
# forward FM model
|
||||
set_all_random_seed(0)
|
||||
feat = self.fm_model.inference(
|
||||
cur_token, cur_token_len,
|
||||
xvec, xvec_lengths,
|
||||
@ -1011,7 +1014,7 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel):
|
||||
**kwargs,
|
||||
)
|
||||
feat = self.rms_rescale_feat(feat)
|
||||
print_token = tokens.cpu().squeeze().tolist()
|
||||
print_token = tokens.cpu().squeeze(0).squeeze(-1).tolist()
|
||||
logging.info(f"valid_tokens: {print_token[:len(print_token) - token_hop_len]}, "
|
||||
f"pad_tokens: {print_token[len(print_token) - token_hop_len:]}.")
|
||||
|
||||
|
||||
@ -612,6 +612,15 @@ async def ws_serve(websocket, path):
|
||||
websocket.streaming_state["previous_asr_text"] = ""
|
||||
websocket.streaming_state["previous_s2tt_text"] = ""
|
||||
if not websocket.is_speaking:
|
||||
if is_alpha_ending(websocket.streaming_state["onscreen_asr_res"]):
|
||||
websocket.streaming_state["onscreen_asr_res"] += "."
|
||||
elif is_chinese_ending(websocket.streaming_state["onscreen_asr_res"]):
|
||||
websocket.streaming_state["onscreen_asr_res"] += "。"
|
||||
|
||||
if is_alpha_ending(websocket.streaming_state["onscreen_s2tt_res"]):
|
||||
websocket.streaming_state["onscreen_s2tt_res"] += "."
|
||||
elif is_chinese_ending(websocket.streaming_state["onscreen_s2tt_res"]):
|
||||
websocket.streaming_state["onscreen_s2tt_res"] += "。"
|
||||
message = json.dumps(
|
||||
{
|
||||
"mode": "online",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user