mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge branch 'dev_gzf_deepspeed' of http://gitlab.alibaba-inc.com/zhifu.gzf/FunASR into dev_gzf_deepspeed
This commit is contained in:
commit
4ca208e061
@ -28,6 +28,8 @@ from funasr.train_utils.device_funcs import to_device
|
|||||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
|
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
|
||||||
from funasr.train_utils.set_all_random_seed import set_all_random_seed
|
from funasr.train_utils.set_all_random_seed import set_all_random_seed
|
||||||
import traceback
|
import traceback
|
||||||
|
from pydub import AudioSegment
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -2790,7 +2792,7 @@ class LLMASRXvecSlotTTS(nn.Module):
|
|||||||
] = speech_token
|
] = speech_token
|
||||||
|
|
||||||
speech_idx += 1
|
speech_idx += 1
|
||||||
return inputs_embeds, contents, batch, source_ids, meta_data, output
|
return inputs_embeds, contents, batch, source_ids, meta_data
|
||||||
|
|
||||||
def inference(
|
def inference(
|
||||||
self,
|
self,
|
||||||
@ -2802,7 +2804,7 @@ class LLMASRXvecSlotTTS(nn.Module):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
inputs_embeds, contents, batch, source_ids, meta_data, outputs = self.inference_prepare(
|
inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare(
|
||||||
data_in, data_lengths, key, tokenizer, frontend, **kwargs
|
data_in, data_lengths, key, tokenizer, frontend, **kwargs
|
||||||
)
|
)
|
||||||
rand_seed = kwargs.get("rand_seed", 0)
|
rand_seed = kwargs.get("rand_seed", 0)
|
||||||
@ -2926,10 +2928,10 @@ class LLMASRXvecSlotTTS(nn.Module):
|
|||||||
# speech_tokens, mel, wav = self.generate_speech(
|
# speech_tokens, mel, wav = self.generate_speech(
|
||||||
# response, llm_cur_kv_cache, llm_cur_kv_cache_len, dtype_map[tts_dtype]
|
# response, llm_cur_kv_cache, llm_cur_kv_cache_len, dtype_map[tts_dtype]
|
||||||
# )
|
# )
|
||||||
speech_tokens, mel, wav = self.simulate_streaming_generate_speech(
|
speech_tokens, mel, wav, mp3 = self.simulate_streaming_generate_speech(
|
||||||
target_ids, llm_cur_kv_cache, llm_cur_kv_cache_len, dtype_map[tts_dtype], tokenizer
|
target_ids, llm_cur_kv_cache, llm_cur_kv_cache_len, dtype_map[tts_dtype], tokenizer
|
||||||
)
|
)
|
||||||
self.write_mel_wav(kwargs.get("output_dir"), mel, wav, key[0])
|
self.write_mel_wav(kwargs.get("output_dir"), mel, wav, mp3, key[0])
|
||||||
|
|
||||||
return results, meta_data
|
return results, meta_data
|
||||||
|
|
||||||
@ -2959,7 +2961,7 @@ class LLMASRXvecSlotTTS(nn.Module):
|
|||||||
|
|
||||||
def split_characters_and_words(self, input_string):
|
def split_characters_and_words(self, input_string):
|
||||||
# 定义正则表达式模式
|
# 定义正则表达式模式
|
||||||
pattern = r'[\u4e00-\u9fff]|[\w]+|[^\w\s]'
|
pattern = r"[\u4e00-\u9fff]|[\w]+|[^\w\s]"
|
||||||
# 使用 re.findall 找到所有匹配的字符和单词
|
# 使用 re.findall 找到所有匹配的字符和单词
|
||||||
results = re.findall(pattern, input_string)
|
results = re.findall(pattern, input_string)
|
||||||
return results
|
return results
|
||||||
@ -2973,13 +2975,19 @@ class LLMASRXvecSlotTTS(nn.Module):
|
|||||||
|
|
||||||
def generate_speech_one_step(
|
def generate_speech_one_step(
|
||||||
self,
|
self,
|
||||||
text: str, last_t_size,
|
text: str,
|
||||||
llm_cur_kv_cache, llm_cur_kv_cache_len,
|
last_t_size,
|
||||||
prompt_token, prompt_audio, tts_text_chunk_size,
|
llm_cur_kv_cache,
|
||||||
chunk_idx, is_last, para_len=30,
|
llm_cur_kv_cache_len,
|
||||||
|
prompt_token,
|
||||||
|
prompt_audio,
|
||||||
|
tts_text_chunk_size,
|
||||||
|
chunk_idx,
|
||||||
|
is_last,
|
||||||
|
para_len=30,
|
||||||
):
|
):
|
||||||
device = llm_cur_kv_cache.device
|
device = llm_cur_kv_cache.device
|
||||||
pounc = ['。', '?', '!', ';', ':', '.', '?', '!', ';', '\n']
|
pounc = ["。", "?", "!", ";", ":", ".", "?", "!", ";", "\n"]
|
||||||
|
|
||||||
# remove duplicated pounctuations
|
# remove duplicated pounctuations
|
||||||
normed_text = []
|
normed_text = []
|
||||||
@ -2997,8 +3005,10 @@ class LLMASRXvecSlotTTS(nn.Module):
|
|||||||
text_token = torch.tensor([text_token], dtype=torch.long, device=device)
|
text_token = torch.tensor([text_token], dtype=torch.long, device=device)
|
||||||
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.long, device=device)
|
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.long, device=device)
|
||||||
cur_token, feat = self.tts_model.streaming_one_step(
|
cur_token, feat = self.tts_model.streaming_one_step(
|
||||||
text_token, text_token_len,
|
text_token,
|
||||||
xvec=None, xvec_lengths=None,
|
text_token_len,
|
||||||
|
xvec=None,
|
||||||
|
xvec_lengths=None,
|
||||||
prompt_dict={
|
prompt_dict={
|
||||||
"prompt_token": prompt_token,
|
"prompt_token": prompt_token,
|
||||||
"prompt_audio": prompt_audio,
|
"prompt_audio": prompt_audio,
|
||||||
@ -3011,8 +3021,14 @@ class LLMASRXvecSlotTTS(nn.Module):
|
|||||||
if cur_token is not None and cur_token.shape[1] > 0 and feat.shape[2] > 0:
|
if cur_token is not None and cur_token.shape[1] > 0 and feat.shape[2] > 0:
|
||||||
# process first package, token in B,T,D, feat in B,F,T
|
# process first package, token in B,T,D, feat in B,F,T
|
||||||
if prompt_token[0] is None:
|
if prompt_token[0] is None:
|
||||||
prompt_token = [cur_token, torch.tensor([cur_token.shape[1]], dtype=torch.long, device=device)]
|
prompt_token = [
|
||||||
prompt_audio = [feat.transpose(1, 2), torch.tensor([feat.shape[2]], dtype=torch.long, device=device)]
|
cur_token,
|
||||||
|
torch.tensor([cur_token.shape[1]], dtype=torch.long, device=device),
|
||||||
|
]
|
||||||
|
prompt_audio = [
|
||||||
|
feat.transpose(1, 2),
|
||||||
|
torch.tensor([feat.shape[2]], dtype=torch.long, device=device),
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
prompt_token[1] = prompt_token[1] + cur_token.shape[1]
|
prompt_token[1] = prompt_token[1] + cur_token.shape[1]
|
||||||
prompt_token[0] = torch.concat([prompt_token[0], cur_token], dim=1)
|
prompt_token[0] = torch.concat([prompt_token[0], cur_token], dim=1)
|
||||||
@ -3038,10 +3054,27 @@ class LLMASRXvecSlotTTS(nn.Module):
|
|||||||
# text = text[idx+1:]
|
# text = text[idx+1:]
|
||||||
# last_t_size = len(self.tts_tokenizer_warpper(text))
|
# last_t_size = len(self.tts_tokenizer_warpper(text))
|
||||||
|
|
||||||
return ((cur_token, feat, wav),
|
return ((cur_token, feat, wav), (text, last_t_size, prompt_token, prompt_audio, chunk_idx))
|
||||||
(text, last_t_size, prompt_token, prompt_audio, chunk_idx))
|
|
||||||
|
|
||||||
def simulate_streaming_generate_speech(self, preds, llm_cur_kv_cache, llm_cur_kv_cache_len, llm_dtype, llm_tokenizer):
|
def convert_wav_to_mp3(self, wav: torch.Tensor):
|
||||||
|
wav = wav.detach().cpu().numpy()
|
||||||
|
wav = (wav * (2**15-1) * 0.8).astype(np.int16)
|
||||||
|
mp3 = AudioSegment(
|
||||||
|
wav.tobytes(),
|
||||||
|
sample_width=16 // 8, # Sample width in bytes
|
||||||
|
frame_rate=22050,
|
||||||
|
channels=1
|
||||||
|
)
|
||||||
|
mp3_buffer = BytesIO()
|
||||||
|
mp3.export(mp3_buffer, format="mp3", bitrate="48k")
|
||||||
|
# we should return this to web page.
|
||||||
|
mp3_bytes_data = mp3_buffer.getvalue()
|
||||||
|
|
||||||
|
return mp3_bytes_data
|
||||||
|
|
||||||
|
def simulate_streaming_generate_speech(
|
||||||
|
self, preds, llm_cur_kv_cache, llm_cur_kv_cache_len, llm_dtype, llm_tokenizer
|
||||||
|
):
|
||||||
# self.tts_text_tokenizer = self.tts_text_tokenizer
|
# self.tts_text_tokenizer = self.tts_text_tokenizer
|
||||||
self.vocoder.to(llm_dtype)
|
self.vocoder.to(llm_dtype)
|
||||||
self.tts_model.to(llm_dtype)
|
self.tts_model.to(llm_dtype)
|
||||||
@ -3049,7 +3082,7 @@ class LLMASRXvecSlotTTS(nn.Module):
|
|||||||
text_chunk_size = 8
|
text_chunk_size = 8
|
||||||
given_rtf = 0.5
|
given_rtf = 0.5
|
||||||
|
|
||||||
token_list, feat_list, wav_list = [], [], []
|
token_list, feat_list, wav_list, mp3_list = [], [], [], []
|
||||||
prompt_token, prompt_audio = [None, None], [None, None]
|
prompt_token, prompt_audio = [None, None], [None, None]
|
||||||
new_text, last_t_size, chunk_idx = "", 0, 0
|
new_text, last_t_size, chunk_idx = "", 0, 0
|
||||||
st, count = 0, 0
|
st, count = 0, 0
|
||||||
@ -3060,15 +3093,19 @@ class LLMASRXvecSlotTTS(nn.Module):
|
|||||||
add_special_tokens=False,
|
add_special_tokens=False,
|
||||||
skip_special_tokens=True,
|
skip_special_tokens=True,
|
||||||
)[0]
|
)[0]
|
||||||
is_last = (st + chunk_size >= preds.shape[1])
|
is_last = st + chunk_size >= preds.shape[1]
|
||||||
|
|
||||||
new_text = new_text + _resp
|
new_text = new_text + _resp
|
||||||
rt_value, states = self.generate_speech_one_step(
|
rt_value, states = self.generate_speech_one_step(
|
||||||
new_text, last_t_size,
|
new_text,
|
||||||
llm_cur_kv_cache, llm_cur_kv_cache_len,
|
last_t_size,
|
||||||
prompt_token, prompt_audio,
|
llm_cur_kv_cache,
|
||||||
|
llm_cur_kv_cache_len,
|
||||||
|
prompt_token,
|
||||||
|
prompt_audio,
|
||||||
text_chunk_size,
|
text_chunk_size,
|
||||||
chunk_idx, is_last,
|
chunk_idx,
|
||||||
|
is_last,
|
||||||
)
|
)
|
||||||
cur_token, feat, wav = rt_value
|
cur_token, feat, wav = rt_value
|
||||||
new_text, last_t_size, prompt_token, prompt_audio, chunk_idx = states
|
new_text, last_t_size, prompt_token, prompt_audio, chunk_idx = states
|
||||||
@ -3076,7 +3113,10 @@ class LLMASRXvecSlotTTS(nn.Module):
|
|||||||
if cur_token is not None and feat is not None and wav is not None:
|
if cur_token is not None and feat is not None and wav is not None:
|
||||||
token_list.append(cur_token)
|
token_list.append(cur_token)
|
||||||
feat_list.append(feat)
|
feat_list.append(feat)
|
||||||
|
# we should return this data to web page for playing.
|
||||||
|
mp3_data = self.convert_wav_to_mp3(wav)
|
||||||
wav_list.append(wav)
|
wav_list.append(wav)
|
||||||
|
mp3_list.append(mp3_data)
|
||||||
|
|
||||||
st += chunk_size
|
st += chunk_size
|
||||||
count += 1
|
count += 1
|
||||||
@ -3084,9 +3124,10 @@ class LLMASRXvecSlotTTS(nn.Module):
|
|||||||
speech_tokens = torch.cat(token_list, dim=1)
|
speech_tokens = torch.cat(token_list, dim=1)
|
||||||
mel_feats = torch.cat(feat_list, dim=2)
|
mel_feats = torch.cat(feat_list, dim=2)
|
||||||
wav = torch.cat(wav_list, dim=1)
|
wav = torch.cat(wav_list, dim=1)
|
||||||
return speech_tokens, mel_feats, wav
|
mp3 = b''.join(mp3_list)
|
||||||
|
return speech_tokens, mel_feats, wav, mp3
|
||||||
|
|
||||||
def write_mel_wav(self, output_dir, feat, wav, key):
|
def write_mel_wav(self, output_dir, feat, wav, mp3, key):
|
||||||
out_dir = os.path.join(output_dir, "1best_recog", "mels")
|
out_dir = os.path.join(output_dir, "1best_recog", "mels")
|
||||||
os.makedirs(out_dir, exist_ok=True)
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
if feat is not None:
|
if feat is not None:
|
||||||
@ -3104,6 +3145,11 @@ class LLMASRXvecSlotTTS(nn.Module):
|
|||||||
encoding="PCM_S",
|
encoding="PCM_S",
|
||||||
bits_per_sample=16,
|
bits_per_sample=16,
|
||||||
)
|
)
|
||||||
|
if mp3 is not None:
|
||||||
|
path = os.path.join(out_dir, f"{key}.mp3")
|
||||||
|
fd = open(path, "wb")
|
||||||
|
fd.write(mp3)
|
||||||
|
fd.close()
|
||||||
|
|
||||||
|
|
||||||
class Swish(torch.nn.Module):
|
class Swish(torch.nn.Module):
|
||||||
|
|||||||
@ -80,6 +80,7 @@ class NlsTtsSynthesizer:
|
|||||||
self.started = True
|
self.started = True
|
||||||
|
|
||||||
def send_text(self, text):
|
def send_text(self, text):
|
||||||
|
if len(text) > 0:
|
||||||
self.sdk.sendStreamInputTts(text)
|
self.sdk.sendStreamInputTts(text)
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user