mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
funasr2
This commit is contained in:
parent
ea4453cc88
commit
00ea1186f9
@ -81,7 +81,7 @@ def main_hydra(kwargs: DictConfig):
|
||||
|
||||
class AutoModel:
|
||||
def __init__(self, **kwargs):
|
||||
registry_tables.print_register_tables()
|
||||
registry_tables.print()
|
||||
assert "model" in kwargs
|
||||
if "model_conf" not in kwargs:
|
||||
logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
|
||||
@ -108,6 +108,7 @@ class AutoModel:
|
||||
frontend_class = registry_tables.frontend_classes.get(frontend.lower())
|
||||
frontend = frontend_class(**kwargs["frontend_conf"])
|
||||
kwargs["frontend"] = frontend
|
||||
kwargs["input_size"] = frontend.output_size()
|
||||
|
||||
# build model
|
||||
model_class = registry_tables.model_classes.get(kwargs["model"].lower())
|
||||
|
||||
@ -39,7 +39,7 @@ def main(**kwargs):
|
||||
# preprocess_config(kwargs)
|
||||
# import pdb; pdb.set_trace()
|
||||
# set random seed
|
||||
registry_tables.print_register_tables()
|
||||
registry_tables.print()
|
||||
set_all_random_seed(kwargs.get("seed", 0))
|
||||
torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
|
||||
torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
|
||||
@ -72,6 +72,7 @@ def main(**kwargs):
|
||||
frontend_class = registry_tables.frontend_classes.get(frontend.lower())
|
||||
frontend = frontend_class(**kwargs["frontend_conf"])
|
||||
kwargs["frontend"] = frontend
|
||||
kwargs["input_size"] = frontend.output_size()
|
||||
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
|
||||
117
funasr/models/conformer/template.yaml
Normal file
117
funasr/models/conformer/template.yaml
Normal file
@ -0,0 +1,117 @@
|
||||
# This is an example that demonstrates how to configure a model file.
|
||||
# You can modify the configuration according to your own requirements.
|
||||
|
||||
# to print the register_table:
|
||||
# from funasr.utils.register import registry_tables
|
||||
# registry_tables.print()
|
||||
|
||||
# network architecture
|
||||
#model: funasr.models.paraformer.model:Paraformer
|
||||
model: Transformer
|
||||
model_conf:
|
||||
ctc_weight: 0.3
|
||||
lsm_weight: 0.1 # label smoothing option
|
||||
length_normalized_loss: false
|
||||
|
||||
# encoder
|
||||
encoder: ConformerEncoder
|
||||
encoder_conf:
|
||||
output_size: 256 # dimension of attention
|
||||
attention_heads: 4
|
||||
linear_units: 2048 # the number of units of position-wise feed forward
|
||||
num_blocks: 12 # the number of encoder blocks
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
attention_dropout_rate: 0.0
|
||||
input_layer: conv2d # encoder architecture type
|
||||
normalize_before: true
|
||||
pos_enc_layer_type: rel_pos
|
||||
selfattention_layer_type: rel_selfattn
|
||||
activation_type: swish
|
||||
macaron_style: true
|
||||
use_cnn_module: true
|
||||
cnn_module_kernel: 15
|
||||
|
||||
# decoder
|
||||
decoder: TransformerDecoder
|
||||
decoder_conf:
|
||||
attention_heads: 4
|
||||
linear_units: 2048
|
||||
num_blocks: 6
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
self_attention_dropout_rate: 0.0
|
||||
src_attention_dropout_rate: 0.0
|
||||
|
||||
|
||||
# frontend related
|
||||
frontend: WavFrontend
|
||||
frontend_conf:
|
||||
fs: 16000
|
||||
window: hamming
|
||||
n_mels: 80
|
||||
frame_length: 25
|
||||
frame_shift: 10
|
||||
lfr_m: 1
|
||||
lfr_n: 1
|
||||
|
||||
specaug: SpecAug
|
||||
specaug_conf:
|
||||
apply_time_warp: true
|
||||
time_warp_window: 5
|
||||
time_warp_mode: bicubic
|
||||
apply_freq_mask: true
|
||||
freq_mask_width_range:
|
||||
- 0
|
||||
- 30
|
||||
num_freq_mask: 2
|
||||
apply_time_mask: true
|
||||
time_mask_width_range:
|
||||
- 0
|
||||
- 40
|
||||
num_time_mask: 2
|
||||
|
||||
train_conf:
|
||||
accum_grad: 1
|
||||
grad_clip: 5
|
||||
max_epoch: 150
|
||||
val_scheduler_criterion:
|
||||
- valid
|
||||
- acc
|
||||
best_model_criterion:
|
||||
- - valid
|
||||
- acc
|
||||
- max
|
||||
keep_nbest_models: 10
|
||||
log_interval: 50
|
||||
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.0005
|
||||
scheduler: warmuplr
|
||||
scheduler_conf:
|
||||
warmup_steps: 30000
|
||||
|
||||
dataset: AudioDataset
|
||||
dataset_conf:
|
||||
index_ds: IndexDSJsonl
|
||||
batch_sampler: DynamicBatchLocalShuffleSampler
|
||||
batch_type: example # example or length
|
||||
batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
|
||||
max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
|
||||
buffer_size: 500
|
||||
shuffle: True
|
||||
num_workers: 0
|
||||
|
||||
tokenizer: CharTokenizer
|
||||
tokenizer_conf:
|
||||
unk_symbol: <unk>
|
||||
split_with_space: true
|
||||
|
||||
|
||||
ctc_conf:
|
||||
dropout_rate: 0.0
|
||||
ctc_type: builtin
|
||||
reduce: true
|
||||
ignore_nan_grad: true
|
||||
normalize: null
|
||||
@ -1,6 +1,12 @@
|
||||
# This is an example that demonstrates how to configure a model file.
|
||||
# You can modify the configuration according to your own requirements.
|
||||
|
||||
# to print the register_table:
|
||||
# from funasr.utils.register import registry_tables
|
||||
# registry_tables.print()
|
||||
|
||||
# network architecture
|
||||
model: funasr.cli.models.paraformer:Paraformer
|
||||
model: NeatContextualParaformer
|
||||
model_conf:
|
||||
ctc_weight: 0.0
|
||||
lsm_weight: 0.1
|
||||
@ -8,9 +14,10 @@ model_conf:
|
||||
predictor_weight: 1.0
|
||||
predictor_bias: 1
|
||||
sampling_ratio: 0.75
|
||||
inner_dim: 512
|
||||
|
||||
# encoder
|
||||
encoder: sanm
|
||||
encoder: SANMEncoder
|
||||
encoder_conf:
|
||||
output_size: 512
|
||||
attention_heads: 4
|
||||
@ -26,8 +33,9 @@ encoder_conf:
|
||||
sanm_shfit: 0
|
||||
selfattention_layer_type: sanm
|
||||
|
||||
|
||||
# decoder
|
||||
decoder: paraformer_decoder_sanm
|
||||
decoder: ContextualParaformerDecoder
|
||||
decoder_conf:
|
||||
attention_heads: 4
|
||||
linear_units: 2048
|
||||
@ -40,7 +48,7 @@ decoder_conf:
|
||||
kernel_size: 11
|
||||
sanm_shfit: 0
|
||||
|
||||
predictor: cif_predictor_v2
|
||||
predictor: CifPredictorV2
|
||||
predictor_conf:
|
||||
idim: 512
|
||||
threshold: 1.0
|
||||
@ -49,7 +57,7 @@ predictor_conf:
|
||||
tail_threshold: 0.45
|
||||
|
||||
# frontend related
|
||||
frontend: wav_frontend
|
||||
frontend: WavFrontend
|
||||
frontend_conf:
|
||||
fs: 16000
|
||||
window: hamming
|
||||
@ -59,7 +67,7 @@ frontend_conf:
|
||||
lfr_m: 7
|
||||
lfr_n: 6
|
||||
|
||||
specaug: specaug_lfr
|
||||
specaug: SpecAugLFR
|
||||
specaug_conf:
|
||||
apply_time_warp: false
|
||||
time_warp_window: 5
|
||||
@ -97,21 +105,22 @@ scheduler: warmuplr
|
||||
scheduler_conf:
|
||||
warmup_steps: 30000
|
||||
|
||||
|
||||
dataset: AudioDataset
|
||||
dataset_conf:
|
||||
data_names: speech,text
|
||||
data_types: sound,text
|
||||
index_ds: IndexDSJsonl
|
||||
batch_sampler: DynamicBatchLocalShuffleSampler
|
||||
batch_type: example # example or length
|
||||
batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
|
||||
max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
|
||||
buffer_size: 500
|
||||
shuffle: True
|
||||
shuffle_conf:
|
||||
shuffle_size: 2048
|
||||
sort_size: 500
|
||||
batch_conf:
|
||||
batch_type: example
|
||||
batch_size: 2
|
||||
num_workers: 8
|
||||
num_workers: 0
|
||||
|
||||
tokenizer: CharTokenizer
|
||||
tokenizer_conf:
|
||||
unk_symbol: <unk>
|
||||
split_with_space: true
|
||||
|
||||
split_with_space: true
|
||||
input_size: 560
|
||||
ctc_conf:
|
||||
dropout_rate: 0.0
|
||||
ctc_type: builtin
|
||||
@ -39,8 +39,6 @@ class Paraformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
# token_list: Union[Tuple[str, ...], List[str]],
|
||||
frontend: Optional[str] = None,
|
||||
frontend_conf: Optional[Dict] = None,
|
||||
specaug: Optional[str] = None,
|
||||
specaug_conf: Optional[Dict] = None,
|
||||
normalize: str = None,
|
||||
|
||||
@ -1,6 +1,10 @@
|
||||
# This is an example that demonstrates how to configure a model file.
|
||||
# You can modify the configuration according to your own requirements.
|
||||
|
||||
# to print the register_table:
|
||||
# from funasr.utils.register import registry_tables
|
||||
# registry_tables.print()
|
||||
|
||||
# network architecture
|
||||
#model: funasr.models.paraformer.model:Paraformer
|
||||
model: Paraformer
|
||||
@ -117,7 +121,6 @@ tokenizer_conf:
|
||||
split_with_space: true
|
||||
|
||||
|
||||
input_size: 560
|
||||
ctc_conf:
|
||||
dropout_rate: 0.0
|
||||
ctc_type: builtin
|
||||
|
||||
111
funasr/models/transformer/template.yaml
Normal file
111
funasr/models/transformer/template.yaml
Normal file
@ -0,0 +1,111 @@
|
||||
# This is an example that demonstrates how to configure a model file.
|
||||
# You can modify the configuration according to your own requirements.
|
||||
|
||||
# to print the register_table:
|
||||
# from funasr.utils.register import registry_tables
|
||||
# registry_tables.print()
|
||||
|
||||
# network architecture
|
||||
#model: funasr.models.paraformer.model:Paraformer
|
||||
model: Transformer
|
||||
model_conf:
|
||||
ctc_weight: 0.3
|
||||
lsm_weight: 0.1 # label smoothing option
|
||||
length_normalized_loss: false
|
||||
|
||||
# encoder
|
||||
encoder: TransformerEncoder
|
||||
encoder_conf:
|
||||
output_size: 256 # dimension of attention
|
||||
attention_heads: 4
|
||||
linear_units: 2048 # the number of units of position-wise feed forward
|
||||
num_blocks: 12 # the number of encoder blocks
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
attention_dropout_rate: 0.0
|
||||
input_layer: conv2d # encoder architecture type
|
||||
normalize_before: true
|
||||
|
||||
# decoder
|
||||
decoder: TransformerDecoder
|
||||
decoder_conf:
|
||||
attention_heads: 4
|
||||
linear_units: 2048
|
||||
num_blocks: 6
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
self_attention_dropout_rate: 0.0
|
||||
src_attention_dropout_rate: 0.0
|
||||
|
||||
|
||||
# frontend related
|
||||
frontend: WavFrontend
|
||||
frontend_conf:
|
||||
fs: 16000
|
||||
window: hamming
|
||||
n_mels: 80
|
||||
frame_length: 25
|
||||
frame_shift: 10
|
||||
lfr_m: 1
|
||||
lfr_n: 1
|
||||
|
||||
specaug: SpecAug
|
||||
specaug_conf:
|
||||
apply_time_warp: true
|
||||
time_warp_window: 5
|
||||
time_warp_mode: bicubic
|
||||
apply_freq_mask: true
|
||||
freq_mask_width_range:
|
||||
- 0
|
||||
- 30
|
||||
num_freq_mask: 2
|
||||
apply_time_mask: true
|
||||
time_mask_width_range:
|
||||
- 0
|
||||
- 40
|
||||
num_time_mask: 2
|
||||
|
||||
train_conf:
|
||||
accum_grad: 1
|
||||
grad_clip: 5
|
||||
max_epoch: 150
|
||||
val_scheduler_criterion:
|
||||
- valid
|
||||
- acc
|
||||
best_model_criterion:
|
||||
- - valid
|
||||
- acc
|
||||
- max
|
||||
keep_nbest_models: 10
|
||||
log_interval: 50
|
||||
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.002
|
||||
scheduler: warmuplr
|
||||
scheduler_conf:
|
||||
warmup_steps: 30000
|
||||
|
||||
dataset: AudioDataset
|
||||
dataset_conf:
|
||||
index_ds: IndexDSJsonl
|
||||
batch_sampler: DynamicBatchLocalShuffleSampler
|
||||
batch_type: example # example or length
|
||||
batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
|
||||
max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
|
||||
buffer_size: 500
|
||||
shuffle: True
|
||||
num_workers: 0
|
||||
|
||||
tokenizer: CharTokenizer
|
||||
tokenizer_conf:
|
||||
unk_symbol: <unk>
|
||||
split_with_space: true
|
||||
|
||||
|
||||
ctc_conf:
|
||||
dropout_rate: 0.0
|
||||
ctc_type: builtin
|
||||
reduce: true
|
||||
ignore_nan_grad: true
|
||||
normalize: null
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import inspect
|
||||
from dataclasses import dataclass, fields
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -19,7 +19,7 @@ class ClassRegistryTables:
|
||||
dataset_classes = {}
|
||||
index_ds_classes = {}
|
||||
|
||||
def print_register_tables(self,):
|
||||
def print(self,):
|
||||
print("\nregister_tables: \n")
|
||||
fields = vars(self)
|
||||
for classes_key, classes_dict in fields.items():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user