diff --git a/examples/industrial_data_pretraining/conformer/demo.py b/examples/industrial_data_pretraining/conformer/demo.py index 43cf67de6..c2d7682bb 100644 --- a/examples/industrial_data_pretraining/conformer/demo.py +++ b/examples/industrial_data_pretraining/conformer/demo.py @@ -8,6 +8,7 @@ from funasr import AutoModel model = AutoModel(model="iic/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch") res = model.generate( - input="https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav" + input="https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav", + decoding_ctc_weight=0.0, ) print(res) diff --git a/examples/industrial_data_pretraining/sense_voice/demo_fsmn.py b/examples/industrial_data_pretraining/sense_voice/demo_fsmn.py new file mode 100644 index 000000000..b94cba7f4 --- /dev/null +++ b/examples/industrial_data_pretraining/sense_voice/demo_fsmn.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +from funasr import AutoModel + +model = AutoModel( + model="/Users/zhifu/Downloads/modelscope_models/SenseVoiceModelscopeFSMN", + vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", + vad_kwargs={"max_single_segment_time": 30000}, +) + + +input_wav = ( + "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" +) + +DecodingOptions = { + "task": ("ASR", "AED", "SER"), + "language": "auto", + "fp16": True, + "gain_event": True, +} + +res = model.generate(input=input_wav, batch_size_s=0, DecodingOptions=DecodingOptions) +print(res) diff --git a/funasr/models/sense_voice/decoder.py b/funasr/models/sense_voice/decoder.py index f5b882595..8a4a2ce6a 100644 --- a/funasr/models/sense_voice/decoder.py +++ b/funasr/models/sense_voice/decoder.py @@ -15,6 +15,7 @@ import numpy as np import torch import torch.nn.functional as F from torch import Tensor, nn +from funasr.models.transformer.utils.mask import subsequent_mask class LayerNorm(nn.LayerNorm): @@ -443,9 +444,19 @@ class ResidualAttentionBlockFSMN(nn.Module): kv_cache: Optional[dict] = None, **kwargs, ): + cache = kwargs.get("cache", {}) + layer = kwargs.get("layer", 0) is_pad_mask = kwargs.get("is_pad_mask", False) is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False) - x = x + self.attn(self.attn_ln(x), mask=None, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0] + + fsmn_cache = cache[layer]["fsmn_cache"] if len(cache) > 0 else None + # if fsmn_cache is not None: + # x = x[:, -1:] + att_res, fsmn_cache = self.attn(self.attn_ln(x), mask=None, cache=fsmn_cache) + # if len(cache)>1: + # cache[layer]["fsmn_cache"] = fsmn_cache + # x = x[:, -1:] + x = x + att_res if self.cross_attn: x = ( x @@ -510,10 +521,9 @@ class SenseVoiceDecoderFSMN(nn.Module): ys_in_lens = kwargs.get("ys_in_lens", None) - offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 tgt, memory = x, xa tgt[tgt == -1] = 0 - tgt = self.token_embedding(tgt) + self.positional_embedding[offset : offset + tgt.size(1)] + tgt = self.token_embedding(tgt) + self.positional_embedding[: tgt.size(1)] # tgt = self.dropout(tgt) x = tgt.to(memory.dtype) @@ -531,9 +541,40 @@ class SenseVoiceDecoderFSMN(nn.Module): memory_mask=memory_mask, is_pad_mask=False, is_pad_memory_mask=True, + cache=kwargs.get("cache", None), + layer=layer, ) x = self.ln(x) x = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float() return x + + def init_state(self, x): + state = {} + for layer, block in enumerate(self.blocks): + state[layer] = { + "fsmn_cache": None, + "memory_key": None, + "memory_value": None, + } + + return state + + def final_score(self, state) -> float: + """Score eos (optional). + + Args: + state: Scorer state for prefix tokens + + Returns: + float: final score + + """ + return 0.0 + + def score(self, ys, state, x): + """Score.""" + ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0) + logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=state) + return logp.squeeze(0)[-1, :], state diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py index c12107e7a..82ccc55e1 100644 --- a/funasr/models/sense_voice/model.py +++ b/funasr/models/sense_voice/model.py @@ -15,6 +15,7 @@ from funasr.losses.label_smoothing_loss import LabelSmoothingLoss from funasr.train_utils.device_funcs import force_gatherable from . import whisper_lib as whisper from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank +from funasr.utils.datadir_writer import DatadirWriter from funasr.register import tables @@ -497,12 +498,14 @@ class SenseVoiceFSMN(nn.Module): # decoder del model.decoder decoder = kwargs.get("decoder", "SenseVoiceDecoder") - decoder_conf = kwargs.get("decoder_conf", {}) decoder_class = tables.decoder_classes.get(decoder) decoder = decoder_class( - vocab_size=dims.n_vocab, - encoder_output_size=dims.n_audio_state, - **decoder_conf, + n_vocab=dims.n_vocab, + n_ctx=dims.n_text_ctx, + n_state=dims.n_text_state, + n_head=dims.n_text_head, + n_layer=dims.n_text_layer, + **kwargs.get("decoder_conf"), ) model.decoder = decoder @@ -512,7 +515,7 @@ class SenseVoiceFSMN(nn.Module): self.activation_checkpoint = kwargs.get("activation_checkpoint", False) self.ignore_id = kwargs.get("ignore_id", -1) - self.vocab_size = kwargs.get("vocab_size", -1) + self.vocab_size = dims.n_vocab self.length_normalized_loss = kwargs.get("length_normalized_loss", True) self.criterion_att = LabelSmoothingLoss( size=self.vocab_size, @@ -630,6 +633,42 @@ class SenseVoiceFSMN(nn.Module): return loss_att, acc_att, None, None + def init_beam_search( + self, + **kwargs, + ): + from .search import BeamSearch + + from funasr.models.transformer.scorers.length_bonus import LengthBonus + + # 1. Build ASR model + scorers = {} + + scorers.update( + decoder=self.model.decoder, + length_bonus=LengthBonus(self.vocab_size), + ) + + weights = dict( + decoder=1.0, + ctc=0.0, + lm=0.0, + ngram=0.0, + length_bonus=kwargs.get("penalty", 0.0), + ) + beam_search = BeamSearch( + beam_size=kwargs.get("beam_size", 5), + weights=weights, + scorers=scorers, + sos=None, + eos=None, + vocab_size=self.vocab_size, + token_list=None, + pre_beam_score_key="full", + ) + + self.beam_search = beam_search + def inference( self, data_in, @@ -642,6 +681,12 @@ class SenseVoiceFSMN(nn.Module): if kwargs.get("batch_size", 1) > 1: raise NotImplementedError("batch decoding is not implemented") + # init beamsearch + if not hasattr(self, "beam_search") or self.beam_search is None: + logging.info("enable beam_search") + self.init_beam_search(**kwargs) + self.nbest = kwargs.get("nbest", 1) + if frontend is None and not hasattr(self, "frontend"): frontend_class = tables.frontend_classes.get("WhisperFrontend") frontend = frontend_class( @@ -690,24 +735,64 @@ class SenseVoiceFSMN(nn.Module): task = [task] task = "".join([f"<|{x}|>" for x in task]) initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}") - DecodingOptions["initial_prompt"] = initial_prompt language = DecodingOptions.get("language", None) language = None if language == "auto" else language - DecodingOptions["language"] = language - DecodingOptions["vocab_path"] = kwargs["tokenizer_conf"].get("vocab_path", None) + sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt + sos_int = tokenizer.encode(sos, allowed_special="all") + eos = kwargs.get("model_conf").get("eos") + eos_int = tokenizer.encode(eos, allowed_special="all") + self.beam_search.sos = sos_int + self.beam_search.eos = eos_int[0] - if "without_timestamps" not in DecodingOptions: - DecodingOptions["without_timestamps"] = True + encoder_out, encoder_out_lens = self.encode( + speech[None, :, :].permute(0, 2, 1), speech_lengths + ) - options = whisper.DecodingOptions(**DecodingOptions) + # c. Passed the encoder result and the beam search + nbest_hyps = self.beam_search( + x=encoder_out[0], + maxlenratio=kwargs.get("maxlenratio", 0.0), + minlenratio=kwargs.get("minlenratio", 0.0), + ) + + nbest_hyps = nbest_hyps[: self.nbest] - result = whisper.decode(self.model, speech, options) - text = f"{result.text}" results = [] - result_i = {"key": key[0], "text": text} + b, n, d = encoder_out.size() + for i in range(b): - results.append(result_i) + for nbest_idx, hyp in enumerate(nbest_hyps): + ibest_writer = None + if kwargs.get("output_dir") is not None: + if not hasattr(self, "writer"): + self.writer = DatadirWriter(kwargs.get("output_dir")) + ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"] + + # remove sos/eos and get results + last_pos = -1 + if isinstance(hyp.yseq, list): + token_int = hyp.yseq[1:last_pos] + else: + token_int = hyp.yseq[1:last_pos].tolist() + + # # remove blank symbol id, which is assumed to be 0 + # token_int = list( + # filter( + # lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int + # ) + # ) + + # Change integer-ids to tokens + # token = tokenizer.ids2tokens(token_int) + text = tokenizer.decode(token_int) + + result_i = {"key": key[i], "text": text} + results.append(result_i) + + if ibest_writer is not None: + # ibest_writer["token"][key[i]] = " ".join(token) + ibest_writer["text"][key[i]] = text return results, meta_data diff --git a/funasr/models/sense_voice/search.py b/funasr/models/sense_voice/search.py new file mode 100644 index 000000000..98d02db9d --- /dev/null +++ b/funasr/models/sense_voice/search.py @@ -0,0 +1,453 @@ +from itertools import chain +import logging +from typing import Any +from typing import Dict +from typing import List +from typing import NamedTuple +from typing import Tuple +from typing import Union + +import torch + +from funasr.metrics.common import end_detect +from funasr.models.transformer.scorers.scorer_interface import PartialScorerInterface +from funasr.models.transformer.scorers.scorer_interface import ScorerInterface + + +class Hypothesis(NamedTuple): + """Hypothesis data type.""" + + yseq: torch.Tensor + score: Union[float, torch.Tensor] = 0 + scores: Dict[str, Union[float, torch.Tensor]] = dict() + states: Dict[str, Any] = dict() + + def asdict(self) -> dict: + """Convert data to JSON-friendly dict.""" + return self._replace( + yseq=self.yseq.tolist(), + score=float(self.score), + scores={k: float(v) for k, v in self.scores.items()}, + )._asdict() + + +class BeamSearch(torch.nn.Module): + """Beam search implementation.""" + + def __init__( + self, + scorers: Dict[str, ScorerInterface], + weights: Dict[str, float], + beam_size: int, + vocab_size: int, + sos=None, + eos=None, + token_list: List[str] = None, + pre_beam_ratio: float = 1.5, + pre_beam_score_key: str = None, + ): + """Initialize beam search. + + Args: + scorers (dict[str, ScorerInterface]): Dict of decoder modules + e.g., Decoder, CTCPrefixScorer, LM + The scorer will be ignored if it is `None` + weights (dict[str, float]): Dict of weights for each scorers + The scorer will be ignored if its weight is 0 + beam_size (int): The number of hypotheses kept during search + vocab_size (int): The number of vocabulary + sos (int): Start of sequence id + eos (int): End of sequence id + token_list (list[str]): List of tokens for debug log + pre_beam_score_key (str): key of scores to perform pre-beam search + pre_beam_ratio (float): beam size in the pre-beam search + will be `int(pre_beam_ratio * beam_size)` + + """ + super().__init__() + # set scorers + self.weights = weights + self.scorers = dict() + self.full_scorers = dict() + self.part_scorers = dict() + # this module dict is required for recursive cast + # `self.to(device, dtype)` in `recog.py` + self.nn_dict = torch.nn.ModuleDict() + for k, v in scorers.items(): + w = weights.get(k, 0) + if w == 0 or v is None: + continue + # assert isinstance( + # v, ScorerInterface + # ), f"{k} ({type(v)}) does not implement ScorerInterface" + self.scorers[k] = v + if isinstance(v, PartialScorerInterface): + self.part_scorers[k] = v + else: + self.full_scorers[k] = v + if isinstance(v, torch.nn.Module): + self.nn_dict[k] = v + + # set configurations + self.sos = sos + self.eos = eos + if isinstance(self.eos, (list, tuple)): + self.eos = eos[0] + self.token_list = token_list + self.pre_beam_size = int(pre_beam_ratio * beam_size) + self.beam_size = beam_size + self.n_vocab = vocab_size + if ( + pre_beam_score_key is not None + and pre_beam_score_key != "full" + and pre_beam_score_key not in self.full_scorers + ): + raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}") + self.pre_beam_score_key = pre_beam_score_key + self.do_pre_beam = ( + self.pre_beam_score_key is not None + and self.pre_beam_size < self.n_vocab + and len(self.part_scorers) > 0 + ) + + def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]: + """Get an initial hypothesis data. + + Args: + x (torch.Tensor): The encoder output feature + + Returns: + Hypothesis: The initial hypothesis. + + """ + init_states = dict() + init_scores = dict() + for k, d in self.scorers.items(): + init_states[k] = d.init_state(x) + init_scores[k] = 0.0 + if not isinstance(self.sos, (list, tuple)): + self.sos = [self.sos] + return [ + Hypothesis( + score=0.0, + scores=init_scores, + states=init_states, + yseq=torch.tensor(self.sos, device=x.device), + ) + ] + + @staticmethod + def append_token(xs: torch.Tensor, x: int) -> torch.Tensor: + """Append new token to prefix tokens. + + Args: + xs (torch.Tensor): The prefix token + x (int): The new token to append + + Returns: + torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device + + """ + x = torch.tensor([x], dtype=xs.dtype, device=xs.device) + return torch.cat((xs, x)) + + def score_full( + self, hyp: Hypothesis, x: torch.Tensor + ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: + """Score new hypothesis by `self.full_scorers`. + + Args: + hyp (Hypothesis): Hypothesis with prefix tokens to score + x (torch.Tensor): Corresponding input feature + + Returns: + Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of + score dict of `hyp` that has string keys of `self.full_scorers` + and tensor score values of shape: `(self.n_vocab,)`, + and state dict that has string keys + and state values of `self.full_scorers` + + """ + scores = dict() + states = dict() + for k, d in self.full_scorers.items(): + scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x) + return scores, states + + def score_partial( + self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor + ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: + """Score new hypothesis by `self.part_scorers`. + + Args: + hyp (Hypothesis): Hypothesis with prefix tokens to score + ids (torch.Tensor): 1D tensor of new partial tokens to score + x (torch.Tensor): Corresponding input feature + + Returns: + Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of + score dict of `hyp` that has string keys of `self.part_scorers` + and tensor score values of shape: `(len(ids),)`, + and state dict that has string keys + and state values of `self.part_scorers` + + """ + scores = dict() + states = dict() + for k, d in self.part_scorers.items(): + scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x) + return scores, states + + def beam( + self, weighted_scores: torch.Tensor, ids: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute topk full token ids and partial token ids. + + Args: + weighted_scores (torch.Tensor): The weighted sum scores for each tokens. + Its shape is `(self.n_vocab,)`. + ids (torch.Tensor): The partial token ids to compute topk + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + The topk full token ids and partial token ids. + Their shapes are `(self.beam_size,)` + + """ + # no pre beam performed + if weighted_scores.size(0) == ids.size(0): + top_ids = weighted_scores.topk(self.beam_size)[1] + return top_ids, top_ids + + # mask pruned in pre-beam not to select in topk + tmp = weighted_scores[ids] + weighted_scores[:] = -float("inf") + weighted_scores[ids] = tmp + top_ids = weighted_scores.topk(self.beam_size)[1] + local_ids = weighted_scores[ids].topk(self.beam_size)[1] + return top_ids, local_ids + + @staticmethod + def merge_scores( + prev_scores: Dict[str, float], + next_full_scores: Dict[str, torch.Tensor], + full_idx: int, + next_part_scores: Dict[str, torch.Tensor], + part_idx: int, + ) -> Dict[str, torch.Tensor]: + """Merge scores for new hypothesis. + + Args: + prev_scores (Dict[str, float]): + The previous hypothesis scores by `self.scorers` + next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers` + full_idx (int): The next token id for `next_full_scores` + next_part_scores (Dict[str, torch.Tensor]): + scores of partial tokens by `self.part_scorers` + part_idx (int): The new token id for `next_part_scores` + + Returns: + Dict[str, torch.Tensor]: The new score dict. + Its keys are names of `self.full_scorers` and `self.part_scorers`. + Its values are scalar tensors by the scorers. + + """ + new_scores = dict() + for k, v in next_full_scores.items(): + new_scores[k] = prev_scores[k] + v[full_idx] + for k, v in next_part_scores.items(): + new_scores[k] = prev_scores[k] + v[part_idx] + return new_scores + + def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any: + """Merge states for new hypothesis. + + Args: + states: states of `self.full_scorers` + part_states: states of `self.part_scorers` + part_idx (int): The new token id for `part_scores` + + Returns: + Dict[str, torch.Tensor]: The new score dict. + Its keys are names of `self.full_scorers` and `self.part_scorers`. + Its values are states of the scorers. + + """ + new_states = dict() + for k, v in states.items(): + new_states[k] = v + for k, d in self.part_scorers.items(): + new_states[k] = d.select_state(part_states[k], part_idx) + return new_states + + def search(self, running_hyps: List[Hypothesis], x: torch.Tensor) -> List[Hypothesis]: + """Search new tokens for running hypotheses and encoded speech x. + + Args: + running_hyps (List[Hypothesis]): Running hypotheses on beam + x (torch.Tensor): Encoded speech feature (T, D) + + Returns: + List[Hypotheses]: Best sorted hypotheses + + """ + best_hyps = [] + part_ids = torch.arange(self.n_vocab, device=x.device) # no pre-beam + for hyp in running_hyps: + # scoring + weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device) + scores, states = self.score_full(hyp, x) + for k in self.full_scorers: + weighted_scores += self.weights[k] * scores[k] + # partial scoring + if self.do_pre_beam: + pre_beam_scores = ( + weighted_scores + if self.pre_beam_score_key == "full" + else scores[self.pre_beam_score_key] + ) + part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1] + part_scores, part_states = self.score_partial(hyp, part_ids, x) + for k in self.part_scorers: + weighted_scores[part_ids] += self.weights[k] * part_scores[k] + # add previous hyp score + weighted_scores += hyp.score + + # update hyps + for j, part_j in zip(*self.beam(weighted_scores, part_ids)): + # will be (2 x beam at most) + best_hyps.append( + Hypothesis( + score=weighted_scores[j], + yseq=self.append_token(hyp.yseq, j), + scores=self.merge_scores(hyp.scores, scores, j, part_scores, part_j), + states=self.merge_states(states, part_states, part_j), + ) + ) + + # sort and prune 2 x beam -> beam + best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[ + : min(len(best_hyps), self.beam_size) + ] + return best_hyps + + def forward( + self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0 + ) -> List[Hypothesis]: + """Perform beam search. + + Args: + x (torch.Tensor): Encoded speech feature (T, D) + maxlenratio (float): Input length ratio to obtain max output length. + If maxlenratio=0.0 (default), it uses a end-detect function + to automatically find maximum hypothesis lengths + If maxlenratio<0.0, its absolute value is interpreted + as a constant max output length. + minlenratio (float): Input length ratio to obtain min output length. + + Returns: + list[Hypothesis]: N-best decoding results + + """ + # set length bounds + if maxlenratio == 0: + maxlen = x.shape[0] + elif maxlenratio < 0: + maxlen = -1 * int(maxlenratio) + else: + maxlen = max(1, int(maxlenratio * x.size(0))) + minlen = int(minlenratio * x.size(0)) + logging.info("decoder input length: " + str(x.shape[0])) + logging.info("max output length: " + str(maxlen)) + logging.info("min output length: " + str(minlen)) + + # main loop of prefix search + running_hyps = self.init_hyp(x) + ended_hyps = [] + for i in range(maxlen): + logging.debug("position " + str(i)) + best = self.search(running_hyps, x) + # post process of one iteration + running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps) + # end detection + if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i): + logging.info(f"end detected at {i}") + break + if len(running_hyps) == 0: + logging.info("no hypothesis. Finish decoding.") + break + else: + logging.debug(f"remained hypotheses: {len(running_hyps)}") + + nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True) + # check the number of hypotheses reaching to eos + if len(nbest_hyps) == 0: + logging.warning( + "there is no N-best results, perform recognition " "again with smaller minlenratio." + ) + return ( + [] + if minlenratio < 0.1 + else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1)) + ) + + # report the best result + best = nbest_hyps[0] + for k, v in best.scores.items(): + logging.info(f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}") + logging.info(f"total log probability: {best.score:.2f}") + logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}") + logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}") + if self.token_list is not None: + logging.info( + "best hypo: " + "".join([self.token_list[x] for x in best.yseq[1:-1]]) + "\n" + ) + return nbest_hyps + + def post_process( + self, + i: int, + maxlen: int, + maxlenratio: float, + running_hyps: List[Hypothesis], + ended_hyps: List[Hypothesis], + ) -> List[Hypothesis]: + """Perform post-processing of beam search iterations. + + Args: + i (int): The length of hypothesis tokens. + maxlen (int): The maximum length of tokens in beam search. + maxlenratio (int): The maximum length ratio in beam search. + running_hyps (List[Hypothesis]): The running hypotheses in beam search. + ended_hyps (List[Hypothesis]): The ended hypotheses in beam search. + + Returns: + List[Hypothesis]: The new running hypotheses. + + """ + logging.debug(f"the number of running hypotheses: {len(running_hyps)}") + if self.token_list is not None: + logging.debug( + "best hypo: " + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]]) + ) + # add eos in the final loop to avoid that there are no ended hyps + if i == maxlen - 1: + logging.info("adding in the last position in the loop") + running_hyps = [ + h._replace(yseq=self.append_token(h.yseq, self.eos)) for h in running_hyps + ] + + # add ended hypotheses to a final list, and removed them from current hypotheses + # (this will be a problem, number of hyps < beam) + remained_hyps = [] + for hyp in running_hyps: + if hyp.yseq[-1] == self.eos: + # e.g., Word LM needs to add final score + for k, d in chain(self.full_scorers.items(), self.part_scorers.items()): + s = d.final_score(hyp.states[k]) + hyp.scores[k] += s + hyp = hyp._replace(score=hyp.score + self.weights[k] * s) + ended_hyps.append(hyp) + else: + remained_hyps.append(hyp) + return remained_hyps