Merge branch 'funasr1.0' of github.com:alibaba-damo-academy/FunASR into funasr1.0

add
This commit is contained in:
游雁 2024-01-10 17:43:50 +08:00
commit d342c642fa
3 changed files with 14 additions and 15 deletions

View File

@ -1,7 +1,7 @@
from funasr_onnx import ContextualParaformer from funasr_onnx import ContextualParaformer
from pathlib import Path from pathlib import Path
model_dir = "./export/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" model_dir = "../export/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" # your export dir
model = ContextualParaformer(model_dir, batch_size=1) model = ContextualParaformer(model_dir, batch_size=1)
wav_path = ['{}/.cache/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/example/asr_example.wav'.format(Path.home())] wav_path = ['{}/.cache/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/example/asr_example.wav'.format(Path.home())]

View File

@ -7,7 +7,6 @@ from pathlib import Path
from typing import List, Union, Tuple from typing import List, Union, Tuple
import copy import copy
import torch
import librosa import librosa
import numpy as np import numpy as np
@ -18,7 +17,7 @@ from .utils.postprocess_utils import (sentence_postprocess,
sentence_postprocess_sentencepiece) sentence_postprocess_sentencepiece)
from .utils.frontend import WavFrontend from .utils.frontend import WavFrontend
from .utils.timestamp_utils import time_stamp_lfr6_onnx from .utils.timestamp_utils import time_stamp_lfr6_onnx
from .utils.utils import pad_list, make_pad_mask from .utils.utils import pad_list
logging = get_logger() logging = get_logger()
@ -309,7 +308,7 @@ class ContextualParaformer(Paraformer):
# index from bias_embed # index from bias_embed
bias_embed = bias_embed.transpose(1, 0, 2) bias_embed = bias_embed.transpose(1, 0, 2)
_ind = np.arange(0, len(hotwords)).tolist() _ind = np.arange(0, len(hotwords)).tolist()
bias_embed = bias_embed[_ind, hotwords_length.cpu().numpy().tolist()] bias_embed = bias_embed[_ind, hotwords_length.tolist()]
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq) waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
waveform_nums = len(waveform_list) waveform_nums = len(waveform_list)
asr_res = [] asr_res = []
@ -336,7 +335,7 @@ class ContextualParaformer(Paraformer):
hotwords = hotwords.split(" ") hotwords = hotwords.split(" ")
hotwords_length = [len(i) - 1 for i in hotwords] hotwords_length = [len(i) - 1 for i in hotwords]
hotwords_length.append(0) hotwords_length.append(0)
hotwords_length = torch.Tensor(hotwords_length).to(torch.int32) hotwords_length = np.array(hotwords_length)
# hotwords.append('<s>') # hotwords.append('<s>')
def word_map(word): def word_map(word):
hotwords = [] hotwords = []
@ -346,11 +345,12 @@ class ContextualParaformer(Paraformer):
logging.warning("oov character {} found in hotword {}, replaced by <unk>".format(c, word)) logging.warning("oov character {} found in hotword {}, replaced by <unk>".format(c, word))
else: else:
hotwords.append(self.vocab[c]) hotwords.append(self.vocab[c])
return torch.tensor(hotwords) return np.array(hotwords)
hotword_int = [word_map(i) for i in hotwords] hotword_int = [word_map(i) for i in hotwords]
# import pdb; pdb.set_trace() # import pdb; pdb.set_trace()
hotword_int.append(torch.tensor([1])) hotword_int.append(np.array([1]))
hotwords = pad_list(hotword_int, pad_value=0, max_len=10) hotwords = pad_list(hotword_int, pad_value=0, max_len=10)
# import pdb; pdb.set_trace()
return hotwords, hotwords_length return hotwords, hotwords_length
def bb_infer(self, feats: np.ndarray, def bb_infer(self, feats: np.ndarray,
@ -359,7 +359,7 @@ class ContextualParaformer(Paraformer):
return outputs return outputs
def eb_infer(self, hotwords, hotwords_length): def eb_infer(self, hotwords, hotwords_length):
outputs = self.ort_infer_eb([hotwords.to(torch.int32).numpy(), hotwords_length.to(torch.int32).numpy()]) outputs = self.ort_infer_eb([hotwords.astype(np.int32), hotwords_length.astype(np.int32)])
return outputs return outputs
def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]: def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:

View File

@ -2,12 +2,10 @@
import functools import functools
import logging import logging
import pickle
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
import re import re
import torch
import numpy as np import numpy as np
import yaml import yaml
try: try:
@ -27,14 +25,15 @@ def pad_list(xs, pad_value, max_len=None):
n_batch = len(xs) n_batch = len(xs)
if max_len is None: if max_len is None:
max_len = max(x.size(0) for x in xs) max_len = max(x.size(0) for x in xs)
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value) # pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
# numpy format
pad = (np.zeros((n_batch, max_len)) + pad_value).astype(np.int32)
for i in range(n_batch): for i in range(n_batch):
pad[i, : xs[i].size(0)] = xs[i] pad[i, : xs[i].shape[0]] = xs[i]
return pad return pad
'''
def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None): def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
if length_dim == 0: if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim)) raise ValueError("length_dim cannot be 0: {}".format(length_dim))
@ -67,7 +66,7 @@ def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
) )
mask = mask[ind].expand_as(xs).to(xs.device) mask = mask[ind].expand_as(xs).to(xs.device)
return mask return mask
'''
class TokenIDConverter(): class TokenIDConverter():
def __init__(self, token_list: Union[List, str], def __init__(self, token_list: Union[List, str],