Merge branch 'funasr1.0' of github.com:alibaba-damo-academy/FunASR into funasr1.0

add
This commit is contained in:
游雁 2024-01-10 17:43:50 +08:00
commit d342c642fa
3 changed files with 14 additions and 15 deletions

View File

@ -1,7 +1,7 @@
from funasr_onnx import ContextualParaformer
from pathlib import Path
model_dir = "./export/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"
model_dir = "../export/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" # your export dir
model = ContextualParaformer(model_dir, batch_size=1)
wav_path = ['{}/.cache/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/example/asr_example.wav'.format(Path.home())]

View File

@ -7,7 +7,6 @@ from pathlib import Path
from typing import List, Union, Tuple
import copy
import torch
import librosa
import numpy as np
@ -18,7 +17,7 @@ from .utils.postprocess_utils import (sentence_postprocess,
sentence_postprocess_sentencepiece)
from .utils.frontend import WavFrontend
from .utils.timestamp_utils import time_stamp_lfr6_onnx
from .utils.utils import pad_list, make_pad_mask
from .utils.utils import pad_list
logging = get_logger()
@ -309,7 +308,7 @@ class ContextualParaformer(Paraformer):
# index from bias_embed
bias_embed = bias_embed.transpose(1, 0, 2)
_ind = np.arange(0, len(hotwords)).tolist()
bias_embed = bias_embed[_ind, hotwords_length.cpu().numpy().tolist()]
bias_embed = bias_embed[_ind, hotwords_length.tolist()]
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
waveform_nums = len(waveform_list)
asr_res = []
@ -336,7 +335,7 @@ class ContextualParaformer(Paraformer):
hotwords = hotwords.split(" ")
hotwords_length = [len(i) - 1 for i in hotwords]
hotwords_length.append(0)
hotwords_length = torch.Tensor(hotwords_length).to(torch.int32)
hotwords_length = np.array(hotwords_length)
# hotwords.append('<s>')
def word_map(word):
hotwords = []
@ -346,11 +345,12 @@ class ContextualParaformer(Paraformer):
logging.warning("oov character {} found in hotword {}, replaced by <unk>".format(c, word))
else:
hotwords.append(self.vocab[c])
return torch.tensor(hotwords)
return np.array(hotwords)
hotword_int = [word_map(i) for i in hotwords]
# import pdb; pdb.set_trace()
hotword_int.append(torch.tensor([1]))
hotword_int.append(np.array([1]))
hotwords = pad_list(hotword_int, pad_value=0, max_len=10)
# import pdb; pdb.set_trace()
return hotwords, hotwords_length
def bb_infer(self, feats: np.ndarray,
@ -359,7 +359,7 @@ class ContextualParaformer(Paraformer):
return outputs
def eb_infer(self, hotwords, hotwords_length):
outputs = self.ort_infer_eb([hotwords.to(torch.int32).numpy(), hotwords_length.to(torch.int32).numpy()])
outputs = self.ort_infer_eb([hotwords.astype(np.int32), hotwords_length.astype(np.int32)])
return outputs
def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:

View File

@ -2,12 +2,10 @@
import functools
import logging
import pickle
from pathlib import Path
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
import re
import torch
import numpy as np
import yaml
try:
@ -27,14 +25,15 @@ def pad_list(xs, pad_value, max_len=None):
n_batch = len(xs)
if max_len is None:
max_len = max(x.size(0) for x in xs)
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
# pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
# numpy format
pad = (np.zeros((n_batch, max_len)) + pad_value).astype(np.int32)
for i in range(n_batch):
pad[i, : xs[i].size(0)] = xs[i]
pad[i, : xs[i].shape[0]] = xs[i]
return pad
'''
def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
@ -67,7 +66,7 @@ def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
)
mask = mask[ind].expand_as(xs).to(xs.device)
return mask
'''
class TokenIDConverter():
def __init__(self, token_list: Union[List, str],