* funasr1.0 funetine

* funasr1.0 pbar

* update with main (#1260)

* Update websocket_protocol_zh.md

* update

---------

Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>
Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>

---------

Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>
Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>
This commit is contained in:
zhifu gao 2024-01-17 18:28:28 +08:00 committed by GitHub
parent b1857837dd
commit 9a9c3b75b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 298 additions and 147 deletions

View File

@ -9,9 +9,11 @@
python funasr/bin/train.py \ python funasr/bin/train.py \
+model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \ +model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
+model_revision="v2.0.2" \ +model_revision="v2.0.2" \
+train_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len.jsonl" \ +train_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \
+valid_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \
++dataset_conf.batch_size=2 \ ++dataset_conf.batch_size=2 \
++dataset_conf.batch_type="example" \ ++dataset_conf.batch_type="example" \
++train_conf.max_epoch=2 \
+output_dir="outputs/debug/ckpt/funasr2/exp2" \ +output_dir="outputs/debug/ckpt/funasr2/exp2" \
+device="cpu" \ +device="cpu" \
+debug="true" +debug="true"

View File

@ -15,6 +15,6 @@ model = AutoModel(model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-co
spk_model_revision="v2.0.2", spk_model_revision="v2.0.2",
) )
res = model.generate(input=f"{model.model_path}/example/asr_example.wav", res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
hotword='达摩院 魔搭') hotword='达摩院 魔搭')
print(res) print(res)

View File

