mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Dev gzf exp (#1626)
* sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune
This commit is contained in:
parent
eaf9dda9e4
commit
824377d2aa
@ -55,6 +55,8 @@ def main(**kwargs):
|
||||
torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
|
||||
torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
|
||||
torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
|
||||
# open tf32
|
||||
torch.backends.cuda.matmul.allow_tf32 = kwargs.get("enable_tf32", True)
|
||||
|
||||
local_rank = int(os.environ.get('LOCAL_RANK', 0))
|
||||
if local_rank == 0:
|
||||
|
||||
@ -61,6 +61,7 @@ class EspnetStyleBatchSampler(DistributedSampler):
|
||||
self.epoch = 0
|
||||
self.sort_size = sort_size * num_replicas
|
||||
self.max_token_length = kwargs.get("max_token_length", 2048)
|
||||
self.min_token_length = kwargs.get("min_token_length", 0)
|
||||
self.length_scale_source = kwargs.get("length_scale_source", 1.0)
|
||||
|
||||
|
||||
@ -85,7 +86,7 @@ class EspnetStyleBatchSampler(DistributedSampler):
|
||||
|
||||
for idx in sorted_indices:
|
||||
original_sample_length = self.dataset.get_source_len(idx)
|
||||
if original_sample_length > self.max_token_length: # Skip samples that exceed the max length
|
||||
if original_sample_length < self.min_token_length or original_sample_length > self.max_token_length: # Skip samples that exceed the max length
|
||||
continue
|
||||
# Set sample_length based on the batch type
|
||||
sample_length = 1 if self.batch_type == "example" else original_sample_length
|
||||
|
||||
@ -76,7 +76,10 @@ class IndexDSJsonlRankFull(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(self, path: str, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.max_source_length = kwargs.get("max_source_length", 2048)
|
||||
self.min_source_length = kwargs.get("min_source_length", 0)
|
||||
self.max_target_length = kwargs.get("max_target_length", 2048)
|
||||
self.min_target_length = kwargs.get("min_target_length", 0)
|
||||
if isinstance(path, (list, tuple)): # wav.scp, text.txt/text.trans
|
||||
from funasr.datasets.audio_datasets.scp2jsonl import gen_jsonl_from_wav_text_list
|
||||
jsonl_outdir = os.path.dirname(path[0])
|
||||
@ -101,7 +104,10 @@ class IndexDSJsonlRankFull(torch.utils.data.Dataset):
|
||||
target_len = data.get("target_len", 0)
|
||||
if "aishell" in source:
|
||||
target = target.replace(" ", "")
|
||||
|
||||
if source_len < self.min_source_length or source_len > self.max_source_length:
|
||||
continue
|
||||
if target_len < self.min_target_length or target_len > self.max_target_length:
|
||||
continue
|
||||
contents_i = {"source": source,
|
||||
"prompt": prompt,
|
||||
"target": target,
|
||||
|
||||
@ -4,42 +4,85 @@
|
||||
# to print the register_table:
|
||||
# from funasr.register import tables
|
||||
# tables.print()
|
||||
|
||||
# network architecture
|
||||
model: SenseVoice
|
||||
model_conf:
|
||||
lsm_weight: 0.1
|
||||
length_normalized_loss: true
|
||||
hub: funasr
|
||||
activation_checkpoint: true
|
||||
sos: "<|startoftranscript|>"
|
||||
eos: "<|endoftext|>"
|
||||
downsample_rate: 4
|
||||
use_padmask: true
|
||||
|
||||
|
||||
|
||||
# only use for hub == funasr,
|
||||
# if hub == openai, dims is automaticall download
|
||||
dims:
|
||||
n_mels: 128
|
||||
n_vocab: 51866
|
||||
n_audio_ctx: 1500
|
||||
n_audio_state: 1280
|
||||
n_audio_head: 20
|
||||
n_audio_layer: 32
|
||||
n_text_ctx: 448
|
||||
n_text_state: 1280
|
||||
n_text_head: 20
|
||||
n_text_layer: 32
|
||||
dims:
|
||||
n_mels: 128
|
||||
n_vocab: 60515
|
||||
n_audio_ctx: 1500
|
||||
n_audio_state: 1280
|
||||
n_audio_head: 20
|
||||
n_audio_layer: 32
|
||||
n_text_ctx: 448
|
||||
n_text_state: 1280
|
||||
n_text_head: 20
|
||||
n_text_layer: 32
|
||||
|
||||
# frontend related
|
||||
frontend: WhisperFrontend
|
||||
frontend_conf:
|
||||
fs: 16000
|
||||
n_mels: ${dims.n_mels}
|
||||
do_pad_trim: true
|
||||
n_mels: ${model_conf.dims.n_mels}
|
||||
do_pad_trim: false
|
||||
|
||||
tokenizer: WhisperTokenizer
|
||||
tokenizer: SenseVoiceTokenizer
|
||||
tokenizer_conf:
|
||||
language: null
|
||||
task: transcribe
|
||||
vocab_path: null
|
||||
is_multilingual: true
|
||||
num_languages: 100
|
||||
num_languages: 8749
|
||||
|
||||
scope_map: [none, "model."]
|
||||
dataset: SenseVoiceDataset
|
||||
dataset_conf:
|
||||
index_ds: IndexDSJsonl
|
||||
batch_sampler: EspnetStyleBatchSampler
|
||||
batch_type: length # example or length
|
||||
batch_size: 7000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
|
||||
max_token_length: 2000 # filter samples if source_token_len+target_token_len > max_token_length,
|
||||
min_token_length: 60
|
||||
shuffle: True
|
||||
num_workers: 4
|
||||
sos: ${model_conf.sos}
|
||||
eos: ${model_conf.eos}
|
||||
|
||||
train_conf:
|
||||
accum_grad: 2
|
||||
grad_clip: 5
|
||||
max_epoch: 20
|
||||
keep_nbest_models: 20
|
||||
avg_nbest_model: ${train_conf.keep_nbest_models}
|
||||
log_interval: 50
|
||||
|
||||
optim: adamw
|
||||
optim_conf:
|
||||
lr: 0.00002
|
||||
|
||||
scheduler: warmuplr
|
||||
scheduler_conf:
|
||||
warmup_steps: 10000
|
||||
|
||||
specaug: SpecAug
|
||||
specaug_conf:
|
||||
apply_time_warp: true
|
||||
time_warp_window: 5
|
||||
time_warp_mode: bicubic
|
||||
apply_freq_mask: true
|
||||
freq_mask_width_range:
|
||||
- 0
|
||||
- 40
|
||||
num_freq_mask: 2
|
||||
apply_time_mask: true
|
||||
time_mask_width_ratio_range:
|
||||
- 0.0
|
||||
- 0.12
|
||||
num_time_mask: 2
|
||||
|
||||
scope_map: ['encoder.encoders', 'model.encoder', 'decoder.decoders', 'model.decoder']
|
||||
@ -268,10 +268,12 @@ class Trainer:
|
||||
# Initialize the gradient accumulation
|
||||
optim.zero_grad()
|
||||
speed_stats = {}
|
||||
time5 = time.perf_counter()
|
||||
|
||||
iterator_stop = torch.tensor(0).to(self.device)
|
||||
|
||||
dataloader_train.batch_sampler.set_epoch(epoch)
|
||||
time_beg = time.perf_counter()
|
||||
time5 = time_beg
|
||||
for batch_idx, batch in enumerate(dataloader_train):
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
|
||||
@ -279,11 +281,13 @@ class Trainer:
|
||||
break
|
||||
self.batch_total += 1
|
||||
time1 = time.perf_counter()
|
||||
speed_stats["data_load"] = f"{time1-time5:0.3f}"
|
||||
speed_stats["data_load"] = f"{time1-time_beg:0.3f}"
|
||||
|
||||
batch = to_device(batch, self.device)
|
||||
|
||||
my_context = model.no_sync if batch_idx % accum_grad != 0 else nullcontext
|
||||
|
||||
my_context = nullcontext
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
my_context = model.no_sync if batch_idx % accum_grad != 0 else my_context
|
||||
with my_context():
|
||||
time2 = time.perf_counter()
|
||||
with maybe_autocast(self.use_fp16):
|
||||
@ -384,6 +388,7 @@ class Trainer:
|
||||
if (batch_idx+1) % self.save_checkpoint_interval == 0:
|
||||
self.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler, step=batch_idx+1)
|
||||
|
||||
time_beg = time.perf_counter()
|
||||
else:
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
iterator_stop.fill_(1)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user