mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge branch 'funasr1.0' of github.com:alibaba-damo-academy/FunASR into funasr1.0
add
This commit is contained in:
commit
d342c642fa
@ -1,7 +1,7 @@
|
|||||||
from funasr_onnx import ContextualParaformer
|
from funasr_onnx import ContextualParaformer
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
model_dir = "./export/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"
|
model_dir = "../export/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" # your export dir
|
||||||
model = ContextualParaformer(model_dir, batch_size=1)
|
model = ContextualParaformer(model_dir, batch_size=1)
|
||||||
|
|
||||||
wav_path = ['{}/.cache/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/example/asr_example.wav'.format(Path.home())]
|
wav_path = ['{}/.cache/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/example/asr_example.wav'.format(Path.home())]
|
||||||
|
|||||||
@ -7,7 +7,6 @@ from pathlib import Path
|
|||||||
from typing import List, Union, Tuple
|
from typing import List, Union, Tuple
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import torch
|
|
||||||
import librosa
|
import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -18,7 +17,7 @@ from .utils.postprocess_utils import (sentence_postprocess,
|
|||||||
sentence_postprocess_sentencepiece)
|
sentence_postprocess_sentencepiece)
|
||||||
from .utils.frontend import WavFrontend
|
from .utils.frontend import WavFrontend
|
||||||
from .utils.timestamp_utils import time_stamp_lfr6_onnx
|
from .utils.timestamp_utils import time_stamp_lfr6_onnx
|
||||||
from .utils.utils import pad_list, make_pad_mask
|
from .utils.utils import pad_list
|
||||||
|
|
||||||
logging = get_logger()
|
logging = get_logger()
|
||||||
|
|
||||||
@ -309,7 +308,7 @@ class ContextualParaformer(Paraformer):
|
|||||||
# index from bias_embed
|
# index from bias_embed
|
||||||
bias_embed = bias_embed.transpose(1, 0, 2)
|
bias_embed = bias_embed.transpose(1, 0, 2)
|
||||||
_ind = np.arange(0, len(hotwords)).tolist()
|
_ind = np.arange(0, len(hotwords)).tolist()
|
||||||
bias_embed = bias_embed[_ind, hotwords_length.cpu().numpy().tolist()]
|
bias_embed = bias_embed[_ind, hotwords_length.tolist()]
|
||||||
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
|
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
|
||||||
waveform_nums = len(waveform_list)
|
waveform_nums = len(waveform_list)
|
||||||
asr_res = []
|
asr_res = []
|
||||||
@ -336,7 +335,7 @@ class ContextualParaformer(Paraformer):
|
|||||||
hotwords = hotwords.split(" ")
|
hotwords = hotwords.split(" ")
|
||||||
hotwords_length = [len(i) - 1 for i in hotwords]
|
hotwords_length = [len(i) - 1 for i in hotwords]
|
||||||
hotwords_length.append(0)
|
hotwords_length.append(0)
|
||||||
hotwords_length = torch.Tensor(hotwords_length).to(torch.int32)
|
hotwords_length = np.array(hotwords_length)
|
||||||
# hotwords.append('<s>')
|
# hotwords.append('<s>')
|
||||||
def word_map(word):
|
def word_map(word):
|
||||||
hotwords = []
|
hotwords = []
|
||||||
@ -346,11 +345,12 @@ class ContextualParaformer(Paraformer):
|
|||||||
logging.warning("oov character {} found in hotword {}, replaced by <unk>".format(c, word))
|
logging.warning("oov character {} found in hotword {}, replaced by <unk>".format(c, word))
|
||||||
else:
|
else:
|
||||||
hotwords.append(self.vocab[c])
|
hotwords.append(self.vocab[c])
|
||||||
return torch.tensor(hotwords)
|
return np.array(hotwords)
|
||||||
hotword_int = [word_map(i) for i in hotwords]
|
hotword_int = [word_map(i) for i in hotwords]
|
||||||
# import pdb; pdb.set_trace()
|
# import pdb; pdb.set_trace()
|
||||||
hotword_int.append(torch.tensor([1]))
|
hotword_int.append(np.array([1]))
|
||||||
hotwords = pad_list(hotword_int, pad_value=0, max_len=10)
|
hotwords = pad_list(hotword_int, pad_value=0, max_len=10)
|
||||||
|
# import pdb; pdb.set_trace()
|
||||||
return hotwords, hotwords_length
|
return hotwords, hotwords_length
|
||||||
|
|
||||||
def bb_infer(self, feats: np.ndarray,
|
def bb_infer(self, feats: np.ndarray,
|
||||||
@ -359,7 +359,7 @@ class ContextualParaformer(Paraformer):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def eb_infer(self, hotwords, hotwords_length):
|
def eb_infer(self, hotwords, hotwords_length):
|
||||||
outputs = self.ort_infer_eb([hotwords.to(torch.int32).numpy(), hotwords_length.to(torch.int32).numpy()])
|
outputs = self.ort_infer_eb([hotwords.astype(np.int32), hotwords_length.astype(np.int32)])
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
|
def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
|
||||||
|
|||||||
@ -2,12 +2,10 @@
|
|||||||
|
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import pickle
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
|
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import yaml
|
import yaml
|
||||||
try:
|
try:
|
||||||
@ -27,14 +25,15 @@ def pad_list(xs, pad_value, max_len=None):
|
|||||||
n_batch = len(xs)
|
n_batch = len(xs)
|
||||||
if max_len is None:
|
if max_len is None:
|
||||||
max_len = max(x.size(0) for x in xs)
|
max_len = max(x.size(0) for x in xs)
|
||||||
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
|
# pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
|
||||||
|
# numpy format
|
||||||
|
pad = (np.zeros((n_batch, max_len)) + pad_value).astype(np.int32)
|
||||||
for i in range(n_batch):
|
for i in range(n_batch):
|
||||||
pad[i, : xs[i].size(0)] = xs[i]
|
pad[i, : xs[i].shape[0]] = xs[i]
|
||||||
|
|
||||||
return pad
|
return pad
|
||||||
|
|
||||||
|
'''
|
||||||
def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
|
def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
|
||||||
if length_dim == 0:
|
if length_dim == 0:
|
||||||
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
|
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
|
||||||
@ -67,7 +66,7 @@ def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
|
|||||||
)
|
)
|
||||||
mask = mask[ind].expand_as(xs).to(xs.device)
|
mask = mask[ind].expand_as(xs).to(xs.device)
|
||||||
return mask
|
return mask
|
||||||
|
'''
|
||||||
|
|
||||||
class TokenIDConverter():
|
class TokenIDConverter():
|
||||||
def __init__(self, token_list: Union[List, str],
|
def __init__(self, token_list: Union[List, str],
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user