This commit is contained in:
游雁 2023-12-19 22:53:18 +08:00
parent ea4453cc88
commit 00ea1186f9
8 changed files with 265 additions and 25 deletions

View File

@ -81,7 +81,7 @@ def main_hydra(kwargs: DictConfig):
class AutoModel:
def __init__(self, **kwargs):
registry_tables.print_register_tables()
registry_tables.print()
assert "model" in kwargs
if "model_conf" not in kwargs:
logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
@ -108,6 +108,7 @@ class AutoModel:
frontend_class = registry_tables.frontend_classes.get(frontend.lower())
frontend = frontend_class(**kwargs["frontend_conf"])
kwargs["frontend"] = frontend
kwargs["input_size"] = frontend.output_size()
# build model
model_class = registry_tables.model_classes.get(kwargs["model"].lower())

View File

@ -39,7 +39,7 @@ def main(**kwargs):
# preprocess_config(kwargs)
# import pdb; pdb.set_trace()
# set random seed
registry_tables.print_register_tables()
registry_tables.print()
set_all_random_seed(kwargs.get("seed", 0))
torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
@ -72,6 +72,7 @@ def main(**kwargs):
frontend_class = registry_tables.frontend_classes.get(frontend.lower())
frontend = frontend_class(**kwargs["frontend_conf"])
kwargs["frontend"] = frontend
kwargs["input_size"] = frontend.output_size()
# import pdb;
# pdb.set_trace()

View File

@ -0,0 +1,117 @@
# This is an example that demonstrates how to configure a model file.
# You can modify the configuration according to your own requirements.
# to print the register_table:
# from funasr.utils.register import registry_tables
# registry_tables.print()
# network architecture
#model: funasr.models.paraformer.model:Paraformer
model: Transformer
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
# encoder
encoder: ConformerEncoder
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder architecture type
normalize_before: true
pos_enc_layer_type: rel_pos
selfattention_layer_type: rel_selfattn
activation_type: swish
macaron_style: true
use_cnn_module: true
cnn_module_kernel: 15
# decoder
decoder: TransformerDecoder
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# frontend related
frontend: WavFrontend
frontend_conf:
fs: 16000
window: hamming
n_mels: 80
frame_length: 25
frame_shift: 10
lfr_m: 1
lfr_n: 1
specaug: SpecAug
specaug_conf:
apply_time_warp: true
time_warp_window: 5
time_warp_mode: bicubic
apply_freq_mask: true
freq_mask_width_range:
- 0
- 30
num_freq_mask: 2
apply_time_mask: true
time_mask_width_range:
- 0
- 40
num_time_mask: 2
train_conf:
accum_grad: 1
grad_clip: 5
max_epoch: 150
val_scheduler_criterion:
- valid
- acc
best_model_criterion:
- - valid
- acc
- max
keep_nbest_models: 10
log_interval: 50
optim: adam
optim_conf:
lr: 0.0005
scheduler: warmuplr
scheduler_conf:
warmup_steps: 30000
dataset: AudioDataset
dataset_conf:
index_ds: IndexDSJsonl
batch_sampler: DynamicBatchLocalShuffleSampler
batch_type: example # example or length
batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
buffer_size: 500
shuffle: True
num_workers: 0
tokenizer: CharTokenizer
tokenizer_conf:
unk_symbol: <unk>
split_with_space: true
ctc_conf:
dropout_rate: 0.0
ctc_type: builtin
reduce: true
ignore_nan_grad: true
normalize: null

View File

