mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Dev gzf exp (#1670)
* resume from step * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch
This commit is contained in:
parent
1cdb3cc28d
commit
93ef505e2d
Binary file not shown.
|
Before Width: | Height: | Size: 178 KiB After Width: | Height: | Size: 182 KiB |
@ -13,7 +13,7 @@ from io import BytesIO
|
||||
|
||||
from contextlib import nullcontext
|
||||
import torch.distributed as dist
|
||||
from collections.abc import Sequence
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@ -99,7 +99,7 @@ def main(**kwargs):
|
||||
if freeze_param is not None:
|
||||
if "," in freeze_param:
|
||||
freeze_param = eval(freeze_param)
|
||||
if not isinstance(freeze_param, Sequence):
|
||||
if not isinstance(freeze_param, (list, tuple)):
|
||||
freeze_param = (freeze_param,)
|
||||
logging.info("freeze_param is not None: %s", freeze_param)
|
||||
for t in freeze_param:
|
||||
@ -193,7 +193,7 @@ def main(**kwargs):
|
||||
try:
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
writer = SummaryWriter(tensorboard_dir) if trainer.rank == 0 else None
|
||||
writer = SummaryWriter(tensorboard_dir) # if trainer.rank == 0 else None
|
||||
except:
|
||||
writer = None
|
||||
|
||||
@ -206,6 +206,7 @@ def main(**kwargs):
|
||||
epoch, data_split_i=data_split_i, start_step=trainer.start_step
|
||||
)
|
||||
trainer.start_step = 0
|
||||
|
||||
trainer.train_epoch(
|
||||
model=model,
|
||||
optim=optim,
|
||||
@ -225,7 +226,7 @@ def main(**kwargs):
|
||||
model=model, dataloader_val=dataloader_val, epoch=epoch, writer=writer
|
||||
)
|
||||
scheduler.step()
|
||||
|
||||
trainer.step_cur_in_epoch = 0
|
||||
trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler)
|
||||
|
||||
time2 = time.perf_counter()
|
||||
|
||||
@ -51,6 +51,7 @@ class SenseVoiceDataset(torch.utils.data.Dataset):
|
||||
self.batch_size = kwargs.get("batch_size")
|
||||
self.batch_type = kwargs.get("batch_type")
|
||||
self.prompt_ids_len = 0
|
||||
self.retry = kwargs.get("retry", 5)
|
||||
|
||||
def get_source_len(self, index):
|
||||
item = self.index_ds[index]
|
||||
@ -64,59 +65,75 @@ class SenseVoiceDataset(torch.utils.data.Dataset):
|
||||
return len(self.index_ds)
|
||||
|
||||
def __getitem__(self, index):
|
||||
item = self.index_ds[index]
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
source = item["source"]
|
||||
data_src = load_audio_text_image_video(source, fs=self.fs)
|
||||
if self.preprocessor_speech:
|
||||
data_src = self.preprocessor_speech(data_src, fs=self.fs)
|
||||
speech, speech_lengths = extract_fbank(
|
||||
data_src, data_type=self.data_type, frontend=self.frontend, is_final=True
|
||||
) # speech: [b, T, d]
|
||||
|
||||
if speech_lengths > self.batch_size:
|
||||
return None
|
||||
speech = speech.permute(0, 2, 1)
|
||||
target = item["target"]
|
||||
if self.preprocessor_text:
|
||||
target = self.preprocessor_text(target)
|
||||
output = None
|
||||
for idx in range(self.retry):
|
||||
if idx == 0:
|
||||
index_cur = index
|
||||
else:
|
||||
if index <= self.retry:
|
||||
index_cur = index + idx
|
||||
else:
|
||||
index_cur = torch.randint(0, index, ()).item()
|
||||
|
||||
task = item.get("prompt", "<|ASR|>")
|
||||
text_language = item.get("text_language", "<|zh|>")
|
||||
item = self.index_ds[index_cur]
|
||||
|
||||
prompt = f"{self.sos}{task}{text_language}"
|
||||
prompt_ids = self.tokenizer.encode(prompt, allowed_special="all")
|
||||
prompt_ids_len = len(prompt_ids) - 1 # [sos, task]
|
||||
self.prompt_ids_len = prompt_ids_len
|
||||
source = item["source"]
|
||||
data_src = load_audio_text_image_video(source, fs=self.fs)
|
||||
if self.preprocessor_speech:
|
||||
data_src = self.preprocessor_speech(data_src, fs=self.fs)
|
||||
speech, speech_lengths = extract_fbank(
|
||||
data_src, data_type=self.data_type, frontend=self.frontend, is_final=True
|
||||
) # speech: [b, T, d]
|
||||
|
||||
target_ids = self.tokenizer.encode(target, allowed_special="all")
|
||||
target_ids_len = len(target_ids) + 1 # [lid, text]
|
||||
if target_ids_len > 200:
|
||||
return None
|
||||
if speech_lengths > self.batch_size:
|
||||
continue
|
||||
speech = speech.permute(0, 2, 1)
|
||||
target = item["target"]
|
||||
if self.preprocessor_text:
|
||||
target = self.preprocessor_text(target)
|
||||
|
||||
eos = self.tokenizer.encode(self.eos, allowed_special="all") # [eos]
|
||||
task = item.get("prompt", "<|ASR|>")
|
||||
text_language = item.get("text_language", "<|zh|>")
|
||||
|
||||
ids = prompt_ids + target_ids + eos
|
||||
ids_lengths = len(ids)
|
||||
prompt = f"{self.sos}{task}{text_language}"
|
||||
prompt_ids = self.tokenizer.encode(prompt, allowed_special="all")
|
||||
prompt_ids_len = len(prompt_ids) - 1 # [sos, task]
|
||||
self.prompt_ids_len = prompt_ids_len
|
||||
|
||||
text = torch.tensor(ids, dtype=torch.int64)
|
||||
text_lengths = torch.tensor([ids_lengths], dtype=torch.int32)
|
||||
target_ids = self.tokenizer.encode(target, allowed_special="all")
|
||||
target_ids_len = len(target_ids) + 1 # [lid, text]
|
||||
if target_ids_len > 200:
|
||||
continue
|
||||
|
||||
target_mask = (
|
||||
[0] * (prompt_ids_len) + [1] * (target_ids_len) + [1]
|
||||
) # [sos, task, lid, text, eos]: [0, 0, 1, 1, 1]
|
||||
target_mask_lengths = len(target_mask)
|
||||
target_mask = torch.tensor(target_mask, dtype=torch.float32)
|
||||
target_mask_lengths = torch.tensor([target_mask_lengths], dtype=torch.int32)
|
||||
return {
|
||||
"speech": speech[0, :, :],
|
||||
"speech_lengths": speech_lengths,
|
||||
"text": text,
|
||||
"text_lengths": text_lengths,
|
||||
"target_mask": target_mask,
|
||||
"target_mask_lengths": target_mask_lengths,
|
||||
}
|
||||
eos = self.tokenizer.encode(self.eos, allowed_special="all") # [eos]
|
||||
|
||||
ids = prompt_ids + target_ids + eos
|
||||
ids_lengths = len(ids)
|
||||
|
||||
text = torch.tensor(ids, dtype=torch.int64)
|
||||
text_lengths = torch.tensor([ids_lengths], dtype=torch.int32)
|
||||
|
||||
target_mask = (
|
||||
[0] * (prompt_ids_len) + [1] * (target_ids_len) + [1]
|
||||
) # [sos, task, lid, text, eos]: [0, 0, 1, 1, 1]
|
||||
target_mask_lengths = len(target_mask)
|
||||
target_mask = torch.tensor(target_mask, dtype=torch.float32)
|
||||
target_mask_lengths = torch.tensor([target_mask_lengths], dtype=torch.int32)
|
||||
|
||||
output = {
|
||||
"speech": speech[0, :, :],
|
||||
"speech_lengths": speech_lengths,
|
||||
"text": text,
|
||||
"text_lengths": text_lengths,
|
||||
"target_mask": target_mask,
|
||||
"target_mask_lengths": target_mask_lengths,
|
||||
}
|
||||
break
|
||||
|
||||
return output
|
||||
|
||||
def collator(self, samples: list = None):
|
||||
outputs = {}
|
||||
@ -129,13 +146,30 @@ class SenseVoiceDataset(torch.utils.data.Dataset):
|
||||
outputs[key].append(sample[key])
|
||||
|
||||
if len(outputs) < 1:
|
||||
logging.info(f"ERROR: data is empty!")
|
||||
logging.error(f"ERROR: data is empty!")
|
||||
outputs = {
|
||||
"speech": torch.rand((10, 128), dtype=torch.float32),
|
||||
"speech_lengths": torch.tensor([10], dtype=torch.int32),
|
||||
"text": torch.tensor([58836], dtype=torch.int32),
|
||||
"text_lengths": torch.tensor([1], dtype=torch.int32),
|
||||
"target_mask": torch.tensor([[0] * (self.prompt_ids_len) + [1] * (1) + [1]]),
|
||||
"speech": torch.rand((10, 128), dtype=torch.float32)[None, :, :],
|
||||
"speech_lengths": torch.tensor(
|
||||
[
|
||||
10,
|
||||
],
|
||||
dtype=torch.int32,
|
||||
)[:, None],
|
||||
"text": torch.tensor(
|
||||
[
|
||||
58836,
|
||||
],
|
||||
dtype=torch.int32,
|
||||
)[None, :],
|
||||
"text_lengths": torch.tensor(
|
||||
[
|
||||
1,
|
||||
],
|
||||
dtype=torch.int32,
|
||||
)[:, None],
|
||||
"target_mask": torch.tensor([[0] * (self.prompt_ids_len) + [1] * (1) + [1]])[
|
||||
None, :
|
||||
],
|
||||
}
|
||||
return outputs
|
||||
|
||||
@ -159,7 +193,7 @@ class SenseVoiceDataset(torch.utils.data.Dataset):
|
||||
|
||||
def _filter_badcase(self, outputs, i=0):
|
||||
b, t, _ = outputs["speech"].shape
|
||||
|
||||
|
||||
if b * t > self.batch_size * 1.25:
|
||||
beg = torch.randint(0, 2, ()).item()
|
||||
if b < 2:
|
||||
@ -170,7 +204,6 @@ class SenseVoiceDataset(torch.utils.data.Dataset):
|
||||
for key, data_list in outputs.items():
|
||||
outputs[key] = outputs[key][beg : beg + b : 2]
|
||||
|
||||
|
||||
speech_lengths_max = outputs["speech_lengths"].max().item()
|
||||
outputs["speech"] = outputs["speech"][:, :speech_lengths_max, :]
|
||||
text_lengths_max = outputs["text_lengths"].max().item()
|
||||
|
||||
@ -116,6 +116,7 @@ class Trainer:
|
||||
self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False)
|
||||
self.start_data_split_i = 0
|
||||
self.start_step = 0
|
||||
self.step_cur_in_epoch = 0
|
||||
self.use_wandb = kwargs.get("use_wandb", False)
|
||||
if self.use_wandb:
|
||||
wandb.login(key=kwargs.get("wandb_token"))
|
||||
@ -137,6 +138,8 @@ class Trainer:
|
||||
optim=None,
|
||||
scheduler=None,
|
||||
scaler=None,
|
||||
step_cur_in_epoch=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Saves a checkpoint containing the model's state, the optimizer's state,
|
||||
@ -147,6 +150,7 @@ class Trainer:
|
||||
epoch (int): The epoch number at which the checkpoint is being saved.
|
||||
"""
|
||||
|
||||
step_cur_in_epoch = None if step is None else step_cur_in_epoch
|
||||
if self.rank == 0:
|
||||
logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
|
||||
# self.step_or_epoch += 1
|
||||
@ -161,7 +165,12 @@ class Trainer:
|
||||
"best_step_or_epoch": self.best_step_or_epoch,
|
||||
"avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
|
||||
"step": step,
|
||||
"step_cur_in_epoch": step_cur_in_epoch,
|
||||
"data_split_i": kwargs.get("data_split_i", 0),
|
||||
"data_split_num": kwargs.get("data_split_num", 1),
|
||||
"batch_total": self.batch_total,
|
||||
}
|
||||
step = step_cur_in_epoch
|
||||
if hasattr(model, "module"):
|
||||
state["state_dict"] = model.module.state_dict()
|
||||
|
||||
@ -293,6 +302,12 @@ class Trainer:
|
||||
self.batch_total = checkpoint["batch_total"] if "batch_total" in checkpoint else 0
|
||||
self.start_step = checkpoint["step"] if "step" in checkpoint else 0
|
||||
self.start_step = 0 if self.start_step is None else self.start_step
|
||||
self.step_cur_in_epoch = (
|
||||
checkpoint["step_cur_in_epoch"] if "step_cur_in_epoch" in checkpoint else 0
|
||||
)
|
||||
self.step_cur_in_epoch = (
|
||||
0 if self.step_cur_in_epoch is None else self.step_cur_in_epoch
|
||||
)
|
||||
|
||||
model.to(self.device)
|
||||
print(f"Checkpoint loaded successfully from '{ckpt}'")
|
||||
@ -321,7 +336,7 @@ class Trainer:
|
||||
"""
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.barrier()
|
||||
logging.info(f"Train epoch: {epoch}, rank: {self.local_rank}\n")
|
||||
logging.info(f"Train epoch: {epoch}, rank: {self.rank}\n")
|
||||
model.train()
|
||||
|
||||
# Set the number of steps for gradient accumulation
|
||||
@ -341,6 +356,7 @@ class Trainer:
|
||||
if iterator_stop > 0:
|
||||
break
|
||||
self.batch_total += 1
|
||||
self.step_cur_in_epoch += 1
|
||||
time1 = time.perf_counter()
|
||||
speed_stats["data_load"] = f"{time1-time_beg:0.3f}"
|
||||
|
||||
@ -443,6 +459,7 @@ class Trainer:
|
||||
self.log(
|
||||
epoch,
|
||||
batch_idx,
|
||||
step_cur_in_epoch=self.step_cur_in_epoch,
|
||||
batch_num_epoch=batch_num_epoch,
|
||||
lr=lr,
|
||||
loss=loss.detach().cpu().item(),
|
||||
@ -461,6 +478,7 @@ class Trainer:
|
||||
epoch=epoch,
|
||||
writer=writer,
|
||||
step=batch_idx + 1,
|
||||
step_cur_in_epoch=self.step_cur_in_epoch,
|
||||
)
|
||||
|
||||
if (batch_idx + 1) % self.save_checkpoint_interval == 0:
|
||||
@ -471,6 +489,9 @@ class Trainer:
|
||||
scheduler=scheduler,
|
||||
scaler=scaler,
|
||||
step=batch_idx + 1,
|
||||
step_cur_in_epoch=self.step_cur_in_epoch,
|
||||
data_split_i=kwargs.get("data_split_i", 0),
|
||||
data_split_num=kwargs.get("data_split_num", 1),
|
||||
)
|
||||
|
||||
time_beg = time.perf_counter()
|
||||
@ -500,7 +521,7 @@ class Trainer:
|
||||
"""
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.barrier()
|
||||
logging.info(f"Validate epoch: {epoch}, rank: {self.local_rank}\n")
|
||||
logging.info(f"Validate epoch: {epoch}, rank: {self.rank}\n")
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
@ -578,10 +599,10 @@ class Trainer:
|
||||
iterator_stop.fill_(1)
|
||||
dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
|
||||
|
||||
if kwargs.get("step", None) is None:
|
||||
if kwargs.get("step_cur_in_epoch", None) is None:
|
||||
ckpt_name = f"model.pt.ep{epoch}"
|
||||
else:
|
||||
ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step")}'
|
||||
ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step_cur_in_epoch")}'
|
||||
self.val_acc_step_or_eoch[ckpt_name] = self.val_acc_avg
|
||||
self.val_loss_step_or_eoch[ckpt_name] = self.val_loss_avg
|
||||
model.train()
|
||||
@ -594,6 +615,7 @@ class Trainer:
|
||||
self,
|
||||
epoch=0,
|
||||
batch_idx=0,
|
||||
step_cur_in_epoch=0,
|
||||
batch_num_epoch=-1,
|
||||
lr=0.0,
|
||||
loss=0.0,
|
||||
@ -626,6 +648,7 @@ class Trainer:
|
||||
f"{tag}, "
|
||||
f"rank: {self.rank}, "
|
||||
f"epoch: {epoch}/{self.max_epoch}, "
|
||||
f"step_cur_in_epoch: {step_cur_in_epoch}, "
|
||||
f"data_slice: {data_split_i}/{data_split_num}, "
|
||||
f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, "
|
||||
f"(loss_avg_rank: {loss:.3f}), "
|
||||
|
||||
Loading…
Reference in New Issue
Block a user