mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
* multi tokenizer * support fsmn_kws, fsmn_kws_mt, sanm_kws, sanm_kws_streaming training * kws --------- Co-authored-by: pengteng.spt <pengteng.spt@alibaba-inc.com>
285 lines
10 KiB
Python
285 lines
10 KiB
Python
import re
|
||
import logging
|
||
|
||
import torch
|
||
import math
|
||
from collections import defaultdict
|
||
from typing import List, Optional, Tuple
|
||
|
||
|
||
symbol_str = '[’!"#$%&\'()*+,-./:;<>=?@,。?★、…【】《》?“”‘’![\\]^_`{|}~\s]+'
|
||
|
||
|
||
def split_mixed_label(input_str):
|
||
tokens = []
|
||
s = input_str.lower()
|
||
while len(s) > 0:
|
||
match = re.match(r'[A-Za-z!?,<>()\']+', s)
|
||
if match is not None:
|
||
word = match.group(0)
|
||
else:
|
||
word = s[0:1]
|
||
tokens.append(word)
|
||
s = s.replace(word, '', 1).strip(' ')
|
||
return tokens
|
||
|
||
|
||
def query_token_set(txt, symbol_table, lexicon_table):
|
||
tokens_str = tuple()
|
||
tokens_idx = tuple()
|
||
|
||
if txt in symbol_table:
|
||
tokens_str = tokens_str + (txt, )
|
||
tokens_idx = tokens_idx + (symbol_table[txt], )
|
||
return tokens_str, tokens_idx
|
||
|
||
parts = split_mixed_label(txt)
|
||
for part in parts:
|
||
if part == '!sil' or part == '(sil)' or part == '<sil>':
|
||
tokens_str = tokens_str + ('!sil', )
|
||
elif part == '<blank>' or part == '<blank>':
|
||
tokens_str = tokens_str + ('<blank>', )
|
||
elif part == '(noise)' or part == 'noise)' or part == '(noise' or part == '<noise>':
|
||
tokens_str = tokens_str + ('<unk>', )
|
||
elif part in symbol_table:
|
||
tokens_str = tokens_str + (part, )
|
||
elif part in lexicon_table:
|
||
for ch in lexicon_table[part]:
|
||
tokens_str = tokens_str + (ch, )
|
||
else:
|
||
part = re.sub(symbol_str, '', part)
|
||
for ch in part:
|
||
tokens_str = tokens_str + (ch, )
|
||
|
||
for ch in tokens_str:
|
||
if ch in symbol_table:
|
||
tokens_idx = tokens_idx + (symbol_table[ch], )
|
||
elif ch == '!sil':
|
||
if 'sil' in symbol_table:
|
||
tokens_idx = tokens_idx + (symbol_table['sil'], )
|
||
else:
|
||
tokens_idx = tokens_idx + (symbol_table['<blank>'], )
|
||
elif ch == '<unk>':
|
||
if '<unk>' in symbol_table:
|
||
tokens_idx = tokens_idx + (symbol_table['<unk>'], )
|
||
else:
|
||
tokens_idx = tokens_idx + (symbol_table['<blank>'], )
|
||
else:
|
||
if '<unk>' in symbol_table:
|
||
tokens_idx = tokens_idx + (symbol_table['<unk>'], )
|
||
logging.info(f'\'{ch}\' is not in token set, replace with <unk>')
|
||
else:
|
||
tokens_idx = tokens_idx + (symbol_table['<blank>'], )
|
||
logging.info(f'\'{ch}\' is not in token set, replace with <blank>')
|
||
|
||
return tokens_str, tokens_idx
|
||
|
||
|
||
class KwsCtcPrefixDecoder():
|
||
"""Decoder interface wrapper for CTCPrefixDecode."""
|
||
|
||
def __init__(
|
||
self,
|
||
ctc: torch.nn.Module,
|
||
keywords: str,
|
||
token_list: list,
|
||
seg_dict: dict,
|
||
):
|
||
"""Initialize class.
|
||
|
||
Args:
|
||
ctc (torch.nn.Module): The CTC implementation.
|
||
For example, :class:`espnet.nets.pytorch_backend.ctc.CTC`
|
||
|
||
"""
|
||
self.ctc = ctc
|
||
self.token_list = token_list
|
||
|
||
token_table = {}
|
||
for token in token_list:
|
||
token_table[token] = token_list.index(token)
|
||
|
||
self.keywords_idxset = {0}
|
||
self.keywords_token = {}
|
||
self.keywords_str = keywords
|
||
keywords_list = self.keywords_str.strip().replace(' ', '').split(',')
|
||
for keyword in keywords_list:
|
||
strs, indexs = query_token_set(keyword, token_table, seg_dict)
|
||
self.keywords_token[keyword] = {}
|
||
self.keywords_token[keyword]['token_id'] = indexs
|
||
self.keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i) for i in indexs)
|
||
[ self.keywords_idxset.add(i) for i in indexs ]
|
||
|
||
def beam_search(
|
||
self,
|
||
logits: torch.Tensor,
|
||
logits_lengths: torch.Tensor,
|
||
keywords_tokenset: set = None,
|
||
score_beam_size: int = 3,
|
||
path_beam_size: int = 20,
|
||
) -> Tuple[List[List[int]], torch.Tensor]:
|
||
""" CTC prefix beam search inner implementation
|
||
|
||
Args:
|
||
logits (torch.Tensor): (1, max_len, vocab_size)
|
||
logits_lengths (torch.Tensor): (1, )
|
||
keywords_tokenset (set): token set for filtering score
|
||
score_beam_size (int): beam size for score
|
||
path_beam_size (int): beam size for path
|
||
|
||
Returns:
|
||
List[List[int]]: nbest results
|
||
"""
|
||
|
||
maxlen = logits.size(0)
|
||
ctc_probs = logits
|
||
cur_hyps = [(tuple(), (1.0, 0.0, []))]
|
||
|
||
# CTC beam search step by step
|
||
for t in range(0, maxlen):
|
||
probs = ctc_probs[t] # (vocab_size,)
|
||
# key: prefix, value (pb, pnb), default value(-inf, -inf)
|
||
next_hyps = defaultdict(lambda: (0.0, 0.0, []))
|
||
|
||
# 2.1 First beam prune: select topk best
|
||
top_k_probs, top_k_index = probs.topk(
|
||
score_beam_size) # (score_beam_size,)
|
||
|
||
# filter prob score that is too small
|
||
filter_probs = []
|
||
filter_index = []
|
||
for prob, idx in zip(top_k_probs.tolist(), top_k_index.tolist()):
|
||
if keywords_tokenset is not None:
|
||
if prob > 0.05 and idx in keywords_tokenset:
|
||
filter_probs.append(prob)
|
||
filter_index.append(idx)
|
||
else:
|
||
if prob > 0.05:
|
||
filter_probs.append(prob)
|
||
filter_index.append(idx)
|
||
|
||
if len(filter_index) == 0:
|
||
continue
|
||
|
||
for s in filter_index:
|
||
ps = probs[s].item()
|
||
# print(f'frame:{t}, token:{s}, score:{ps}')
|
||
|
||
for prefix, (pb, pnb, cur_nodes) in cur_hyps:
|
||
last = prefix[-1] if len(prefix) > 0 else None
|
||
if s == 0: # blank
|
||
n_pb, n_pnb, nodes = next_hyps[prefix]
|
||
n_pb = n_pb + pb * ps + pnb * ps
|
||
nodes = cur_nodes.copy()
|
||
next_hyps[prefix] = (n_pb, n_pnb, nodes)
|
||
elif s == last:
|
||
if not math.isclose(pnb, 0.0, abs_tol=0.000001):
|
||
# Update *ss -> *s;
|
||
n_pb, n_pnb, nodes = next_hyps[prefix]
|
||
n_pnb = n_pnb + pnb * ps
|
||
nodes = cur_nodes.copy()
|
||
if ps > nodes[-1]['prob']: # update frame and prob
|
||
nodes[-1]['prob'] = ps
|
||
nodes[-1]['frame'] = t
|
||
next_hyps[prefix] = (n_pb, n_pnb, nodes)
|
||
|
||
if not math.isclose(pb, 0.0, abs_tol=0.000001):
|
||
# Update *s-s -> *ss, - is for blank
|
||
n_prefix = prefix + (s, )
|
||
n_pb, n_pnb, nodes = next_hyps[n_prefix]
|
||
n_pnb = n_pnb + pb * ps
|
||
nodes = cur_nodes.copy()
|
||
nodes.append(dict(token=s, frame=t,
|
||
prob=ps)) # to record token prob
|
||
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
|
||
else:
|
||
n_prefix = prefix + (s, )
|
||
n_pb, n_pnb, nodes = next_hyps[n_prefix]
|
||
if nodes:
|
||
if ps > nodes[-1]['prob']: # update frame and prob
|
||
nodes[-1]['prob'] = ps
|
||
nodes[-1]['frame'] = t
|
||
else:
|
||
nodes = cur_nodes.copy()
|
||
nodes.append(dict(token=s, frame=t,
|
||
prob=ps)) # to record token prob
|
||
n_pnb = n_pnb + pb * ps + pnb * ps
|
||
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
|
||
|
||
# 2.2 Second beam prune
|
||
next_hyps = sorted(next_hyps.items(),
|
||
key=lambda x: (x[1][0] + x[1][1]),
|
||
reverse=True)
|
||
|
||
cur_hyps = next_hyps[:path_beam_size]
|
||
|
||
hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in cur_hyps]
|
||
return hyps
|
||
|
||
|
||
def is_sublist(self, main_list, check_list):
|
||
if len(main_list) < len(check_list):
|
||
return -1
|
||
|
||
if len(main_list) == len(check_list):
|
||
return 0 if main_list == check_list else -1
|
||
|
||
for i in range(len(main_list) - len(check_list)):
|
||
if main_list[i] == check_list[0]:
|
||
for j in range(len(check_list)):
|
||
if main_list[i + j] != check_list[j]:
|
||
break
|
||
else:
|
||
return i
|
||
else:
|
||
return -1
|
||
|
||
|
||
def _decode_inside(
|
||
self,
|
||
logits: torch.Tensor,
|
||
logits_lengths: torch.Tensor,
|
||
):
|
||
hyps = self.beam_search(logits, logits_lengths, self.keywords_idxset)
|
||
|
||
hit_keyword = None
|
||
hit_score = 1.0
|
||
# start = 0; end = 0
|
||
for one_hyp in hyps:
|
||
prefix_ids = one_hyp[0]
|
||
# path_score = one_hyp[1]
|
||
prefix_nodes = one_hyp[2]
|
||
assert len(prefix_ids) == len(prefix_nodes)
|
||
for word in self.keywords_token.keys():
|
||
lab = self.keywords_token[word]['token_id']
|
||
offset = self.is_sublist(prefix_ids, lab)
|
||
if offset != -1:
|
||
hit_keyword = word
|
||
for idx in range(offset, offset + len(lab)):
|
||
hit_score *= prefix_nodes[idx]['prob']
|
||
break
|
||
if hit_keyword is not None:
|
||
hit_score = math.sqrt(hit_score)
|
||
break
|
||
|
||
if hit_keyword is not None:
|
||
return True, hit_keyword, hit_score
|
||
else:
|
||
return False, None, None
|
||
|
||
|
||
def decode(self, x: torch.Tensor):
|
||
"""Get an initial state for decoding.
|
||
|
||
Args:
|
||
x (torch.Tensor): The encoded feature tensor
|
||
|
||
Returns: decode result
|
||
|
||
"""
|
||
|
||
raw_logp = self.ctc.softmax(x.unsqueeze(0)).detach().squeeze(0).cpu()
|
||
xlen = torch.tensor([raw_logp.size(1)])
|
||
|
||
return self._decode_inside(raw_logp, xlen)
|