Merge branch 'dev_gzf_exp' of github.com:alibaba-damo-academy/FunASR into dev_gzf_exp

merge
This commit is contained in:
游雁 2024-04-28 21:18:45 +08:00
commit b76af7be8c
4 changed files with 123 additions and 66 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 178 KiB

After

Width:  |  Height:  |  Size: 182 KiB

View File

@ -13,7 +13,7 @@ from io import BytesIO
from contextlib import nullcontext
import torch.distributed as dist
from collections.abc import Sequence
from omegaconf import DictConfig, OmegaConf
from torch.cuda.amp import autocast, GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
@ -99,7 +99,7 @@ def main(**kwargs):
if freeze_param is not None:
if "," in freeze_param:
freeze_param = eval(freeze_param)
if not isinstance(freeze_param, Sequence):
if not isinstance(freeze_param, (list, tuple)):
freeze_param = (freeze_param,)
logging.info("freeze_param is not None: %s", freeze_param)
for t in freeze_param:
@ -193,7 +193,7 @@ def main(**kwargs):
try:
from tensorboardX import SummaryWriter
writer = SummaryWriter(tensorboard_dir) if trainer.rank == 0 else None
writer = SummaryWriter(tensorboard_dir) # if trainer.rank == 0 else None
except:
writer = None
@ -206,6 +206,7 @@ def main(**kwargs):
epoch, data_split_i=data_split_i, start_step=trainer.start_step
)
trainer.start_step = 0
trainer.train_epoch(
model=model,
optim=optim,
@ -222,11 +223,13 @@ def main(**kwargs):
torch.cuda.empty_cache()
trainer.validate_epoch(
model=model, dataloader_val=dataloader_val, epoch=epoch, writer=writer
model=model, dataloader_val=dataloader_val, epoch=epoch + 1, writer=writer
)
scheduler.step()
trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler)
trainer.step_in_epoch = 0
trainer.save_checkpoint(
epoch + 1, model=model, optim=optim, scheduler=scheduler, scaler=scaler
)
time2 = time.perf_counter()
time_escaped = (time2 - time1) / 3600.0

View File

