Dev gzf exp (#1670)

* resume from step

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch
This commit is contained in:
zhifu gao 2024-04-28 15:14:57 +08:00 committed by GitHub
parent 1cdb3cc28d
commit 93ef505e2d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 116 additions and 59 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 178 KiB

After

Width:  |  Height:  |  Size: 182 KiB

View File

@ -13,7 +13,7 @@ from io import BytesIO
from contextlib import nullcontext
import torch.distributed as dist
from collections.abc import Sequence
from omegaconf import DictConfig, OmegaConf
from torch.cuda.amp import autocast, GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
@ -99,7 +99,7 @@ def main(**kwargs):
if freeze_param is not None:
if "," in freeze_param:
freeze_param = eval(freeze_param)
if not isinstance(freeze_param, Sequence):
if not isinstance(freeze_param, (list, tuple)):
freeze_param = (freeze_param,)
logging.info("freeze_param is not None: %s", freeze_param)
for t in freeze_param:
@ -193,7 +193,7 @@ def main(**kwargs):
try:
from tensorboardX import SummaryWriter
writer = SummaryWriter(tensorboard_dir) if trainer.rank == 0 else None
writer = SummaryWriter(tensorboard_dir) # if trainer.rank == 0 else None
except:
writer = None
@ -206,6 +206,7 @@ def main(**kwargs):
epoch, data_split_i=data_split_i, start_step=trainer.start_step
)
trainer.start_step = 0
trainer.train_epoch(
model=model,
optim=optim,
@ -225,7 +226,7 @@ def main(**kwargs):
model=model, dataloader_val=dataloader_val, epoch=epoch, writer=writer
)
scheduler.step()
trainer.step_cur_in_epoch = 0
trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler)
time2 = time.perf_counter()

View File

