mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Dev gzf decoding (#1695)
* resume from step * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * train_loss_avg train_acc_avg * train_loss_avg train_acc_avg * train_loss_avg train_acc_avg * log step * wav is not exist * wav is not exist * decoding * decoding
This commit is contained in:
parent
48a8c95334
commit
00d0df3a10
@ -8,6 +8,7 @@ from funasr import AutoModel
|
|||||||
model = AutoModel(model="iic/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch")
|
model = AutoModel(model="iic/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch")
|
||||||
|
|
||||||
res = model.generate(
|
res = model.generate(
|
||||||
input="https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav"
|
input="https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav",
|
||||||
|
decoding_ctc_weight=0.0,
|
||||||
)
|
)
|
||||||
print(res)
|
print(res)
|
||||||
|
|||||||
@ -0,0 +1,27 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- encoding: utf-8 -*-
|
||||||
|
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||||
|
# MIT License (https://opensource.org/licenses/MIT)
|
||||||
|
|
||||||
|
from funasr import AutoModel
|
||||||
|
|
||||||
|
model = AutoModel(
|
||||||
|
model="/Users/zhifu/Downloads/modelscope_models/SenseVoiceModelscopeFSMN",
|
||||||
|
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||||
|
vad_kwargs={"max_single_segment_time": 30000},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
input_wav = (
|
||||||
|
"https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
|
||||||
|
)
|
||||||
|
|
||||||
|
DecodingOptions = {
|
||||||
|
"task": ("ASR", "AED", "SER"),
|
||||||
|
"language": "auto",
|
||||||
|
"fp16": True,
|
||||||
|
"gain_event": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
res = model.generate(input=input_wav, batch_size_s=0, DecodingOptions=DecodingOptions, beam_size=5)
|
||||||
|
print(res)
|
||||||
@ -46,16 +46,17 @@ def update_data(lines, i):
|
|||||||
data = json.loads(line.strip())
|
data = json.loads(line.strip())
|
||||||
|
|
||||||
wav_path = data["source"].replace("/cpfs01", "/cpfs_speech/data")
|
wav_path = data["source"].replace("/cpfs01", "/cpfs_speech/data")
|
||||||
waveform, _ = librosa.load(wav_path, sr=16000)
|
if os.path.exists(wav_path):
|
||||||
sample_num = len(waveform)
|
waveform, _ = librosa.load(wav_path, sr=16000)
|
||||||
source_len = int(sample_num / 16000 * 1000 / 10)
|
sample_num = len(waveform)
|
||||||
source_len_old = data["source_len"]
|
source_len = int(sample_num / 16000 * 1000 / 10)
|
||||||
# if (source_len_old - source_len) > 100 or (source_len - source_len_old) > 100:
|
source_len_old = data["source_len"]
|
||||||
# logging.info(f"old: {source_len_old}, new: {source_len}, wav: {wav_path}")
|
# if (source_len_old - source_len) > 100 or (source_len - source_len_old) > 100:
|
||||||
data["source_len"] = source_len
|
# logging.info(f"old: {source_len_old}, new: {source_len}, wav: {wav_path}")
|
||||||
data["source"] = wav_path
|
data["source_len"] = source_len
|
||||||
jsonl_line = json.dumps(data, ensure_ascii=False)
|
data["source"] = wav_path
|
||||||
lines[i] = jsonl_line
|
jsonl_line = json.dumps(data, ensure_ascii=False)
|
||||||
|
lines[i] = jsonl_line
|
||||||
|
|
||||||
|
|
||||||
def update_wav_len(jsonl_file_list_in, jsonl_file_out_dir, ncpu=1):
|
def update_wav_len(jsonl_file_list_in, jsonl_file_out_dir, ncpu=1):
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import logging
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import random
|
import random
|
||||||
|
import traceback
|
||||||
from funasr.register import tables
|
from funasr.register import tables
|
||||||
from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
|
from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
|
||||||
|
|
||||||
@ -73,15 +73,17 @@ class SenseVoiceDataset(torch.utils.data.Dataset):
|
|||||||
if idx == 0:
|
if idx == 0:
|
||||||
index_cur = index
|
index_cur = index
|
||||||
else:
|
else:
|
||||||
if index <= self.retry:
|
index_cur = torch.randint(0, len(self.index_ds), ()).item()
|
||||||
index_cur = index + idx
|
|
||||||
else:
|
|
||||||
index_cur = torch.randint(0, index, ()).item()
|
|
||||||
|
|
||||||
item = self.index_ds[index_cur]
|
item = self.index_ds[index_cur]
|
||||||
|
|
||||||
source = item["source"]
|
source = item["source"]
|
||||||
data_src = load_audio_text_image_video(source, fs=self.fs)
|
try:
|
||||||
|
data_src = load_audio_text_image_video(source, fs=self.fs)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Loading wav failed! {str(e)}, {traceback.format_exc()}")
|
||||||
|
continue
|
||||||
|
|
||||||
if self.preprocessor_speech:
|
if self.preprocessor_speech:
|
||||||
data_src = self.preprocessor_speech(data_src, fs=self.fs)
|
data_src = self.preprocessor_speech(data_src, fs=self.fs)
|
||||||
speech, speech_lengths = extract_fbank(
|
speech, speech_lengths = extract_fbank(
|
||||||
@ -186,7 +188,7 @@ class SenseVoiceDataset(torch.utils.data.Dataset):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.batch_type != "example":
|
if self.batch_type != "example":
|
||||||
for i in range(3):
|
for i in range(10):
|
||||||
outputs = self._filter_badcase(outputs, i=i)
|
outputs = self._filter_badcase(outputs, i=i)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|||||||
@ -15,6 +15,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
from funasr.models.transformer.utils.mask import subsequent_mask
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(nn.LayerNorm):
|
class LayerNorm(nn.LayerNorm):
|
||||||
@ -336,6 +337,29 @@ class SenseVoiceDecoder(nn.Module):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def init_state(self, x):
|
||||||
|
state = {}
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
def final_score(self, state) -> float:
|
||||||
|
"""Score eos (optional).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Scorer state for prefix tokens
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: final score
|
||||||
|
|
||||||
|
"""
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
def score(self, ys, state, x):
|
||||||
|
"""Score."""
|
||||||
|
ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
|
||||||
|
logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=state)
|
||||||
|
return logp.squeeze(0)[-1, :], state
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadedAttentionSANMDecoder(nn.Module):
|
class MultiHeadedAttentionSANMDecoder(nn.Module):
|
||||||
"""Multi-Head Attention layer.
|
"""Multi-Head Attention layer.
|
||||||
@ -443,9 +467,19 @@ class ResidualAttentionBlockFSMN(nn.Module):
|
|||||||
kv_cache: Optional[dict] = None,
|
kv_cache: Optional[dict] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
cache = kwargs.get("cache", {})
|
||||||
|
layer = kwargs.get("layer", 0)
|
||||||
is_pad_mask = kwargs.get("is_pad_mask", False)
|
is_pad_mask = kwargs.get("is_pad_mask", False)
|
||||||
is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
|
is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
|
||||||
x = x + self.attn(self.attn_ln(x), mask=None, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
|
|
||||||
|
fsmn_cache = cache[layer]["fsmn_cache"] if len(cache) > 0 else None
|
||||||
|
# if fsmn_cache is not None:
|
||||||
|
# x = x[:, -1:]
|
||||||
|
att_res, fsmn_cache = self.attn(self.attn_ln(x), mask=None, cache=fsmn_cache)
|
||||||
|
# if len(cache)>1:
|
||||||
|
# cache[layer]["fsmn_cache"] = fsmn_cache
|
||||||
|
# x = x[:, -1:]
|
||||||
|
x = x + att_res
|
||||||
if self.cross_attn:
|
if self.cross_attn:
|
||||||
x = (
|
x = (
|
||||||
x
|
x
|
||||||
@ -510,10 +544,9 @@ class SenseVoiceDecoderFSMN(nn.Module):
|
|||||||
|
|
||||||
ys_in_lens = kwargs.get("ys_in_lens", None)
|
ys_in_lens = kwargs.get("ys_in_lens", None)
|
||||||
|
|
||||||
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
|
||||||
tgt, memory = x, xa
|
tgt, memory = x, xa
|
||||||
tgt[tgt == -1] = 0
|
tgt[tgt == -1] = 0
|
||||||
tgt = self.token_embedding(tgt) + self.positional_embedding[offset : offset + tgt.size(1)]
|
tgt = self.token_embedding(tgt) + self.positional_embedding[: tgt.size(1)]
|
||||||
# tgt = self.dropout(tgt)
|
# tgt = self.dropout(tgt)
|
||||||
|
|
||||||
x = tgt.to(memory.dtype)
|
x = tgt.to(memory.dtype)
|
||||||
@ -531,9 +564,40 @@ class SenseVoiceDecoderFSMN(nn.Module):
|
|||||||
memory_mask=memory_mask,
|
memory_mask=memory_mask,
|
||||||
is_pad_mask=False,
|
is_pad_mask=False,
|
||||||
is_pad_memory_mask=True,
|
is_pad_memory_mask=True,
|
||||||
|
cache=kwargs.get("cache", None),
|
||||||
|
layer=layer,
|
||||||
)
|
)
|
||||||
|
|
||||||
x = self.ln(x)
|
x = self.ln(x)
|
||||||
x = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
|
x = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def init_state(self, x):
|
||||||
|
state = {}
|
||||||
|
for layer, block in enumerate(self.blocks):
|
||||||
|
state[layer] = {
|
||||||
|
"fsmn_cache": None,
|
||||||
|
"memory_key": None,
|
||||||
|
"memory_value": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
def final_score(self, state) -> float:
|
||||||
|
"""Score eos (optional).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Scorer state for prefix tokens
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: final score
|
||||||
|
|
||||||
|
"""
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
def score(self, ys, state, x):
|
||||||
|
"""Score."""
|
||||||
|
ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
|
||||||
|
logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=state)
|
||||||
|
return logp.squeeze(0)[-1, :], state
|
||||||
|
|||||||
@ -15,6 +15,7 @@ from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
|
|||||||
from funasr.train_utils.device_funcs import force_gatherable
|
from funasr.train_utils.device_funcs import force_gatherable
|
||||||
from . import whisper_lib as whisper
|
from . import whisper_lib as whisper
|
||||||
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
|
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
|
||||||
|
from funasr.utils.datadir_writer import DatadirWriter
|
||||||
|
|
||||||
from funasr.register import tables
|
from funasr.register import tables
|
||||||
|
|
||||||
@ -395,6 +396,42 @@ class SenseVoiceRWKV(nn.Module):
|
|||||||
|
|
||||||
return loss_att, acc_att, None, None
|
return loss_att, acc_att, None, None
|
||||||
|
|
||||||
|
def init_beam_search(
|
||||||
|
self,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
from .search import BeamSearch
|
||||||
|
|
||||||
|
from funasr.models.transformer.scorers.length_bonus import LengthBonus
|
||||||
|
|
||||||
|
# 1. Build ASR model
|
||||||
|
scorers = {}
|
||||||
|
|
||||||
|
scorers.update(
|
||||||
|
decoder=self.model.decoder,
|
||||||
|
length_bonus=LengthBonus(self.vocab_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
weights = dict(
|
||||||
|
decoder=1.0,
|
||||||
|
ctc=0.0,
|
||||||
|
lm=0.0,
|
||||||
|
ngram=0.0,
|
||||||
|
length_bonus=kwargs.get("penalty", 0.0),
|
||||||
|
)
|
||||||
|
beam_search = BeamSearch(
|
||||||
|
beam_size=kwargs.get("beam_size", 5),
|
||||||
|
weights=weights,
|
||||||
|
scorers=scorers,
|
||||||
|
sos=None,
|
||||||
|
eos=None,
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
token_list=None,
|
||||||
|
pre_beam_score_key="full",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.beam_search = beam_search
|
||||||
|
|
||||||
def inference(
|
def inference(
|
||||||
self,
|
self,
|
||||||
data_in,
|
data_in,
|
||||||
@ -407,6 +444,12 @@ class SenseVoiceRWKV(nn.Module):
|
|||||||
if kwargs.get("batch_size", 1) > 1:
|
if kwargs.get("batch_size", 1) > 1:
|
||||||
raise NotImplementedError("batch decoding is not implemented")
|
raise NotImplementedError("batch decoding is not implemented")
|
||||||
|
|
||||||
|
# init beamsearch
|
||||||
|
if not hasattr(self, "beam_search") or self.beam_search is None:
|
||||||
|
logging.info("enable beam_search")
|
||||||
|
self.init_beam_search(**kwargs)
|
||||||
|
self.nbest = kwargs.get("nbest", 1)
|
||||||
|
|
||||||
if frontend is None and not hasattr(self, "frontend"):
|
if frontend is None and not hasattr(self, "frontend"):
|
||||||
frontend_class = tables.frontend_classes.get("WhisperFrontend")
|
frontend_class = tables.frontend_classes.get("WhisperFrontend")
|
||||||
frontend = frontend_class(
|
frontend = frontend_class(
|
||||||
@ -455,25 +498,65 @@ class SenseVoiceRWKV(nn.Module):
|
|||||||
task = [task]
|
task = [task]
|
||||||
task = "".join([f"<|{x}|>" for x in task])
|
task = "".join([f"<|{x}|>" for x in task])
|
||||||
initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
|
initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
|
||||||
DecodingOptions["initial_prompt"] = initial_prompt
|
|
||||||
|
|
||||||
language = DecodingOptions.get("language", None)
|
language = DecodingOptions.get("language", None)
|
||||||
language = None if language == "auto" else language
|
language = None if language == "auto" else language
|
||||||
DecodingOptions["language"] = language
|
|
||||||
|
|
||||||
DecodingOptions["vocab_path"] = kwargs["tokenizer_conf"].get("vocab_path", None)
|
sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
|
||||||
|
sos_int = tokenizer.encode(sos, allowed_special="all")
|
||||||
|
eos = kwargs.get("model_conf").get("eos")
|
||||||
|
eos_int = tokenizer.encode(eos, allowed_special="all")
|
||||||
|
self.beam_search.sos = sos_int
|
||||||
|
self.beam_search.eos = eos_int[0]
|
||||||
|
|
||||||
if "without_timestamps" not in DecodingOptions:
|
encoder_out, encoder_out_lens = self.encode(
|
||||||
DecodingOptions["without_timestamps"] = True
|
speech[None, :, :].permute(0, 2, 1), speech_lengths
|
||||||
|
)
|
||||||
|
|
||||||
options = whisper.DecodingOptions(**DecodingOptions)
|
# c. Passed the encoder result and the beam search
|
||||||
|
nbest_hyps = self.beam_search(
|
||||||
|
x=encoder_out[0],
|
||||||
|
maxlenratio=kwargs.get("maxlenratio", 0.0),
|
||||||
|
minlenratio=kwargs.get("minlenratio", 0.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
nbest_hyps = nbest_hyps[: self.nbest]
|
||||||
|
|
||||||
result = whisper.decode(self.model, speech, options)
|
|
||||||
text = f"{result.text}"
|
|
||||||
results = []
|
results = []
|
||||||
result_i = {"key": key[0], "text": text}
|
b, n, d = encoder_out.size()
|
||||||
|
for i in range(b):
|
||||||
|
|
||||||
results.append(result_i)
|
for nbest_idx, hyp in enumerate(nbest_hyps):
|
||||||
|
ibest_writer = None
|
||||||
|
if kwargs.get("output_dir") is not None:
|
||||||
|
if not hasattr(self, "writer"):
|
||||||
|
self.writer = DatadirWriter(kwargs.get("output_dir"))
|
||||||
|
ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
|
||||||
|
|
||||||
|
# remove sos/eos and get results
|
||||||
|
last_pos = -1
|
||||||
|
if isinstance(hyp.yseq, list):
|
||||||
|
token_int = hyp.yseq[1:last_pos]
|
||||||
|
else:
|
||||||
|
token_int = hyp.yseq[1:last_pos].tolist()
|
||||||
|
|
||||||
|
# # remove blank symbol id, which is assumed to be 0
|
||||||
|
# token_int = list(
|
||||||
|
# filter(
|
||||||
|
# lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
|
||||||
|
# Change integer-ids to tokens
|
||||||
|
# token = tokenizer.ids2tokens(token_int)
|
||||||
|
text = tokenizer.decode(token_int)
|
||||||
|
|
||||||
|
result_i = {"key": key[i], "text": text}
|
||||||
|
results.append(result_i)
|
||||||
|
|
||||||
|
if ibest_writer is not None:
|
||||||
|
# ibest_writer["token"][key[i]] = " ".join(token)
|
||||||
|
ibest_writer["text"][key[i]] = text
|
||||||
|
|
||||||
return results, meta_data
|
return results, meta_data
|
||||||
|
|
||||||
@ -497,12 +580,14 @@ class SenseVoiceFSMN(nn.Module):
|
|||||||
# decoder
|
# decoder
|
||||||
del model.decoder
|
del model.decoder
|
||||||
decoder = kwargs.get("decoder", "SenseVoiceDecoder")
|
decoder = kwargs.get("decoder", "SenseVoiceDecoder")
|
||||||
decoder_conf = kwargs.get("decoder_conf", {})
|
|
||||||
decoder_class = tables.decoder_classes.get(decoder)
|
decoder_class = tables.decoder_classes.get(decoder)
|
||||||
decoder = decoder_class(
|
decoder = decoder_class(
|
||||||
vocab_size=dims.n_vocab,
|
n_vocab=dims.n_vocab,
|
||||||
encoder_output_size=dims.n_audio_state,
|
n_ctx=dims.n_text_ctx,
|
||||||
**decoder_conf,
|
n_state=dims.n_text_state,
|
||||||
|
n_head=dims.n_text_head,
|
||||||
|
n_layer=dims.n_text_layer,
|
||||||
|
**kwargs.get("decoder_conf"),
|
||||||
)
|
)
|
||||||
model.decoder = decoder
|
model.decoder = decoder
|
||||||
|
|
||||||
@ -512,7 +597,7 @@ class SenseVoiceFSMN(nn.Module):
|
|||||||
|
|
||||||
self.activation_checkpoint = kwargs.get("activation_checkpoint", False)
|
self.activation_checkpoint = kwargs.get("activation_checkpoint", False)
|
||||||
self.ignore_id = kwargs.get("ignore_id", -1)
|
self.ignore_id = kwargs.get("ignore_id", -1)
|
||||||
self.vocab_size = kwargs.get("vocab_size", -1)
|
self.vocab_size = dims.n_vocab
|
||||||
self.length_normalized_loss = kwargs.get("length_normalized_loss", True)
|
self.length_normalized_loss = kwargs.get("length_normalized_loss", True)
|
||||||
self.criterion_att = LabelSmoothingLoss(
|
self.criterion_att = LabelSmoothingLoss(
|
||||||
size=self.vocab_size,
|
size=self.vocab_size,
|
||||||
@ -630,6 +715,42 @@ class SenseVoiceFSMN(nn.Module):
|
|||||||
|
|
||||||
return loss_att, acc_att, None, None
|
return loss_att, acc_att, None, None
|
||||||
|
|
||||||
|
def init_beam_search(
|
||||||
|
self,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
from .search import BeamSearch
|
||||||
|
|
||||||
|
from funasr.models.transformer.scorers.length_bonus import LengthBonus
|
||||||
|
|
||||||
|
# 1. Build ASR model
|
||||||
|
scorers = {}
|
||||||
|
|
||||||
|
scorers.update(
|
||||||
|
decoder=self.model.decoder,
|
||||||
|
length_bonus=LengthBonus(self.vocab_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
weights = dict(
|
||||||
|
decoder=1.0,
|
||||||
|
ctc=0.0,
|
||||||
|
lm=0.0,
|
||||||
|
ngram=0.0,
|
||||||
|
length_bonus=kwargs.get("penalty", 0.0),
|
||||||
|
)
|
||||||
|
beam_search = BeamSearch(
|
||||||
|
beam_size=kwargs.get("beam_size", 5),
|
||||||
|
weights=weights,
|
||||||
|
scorers=scorers,
|
||||||
|
sos=None,
|
||||||
|
eos=None,
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
token_list=None,
|
||||||
|
pre_beam_score_key="full",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.beam_search = beam_search
|
||||||
|
|
||||||
def inference(
|
def inference(
|
||||||
self,
|
self,
|
||||||
data_in,
|
data_in,
|
||||||
@ -642,6 +763,12 @@ class SenseVoiceFSMN(nn.Module):
|
|||||||
if kwargs.get("batch_size", 1) > 1:
|
if kwargs.get("batch_size", 1) > 1:
|
||||||
raise NotImplementedError("batch decoding is not implemented")
|
raise NotImplementedError("batch decoding is not implemented")
|
||||||
|
|
||||||
|
# init beamsearch
|
||||||
|
if not hasattr(self, "beam_search") or self.beam_search is None:
|
||||||
|
logging.info("enable beam_search")
|
||||||
|
self.init_beam_search(**kwargs)
|
||||||
|
self.nbest = kwargs.get("nbest", 1)
|
||||||
|
|
||||||
if frontend is None and not hasattr(self, "frontend"):
|
if frontend is None and not hasattr(self, "frontend"):
|
||||||
frontend_class = tables.frontend_classes.get("WhisperFrontend")
|
frontend_class = tables.frontend_classes.get("WhisperFrontend")
|
||||||
frontend = frontend_class(
|
frontend = frontend_class(
|
||||||
@ -690,24 +817,64 @@ class SenseVoiceFSMN(nn.Module):
|
|||||||
task = [task]
|
task = [task]
|
||||||
task = "".join([f"<|{x}|>" for x in task])
|
task = "".join([f"<|{x}|>" for x in task])
|
||||||
initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
|
initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
|
||||||
DecodingOptions["initial_prompt"] = initial_prompt
|
|
||||||
|
|
||||||
language = DecodingOptions.get("language", None)
|
language = DecodingOptions.get("language", None)
|
||||||
language = None if language == "auto" else language
|
language = None if language == "auto" else language
|
||||||
DecodingOptions["language"] = language
|
|
||||||
|
|
||||||
DecodingOptions["vocab_path"] = kwargs["tokenizer_conf"].get("vocab_path", None)
|
sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
|
||||||
|
sos_int = tokenizer.encode(sos, allowed_special="all")
|
||||||
|
eos = kwargs.get("model_conf").get("eos")
|
||||||
|
eos_int = tokenizer.encode(eos, allowed_special="all")
|
||||||
|
self.beam_search.sos = sos_int
|
||||||
|
self.beam_search.eos = eos_int[0]
|
||||||
|
|
||||||
if "without_timestamps" not in DecodingOptions:
|
encoder_out, encoder_out_lens = self.encode(
|
||||||
DecodingOptions["without_timestamps"] = True
|
speech[None, :, :].permute(0, 2, 1), speech_lengths
|
||||||
|
)
|
||||||
|
|
||||||
options = whisper.DecodingOptions(**DecodingOptions)
|
# c. Passed the encoder result and the beam search
|
||||||
|
nbest_hyps = self.beam_search(
|
||||||
|
x=encoder_out[0],
|
||||||
|
maxlenratio=kwargs.get("maxlenratio", 0.0),
|
||||||
|
minlenratio=kwargs.get("minlenratio", 0.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
nbest_hyps = nbest_hyps[: self.nbest]
|
||||||
|
|
||||||
result = whisper.decode(self.model, speech, options)
|
|
||||||
text = f"{result.text}"
|
|
||||||
results = []
|
results = []
|
||||||
result_i = {"key": key[0], "text": text}
|
b, n, d = encoder_out.size()
|
||||||
|
for i in range(b):
|
||||||
|
|
||||||
results.append(result_i)
|
for nbest_idx, hyp in enumerate(nbest_hyps):
|
||||||
|
ibest_writer = None
|
||||||
|
if kwargs.get("output_dir") is not None:
|
||||||
|
if not hasattr(self, "writer"):
|
||||||
|
self.writer = DatadirWriter(kwargs.get("output_dir"))
|
||||||
|
ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
|
||||||
|
|
||||||
|
# remove sos/eos and get results
|
||||||
|
last_pos = -1
|
||||||
|
if isinstance(hyp.yseq, list):
|
||||||
|
token_int = hyp.yseq[1:last_pos]
|
||||||
|
else:
|
||||||
|
token_int = hyp.yseq[1:last_pos].tolist()
|
||||||
|
|
||||||
|
# # remove blank symbol id, which is assumed to be 0
|
||||||
|
# token_int = list(
|
||||||
|
# filter(
|
||||||
|
# lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
|
||||||
|
# Change integer-ids to tokens
|
||||||
|
# token = tokenizer.ids2tokens(token_int)
|
||||||
|
text = tokenizer.decode(token_int)
|
||||||
|
|
||||||
|
result_i = {"key": key[i], "text": text}
|
||||||
|
results.append(result_i)
|
||||||
|
|
||||||
|
if ibest_writer is not None:
|
||||||
|
# ibest_writer["token"][key[i]] = " ".join(token)
|
||||||
|
ibest_writer["text"][key[i]] = text
|
||||||
|
|
||||||
return results, meta_data
|
return results, meta_data
|
||||||
|
|||||||
453
funasr/models/sense_voice/search.py
Normal file
453
funasr/models/sense_voice/search.py
Normal file
@ -0,0 +1,453 @@
|
|||||||
|
from itertools import chain
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
from typing import Dict
|
||||||
|
from typing import List
|
||||||
|
from typing import NamedTuple
|
||||||
|
from typing import Tuple
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from funasr.metrics.common import end_detect
|
||||||
|
from funasr.models.transformer.scorers.scorer_interface import PartialScorerInterface
|
||||||
|
from funasr.models.transformer.scorers.scorer_interface import ScorerInterface
|
||||||
|
|
||||||
|
|
||||||
|
class Hypothesis(NamedTuple):
|
||||||
|
"""Hypothesis data type."""
|
||||||
|
|
||||||
|
yseq: torch.Tensor
|
||||||
|
score: Union[float, torch.Tensor] = 0
|
||||||
|
scores: Dict[str, Union[float, torch.Tensor]] = dict()
|
||||||
|
states: Dict[str, Any] = dict()
|
||||||
|
|
||||||
|
def asdict(self) -> dict:
|
||||||
|
"""Convert data to JSON-friendly dict."""
|
||||||
|
return self._replace(
|
||||||
|
yseq=self.yseq.tolist(),
|
||||||
|
score=float(self.score),
|
||||||
|
scores={k: float(v) for k, v in self.scores.items()},
|
||||||
|
)._asdict()
|
||||||
|
|
||||||
|
|
||||||
|
class BeamSearch(torch.nn.Module):
|
||||||
|
"""Beam search implementation."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
scorers: Dict[str, ScorerInterface],
|
||||||
|
weights: Dict[str, float],
|
||||||
|
beam_size: int,
|
||||||
|
vocab_size: int,
|
||||||
|
sos=None,
|
||||||
|
eos=None,
|
||||||
|
token_list: List[str] = None,
|
||||||
|
pre_beam_ratio: float = 1.5,
|
||||||
|
pre_beam_score_key: str = None,
|
||||||
|
):
|
||||||
|
"""Initialize beam search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scorers (dict[str, ScorerInterface]): Dict of decoder modules
|
||||||
|
e.g., Decoder, CTCPrefixScorer, LM
|
||||||
|
The scorer will be ignored if it is `None`
|
||||||
|
weights (dict[str, float]): Dict of weights for each scorers
|
||||||
|
The scorer will be ignored if its weight is 0
|
||||||
|
beam_size (int): The number of hypotheses kept during search
|
||||||
|
vocab_size (int): The number of vocabulary
|
||||||
|
sos (int): Start of sequence id
|
||||||
|
eos (int): End of sequence id
|
||||||
|
token_list (list[str]): List of tokens for debug log
|
||||||
|
pre_beam_score_key (str): key of scores to perform pre-beam search
|
||||||
|
pre_beam_ratio (float): beam size in the pre-beam search
|
||||||
|
will be `int(pre_beam_ratio * beam_size)`
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
# set scorers
|
||||||
|
self.weights = weights
|
||||||
|
self.scorers = dict()
|
||||||
|
self.full_scorers = dict()
|
||||||
|
self.part_scorers = dict()
|
||||||
|
# this module dict is required for recursive cast
|
||||||
|
# `self.to(device, dtype)` in `recog.py`
|
||||||
|
self.nn_dict = torch.nn.ModuleDict()
|
||||||
|
for k, v in scorers.items():
|
||||||
|
w = weights.get(k, 0)
|
||||||
|
if w == 0 or v is None:
|
||||||
|
continue
|
||||||
|
# assert isinstance(
|
||||||
|
# v, ScorerInterface
|
||||||
|
# ), f"{k} ({type(v)}) does not implement ScorerInterface"
|
||||||
|
self.scorers[k] = v
|
||||||
|
if isinstance(v, PartialScorerInterface):
|
||||||
|
self.part_scorers[k] = v
|
||||||
|
else:
|
||||||
|
self.full_scorers[k] = v
|
||||||
|
if isinstance(v, torch.nn.Module):
|
||||||
|
self.nn_dict[k] = v
|
||||||
|
|
||||||
|
# set configurations
|
||||||
|
self.sos = sos
|
||||||
|
self.eos = eos
|
||||||
|
if isinstance(self.eos, (list, tuple)):
|
||||||
|
self.eos = eos[0]
|
||||||
|
self.token_list = token_list
|
||||||
|
self.pre_beam_size = int(pre_beam_ratio * beam_size)
|
||||||
|
self.beam_size = beam_size
|
||||||
|
self.n_vocab = vocab_size
|
||||||
|
if (
|
||||||
|
pre_beam_score_key is not None
|
||||||
|
and pre_beam_score_key != "full"
|
||||||
|
and pre_beam_score_key not in self.full_scorers
|
||||||
|
):
|
||||||
|
raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
|
||||||
|
self.pre_beam_score_key = pre_beam_score_key
|
||||||
|
self.do_pre_beam = (
|
||||||
|
self.pre_beam_score_key is not None
|
||||||
|
and self.pre_beam_size < self.n_vocab
|
||||||
|
and len(self.part_scorers) > 0
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]:
|
||||||
|
"""Get an initial hypothesis data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): The encoder output feature
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hypothesis: The initial hypothesis.
|
||||||
|
|
||||||
|
"""
|
||||||
|
init_states = dict()
|
||||||
|
init_scores = dict()
|
||||||
|
for k, d in self.scorers.items():
|
||||||
|
init_states[k] = d.init_state(x)
|
||||||
|
init_scores[k] = 0.0
|
||||||
|
if not isinstance(self.sos, (list, tuple)):
|
||||||
|
self.sos = [self.sos]
|
||||||
|
return [
|
||||||
|
Hypothesis(
|
||||||
|
score=0.0,
|
||||||
|
scores=init_scores,
|
||||||
|
states=init_states,
|
||||||
|
yseq=torch.tensor(self.sos, device=x.device),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
|
||||||
|
"""Append new token to prefix tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
xs (torch.Tensor): The prefix token
|
||||||
|
x (int): The new token to append
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device
|
||||||
|
|
||||||
|
"""
|
||||||
|
x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
|
||||||
|
return torch.cat((xs, x))
|
||||||
|
|
||||||
|
def score_full(
|
||||||
|
self, hyp: Hypothesis, x: torch.Tensor
|
||||||
|
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
|
||||||
|
"""Score new hypothesis by `self.full_scorers`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hyp (Hypothesis): Hypothesis with prefix tokens to score
|
||||||
|
x (torch.Tensor): Corresponding input feature
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
|
||||||
|
score dict of `hyp` that has string keys of `self.full_scorers`
|
||||||
|
and tensor score values of shape: `(self.n_vocab,)`,
|
||||||
|
and state dict that has string keys
|
||||||
|
and state values of `self.full_scorers`
|
||||||
|
|
||||||
|
"""
|
||||||
|
scores = dict()
|
||||||
|
states = dict()
|
||||||
|
for k, d in self.full_scorers.items():
|
||||||
|
scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x)
|
||||||
|
return scores, states
|
||||||
|
|
||||||
|
def score_partial(
|
||||||
|
self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor
|
||||||
|
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
|
||||||
|
"""Score new hypothesis by `self.part_scorers`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hyp (Hypothesis): Hypothesis with prefix tokens to score
|
||||||
|
ids (torch.Tensor): 1D tensor of new partial tokens to score
|
||||||
|
x (torch.Tensor): Corresponding input feature
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
|
||||||
|
score dict of `hyp` that has string keys of `self.part_scorers`
|
||||||
|
and tensor score values of shape: `(len(ids),)`,
|
||||||
|
and state dict that has string keys
|
||||||
|
and state values of `self.part_scorers`
|
||||||
|
|
||||||
|
"""
|
||||||
|
scores = dict()
|
||||||
|
states = dict()
|
||||||
|
for k, d in self.part_scorers.items():
|
||||||
|
scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x)
|
||||||
|
return scores, states
|
||||||
|
|
||||||
|
def beam(
|
||||||
|
self, weighted_scores: torch.Tensor, ids: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Compute topk full token ids and partial token ids.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
|
||||||
|
Its shape is `(self.n_vocab,)`.
|
||||||
|
ids (torch.Tensor): The partial token ids to compute topk
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
The topk full token ids and partial token ids.
|
||||||
|
Their shapes are `(self.beam_size,)`
|
||||||
|
|
||||||
|
"""
|
||||||
|
# no pre beam performed
|
||||||
|
if weighted_scores.size(0) == ids.size(0):
|
||||||
|
top_ids = weighted_scores.topk(self.beam_size)[1]
|
||||||
|
return top_ids, top_ids
|
||||||
|
|
||||||
|
# mask pruned in pre-beam not to select in topk
|
||||||
|
tmp = weighted_scores[ids]
|
||||||
|
weighted_scores[:] = -float("inf")
|
||||||
|
weighted_scores[ids] = tmp
|
||||||
|
top_ids = weighted_scores.topk(self.beam_size)[1]
|
||||||
|
local_ids = weighted_scores[ids].topk(self.beam_size)[1]
|
||||||
|
return top_ids, local_ids
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def merge_scores(
|
||||||
|
prev_scores: Dict[str, float],
|
||||||
|
next_full_scores: Dict[str, torch.Tensor],
|
||||||
|
full_idx: int,
|
||||||
|
next_part_scores: Dict[str, torch.Tensor],
|
||||||
|
part_idx: int,
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
"""Merge scores for new hypothesis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prev_scores (Dict[str, float]):
|
||||||
|
The previous hypothesis scores by `self.scorers`
|
||||||
|
next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
|
||||||
|
full_idx (int): The next token id for `next_full_scores`
|
||||||
|
next_part_scores (Dict[str, torch.Tensor]):
|
||||||
|
scores of partial tokens by `self.part_scorers`
|
||||||
|
part_idx (int): The new token id for `next_part_scores`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, torch.Tensor]: The new score dict.
|
||||||
|
Its keys are names of `self.full_scorers` and `self.part_scorers`.
|
||||||
|
Its values are scalar tensors by the scorers.
|
||||||
|
|
||||||
|
"""
|
||||||
|
new_scores = dict()
|
||||||
|
for k, v in next_full_scores.items():
|
||||||
|
new_scores[k] = prev_scores[k] + v[full_idx]
|
||||||
|
for k, v in next_part_scores.items():
|
||||||
|
new_scores[k] = prev_scores[k] + v[part_idx]
|
||||||
|
return new_scores
|
||||||
|
|
||||||
|
def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
|
||||||
|
"""Merge states for new hypothesis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
states: states of `self.full_scorers`
|
||||||
|
part_states: states of `self.part_scorers`
|
||||||
|
part_idx (int): The new token id for `part_scores`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, torch.Tensor]: The new score dict.
|
||||||
|
Its keys are names of `self.full_scorers` and `self.part_scorers`.
|
||||||
|
Its values are states of the scorers.
|
||||||
|
|
||||||
|
"""
|
||||||
|
new_states = dict()
|
||||||
|
for k, v in states.items():
|
||||||
|
new_states[k] = v
|
||||||
|
for k, d in self.part_scorers.items():
|
||||||
|
new_states[k] = d.select_state(part_states[k], part_idx)
|
||||||
|
return new_states
|
||||||
|
|
||||||
|
def search(self, running_hyps: List[Hypothesis], x: torch.Tensor) -> List[Hypothesis]:
|
||||||
|
"""Search new tokens for running hypotheses and encoded speech x.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
running_hyps (List[Hypothesis]): Running hypotheses on beam
|
||||||
|
x (torch.Tensor): Encoded speech feature (T, D)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Hypotheses]: Best sorted hypotheses
|
||||||
|
|
||||||
|
"""
|
||||||
|
best_hyps = []
|
||||||
|
part_ids = torch.arange(self.n_vocab, device=x.device) # no pre-beam
|
||||||
|
for hyp in running_hyps:
|
||||||
|
# scoring
|
||||||
|
weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
|
||||||
|
scores, states = self.score_full(hyp, x)
|
||||||
|
for k in self.full_scorers:
|
||||||
|
weighted_scores += self.weights[k] * scores[k]
|
||||||
|
# partial scoring
|
||||||
|
if self.do_pre_beam:
|
||||||
|
pre_beam_scores = (
|
||||||
|
weighted_scores
|
||||||
|
if self.pre_beam_score_key == "full"
|
||||||
|
else scores[self.pre_beam_score_key]
|
||||||
|
)
|
||||||
|
part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1]
|
||||||
|
part_scores, part_states = self.score_partial(hyp, part_ids, x)
|
||||||
|
for k in self.part_scorers:
|
||||||
|
weighted_scores[part_ids] += self.weights[k] * part_scores[k]
|
||||||
|
# add previous hyp score
|
||||||
|
weighted_scores += hyp.score
|
||||||
|
|
||||||
|
# update hyps
|
||||||
|
for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
|
||||||
|
# will be (2 x beam at most)
|
||||||
|
best_hyps.append(
|
||||||
|
Hypothesis(
|
||||||
|
score=weighted_scores[j],
|
||||||
|
yseq=self.append_token(hyp.yseq, j),
|
||||||
|
scores=self.merge_scores(hyp.scores, scores, j, part_scores, part_j),
|
||||||
|
states=self.merge_states(states, part_states, part_j),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# sort and prune 2 x beam -> beam
|
||||||
|
best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
|
||||||
|
: min(len(best_hyps), self.beam_size)
|
||||||
|
]
|
||||||
|
return best_hyps
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
|
||||||
|
) -> List[Hypothesis]:
|
||||||
|
"""Perform beam search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Encoded speech feature (T, D)
|
||||||
|
maxlenratio (float): Input length ratio to obtain max output length.
|
||||||
|
If maxlenratio=0.0 (default), it uses a end-detect function
|
||||||
|
to automatically find maximum hypothesis lengths
|
||||||
|
If maxlenratio<0.0, its absolute value is interpreted
|
||||||
|
as a constant max output length.
|
||||||
|
minlenratio (float): Input length ratio to obtain min output length.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[Hypothesis]: N-best decoding results
|
||||||
|
|
||||||
|
"""
|
||||||
|
# set length bounds
|
||||||
|
if maxlenratio == 0:
|
||||||
|
maxlen = x.shape[0]
|
||||||
|
elif maxlenratio < 0:
|
||||||
|
maxlen = -1 * int(maxlenratio)
|
||||||
|
else:
|
||||||
|
maxlen = max(1, int(maxlenratio * x.size(0)))
|
||||||
|
minlen = int(minlenratio * x.size(0))
|
||||||
|
logging.info("decoder input length: " + str(x.shape[0]))
|
||||||
|
logging.info("max output length: " + str(maxlen))
|
||||||
|
logging.info("min output length: " + str(minlen))
|
||||||
|
|
||||||
|
# main loop of prefix search
|
||||||
|
running_hyps = self.init_hyp(x)
|
||||||
|
ended_hyps = []
|
||||||
|
for i in range(maxlen):
|
||||||
|
logging.debug("position " + str(i))
|
||||||
|
best = self.search(running_hyps, x)
|
||||||
|
# post process of one iteration
|
||||||
|
running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
|
||||||
|
# end detection
|
||||||
|
if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
|
||||||
|
logging.info(f"end detected at {i}")
|
||||||
|
break
|
||||||
|
if len(running_hyps) == 0:
|
||||||
|
logging.info("no hypothesis. Finish decoding.")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
logging.debug(f"remained hypotheses: {len(running_hyps)}")
|
||||||
|
|
||||||
|
nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
|
||||||
|
# check the number of hypotheses reaching to eos
|
||||||
|
if len(nbest_hyps) == 0:
|
||||||
|
logging.warning(
|
||||||
|
"there is no N-best results, perform recognition " "again with smaller minlenratio."
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
[]
|
||||||
|
if minlenratio < 0.1
|
||||||
|
else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
|
||||||
|
)
|
||||||
|
|
||||||
|
# report the best result
|
||||||
|
best = nbest_hyps[0]
|
||||||
|
for k, v in best.scores.items():
|
||||||
|
logging.info(f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}")
|
||||||
|
logging.info(f"total log probability: {best.score:.2f}")
|
||||||
|
logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
|
||||||
|
logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
|
||||||
|
if self.token_list is not None:
|
||||||
|
logging.info(
|
||||||
|
"best hypo: " + "".join([self.token_list[x] for x in best.yseq[1:-1]]) + "\n"
|
||||||
|
)
|
||||||
|
return nbest_hyps
|
||||||
|
|
||||||
|
def post_process(
|
||||||
|
self,
|
||||||
|
i: int,
|
||||||
|
maxlen: int,
|
||||||
|
maxlenratio: float,
|
||||||
|
running_hyps: List[Hypothesis],
|
||||||
|
ended_hyps: List[Hypothesis],
|
||||||
|
) -> List[Hypothesis]:
|
||||||
|
"""Perform post-processing of beam search iterations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
i (int): The length of hypothesis tokens.
|
||||||
|
maxlen (int): The maximum length of tokens in beam search.
|
||||||
|
maxlenratio (int): The maximum length ratio in beam search.
|
||||||
|
running_hyps (List[Hypothesis]): The running hypotheses in beam search.
|
||||||
|
ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Hypothesis]: The new running hypotheses.
|
||||||
|
|
||||||
|
"""
|
||||||
|
logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
|
||||||
|
if self.token_list is not None:
|
||||||
|
logging.debug(
|
||||||
|
"best hypo: " + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
|
||||||
|
)
|
||||||
|
# add eos in the final loop to avoid that there are no ended hyps
|
||||||
|
if i == maxlen - 1:
|
||||||
|
logging.info("adding <eos> in the last position in the loop")
|
||||||
|
running_hyps = [
|
||||||
|
h._replace(yseq=self.append_token(h.yseq, self.eos)) for h in running_hyps
|
||||||
|
]
|
||||||
|
|
||||||
|
# add ended hypotheses to a final list, and removed them from current hypotheses
|
||||||
|
# (this will be a problem, number of hyps < beam)
|
||||||
|
remained_hyps = []
|
||||||
|
for hyp in running_hyps:
|
||||||
|
if hyp.yseq[-1] == self.eos:
|
||||||
|
# e.g., Word LM needs to add final <eos> score
|
||||||
|
for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
|
||||||
|
s = d.final_score(hyp.states[k])
|
||||||
|
hyp.scores[k] += s
|
||||||
|
hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
|
||||||
|
ended_hyps.append(hyp)
|
||||||
|
else:
|
||||||
|
remained_hyps.append(hyp)
|
||||||
|
return remained_hyps
|
||||||
@ -308,6 +308,7 @@ class Trainer:
|
|||||||
checkpoint["step_in_epoch"] if "step_in_epoch" in checkpoint else 0
|
checkpoint["step_in_epoch"] if "step_in_epoch" in checkpoint else 0
|
||||||
)
|
)
|
||||||
self.step_in_epoch = 0 if self.step_in_epoch is None else self.step_in_epoch
|
self.step_in_epoch = 0 if self.step_in_epoch is None else self.step_in_epoch
|
||||||
|
print(checkpoint["train_acc_avg"])
|
||||||
self.train_acc_avg = (
|
self.train_acc_avg = (
|
||||||
checkpoint["train_acc_avg"] if "train_acc_avg" in checkpoint else 0
|
checkpoint["train_acc_avg"] if "train_acc_avg" in checkpoint else 0
|
||||||
)
|
)
|
||||||
@ -464,7 +465,8 @@ class Trainer:
|
|||||||
batch_num_epoch = len(dataloader_train)
|
batch_num_epoch = len(dataloader_train)
|
||||||
self.log(
|
self.log(
|
||||||
epoch,
|
epoch,
|
||||||
batch_idx + kwargs.get("start_step", 0),
|
batch_idx,
|
||||||
|
log_step=batch_idx + kwargs.get("start_step", 0),
|
||||||
step_in_epoch=self.step_in_epoch,
|
step_in_epoch=self.step_in_epoch,
|
||||||
batch_num_epoch=batch_num_epoch,
|
batch_num_epoch=batch_num_epoch,
|
||||||
lr=lr,
|
lr=lr,
|
||||||
@ -633,11 +635,12 @@ class Trainer:
|
|||||||
tag="train",
|
tag="train",
|
||||||
data_split_i=0,
|
data_split_i=0,
|
||||||
data_split_num=1,
|
data_split_num=1,
|
||||||
|
log_step=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
if (batch_idx + 1) % self.log_interval == 0:
|
if (batch_idx + 1) % self.log_interval == 0:
|
||||||
|
batch_idx = log_step if log_step is not None else batch_idx
|
||||||
gpu_info = (
|
gpu_info = (
|
||||||
"GPU, memory: usage: {:.3f} GB, "
|
"GPU, memory: usage: {:.3f} GB, "
|
||||||
"peak: {:.3f} GB, "
|
"peak: {:.3f} GB, "
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user