mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Dev gzf exp (#1593)
* update * update with main (#1582) * update * Expose the max_end_silence_time to the user (#1532) * update * update * update * update * update * update * update * update * update * finetune * finetune * finetune * finetune * finetune * finetune * fix: resolve IndexError when using spk model and the audio contains only 1 segment (#1535) * install requirements automatically * install requirements automatically * install requirements automatically * install requirements automatically * install requirements automatically * install requirements automatically * install requirements automatically * v1.0.19 * v1.0.19 * train * train * docs * update * update * update * update * update * update * update * train update * bugfix seg_dict_file * bugfix seg_dict_file * train * train * train (#1548) * Dev gzf new (#1551) * train * train * <funasr>: <punc online> (#1552) 1.修正添加标点时英文首单词和第二个单词被错误合并的问题。 Co-authored-by: carl.che <carl.che@cloudminds.com> * Dev gzf new (#1553) * train * train * train * train * train * train * train * train * Dev gzf new (#1554) * train * train * train * train * train * train * train * train * train * train * train * Dev gzf new (#1555) * train * train * train * train * train * train * train * train * train * train * train * train * train * 修正commit87b62d6895引入的英文整句标点预测导致末尾两个单词中间的空格被删除的问题。 (#1556) * <funasr>: <punc online> 1.修正添加标点时英文首单词和第二个单词被错误合并的问题。 * <funasr>: <punc online> 1.修正commit87b62d6895引入的英文整句标点预测导致末尾两个单词中间的空格被删除的问题。 --------- Co-authored-by: carl.che <carl.che@cloudminds.com> * Dev gzf new (#1557) * train * train * train * train * train * train * train * train * train * train * train * train * train * train * Dev gzf new (#1559) * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * Dev gzf new (#1561) * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * Dev gzf new (#1562) * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * Dev gzf new (#1567) * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * whisper_lib for sense voice * aishell recipe * sense voice (#1568) * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * whisper_lib for sense voice * aishell recipe * sense voice * Dev gzf new (#1574) * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * whisper_lib for sense voice * aishell recipe * sense voice * docs * bugfix (#1580) * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * whisper_lib for sense voice * aishell recipe * sense voice * docs * bugfix * v1.0.20 --------- Co-authored-by: BOBOTANG <tzfjobmail@gmail.com> Co-authored-by: Atomie CHEN <atomic_cwh@163.com> Co-authored-by: Carl <415692979@qq.com> Co-authored-by: carl.che <carl.che@cloudminds.com> * ctc * ctc * ctc * ctc * update with main (#1592) * update * Expose the max_end_silence_time to the user (#1532) * update * update * update * update * update * update * update * update * update * finetune * finetune * finetune * finetune * finetune * finetune * fix: resolve IndexError when using spk model and the audio contains only 1 segment (#1535) * install requirements automatically * install requirements automatically * install requirements automatically * install requirements automatically * install requirements automatically * install requirements automatically * install requirements automatically * v1.0.19 * v1.0.19 * train * train * docs * update * update * update * update * update * update * update * train update * bugfix seg_dict_file * bugfix seg_dict_file * train * train * train (#1548) * Dev gzf new (#1551) * train * train * <funasr>: <punc online> (#1552) 1.修正添加标点时英文首单词和第二个单词被错误合并的问题。 Co-authored-by: carl.che <carl.che@cloudminds.com> * Dev gzf new (#1553) * train * train * train * train * train * train * train * train * Dev gzf new (#1554) * train * train * train * train * train * train * train * train * train * train * train * Dev gzf new (#1555) * train * train * train * train * train * train * train * train * train * train * train * train * train * 修正commit87b62d6895引入的英文整句标点预测导致末尾两个单词中间的空格被删除的问题。 (#1556) * <funasr>: <punc online> 1.修正添加标点时英文首单词和第二个单词被错误合并的问题。 * <funasr>: <punc online> 1.修正commit87b62d6895引入的英文整句标点预测导致末尾两个单词中间的空格被删除的问题。 --------- Co-authored-by: carl.che <carl.che@cloudminds.com> * Dev gzf new (#1557) * train * train * train * train * train * train * train * train * train * train * train * train * train * train * Dev gzf new (#1559) * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * Dev gzf new (#1561) * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * Dev gzf new (#1562) * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * Dev gzf new (#1567) * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * whisper_lib for sense voice * aishell recipe * sense voice (#1568) * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * whisper_lib for sense voice * aishell recipe * sense voice * Dev gzf new (#1574) * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * whisper_lib for sense voice * aishell recipe * sense voice * docs * bugfix (#1580) * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * whisper_lib for sense voice * aishell recipe * sense voice * docs * bugfix * v1.0.20 * update demo page (#1585) * commit web page vue * optimize web page * optimize web page * remove other private component * modify web page * Update index.vue * Update lxwjzxfw.vue * Update sstx.vue * update static file --------- Co-authored-by: BOBOTANG <tzfjobmail@gmail.com> Co-authored-by: Atomie CHEN <atomic_cwh@163.com> Co-authored-by: Carl <415692979@qq.com> Co-authored-by: carl.che <carl.che@cloudminds.com> Co-authored-by: bltcn <blt@tom.com> * sensevoice * sensevoice --------- Co-authored-by: BOBOTANG <tzfjobmail@gmail.com> Co-authored-by: Atomie CHEN <atomic_cwh@163.com> Co-authored-by: Carl <415692979@qq.com> Co-authored-by: carl.che <carl.che@cloudminds.com> Co-authored-by: bltcn <blt@tom.com>
This commit is contained in:
parent
c1a492a96e
commit
d19f48e174
@ -38,7 +38,13 @@ class WhisperFrontend(nn.Module):
|
||||
if whisper_model == "large-v3" or whisper_model == "large":
|
||||
self.n_mels = 128
|
||||
|
||||
self.mel_filters = whisper.audio.mel_filters
|
||||
filters_path = kwargs.get("filters_path", None)
|
||||
self.filters_path = filters_path
|
||||
if filters_path is not None:
|
||||
from funasr.models.sense_voice.whisper_lib.audio import mel_filters
|
||||
self.mel_filters = mel_filters
|
||||
else:
|
||||
self.mel_filters = whisper.audio.mel_filters
|
||||
self.do_pad_trim = do_pad_trim
|
||||
if do_pad_trim:
|
||||
self.pad_or_trim = whisper.pad_or_trim
|
||||
@ -61,8 +67,10 @@ class WhisperFrontend(nn.Module):
|
||||
|
||||
# whisper deletes the last frame by default (Shih-Lun)
|
||||
magnitudes = stft[..., :-1].abs() ** 2
|
||||
|
||||
filters = self.mel_filters(audio.device, self.n_mels)
|
||||
if self.filters_path is not None:
|
||||
filters = self.mel_filters(audio.device, self.n_mels, self.filters_path)
|
||||
else:
|
||||
filters = self.mel_filters(audio.device, self.n_mels)
|
||||
mel_spec = filters @ magnitudes
|
||||
|
||||
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||
@ -86,6 +94,7 @@ class WhisperFrontend(nn.Module):
|
||||
batch_size = input.size(0)
|
||||
feats = []
|
||||
feats_lens = []
|
||||
input = input.to(torch.float32)
|
||||
for i in range(batch_size):
|
||||
if self.do_pad_trim:
|
||||
feat = self.pad_or_trim(input[i], self.pad_samples)
|
||||
|
||||
@ -366,7 +366,7 @@ class LLMASRNARPrompt(nn.Module):
|
||||
decoder_conf: dict = None,
|
||||
ctc: str = None,
|
||||
ctc_conf: dict = None,
|
||||
ctc_weight: float = 0.5,
|
||||
ctc_weight: float = 0.0,
|
||||
llm: str = None,
|
||||
llm_conf: dict = None,
|
||||
adaptor: str = None,
|
||||
@ -473,6 +473,15 @@ class LLMASRNARPrompt(nn.Module):
|
||||
|
||||
self.length_normalized_loss = length_normalized_loss
|
||||
self.beam_search = None
|
||||
if ctc_weight > 0.0:
|
||||
if ctc_conf is None:
|
||||
ctc_conf = {}
|
||||
|
||||
ctc = CTC(
|
||||
odim=vocab_size, encoder_output_size=adaptor_conf["encoder_dim"], **ctc_conf
|
||||
)
|
||||
self.ctc_weight = ctc_weight
|
||||
self.ctc = ctc
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -502,9 +511,23 @@ class LLMASRNARPrompt(nn.Module):
|
||||
speech_lengths = speech_lengths[:, 0]
|
||||
|
||||
batch_size = speech.shape[0]
|
||||
|
||||
|
||||
stats = {}
|
||||
# audio encoder
|
||||
encoder_out, encoder_out_lens, loss_pre = self.encode(speech, speech_lengths, audio_mask=audio_mask)
|
||||
outs = self.encode(speech, speech_lengths, audio_mask=audio_mask)
|
||||
enc, enc_lens = outs[0], outs[1]
|
||||
encoder_out, encoder_out_lens, loss_pre = outs[2], outs[3], outs[4]
|
||||
|
||||
|
||||
# decoder: CTC branch
|
||||
|
||||
if self.ctc_weight != 0.0:
|
||||
loss_ctc, cer_ctc = self._calc_ctc_loss(
|
||||
enc, enc_lens, text, text_lengths
|
||||
)
|
||||
|
||||
# Collect CTC branch stats
|
||||
stats["loss_ctc"] = torch.clone(loss_ctc.detach()) if loss_ctc is not None else None
|
||||
|
||||
# adaptor
|
||||
encoder_out = self.adaptor(encoder_out)
|
||||
@ -536,17 +559,19 @@ class LLMASRNARPrompt(nn.Module):
|
||||
# labels_ids[1:] -> [prompt, input, target, eos] -> [-1, input, target, eos];
|
||||
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids)
|
||||
loss_llm = model_outputs.loss
|
||||
stats["loss_llm"] = torch.clone(loss_llm.detach())
|
||||
if self.ctc_weight > 0.0:
|
||||
loss_llm = self.ctc_weight * loss_ctc + loss_llm
|
||||
loss = loss_llm + loss_pre * self.predictor_weight
|
||||
stats = {}
|
||||
|
||||
with torch.no_grad():
|
||||
preds = torch.argmax(model_outputs.logits, -1)
|
||||
acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
|
||||
stats["acc"] = acc_att
|
||||
|
||||
|
||||
stats["loss_pre"] = torch.clone(loss_pre.detach())
|
||||
stats["loss_llm"] = torch.clone(loss_llm.detach())
|
||||
stats["loss"] = torch.clone(loss.detach())
|
||||
stats["batch_size"] = batch_size
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
if self.length_normalized_loss:
|
||||
@ -576,7 +601,24 @@ class LLMASRNARPrompt(nn.Module):
|
||||
if audio_token_lengths is not None:
|
||||
loss_pre = self.criterion_pre(audio_token_lengths.type_as(pre_token_length), pre_token_length)
|
||||
|
||||
return pre_acoustic_embeds, pre_token_length, loss_pre
|
||||
return enc, enc_lens, pre_acoustic_embeds, pre_token_length, loss_pre
|
||||
|
||||
def _calc_ctc_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
):
|
||||
# Calc CTC loss
|
||||
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
|
||||
|
||||
# Calc CER using CTC
|
||||
cer_ctc = None
|
||||
if not self.training and self.error_calculator is not None:
|
||||
ys_hat = self.ctc.argmax(encoder_out).data
|
||||
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
|
||||
return loss_ctc, cer_ctc
|
||||
|
||||
def inference(self,
|
||||
data_in,
|
||||
@ -648,7 +690,8 @@ class LLMASRNARPrompt(nn.Module):
|
||||
else:
|
||||
inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids)
|
||||
|
||||
inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out, pad[None, :, :]), dim=1) # [prompt, audio]
|
||||
# inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out, pad[None, :, :]), dim=1) # [prompt, audio, pad]
|
||||
inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out), dim=1) # [prompt, audio]
|
||||
attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(kwargs["device"])
|
||||
|
||||
# model_outputs = self.llm.generate(
|
||||
|
||||
@ -91,7 +91,11 @@ class SenseVoice(nn.Module):
|
||||
# decode the audio
|
||||
|
||||
# initial_prompt = kwargs.get("initial_prompt", "<|startoftranscript|><|ASR|>")
|
||||
options = whisper.DecodingOptions(language=language, fp16=False, without_timestamps=True, initial_prompt=initial_prompt)
|
||||
|
||||
vocab_path = kwargs.get("vocab_path", None)
|
||||
options = whisper.DecodingOptions(language=language, fp16=False, without_timestamps=True, initial_prompt=initial_prompt, vocab_path=vocab_path)
|
||||
|
||||
|
||||
result = whisper.decode(self.model, speech, options)
|
||||
|
||||
results = []
|
||||
|
||||
@ -89,7 +89,7 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def mel_filters(device, n_mels: int) -> torch.Tensor:
|
||||
def mel_filters(device, n_mels: int, filters_path: str=None) -> torch.Tensor:
|
||||
"""
|
||||
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
||||
Allows decoupling librosa dependency; saved using:
|
||||
@ -101,8 +101,8 @@ def mel_filters(device, n_mels: int) -> torch.Tensor:
|
||||
)
|
||||
"""
|
||||
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
|
||||
|
||||
filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
|
||||
if filters_path is None:
|
||||
filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
|
||||
with np.load(filters_path, allow_pickle=False) as f:
|
||||
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
||||
|
||||
|
||||
@ -119,6 +119,7 @@ class DecodingOptions:
|
||||
|
||||
# FIX(funasr): sense vocie
|
||||
initial_prompt: str = None
|
||||
vocab_path: str = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -527,6 +528,7 @@ class DecodingTask:
|
||||
num_languages=model.num_languages,
|
||||
language=language,
|
||||
task=options.task,
|
||||
vocab_path=options.vocab_path
|
||||
)
|
||||
self.tokenizer: Tokenizer = tokenizer
|
||||
self.options: DecodingOptions = self._verify_options(options)
|
||||
@ -616,10 +618,13 @@ class DecodingTask:
|
||||
+ prompt_tokens[-(self.n_ctx // 2 - 1) :]
|
||||
+ tokens
|
||||
)
|
||||
#FIX(gzf): sense vocie
|
||||
#FIX(funasr): sense vocie
|
||||
if initial_prompt := self.options.initial_prompt:
|
||||
tokens = self.tokenizer.encode(initial_prompt, allowed_special="all")
|
||||
if self.options.language is None:
|
||||
if self.options.language is not None:
|
||||
initial_prompt = f"{initial_prompt}<|{self.options.language}|>"
|
||||
tokens = self.tokenizer.encode(initial_prompt, allowed_special="all")
|
||||
else:
|
||||
tokens = self.tokenizer.encode(initial_prompt, allowed_special="all")
|
||||
tokens += [0]
|
||||
|
||||
|
||||
@ -691,6 +696,7 @@ class DecodingTask:
|
||||
if self.options.language is None:
|
||||
# tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
|
||||
languages = "".join([f"<|{language}|>" for language in languages])
|
||||
|
||||
n_audio = audio_features.shape[0]
|
||||
lang_tokens = torch.tensor([self.tokenizer.encode(languages, allowed_special="all")] * n_audio).to(
|
||||
audio_features.device) # [n_audio, 1]
|
||||
|
||||
@ -363,8 +363,10 @@ class Tokenizer:
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_encoding(name: str = "gpt2", num_languages: int = 99):
|
||||
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
||||
def get_encoding(name: str = "gpt2", num_languages: int = 99, vocab_path:str=None):
|
||||
if vocab_path is None:
|
||||
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
||||
|
||||
ranks = {
|
||||
base64.b64decode(token): int(rank)
|
||||
for token, rank in (line.split() for line in open(vocab_path) if line)
|
||||
@ -423,6 +425,7 @@ def get_tokenizer(
|
||||
language: Optional[str] = None,
|
||||
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
||||
encoding_path: Optional[str] = None,
|
||||
vocab_path: Optional[str] = None,
|
||||
) -> Tokenizer:
|
||||
if language is not None:
|
||||
language = language.lower()
|
||||
@ -443,7 +446,9 @@ def get_tokenizer(
|
||||
if encoding_path is not None:
|
||||
encoding_name = encoding_path
|
||||
|
||||
encoding = get_encoding(name=encoding_name, num_languages=num_languages)
|
||||
|
||||
encoding = get_encoding(name=encoding_name, num_languages=num_languages, vocab_path=vocab_path)
|
||||
|
||||
|
||||
return Tokenizer(
|
||||
encoding=encoding, num_languages=num_languages, language=language, task=task
|
||||
|
||||
@ -1 +1 @@
|
||||
1.0.20
|
||||
1.0.22
|
||||
Loading…
Reference in New Issue
Block a user