mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
train finetune
This commit is contained in:
parent
58b6154a73
commit
45d9ccafef
@ -105,7 +105,8 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
echo "stage 4: ASR Training"
|
||||
|
||||
mkdir -p ${exp_dir}/exp/${model_dir}
|
||||
log_file="${exp_dir}/exp/${model_dir}/train.log.txt"
|
||||
current_time=$(date "+%Y-%m-%d_%H-%M")
|
||||
log_file="${exp_dir}/exp/${model_dir}/train.log.txt.${current_time}"
|
||||
echo "log_file: ${log_file}"
|
||||
|
||||
gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
|
||||
@ -26,10 +26,11 @@ class SpeechPreprocessSpeedPerturb(nn.Module):
|
||||
return waveform
|
||||
speed = random.choice(self.speed_perturb)
|
||||
if speed != 1.0:
|
||||
with torch.no_grad():
|
||||
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
|
||||
torch.tensor(waveform).view(1, -1), fs, [['speed', str(speed)], ['rate', str(fs)]])
|
||||
waveform = waveform.view(-1)
|
||||
if not isinstance(waveform, torch.Tensor):
|
||||
waveform = torch.tensor(waveform)
|
||||
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
|
||||
waveform.view(1, -1), fs, [['speed', str(speed)], ['rate', str(fs)]])
|
||||
waveform = waveform.view(-1)
|
||||
|
||||
return waveform
|
||||
|
||||
|
||||
@ -70,6 +70,7 @@ class Trainer:
|
||||
self.avg_nbest_model = kwargs.get("avg_nbest_model", 5)
|
||||
self.kwargs = kwargs
|
||||
self.log_interval = kwargs.get("log_interval", 50)
|
||||
self.batch_total = 0
|
||||
|
||||
|
||||
try:
|
||||
@ -196,7 +197,9 @@ class Trainer:
|
||||
self.optim.zero_grad()
|
||||
speed_stats = {}
|
||||
time5 = time.perf_counter()
|
||||
|
||||
for batch_idx, batch in enumerate(self.dataloader_train):
|
||||
self.batch_total += 1
|
||||
time1 = time.perf_counter()
|
||||
speed_stats["data_load"] = f"{time1-time5:0.3f}"
|
||||
|
||||
@ -205,25 +208,10 @@ class Trainer:
|
||||
my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
|
||||
with my_context():
|
||||
time2 = time.perf_counter()
|
||||
# print("before, GPU, memory: {:.3f} GB, "
|
||||
# "{:.3f} GB, "
|
||||
# "{:.3f} GB, "
|
||||
# "{:.3f} GB".format(torch.cuda.memory_allocated()/1024/1024/1024,
|
||||
# torch.cuda.max_memory_allocated()/1024/1024/1024,
|
||||
# torch.cuda.memory_reserved()/1024/1024/1024,
|
||||
# torch.cuda.max_memory_reserved()/1024/1024/1024,
|
||||
# ))
|
||||
|
||||
retval = self.model(**batch)
|
||||
torch.cuda.empty_cache()
|
||||
# print("after, GPU, memory: {:.3f} GB, "
|
||||
# "{:.3f} GB, "
|
||||
# "{:.3f} GB, "
|
||||
# "{:.3f} GB".format(torch.cuda.memory_allocated()/1024/1024/1024,
|
||||
# torch.cuda.max_memory_allocated()/1024/1024/1024,
|
||||
# torch.cuda.memory_reserved()/1024/1024/1024,
|
||||
# torch.cuda.max_memory_reserved()/1024/1024/1024,
|
||||
# ))
|
||||
|
||||
time3 = time.perf_counter()
|
||||
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
|
||||
loss, stats, weight = retval
|
||||
@ -275,7 +263,7 @@ class Trainer:
|
||||
|
||||
|
||||
|
||||
if batch_idx % self.log_interval == 0 or batch_idx == len(self.dataloader_train) - 1:
|
||||
if (batch_idx+1) % self.log_interval == 0 or (batch_idx+1) == len(self.dataloader_train):
|
||||
pbar.update(self.log_interval)
|
||||
gpu_info = "GPU, memory: {:.3f} GB, " \
|
||||
"{:.3f} GB, "\
|
||||
@ -287,22 +275,22 @@ class Trainer:
|
||||
)
|
||||
description = (
|
||||
f"rank: {self.local_rank}, "
|
||||
f"Train epoch: {epoch}/{self.max_epoch}, "
|
||||
f"step {batch_idx}/{len(self.dataloader_train)}, "
|
||||
f"{speed_stats}, "
|
||||
f"epoch: {epoch}/{self.max_epoch}, "
|
||||
f"step: {batch_idx}/{len(self.dataloader_train)}, total: {self.batch_total}, "
|
||||
f"(loss: {loss.detach().cpu().item():.3f}), "
|
||||
f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
|
||||
f"{speed_stats}, "
|
||||
f"{gpu_info}"
|
||||
)
|
||||
pbar.set_description(description)
|
||||
if self.writer:
|
||||
self.writer.add_scalar(f'rank{self.local_rank}, Loss/train', loss.item(),
|
||||
self.writer.add_scalar(f'rank{self.local_rank}_Loss/train', loss.item(),
|
||||
epoch*len(self.dataloader_train) + batch_idx)
|
||||
for key, var in stats.items():
|
||||
self.writer.add_scalar(f'rank{self.local_rank}, {key}/train', var.item(),
|
||||
self.writer.add_scalar(f'rank{self.local_rank}_{key}/train', var.item(),
|
||||
epoch * len(self.dataloader_train) + batch_idx)
|
||||
for key, var in speed_stats.items():
|
||||
self.writer.add_scalar(f'rank{self.local_rank}, {key}/train', eval(var),
|
||||
self.writer.add_scalar(f'rank{self.local_rank}_{key}/train', eval(var),
|
||||
epoch * len(self.dataloader_train) + batch_idx)
|
||||
|
||||
# if batch_idx == 2:
|
||||
@ -348,24 +336,23 @@ class Trainer:
|
||||
time4 = time.perf_counter()
|
||||
|
||||
|
||||
if batch_idx % self.log_interval == 0 or batch_idx == len(self.dataloader_train) - 1:
|
||||
if (batch_idx+1) % self.log_interval == 0 or (batch_idx+1) == len(self.dataloader_val):
|
||||
pbar.update(self.log_interval)
|
||||
description = (
|
||||
f"rank: {self.local_rank}, "
|
||||
f"validation epoch: {epoch}/{self.max_epoch}, "
|
||||
f"step {batch_idx}/{len(self.dataloader_train)}, "
|
||||
f"{speed_stats}, "
|
||||
f"step: {batch_idx}/{len(self.dataloader_val)}, "
|
||||
f"(loss: {loss.detach().cpu().item():.3f}), "
|
||||
f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
|
||||
f"rank: {self.local_rank}"
|
||||
f"{speed_stats}, "
|
||||
)
|
||||
pbar.set_description(description)
|
||||
if self.writer:
|
||||
self.writer.add_scalar(f"rank{self.local_rank}, Loss/val", loss.item(),
|
||||
epoch*len(self.dataloader_train) + batch_idx)
|
||||
self.writer.add_scalar(f"rank{self.local_rank}_Loss/val", loss.item(),
|
||||
epoch*len(self.dataloader_val) + batch_idx)
|
||||
for key, var in stats.items():
|
||||
self.writer.add_scalar(f'rank{self.local_rank}, {key}/val', var.item(),
|
||||
epoch * len(self.dataloader_train) + batch_idx)
|
||||
self.writer.add_scalar(f'rank{self.local_rank}_{key}/val', var.item(),
|
||||
epoch * len(self.dataloader_val) + batch_idx)
|
||||
for key, var in speed_stats.items():
|
||||
self.writer.add_scalar(f'rank{self.local_rank}, {key}/val', eval(var),
|
||||
epoch * len(self.dataloader_train) + batch_idx)
|
||||
self.writer.add_scalar(f'rank{self.local_rank}_{key}/val', eval(var),
|
||||
epoch * len(self.dataloader_val) + batch_idx)
|
||||
Loading…
Reference in New Issue
Block a user