@ -51,6 +51,7 @@ class SenseVoiceDataset(torch.utils.data.Dataset):
self.batch_size = kwargs.get("batch_size")
self.batch_type = kwargs.get("batch_type")
self.prompt_ids_len = 0
self.retry = kwargs.get("retry", 5)
def get_source_len(self, index):
item = self.index_ds[index]
@ -64,59 +65,75 @@ class SenseVoiceDataset(torch.utils.data.Dataset):
return len(self.index_ds)
def __getitem__(self, index):
item = self.index_ds[index]
# import pdb;
# pdb.set_trace()
source = item["source"]
data_src = load_audio_text_image_video(source, fs=self.fs)
if self.preprocessor_speech:
data_src = self.preprocessor_speech(data_src, fs=self.fs)
speech, speech_lengths = extract_fbank(
data_src, data_type=self.data_type, frontend=self.frontend, is_final=True
) # speech: [b, T, d]
if speech_lengths > self.batch_size:
return None
speech = speech.permute(0, 2, 1)
target = item["target"]
if self.preprocessor_text:
target = self.preprocessor_text(target)
output = None
for idx in range(self.retry):
if idx == 0:
index_cur = index
else:
if index <= self.retry:
index_cur = index + idx
else:
index_cur = torch.randint(0, index, ()).item()
task = item.get("prompt", "<|ASR|>")
text_language = item.get("text_language", "<|zh|>")
item = self.index_ds[index_cur]
prompt = f"{self.sos}{task}{text_language}"
prompt_ids = self.tokenizer.encode(prompt, allowed_special="all")
prompt_ids_len = len(prompt_ids) - 1 # [sos, task]
self.prompt_ids_len = prompt_ids_len
source = item["source"]
data_src = load_audio_text_image_video(source, fs=self.fs)
if self.preprocessor_speech:
data_src = self.preprocessor_speech(data_src, fs=self.fs)
speech, speech_lengths = extract_fbank(
data_src, data_type=self.data_type, frontend=self.frontend, is_final=True
) # speech: [b, T, d]
target_ids = self.tokenizer.encode(target, allowed_special="all")
target_ids_len = len(target_ids) + 1 # [lid, text]
if target_ids_len > 200:
return None
if speech_lengths > self.batch_size:
continue
speech = speech.permute(0, 2, 1)
target = item["target"]
if self.preprocessor_text:
target = self.preprocessor_text(target)
eos = self.tokenizer.encode(self.eos, allowed_special="all") # [eos]
task = item.get("prompt", "<|ASR|>")
text_language = item.get("text_language", "<|zh|>")
ids = prompt_ids + target_ids + eos
ids_lengths = len(ids)
prompt = f"{self.sos}{task}{text_language}"
prompt_ids = self.tokenizer.encode(prompt, allowed_special="all")
prompt_ids_len = len(prompt_ids) - 1 # [sos, task]
self.prompt_ids_len = prompt_ids_len
text = torch.tensor(ids, dtype=torch.int64)
text_lengths = torch.tensor([ids_lengths], dtype=torch.int32)
target_ids = self.tokenizer.encode(target, allowed_special="all")
target_ids_len = len(target_ids) + 1 # [lid, text]
if target_ids_len > 200:
continue
target_mask = (
[0] * (prompt_ids_len) + [1] * (target_ids_len) + [1]
) # [sos, task, lid, text, eos]: [0, 0, 1, 1, 1]
target_mask_lengths = len(target_mask)
target_mask = torch.tensor(target_mask, dtype=torch.float32)
target_mask_lengths = torch.tensor([target_mask_lengths], dtype=torch.int32)
return {
"speech": speech[0, :, :],
"speech_lengths": speech_lengths,
"text": text,
"text_lengths": text_lengths,
"target_mask": target_mask,
"target_mask_lengths": target_mask_lengths,
}
eos = self.tokenizer.encode(self.eos, allowed_special="all") # [eos]
ids = prompt_ids + target_ids + eos
ids_lengths = len(ids)
text = torch.tensor(ids, dtype=torch.int64)
text_lengths = torch.tensor([ids_lengths], dtype=torch.int32)
target_mask = (
[0] * (prompt_ids_len) + [1] * (target_ids_len) + [1]
) # [sos, task, lid, text, eos]: [0, 0, 1, 1, 1]
target_mask_lengths = len(target_mask)
target_mask = torch.tensor(target_mask, dtype=torch.float32)
target_mask_lengths = torch.tensor([target_mask_lengths], dtype=torch.int32)
output = {
"speech": speech[0, :, :],
"speech_lengths": speech_lengths,
"text": text,
"text_lengths": text_lengths,
"target_mask": target_mask,
"target_mask_lengths": target_mask_lengths,
}
break
return output
def collator(self, samples: list = None):
outputs = {}
@ -129,13 +146,30 @@ class SenseVoiceDataset(torch.utils.data.Dataset):
outputs[key].append(sample[key])
if len(outputs) < 1:
logging.info(f"ERROR: data is empty!")
logging.error(f"ERROR: data is empty!")
outputs = {
"speech": torch.rand((10, 128), dtype=torch.float32),
"speech_lengths": torch.tensor([10], dtype=torch.int32),
"text": torch.tensor([58836], dtype=torch.int32),
"text_lengths": torch.tensor([1], dtype=torch.int32),
"target_mask": torch.tensor([[0] * (self.prompt_ids_len) + [1] * (1) + [1]]),
"speech": torch.rand((10, 128), dtype=torch.float32)[None, :, :],
"speech_lengths": torch.tensor(
[
10,
],
dtype=torch.int32,
)[:, None],
"text": torch.tensor(
[
58836,
],
dtype=torch.int32,
)[None, :],
"text_lengths": torch.tensor(
[
1,
],
dtype=torch.int32,
)[:, None],
"target_mask": torch.tensor([[0] * (self.prompt_ids_len) + [1] * (1) + [1]])[
None, :
],
}
return outputs

View File

