Dev gzf exp (#1593)

* update

* update with main (#1582)

* update

* Expose the max_end_silence_time to the user (#1532)

* update

* update

* update

* update

* update

* update

* update

* update

* update

* finetune

* finetune

* finetune

* finetune

* finetune

* finetune

* fix: resolve IndexError when using spk model and the audio contains only 1 segment (#1535)

* install requirements automatically

* install requirements automatically

* install requirements automatically

* install requirements automatically

* install requirements automatically

* install requirements automatically

* install requirements automatically

* v1.0.19

* v1.0.19

* train

* train

* docs

* update

* update

* update

* update

* update

* update

* update

* train update

* bugfix seg_dict_file

* bugfix seg_dict_file

* train

* train

* train (#1548)

* Dev gzf new (#1551)

* train

* train

* <funasr>: <punc online> (#1552)

1.修正添加标点时英文首单词和第二个单词被错误合并的问题。

Co-authored-by: carl.che <carl.che@cloudminds.com>

* Dev gzf new (#1553)

* train

* train

* train

* train

* train

* train

* train

* train

* Dev gzf new (#1554)

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* Dev gzf new (#1555)

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* 修正commit 87b62d6895 引入的英文整句标点预测导致末尾两个单词中间的空格被删除的问题。 (#1556)

* <funasr>: <punc online>

1.修正添加标点时英文首单词和第二个单词被错误合并的问题。

* <funasr>: <punc online>

1.修正commit 87b62d6895 引入的英文整句标点预测导致末尾两个单词中间的空格被删除的问题。

---------

Co-authored-by: carl.che <carl.che@cloudminds.com>

* Dev gzf new (#1557)

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* Dev gzf new (#1559)

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* Dev gzf new (#1561)

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* Dev gzf new (#1562)

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* Dev gzf new (#1567)

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* whisper_lib for sense voice

* aishell recipe

* sense voice (#1568)

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* whisper_lib for sense voice

* aishell recipe

* sense voice

* Dev gzf new (#1574)

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* whisper_lib for sense voice

* aishell recipe

* sense voice

* docs

* bugfix (#1580)

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* whisper_lib for sense voice

* aishell recipe

* sense voice

* docs

* bugfix

* v1.0.20

---------

Co-authored-by: BOBOTANG <tzfjobmail@gmail.com>
Co-authored-by: Atomie CHEN <atomic_cwh@163.com>
Co-authored-by: Carl <415692979@qq.com>
Co-authored-by: carl.che <carl.che@cloudminds.com>

* ctc

* ctc

* ctc

* ctc

* update with main (#1592)

* update

* Expose the max_end_silence_time to the user (#1532)

* update

* update

* update

* update

* update

* update

* update

* update

* update

* finetune

* finetune

* finetune

* finetune

* finetune

* finetune

* fix: resolve IndexError when using spk model and the audio contains only 1 segment (#1535)

* install requirements automatically

* install requirements automatically

* install requirements automatically

* install requirements automatically

* install requirements automatically

* install requirements automatically

* install requirements automatically

* v1.0.19

* v1.0.19

* train

* train

* docs

* update

* update

* update

* update

* update

* update

* update

* train update

* bugfix seg_dict_file

* bugfix seg_dict_file

* train

* train

* train (#1548)

* Dev gzf new (#1551)

* train

* train

* <funasr>: <punc online> (#1552)

1.修正添加标点时英文首单词和第二个单词被错误合并的问题。

Co-authored-by: carl.che <carl.che@cloudminds.com>

* Dev gzf new (#1553)

* train

* train

* train

* train

* train

* train

* train

* train

* Dev gzf new (#1554)

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* Dev gzf new (#1555)

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* 修正commit 87b62d6895 引入的英文整句标点预测导致末尾两个单词中间的空格被删除的问题。 (#1556)

* <funasr>: <punc online>

1.修正添加标点时英文首单词和第二个单词被错误合并的问题。

* <funasr>: <punc online>

1.修正commit 87b62d6895 引入的英文整句标点预测导致末尾两个单词中间的空格被删除的问题。

---------

Co-authored-by: carl.che <carl.che@cloudminds.com>

* Dev gzf new (#1557)

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* Dev gzf new (#1559)

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* Dev gzf new (#1561)

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* Dev gzf new (#1562)

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* Dev gzf new (#1567)

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* whisper_lib for sense voice

* aishell recipe

* sense voice (#1568)

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* whisper_lib for sense voice

* aishell recipe

* sense voice

* Dev gzf new (#1574)

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* whisper_lib for sense voice

* aishell recipe

* sense voice

* docs

* bugfix (#1580)

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* whisper_lib for sense voice

* aishell recipe

* sense voice

* docs

* bugfix

* v1.0.20

* update demo page (#1585)

* commit web page vue

* optimize web page

* optimize web page

* remove other private component

* modify web page

* Update index.vue

* Update lxwjzxfw.vue

* Update sstx.vue

* update static file

---------

Co-authored-by: BOBOTANG <tzfjobmail@gmail.com>
Co-authored-by: Atomie CHEN <atomic_cwh@163.com>
Co-authored-by: Carl <415692979@qq.com>
Co-authored-by: carl.che <carl.che@cloudminds.com>
Co-authored-by: bltcn <blt@tom.com>

* sensevoice

* sensevoice

---------

Co-authored-by: BOBOTANG <tzfjobmail@gmail.com>
Co-authored-by: Atomie CHEN <atomic_cwh@163.com>
Co-authored-by: Carl <415692979@qq.com>
Co-authored-by: carl.che <carl.che@cloudminds.com>
Co-authored-by: bltcn <blt@tom.com>
This commit is contained in:
zhifu gao 2024-04-08 18:51:53 +08:00 committed by GitHub
parent c1a492a96e
commit d19f48e174
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 89 additions and 22 deletions

View File

@ -38,7 +38,13 @@ class WhisperFrontend(nn.Module):
if whisper_model == "large-v3" or whisper_model == "large":
self.n_mels = 128
self.mel_filters = whisper.audio.mel_filters
filters_path = kwargs.get("filters_path", None)
self.filters_path = filters_path
if filters_path is not None:
from funasr.models.sense_voice.whisper_lib.audio import mel_filters
self.mel_filters = mel_filters
else:
self.mel_filters = whisper.audio.mel_filters
self.do_pad_trim = do_pad_trim
if do_pad_trim:
self.pad_or_trim = whisper.pad_or_trim
@ -61,8 +67,10 @@ class WhisperFrontend(nn.Module):
# whisper deletes the last frame by default (Shih-Lun)
magnitudes = stft[..., :-1].abs() ** 2
filters = self.mel_filters(audio.device, self.n_mels)
if self.filters_path is not None:
filters = self.mel_filters(audio.device, self.n_mels, self.filters_path)
else:
filters = self.mel_filters(audio.device, self.n_mels)
mel_spec = filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
@ -86,6 +94,7 @@ class WhisperFrontend(nn.Module):
batch_size = input.size(0)
feats = []
feats_lens = []
input = input.to(torch.float32)
for i in range(batch_size):
if self.do_pad_trim:
feat = self.pad_or_trim(input[i], self.pad_samples)

View File

@ -366,7 +366,7 @@ class LLMASRNARPrompt(nn.Module):
decoder_conf: dict = None,
ctc: str = None,
ctc_conf: dict = None,
ctc_weight: float = 0.5,
ctc_weight: float = 0.0,
llm: str = None,
llm_conf: dict = None,
adaptor: str = None,
@ -473,6 +473,15 @@ class LLMASRNARPrompt(nn.Module):
self.length_normalized_loss = length_normalized_loss
self.beam_search = None
if ctc_weight > 0.0:
if ctc_conf is None:
ctc_conf = {}
ctc = CTC(
odim=vocab_size, encoder_output_size=adaptor_conf["encoder_dim"], **ctc_conf
)
self.ctc_weight = ctc_weight
self.ctc = ctc
def forward(
self,
@ -502,9 +511,23 @@ class LLMASRNARPrompt(nn.Module):
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
stats = {}
# audio encoder
encoder_out, encoder_out_lens, loss_pre = self.encode(speech, speech_lengths, audio_mask=audio_mask)
outs = self.encode(speech, speech_lengths, audio_mask=audio_mask)
enc, enc_lens = outs[0], outs[1]
encoder_out, encoder_out_lens, loss_pre = outs[2], outs[3], outs[4]
# decoder: CTC branch
if self.ctc_weight != 0.0:
loss_ctc, cer_ctc = self._calc_ctc_loss(
enc, enc_lens, text, text_lengths
)
# Collect CTC branch stats
stats["loss_ctc"] = torch.clone(loss_ctc.detach()) if loss_ctc is not None else None
# adaptor
encoder_out = self.adaptor(encoder_out)
@ -536,17 +559,19 @@ class LLMASRNARPrompt(nn.Module):
# labels_ids[1:] -> [prompt, input, target, eos] -> [-1, input, target, eos];
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids)
loss_llm = model_outputs.loss
stats["loss_llm"] = torch.clone(loss_llm.detach())
if self.ctc_weight > 0.0:
loss_llm = self.ctc_weight * loss_ctc + loss_llm
loss = loss_llm + loss_pre * self.predictor_weight
stats = {}
with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1)
acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
stats["acc"] = acc_att
stats["loss_pre"] = torch.clone(loss_pre.detach())
stats["loss_llm"] = torch.clone(loss_llm.detach())
stats["loss"] = torch.clone(loss.detach())
stats["batch_size"] = batch_size
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
@ -576,7 +601,24 @@ class LLMASRNARPrompt(nn.Module):
if audio_token_lengths is not None:
loss_pre = self.criterion_pre(audio_token_lengths.type_as(pre_token_length), pre_token_length)
return pre_acoustic_embeds, pre_token_length, loss_pre
return enc, enc_lens, pre_acoustic_embeds, pre_token_length, loss_pre
def _calc_ctc_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
# Calc CTC loss
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
# Calc CER using CTC
cer_ctc = None
if not self.training and self.error_calculator is not None:
ys_hat = self.ctc.argmax(encoder_out).data
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
return loss_ctc, cer_ctc
def inference(self,
data_in,
@ -648,7 +690,8 @@ class LLMASRNARPrompt(nn.Module):
else:
inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids)
inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out, pad[None, :, :]), dim=1) # [prompt, audio]
# inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out, pad[None, :, :]), dim=1) # [prompt, audio, pad]
inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out), dim=1) # [prompt, audio]
attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(kwargs["device"])
# model_outputs = self.llm.generate(

View File

@ -91,7 +91,11 @@ class SenseVoice(nn.Module):
# decode the audio
# initial_prompt = kwargs.get("initial_prompt", "<|startoftranscript|><|ASR|>")
options = whisper.DecodingOptions(language=language, fp16=False, without_timestamps=True, initial_prompt=initial_prompt)
vocab_path = kwargs.get("vocab_path", None)
options = whisper.DecodingOptions(language=language, fp16=False, without_timestamps=True, initial_prompt=initial_prompt, vocab_path=vocab_path)
result = whisper.decode(self.model, speech, options)
results = []

View File

@ -89,7 +89,7 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
@lru_cache(maxsize=None)
def mel_filters(device, n_mels: int) -> torch.Tensor:
def mel_filters(device, n_mels: int, filters_path: str=None) -> torch.Tensor:
"""
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using:
@ -101,8 +101,8 @@ def mel_filters(device, n_mels: int) -> torch.Tensor:
)
"""
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
if filters_path is None:
filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
with np.load(filters_path, allow_pickle=False) as f:
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)

View File

@ -119,6 +119,7 @@ class DecodingOptions:
# FIX(funasr): sense vocie
initial_prompt: str = None
vocab_path: str = None
@dataclass(frozen=True)
@ -527,6 +528,7 @@ class DecodingTask:
num_languages=model.num_languages,
language=language,
task=options.task,
vocab_path=options.vocab_path
)
self.tokenizer: Tokenizer = tokenizer
self.options: DecodingOptions = self._verify_options(options)
@ -616,10 +618,13 @@ class DecodingTask:
+ prompt_tokens[-(self.n_ctx // 2 - 1) :]
+ tokens
)
#FIX(gzf): sense vocie
#FIX(funasr): sense vocie
if initial_prompt := self.options.initial_prompt:
tokens = self.tokenizer.encode(initial_prompt, allowed_special="all")
if self.options.language is None:
if self.options.language is not None:
initial_prompt = f"{initial_prompt}<|{self.options.language}|>"
tokens = self.tokenizer.encode(initial_prompt, allowed_special="all")
else:
tokens = self.tokenizer.encode(initial_prompt, allowed_special="all")
tokens += [0]
@ -691,6 +696,7 @@ class DecodingTask:
if self.options.language is None:
# tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
languages = "".join([f"<|{language}|>" for language in languages])
n_audio = audio_features.shape[0]
lang_tokens = torch.tensor([self.tokenizer.encode(languages, allowed_special="all")] * n_audio).to(
audio_features.device) # [n_audio, 1]

View File

@ -363,8 +363,10 @@ class Tokenizer:
@lru_cache(maxsize=None)
def get_encoding(name: str = "gpt2", num_languages: int = 99):
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
def get_encoding(name: str = "gpt2", num_languages: int = 99, vocab_path:str=None):
if vocab_path is None:
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
ranks = {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in open(vocab_path) if line)
@ -423,6 +425,7 @@ def get_tokenizer(
language: Optional[str] = None,
task: Optional[str] = None, # Literal["transcribe", "translate", None]
encoding_path: Optional[str] = None,
vocab_path: Optional[str] = None,
) -> Tokenizer:
if language is not None:
language = language.lower()
@ -443,7 +446,9 @@ def get_tokenizer(
if encoding_path is not None:
encoding_name = encoding_path
encoding = get_encoding(name=encoding_name, num_languages=num_languages)
encoding = get_encoding(name=encoding_name, num_languages=num_languages, vocab_path=vocab_path)
return Tokenizer(
encoding=encoding, num_languages=num_languages, language=language, task=task

View File

@ -1 +1 @@
1.0.20
1.0.22