diff --git a/docs/images/wechat.png b/docs/images/wechat.png index 2e2aa6bcb..a0ee69308 100644 Binary files a/docs/images/wechat.png and b/docs/images/wechat.png differ diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py index 023063896..00bc85b5d 100644 --- a/funasr/models/sense_voice/model.py +++ b/funasr/models/sense_voice/model.py @@ -514,6 +514,20 @@ class SenseVoiceRWKV(nn.Module): self.beam_search.sos = sos_int self.beam_search.eos = eos_int[0] + # Paramterts for rich decoding + self.beam_search.emo_unk = tokenizer.encode( + DecodingOptions.get("emo_unk_token", "<|SPECIAL_TOKEN_1|>"), allowed_special="all")[0] + self.beam_search.emo_unk_score = 1 + self.beam_search.emo_tokens = tokenizer.encode( + DecodingOptions.get("emo_target_tokens", "<|HAPPY|><|SAD|><|ANGRY|>"), allowed_special="all") + self.beam_search.emo_scores = DecodingOptions.get("emo_target_threshold", [0.1, 0.1, 0.1]) + + self.beam_search.event_bg_token = tokenizer.encode( + DecodingOptions.get("gain_tokens_bg", "<|Speech|><|BGM|><|Applause|><|Laughter|>"), allowed_special="all") + self.beam_search.event_ed_token = tokenizer.encode( + DecodingOptions.get("gain_tokens_ed", "<|/Speech|><|/BGM|><|/Applause|><|/Laughter|>"), allowed_special="all") + self.beam_search.event_score_ga = DecodingOptions.get("gain_tokens_score", [1, 1, 1, 1]) + encoder_out, encoder_out_lens = self.encode( speech[None, :, :].permute(0, 2, 1), speech_lengths ) @@ -843,6 +857,20 @@ class SenseVoiceFSMN(nn.Module): self.beam_search.sos = sos_int self.beam_search.eos = eos_int[0] + # Paramterts for rich decoding + self.beam_search.emo_unk = tokenizer.encode( + DecodingOptions.get("emo_unk_token", "<|SPECIAL_TOKEN_1|>"), allowed_special="all")[0] + self.beam_search.emo_unk_score = 1 + self.beam_search.emo_tokens = tokenizer.encode( + DecodingOptions.get("emo_target_tokens", "<|HAPPY|><|SAD|><|ANGRY|>"), allowed_special="all") + self.beam_search.emo_scores = DecodingOptions.get("emo_target_threshold", [0.1, 0.1, 0.1]) + + self.beam_search.event_bg_token = tokenizer.encode( + DecodingOptions.get("gain_tokens_bg", "<|Speech|><|BGM|><|Applause|><|Laughter|>"), allowed_special="all") + self.beam_search.event_ed_token = tokenizer.encode( + DecodingOptions.get("gain_tokens_ed", "<|/Speech|><|/BGM|><|/Applause|><|/Laughter|>"), allowed_special="all") + self.beam_search.event_score_ga = DecodingOptions.get("gain_tokens_score", [1, 1, 1, 1]) + encoder_out, encoder_out_lens = self.encode( speech[None, :, :].permute(0, 2, 1), speech_lengths ) diff --git a/funasr/models/sense_voice/search.py b/funasr/models/sense_voice/search.py index 694e569ea..4400ce75d 100644 --- a/funasr/models/sense_voice/search.py +++ b/funasr/models/sense_voice/search.py @@ -1,4 +1,5 @@ from itertools import chain +from dataclasses import field import logging from typing import Any from typing import Dict @@ -8,6 +9,7 @@ from typing import Tuple from typing import Union import torch +import numpy as np from funasr.metrics.common import end_detect from funasr.models.transformer.scorers.scorer_interface import PartialScorerInterface @@ -42,6 +44,17 @@ class BeamSearch(torch.nn.Module): vocab_size: int, sos=None, eos=None, + # NOTE add rich decoding parameters + # [SPECIAL_TOKEN_1, HAPPY, SAD, ANGRY, NEUTRAL] + emo_unk: int = 58964, + emo_unk_score: float = 1.0, + emo_tokens: List[int] = field(default_factory=lambda: [58954, 58955, 58956, 58957]), + emo_scores: List[float] = field(default_factory=lambda: [0.1, 0.1, 0.1, 0.1]), + # [Speech, BGM, Laughter, Applause] + event_bg_token: List[int] = field(default_factory=lambda: [58946, 58948, 58950, 58952]), + event_ed_token: List[int] = field(default_factory=lambda: [58947, 58949, 58951, 58953]), + event_score_ga: List[float] = field(default_factory=lambda: [1, 1, 5, 25]), + token_list: List[str] = None, pre_beam_ratio: float = 1.5, pre_beam_score_key: str = None, @@ -110,6 +123,14 @@ class BeamSearch(torch.nn.Module): and len(self.part_scorers) > 0 ) + self.emo_unk = emo_unk + self.emo_unk_score = emo_unk_score + self.emo_tokens = emo_tokens + self.emo_scores = emo_scores + self.event_bg_token = event_bg_token + self.event_ed_token = event_ed_token + self.event_score_ga = event_score_ga + def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]: """Get an initial hypothesis data. @@ -170,10 +191,48 @@ class BeamSearch(torch.nn.Module): """ scores = dict() states = dict() + + def get_score(yseq, sp1, sp2): + score = [0, 0] + last_token = yseq[-1] + last_token2 = yseq[-2] if len(yseq) > 1 else yseq[-1] + sum_sp1 = sum([1 if x == sp1 else 0 for x in yseq]) + sum_sp2 = sum([1 if x == sp2 else 0 for x in yseq]) + if sum_sp1 > sum_sp2 or last_token in [sp1, sp2]: + score[0] = -np.inf + if sum_sp2 >= sum_sp1: + score[1] = -np.inf + return score + + def struct_score(yseq, score): + import math + + last_token = yseq[-1] + if last_token in self.emo_tokens + [self.emo_unk]: + # prevent output event after emotation token + score[self.event_bg_token] = -np.inf + + for eve_bg, eve_ed, eve_ga in zip(self.event_bg_token, self.event_ed_token, self.event_score_ga): + score_offset = get_score(yseq, eve_bg, eve_ed) + score[eve_bg] += score_offset[0] + score[eve_ed] += score_offset[1] + score[eve_bg] += math.log(eve_ga) + + + score[self.emo_unk] += math.log(self.emo_unk_score) + for emo, emo_th in zip(self.emo_tokens, self.emo_scores): + if score.argmax() == emo and score[emo] < math.log(emo_th): + score[self.emo_unk] = max(score[emo], score[self.emo_unk]) + score[emo] = -np.inf + return score + for k, d in self.full_scorers.items(): scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x) + scores[k] = struct_score(hyp.yseq, scores[k]) + return scores, states + def score_partial( self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: