mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Dev gzf exp (#1705)
* resume from step * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * train_loss_avg train_acc_avg * train_loss_avg train_acc_avg * train_loss_avg train_acc_avg * log step * wav is not exist * wav is not exist * decoding * decoding * decoding * wechat * decoding key * decoding key * decoding key * decoding key * decoding key * Gcf (#1704) * 添加富文本解码约束 * special token * bug fix * fix --------- Co-authored-by: 常材 <gaochangfeng.gcf@alibaba-inc.com> * decoding key --------- Co-authored-by: 常材 <gaochangfeng.gcf@alibaba-inc.com>
This commit is contained in:
parent
d42c45c60a
commit
a7bc099548
Binary file not shown.
|
Before Width: | Height: | Size: 184 KiB After Width: | Height: | Size: 158 KiB |
@ -514,6 +514,20 @@ class SenseVoiceRWKV(nn.Module):
|
||||
self.beam_search.sos = sos_int
|
||||
self.beam_search.eos = eos_int[0]
|
||||
|
||||
# Paramterts for rich decoding
|
||||
self.beam_search.emo_unk = tokenizer.encode(
|
||||
DecodingOptions.get("emo_unk_token", "<|SPECIAL_TOKEN_1|>"), allowed_special="all")[0]
|
||||
self.beam_search.emo_unk_score = 1
|
||||
self.beam_search.emo_tokens = tokenizer.encode(
|
||||
DecodingOptions.get("emo_target_tokens", "<|HAPPY|><|SAD|><|ANGRY|>"), allowed_special="all")
|
||||
self.beam_search.emo_scores = DecodingOptions.get("emo_target_threshold", [0.1, 0.1, 0.1])
|
||||
|
||||
self.beam_search.event_bg_token = tokenizer.encode(
|
||||
DecodingOptions.get("gain_tokens_bg", "<|Speech|><|BGM|><|Applause|><|Laughter|>"), allowed_special="all")
|
||||
self.beam_search.event_ed_token = tokenizer.encode(
|
||||
DecodingOptions.get("gain_tokens_ed", "<|/Speech|><|/BGM|><|/Applause|><|/Laughter|>"), allowed_special="all")
|
||||
self.beam_search.event_score_ga = DecodingOptions.get("gain_tokens_score", [1, 1, 1, 1])
|
||||
|
||||
encoder_out, encoder_out_lens = self.encode(
|
||||
speech[None, :, :].permute(0, 2, 1), speech_lengths
|
||||
)
|
||||
@ -843,6 +857,20 @@ class SenseVoiceFSMN(nn.Module):
|
||||
self.beam_search.sos = sos_int
|
||||
self.beam_search.eos = eos_int[0]
|
||||
|
||||
# Paramterts for rich decoding
|
||||
self.beam_search.emo_unk = tokenizer.encode(
|
||||
DecodingOptions.get("emo_unk_token", "<|SPECIAL_TOKEN_1|>"), allowed_special="all")[0]
|
||||
self.beam_search.emo_unk_score = 1
|
||||
self.beam_search.emo_tokens = tokenizer.encode(
|
||||
DecodingOptions.get("emo_target_tokens", "<|HAPPY|><|SAD|><|ANGRY|>"), allowed_special="all")
|
||||
self.beam_search.emo_scores = DecodingOptions.get("emo_target_threshold", [0.1, 0.1, 0.1])
|
||||
|
||||
self.beam_search.event_bg_token = tokenizer.encode(
|
||||
DecodingOptions.get("gain_tokens_bg", "<|Speech|><|BGM|><|Applause|><|Laughter|>"), allowed_special="all")
|
||||
self.beam_search.event_ed_token = tokenizer.encode(
|
||||
DecodingOptions.get("gain_tokens_ed", "<|/Speech|><|/BGM|><|/Applause|><|/Laughter|>"), allowed_special="all")
|
||||
self.beam_search.event_score_ga = DecodingOptions.get("gain_tokens_score", [1, 1, 1, 1])
|
||||
|
||||
encoder_out, encoder_out_lens = self.encode(
|
||||
speech[None, :, :].permute(0, 2, 1), speech_lengths
|
||||
)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from itertools import chain
|
||||
from dataclasses import field
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
@ -8,6 +9,7 @@ from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from funasr.metrics.common import end_detect
|
||||
from funasr.models.transformer.scorers.scorer_interface import PartialScorerInterface
|
||||
@ -42,6 +44,17 @@ class BeamSearch(torch.nn.Module):
|
||||
vocab_size: int,
|
||||
sos=None,
|
||||
eos=None,
|
||||
# NOTE add rich decoding parameters
|
||||
# [SPECIAL_TOKEN_1, HAPPY, SAD, ANGRY, NEUTRAL]
|
||||
emo_unk: int = 58964,
|
||||
emo_unk_score: float = 1.0,
|
||||
emo_tokens: List[int] = field(default_factory=lambda: [58954, 58955, 58956, 58957]),
|
||||
emo_scores: List[float] = field(default_factory=lambda: [0.1, 0.1, 0.1, 0.1]),
|
||||
# [Speech, BGM, Laughter, Applause]
|
||||
event_bg_token: List[int] = field(default_factory=lambda: [58946, 58948, 58950, 58952]),
|
||||
event_ed_token: List[int] = field(default_factory=lambda: [58947, 58949, 58951, 58953]),
|
||||
event_score_ga: List[float] = field(default_factory=lambda: [1, 1, 5, 25]),
|
||||
|
||||
token_list: List[str] = None,
|
||||
pre_beam_ratio: float = 1.5,
|
||||
pre_beam_score_key: str = None,
|
||||
@ -110,6 +123,14 @@ class BeamSearch(torch.nn.Module):
|
||||
and len(self.part_scorers) > 0
|
||||
)
|
||||
|
||||
self.emo_unk = emo_unk
|
||||
self.emo_unk_score = emo_unk_score
|
||||
self.emo_tokens = emo_tokens
|
||||
self.emo_scores = emo_scores
|
||||
self.event_bg_token = event_bg_token
|
||||
self.event_ed_token = event_ed_token
|
||||
self.event_score_ga = event_score_ga
|
||||
|
||||
def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]:
|
||||
"""Get an initial hypothesis data.
|
||||
|
||||
@ -170,10 +191,48 @@ class BeamSearch(torch.nn.Module):
|
||||
"""
|
||||
scores = dict()
|
||||
states = dict()
|
||||
|
||||
def get_score(yseq, sp1, sp2):
|
||||
score = [0, 0]
|
||||
last_token = yseq[-1]
|
||||
last_token2 = yseq[-2] if len(yseq) > 1 else yseq[-1]
|
||||
sum_sp1 = sum([1 if x == sp1 else 0 for x in yseq])
|
||||
sum_sp2 = sum([1 if x == sp2 else 0 for x in yseq])
|
||||
if sum_sp1 > sum_sp2 or last_token in [sp1, sp2]:
|
||||
score[0] = -np.inf
|
||||
if sum_sp2 >= sum_sp1:
|
||||
score[1] = -np.inf
|
||||
return score
|
||||
|
||||
def struct_score(yseq, score):
|
||||
import math
|
||||
|
||||
last_token = yseq[-1]
|
||||
if last_token in self.emo_tokens + [self.emo_unk]:
|
||||
# prevent output event after emotation token
|
||||
score[self.event_bg_token] = -np.inf
|
||||
|
||||
for eve_bg, eve_ed, eve_ga in zip(self.event_bg_token, self.event_ed_token, self.event_score_ga):
|
||||
score_offset = get_score(yseq, eve_bg, eve_ed)
|
||||
score[eve_bg] += score_offset[0]
|
||||
score[eve_ed] += score_offset[1]
|
||||
score[eve_bg] += math.log(eve_ga)
|
||||
|
||||
|
||||
score[self.emo_unk] += math.log(self.emo_unk_score)
|
||||
for emo, emo_th in zip(self.emo_tokens, self.emo_scores):
|
||||
if score.argmax() == emo and score[emo] < math.log(emo_th):
|
||||
score[self.emo_unk] = max(score[emo], score[self.emo_unk])
|
||||
score[emo] = -np.inf
|
||||
return score
|
||||
|
||||
for k, d in self.full_scorers.items():
|
||||
scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x)
|
||||
scores[k] = struct_score(hyp.yseq, scores[k])
|
||||
|
||||
return scores, states
|
||||
|
||||
|
||||
def score_partial(
|
||||
self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor
|
||||
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user