mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Funasr1.0 (#1261)
* funasr1.0 funetine * funasr1.0 pbar * update with main (#1260) * Update websocket_protocol_zh.md * update --------- Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com> Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com> --------- Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com> Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>
This commit is contained in:
parent
b1857837dd
commit
9a9c3b75b5
@ -9,9 +9,11 @@
|
|||||||
python funasr/bin/train.py \
|
python funasr/bin/train.py \
|
||||||
+model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
|
+model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
|
||||||
+model_revision="v2.0.2" \
|
+model_revision="v2.0.2" \
|
||||||
+train_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len.jsonl" \
|
+train_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \
|
||||||
|
+valid_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \
|
||||||
++dataset_conf.batch_size=2 \
|
++dataset_conf.batch_size=2 \
|
||||||
++dataset_conf.batch_type="example" \
|
++dataset_conf.batch_type="example" \
|
||||||
|
++train_conf.max_epoch=2 \
|
||||||
+output_dir="outputs/debug/ckpt/funasr2/exp2" \
|
+output_dir="outputs/debug/ckpt/funasr2/exp2" \
|
||||||
+device="cpu" \
|
+device="cpu" \
|
||||||
+debug="true"
|
+debug="true"
|
||||||
@ -15,6 +15,6 @@ model = AutoModel(model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-co
|
|||||||
spk_model_revision="v2.0.2",
|
spk_model_revision="v2.0.2",
|
||||||
)
|
)
|
||||||
|
|
||||||
res = model.generate(input=f"{model.model_path}/example/asr_example.wav",
|
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
|
||||||
hotword='达摩院 魔搭')
|
hotword='达摩院 魔搭')
|
||||||
print(res)
|
print(res)
|
||||||
@ -221,7 +221,8 @@ class AutoModel:
|
|||||||
speed_stats = {}
|
speed_stats = {}
|
||||||
asr_result_list = []
|
asr_result_list = []
|
||||||
num_samples = len(data_list)
|
num_samples = len(data_list)
|
||||||
pbar = tqdm(colour="blue", total=num_samples+1, dynamic_ncols=True)
|
disable_pbar = kwargs.get("disable_pbar", False)
|
||||||
|
pbar = tqdm(colour="blue", total=num_samples+1, dynamic_ncols=True) if not disable_pbar else None
|
||||||
time_speech_total = 0.0
|
time_speech_total = 0.0
|
||||||
time_escape_total = 0.0
|
time_escape_total = 0.0
|
||||||
for beg_idx in range(0, num_samples, batch_size):
|
for beg_idx in range(0, num_samples, batch_size):
|
||||||
@ -239,8 +240,7 @@ class AutoModel:
|
|||||||
time2 = time.perf_counter()
|
time2 = time.perf_counter()
|
||||||
|
|
||||||
asr_result_list.extend(results)
|
asr_result_list.extend(results)
|
||||||
pbar.update(1)
|
|
||||||
|
|
||||||
# batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
|
# batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
|
||||||
batch_data_time = meta_data.get("batch_data_time", -1)
|
batch_data_time = meta_data.get("batch_data_time", -1)
|
||||||
time_escape = time2 - time1
|
time_escape = time2 - time1
|
||||||
@ -252,12 +252,15 @@ class AutoModel:
|
|||||||
description = (
|
description = (
|
||||||
f"{speed_stats}, "
|
f"{speed_stats}, "
|
||||||
)
|
)
|
||||||
pbar.set_description(description)
|
if pbar:
|
||||||
|
pbar.update(1)
|
||||||
|
pbar.set_description(description)
|
||||||
time_speech_total += batch_data_time
|
time_speech_total += batch_data_time
|
||||||
time_escape_total += time_escape
|
time_escape_total += time_escape
|
||||||
|
|
||||||
pbar.update(1)
|
if pbar:
|
||||||
pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
|
pbar.update(1)
|
||||||
|
pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
return asr_result_list
|
return asr_result_list
|
||||||
|
|
||||||
@ -309,8 +312,11 @@ class AutoModel:
|
|||||||
time_speech_total_per_sample = speech_lengths/16000
|
time_speech_total_per_sample = speech_lengths/16000
|
||||||
time_speech_total_all_samples += time_speech_total_per_sample
|
time_speech_total_all_samples += time_speech_total_per_sample
|
||||||
|
|
||||||
|
pbar_sample = tqdm(colour="blue", total=n + 1, dynamic_ncols=True)
|
||||||
|
|
||||||
all_segments = []
|
all_segments = []
|
||||||
for j, _ in enumerate(range(0, n)):
|
for j, _ in enumerate(range(0, n)):
|
||||||
|
pbar_sample.update(1)
|
||||||
batch_size_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
|
batch_size_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
|
||||||
if j < n - 1 and (
|
if j < n - 1 and (
|
||||||
batch_size_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size and (
|
batch_size_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size and (
|
||||||
@ -319,13 +325,14 @@ class AutoModel:
|
|||||||
batch_size_ms_cum = 0
|
batch_size_ms_cum = 0
|
||||||
end_idx = j + 1
|
end_idx = j + 1
|
||||||
speech_j, speech_lengths_j = slice_padding_audio_samples(speech, speech_lengths, sorted_data[beg_idx:end_idx])
|
speech_j, speech_lengths_j = slice_padding_audio_samples(speech, speech_lengths, sorted_data[beg_idx:end_idx])
|
||||||
results = self.inference(speech_j, input_len=None, model=model, kwargs=kwargs, **cfg)
|
results = self.inference(speech_j, input_len=None, model=model, kwargs=kwargs, disable_pbar=True, **cfg)
|
||||||
if self.spk_model is not None:
|
if self.spk_model is not None:
|
||||||
|
|
||||||
|
|
||||||
# compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
|
# compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
|
||||||
for _b in range(len(speech_j)):
|
for _b in range(len(speech_j)):
|
||||||
vad_segments = [[sorted_data[beg_idx:end_idx][_b][0][0]/1000.0, \
|
vad_segments = [[sorted_data[beg_idx:end_idx][_b][0][0]/1000.0,
|
||||||
sorted_data[beg_idx:end_idx][_b][0][1]/1000.0, \
|
sorted_data[beg_idx:end_idx][_b][0][1]/1000.0,
|
||||||
speech_j[_b]]]
|
speech_j[_b]]]
|
||||||
segments = sv_chunk(vad_segments)
|
segments = sv_chunk(vad_segments)
|
||||||
all_segments.extend(segments)
|
all_segments.extend(segments)
|
||||||
@ -338,12 +345,13 @@ class AutoModel:
|
|||||||
results_sorted.extend(results)
|
results_sorted.extend(results)
|
||||||
|
|
||||||
|
|
||||||
pbar_total.update(1)
|
|
||||||
end_asr_total = time.time()
|
end_asr_total = time.time()
|
||||||
time_escape_total_per_sample = end_asr_total - beg_asr_total
|
time_escape_total_per_sample = end_asr_total - beg_asr_total
|
||||||
pbar_total.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
|
pbar_sample.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
|
||||||
f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
|
f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
|
||||||
f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
|
f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
|
||||||
|
|
||||||
|
|
||||||
restored_data = [0] * n
|
restored_data = [0] * n
|
||||||
for j in range(n):
|
for j in range(n):
|
||||||
|
|||||||
@ -141,30 +141,37 @@ def main(**kwargs):
|
|||||||
scheduler_class = scheduler_classes.get(scheduler)
|
scheduler_class = scheduler_classes.get(scheduler)
|
||||||
scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
|
scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
|
||||||
|
|
||||||
# import pdb;
|
|
||||||
# pdb.set_trace()
|
|
||||||
# dataset
|
# dataset
|
||||||
dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
|
dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
|
||||||
dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
|
dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
|
||||||
|
dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer,
|
||||||
|
**kwargs.get("dataset_conf"))
|
||||||
|
|
||||||
# dataloader
|
# dataloader
|
||||||
batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
|
batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
|
||||||
batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
|
batch_sampler_val = None
|
||||||
if batch_sampler is not None:
|
if batch_sampler is not None:
|
||||||
|
batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
|
||||||
batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
|
batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
|
||||||
|
batch_sampler_val = batch_sampler_class(dataset_tr, is_training=False, **kwargs.get("dataset_conf"))
|
||||||
dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
|
dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
|
||||||
collate_fn=dataset_tr.collator,
|
collate_fn=dataset_tr.collator,
|
||||||
batch_sampler=batch_sampler,
|
batch_sampler=batch_sampler,
|
||||||
num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
|
num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
|
||||||
pin_memory=True)
|
pin_memory=True)
|
||||||
|
|
||||||
|
dataloader_val = torch.utils.data.DataLoader(dataset_val,
|
||||||
|
collate_fn=dataset_val.collator,
|
||||||
|
batch_sampler=batch_sampler_val,
|
||||||
|
num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
|
||||||
|
pin_memory=True)
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
optim=optim,
|
optim=optim,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
dataloader_train=dataloader_tr,
|
dataloader_train=dataloader_tr,
|
||||||
dataloader_val=None,
|
dataloader_val=dataloader_val,
|
||||||
local_rank=local_rank,
|
local_rank=local_rank,
|
||||||
use_ddp=use_ddp,
|
use_ddp=use_ddp,
|
||||||
use_fsdp=use_fsdp,
|
use_fsdp=use_fsdp,
|
||||||
|
|||||||
@ -54,7 +54,11 @@ class IndexDSJsonl(torch.utils.data.Dataset):
|
|||||||
return len(self.contents)
|
return len(self.contents)
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
return self.contents[index]
|
try:
|
||||||
|
data = self.contents[index]
|
||||||
|
except:
|
||||||
|
print(index)
|
||||||
|
return data
|
||||||
|
|
||||||
def get_source_len(self, data_dict):
|
def get_source_len(self, data_dict):
|
||||||
return data_dict["source_len"]
|
return data_dict["source_len"]
|
||||||
|
|||||||
@ -13,6 +13,7 @@ class BatchSampler(torch.utils.data.BatchSampler):
|
|||||||
buffer_size: int = 30,
|
buffer_size: int = 30,
|
||||||
drop_last: bool = False,
|
drop_last: bool = False,
|
||||||
shuffle: bool = True,
|
shuffle: bool = True,
|
||||||
|
is_training: bool = True,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
self.drop_last = drop_last
|
self.drop_last = drop_last
|
||||||
@ -24,7 +25,7 @@ class BatchSampler(torch.utils.data.BatchSampler):
|
|||||||
self.buffer_size = buffer_size
|
self.buffer_size = buffer_size
|
||||||
self.max_token_length = kwargs.get("max_token_length", 5000)
|
self.max_token_length = kwargs.get("max_token_length", 5000)
|
||||||
self.shuffle_idx = np.arange(self.total_samples)
|
self.shuffle_idx = np.arange(self.total_samples)
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle and is_training
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.total_samples
|
return self.total_samples
|
||||||
|
|||||||
@ -164,6 +164,7 @@ class Paraformer(torch.nn.Module):
|
|||||||
self.use_1st_decoder_loss = use_1st_decoder_loss
|
self.use_1st_decoder_loss = use_1st_decoder_loss
|
||||||
self.length_normalized_loss = length_normalized_loss
|
self.length_normalized_loss = length_normalized_loss
|
||||||
self.beam_search = None
|
self.beam_search = None
|
||||||
|
self.error_calculator = None
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -95,6 +95,7 @@ train_conf:
|
|||||||
- acc
|
- acc
|
||||||
- max
|
- max
|
||||||
keep_nbest_models: 10
|
keep_nbest_models: 10
|
||||||
|
avg_nbest_model: 5
|
||||||
log_interval: 50
|
log_interval: 50
|
||||||
|
|
||||||
optim: adam
|
optim: adam
|
||||||
|
|||||||
@ -9,117 +9,173 @@ from io import BytesIO
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing import Collection
|
from typing import Collection
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import re
|
||||||
|
from collections import OrderedDict
|
||||||
|
from functools import cmp_to_key
|
||||||
|
|
||||||
from funasr.train.reporter import Reporter
|
|
||||||
|
|
||||||
|
# @torch.no_grad()
|
||||||
|
# def average_nbest_models(
|
||||||
|
# output_dir: Path,
|
||||||
|
# best_model_criterion: Sequence[Sequence[str]],
|
||||||
|
# nbest: Union[Collection[int], int],
|
||||||
|
# suffix: Optional[str] = None,
|
||||||
|
# oss_bucket=None,
|
||||||
|
# pai_output_dir=None,
|
||||||
|
# ) -> None:
|
||||||
|
# """Generate averaged model from n-best models
|
||||||
|
#
|
||||||
|
# Args:
|
||||||
|
# output_dir: The directory contains the model file for each epoch
|
||||||
|
# reporter: Reporter instance
|
||||||
|
# best_model_criterion: Give criterions to decide the best model.
|
||||||
|
# e.g. [("valid", "loss", "min"), ("train", "acc", "max")]
|
||||||
|
# nbest: Number of best model files to be averaged
|
||||||
|
# suffix: A suffix added to the averaged model file name
|
||||||
|
# """
|
||||||
|
# if isinstance(nbest, int):
|
||||||
|
# nbests = [nbest]
|
||||||
|
# else:
|
||||||
|
# nbests = list(nbest)
|
||||||
|
# if len(nbests) == 0:
|
||||||
|
# warnings.warn("At least 1 nbest values are required")
|
||||||
|
# nbests = [1]
|
||||||
|
# if suffix is not None:
|
||||||
|
# suffix = suffix + "."
|
||||||
|
# else:
|
||||||
|
# suffix = ""
|
||||||
|
#
|
||||||
|
# # 1. Get nbests: List[Tuple[str, str, List[Tuple[epoch, value]]]]
|
||||||
|
# nbest_epochs = [
|
||||||
|
# (ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)])
|
||||||
|
# for ph, k, m in best_model_criterion
|
||||||
|
# if reporter.has(ph, k)
|
||||||
|
# ]
|
||||||
|
#
|
||||||
|
# _loaded = {}
|
||||||
|
# for ph, cr, epoch_and_values in nbest_epochs:
|
||||||
|
# _nbests = [i for i in nbests if i <= len(epoch_and_values)]
|
||||||
|
# if len(_nbests) == 0:
|
||||||
|
# _nbests = [1]
|
||||||
|
#
|
||||||
|
# for n in _nbests:
|
||||||
|
# if n == 0:
|
||||||
|
# continue
|
||||||
|
# elif n == 1:
|
||||||
|
# # The averaged model is same as the best model
|
||||||
|
# e, _ = epoch_and_values[0]
|
||||||
|
# op = output_dir / f"{e}epoch.pb"
|
||||||
|
# sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb"
|
||||||
|
# if sym_op.is_symlink() or sym_op.exists():
|
||||||
|
# sym_op.unlink()
|
||||||
|
# sym_op.symlink_to(op.name)
|
||||||
|
# else:
|
||||||
|
# op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb"
|
||||||
|
# logging.info(
|
||||||
|
# f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}'
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# avg = None
|
||||||
|
# # 2.a. Averaging model
|
||||||
|
# for e, _ in epoch_and_values[:n]:
|
||||||
|
# if e not in _loaded:
|
||||||
|
# if oss_bucket is None:
|
||||||
|
# _loaded[e] = torch.load(
|
||||||
|
# output_dir / f"{e}epoch.pb",
|
||||||
|
# map_location="cpu",
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# buffer = BytesIO(
|
||||||
|
# oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read())
|
||||||
|
# _loaded[e] = torch.load(buffer)
|
||||||
|
# states = _loaded[e]
|
||||||
|
#
|
||||||
|
# if avg is None:
|
||||||
|
# avg = states
|
||||||
|
# else:
|
||||||
|
# # Accumulated
|
||||||
|
# for k in avg:
|
||||||
|
# avg[k] = avg[k] + states[k]
|
||||||
|
# for k in avg:
|
||||||
|
# if str(avg[k].dtype).startswith("torch.int"):
|
||||||
|
# # For int type, not averaged, but only accumulated.
|
||||||
|
# # e.g. BatchNorm.num_batches_tracked
|
||||||
|
# # (If there are any cases that requires averaging
|
||||||
|
# # or the other reducing method, e.g. max/min, for integer type,
|
||||||
|
# # please report.)
|
||||||
|
# pass
|
||||||
|
# else:
|
||||||
|
# avg[k] = avg[k] / n
|
||||||
|
#
|
||||||
|
# # 2.b. Save the ave model and create a symlink
|
||||||
|
# if oss_bucket is None:
|
||||||
|
# torch.save(avg, op)
|
||||||
|
# else:
|
||||||
|
# buffer = BytesIO()
|
||||||
|
# torch.save(avg, buffer)
|
||||||
|
# oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"),
|
||||||
|
# buffer.getvalue())
|
||||||
|
#
|
||||||
|
# # 3. *.*.ave.pb is a symlink to the max ave model
|
||||||
|
# if oss_bucket is None:
|
||||||
|
# op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb"
|
||||||
|
# sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb"
|
||||||
|
# if sym_op.is_symlink() or sym_op.exists():
|
||||||
|
# sym_op.unlink()
|
||||||
|
# sym_op.symlink_to(op.name)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_checkpoint_paths(output_dir: str, last_n: int=5):
|
||||||
|
"""
|
||||||
|
Get the paths of the last 'last_n' checkpoints by parsing filenames
|
||||||
|
in the output directory.
|
||||||
|
"""
|
||||||
|
# List all files in the output directory
|
||||||
|
files = os.listdir(output_dir)
|
||||||
|
# Filter out checkpoint files and extract epoch numbers
|
||||||
|
checkpoint_files = [f for f in files if f.startswith("model.pt.e")]
|
||||||
|
# Sort files by epoch number in descending order
|
||||||
|
checkpoint_files.sort(key=lambda x: int(re.search(r'(\d+)', x).group()), reverse=True)
|
||||||
|
# Get the last 'last_n' checkpoint paths
|
||||||
|
checkpoint_paths = [os.path.join(output_dir, f) for f in checkpoint_files[:last_n]]
|
||||||
|
return checkpoint_paths
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def average_nbest_models(
|
def average_checkpoints(output_dir: str, last_n: int=5):
|
||||||
output_dir: Path,
|
|
||||||
reporter: Reporter,
|
|
||||||
best_model_criterion: Sequence[Sequence[str]],
|
|
||||||
nbest: Union[Collection[int], int],
|
|
||||||
suffix: Optional[str] = None,
|
|
||||||
oss_bucket=None,
|
|
||||||
pai_output_dir=None,
|
|
||||||
) -> None:
|
|
||||||
"""Generate averaged model from n-best models
|
|
||||||
|
|
||||||
Args:
|
|
||||||
output_dir: The directory contains the model file for each epoch
|
|
||||||
reporter: Reporter instance
|
|
||||||
best_model_criterion: Give criterions to decide the best model.
|
|
||||||
e.g. [("valid", "loss", "min"), ("train", "acc", "max")]
|
|
||||||
nbest: Number of best model files to be averaged
|
|
||||||
suffix: A suffix added to the averaged model file name
|
|
||||||
"""
|
"""
|
||||||
if isinstance(nbest, int):
|
Average the last 'last_n' checkpoints' model state_dicts.
|
||||||
nbests = [nbest]
|
If a tensor is of type torch.int, perform sum instead of average.
|
||||||
else:
|
"""
|
||||||
nbests = list(nbest)
|
checkpoint_paths = _get_checkpoint_paths(output_dir, last_n)
|
||||||
if len(nbests) == 0:
|
state_dicts = []
|
||||||
warnings.warn("At least 1 nbest values are required")
|
|
||||||
nbests = [1]
|
|
||||||
if suffix is not None:
|
|
||||||
suffix = suffix + "."
|
|
||||||
else:
|
|
||||||
suffix = ""
|
|
||||||
|
|
||||||
# 1. Get nbests: List[Tuple[str, str, List[Tuple[epoch, value]]]]
|
# Load state_dicts from checkpoints
|
||||||
nbest_epochs = [
|
for path in checkpoint_paths:
|
||||||
(ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)])
|
if os.path.isfile(path):
|
||||||
for ph, k, m in best_model_criterion
|
state_dicts.append(torch.load(path, map_location='cpu')['state_dict'])
|
||||||
if reporter.has(ph, k)
|
else:
|
||||||
]
|
print(f"Checkpoint file {path} not found.")
|
||||||
|
continue
|
||||||
|
|
||||||
_loaded = {}
|
# Check if we have any state_dicts to average
|
||||||
for ph, cr, epoch_and_values in nbest_epochs:
|
if not state_dicts:
|
||||||
_nbests = [i for i in nbests if i <= len(epoch_and_values)]
|
raise RuntimeError("No checkpoints found for averaging.")
|
||||||
if len(_nbests) == 0:
|
|
||||||
_nbests = [1]
|
|
||||||
|
|
||||||
for n in _nbests:
|
# Average or sum weights
|
||||||
if n == 0:
|
avg_state_dict = OrderedDict()
|
||||||
continue
|
for key in state_dicts[0].keys():
|
||||||
elif n == 1:
|
tensors = [state_dict[key].cpu() for state_dict in state_dicts]
|
||||||
# The averaged model is same as the best model
|
# Check the type of the tensor
|
||||||
e, _ = epoch_and_values[0]
|
if str(tensors[0].dtype).startswith("torch.int"):
|
||||||
op = output_dir / f"{e}epoch.pb"
|
# Perform sum for integer tensors
|
||||||
sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb"
|
summed_tensor = sum(tensors)
|
||||||
if sym_op.is_symlink() or sym_op.exists():
|
avg_state_dict[key] = summed_tensor
|
||||||
sym_op.unlink()
|
else:
|
||||||
sym_op.symlink_to(op.name)
|
# Perform average for other types of tensors
|
||||||
else:
|
stacked_tensors = torch.stack(tensors)
|
||||||
op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb"
|
avg_state_dict[key] = torch.mean(stacked_tensors, dim=0)
|
||||||
logging.info(
|
|
||||||
f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}'
|
torch.save({'state_dict': avg_state_dict}, os.path.join(output_dir, f"model.pt.avg{last_n}"))
|
||||||
)
|
return avg_state_dict
|
||||||
|
|
||||||
avg = None
|
|
||||||
# 2.a. Averaging model
|
|
||||||
for e, _ in epoch_and_values[:n]:
|
|
||||||
if e not in _loaded:
|
|
||||||
if oss_bucket is None:
|
|
||||||
_loaded[e] = torch.load(
|
|
||||||
output_dir / f"{e}epoch.pb",
|
|
||||||
map_location="cpu",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
buffer = BytesIO(
|
|
||||||
oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read())
|
|
||||||
_loaded[e] = torch.load(buffer)
|
|
||||||
states = _loaded[e]
|
|
||||||
|
|
||||||
if avg is None:
|
|
||||||
avg = states
|
|
||||||
else:
|
|
||||||
# Accumulated
|
|
||||||
for k in avg:
|
|
||||||
avg[k] = avg[k] + states[k]
|
|
||||||
for k in avg:
|
|
||||||
if str(avg[k].dtype).startswith("torch.int"):
|
|
||||||
# For int type, not averaged, but only accumulated.
|
|
||||||
# e.g. BatchNorm.num_batches_tracked
|
|
||||||
# (If there are any cases that requires averaging
|
|
||||||
# or the other reducing method, e.g. max/min, for integer type,
|
|
||||||
# please report.)
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
avg[k] = avg[k] / n
|
|
||||||
|
|
||||||
# 2.b. Save the ave model and create a symlink
|
|
||||||
if oss_bucket is None:
|
|
||||||
torch.save(avg, op)
|
|
||||||
else:
|
|
||||||
buffer = BytesIO()
|
|
||||||
torch.save(avg, buffer)
|
|
||||||
oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"),
|
|
||||||
buffer.getvalue())
|
|
||||||
|
|
||||||
# 3. *.*.ave.pb is a symlink to the max ave model
|
|
||||||
if oss_bucket is None:
|
|
||||||
op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb"
|
|
||||||
sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb"
|
|
||||||
if sym_op.is_symlink() or sym_op.exists():
|
|
||||||
sym_op.unlink()
|
|
||||||
sym_op.symlink_to(op.name)
|
|
||||||
@ -7,10 +7,11 @@ import torch.distributed as dist
|
|||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
# from torch.utils.tensorboard import SummaryWriter
|
# from torch.utils.tensorboard import SummaryWriter
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from funasr.train_utils.device_funcs import to_device
|
from funasr.train_utils.device_funcs import to_device
|
||||||
from funasr.train_utils.recursive_op import recursive_average
|
from funasr.train_utils.recursive_op import recursive_average
|
||||||
|
from funasr.train_utils.average_nbest_models import average_checkpoints
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
"""
|
"""
|
||||||
@ -66,10 +67,9 @@ class Trainer:
|
|||||||
self.use_ddp = use_ddp
|
self.use_ddp = use_ddp
|
||||||
self.use_fsdp = use_fsdp
|
self.use_fsdp = use_fsdp
|
||||||
self.device = next(model.parameters()).device
|
self.device = next(model.parameters()).device
|
||||||
|
self.avg_nbest_model = kwargs.get("avg_nbest_model", 5)
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
||||||
if self.resume:
|
|
||||||
self._resume_checkpoint(self.resume)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
@ -102,9 +102,17 @@ class Trainer:
|
|||||||
}
|
}
|
||||||
# Create output directory if it does not exist
|
# Create output directory if it does not exist
|
||||||
os.makedirs(self.output_dir, exist_ok=True)
|
os.makedirs(self.output_dir, exist_ok=True)
|
||||||
filename = os.path.join(self.output_dir, f'model.e{epoch}.pb')
|
filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
|
||||||
torch.save(state, filename)
|
torch.save(state, filename)
|
||||||
|
|
||||||
print(f'Checkpoint saved to {filename}')
|
print(f'Checkpoint saved to {filename}')
|
||||||
|
latest = Path(os.path.join(self.output_dir, f'model.pt'))
|
||||||
|
try:
|
||||||
|
latest.unlink()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
latest.symlink_to(filename)
|
||||||
|
|
||||||
def _resume_checkpoint(self, resume_path):
|
def _resume_checkpoint(self, resume_path):
|
||||||
"""
|
"""
|
||||||
@ -114,29 +122,50 @@ class Trainer:
|
|||||||
Args:
|
Args:
|
||||||
resume_path (str): The file path to the checkpoint to resume from.
|
resume_path (str): The file path to the checkpoint to resume from.
|
||||||
"""
|
"""
|
||||||
if os.path.isfile(resume_path):
|
ckpt = os.path.join(resume_path, "model.pt")
|
||||||
checkpoint = torch.load(resume_path)
|
if os.path.isfile(ckpt):
|
||||||
|
checkpoint = torch.load(ckpt)
|
||||||
self.start_epoch = checkpoint['epoch'] + 1
|
self.start_epoch = checkpoint['epoch'] + 1
|
||||||
self.model.load_state_dict(checkpoint['state_dict'])
|
self.model.load_state_dict(checkpoint['state_dict'])
|
||||||
self.optim.load_state_dict(checkpoint['optimizer'])
|
self.optim.load_state_dict(checkpoint['optimizer'])
|
||||||
self.scheduler.load_state_dict(checkpoint['scheduler'])
|
self.scheduler.load_state_dict(checkpoint['scheduler'])
|
||||||
print(f"Checkpoint loaded successfully from '{resume_path}' at (epoch {checkpoint['epoch']})")
|
print(f"Checkpoint loaded successfully from '{ckpt}'")
|
||||||
else:
|
else:
|
||||||
print(f"No checkpoint found at '{resume_path}', starting from scratch")
|
print(f"No checkpoint found at '{ckpt}', starting from scratch")
|
||||||
|
|
||||||
|
if self.use_ddp or self.use_fsdp:
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
"""
|
"""
|
||||||
Starts the training process, iterating over epochs, training the model,
|
Starts the training process, iterating over epochs, training the model,
|
||||||
and saving checkpoints at the end of each epoch.
|
and saving checkpoints at the end of each epoch.
|
||||||
"""
|
"""
|
||||||
|
if self.resume:
|
||||||
|
self._resume_checkpoint(self.output_dir)
|
||||||
|
|
||||||
for epoch in range(self.start_epoch, self.max_epoch + 1):
|
for epoch in range(self.start_epoch, self.max_epoch + 1):
|
||||||
|
|
||||||
self._train_epoch(epoch)
|
self._train_epoch(epoch)
|
||||||
# self._validate_epoch(epoch)
|
|
||||||
|
self._validate_epoch(epoch)
|
||||||
|
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
self._save_checkpoint(epoch)
|
self._save_checkpoint(epoch)
|
||||||
self.scheduler.step()
|
|
||||||
|
|
||||||
|
if self.use_ddp or self.use_fsdp:
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
self.scheduler.step()
|
||||||
|
|
||||||
|
|
||||||
|
if self.rank == 0:
|
||||||
|
average_checkpoints(self.output_dir, self.avg_nbest_model)
|
||||||
|
|
||||||
|
if self.use_ddp or self.use_fsdp:
|
||||||
|
dist.barrier()
|
||||||
self.writer.close()
|
self.writer.close()
|
||||||
|
|
||||||
|
|
||||||
def _train_epoch(self, epoch):
|
def _train_epoch(self, epoch):
|
||||||
"""
|
"""
|
||||||
@ -157,8 +186,7 @@ class Trainer:
|
|||||||
for batch_idx, batch in enumerate(self.dataloader_train):
|
for batch_idx, batch in enumerate(self.dataloader_train):
|
||||||
time1 = time.perf_counter()
|
time1 = time.perf_counter()
|
||||||
speed_stats["data_load"] = f"{time1-time5:0.3f}"
|
speed_stats["data_load"] = f"{time1-time5:0.3f}"
|
||||||
# import pdb;
|
|
||||||
# pdb.set_trace()
|
|
||||||
batch = to_device(batch, self.device)
|
batch = to_device(batch, self.device)
|
||||||
|
|
||||||
my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
|
my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
|
||||||
@ -211,13 +239,12 @@ class Trainer:
|
|||||||
speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
|
speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
|
||||||
|
|
||||||
speed_stats["total_time"] = total_time
|
speed_stats["total_time"] = total_time
|
||||||
|
|
||||||
# import pdb;
|
|
||||||
# pdb.set_trace()
|
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
if self.local_rank == 0:
|
if self.local_rank == 0:
|
||||||
description = (
|
description = (
|
||||||
f"Epoch: {epoch + 1}/{self.max_epoch}, "
|
f"Epoch: {epoch}/{self.max_epoch}, "
|
||||||
f"step {batch_idx}/{len(self.dataloader_train)}, "
|
f"step {batch_idx}/{len(self.dataloader_train)}, "
|
||||||
f"{speed_stats}, "
|
f"{speed_stats}, "
|
||||||
f"(loss: {loss.detach().cpu().item():.3f}), "
|
f"(loss: {loss.detach().cpu().item():.3f}), "
|
||||||
@ -248,6 +275,50 @@ class Trainer:
|
|||||||
"""
|
"""
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for data, target in self.dataloader_val:
|
pbar = tqdm(colour="red", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_val),
|
||||||
# Implement the model validation steps here
|
dynamic_ncols=True)
|
||||||
pass
|
speed_stats = {}
|
||||||
|
time5 = time.perf_counter()
|
||||||
|
for batch_idx, batch in enumerate(self.dataloader_val):
|
||||||
|
time1 = time.perf_counter()
|
||||||
|
speed_stats["data_load"] = f"{time1 - time5:0.3f}"
|
||||||
|
batch = to_device(batch, self.device)
|
||||||
|
time2 = time.perf_counter()
|
||||||
|
retval = self.model(**batch)
|
||||||
|
time3 = time.perf_counter()
|
||||||
|
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
|
||||||
|
loss, stats, weight = retval
|
||||||
|
stats = {k: v for k, v in stats.items() if v is not None}
|
||||||
|
if self.use_ddp or self.use_fsdp:
|
||||||
|
# Apply weighted averaging for loss and stats
|
||||||
|
loss = (loss * weight.type(loss.dtype)).sum()
|
||||||
|
# if distributed, this method can also apply all_reduce()
|
||||||
|
stats, weight = recursive_average(stats, weight, distributed=True)
|
||||||
|
# Now weight is summation over all workers
|
||||||
|
loss /= weight
|
||||||
|
# Multiply world_size because DistributedDataParallel
|
||||||
|
# automatically normalizes the gradient by world_size.
|
||||||
|
loss *= self.world_size
|
||||||
|
# Scale the loss since we're not updating for every mini-batch
|
||||||
|
loss = loss
|
||||||
|
time4 = time.perf_counter()
|
||||||
|
|
||||||
|
pbar.update(1)
|
||||||
|
if self.local_rank == 0:
|
||||||
|
description = (
|
||||||
|
f"validation: \nEpoch: {epoch}/{self.max_epoch}, "
|
||||||
|
f"step {batch_idx}/{len(self.dataloader_train)}, "
|
||||||
|
f"{speed_stats}, "
|
||||||
|
f"(loss: {loss.detach().cpu().item():.3f}), "
|
||||||
|
f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
|
||||||
|
)
|
||||||
|
pbar.set_description(description)
|
||||||
|
if self.writer:
|
||||||
|
self.writer.add_scalar('Loss/val', loss.item(),
|
||||||
|
epoch*len(self.dataloader_train) + batch_idx)
|
||||||
|
for key, var in stats.items():
|
||||||
|
self.writer.add_scalar(f'{key}/val', var.item(),
|
||||||
|
epoch * len(self.dataloader_train) + batch_idx)
|
||||||
|
for key, var in speed_stats.items():
|
||||||
|
self.writer.add_scalar(f'{key}/val', eval(var),
|
||||||
|
epoch * len(self.dataloader_train) + batch_idx)
|
||||||
Loading…
Reference in New Issue
Block a user