@ -1,6 +1,12 @@
# This is an example that demonstrates how to configure a model file.
# You can modify the configuration according to your own requirements.
# to print the register_table:
# from funasr.utils.register import registry_tables
# registry_tables.print()
# network architecture
model: funasr.cli.models.paraformer:Paraformer
model: NeatContextualParaformer
model_conf:
ctc_weight: 0.0
lsm_weight: 0.1
@ -8,9 +14,10 @@ model_conf:
predictor_weight: 1.0
predictor_bias: 1
sampling_ratio: 0.75
inner_dim: 512
# encoder
encoder: sanm
encoder: SANMEncoder
encoder_conf:
output_size: 512
attention_heads: 4
@ -26,8 +33,9 @@ encoder_conf:
sanm_shfit: 0
selfattention_layer_type: sanm
# decoder
decoder: paraformer_decoder_sanm
decoder: ContextualParaformerDecoder
decoder_conf:
attention_heads: 4
linear_units: 2048
@ -40,7 +48,7 @@ decoder_conf:
kernel_size: 11
sanm_shfit: 0
predictor: cif_predictor_v2
predictor: CifPredictorV2
predictor_conf:
idim: 512
threshold: 1.0
@ -49,7 +57,7 @@ predictor_conf:
tail_threshold: 0.45
# frontend related
frontend: wav_frontend
frontend: WavFrontend
frontend_conf:
fs: 16000
window: hamming
@ -59,7 +67,7 @@ frontend_conf:
lfr_m: 7
lfr_n: 6
specaug: specaug_lfr
specaug: SpecAugLFR
specaug_conf:
apply_time_warp: false
time_warp_window: 5
@ -97,21 +105,22 @@ scheduler: warmuplr
scheduler_conf:
warmup_steps: 30000
dataset: AudioDataset
dataset_conf:
data_names: speech,text
data_types: sound,text
index_ds: IndexDSJsonl
batch_sampler: DynamicBatchLocalShuffleSampler
batch_type: example # example or length
batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
buffer_size: 500
shuffle: True
shuffle_conf:
shuffle_size: 2048
sort_size: 500
batch_conf:
batch_type: example
batch_size: 2
num_workers: 8
num_workers: 0
tokenizer: CharTokenizer
tokenizer_conf:
unk_symbol: <unk>
split_with_space: true
split_with_space: true
input_size: 560
ctc_conf:
dropout_rate: 0.0
ctc_type: builtin

View File

@ -39,8 +39,6 @@ class Paraformer(nn.Module):
def __init__(
self,
# token_list: Union[Tuple[str, ...], List[str]],
frontend: Optional[str] = None,
frontend_conf: Optional[Dict] = None,
specaug: Optional[str] = None,
specaug_conf: Optional[Dict] = None,
normalize: str = None,

View File

@ -1,6 +1,10 @@
# This is an example that demonstrates how to configure a model file.
# You can modify the configuration according to your own requirements.
# to print the register_table:
# from funasr.utils.register import registry_tables
# registry_tables.print()
# network architecture
#model: funasr.models.paraformer.model:Paraformer
model: Paraformer
@ -117,7 +121,6 @@ tokenizer_conf:
split_with_space: true
input_size: 560
ctc_conf:
dropout_rate: 0.0
ctc_type: builtin

View File

@ -0,0 +1,111 @@
# This is an example that demonstrates how to configure a model file.
# You can modify the configuration according to your own requirements.
# to print the register_table:
# from funasr.utils.register import registry_tables
# registry_tables.print()
# network architecture
#model: funasr.models.paraformer.model:Paraformer
model: Transformer
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
# encoder
encoder: TransformerEncoder
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder architecture type
normalize_before: true
# decoder
decoder: TransformerDecoder
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# frontend related
frontend: WavFrontend
frontend_conf:
fs: 16000
window: hamming
n_mels: 80
frame_length: 25
frame_shift: 10
lfr_m: 1
lfr_n: 1
specaug: SpecAug
specaug_conf:
apply_time_warp: true
time_warp_window: 5
time_warp_mode: bicubic
apply_freq_mask: true
freq_mask_width_range:
- 0
- 30
num_freq_mask: 2
apply_time_mask: true
time_mask_width_range:
- 0
- 40
num_time_mask: 2
train_conf:
accum_grad: 1
grad_clip: 5
max_epoch: 150
val_scheduler_criterion:
- valid
- acc
best_model_criterion:
- - valid
- acc
- max
keep_nbest_models: 10
log_interval: 50
optim: adam
optim_conf:
lr: 0.002
scheduler: warmuplr
scheduler_conf:
warmup_steps: 30000
dataset: AudioDataset
dataset_conf:
index_ds: IndexDSJsonl
batch_sampler: DynamicBatchLocalShuffleSampler
batch_type: example # example or length
batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
buffer_size: 500
shuffle: True
num_workers: 0
tokenizer: CharTokenizer
tokenizer_conf:
unk_symbol: <unk>
split_with_space: true
ctc_conf:
dropout_rate: 0.0
ctc_type: builtin
reduce: true
ignore_nan_grad: true
normalize: null

View File

@ -1,6 +1,6 @@
import logging
import inspect
from dataclasses import dataclass, fields
from dataclasses import dataclass
@dataclass
@ -19,7 +19,7 @@ class ClassRegistryTables:
dataset_classes = {}
index_ds_classes = {}
def print_register_tables(self,):
def print(self,):
print("\nregister_tables: \n")
fields = vars(self)
for classes_key, classes_dict in fields.items():