Dev gzf exp (#1670)

* resume from step

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch
This commit is contained in:
zhifu gao 2024-04-28 15:14:57 +08:00 committed by GitHub
parent 1cdb3cc28d
commit 93ef505e2d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 116 additions and 59 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 178 KiB

After

Width:  |  Height:  |  Size: 182 KiB

View File

@ -13,7 +13,7 @@ from io import BytesIO
from contextlib import nullcontext from contextlib import nullcontext
import torch.distributed as dist import torch.distributed as dist
from collections.abc import Sequence
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from torch.cuda.amp import autocast, GradScaler from torch.cuda.amp import autocast, GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
@ -99,7 +99,7 @@ def main(**kwargs):
if freeze_param is not None: if freeze_param is not None:
if "," in freeze_param: if "," in freeze_param:
freeze_param = eval(freeze_param) freeze_param = eval(freeze_param)
if not isinstance(freeze_param, Sequence): if not isinstance(freeze_param, (list, tuple)):
freeze_param = (freeze_param,) freeze_param = (freeze_param,)
logging.info("freeze_param is not None: %s", freeze_param) logging.info("freeze_param is not None: %s", freeze_param)
for t in freeze_param: for t in freeze_param:
@ -193,7 +193,7 @@ def main(**kwargs):
try: try:
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
writer = SummaryWriter(tensorboard_dir) if trainer.rank == 0 else None writer = SummaryWriter(tensorboard_dir) # if trainer.rank == 0 else None
except: except:
writer = None writer = None
@ -206,6 +206,7 @@ def main(**kwargs):
epoch, data_split_i=data_split_i, start_step=trainer.start_step epoch, data_split_i=data_split_i, start_step=trainer.start_step
) )
trainer.start_step = 0 trainer.start_step = 0
trainer.train_epoch( trainer.train_epoch(
model=model, model=model,
optim=optim, optim=optim,
@ -225,7 +226,7 @@ def main(**kwargs):
model=model, dataloader_val=dataloader_val, epoch=epoch, writer=writer model=model, dataloader_val=dataloader_val, epoch=epoch, writer=writer
) )
scheduler.step() scheduler.step()
trainer.step_cur_in_epoch = 0
trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler) trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler)
time2 = time.perf_counter() time2 = time.perf_counter()

View File