@ -221,7 +221,8 @@ class AutoModel:
speed_stats = {} speed_stats = {}
asr_result_list = [] asr_result_list = []
num_samples = len(data_list) num_samples = len(data_list)
pbar = tqdm(colour="blue", total=num_samples+1, dynamic_ncols=True) disable_pbar = kwargs.get("disable_pbar", False)
pbar = tqdm(colour="blue", total=num_samples+1, dynamic_ncols=True) if not disable_pbar else None
time_speech_total = 0.0 time_speech_total = 0.0
time_escape_total = 0.0 time_escape_total = 0.0
for beg_idx in range(0, num_samples, batch_size): for beg_idx in range(0, num_samples, batch_size):
@ -239,8 +240,7 @@ class AutoModel:
time2 = time.perf_counter() time2 = time.perf_counter()
asr_result_list.extend(results) asr_result_list.extend(results)
pbar.update(1)
# batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item() # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
batch_data_time = meta_data.get("batch_data_time", -1) batch_data_time = meta_data.get("batch_data_time", -1)
time_escape = time2 - time1 time_escape = time2 - time1
@ -252,12 +252,15 @@ class AutoModel:
description = ( description = (
f"{speed_stats}, " f"{speed_stats}, "
) )
pbar.set_description(description) if pbar:
pbar.update(1)
pbar.set_description(description)
time_speech_total += batch_data_time time_speech_total += batch_data_time
time_escape_total += time_escape time_escape_total += time_escape
pbar.update(1) if pbar:
pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}") pbar.update(1)
pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
torch.cuda.empty_cache() torch.cuda.empty_cache()
return asr_result_list return asr_result_list
@ -309,8 +312,11 @@ class AutoModel:
time_speech_total_per_sample = speech_lengths/16000 time_speech_total_per_sample = speech_lengths/16000
time_speech_total_all_samples += time_speech_total_per_sample time_speech_total_all_samples += time_speech_total_per_sample
pbar_sample = tqdm(colour="blue", total=n + 1, dynamic_ncols=True)
all_segments = [] all_segments = []
for j, _ in enumerate(range(0, n)): for j, _ in enumerate(range(0, n)):
pbar_sample.update(1)
batch_size_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0]) batch_size_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
if j < n - 1 and ( if j < n - 1 and (
batch_size_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size and ( batch_size_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size and (
@ -319,13 +325,14 @@ class AutoModel:
batch_size_ms_cum = 0 batch_size_ms_cum = 0
end_idx = j + 1 end_idx = j + 1
speech_j, speech_lengths_j = slice_padding_audio_samples(speech, speech_lengths, sorted_data[beg_idx:end_idx]) speech_j, speech_lengths_j = slice_padding_audio_samples(speech, speech_lengths, sorted_data[beg_idx:end_idx])
results = self.inference(speech_j, input_len=None, model=model, kwargs=kwargs, **cfg) results = self.inference(speech_j, input_len=None, model=model, kwargs=kwargs, disable_pbar=True, **cfg)
if self.spk_model is not None: if self.spk_model is not None:
# compose vad segments: [[start_time_sec, end_time_sec, speech], [...]] # compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
for _b in range(len(speech_j)): for _b in range(len(speech_j)):
vad_segments = [[sorted_data[beg_idx:end_idx][_b][0][0]/1000.0, \ vad_segments = [[sorted_data[beg_idx:end_idx][_b][0][0]/1000.0,
sorted_data[beg_idx:end_idx][_b][0][1]/1000.0, \ sorted_data[beg_idx:end_idx][_b][0][1]/1000.0,
speech_j[_b]]] speech_j[_b]]]
segments = sv_chunk(vad_segments) segments = sv_chunk(vad_segments)
all_segments.extend(segments) all_segments.extend(segments)
@ -338,12 +345,13 @@ class AutoModel:
results_sorted.extend(results) results_sorted.extend(results)
pbar_total.update(1)
end_asr_total = time.time() end_asr_total = time.time()
time_escape_total_per_sample = end_asr_total - beg_asr_total time_escape_total_per_sample = end_asr_total - beg_asr_total
pbar_total.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, " pbar_sample.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, " f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}") f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
restored_data = [0] * n restored_data = [0] * n
for j in range(n): for j in range(n):

View File

@ -141,30 +141,37 @@ def main(**kwargs):
scheduler_class = scheduler_classes.get(scheduler) scheduler_class = scheduler_classes.get(scheduler)
scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf")) scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
# import pdb;
# pdb.set_trace()
# dataset # dataset
dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset")) dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf")) dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer,
**kwargs.get("dataset_conf"))
# dataloader # dataloader
batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler") batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler) batch_sampler_val = None
if batch_sampler is not None: if batch_sampler is not None:
batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf")) batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
batch_sampler_val = batch_sampler_class(dataset_tr, is_training=False, **kwargs.get("dataset_conf"))
dataloader_tr = torch.utils.data.DataLoader(dataset_tr, dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
collate_fn=dataset_tr.collator, collate_fn=dataset_tr.collator,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
num_workers=kwargs.get("dataset_conf").get("num_workers", 4), num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
pin_memory=True) pin_memory=True)
dataloader_val = torch.utils.data.DataLoader(dataset_val,
collate_fn=dataset_val.collator,
batch_sampler=batch_sampler_val,
num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
pin_memory=True)
trainer = Trainer( trainer = Trainer(
model=model, model=model,
optim=optim, optim=optim,
scheduler=scheduler, scheduler=scheduler,
dataloader_train=dataloader_tr, dataloader_train=dataloader_tr,
dataloader_val=None, dataloader_val=dataloader_val,
local_rank=local_rank, local_rank=local_rank,
use_ddp=use_ddp, use_ddp=use_ddp,
use_fsdp=use_fsdp, use_fsdp=use_fsdp,

View File

@ -54,7 +54,11 @@ class IndexDSJsonl(torch.utils.data.Dataset):
return len(self.contents) return len(self.contents)
def __getitem__(self, index): def __getitem__(self, index):
return self.contents[index] try:
data = self.contents[index]
except:
print(index)
return data
def get_source_len(self, data_dict): def get_source_len(self, data_dict):
return data_dict["source_len"] return data_dict["source_len"]

View File

@ -13,6 +13,7 @@ class BatchSampler(torch.utils.data.BatchSampler):
buffer_size: int = 30, buffer_size: int = 30,
drop_last: bool = False, drop_last: bool = False,
shuffle: bool = True, shuffle: bool = True,
is_training: bool = True,
**kwargs): **kwargs):
self.drop_last = drop_last self.drop_last = drop_last
@ -24,7 +25,7 @@ class BatchSampler(torch.utils.data.BatchSampler):
self.buffer_size = buffer_size self.buffer_size = buffer_size
self.max_token_length = kwargs.get("max_token_length", 5000) self.max_token_length = kwargs.get("max_token_length", 5000)
self.shuffle_idx = np.arange(self.total_samples) self.shuffle_idx = np.arange(self.total_samples)
self.shuffle = shuffle self.shuffle = shuffle and is_training
def __len__(self): def __len__(self):
return self.total_samples return self.total_samples

