mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Funasr1.0 (#1275)
* funasr1.0 funetine * funasr1.0 pbar * update with main (#1260) * Update websocket_protocol_zh.md * update --------- Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com> Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com> * update with main (#1264) * Funasr1.0 (#1261) * funasr1.0 funetine * funasr1.0 pbar * update with main (#1260) * Update websocket_protocol_zh.md * update --------- Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com> Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com> --------- Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com> Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com> * bug fix --------- Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com> Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com> * funasr1.0 sanm scama * funasr1.0 infer_after_finetune * funasr1.0 fsmn-vad bug fix * funasr1.0 fsmn-vad bug fix * funasr1.0 fsmn-vad bug fix * funasr1.0 finetune * funasr1.0 finetune * funasr1.0 finetune --------- Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com> Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>
This commit is contained in:
parent
12496e559f
commit
2cca8104d2
@ -11,9 +11,9 @@ python funasr/bin/train.py \
|
|||||||
+model_revision="v2.0.2" \
|
+model_revision="v2.0.2" \
|
||||||
+train_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \
|
+train_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \
|
||||||
+valid_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \
|
+valid_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \
|
||||||
++dataset_conf.batch_size=2 \
|
++dataset_conf.batch_size=64 \
|
||||||
++dataset_conf.batch_type="example" \
|
++dataset_conf.batch_type="example" \
|
||||||
++train_conf.max_epoch=2 \
|
++train_conf.max_epoch=2 \
|
||||||
|
++dataset_conf.num_workers=4 \
|
||||||
+output_dir="outputs/debug/ckpt/funasr2/exp2" \
|
+output_dir="outputs/debug/ckpt/funasr2/exp2" \
|
||||||
+device="cpu" \
|
|
||||||
+debug="true"
|
+debug="true"
|
||||||
@ -132,7 +132,7 @@ class AutoModel:
|
|||||||
self.punc_kwargs = punc_kwargs
|
self.punc_kwargs = punc_kwargs
|
||||||
self.spk_model = spk_model
|
self.spk_model = spk_model
|
||||||
self.spk_kwargs = spk_kwargs
|
self.spk_kwargs = spk_kwargs
|
||||||
self.model_path = kwargs["model_path"]
|
self.model_path = kwargs.get("model_path", "./")
|
||||||
|
|
||||||
|
|
||||||
def build_model(self, **kwargs):
|
def build_model(self, **kwargs):
|
||||||
|
|||||||
@ -40,7 +40,7 @@ def main_hydra(kwargs: DictConfig):
|
|||||||
|
|
||||||
|
|
||||||
def main(**kwargs):
|
def main(**kwargs):
|
||||||
|
print(kwargs)
|
||||||
# set random seed
|
# set random seed
|
||||||
tables.print()
|
tables.print()
|
||||||
set_all_random_seed(kwargs.get("seed", 0))
|
set_all_random_seed(kwargs.get("seed", 0))
|
||||||
|
|||||||
@ -28,7 +28,7 @@ class BatchSampler(torch.utils.data.BatchSampler):
|
|||||||
self.shuffle = shuffle and is_training
|
self.shuffle = shuffle and is_training
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.total_samples
|
return (self.total_samples-1) // self.batch_size + 1
|
||||||
|
|
||||||
def set_epoch(self, epoch):
|
def set_epoch(self, epoch):
|
||||||
np.random.seed(epoch)
|
np.random.seed(epoch)
|
||||||
|
|||||||
@ -255,7 +255,6 @@ class Stats(object):
|
|||||||
self.waveform = None
|
self.waveform = None
|
||||||
self.last_drop_frames = 0
|
self.last_drop_frames = 0
|
||||||
|
|
||||||
|
|
||||||
@tables.register("model_classes", "FsmnVADStreaming")
|
@tables.register("model_classes", "FsmnVADStreaming")
|
||||||
class FsmnVADStreaming(nn.Module):
|
class FsmnVADStreaming(nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -500,7 +499,6 @@ class FsmnVADStreaming(nn.Module):
|
|||||||
# # reset class variables and clear the dict for the next query
|
# # reset class variables and clear the dict for the next query
|
||||||
# self.AllResetDetection()
|
# self.AllResetDetection()
|
||||||
return segments
|
return segments
|
||||||
|
|
||||||
|
|
||||||
def init_cache(self, cache: dict = {}, **kwargs):
|
def init_cache(self, cache: dict = {}, **kwargs):
|
||||||
|
|
||||||
|
|||||||
@ -147,9 +147,17 @@ class Trainer:
|
|||||||
for epoch in range(self.start_epoch, self.max_epoch + 1):
|
for epoch in range(self.start_epoch, self.max_epoch + 1):
|
||||||
|
|
||||||
self._train_epoch(epoch)
|
self._train_epoch(epoch)
|
||||||
|
|
||||||
|
|
||||||
|
if self.use_ddp or self.use_fsdp:
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
self._validate_epoch(epoch)
|
self._validate_epoch(epoch)
|
||||||
|
|
||||||
|
if self.use_ddp or self.use_fsdp:
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
self._save_checkpoint(epoch)
|
self._save_checkpoint(epoch)
|
||||||
|
|
||||||
@ -164,7 +172,9 @@ class Trainer:
|
|||||||
|
|
||||||
if self.use_ddp or self.use_fsdp:
|
if self.use_ddp or self.use_fsdp:
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
self.writer.close()
|
|
||||||
|
if self.writer:
|
||||||
|
self.writer.close()
|
||||||
|
|
||||||
|
|
||||||
def _train_epoch(self, epoch):
|
def _train_epoch(self, epoch):
|
||||||
@ -230,6 +240,8 @@ class Trainer:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Execute an optimization step (update model parameters)
|
# Execute an optimization step (update model parameters)
|
||||||
|
if self.use_ddp or self.use_fsdp:
|
||||||
|
dist.barrier()
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
self.scheduler.step()
|
self.scheduler.step()
|
||||||
# Clear gradients for the next accumulation stage
|
# Clear gradients for the next accumulation stage
|
||||||
@ -244,7 +256,7 @@ class Trainer:
|
|||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
if self.local_rank == 0:
|
if self.local_rank == 0:
|
||||||
description = (
|
description = (
|
||||||
f"Epoch: {epoch}/{self.max_epoch}, "
|
f"Train epoch: {epoch}/{self.max_epoch}, "
|
||||||
f"step {batch_idx}/{len(self.dataloader_train)}, "
|
f"step {batch_idx}/{len(self.dataloader_train)}, "
|
||||||
f"{speed_stats}, "
|
f"{speed_stats}, "
|
||||||
f"(loss: {loss.detach().cpu().item():.3f}), "
|
f"(loss: {loss.detach().cpu().item():.3f}), "
|
||||||
@ -306,7 +318,7 @@ class Trainer:
|
|||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
if self.local_rank == 0:
|
if self.local_rank == 0:
|
||||||
description = (
|
description = (
|
||||||
f"validation: \nEpoch: {epoch}/{self.max_epoch}, "
|
f"validation epoch: {epoch}/{self.max_epoch}, "
|
||||||
f"step {batch_idx}/{len(self.dataloader_train)}, "
|
f"step {batch_idx}/{len(self.dataloader_train)}, "
|
||||||
f"{speed_stats}, "
|
f"{speed_stats}, "
|
||||||
f"(loss: {loss.detach().cpu().item():.3f}), "
|
f"(loss: {loss.detach().cpu().item():.3f}), "
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user