mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
funasr1.0.5 (#1328)
This commit is contained in:
parent
85e658a0f6
commit
3f487c4290
13
examples/industrial_data_pretraining/conformer/demo.py
Normal file
13
examples/industrial_data_pretraining/conformer/demo.py
Normal file
@ -0,0 +1,13 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- encoding: utf-8 -*-
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="iic/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch", model_revision="v2.0.4",
|
||||
)
|
||||
|
||||
res = model.generate(input="https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav")
|
||||
print(res)
|
||||
|
||||
11
examples/industrial_data_pretraining/conformer/infer.sh
Normal file
11
examples/industrial_data_pretraining/conformer/infer.sh
Normal file
@ -0,0 +1,11 @@
|
||||
|
||||
model="iic/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch"
|
||||
model_revision="v2.0.4"
|
||||
|
||||
python funasr/bin/inference.py \
|
||||
+model=${model} \
|
||||
+model_revision=${model_revision} \
|
||||
+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" \
|
||||
+output_dir="./outputs/debug" \
|
||||
+device="cpu" \
|
||||
|
||||
@ -6,8 +6,7 @@
|
||||
# tables.print()
|
||||
|
||||
# network architecture
|
||||
#model: funasr.models.paraformer.model:Paraformer
|
||||
model: Transformer
|
||||
model: Conformer
|
||||
model_conf:
|
||||
ctc_weight: 0.3
|
||||
lsm_weight: 0.1 # label smoothing option
|
||||
@ -16,14 +15,14 @@ model_conf:
|
||||
# encoder
|
||||
encoder: ConformerEncoder
|
||||
encoder_conf:
|
||||
output_size: 256 # dimension of attention
|
||||
output_size: 256
|
||||
attention_heads: 4
|
||||
linear_units: 2048 # the number of units of position-wise feed forward
|
||||
num_blocks: 12 # the number of encoder blocks
|
||||
linear_units: 2048
|
||||
num_blocks: 12
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
attention_dropout_rate: 0.0
|
||||
input_layer: conv2d # encoder architecture type
|
||||
input_layer: conv2d
|
||||
normalize_before: true
|
||||
pos_enc_layer_type: rel_pos
|
||||
selfattention_layer_type: rel_selfattn
|
||||
@ -52,6 +51,7 @@ frontend_conf:
|
||||
n_mels: 80
|
||||
frame_length: 25
|
||||
frame_shift: 10
|
||||
dither: 0.0
|
||||
lfr_m: 1
|
||||
lfr_n: 1
|
||||
|
||||
|
||||
@ -24,18 +24,16 @@ class Transformer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frontend: Optional[str] = None,
|
||||
frontend_conf: Optional[Dict] = None,
|
||||
specaug: Optional[str] = None,
|
||||
specaug_conf: Optional[Dict] = None,
|
||||
specaug: str = None,
|
||||
specaug_conf: dict = None,
|
||||
normalize: str = None,
|
||||
normalize_conf: Optional[Dict] = None,
|
||||
normalize_conf: dict = None,
|
||||
encoder: str = None,
|
||||
encoder_conf: Optional[Dict] = None,
|
||||
encoder_conf: dict = None,
|
||||
decoder: str = None,
|
||||
decoder_conf: Optional[Dict] = None,
|
||||
decoder_conf: dict = None,
|
||||
ctc: str = None,
|
||||
ctc_conf: Optional[Dict] = None,
|
||||
ctc_conf: dict = None,
|
||||
ctc_weight: float = 0.5,
|
||||
interctc_weight: float = 0.0,
|
||||
input_size: int = 80,
|
||||
@ -59,20 +57,17 @@ class Transformer(nn.Module):
|
||||
|
||||
super().__init__()
|
||||
|
||||
if frontend is not None:
|
||||
frontend_class = tables.frontend_classes.get_class(frontend)
|
||||
frontend = frontend_class(**frontend_conf)
|
||||
if specaug is not None:
|
||||
specaug_class = tables.specaug_classes.get_class(specaug)
|
||||
specaug_class = tables.specaug_classes.get(specaug)
|
||||
specaug = specaug_class(**specaug_conf)
|
||||
if normalize is not None:
|
||||
normalize_class = tables.normalize_classes.get_class(normalize)
|
||||
normalize_class = tables.normalize_classes.get(normalize)
|
||||
normalize = normalize_class(**normalize_conf)
|
||||
encoder_class = tables.encoder_classes.get_class(encoder)
|
||||
encoder_class = tables.encoder_classes.get(encoder)
|
||||
encoder = encoder_class(input_size=input_size, **encoder_conf)
|
||||
encoder_output_size = encoder.output_size()
|
||||
if decoder is not None:
|
||||
decoder_class = tables.decoder_classes.get_class(decoder)
|
||||
decoder_class = tables.decoder_classes.get(decoder)
|
||||
decoder = decoder_class(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
@ -93,7 +88,6 @@ class Transformer(nn.Module):
|
||||
self.vocab_size = vocab_size
|
||||
self.ignore_id = ignore_id
|
||||
self.ctc_weight = ctc_weight
|
||||
self.frontend = frontend
|
||||
self.specaug = specaug
|
||||
self.normalize = normalize
|
||||
self.encoder = encoder
|
||||
@ -338,6 +332,7 @@ class Transformer(nn.Module):
|
||||
)
|
||||
token_list = kwargs.get("token_list")
|
||||
scorers.update(
|
||||
decoder=self.decoder,
|
||||
length_bonus=LengthBonus(len(token_list)),
|
||||
)
|
||||
|
||||
@ -348,14 +343,14 @@ class Transformer(nn.Module):
|
||||
scorers["ngram"] = ngram
|
||||
|
||||
weights = dict(
|
||||
decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.0),
|
||||
ctc=kwargs.get("decoding_ctc_weight", 0.0),
|
||||
decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.5),
|
||||
ctc=kwargs.get("decoding_ctc_weight", 0.5),
|
||||
lm=kwargs.get("lm_weight", 0.0),
|
||||
ngram=kwargs.get("ngram_weight", 0.0),
|
||||
length_bonus=kwargs.get("penalty", 0.0),
|
||||
)
|
||||
beam_search = BeamSearch(
|
||||
beam_size=kwargs.get("beam_size", 2),
|
||||
beam_size=kwargs.get("beam_size", 10),
|
||||
weights=weights,
|
||||
scorers=scorers,
|
||||
sos=self.sos,
|
||||
@ -364,17 +359,15 @@ class Transformer(nn.Module):
|
||||
token_list=token_list,
|
||||
pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
|
||||
)
|
||||
# beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
|
||||
# for scorer in scorers.values():
|
||||
# if isinstance(scorer, torch.nn.Module):
|
||||
# scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
|
||||
|
||||
self.beam_search = beam_search
|
||||
|
||||
def generate(self,
|
||||
data_in: list,
|
||||
data_lengths: list=None,
|
||||
def inference(self,
|
||||
data_in,
|
||||
data_lengths=None,
|
||||
key: list=None,
|
||||
tokenizer=None,
|
||||
frontend=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
@ -382,27 +375,34 @@ class Transformer(nn.Module):
|
||||
raise NotImplementedError("batch decoding is not implemented")
|
||||
|
||||
# init beamsearch
|
||||
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
|
||||
is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
|
||||
if self.beam_search is None and (is_use_lm or is_use_ctc):
|
||||
if self.beam_search is None:
|
||||
logging.info("enable beam_search")
|
||||
self.init_beam_search(**kwargs)
|
||||
self.nbest = kwargs.get("nbest", 1)
|
||||
|
||||
|
||||
meta_data = {}
|
||||
# extract fbank feats
|
||||
time1 = time.perf_counter()
|
||||
audio_sample_list = load_audio_text_image_video(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
|
||||
time2 = time.perf_counter()
|
||||
meta_data["load_data"] = f"{time2 - time1:0.3f}"
|
||||
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=self.frontend)
|
||||
time3 = time.perf_counter()
|
||||
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
|
||||
meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
|
||||
|
||||
if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
|
||||
speech, speech_lengths = data_in, data_lengths
|
||||
if len(speech.shape) < 3:
|
||||
speech = speech[None, :, :]
|
||||
if speech_lengths is None:
|
||||
speech_lengths = speech.shape[1]
|
||||
else:
|
||||
# extract fbank feats
|
||||
time1 = time.perf_counter()
|
||||
audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
|
||||
data_type=kwargs.get("data_type", "sound"),
|
||||
tokenizer=tokenizer)
|
||||
time2 = time.perf_counter()
|
||||
meta_data["load_data"] = f"{time2 - time1:0.3f}"
|
||||
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
|
||||
frontend=frontend)
|
||||
time3 = time.perf_counter()
|
||||
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
|
||||
meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
|
||||
|
||||
speech = speech.to(device=kwargs["device"])
|
||||
speech_lengths = speech_lengths.to(device=kwargs["device"])
|
||||
|
||||
# Encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
if isinstance(encoder_out, tuple):
|
||||
@ -439,14 +439,13 @@ class Transformer(nn.Module):
|
||||
token = tokenizer.ids2tokens(token_int)
|
||||
text = tokenizer.tokens2text(token)
|
||||
|
||||
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
|
||||
result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed}
|
||||
# text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
|
||||
result_i = {"key": key[i], "token": token, "text": text}
|
||||
results.append(result_i)
|
||||
|
||||
if ibest_writer is not None:
|
||||
ibest_writer["token"][key[i]] = " ".join(token)
|
||||
ibest_writer["text"][key[i]] = text
|
||||
ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
|
||||
|
||||
return results, meta_data
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from funasr.metrics import end_detect
|
||||
from funasr.metrics.common import end_detect
|
||||
from funasr.models.transformer.scorers.scorer_interface import PartialScorerInterface
|
||||
from funasr.models.transformer.scorers.scorer_interface import ScorerInterface
|
||||
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
# tables.print()
|
||||
|
||||
# network architecture
|
||||
#model: funasr.models.paraformer.model:Paraformer
|
||||
model: Transformer
|
||||
model_conf:
|
||||
ctc_weight: 0.3
|
||||
|
||||
Loading…
Reference in New Issue
Block a user