@ -51,6 +51,7 @@ class SenseVoiceDataset(torch.utils.data.Dataset):
self.batch_size = kwargs.get("batch_size")
self.batch_type = kwargs.get("batch_type")
self.prompt_ids_len = 0
self.retry = kwargs.get("retry", 5)
def get_source_len(self, index):
item = self.index_ds[index]
@ -64,59 +65,75 @@ class SenseVoiceDataset(torch.utils.data.Dataset):
return len(self.index_ds)
def __getitem__(self, index):
item = self.index_ds[index]
# import pdb;
# pdb.set_trace()
source = item["source"]
data_src = load_audio_text_image_video(source, fs=self.fs)
if self.preprocessor_speech:
data_src = self.preprocessor_speech(data_src, fs=self.fs)
speech, speech_lengths = extract_fbank(
data_src, data_type=self.data_type, frontend=self.frontend, is_final=True
) # speech: [b, T, d]
if speech_lengths > self.batch_size:
return None
speech = speech.permute(0, 2, 1)
target = item["target"]
if self.preprocessor_text:
target = self.preprocessor_text(target)
output = None
for idx in range(self.retry):
if idx == 0:
index_cur = index
else:
if index <= self.retry:
index_cur = index + idx
else:
index_cur = torch.randint(0, index, ()).item()
task = item.get("prompt", "<|ASR|>")
text_language = item.get("text_language", "<|zh|>")
item = self.index_ds[index_cur]
prompt = f"{self.sos}{task}{text_language}"
prompt_ids = self.tokenizer.encode(prompt, allowed_special="all")
prompt_ids_len = len(prompt_ids) - 1 # [sos, task]
self.prompt_ids_len = prompt_ids_len
source = item["source"]
data_src = load_audio_text_image_video(source, fs=self.fs)
if self.preprocessor_speech:
data_src = self.preprocessor_speech(data_src, fs=self.fs)
speech, speech_lengths = extract_fbank(
data_src, data_type=self.data_type, frontend=self.frontend, is_final=True
) # speech: [b, T, d]
target_ids = self.tokenizer.encode(target, allowed_special="all")
target_ids_len = len(target_ids) + 1 # [lid, text]
if target_ids_len > 200:
return None
if speech_lengths > self.batch_size:
continue
speech = speech.permute(0, 2, 1)
target = item["target"]
if self.preprocessor_text:
target = self.preprocessor_text(target)
eos = self.tokenizer.encode(self.eos, allowed_special="all") # [eos]
task = item.get("prompt", "<|ASR|>")
text_language = item.get("text_language", "<|zh|>")
ids = prompt_ids + target_ids + eos
ids_lengths = len(ids)
prompt = f"{self.sos}{task}{text_language}"
prompt_ids = self.tokenizer.encode(prompt, allowed_special="all")
prompt_ids_len = len(prompt_ids) - 1 # [sos, task]
self.prompt_ids_len = prompt_ids_len
text = torch.tensor(ids, dtype=torch.int64)
text_lengths = torch.tensor([ids_lengths], dtype=torch.int32)
target_ids = self.tokenizer.encode(target, allowed_special="all")
target_ids_len = len(target_ids) + 1 # [lid, text]
if target_ids_len > 200:
continue
target_mask = (
[0] * (prompt_ids_len) + [1] * (target_ids_len) + [1]
) # [sos, task, lid, text, eos]: [0, 0, 1, 1, 1]
target_mask_lengths = len(target_mask)
target_mask = torch.tensor(target_mask, dtype=torch.float32)
target_mask_lengths = torch.tensor([target_mask_lengths], dtype=torch.int32)
return {
"speech": speech[0, :, :],
"speech_lengths": speech_lengths,
"text": text,
"text_lengths": text_lengths,
"target_mask": target_mask,
"target_mask_lengths": target_mask_lengths,
}
eos = self.tokenizer.encode(self.eos, allowed_special="all") # [eos]
ids = prompt_ids + target_ids + eos
ids_lengths = len(ids)
text = torch.tensor(ids, dtype=torch.int64)
text_lengths = torch.tensor([ids_lengths], dtype=torch.int32)
target_mask = (
[0] * (prompt_ids_len) + [1] * (target_ids_len) + [1]
) # [sos, task, lid, text, eos]: [0, 0, 1, 1, 1]
target_mask_lengths = len(target_mask)
target_mask = torch.tensor(target_mask, dtype=torch.float32)
target_mask_lengths = torch.tensor([target_mask_lengths], dtype=torch.int32)
output = {
"speech": speech[0, :, :],
"speech_lengths": speech_lengths,
"text": text,
"text_lengths": text_lengths,
"target_mask": target_mask,
"target_mask_lengths": target_mask_lengths,
}
break
return output
def collator(self, samples: list = None):
outputs = {}
@ -129,13 +146,30 @@ class SenseVoiceDataset(torch.utils.data.Dataset):
outputs[key].append(sample[key])
if len(outputs) < 1:
logging.info(f"ERROR: data is empty!")
logging.error(f"ERROR: data is empty!")
outputs = {
"speech": torch.rand((10, 128), dtype=torch.float32),
"speech_lengths": torch.tensor([10], dtype=torch.int32),
"text": torch.tensor([58836], dtype=torch.int32),
"text_lengths": torch.tensor([1], dtype=torch.int32),
"target_mask": torch.tensor([[0] * (self.prompt_ids_len) + [1] * (1) + [1]]),
"speech": torch.rand((10, 128), dtype=torch.float32)[None, :, :],
"speech_lengths": torch.tensor(
[
10,
],
dtype=torch.int32,
)[:, None],
"text": torch.tensor(
[
58836,
],
dtype=torch.int32,
)[None, :],
"text_lengths": torch.tensor(
[
1,
],
dtype=torch.int32,
)[:, None],
"target_mask": torch.tensor([[0] * (self.prompt_ids_len) + [1] * (1) + [1]])[
None, :
],
}
return outputs
@ -159,7 +193,7 @@ class SenseVoiceDataset(torch.utils.data.Dataset):
def _filter_badcase(self, outputs, i=0):
b, t, _ = outputs["speech"].shape
if b * t > self.batch_size * 1.25:
beg = torch.randint(0, 2, ()).item()
if b < 2:
@ -170,7 +204,6 @@ class SenseVoiceDataset(torch.utils.data.Dataset):
for key, data_list in outputs.items():
outputs[key] = outputs[key][beg : beg + b : 2]
speech_lengths_max = outputs["speech_lengths"].max().item()
outputs["speech"] = outputs["speech"][:, :speech_lengths_max, :]
text_lengths_max = outputs["text_lengths"].max().item()

View File

