From 3f487c42904a27deeae4ab48cf8ccc45537263d1 Mon Sep 17 00:00:00 2001 From: zhifu gao Date: Wed, 31 Jan 2024 16:21:10 +0800 Subject: [PATCH] funasr1.0.5 (#1328) --- .../conformer/demo.py | 13 +++ .../conformer/infer.sh | 11 +++ funasr/models/conformer/template.yaml | 12 +-- funasr/models/transformer/model.py | 87 +++++++++---------- funasr/models/transformer/search.py | 2 +- funasr/models/transformer/template.yaml | 1 - 6 files changed, 74 insertions(+), 52 deletions(-) create mode 100644 examples/industrial_data_pretraining/conformer/demo.py create mode 100644 examples/industrial_data_pretraining/conformer/infer.sh diff --git a/examples/industrial_data_pretraining/conformer/demo.py b/examples/industrial_data_pretraining/conformer/demo.py new file mode 100644 index 000000000..358a1f800 --- /dev/null +++ b/examples/industrial_data_pretraining/conformer/demo.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +from funasr import AutoModel + +model = AutoModel(model="iic/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch", model_revision="v2.0.4", + ) + +res = model.generate(input="https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav") +print(res) + diff --git a/examples/industrial_data_pretraining/conformer/infer.sh b/examples/industrial_data_pretraining/conformer/infer.sh new file mode 100644 index 000000000..c259799f3 --- /dev/null +++ b/examples/industrial_data_pretraining/conformer/infer.sh @@ -0,0 +1,11 @@ + +model="iic/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch" +model_revision="v2.0.4" + +python funasr/bin/inference.py \ ++model=${model} \ ++model_revision=${model_revision} \ ++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" \ ++output_dir="./outputs/debug" \ ++device="cpu" \ + diff --git a/funasr/models/conformer/template.yaml b/funasr/models/conformer/template.yaml index 4cbeca46f..f646acc9d 100644 --- a/funasr/models/conformer/template.yaml +++ b/funasr/models/conformer/template.yaml @@ -6,8 +6,7 @@ # tables.print() # network architecture -#model: funasr.models.paraformer.model:Paraformer -model: Transformer +model: Conformer model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option @@ -16,14 +15,14 @@ model_conf: # encoder encoder: ConformerEncoder encoder_conf: - output_size: 256 # dimension of attention + output_size: 256 attention_heads: 4 - linear_units: 2048 # the number of units of position-wise feed forward - num_blocks: 12 # the number of encoder blocks + linear_units: 2048 + num_blocks: 12 dropout_rate: 0.1 positional_dropout_rate: 0.1 attention_dropout_rate: 0.0 - input_layer: conv2d # encoder architecture type + input_layer: conv2d normalize_before: true pos_enc_layer_type: rel_pos selfattention_layer_type: rel_selfattn @@ -52,6 +51,7 @@ frontend_conf: n_mels: 80 frame_length: 25 frame_shift: 10 + dither: 0.0 lfr_m: 1 lfr_n: 1 diff --git a/funasr/models/transformer/model.py b/funasr/models/transformer/model.py index 7e40060dc..4ad466b4f 100644 --- a/funasr/models/transformer/model.py +++ b/funasr/models/transformer/model.py @@ -24,18 +24,16 @@ class Transformer(nn.Module): def __init__( self, - frontend: Optional[str] = None, - frontend_conf: Optional[Dict] = None, - specaug: Optional[str] = None, - specaug_conf: Optional[Dict] = None, + specaug: str = None, + specaug_conf: dict = None, normalize: str = None, - normalize_conf: Optional[Dict] = None, + normalize_conf: dict = None, encoder: str = None, - encoder_conf: Optional[Dict] = None, + encoder_conf: dict = None, decoder: str = None, - decoder_conf: Optional[Dict] = None, + decoder_conf: dict = None, ctc: str = None, - ctc_conf: Optional[Dict] = None, + ctc_conf: dict = None, ctc_weight: float = 0.5, interctc_weight: float = 0.0, input_size: int = 80, @@ -59,20 +57,17 @@ class Transformer(nn.Module): super().__init__() - if frontend is not None: - frontend_class = tables.frontend_classes.get_class(frontend) - frontend = frontend_class(**frontend_conf) if specaug is not None: - specaug_class = tables.specaug_classes.get_class(specaug) + specaug_class = tables.specaug_classes.get(specaug) specaug = specaug_class(**specaug_conf) if normalize is not None: - normalize_class = tables.normalize_classes.get_class(normalize) + normalize_class = tables.normalize_classes.get(normalize) normalize = normalize_class(**normalize_conf) - encoder_class = tables.encoder_classes.get_class(encoder) + encoder_class = tables.encoder_classes.get(encoder) encoder = encoder_class(input_size=input_size, **encoder_conf) encoder_output_size = encoder.output_size() if decoder is not None: - decoder_class = tables.decoder_classes.get_class(decoder) + decoder_class = tables.decoder_classes.get(decoder) decoder = decoder_class( vocab_size=vocab_size, encoder_output_size=encoder_output_size, @@ -93,7 +88,6 @@ class Transformer(nn.Module): self.vocab_size = vocab_size self.ignore_id = ignore_id self.ctc_weight = ctc_weight - self.frontend = frontend self.specaug = specaug self.normalize = normalize self.encoder = encoder @@ -338,6 +332,7 @@ class Transformer(nn.Module): ) token_list = kwargs.get("token_list") scorers.update( + decoder=self.decoder, length_bonus=LengthBonus(len(token_list)), ) @@ -348,14 +343,14 @@ class Transformer(nn.Module): scorers["ngram"] = ngram weights = dict( - decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.0), - ctc=kwargs.get("decoding_ctc_weight", 0.0), + decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.5), + ctc=kwargs.get("decoding_ctc_weight", 0.5), lm=kwargs.get("lm_weight", 0.0), ngram=kwargs.get("ngram_weight", 0.0), length_bonus=kwargs.get("penalty", 0.0), ) beam_search = BeamSearch( - beam_size=kwargs.get("beam_size", 2), + beam_size=kwargs.get("beam_size", 10), weights=weights, scorers=scorers, sos=self.sos, @@ -364,17 +359,15 @@ class Transformer(nn.Module): token_list=token_list, pre_beam_score_key=None if self.ctc_weight == 1.0 else "full", ) - # beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval() - # for scorer in scorers.values(): - # if isinstance(scorer, torch.nn.Module): - # scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval() + self.beam_search = beam_search - def generate(self, - data_in: list, - data_lengths: list=None, + def inference(self, + data_in, + data_lengths=None, key: list=None, tokenizer=None, + frontend=None, **kwargs, ): @@ -382,27 +375,34 @@ class Transformer(nn.Module): raise NotImplementedError("batch decoding is not implemented") # init beamsearch - is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None - is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None - if self.beam_search is None and (is_use_lm or is_use_ctc): + if self.beam_search is None: logging.info("enable beam_search") self.init_beam_search(**kwargs) self.nbest = kwargs.get("nbest", 1) - + meta_data = {} - # extract fbank feats - time1 = time.perf_counter() - audio_sample_list = load_audio_text_image_video(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000)) - time2 = time.perf_counter() - meta_data["load_data"] = f"{time2 - time1:0.3f}" - speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=self.frontend) - time3 = time.perf_counter() - meta_data["extract_feat"] = f"{time3 - time2:0.3f}" - meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000 - + if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank + speech, speech_lengths = data_in, data_lengths + if len(speech.shape) < 3: + speech = speech[None, :, :] + if speech_lengths is None: + speech_lengths = speech.shape[1] + else: + # extract fbank feats + time1 = time.perf_counter() + audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000), + data_type=kwargs.get("data_type", "sound"), + tokenizer=tokenizer) + time2 = time.perf_counter() + meta_data["load_data"] = f"{time2 - time1:0.3f}" + speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), + frontend=frontend) + time3 = time.perf_counter() + meta_data["extract_feat"] = f"{time3 - time2:0.3f}" + meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 + speech = speech.to(device=kwargs["device"]) speech_lengths = speech_lengths.to(device=kwargs["device"]) - # Encoder encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) if isinstance(encoder_out, tuple): @@ -439,14 +439,13 @@ class Transformer(nn.Module): token = tokenizer.ids2tokens(token_int) text = tokenizer.tokens2text(token) - text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) - result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed} + # text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) + result_i = {"key": key[i], "token": token, "text": text} results.append(result_i) if ibest_writer is not None: ibest_writer["token"][key[i]] = " ".join(token) ibest_writer["text"][key[i]] = text - ibest_writer["text_postprocessed"][key[i]] = text_postprocessed return results, meta_data diff --git a/funasr/models/transformer/search.py b/funasr/models/transformer/search.py index 39c4f8c48..ab7ac7d78 100644 --- a/funasr/models/transformer/search.py +++ b/funasr/models/transformer/search.py @@ -9,7 +9,7 @@ from typing import Union import torch -from funasr.metrics import end_detect +from funasr.metrics.common import end_detect from funasr.models.transformer.scorers.scorer_interface import PartialScorerInterface from funasr.models.transformer.scorers.scorer_interface import ScorerInterface diff --git a/funasr/models/transformer/template.yaml b/funasr/models/transformer/template.yaml index c9228f433..87814dc3b 100644 --- a/funasr/models/transformer/template.yaml +++ b/funasr/models/transformer/template.yaml @@ -6,7 +6,6 @@ # tables.print() # network architecture -#model: funasr.models.paraformer.model:Paraformer model: Transformer model_conf: ctc_weight: 0.3