mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update egs_modelscope paraformer-large-en
This commit is contained in:
parent
8a16e1c13e
commit
fde48a8652
@ -16,7 +16,7 @@ def modelscope_infer(args):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--model', type=str, default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
|
parser.add_argument('--model', type=str, default="damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020")
|
||||||
parser.add_argument('--audio_in', type=str, default="./data/test/wav.scp")
|
parser.add_argument('--audio_in', type=str, default="./data/test/wav.scp")
|
||||||
parser.add_argument('--output_dir', type=str, default="./results/")
|
parser.add_argument('--output_dir', type=str, default="./results/")
|
||||||
parser.add_argument('--decoding_mode', type=str, default="normal")
|
parser.add_argument('--decoding_mode', type=str, default="normal")
|
||||||
|
|||||||
@ -6,7 +6,7 @@ set -o pipefail
|
|||||||
|
|
||||||
stage=1
|
stage=1
|
||||||
stop_stage=2
|
stop_stage=2
|
||||||
model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
model="damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020"
|
||||||
data_dir="./data/test"
|
data_dir="./data/test"
|
||||||
output_dir="./results"
|
output_dir="./results"
|
||||||
batch_size=64
|
batch_size=64
|
||||||
|
|||||||
@ -1918,6 +1918,8 @@ class Speech2TextWhisper:
|
|||||||
nbest: int = 1,
|
nbest: int = 1,
|
||||||
streaming: bool = False,
|
streaming: bool = False,
|
||||||
frontend_conf: dict = None,
|
frontend_conf: dict = None,
|
||||||
|
language: str = None,
|
||||||
|
task: str = "transcribe",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
@ -1960,6 +1962,8 @@ class Speech2TextWhisper:
|
|||||||
self.device = device
|
self.device = device
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.frontend = frontend
|
self.frontend = frontend
|
||||||
|
self.language = language
|
||||||
|
self.task = task
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __call__(
|
def __call__(
|
||||||
@ -1986,10 +1990,10 @@ class Speech2TextWhisper:
|
|||||||
mel = log_mel_spectrogram(speech).to(self.device)
|
mel = log_mel_spectrogram(speech).to(self.device)
|
||||||
|
|
||||||
if self.asr_model.is_multilingual:
|
if self.asr_model.is_multilingual:
|
||||||
options = DecodingOptions(fp16=False)
|
options = DecodingOptions(fp16=False, language=self.language, task=self.task)
|
||||||
asr_res = decode(self.asr_model, mel, options)
|
asr_res = decode(self.asr_model, mel, options)
|
||||||
text = asr_res.text
|
text = asr_res.text
|
||||||
language = asr_res.language
|
language = self.language if self.language else asr_res.language
|
||||||
else:
|
else:
|
||||||
asr_res = transcribe(self.asr_model, speech, fp16=False)
|
asr_res = transcribe(self.asr_model, speech, fp16=False)
|
||||||
text = asr_res["text"]
|
text = asr_res["text"]
|
||||||
|
|||||||
@ -2056,6 +2056,8 @@ def inference_whisper(
|
|||||||
|
|
||||||
ncpu = kwargs.get("ncpu", 1)
|
ncpu = kwargs.get("ncpu", 1)
|
||||||
torch.set_num_threads(ncpu)
|
torch.set_num_threads(ncpu)
|
||||||
|
language = param_dict.get("language", None)
|
||||||
|
task = param_dict.get("task", "transcribe")
|
||||||
if batch_size > 1:
|
if batch_size > 1:
|
||||||
raise NotImplementedError("batch decoding is not implemented")
|
raise NotImplementedError("batch decoding is not implemented")
|
||||||
if word_lm_train_config is not None:
|
if word_lm_train_config is not None:
|
||||||
@ -2099,6 +2101,8 @@ def inference_whisper(
|
|||||||
penalty=penalty,
|
penalty=penalty,
|
||||||
nbest=nbest,
|
nbest=nbest,
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
|
language=language,
|
||||||
|
task=task,
|
||||||
)
|
)
|
||||||
logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
|
logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
|
||||||
speech2text = Speech2TextWhisper(**speech2text_kwargs)
|
speech2text = Speech2TextWhisper(**speech2text_kwargs)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user