This commit is contained in:
游雁 2024-04-30 10:48:31 +08:00
parent 0de8b6447c
commit 9a8086bdf5
3 changed files with 5 additions and 4 deletions

View File

@ -205,7 +205,6 @@ def main(**kwargs):
dataloader_tr, dataloader_val = dataloader.build_iter(
epoch, data_split_i=data_split_i, start_step=trainer.start_step
)
trainer.start_step = 0
trainer.train_epoch(
model=model,
@ -218,7 +217,9 @@ def main(**kwargs):
writer=writer,
data_split_i=data_split_i,
data_split_num=dataloader.data_split_num,
start_step=trainer.start_step,
)
trainer.start_step = 0
torch.cuda.empty_cache()

View File

@ -50,8 +50,8 @@ def update_data(lines, i):
sample_num = len(waveform)
source_len = int(sample_num / 16000 * 1000 / 10)
source_len_old = data["source_len"]
if (source_len_old - source_len) > 100 or (source_len - source_len_old) > 100:
logging.info(f"old: {source_len_old}, new: {source_len}, wav: {wav_path}")
# if (source_len_old - source_len) > 100 or (source_len - source_len_old) > 100:
# logging.info(f"old: {source_len_old}, new: {source_len}, wav: {wav_path}")
data["source_len"] = source_len
data["source"] = wav_path
jsonl_line = json.dumps(data, ensure_ascii=False)

View File

@ -456,7 +456,7 @@ class Trainer:
batch_num_epoch = len(dataloader_train)
self.log(
epoch,
batch_idx,
batch_idx + kwargs.get("start_step", 0),
step_in_epoch=self.step_in_epoch,
batch_num_epoch=batch_num_epoch,
lr=lr,