This commit is contained in:
游雁 2024-02-20 16:42:09 +08:00
parent bc19499b48
commit cb8b09e085
2 changed files with 10 additions and 5 deletions

View File

@ -26,9 +26,10 @@ class SpeechPreprocessSpeedPerturb(nn.Module):
return waveform
speed = random.choice(self.speed_perturb)
if speed != 1.0:
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
torch.tensor(waveform).view(1, -1), fs, [['speed', str(speed)], ['rate', str(fs)]])
waveform = waveform.view(-1)
with torch.no_grad():
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
torch.tensor(waveform).view(1, -1), fs, [['speed', str(speed)], ['rate', str(fs)]])
waveform = waveform.view(-1)
return waveform

View File

@ -273,8 +273,9 @@ class Trainer:
speed_stats["total_time"] = total_time
pbar.update(1)
if self.local_rank == 0:
pbar.update(1)
gpu_info = "GPU, memory: {:.3f} GB, " \
"{:.3f} GB, "\
"{:.3f} GB, "\
@ -290,6 +291,7 @@ class Trainer:
f"(loss: {loss.detach().cpu().item():.3f}), "
f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
f"{gpu_info}"
f"rank: {self.local_rank}"
)
pbar.set_description(description)
if self.writer:
@ -344,14 +346,16 @@ class Trainer:
loss = loss
time4 = time.perf_counter()
pbar.update(1)
if self.local_rank == 0:
pbar.update(1)
description = (
f"validation epoch: {epoch}/{self.max_epoch}, "
f"step {batch_idx}/{len(self.dataloader_train)}, "
f"{speed_stats}, "
f"(loss: {loss.detach().cpu().item():.3f}), "
f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
f"rank: {self.local_rank}"
)
pbar.set_description(description)
if self.writer: