mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Funasr1.0 (#1297)
* fix add_file bug (#1296) Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com> * funasr1.0 uniasr * funasr1.0 uniasr --------- Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>
This commit is contained in:
parent
09372cb279
commit
e4035edb46
29
examples/industrial_data_pretraining/uniasr/demo.py
Normal file
29
examples/industrial_data_pretraining/uniasr/demo.py
Normal file
@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- encoding: utf-8 -*-
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="/Users/zhifu/Downloads/modelscope_models/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online", model_revision="v2.0.4",
|
||||
# vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
# vad_model_revision="v2.0.4",
|
||||
# punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
||||
# punc_model_revision="v2.0.4",
|
||||
)
|
||||
|
||||
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
|
||||
print(res)
|
||||
|
||||
|
||||
''' can not use currently
|
||||
from funasr import AutoFrontend
|
||||
|
||||
frontend = AutoFrontend(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.4")
|
||||
|
||||
fbanks = frontend(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", batch_size=2)
|
||||
|
||||
for batch_idx, fbank_dict in enumerate(fbanks):
|
||||
res = model.generate(**fbank_dict)
|
||||
print(res)
|
||||
'''
|
||||
11
examples/industrial_data_pretraining/uniasr/infer.sh
Normal file
11
examples/industrial_data_pretraining/uniasr/infer.sh
Normal file
@ -0,0 +1,11 @@
|
||||
|
||||
model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
||||
model_revision="v2.0.4"
|
||||
|
||||
python funasr/bin/inference.py \
|
||||
+model=${model} \
|
||||
+model_revision=${model_revision} \
|
||||
+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" \
|
||||
+output_dir="./outputs/debug" \
|
||||
+device="cpu" \
|
||||
|
||||
@ -348,7 +348,7 @@ class Transformer(nn.Module):
|
||||
scorers["ngram"] = ngram
|
||||
|
||||
weights = dict(
|
||||
decoder=1.0 - kwargs.get("decoding_ctc_weight"),
|
||||
decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.0),
|
||||
ctc=kwargs.get("decoding_ctc_weight", 0.0),
|
||||
lm=kwargs.get("lm_weight", 0.0),
|
||||
ngram=kwargs.get("ngram_weight", 0.0),
|
||||
|
||||
496
funasr/models/uniasr/beam_search.py
Normal file
496
funasr/models/uniasr/beam_search.py
Normal file
@ -0,0 +1,496 @@
|
||||
"""Beam search module."""
|
||||
|
||||
from itertools import chain
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import NamedTuple
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from funasr.metrics.common import end_detect
|
||||
from funasr.models.transformer.scorers.scorer_interface import PartialScorerInterface
|
||||
from funasr.models.transformer.scorers.scorer_interface import ScorerInterface
|
||||
|
||||
|
||||
class Hypothesis(NamedTuple):
|
||||
"""Hypothesis data type."""
|
||||
|
||||
yseq: torch.Tensor
|
||||
score: Union[float, torch.Tensor] = 0
|
||||
scores: Dict[str, Union[float, torch.Tensor]] = dict()
|
||||
states: Dict[str, Any] = dict()
|
||||
|
||||
def asdict(self) -> dict:
|
||||
"""Convert data to JSON-friendly dict."""
|
||||
return self._replace(
|
||||
yseq=self.yseq.tolist(),
|
||||
score=float(self.score),
|
||||
scores={k: float(v) for k, v in self.scores.items()},
|
||||
)._asdict()
|
||||
|
||||
|
||||
|
||||
class BeamSearchScama(torch.nn.Module):
|
||||
"""Beam search implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scorers: Dict[str, ScorerInterface],
|
||||
weights: Dict[str, float],
|
||||
beam_size: int,
|
||||
vocab_size: int,
|
||||
sos: int,
|
||||
eos: int,
|
||||
token_list: List[str] = None,
|
||||
pre_beam_ratio: float = 1.5,
|
||||
pre_beam_score_key: str = None,
|
||||
):
|
||||
"""Initialize beam search.
|
||||
|
||||
Args:
|
||||
scorers (dict[str, ScorerInterface]): Dict of decoder modules
|
||||
e.g., Decoder, CTCPrefixScorer, LM
|
||||
The scorer will be ignored if it is `None`
|
||||
weights (dict[str, float]): Dict of weights for each scorers
|
||||
The scorer will be ignored if its weight is 0
|
||||
beam_size (int): The number of hypotheses kept during search
|
||||
vocab_size (int): The number of vocabulary
|
||||
sos (int): Start of sequence id
|
||||
eos (int): End of sequence id
|
||||
token_list (list[str]): List of tokens for debug log
|
||||
pre_beam_score_key (str): key of scores to perform pre-beam search
|
||||
pre_beam_ratio (float): beam size in the pre-beam search
|
||||
will be `int(pre_beam_ratio * beam_size)`
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
# set scorers
|
||||
self.weights = weights
|
||||
self.scorers = dict()
|
||||
self.full_scorers = dict()
|
||||
self.part_scorers = dict()
|
||||
# this module dict is required for recursive cast
|
||||
# `self.to(device, dtype)` in `recog.py`
|
||||
self.nn_dict = torch.nn.ModuleDict()
|
||||
for k, v in scorers.items():
|
||||
w = weights.get(k, 0)
|
||||
if w == 0 or v is None:
|
||||
continue
|
||||
assert isinstance(
|
||||
v, ScorerInterface
|
||||
), f"{k} ({type(v)}) does not implement ScorerInterface"
|
||||
self.scorers[k] = v
|
||||
if isinstance(v, PartialScorerInterface):
|
||||
self.part_scorers[k] = v
|
||||
else:
|
||||
self.full_scorers[k] = v
|
||||
if isinstance(v, torch.nn.Module):
|
||||
self.nn_dict[k] = v
|
||||
|
||||
# set configurations
|
||||
self.sos = sos
|
||||
self.eos = eos
|
||||
self.token_list = token_list
|
||||
self.pre_beam_size = int(pre_beam_ratio * beam_size)
|
||||
self.beam_size = beam_size
|
||||
self.n_vocab = vocab_size
|
||||
if (
|
||||
pre_beam_score_key is not None
|
||||
and pre_beam_score_key != "full"
|
||||
and pre_beam_score_key not in self.full_scorers
|
||||
):
|
||||
raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
|
||||
self.pre_beam_score_key = pre_beam_score_key
|
||||
self.do_pre_beam = (
|
||||
self.pre_beam_score_key is not None
|
||||
and self.pre_beam_size < self.n_vocab
|
||||
and len(self.part_scorers) > 0
|
||||
)
|
||||
|
||||
def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]:
|
||||
"""Get an initial hypothesis data.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The encoder output feature
|
||||
|
||||
Returns:
|
||||
Hypothesis: The initial hypothesis.
|
||||
|
||||
"""
|
||||
init_states = dict()
|
||||
init_scores = dict()
|
||||
for k, d in self.scorers.items():
|
||||
init_states[k] = d.init_state(x)
|
||||
init_scores[k] = 0.0
|
||||
return [
|
||||
Hypothesis(
|
||||
score=0.0,
|
||||
scores=init_scores,
|
||||
states=init_states,
|
||||
yseq=torch.tensor([self.sos], device=x.device),
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
|
||||
"""Append new token to prefix tokens.
|
||||
|
||||
Args:
|
||||
xs (torch.Tensor): The prefix token
|
||||
x (int): The new token to append
|
||||
|
||||
Returns:
|
||||
torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device
|
||||
|
||||
"""
|
||||
x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
|
||||
return torch.cat((xs, x))
|
||||
|
||||
def score_full(
|
||||
self, hyp: Hypothesis,
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor = None,
|
||||
pre_acoustic_embeds: torch.Tensor = None,
|
||||
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
|
||||
"""Score new hypothesis by `self.full_scorers`.
|
||||
|
||||
Args:
|
||||
hyp (Hypothesis): Hypothesis with prefix tokens to score
|
||||
x (torch.Tensor): Corresponding input feature
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
|
||||
score dict of `hyp` that has string keys of `self.full_scorers`
|
||||
and tensor score values of shape: `(self.n_vocab,)`,
|
||||
and state dict that has string keys
|
||||
and state values of `self.full_scorers`
|
||||
|
||||
"""
|
||||
scores = dict()
|
||||
states = dict()
|
||||
for k, d in self.full_scorers.items():
|
||||
scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds)
|
||||
return scores, states
|
||||
|
||||
def score_partial(
|
||||
self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor
|
||||
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
|
||||
"""Score new hypothesis by `self.part_scorers`.
|
||||
|
||||
Args:
|
||||
hyp (Hypothesis): Hypothesis with prefix tokens to score
|
||||
ids (torch.Tensor): 1D tensor of new partial tokens to score
|
||||
x (torch.Tensor): Corresponding input feature
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
|
||||
score dict of `hyp` that has string keys of `self.part_scorers`
|
||||
and tensor score values of shape: `(len(ids),)`,
|
||||
and state dict that has string keys
|
||||
and state values of `self.part_scorers`
|
||||
|
||||
"""
|
||||
scores = dict()
|
||||
states = dict()
|
||||
for k, d in self.part_scorers.items():
|
||||
scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x)
|
||||
return scores, states
|
||||
|
||||
def beam(
|
||||
self, weighted_scores: torch.Tensor, ids: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute topk full token ids and partial token ids.
|
||||
|
||||
Args:
|
||||
weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
|
||||
Its shape is `(self.n_vocab,)`.
|
||||
ids (torch.Tensor): The partial token ids to compute topk
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]:
|
||||
The topk full token ids and partial token ids.
|
||||
Their shapes are `(self.beam_size,)`
|
||||
|
||||
"""
|
||||
# no pre beam performed
|
||||
if weighted_scores.size(0) == ids.size(0):
|
||||
top_ids = weighted_scores.topk(self.beam_size)[1]
|
||||
return top_ids, top_ids
|
||||
|
||||
# mask pruned in pre-beam not to select in topk
|
||||
tmp = weighted_scores[ids]
|
||||
weighted_scores[:] = -float("inf")
|
||||
weighted_scores[ids] = tmp
|
||||
top_ids = weighted_scores.topk(self.beam_size)[1]
|
||||
local_ids = weighted_scores[ids].topk(self.beam_size)[1]
|
||||
return top_ids, local_ids
|
||||
|
||||
@staticmethod
|
||||
def merge_scores(
|
||||
prev_scores: Dict[str, float],
|
||||
next_full_scores: Dict[str, torch.Tensor],
|
||||
full_idx: int,
|
||||
next_part_scores: Dict[str, torch.Tensor],
|
||||
part_idx: int,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Merge scores for new hypothesis.
|
||||
|
||||
Args:
|
||||
prev_scores (Dict[str, float]):
|
||||
The previous hypothesis scores by `self.scorers`
|
||||
next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
|
||||
full_idx (int): The next token id for `next_full_scores`
|
||||
next_part_scores (Dict[str, torch.Tensor]):
|
||||
scores of partial tokens by `self.part_scorers`
|
||||
part_idx (int): The new token id for `next_part_scores`
|
||||
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: The new score dict.
|
||||
Its keys are names of `self.full_scorers` and `self.part_scorers`.
|
||||
Its values are scalar tensors by the scorers.
|
||||
|
||||
"""
|
||||
new_scores = dict()
|
||||
for k, v in next_full_scores.items():
|
||||
new_scores[k] = prev_scores[k] + v[full_idx]
|
||||
for k, v in next_part_scores.items():
|
||||
new_scores[k] = prev_scores[k] + v[part_idx]
|
||||
return new_scores
|
||||
|
||||
def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
|
||||
"""Merge states for new hypothesis.
|
||||
|
||||
Args:
|
||||
states: states of `self.full_scorers`
|
||||
part_states: states of `self.part_scorers`
|
||||
part_idx (int): The new token id for `part_scores`
|
||||
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: The new score dict.
|
||||
Its keys are names of `self.full_scorers` and `self.part_scorers`.
|
||||
Its values are states of the scorers.
|
||||
|
||||
"""
|
||||
new_states = dict()
|
||||
for k, v in states.items():
|
||||
new_states[k] = v
|
||||
for k, d in self.part_scorers.items():
|
||||
new_states[k] = d.select_state(part_states[k], part_idx)
|
||||
return new_states
|
||||
|
||||
def search(
|
||||
self, running_hyps: List[Hypothesis],
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor = None,
|
||||
pre_acoustic_embeds: torch.Tensor = None,
|
||||
) -> List[Hypothesis]:
|
||||
"""Search new tokens for running hypotheses and encoded speech x.
|
||||
|
||||
Args:
|
||||
running_hyps (List[Hypothesis]): Running hypotheses on beam
|
||||
x (torch.Tensor): Encoded speech feature (T, D)
|
||||
|
||||
Returns:
|
||||
List[Hypotheses]: Best sorted hypotheses
|
||||
|
||||
"""
|
||||
best_hyps = []
|
||||
part_ids = torch.arange(self.n_vocab, device=x.device) # no pre-beam
|
||||
for hyp in running_hyps:
|
||||
# scoring
|
||||
weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
|
||||
scores, states = self.score_full(hyp, x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds)
|
||||
for k in self.full_scorers:
|
||||
weighted_scores += self.weights[k] * scores[k]
|
||||
# partial scoring
|
||||
if self.do_pre_beam:
|
||||
pre_beam_scores = (
|
||||
weighted_scores
|
||||
if self.pre_beam_score_key == "full"
|
||||
else scores[self.pre_beam_score_key]
|
||||
)
|
||||
part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1]
|
||||
part_scores, part_states = self.score_partial(hyp, part_ids, x)
|
||||
for k in self.part_scorers:
|
||||
weighted_scores[part_ids] += self.weights[k] * part_scores[k]
|
||||
# add previous hyp score
|
||||
weighted_scores += hyp.score
|
||||
|
||||
# update hyps
|
||||
for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
|
||||
# will be (2 x beam at most)
|
||||
best_hyps.append(
|
||||
Hypothesis(
|
||||
score=weighted_scores[j],
|
||||
yseq=self.append_token(hyp.yseq, j),
|
||||
scores=self.merge_scores(
|
||||
hyp.scores, scores, j, part_scores, part_j
|
||||
),
|
||||
states=self.merge_states(states, part_states, part_j),
|
||||
)
|
||||
)
|
||||
|
||||
# sort and prune 2 x beam -> beam
|
||||
best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
|
||||
: min(len(best_hyps), self.beam_size)
|
||||
]
|
||||
return best_hyps
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor,
|
||||
scama_mask: torch.Tensor = None,
|
||||
pre_acoustic_embeds: torch.Tensor = None,
|
||||
maxlenratio: float = 0.0,
|
||||
minlenratio: float = 0.0,
|
||||
maxlen: int = None,
|
||||
minlen: int = 0,
|
||||
) -> List[Hypothesis]:
|
||||
"""Perform beam search.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Encoded speech feature (T, D)
|
||||
maxlenratio (float): Input length ratio to obtain max output length.
|
||||
If maxlenratio=0.0 (default), it uses a end-detect function
|
||||
to automatically find maximum hypothesis lengths
|
||||
If maxlenratio<0.0, its absolute value is interpreted
|
||||
as a constant max output length.
|
||||
minlenratio (float): Input length ratio to obtain min output length.
|
||||
|
||||
Returns:
|
||||
list[Hypothesis]: N-best decoding results
|
||||
|
||||
"""
|
||||
if maxlen is None:
|
||||
# set length bounds
|
||||
if maxlenratio == 0:
|
||||
maxlen = x.shape[0]
|
||||
elif maxlenratio < 0:
|
||||
maxlen = -1 * int(maxlenratio)
|
||||
else:
|
||||
maxlen = max(1, int(maxlenratio * x.size(0)))
|
||||
minlen = int(minlenratio * x.size(0))
|
||||
|
||||
logging.info("decoder input length: " + str(x.shape[0]))
|
||||
logging.info("max output length: " + str(maxlen))
|
||||
logging.info("min output length: " + str(minlen))
|
||||
|
||||
# main loop of prefix search
|
||||
running_hyps = self.init_hyp(x)
|
||||
ended_hyps = []
|
||||
for i in range(maxlen):
|
||||
logging.debug("position " + str(i))
|
||||
mask_enc = None
|
||||
if scama_mask is not None:
|
||||
token_num_predictor = scama_mask.size(1)
|
||||
token_id_slice = min(i, token_num_predictor-1)
|
||||
mask_enc = scama_mask[:, token_id_slice:token_id_slice+1, :]
|
||||
# if mask_enc.size(1) == 0:
|
||||
# mask_enc = scama_mask[:, -2:-1, :]
|
||||
# # mask_enc = torch.zeros_like(mask_enc)
|
||||
pre_acoustic_embeds_cur = None
|
||||
if pre_acoustic_embeds is not None:
|
||||
b, t, d = pre_acoustic_embeds.size()
|
||||
pad = torch.zeros((b, 1, d), dtype=pre_acoustic_embeds.dtype).to(device=pre_acoustic_embeds.device)
|
||||
pre_acoustic_embeds = torch.cat((pre_acoustic_embeds, pad), dim=1)
|
||||
token_id_slice = min(i, t)
|
||||
pre_acoustic_embeds_cur = pre_acoustic_embeds[:, token_id_slice:token_id_slice+1, :]
|
||||
|
||||
best = self.search(running_hyps, x, x_mask=mask_enc, pre_acoustic_embeds=pre_acoustic_embeds_cur)
|
||||
# post process of one iteration
|
||||
running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
|
||||
# end detection
|
||||
if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
|
||||
logging.info(f"end detected at {i}")
|
||||
break
|
||||
if len(running_hyps) == 0:
|
||||
logging.info("no hypothesis. Finish decoding.")
|
||||
break
|
||||
else:
|
||||
logging.debug(f"remained hypotheses: {len(running_hyps)}")
|
||||
|
||||
nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
|
||||
# check the number of hypotheses reaching to eos
|
||||
if len(nbest_hyps) == 0:
|
||||
logging.warning(
|
||||
"there is no N-best results, perform recognition "
|
||||
"again with smaller minlenratio."
|
||||
)
|
||||
return (
|
||||
[]
|
||||
if minlenratio < 0.1
|
||||
else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
|
||||
)
|
||||
|
||||
# report the best result
|
||||
for x in nbest_hyps:
|
||||
yseq = "".join([self.token_list[x] for x in x.yseq])
|
||||
logging.debug("nbest: y: {}, yseq: {}, score: {}".format(x.yseq, yseq, x.score))
|
||||
best = nbest_hyps[0]
|
||||
for k, v in best.scores.items():
|
||||
logging.info(
|
||||
f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
|
||||
)
|
||||
logging.info(f"total log probability: {best.score:.2f}")
|
||||
logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
|
||||
logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
|
||||
if self.token_list is not None:
|
||||
logging.info(
|
||||
"best hypo: "
|
||||
+ "".join([self.token_list[x] for x in best.yseq[1:-1]])
|
||||
+ "\n"
|
||||
)
|
||||
return nbest_hyps
|
||||
|
||||
def post_process(
|
||||
self,
|
||||
i: int,
|
||||
maxlen: int,
|
||||
maxlenratio: float,
|
||||
running_hyps: List[Hypothesis],
|
||||
ended_hyps: List[Hypothesis],
|
||||
) -> List[Hypothesis]:
|
||||
"""Perform post-processing of beam search iterations.
|
||||
|
||||
Args:
|
||||
i (int): The length of hypothesis tokens.
|
||||
maxlen (int): The maximum length of tokens in beam search.
|
||||
maxlenratio (int): The maximum length ratio in beam search.
|
||||
running_hyps (List[Hypothesis]): The running hypotheses in beam search.
|
||||
ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
|
||||
|
||||
Returns:
|
||||
List[Hypothesis]: The new running hypotheses.
|
||||
|
||||
"""
|
||||
logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
|
||||
if self.token_list is not None:
|
||||
logging.debug(
|
||||
"best hypo: "
|
||||
+ "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
|
||||
)
|
||||
# add eos in the final loop to avoid that there are no ended hyps
|
||||
if i == maxlen - 1:
|
||||
logging.info("adding <eos> in the last position in the loop")
|
||||
running_hyps = [
|
||||
h._replace(yseq=self.append_token(h.yseq, self.eos))
|
||||
for h in running_hyps
|
||||
]
|
||||
|
||||
# add ended hypotheses to a final list, and removed them from current hypotheses
|
||||
# (this will be a problem, number of hyps < beam)
|
||||
remained_hyps = []
|
||||
for hyp in running_hyps:
|
||||
if hyp.yseq[-1] == self.eos:
|
||||
# e.g., Word LM needs to add final <eos> score
|
||||
for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
|
||||
s = d.final_score(hyp.states[k])
|
||||
hyp.scores[k] += s
|
||||
hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
|
||||
ended_hyps.append(hyp)
|
||||
else:
|
||||
remained_hyps.append(hyp)
|
||||
return remained_hyps
|
||||
@ -14,14 +14,13 @@ from funasr.models.ctc.ctc import CTC
|
||||
from funasr.utils import postprocess_utils
|
||||
from funasr.metrics.compute_acc import th_accuracy
|
||||
from funasr.utils.datadir_writer import DatadirWriter
|
||||
from funasr.models.paraformer.search import Hypothesis
|
||||
from funasr.models.paraformer.cif_predictor import mae_loss
|
||||
from funasr.train_utils.device_funcs import force_gatherable
|
||||
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
|
||||
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
|
||||
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
|
||||
|
||||
from funasr.models.scama.utils import sequence_mask
|
||||
|
||||
@tables.register("model_classes", "UniASR")
|
||||
class UniASR(torch.nn.Module):
|
||||
@ -31,19 +30,37 @@ class UniASR(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
specaug: Optional[str] = None,
|
||||
specaug_conf: Optional[Dict] = None,
|
||||
specaug: str = None,
|
||||
specaug_conf: dict = None,
|
||||
normalize: str = None,
|
||||
normalize_conf: Optional[Dict] = None,
|
||||
normalize_conf: dict = None,
|
||||
encoder: str = None,
|
||||
encoder_conf: Optional[Dict] = None,
|
||||
encoder_conf: dict = None,
|
||||
encoder2: str = None,
|
||||
encoder2_conf: dict = None,
|
||||
decoder: str = None,
|
||||
decoder_conf: Optional[Dict] = None,
|
||||
ctc: str = None,
|
||||
ctc_conf: Optional[Dict] = None,
|
||||
decoder_conf: dict = None,
|
||||
decoder2: str = None,
|
||||
decoder2_conf: dict = None,
|
||||
predictor: str = None,
|
||||
predictor_conf: Optional[Dict] = None,
|
||||
predictor_conf: dict = None,
|
||||
predictor_bias: int = 0,
|
||||
predictor_weight: float = 0.0,
|
||||
predictor2: str = None,
|
||||
predictor2_conf: dict = None,
|
||||
predictor2_bias: int = 0,
|
||||
predictor2_weight: float = 0.0,
|
||||
ctc: str = None,
|
||||
ctc_conf: dict = None,
|
||||
ctc_weight: float = 0.5,
|
||||
ctc2: str = None,
|
||||
ctc2_conf: dict = None,
|
||||
ctc2_weight: float = 0.5,
|
||||
decoder_attention_chunk_type: str = 'chunk',
|
||||
decoder_attention_chunk_type2: str = 'chunk',
|
||||
stride_conv=None,
|
||||
stride_conv_conf: dict = None,
|
||||
loss_weight_model1: float = 0.5,
|
||||
input_size: int = 80,
|
||||
vocab_size: int = -1,
|
||||
ignore_id: int = -1,
|
||||
@ -52,60 +69,72 @@ class UniASR(torch.nn.Module):
|
||||
eos: int = 2,
|
||||
lsm_weight: float = 0.0,
|
||||
length_normalized_loss: bool = False,
|
||||
# report_cer: bool = True,
|
||||
# report_wer: bool = True,
|
||||
# sym_space: str = "<space>",
|
||||
# sym_blank: str = "<blank>",
|
||||
# extract_feats_in_collect_stats: bool = True,
|
||||
# predictor=None,
|
||||
predictor_weight: float = 0.0,
|
||||
predictor_bias: int = 0,
|
||||
sampling_ratio: float = 0.2,
|
||||
share_embedding: bool = False,
|
||||
# preencoder: Optional[AbsPreEncoder] = None,
|
||||
# postencoder: Optional[AbsPostEncoder] = None,
|
||||
use_1st_decoder_loss: bool = False,
|
||||
encoder1_encoder2_joint_training: bool = True,
|
||||
**kwargs,
|
||||
|
||||
):
|
||||
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
|
||||
assert 0.0 <= interctc_weight < 1.0, interctc_weight
|
||||
|
||||
super().__init__()
|
||||
self.blank_id = 0
|
||||
self.sos = 1
|
||||
self.eos = 2
|
||||
|
||||
if specaug is not None:
|
||||
specaug_class = tables.specaug_classes.get(specaug)
|
||||
specaug = specaug_class(**specaug_conf)
|
||||
if normalize is not None:
|
||||
normalize_class = tables.normalize_classes.get(normalize)
|
||||
normalize = normalize_class(**normalize_conf)
|
||||
|
||||
encoder_class = tables.encoder_classes.get(encoder)
|
||||
encoder = encoder_class(input_size=input_size, **encoder_conf)
|
||||
encoder_output_size = encoder.output_size()
|
||||
|
||||
decoder_class = tables.decoder_classes.get(decoder)
|
||||
decoder = decoder_class(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
**decoder_conf,
|
||||
)
|
||||
predictor_class = tables.predictor_classes.get(predictor)
|
||||
predictor = predictor_class(**predictor_conf)
|
||||
|
||||
|
||||
|
||||
from funasr.models.transformer.utils.subsampling import Conv1dSubsampling
|
||||
stride_conv = Conv1dSubsampling(**stride_conv_conf, idim=input_size + encoder_output_size,
|
||||
odim=input_size + encoder_output_size)
|
||||
stride_conv_output_size = stride_conv.output_size()
|
||||
|
||||
encoder_class = tables.encoder_classes.get(encoder2)
|
||||
encoder2 = encoder_class(input_size=stride_conv_output_size, **encoder2_conf)
|
||||
encoder2_output_size = encoder2.output_size()
|
||||
|
||||
decoder_class = tables.decoder_classes.get(decoder2)
|
||||
decoder2 = decoder_class(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder2_output_size,
|
||||
**decoder2_conf,
|
||||
)
|
||||
predictor_class = tables.predictor_classes.get(predictor2)
|
||||
predictor2 = predictor_class(**predictor2_conf)
|
||||
|
||||
|
||||
|
||||
self.blank_id = blank_id
|
||||
self.sos = sos
|
||||
self.eos = eos
|
||||
self.vocab_size = vocab_size
|
||||
self.ignore_id = ignore_id
|
||||
self.ctc_weight = ctc_weight
|
||||
self.interctc_weight = interctc_weight
|
||||
self.token_list = token_list.copy()
|
||||
self.ctc2_weight = ctc2_weight
|
||||
|
||||
self.frontend = frontend
|
||||
self.specaug = specaug
|
||||
self.normalize = normalize
|
||||
self.preencoder = preencoder
|
||||
self.postencoder = postencoder
|
||||
|
||||
self.encoder = encoder
|
||||
|
||||
if not hasattr(self.encoder, "interctc_use_conditioning"):
|
||||
self.encoder.interctc_use_conditioning = False
|
||||
if self.encoder.interctc_use_conditioning:
|
||||
self.encoder.conditioning_layer = torch.nn.Linear(
|
||||
vocab_size, self.encoder.output_size()
|
||||
)
|
||||
|
||||
self.error_calculator = None
|
||||
|
||||
# we set self.decoder = None in the CTC mode since
|
||||
# self.decoder parameters were never used and PyTorch complained
|
||||
# and threw an Exception in the multi-GPU experiment.
|
||||
# thanks Jeff Farris for pointing out the issue.
|
||||
if ctc_weight == 1.0:
|
||||
self.decoder = None
|
||||
else:
|
||||
self.decoder = decoder
|
||||
self.decoder = decoder
|
||||
self.ctc = None
|
||||
self.ctc2 = None
|
||||
|
||||
self.criterion_att = LabelSmoothingLoss(
|
||||
size=vocab_size,
|
||||
@ -113,22 +142,13 @@ class UniASR(torch.nn.Module):
|
||||
smoothing=lsm_weight,
|
||||
normalize_length=length_normalized_loss,
|
||||
)
|
||||
|
||||
if report_cer or report_wer:
|
||||
self.error_calculator = ErrorCalculator(
|
||||
token_list, sym_space, sym_blank, report_cer, report_wer
|
||||
)
|
||||
|
||||
if ctc_weight == 0.0:
|
||||
self.ctc = None
|
||||
else:
|
||||
self.ctc = ctc
|
||||
|
||||
self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
|
||||
|
||||
self.predictor = predictor
|
||||
self.predictor_weight = predictor_weight
|
||||
self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
|
||||
self.step_cur = 0
|
||||
self.encoder1_encoder2_joint_training = kwargs.get("encoder1_encoder2_joint_training", True)
|
||||
|
||||
|
||||
if self.encoder.overlap_chunk_cls is not None:
|
||||
from funasr.models.scama.chunk_utilis import build_scama_mask_for_cross_attention_decoder
|
||||
self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
|
||||
@ -136,14 +156,10 @@ class UniASR(torch.nn.Module):
|
||||
|
||||
self.encoder2 = encoder2
|
||||
self.decoder2 = decoder2
|
||||
self.ctc_weight2 = ctc_weight2
|
||||
if ctc_weight2 == 0.0:
|
||||
self.ctc2 = None
|
||||
else:
|
||||
self.ctc2 = ctc2
|
||||
self.interctc_weight2 = interctc_weight2
|
||||
self.ctc2_weight = ctc2_weight
|
||||
|
||||
self.predictor2 = predictor2
|
||||
self.predictor_weight2 = predictor_weight2
|
||||
self.predictor2_weight = predictor2_weight
|
||||
self.decoder_attention_chunk_type2 = decoder_attention_chunk_type2
|
||||
self.stride_conv = stride_conv
|
||||
self.loss_weight_model1 = loss_weight_model1
|
||||
@ -152,10 +168,10 @@ class UniASR(torch.nn.Module):
|
||||
self.build_scama_mask_for_cross_attention_decoder_fn2 = build_scama_mask_for_cross_attention_decoder
|
||||
self.decoder_attention_chunk_type2 = decoder_attention_chunk_type2
|
||||
|
||||
self.enable_maas_finetune = enable_maas_finetune
|
||||
self.freeze_encoder2 = freeze_encoder2
|
||||
self.encoder1_encoder2_joint_training = encoder1_encoder2_joint_training
|
||||
self.length_normalized_loss = length_normalized_loss
|
||||
self.enable_maas_finetune = kwargs.get("enable_maas_finetune", False)
|
||||
self.freeze_encoder2 = kwargs.get("freeze_encoder2", False)
|
||||
self.beam_search = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -163,7 +179,7 @@ class UniASR(torch.nn.Module):
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
decoding_ind: int = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Frontend + Encoder + Decoder + Calc loss
|
||||
Args:
|
||||
@ -172,19 +188,14 @@ class UniASR(torch.nn.Module):
|
||||
text: (Batch, Length)
|
||||
text_lengths: (Batch,)
|
||||
"""
|
||||
assert text_lengths.dim() == 1, text_lengths.shape
|
||||
# Check that batch_size is unified
|
||||
assert (
|
||||
speech.shape[0]
|
||||
== speech_lengths.shape[0]
|
||||
== text.shape[0]
|
||||
== text_lengths.shape[0]
|
||||
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
|
||||
decoding_ind = kwargs.get("decoding_ind", None)
|
||||
if len(text_lengths.size()) > 1:
|
||||
text_lengths = text_lengths[:, 0]
|
||||
if len(speech_lengths.size()) > 1:
|
||||
speech_lengths = speech_lengths[:, 0]
|
||||
|
||||
batch_size = speech.shape[0]
|
||||
|
||||
# for data-parallel
|
||||
text = text[:, : text_lengths.max()]
|
||||
speech = speech[:, :speech_lengths.max()]
|
||||
|
||||
ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
|
||||
# 1. Encoder
|
||||
@ -194,10 +205,6 @@ class UniASR(torch.nn.Module):
|
||||
else:
|
||||
speech_raw, encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
|
||||
|
||||
intermediate_outs = None
|
||||
if isinstance(encoder_out, tuple):
|
||||
intermediate_outs = encoder_out[1]
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
loss_att, acc_att, cer_att, wer_att = None, None, None, None
|
||||
loss_ctc, cer_ctc = None, None
|
||||
@ -210,62 +217,12 @@ class UniASR(torch.nn.Module):
|
||||
# 1. CTC branch
|
||||
if self.enable_maas_finetune:
|
||||
with torch.no_grad():
|
||||
if self.ctc_weight != 0.0:
|
||||
if self.encoder.overlap_chunk_cls is not None:
|
||||
encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
|
||||
encoder_out_lens,
|
||||
chunk_outs=None)
|
||||
loss_ctc, cer_ctc = self._calc_ctc_loss(
|
||||
encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
|
||||
)
|
||||
|
||||
# Collect CTC branch stats
|
||||
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
|
||||
stats["cer_ctc"] = cer_ctc
|
||||
loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss(
|
||||
encoder_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
|
||||
# Intermediate CTC (optional)
|
||||
loss_interctc = 0.0
|
||||
if self.interctc_weight != 0.0 and intermediate_outs is not None:
|
||||
for layer_idx, intermediate_out in intermediate_outs:
|
||||
# we assume intermediate_out has the same length & padding
|
||||
# as those of encoder_out
|
||||
if self.encoder.overlap_chunk_cls is not None:
|
||||
encoder_out_ctc, encoder_out_lens_ctc = \
|
||||
self.encoder.overlap_chunk_cls.remove_chunk(
|
||||
intermediate_out,
|
||||
encoder_out_lens,
|
||||
chunk_outs=None)
|
||||
loss_ic, cer_ic = self._calc_ctc_loss(
|
||||
encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
|
||||
)
|
||||
loss_interctc = loss_interctc + loss_ic
|
||||
|
||||
# Collect Intermedaite CTC stats
|
||||
stats["loss_interctc_layer{}".format(layer_idx)] = (
|
||||
loss_ic.detach() if loss_ic is not None else None
|
||||
)
|
||||
stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
|
||||
|
||||
loss_interctc = loss_interctc / len(intermediate_outs)
|
||||
|
||||
# calculate whole encoder loss
|
||||
loss_ctc = (
|
||||
1 - self.interctc_weight
|
||||
) * loss_ctc + self.interctc_weight * loss_interctc
|
||||
|
||||
# 2b. Attention decoder branch
|
||||
if self.ctc_weight != 1.0:
|
||||
loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss(
|
||||
encoder_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
|
||||
# 3. CTC-Att loss definition
|
||||
if self.ctc_weight == 0.0:
|
||||
loss = loss_att + loss_pre * self.predictor_weight
|
||||
elif self.ctc_weight == 1.0:
|
||||
loss = loss_ctc
|
||||
else:
|
||||
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
|
||||
loss = loss_att + loss_pre * self.predictor_weight
|
||||
|
||||
# Collect Attn branch stats
|
||||
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
|
||||
@ -274,62 +231,13 @@ class UniASR(torch.nn.Module):
|
||||
stats["wer"] = wer_att
|
||||
stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
|
||||
else:
|
||||
if self.ctc_weight != 0.0:
|
||||
if self.encoder.overlap_chunk_cls is not None:
|
||||
encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
|
||||
encoder_out_lens,
|
||||
chunk_outs=None)
|
||||
loss_ctc, cer_ctc = self._calc_ctc_loss(
|
||||
encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
|
||||
)
|
||||
|
||||
loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss(
|
||||
encoder_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
|
||||
# Collect CTC branch stats
|
||||
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
|
||||
stats["cer_ctc"] = cer_ctc
|
||||
|
||||
# Intermediate CTC (optional)
|
||||
loss_interctc = 0.0
|
||||
if self.interctc_weight != 0.0 and intermediate_outs is not None:
|
||||
for layer_idx, intermediate_out in intermediate_outs:
|
||||
# we assume intermediate_out has the same length & padding
|
||||
# as those of encoder_out
|
||||
if self.encoder.overlap_chunk_cls is not None:
|
||||
encoder_out_ctc, encoder_out_lens_ctc = \
|
||||
self.encoder.overlap_chunk_cls.remove_chunk(
|
||||
intermediate_out,
|
||||
encoder_out_lens,
|
||||
chunk_outs=None)
|
||||
loss_ic, cer_ic = self._calc_ctc_loss(
|
||||
encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
|
||||
)
|
||||
loss_interctc = loss_interctc + loss_ic
|
||||
|
||||
# Collect Intermedaite CTC stats
|
||||
stats["loss_interctc_layer{}".format(layer_idx)] = (
|
||||
loss_ic.detach() if loss_ic is not None else None
|
||||
)
|
||||
stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
|
||||
|
||||
loss_interctc = loss_interctc / len(intermediate_outs)
|
||||
|
||||
# calculate whole encoder loss
|
||||
loss_ctc = (
|
||||
1 - self.interctc_weight
|
||||
) * loss_ctc + self.interctc_weight * loss_interctc
|
||||
|
||||
# 2b. Attention decoder branch
|
||||
if self.ctc_weight != 1.0:
|
||||
loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss(
|
||||
encoder_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
|
||||
# 3. CTC-Att loss definition
|
||||
if self.ctc_weight == 0.0:
|
||||
loss = loss_att + loss_pre * self.predictor_weight
|
||||
elif self.ctc_weight == 1.0:
|
||||
loss = loss_ctc
|
||||
else:
|
||||
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
|
||||
loss = loss_att + loss_pre * self.predictor_weight
|
||||
|
||||
# Collect Attn branch stats
|
||||
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
|
||||
@ -354,67 +262,14 @@ class UniASR(torch.nn.Module):
|
||||
if isinstance(encoder_out, tuple):
|
||||
intermediate_outs = encoder_out[1]
|
||||
encoder_out = encoder_out[0]
|
||||
# CTC2
|
||||
if self.ctc_weight2 != 0.0:
|
||||
if self.encoder2.overlap_chunk_cls is not None:
|
||||
encoder_out_ctc, encoder_out_lens_ctc = \
|
||||
self.encoder2.overlap_chunk_cls.remove_chunk(
|
||||
encoder_out,
|
||||
encoder_out_lens,
|
||||
chunk_outs=None,
|
||||
)
|
||||
loss_ctc, cer_ctc = self._calc_ctc_loss2(
|
||||
encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
|
||||
)
|
||||
|
||||
# Collect CTC branch stats
|
||||
stats["loss_ctc2"] = loss_ctc.detach() if loss_ctc is not None else None
|
||||
stats["cer_ctc2"] = cer_ctc
|
||||
|
||||
# Intermediate CTC (optional)
|
||||
loss_interctc = 0.0
|
||||
if self.interctc_weight2 != 0.0 and intermediate_outs is not None:
|
||||
for layer_idx, intermediate_out in intermediate_outs:
|
||||
# we assume intermediate_out has the same length & padding
|
||||
# as those of encoder_out
|
||||
if self.encoder2.overlap_chunk_cls is not None:
|
||||
encoder_out_ctc, encoder_out_lens_ctc = \
|
||||
self.encoder2.overlap_chunk_cls.remove_chunk(
|
||||
intermediate_out,
|
||||
encoder_out_lens,
|
||||
chunk_outs=None)
|
||||
loss_ic, cer_ic = self._calc_ctc_loss2(
|
||||
encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
|
||||
)
|
||||
loss_interctc = loss_interctc + loss_ic
|
||||
loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss2(
|
||||
encoder_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
|
||||
# Collect Intermedaite CTC stats
|
||||
stats["loss_interctc_layer{}2".format(layer_idx)] = (
|
||||
loss_ic.detach() if loss_ic is not None else None
|
||||
)
|
||||
stats["cer_interctc_layer{}2".format(layer_idx)] = cer_ic
|
||||
|
||||
loss_interctc = loss_interctc / len(intermediate_outs)
|
||||
|
||||
# calculate whole encoder loss
|
||||
loss_ctc = (
|
||||
1 - self.interctc_weight2
|
||||
) * loss_ctc + self.interctc_weight2 * loss_interctc
|
||||
|
||||
# 2b. Attention decoder branch
|
||||
if self.ctc_weight2 != 1.0:
|
||||
loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss2(
|
||||
encoder_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
|
||||
# 3. CTC-Att loss definition
|
||||
if self.ctc_weight2 == 0.0:
|
||||
loss = loss_att + loss_pre * self.predictor_weight2
|
||||
elif self.ctc_weight2 == 1.0:
|
||||
loss = loss_ctc
|
||||
else:
|
||||
loss = self.ctc_weight2 * loss_ctc + (
|
||||
1 - self.ctc_weight2) * loss_att + loss_pre * self.predictor_weight2
|
||||
loss = loss_att + loss_pre * self.predictor2_weight
|
||||
|
||||
# Collect Attn branch stats
|
||||
stats["loss_att2"] = loss_att.detach() if loss_att is not None else None
|
||||
@ -422,6 +277,7 @@ class UniASR(torch.nn.Module):
|
||||
stats["cer2"] = cer_att
|
||||
stats["wer2"] = wer_att
|
||||
stats["loss_pre2"] = loss_pre.detach().cpu() if loss_pre is not None else None
|
||||
|
||||
loss2 = loss
|
||||
|
||||
loss = loss1 * self.loss_weight_model1 + loss2 * (1 - self.loss_weight_model1)
|
||||
@ -456,61 +312,31 @@ class UniASR(torch.nn.Module):
|
||||
return {"feats": feats, "feats_lengths": feats_lengths}
|
||||
|
||||
def encode(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
|
||||
):
|
||||
"""Frontend + Encoder. Note that this method is used by asr_inference.py
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
"""
|
||||
ind = kwargs.get("ind", 0)
|
||||
with autocast(False):
|
||||
# 1. Extract feats
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
|
||||
# 2. Data augmentation
|
||||
# Data augmentation
|
||||
if self.specaug is not None and self.training:
|
||||
feats, feats_lengths = self.specaug(feats, feats_lengths)
|
||||
|
||||
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
||||
speech, speech_lengths = self.specaug(speech, speech_lengths)
|
||||
|
||||
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
||||
if self.normalize is not None:
|
||||
feats, feats_lengths = self.normalize(feats, feats_lengths)
|
||||
speech_raw = feats.clone().to(feats.device)
|
||||
# Pre-encoder, e.g. used for raw input data
|
||||
if self.preencoder is not None:
|
||||
feats, feats_lengths = self.preencoder(feats, feats_lengths)
|
||||
speech, speech_lengths = self.normalize(speech, speech_lengths)
|
||||
|
||||
speech_raw = speech.clone().to(speech.device)
|
||||
|
||||
|
||||
# 4. Forward encoder
|
||||
# feats: (Batch, Length, Dim)
|
||||
# -> encoder_out: (Batch, Length2, Dim2)
|
||||
if self.encoder.interctc_use_conditioning:
|
||||
encoder_out, encoder_out_lens, _ = self.encoder(
|
||||
feats, feats_lengths, ctc=self.ctc, ind=ind
|
||||
)
|
||||
else:
|
||||
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, ind=ind)
|
||||
intermediate_outs = None
|
||||
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths, ind=ind)
|
||||
if isinstance(encoder_out, tuple):
|
||||
intermediate_outs = encoder_out[1]
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
# Post-encoder, e.g. NLU
|
||||
if self.postencoder is not None:
|
||||
encoder_out, encoder_out_lens = self.postencoder(
|
||||
encoder_out, encoder_out_lens
|
||||
)
|
||||
|
||||
assert encoder_out.size(0) == speech.size(0), (
|
||||
encoder_out.size(),
|
||||
speech.size(0),
|
||||
)
|
||||
assert encoder_out.size(1) <= encoder_out_lens.max(), (
|
||||
encoder_out.size(),
|
||||
encoder_out_lens.max(),
|
||||
)
|
||||
|
||||
if intermediate_outs is not None:
|
||||
return (encoder_out, intermediate_outs), encoder_out_lens
|
||||
|
||||
return speech_raw, encoder_out, encoder_out_lens
|
||||
|
||||
def encode2(
|
||||
@ -519,28 +345,15 @@ class UniASR(torch.nn.Module):
|
||||
encoder_out_lens: torch.Tensor,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
ind: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
**kwargs,
|
||||
):
|
||||
"""Frontend + Encoder. Note that this method is used by asr_inference.py
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
"""
|
||||
# with autocast(False):
|
||||
# # 1. Extract feats
|
||||
# feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
#
|
||||
# # 2. Data augmentation
|
||||
# if self.specaug is not None and self.training:
|
||||
# feats, feats_lengths = self.specaug(feats, feats_lengths)
|
||||
#
|
||||
# # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
||||
# if self.normalize is not None:
|
||||
# feats, feats_lengths = self.normalize(feats, feats_lengths)
|
||||
|
||||
# Pre-encoder, e.g. used for raw input data
|
||||
# if self.preencoder is not None:
|
||||
# feats, feats_lengths = self.preencoder(feats, feats_lengths)
|
||||
ind = kwargs.get("ind", 0)
|
||||
encoder_out_rm, encoder_out_lens_rm = self.encoder.overlap_chunk_cls.remove_chunk(
|
||||
encoder_out,
|
||||
encoder_out_lens,
|
||||
@ -557,55 +370,14 @@ class UniASR(torch.nn.Module):
|
||||
# 4. Forward encoder
|
||||
# feats: (Batch, Length, Dim)
|
||||
# -> encoder_out: (Batch, Length2, Dim2)
|
||||
if self.encoder2.interctc_use_conditioning:
|
||||
encoder_out, encoder_out_lens, _ = self.encoder2(
|
||||
speech, speech_lengths, ctc=self.ctc2, ind=ind
|
||||
)
|
||||
else:
|
||||
encoder_out, encoder_out_lens, _ = self.encoder2(speech, speech_lengths, ind=ind)
|
||||
intermediate_outs = None
|
||||
|
||||
encoder_out, encoder_out_lens, _ = self.encoder2(speech, speech_lengths, ind=ind)
|
||||
if isinstance(encoder_out, tuple):
|
||||
intermediate_outs = encoder_out[1]
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
# # Post-encoder, e.g. NLU
|
||||
# if self.postencoder is not None:
|
||||
# encoder_out, encoder_out_lens = self.postencoder(
|
||||
# encoder_out, encoder_out_lens
|
||||
# )
|
||||
|
||||
assert encoder_out.size(0) == speech.size(0), (
|
||||
encoder_out.size(),
|
||||
speech.size(0),
|
||||
)
|
||||
assert encoder_out.size(1) <= encoder_out_lens.max(), (
|
||||
encoder_out.size(),
|
||||
encoder_out_lens.max(),
|
||||
)
|
||||
|
||||
if intermediate_outs is not None:
|
||||
return (encoder_out, intermediate_outs), encoder_out_lens
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def _extract_feats(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert speech_lengths.dim() == 1, speech_lengths.shape
|
||||
|
||||
# for data-parallel
|
||||
speech = speech[:, : speech_lengths.max()]
|
||||
|
||||
if self.frontend is not None:
|
||||
# Frontend
|
||||
# e.g. STFT and Feature extract
|
||||
# data_loader may send time-domain signal in this case
|
||||
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
|
||||
feats, feats_lengths = self.frontend(speech, speech_lengths)
|
||||
else:
|
||||
# No frontend and no feature extract
|
||||
feats, feats_lengths = speech, speech_lengths
|
||||
return feats, feats_lengths
|
||||
|
||||
def nll(
|
||||
self,
|
||||
@ -1024,36 +796,152 @@ class UniASR(torch.nn.Module):
|
||||
|
||||
return pre_acoustic_embeds, pre_token_length, predictor_alignments, predictor_alignments_len, scama_mask
|
||||
|
||||
def _calc_ctc_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
):
|
||||
# Calc CTC loss
|
||||
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
|
||||
def init_beam_search(self,
|
||||
**kwargs,
|
||||
):
|
||||
from funasr.models.uniasr.beam_search import BeamSearchScama
|
||||
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
|
||||
from funasr.models.transformer.scorers.length_bonus import LengthBonus
|
||||
|
||||
# Calc CER using CTC
|
||||
cer_ctc = None
|
||||
if not self.training and self.error_calculator is not None:
|
||||
ys_hat = self.ctc.argmax(encoder_out).data
|
||||
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
|
||||
return loss_ctc, cer_ctc
|
||||
decoding_mode = kwargs.get("decoding_mode", "model1")
|
||||
if decoding_mode == "model1":
|
||||
decoder = self.decoder
|
||||
else:
|
||||
decoder = self.decoder2
|
||||
# 1. Build ASR model
|
||||
scorers = {}
|
||||
|
||||
if self.ctc != None:
|
||||
ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
|
||||
scorers.update(
|
||||
ctc=ctc
|
||||
)
|
||||
token_list = kwargs.get("token_list")
|
||||
scorers.update(
|
||||
decoder=decoder,
|
||||
length_bonus=LengthBonus(len(token_list)),
|
||||
)
|
||||
|
||||
# 3. Build ngram model
|
||||
# ngram is not supported now
|
||||
ngram = None
|
||||
scorers["ngram"] = ngram
|
||||
|
||||
weights = dict(
|
||||
decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.0),
|
||||
ctc=kwargs.get("decoding_ctc_weight", 0.0),
|
||||
lm=kwargs.get("lm_weight", 0.0),
|
||||
ngram=kwargs.get("ngram_weight", 0.0),
|
||||
length_bonus=kwargs.get("penalty", 0.0),
|
||||
)
|
||||
beam_search = BeamSearchScama(
|
||||
beam_size=kwargs.get("beam_size", 5),
|
||||
weights=weights,
|
||||
scorers=scorers,
|
||||
sos=self.sos,
|
||||
eos=self.eos,
|
||||
vocab_size=len(token_list),
|
||||
token_list=token_list,
|
||||
pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
|
||||
)
|
||||
|
||||
self.beam_search = beam_search
|
||||
|
||||
def _calc_ctc_loss2(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
):
|
||||
# Calc CTC loss
|
||||
loss_ctc = self.ctc2(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
|
||||
def inference(self,
|
||||
data_in,
|
||||
data_lengths=None,
|
||||
key: list = None,
|
||||
tokenizer=None,
|
||||
frontend=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
# Calc CER using CTC
|
||||
cer_ctc = None
|
||||
if not self.training and self.error_calculator is not None:
|
||||
ys_hat = self.ctc2.argmax(encoder_out).data
|
||||
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
|
||||
return loss_ctc, cer_ctc
|
||||
decoding_model = kwargs.get("decoding_model", "normal")
|
||||
token_num_relax = kwargs.get("token_num_relax", 5)
|
||||
if decoding_model == "fast":
|
||||
decoding_ind = 0
|
||||
decoding_mode = "model1"
|
||||
elif decoding_model == "offline":
|
||||
decoding_ind = 1
|
||||
decoding_mode = "model2"
|
||||
else:
|
||||
decoding_ind = 0
|
||||
decoding_mode = "model2"
|
||||
# init beamsearch
|
||||
|
||||
if self.beam_search is None:
|
||||
logging.info("enable beam_search")
|
||||
self.init_beam_search(decoding_mode=decoding_mode, **kwargs)
|
||||
self.nbest = kwargs.get("nbest", 1)
|
||||
|
||||
meta_data = {}
|
||||
if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
|
||||
speech, speech_lengths = data_in, data_lengths
|
||||
if len(speech.shape) < 3:
|
||||
speech = speech[None, :, :]
|
||||
if speech_lengths is None:
|
||||
speech_lengths = speech.shape[1]
|
||||
else:
|
||||
# extract fbank feats
|
||||
time1 = time.perf_counter()
|
||||
audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
|
||||
data_type=kwargs.get("data_type", "sound"),
|
||||
tokenizer=tokenizer)
|
||||
time2 = time.perf_counter()
|
||||
meta_data["load_data"] = f"{time2 - time1:0.3f}"
|
||||
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
|
||||
frontend=frontend)
|
||||
time3 = time.perf_counter()
|
||||
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
|
||||
meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
|
||||
|
||||
speech = speech.to(device=kwargs["device"])
|
||||
speech_lengths = speech_lengths.to(device=kwargs["device"])
|
||||
speech_raw = speech.clone().to(device=kwargs["device"])
|
||||
# Encoder
|
||||
_, encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=decoding_ind)
|
||||
if decoding_mode == "model1":
|
||||
predictor_outs = self.calc_predictor_mask(encoder_out, encoder_out_lens)
|
||||
else:
|
||||
encoder_out, encoder_out_lens = self.encode2(encoder_out, encoder_out_lens, speech_raw, speech_lengths, ind=decoding_ind)
|
||||
predictor_outs = self.calc_predictor_mask2(encoder_out, encoder_out_lens)
|
||||
|
||||
|
||||
scama_mask = predictor_outs[4]
|
||||
pre_token_length = predictor_outs[1]
|
||||
pre_acoustic_embeds = predictor_outs[0]
|
||||
maxlen = pre_token_length.sum().item() + token_num_relax
|
||||
minlen = max(0, pre_token_length.sum().item() - token_num_relax)
|
||||
# c. Passed the encoder result and the beam search
|
||||
nbest_hyps = self.beam_search(
|
||||
x=encoder_out[0], scama_mask=scama_mask, pre_acoustic_embeds=pre_acoustic_embeds, maxlenratio=0.0,
|
||||
minlenratio=0.0, maxlen=int(maxlen), minlen=int(minlen),
|
||||
)
|
||||
|
||||
nbest_hyps = nbest_hyps[: self.nbest]
|
||||
|
||||
results = []
|
||||
for hyp in nbest_hyps:
|
||||
|
||||
# remove sos/eos and get results
|
||||
last_pos = -1
|
||||
if isinstance(hyp.yseq, list):
|
||||
token_int = hyp.yseq[1:last_pos]
|
||||
else:
|
||||
token_int = hyp.yseq[1:last_pos].tolist()
|
||||
|
||||
# remove blank symbol id, which is assumed to be 0
|
||||
token_int = list(filter(lambda x: x != 0, token_int))
|
||||
|
||||
|
||||
# Change integer-ids to tokens
|
||||
token = tokenizer.ids2tokens(token_int)
|
||||
text_postprocessed = tokenizer.tokens2text(token)
|
||||
if not hasattr(tokenizer, "bpemodel"):
|
||||
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
|
||||
|
||||
|
||||
result_i = {"key": key[0], "text": text_postprocessed}
|
||||
results.append(result_i)
|
||||
|
||||
return results, meta_data
|
||||
Loading…
Reference in New Issue
Block a user