@ -51,6 +51,7 @@ class SenseVoiceDataset(torch.utils.data.Dataset):
self.batch_size = kwargs.get("batch_size") self.batch_size = kwargs.get("batch_size")
self.batch_type = kwargs.get("batch_type") self.batch_type = kwargs.get("batch_type")
self.prompt_ids_len = 0 self.prompt_ids_len = 0
self.retry = kwargs.get("retry", 5)
def get_source_len(self, index): def get_source_len(self, index):
item = self.index_ds[index] item = self.index_ds[index]
@ -64,59 +65,75 @@ class SenseVoiceDataset(torch.utils.data.Dataset):
return len(self.index_ds) return len(self.index_ds)
def __getitem__(self, index): def __getitem__(self, index):
item = self.index_ds[index]
# import pdb; # import pdb;
# pdb.set_trace() # pdb.set_trace()
source = item["source"]
data_src = load_audio_text_image_video(source, fs=self.fs)
if self.preprocessor_speech:
data_src = self.preprocessor_speech(data_src, fs=self.fs)
speech, speech_lengths = extract_fbank(
data_src, data_type=self.data_type, frontend=self.frontend, is_final=True
) # speech: [b, T, d]
if speech_lengths > self.batch_size: output = None
return None for idx in range(self.retry):
speech = speech.permute(0, 2, 1) if idx == 0:
target = item["target"] index_cur = index
if self.preprocessor_text: else:
target = self.preprocessor_text(target) if index <= self.retry:
index_cur = index + idx
else:
index_cur = torch.randint(0, index, ()).item()
task = item.get("prompt", "<|ASR|>") item = self.index_ds[index_cur]
text_language = item.get("text_language", "<|zh|>")
prompt = f"{self.sos}{task}{text_language}" source = item["source"]
prompt_ids = self.tokenizer.encode(prompt, allowed_special="all") data_src = load_audio_text_image_video(source, fs=self.fs)
prompt_ids_len = len(prompt_ids) - 1 # [sos, task] if self.preprocessor_speech:
self.prompt_ids_len = prompt_ids_len data_src = self.preprocessor_speech(data_src, fs=self.fs)
speech, speech_lengths = extract_fbank(
data_src, data_type=self.data_type, frontend=self.frontend, is_final=True
) # speech: [b, T, d]
target_ids = self.tokenizer.encode(target, allowed_special="all") if speech_lengths > self.batch_size:
target_ids_len = len(target_ids) + 1 # [lid, text] continue
if target_ids_len > 200: speech = speech.permute(0, 2, 1)
return None target = item["target"]
if self.preprocessor_text:
target = self.preprocessor_text(target)
eos = self.tokenizer.encode(self.eos, allowed_special="all") # [eos] task = item.get("prompt", "<|ASR|>")
text_language = item.get("text_language", "<|zh|>")
ids = prompt_ids + target_ids + eos prompt = f"{self.sos}{task}{text_language}"
ids_lengths = len(ids) prompt_ids = self.tokenizer.encode(prompt, allowed_special="all")
prompt_ids_len = len(prompt_ids) - 1 # [sos, task]
self.prompt_ids_len = prompt_ids_len
text = torch.tensor(ids, dtype=torch.int64) target_ids = self.tokenizer.encode(target, allowed_special="all")
text_lengths = torch.tensor([ids_lengths], dtype=torch.int32) target_ids_len = len(target_ids) + 1 # [lid, text]
if target_ids_len > 200:
continue
target_mask = ( eos = self.tokenizer.encode(self.eos, allowed_special="all") # [eos]
[0] * (prompt_ids_len) + [1] * (target_ids_len) + [1]
) # [sos, task, lid, text, eos]: [0, 0, 1, 1, 1] ids = prompt_ids + target_ids + eos
target_mask_lengths = len(target_mask) ids_lengths = len(ids)
target_mask = torch.tensor(target_mask, dtype=torch.float32)
target_mask_lengths = torch.tensor([target_mask_lengths], dtype=torch.int32) text = torch.tensor(ids, dtype=torch.int64)
return { text_lengths = torch.tensor([ids_lengths], dtype=torch.int32)
"speech": speech[0, :, :],
"speech_lengths": speech_lengths, target_mask = (
"text": text, [0] * (prompt_ids_len) + [1] * (target_ids_len) + [1]
"text_lengths": text_lengths, ) # [sos, task, lid, text, eos]: [0, 0, 1, 1, 1]
"target_mask": target_mask, target_mask_lengths = len(target_mask)
"target_mask_lengths": target_mask_lengths, target_mask = torch.tensor(target_mask, dtype=torch.float32)
} target_mask_lengths = torch.tensor([target_mask_lengths], dtype=torch.int32)
output = {
"speech": speech[0, :, :],
"speech_lengths": speech_lengths,
"text": text,
"text_lengths": text_lengths,
"target_mask": target_mask,
"target_mask_lengths": target_mask_lengths,
}
break
return output
def collator(self, samples: list = None): def collator(self, samples: list = None):
outputs = {} outputs = {}
@ -129,13 +146,30 @@ class SenseVoiceDataset(torch.utils.data.Dataset):
outputs[key].append(sample[key]) outputs[key].append(sample[key])
if len(outputs) < 1: if len(outputs) < 1:
logging.info(f"ERROR: data is empty!") logging.error(f"ERROR: data is empty!")
outputs = { outputs = {
"speech": torch.rand((10, 128), dtype=torch.float32), "speech": torch.rand((10, 128), dtype=torch.float32)[None, :, :],
"speech_lengths": torch.tensor([10], dtype=torch.int32), "speech_lengths": torch.tensor(
"text": torch.tensor([58836], dtype=torch.int32), [
"text_lengths": torch.tensor([1], dtype=torch.int32), 10,
"target_mask": torch.tensor([[0] * (self.prompt_ids_len) + [1] * (1) + [1]]), ],
dtype=torch.int32,
)[:, None],
"text": torch.tensor(
[
58836,
],
dtype=torch.int32,
)[None, :],
"text_lengths": torch.tensor(
[
1,
],
dtype=torch.int32,
)[:, None],
"target_mask": torch.tensor([[0] * (self.prompt_ids_len) + [1] * (1) + [1]])[
None, :
],
} }
return outputs return outputs
@ -159,7 +193,7 @@ class SenseVoiceDataset(torch.utils.data.Dataset):
def _filter_badcase(self, outputs, i=0): def _filter_badcase(self, outputs, i=0):
b, t, _ = outputs["speech"].shape b, t, _ = outputs["speech"].shape
if b * t > self.batch_size * 1.25: if b * t > self.batch_size * 1.25:
beg = torch.randint(0, 2, ()).item() beg = torch.randint(0, 2, ()).item()
if b < 2: if b < 2:
@ -170,7 +204,6 @@ class SenseVoiceDataset(torch.utils.data.Dataset):
for key, data_list in outputs.items(): for key, data_list in outputs.items():
outputs[key] = outputs[key][beg : beg + b : 2] outputs[key] = outputs[key][beg : beg + b : 2]
speech_lengths_max = outputs["speech_lengths"].max().item() speech_lengths_max = outputs["speech_lengths"].max().item()
outputs["speech"] = outputs["speech"][:, :speech_lengths_max, :] outputs["speech"] = outputs["speech"][:, :speech_lengths_max, :]
text_lengths_max = outputs["text_lengths"].max().item() text_lengths_max = outputs["text_lengths"].max().item()

