mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
batch
This commit is contained in:
parent
0de8b6447c
commit
9a8086bdf5
@ -205,7 +205,6 @@ def main(**kwargs):
|
||||
dataloader_tr, dataloader_val = dataloader.build_iter(
|
||||
epoch, data_split_i=data_split_i, start_step=trainer.start_step
|
||||
)
|
||||
trainer.start_step = 0
|
||||
|
||||
trainer.train_epoch(
|
||||
model=model,
|
||||
@ -218,7 +217,9 @@ def main(**kwargs):
|
||||
writer=writer,
|
||||
data_split_i=data_split_i,
|
||||
data_split_num=dataloader.data_split_num,
|
||||
start_step=trainer.start_step,
|
||||
)
|
||||
trainer.start_step = 0
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@ -50,8 +50,8 @@ def update_data(lines, i):
|
||||
sample_num = len(waveform)
|
||||
source_len = int(sample_num / 16000 * 1000 / 10)
|
||||
source_len_old = data["source_len"]
|
||||
if (source_len_old - source_len) > 100 or (source_len - source_len_old) > 100:
|
||||
logging.info(f"old: {source_len_old}, new: {source_len}, wav: {wav_path}")
|
||||
# if (source_len_old - source_len) > 100 or (source_len - source_len_old) > 100:
|
||||
# logging.info(f"old: {source_len_old}, new: {source_len}, wav: {wav_path}")
|
||||
data["source_len"] = source_len
|
||||
data["source"] = wav_path
|
||||
jsonl_line = json.dumps(data, ensure_ascii=False)
|
||||
|
||||
@ -456,7 +456,7 @@ class Trainer:
|
||||
batch_num_epoch = len(dataloader_train)
|
||||
self.log(
|
||||
epoch,
|
||||
batch_idx,
|
||||
batch_idx + kwargs.get("start_step", 0),
|
||||
step_in_epoch=self.step_in_epoch,
|
||||
batch_num_epoch=batch_num_epoch,
|
||||
lr=lr,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user