@ -116,6 +116,7 @@ class Trainer:
self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False)
self.start_data_split_i = 0
self.start_step = 0
self.step_cur_in_epoch = 0
self.use_wandb = kwargs.get("use_wandb", False)
if self.use_wandb:
wandb.login(key=kwargs.get("wandb_token"))
@ -137,6 +138,8 @@ class Trainer:
optim=None,
scheduler=None,
scaler=None,
step_cur_in_epoch=None,
**kwargs,
):
"""
Saves a checkpoint containing the model's state, the optimizer's state,
@ -147,6 +150,7 @@ class Trainer:
epoch (int): The epoch number at which the checkpoint is being saved.
"""
step_cur_in_epoch = None if step is None else step_cur_in_epoch
if self.rank == 0:
logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
# self.step_or_epoch += 1
@ -161,7 +165,12 @@ class Trainer:
"best_step_or_epoch": self.best_step_or_epoch,
"avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
"step": step,
"step_cur_in_epoch": step_cur_in_epoch,
"data_split_i": kwargs.get("data_split_i", 0),
"data_split_num": kwargs.get("data_split_num", 1),
"batch_total": self.batch_total,
}
step = step_cur_in_epoch
if hasattr(model, "module"):
state["state_dict"] = model.module.state_dict()
@ -293,6 +302,12 @@ class Trainer:
self.batch_total = checkpoint["batch_total"] if "batch_total" in checkpoint else 0
self.start_step = checkpoint["step"] if "step" in checkpoint else 0
self.start_step = 0 if self.start_step is None else self.start_step
self.step_cur_in_epoch = (
checkpoint["step_cur_in_epoch"] if "step_cur_in_epoch" in checkpoint else 0
)
self.step_cur_in_epoch = (
0 if self.step_cur_in_epoch is None else self.step_cur_in_epoch
)
model.to(self.device)
print(f"Checkpoint loaded successfully from '{ckpt}'")
@ -321,7 +336,7 @@ class Trainer:
"""
if self.use_ddp or self.use_fsdp:
dist.barrier()
logging.info(f"Train epoch: {epoch}, rank: {self.local_rank}\n")
logging.info(f"Train epoch: {epoch}, rank: {self.rank}\n")
model.train()
# Set the number of steps for gradient accumulation
@ -341,6 +356,7 @@ class Trainer:
if iterator_stop > 0:
break
self.batch_total += 1
self.step_cur_in_epoch += 1
time1 = time.perf_counter()
speed_stats["data_load"] = f"{time1-time_beg:0.3f}"
@ -443,6 +459,7 @@ class Trainer:
self.log(
epoch,
batch_idx,
step_cur_in_epoch=self.step_cur_in_epoch,
batch_num_epoch=batch_num_epoch,
lr=lr,
loss=loss.detach().cpu().item(),
@ -461,6 +478,7 @@ class Trainer:
epoch=epoch,
writer=writer,
step=batch_idx + 1,
step_cur_in_epoch=self.step_cur_in_epoch,
)
if (batch_idx + 1) % self.save_checkpoint_interval == 0:
@ -471,6 +489,9 @@ class Trainer:
scheduler=scheduler,
scaler=scaler,
step=batch_idx + 1,
step_cur_in_epoch=self.step_cur_in_epoch,
data_split_i=kwargs.get("data_split_i", 0),
data_split_num=kwargs.get("data_split_num", 1),
)
time_beg = time.perf_counter()
@ -500,7 +521,7 @@ class Trainer:
"""
if self.use_ddp or self.use_fsdp:
dist.barrier()
logging.info(f"Validate epoch: {epoch}, rank: {self.local_rank}\n")
logging.info(f"Validate epoch: {epoch}, rank: {self.rank}\n")
model.eval()
with torch.no_grad():
@ -578,10 +599,10 @@ class Trainer:
iterator_stop.fill_(1)
dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
if kwargs.get("step", None) is None:
if kwargs.get("step_cur_in_epoch", None) is None:
ckpt_name = f"model.pt.ep{epoch}"
else:
ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step")}'
ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step_cur_in_epoch")}'
self.val_acc_step_or_eoch[ckpt_name] = self.val_acc_avg
self.val_loss_step_or_eoch[ckpt_name] = self.val_loss_avg
model.train()
@ -594,6 +615,7 @@ class Trainer:
self,
epoch=0,
batch_idx=0,
step_cur_in_epoch=0,
batch_num_epoch=-1,
lr=0.0,
loss=0.0,
@ -626,6 +648,7 @@ class Trainer:
f"{tag}, "
f"rank: {self.rank}, "
f"epoch: {epoch}/{self.max_epoch}, "
f"step_cur_in_epoch: {step_cur_in_epoch}, "
f"data_slice: {data_split_i}/{data_split_num}, "
f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, "
f"(loss_avg_rank: {loss:.3f}), "