View File

@ -116,6 +116,7 @@ class Trainer:
self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False) self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False)
self.start_data_split_i = 0 self.start_data_split_i = 0
self.start_step = 0 self.start_step = 0
self.step_cur_in_epoch = 0
self.use_wandb = kwargs.get("use_wandb", False) self.use_wandb = kwargs.get("use_wandb", False)
if self.use_wandb: if self.use_wandb:
wandb.login(key=kwargs.get("wandb_token")) wandb.login(key=kwargs.get("wandb_token"))
@ -137,6 +138,8 @@ class Trainer:
optim=None, optim=None,
scheduler=None, scheduler=None,
scaler=None, scaler=None,
step_cur_in_epoch=None,
**kwargs,
): ):
""" """
Saves a checkpoint containing the model's state, the optimizer's state, Saves a checkpoint containing the model's state, the optimizer's state,
@ -147,6 +150,7 @@ class Trainer:
epoch (int): The epoch number at which the checkpoint is being saved. epoch (int): The epoch number at which the checkpoint is being saved.
""" """
step_cur_in_epoch = None if step is None else step_cur_in_epoch
if self.rank == 0: if self.rank == 0:
logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n") logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
# self.step_or_epoch += 1 # self.step_or_epoch += 1
@ -161,7 +165,12 @@ class Trainer:
"best_step_or_epoch": self.best_step_or_epoch, "best_step_or_epoch": self.best_step_or_epoch,
"avg_keep_nbest_models_type": self.avg_keep_nbest_models_type, "avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
"step": step, "step": step,
"step_cur_in_epoch": step_cur_in_epoch,
"data_split_i": kwargs.get("data_split_i", 0),
"data_split_num": kwargs.get("data_split_num", 1),
"batch_total": self.batch_total,
} }
step = step_cur_in_epoch
if hasattr(model, "module"): if hasattr(model, "module"):
state["state_dict"] = model.module.state_dict() state["state_dict"] = model.module.state_dict()
@ -293,6 +302,12 @@ class Trainer:
self.batch_total = checkpoint["batch_total"] if "batch_total" in checkpoint else 0 self.batch_total = checkpoint["batch_total"] if "batch_total" in checkpoint else 0
self.start_step = checkpoint["step"] if "step" in checkpoint else 0 self.start_step = checkpoint["step"] if "step" in checkpoint else 0
self.start_step = 0 if self.start_step is None else self.start_step self.start_step = 0 if self.start_step is None else self.start_step
self.step_cur_in_epoch = (
checkpoint["step_cur_in_epoch"] if "step_cur_in_epoch" in checkpoint else 0
)
self.step_cur_in_epoch = (
0 if self.step_cur_in_epoch is None else self.step_cur_in_epoch
)
model.to(self.device) model.to(self.device)
print(f"Checkpoint loaded successfully from '{ckpt}'") print(f"Checkpoint loaded successfully from '{ckpt}'")
@ -321,7 +336,7 @@ class Trainer:
""" """
if self.use_ddp or self.use_fsdp: if self.use_ddp or self.use_fsdp:
dist.barrier() dist.barrier()
logging.info(f"Train epoch: {epoch}, rank: {self.local_rank}\n") logging.info(f"Train epoch: {epoch}, rank: {self.rank}\n")
model.train() model.train()
# Set the number of steps for gradient accumulation # Set the number of steps for gradient accumulation
@ -341,6 +356,7 @@ class Trainer:
if iterator_stop > 0: if iterator_stop > 0:
break break
self.batch_total += 1 self.batch_total += 1
self.step_cur_in_epoch += 1
time1 = time.perf_counter() time1 = time.perf_counter()
speed_stats["data_load"] = f"{time1-time_beg:0.3f}" speed_stats["data_load"] = f"{time1-time_beg:0.3f}"
@ -443,6 +459,7 @@ class Trainer:
self.log( self.log(
epoch, epoch,
batch_idx, batch_idx,
step_cur_in_epoch=self.step_cur_in_epoch,
batch_num_epoch=batch_num_epoch, batch_num_epoch=batch_num_epoch,
lr=lr, lr=lr,
loss=loss.detach().cpu().item(), loss=loss.detach().cpu().item(),
@ -461,6 +478,7 @@ class Trainer:
epoch=epoch, epoch=epoch,
writer=writer, writer=writer,
step=batch_idx + 1, step=batch_idx + 1,
step_cur_in_epoch=self.step_cur_in_epoch,
) )
if (batch_idx + 1) % self.save_checkpoint_interval == 0: if (batch_idx + 1) % self.save_checkpoint_interval == 0:
@ -471,6 +489,9 @@ class Trainer:
scheduler=scheduler, scheduler=scheduler,
scaler=scaler, scaler=scaler,
step=batch_idx + 1, step=batch_idx + 1,
step_cur_in_epoch=self.step_cur_in_epoch,
data_split_i=kwargs.get("data_split_i", 0),
data_split_num=kwargs.get("data_split_num", 1),
) )
time_beg = time.perf_counter() time_beg = time.perf_counter()
@ -500,7 +521,7 @@ class Trainer:
""" """
if self.use_ddp or self.use_fsdp: if self.use_ddp or self.use_fsdp:
dist.barrier() dist.barrier()
logging.info(f"Validate epoch: {epoch}, rank: {self.local_rank}\n") logging.info(f"Validate epoch: {epoch}, rank: {self.rank}\n")
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
@ -578,10 +599,10 @@ class Trainer:
iterator_stop.fill_(1) iterator_stop.fill_(1)
dist.all_reduce(iterator_stop, dist.ReduceOp.SUM) dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
if kwargs.get("step", None) is None: if kwargs.get("step_cur_in_epoch", None) is None:
ckpt_name = f"model.pt.ep{epoch}" ckpt_name = f"model.pt.ep{epoch}"
else: else:
ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step")}' ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step_cur_in_epoch")}'
self.val_acc_step_or_eoch[ckpt_name] = self.val_acc_avg self.val_acc_step_or_eoch[ckpt_name] = self.val_acc_avg
self.val_loss_step_or_eoch[ckpt_name] = self.val_loss_avg self.val_loss_step_or_eoch[ckpt_name] = self.val_loss_avg
model.train() model.train()
@ -594,6 +615,7 @@ class Trainer:
self, self,
epoch=0, epoch=0,
batch_idx=0, batch_idx=0,
step_cur_in_epoch=0,
batch_num_epoch=-1, batch_num_epoch=-1,
lr=0.0, lr=0.0,
loss=0.0, loss=0.0,
@ -626,6 +648,7 @@ class Trainer:
f"{tag}, " f"{tag}, "
f"rank: {self.rank}, " f"rank: {self.rank}, "
f"epoch: {epoch}/{self.max_epoch}, " f"epoch: {epoch}/{self.max_epoch}, "
f"step_cur_in_epoch: {step_cur_in_epoch}, "
f"data_slice: {data_split_i}/{data_split_num}, " f"data_slice: {data_split_i}/{data_split_num}, "
f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, " f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, "
f"(loss_avg_rank: {loss:.3f}), " f"(loss_avg_rank: {loss:.3f}), "