View File

@ -164,6 +164,7 @@ class Paraformer(torch.nn.Module):
self.use_1st_decoder_loss = use_1st_decoder_loss self.use_1st_decoder_loss = use_1st_decoder_loss
self.length_normalized_loss = length_normalized_loss self.length_normalized_loss = length_normalized_loss
self.beam_search = None self.beam_search = None
self.error_calculator = None
def forward( def forward(
self, self,

View File

@ -95,6 +95,7 @@ train_conf:
- acc - acc
- max - max
keep_nbest_models: 10 keep_nbest_models: 10
avg_nbest_model: 5
log_interval: 50 log_interval: 50
optim: adam optim: adam

View File

@ -9,117 +9,173 @@ from io import BytesIO
import torch import torch
from typing import Collection from typing import Collection
import os
import torch
import re
from collections import OrderedDict
from functools import cmp_to_key
from funasr.train.reporter import Reporter
# @torch.no_grad()
# def average_nbest_models(
# output_dir: Path,
# best_model_criterion: Sequence[Sequence[str]],
# nbest: Union[Collection[int], int],
# suffix: Optional[str] = None,
# oss_bucket=None,
# pai_output_dir=None,
# ) -> None:
# """Generate averaged model from n-best models
#
# Args:
# output_dir: The directory contains the model file for each epoch
# reporter: Reporter instance
# best_model_criterion: Give criterions to decide the best model.
# e.g. [("valid", "loss", "min"), ("train", "acc", "max")]
# nbest: Number of best model files to be averaged
# suffix: A suffix added to the averaged model file name
# """
# if isinstance(nbest, int):
# nbests = [nbest]
# else:
# nbests = list(nbest)
# if len(nbests) == 0:
# warnings.warn("At least 1 nbest values are required")
# nbests = [1]
# if suffix is not None:
# suffix = suffix + "."
# else:
# suffix = ""
#
# # 1. Get nbests: List[Tuple[str, str, List[Tuple[epoch, value]]]]
# nbest_epochs = [
# (ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)])
# for ph, k, m in best_model_criterion
# if reporter.has(ph, k)
# ]
#
# _loaded = {}
# for ph, cr, epoch_and_values in nbest_epochs:
# _nbests = [i for i in nbests if i <= len(epoch_and_values)]
# if len(_nbests) == 0:
# _nbests = [1]
#
# for n in _nbests:
# if n == 0:
# continue
# elif n == 1:
# # The averaged model is same as the best model
# e, _ = epoch_and_values[0]
# op = output_dir / f"{e}epoch.pb"
# sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb"
# if sym_op.is_symlink() or sym_op.exists():
# sym_op.unlink()
# sym_op.symlink_to(op.name)
# else:
# op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb"
# logging.info(
# f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}'
# )
#
# avg = None
# # 2.a. Averaging model
# for e, _ in epoch_and_values[:n]:
# if e not in _loaded:
# if oss_bucket is None:
# _loaded[e] = torch.load(
# output_dir / f"{e}epoch.pb",
# map_location="cpu",
# )
# else:
# buffer = BytesIO(
# oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read())
# _loaded[e] = torch.load(buffer)
# states = _loaded[e]
#
# if avg is None:
# avg = states
# else:
# # Accumulated
# for k in avg:
# avg[k] = avg[k] + states[k]
# for k in avg:
# if str(avg[k].dtype).startswith("torch.int"):
# # For int type, not averaged, but only accumulated.
# # e.g. BatchNorm.num_batches_tracked
# # (If there are any cases that requires averaging
# # or the other reducing method, e.g. max/min, for integer type,
# # please report.)
# pass
# else:
# avg[k] = avg[k] / n
#
# # 2.b. Save the ave model and create a symlink
# if oss_bucket is None:
# torch.save(avg, op)
# else:
# buffer = BytesIO()
# torch.save(avg, buffer)
# oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"),
# buffer.getvalue())
#
# # 3. *.*.ave.pb is a symlink to the max ave model
# if oss_bucket is None:
# op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb"
# sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb"
# if sym_op.is_symlink() or sym_op.exists():
# sym_op.unlink()
# sym_op.symlink_to(op.name)
def _get_checkpoint_paths(output_dir: str, last_n: int=5):
"""
Get the paths of the last 'last_n' checkpoints by parsing filenames
in the output directory.
"""
# List all files in the output directory
files = os.listdir(output_dir)
# Filter out checkpoint files and extract epoch numbers
checkpoint_files = [f for f in files if f.startswith("model.pt.e")]
# Sort files by epoch number in descending order
checkpoint_files.sort(key=lambda x: int(re.search(r'(\d+)', x).group()), reverse=True)
# Get the last 'last_n' checkpoint paths
checkpoint_paths = [os.path.join(output_dir, f) for f in checkpoint_files[:last_n]]
return checkpoint_paths
@torch.no_grad() @torch.no_grad()
def average_nbest_models( def average_checkpoints(output_dir: str, last_n: int=5):
output_dir: Path,
reporter: Reporter,
best_model_criterion: Sequence[Sequence[str]],
nbest: Union[Collection[int], int],
suffix: Optional[str] = None,
oss_bucket=None,
pai_output_dir=None,
) -> None:
"""Generate averaged model from n-best models
Args:
output_dir: The directory contains the model file for each epoch
reporter: Reporter instance
best_model_criterion: Give criterions to decide the best model.
e.g. [("valid", "loss", "min"), ("train", "acc", "max")]
nbest: Number of best model files to be averaged
suffix: A suffix added to the averaged model file name
""" """
if isinstance(nbest, int): Average the last 'last_n' checkpoints' model state_dicts.
nbests = [nbest] If a tensor is of type torch.int, perform sum instead of average.
else: """
nbests = list(nbest) checkpoint_paths = _get_checkpoint_paths(output_dir, last_n)
if len(nbests) == 0: state_dicts = []
warnings.warn("At least 1 nbest values are required")
nbests = [1]
if suffix is not None:
suffix = suffix + "."
else:
suffix = ""
# 1. Get nbests: List[Tuple[str, str, List[Tuple[epoch, value]]]] # Load state_dicts from checkpoints
nbest_epochs = [ for path in checkpoint_paths:
(ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)]) if os.path.isfile(path):
for ph, k, m in best_model_criterion state_dicts.append(torch.load(path, map_location='cpu')['state_dict'])
if reporter.has(ph, k) else:
] print(f"Checkpoint file {path} not found.")
continue
_loaded = {} # Check if we have any state_dicts to average
for ph, cr, epoch_and_values in nbest_epochs: if not state_dicts:
_nbests = [i for i in nbests if i <= len(epoch_and_values)] raise RuntimeError("No checkpoints found for averaging.")
if len(_nbests) == 0:
_nbests = [1]
for n in _nbests: # Average or sum weights
if n == 0: avg_state_dict = OrderedDict()
continue for key in state_dicts[0].keys():
elif n == 1: tensors = [state_dict[key].cpu() for state_dict in state_dicts]
# The averaged model is same as the best model # Check the type of the tensor
e, _ = epoch_and_values[0] if str(tensors[0].dtype).startswith("torch.int"):
op = output_dir / f"{e}epoch.pb" # Perform sum for integer tensors
sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb" summed_tensor = sum(tensors)
if sym_op.is_symlink() or sym_op.exists(): avg_state_dict[key] = summed_tensor
sym_op.unlink() else:
sym_op.symlink_to(op.name) # Perform average for other types of tensors
else: stacked_tensors = torch.stack(tensors)
op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb" avg_state_dict[key] = torch.mean(stacked_tensors, dim=0)
logging.info(
f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}' torch.save({'state_dict': avg_state_dict}, os.path.join(output_dir, f"model.pt.avg{last_n}"))
) return avg_state_dict
avg = None
# 2.a. Averaging model
for e, _ in epoch_and_values[:n]:
if e not in _loaded:
if oss_bucket is None:
_loaded[e] = torch.load(
output_dir / f"{e}epoch.pb",
map_location="cpu",
)
else:
buffer = BytesIO(
oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read())
_loaded[e] = torch.load(buffer)
states = _loaded[e]
if avg is None:
avg = states
else:
# Accumulated
for k in avg:
avg[k] = avg[k] + states[k]
for k in avg:
if str(avg[k].dtype).startswith("torch.int"):
# For int type, not averaged, but only accumulated.
# e.g. BatchNorm.num_batches_tracked
# (If there are any cases that requires averaging
# or the other reducing method, e.g. max/min, for integer type,
# please report.)
pass
else:
avg[k] = avg[k] / n
# 2.b. Save the ave model and create a symlink
if oss_bucket is None:
torch.save(avg, op)
else:
buffer = BytesIO()
torch.save(avg, buffer)
oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"),
buffer.getvalue())
# 3. *.*.ave.pb is a symlink to the max ave model
if oss_bucket is None:
op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb"
sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb"
if sym_op.is_symlink() or sym_op.exists():
sym_op.unlink()
sym_op.symlink_to(op.name)