@ -116,6 +116,7 @@ class Trainer:
self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False)
self.start_data_split_i = 0
self.start_step = 0
self.step_in_epoch = 0
self.use_wandb = kwargs.get("use_wandb", False)
if self.use_wandb:
wandb.login(key=kwargs.get("wandb_token"))
@ -137,6 +138,8 @@ class Trainer:
optim=None,
scheduler=None,
scaler=None,
step_in_epoch=None,
**kwargs,
):
"""
Saves a checkpoint containing the model's state, the optimizer's state,
@ -147,6 +150,7 @@ class Trainer:
epoch (int): The epoch number at which the checkpoint is being saved.
"""
step_in_epoch = None if step is None else step_in_epoch
if self.rank == 0:
logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
# self.step_or_epoch += 1
@ -161,7 +165,12 @@ class Trainer:
"best_step_or_epoch": self.best_step_or_epoch,
"avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
"step": step,
"step_in_epoch": step_in_epoch,
"data_split_i": kwargs.get("data_split_i", 0),
"data_split_num": kwargs.get("data_split_num", 1),
"batch_total": self.batch_total,
}
step = step_in_epoch
if hasattr(model, "module"):
state["state_dict"] = model.module.state_dict()
@ -195,7 +204,7 @@ class Trainer:
)
else:
logging.info(
f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]:.4f} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}"
f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]:.4f} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
)
elif self.avg_keep_nbest_models_type == "loss":
if (
@ -210,7 +219,7 @@ class Trainer:
)
else:
logging.info(
f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]:.4f} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}"
f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]:.4f} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
)
else:
print("Undo")
@ -251,7 +260,7 @@ class Trainer:
ckpt = os.path.join(self.output_dir, "model.pt")
if os.path.isfile(ckpt):
checkpoint = torch.load(ckpt, map_location="cpu")
self.start_epoch = checkpoint["epoch"] + 1
self.start_epoch = checkpoint["epoch"]
# self.model.load_state_dict(checkpoint['state_dict'])
src_state = checkpoint["state_dict"]
dst_state = model.state_dict()
@ -288,11 +297,15 @@ class Trainer:
checkpoint["best_step_or_epoch"] if "best_step_or_epoch" in checkpoint else ""
)
self.start_data_split_i = (
checkpoint["start_data_split_i"] if "start_data_split_i" in checkpoint else 0
checkpoint["data_split_i"] if "data_split_i" in checkpoint else 0
)
self.batch_total = checkpoint["batch_total"] if "batch_total" in checkpoint else 0
self.start_step = checkpoint["step"] if "step" in checkpoint else 0
self.start_step = 0 if self.start_step is None else self.start_step
self.step_in_epoch = (
checkpoint["step_in_epoch"] if "step_in_epoch" in checkpoint else 0
)
self.step_in_epoch = 0 if self.step_in_epoch is None else self.step_in_epoch
model.to(self.device)
print(f"Checkpoint loaded successfully from '{ckpt}'")
@ -321,7 +334,7 @@ class Trainer:
"""
if self.use_ddp or self.use_fsdp:
dist.barrier()
logging.info(f"Train epoch: {epoch}, rank: {self.local_rank}\n")
logging.info(f"Train epoch: {epoch}, rank: {self.rank}\n")
model.train()
# Set the number of steps for gradient accumulation
@ -341,6 +354,7 @@ class Trainer:
if iterator_stop > 0:
break
self.batch_total += 1
self.step_in_epoch += 1
time1 = time.perf_counter()
speed_stats["data_load"] = f"{time1-time_beg:0.3f}"
@ -443,6 +457,7 @@ class Trainer:
self.log(
epoch,
batch_idx,
step_in_epoch=self.step_in_epoch,
batch_num_epoch=batch_num_epoch,
lr=lr,
loss=loss.detach().cpu().item(),
@ -454,16 +469,17 @@ class Trainer:
data_split_num=kwargs.get("data_split_num", 1),
)
if (batch_idx + 1) % self.validate_interval == 0:
if self.step_in_epoch % self.validate_interval == 0:
self.validate_epoch(
model=model,
dataloader_val=dataloader_val,
epoch=epoch,
writer=writer,
step=batch_idx + 1,
step_in_epoch=self.step_in_epoch,
)
if (batch_idx + 1) % self.save_checkpoint_interval == 0:
if self.step_in_epoch % self.save_checkpoint_interval == 0:
self.save_checkpoint(
epoch,
model=model,
@ -471,6 +487,9 @@ class Trainer:
scheduler=scheduler,
scaler=scaler,
step=batch_idx + 1,
step_in_epoch=self.step_in_epoch,
data_split_i=kwargs.get("data_split_i", 0),
data_split_num=kwargs.get("data_split_num", 1),
)
time_beg = time.perf_counter()
@ -500,7 +519,7 @@ class Trainer:
"""
if self.use_ddp or self.use_fsdp:
dist.barrier()
logging.info(f"Validate epoch: {epoch}, rank: {self.local_rank}\n")
logging.info(f"Validate epoch: {epoch}, rank: {self.rank}\n")
model.eval()
with torch.no_grad():
@ -578,10 +597,10 @@ class Trainer:
iterator_stop.fill_(1)
dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
if kwargs.get("step", None) is None:
if kwargs.get("step_in_epoch", None) is None:
ckpt_name = f"model.pt.ep{epoch}"
else:
ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step")}'
ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step_in_epoch")}'
self.val_acc_step_or_eoch[ckpt_name] = self.val_acc_avg
self.val_loss_step_or_eoch[ckpt_name] = self.val_loss_avg
model.train()
@ -594,6 +613,7 @@ class Trainer:
self,
epoch=0,
batch_idx=0,
step_in_epoch=0,
batch_num_epoch=-1,
lr=0.0,
loss=0.0,
@ -627,7 +647,7 @@ class Trainer:
f"rank: {self.rank}, "
f"epoch: {epoch}/{self.max_epoch}, "
f"data_slice: {data_split_i}/{data_split_num}, "
f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, "
f"step_in_slice: {batch_idx + 1}/{batch_num_epoch}, step_in_epoch: {step_in_epoch}, total step: {self.batch_total}, "
f"(loss_avg_rank: {loss:.3f}), "
f"(loss_avg_epoch: {loss_avg_epoch:.3f}), "
f"(ppl_avg_epoch: {math.exp(loss_avg_epoch):.3e}), "