View File

@ -7,10 +7,11 @@ import torch.distributed as dist
from contextlib import nullcontext from contextlib import nullcontext
# from torch.utils.tensorboard import SummaryWriter # from torch.utils.tensorboard import SummaryWriter
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from pathlib import Path
from funasr.train_utils.device_funcs import to_device from funasr.train_utils.device_funcs import to_device
from funasr.train_utils.recursive_op import recursive_average from funasr.train_utils.recursive_op import recursive_average
from funasr.train_utils.average_nbest_models import average_checkpoints
class Trainer: class Trainer:
""" """
@ -66,10 +67,9 @@ class Trainer:
self.use_ddp = use_ddp self.use_ddp = use_ddp
self.use_fsdp = use_fsdp self.use_fsdp = use_fsdp
self.device = next(model.parameters()).device self.device = next(model.parameters()).device
self.avg_nbest_model = kwargs.get("avg_nbest_model", 5)
self.kwargs = kwargs self.kwargs = kwargs
if self.resume:
self._resume_checkpoint(self.resume)
try: try:
rank = dist.get_rank() rank = dist.get_rank()
@ -102,9 +102,17 @@ class Trainer:
} }
# Create output directory if it does not exist # Create output directory if it does not exist
os.makedirs(self.output_dir, exist_ok=True) os.makedirs(self.output_dir, exist_ok=True)
filename = os.path.join(self.output_dir, f'model.e{epoch}.pb') filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
torch.save(state, filename) torch.save(state, filename)
print(f'Checkpoint saved to {filename}') print(f'Checkpoint saved to {filename}')
latest = Path(os.path.join(self.output_dir, f'model.pt'))
try:
latest.unlink()
except:
pass
latest.symlink_to(filename)
def _resume_checkpoint(self, resume_path): def _resume_checkpoint(self, resume_path):
""" """
@ -114,29 +122,50 @@ class Trainer:
Args: Args:
resume_path (str): The file path to the checkpoint to resume from. resume_path (str): The file path to the checkpoint to resume from.
""" """
if os.path.isfile(resume_path): ckpt = os.path.join(resume_path, "model.pt")
checkpoint = torch.load(resume_path) if os.path.isfile(ckpt):
checkpoint = torch.load(ckpt)
self.start_epoch = checkpoint['epoch'] + 1 self.start_epoch = checkpoint['epoch'] + 1
self.model.load_state_dict(checkpoint['state_dict']) self.model.load_state_dict(checkpoint['state_dict'])
self.optim.load_state_dict(checkpoint['optimizer']) self.optim.load_state_dict(checkpoint['optimizer'])
self.scheduler.load_state_dict(checkpoint['scheduler']) self.scheduler.load_state_dict(checkpoint['scheduler'])
print(f"Checkpoint loaded successfully from '{resume_path}' at (epoch {checkpoint['epoch']})") print(f"Checkpoint loaded successfully from '{ckpt}'")
else: else:
print(f"No checkpoint found at '{resume_path}', starting from scratch") print(f"No checkpoint found at '{ckpt}', starting from scratch")
if self.use_ddp or self.use_fsdp:
dist.barrier()
def run(self): def run(self):
""" """
Starts the training process, iterating over epochs, training the model, Starts the training process, iterating over epochs, training the model,
and saving checkpoints at the end of each epoch. and saving checkpoints at the end of each epoch.
""" """
if self.resume:
self._resume_checkpoint(self.output_dir)
for epoch in range(self.start_epoch, self.max_epoch + 1): for epoch in range(self.start_epoch, self.max_epoch + 1):
self._train_epoch(epoch) self._train_epoch(epoch)
# self._validate_epoch(epoch)
self._validate_epoch(epoch)
if self.rank == 0: if self.rank == 0:
self._save_checkpoint(epoch) self._save_checkpoint(epoch)
self.scheduler.step()
if self.use_ddp or self.use_fsdp:
dist.barrier()
self.scheduler.step()
if self.rank == 0:
average_checkpoints(self.output_dir, self.avg_nbest_model)
if self.use_ddp or self.use_fsdp:
dist.barrier()
self.writer.close() self.writer.close()
def _train_epoch(self, epoch): def _train_epoch(self, epoch):
""" """
@ -157,8 +186,7 @@ class Trainer:
for batch_idx, batch in enumerate(self.dataloader_train): for batch_idx, batch in enumerate(self.dataloader_train):
time1 = time.perf_counter() time1 = time.perf_counter()
speed_stats["data_load"] = f"{time1-time5:0.3f}" speed_stats["data_load"] = f"{time1-time5:0.3f}"
# import pdb;
# pdb.set_trace()
batch = to_device(batch, self.device) batch = to_device(batch, self.device)
my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
@ -211,13 +239,12 @@ class Trainer:
speed_stats["optim_time"] = f"{time5 - time4:0.3f}" speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
speed_stats["total_time"] = total_time speed_stats["total_time"] = total_time
# import pdb;
# pdb.set_trace()
pbar.update(1) pbar.update(1)
if self.local_rank == 0: if self.local_rank == 0:
description = ( description = (
f"Epoch: {epoch + 1}/{self.max_epoch}, " f"Epoch: {epoch}/{self.max_epoch}, "
f"step {batch_idx}/{len(self.dataloader_train)}, " f"step {batch_idx}/{len(self.dataloader_train)}, "
f"{speed_stats}, " f"{speed_stats}, "
f"(loss: {loss.detach().cpu().item():.3f}), " f"(loss: {loss.detach().cpu().item():.3f}), "
@ -248,6 +275,50 @@ class Trainer:
""" """
self.model.eval() self.model.eval()
with torch.no_grad(): with torch.no_grad():
for data, target in self.dataloader_val: pbar = tqdm(colour="red", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_val),
# Implement the model validation steps here dynamic_ncols=True)
pass speed_stats = {}
time5 = time.perf_counter()
for batch_idx, batch in enumerate(self.dataloader_val):
time1 = time.perf_counter()
speed_stats["data_load"] = f"{time1 - time5:0.3f}"
batch = to_device(batch, self.device)
time2 = time.perf_counter()
retval = self.model(**batch)
time3 = time.perf_counter()
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
loss, stats, weight = retval
stats = {k: v for k, v in stats.items() if v is not None}
if self.use_ddp or self.use_fsdp:
# Apply weighted averaging for loss and stats
loss = (loss * weight.type(loss.dtype)).sum()
# if distributed, this method can also apply all_reduce()
stats, weight = recursive_average(stats, weight, distributed=True)
# Now weight is summation over all workers
loss /= weight
# Multiply world_size because DistributedDataParallel
# automatically normalizes the gradient by world_size.
loss *= self.world_size
# Scale the loss since we're not updating for every mini-batch
loss = loss
time4 = time.perf_counter()
pbar.update(1)
if self.local_rank == 0:
description = (
f"validation: \nEpoch: {epoch}/{self.max_epoch}, "
f"step {batch_idx}/{len(self.dataloader_train)}, "
f"{speed_stats}, "
f"(loss: {loss.detach().cpu().item():.3f}), "
f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
)
pbar.set_description(description)
if self.writer:
self.writer.add_scalar('Loss/val', loss.item(),
epoch*len(self.dataloader_train) + batch_idx)
for key, var in stats.items():
self.writer.add_scalar(f'{key}/val', var.item(),
epoch * len(self.dataloader_train) + batch_idx)
for key, var in speed_stats.items():
self.writer.add_scalar(f'{key}/val', eval(var),
epoch * len(self.dataloader_train) + batch_idx)