* funasr1.0 funetine

* funasr1.0 pbar

* update with main (#1260)

* Update websocket_protocol_zh.md

* update

---------

Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>
Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>

* update with main (#1264)

* Funasr1.0 (#1261)

* funasr1.0 funetine

* funasr1.0 pbar

* update with main (#1260)

* Update websocket_protocol_zh.md

* update

---------

Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>
Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>

---------

Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>
Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>

* bug fix

---------

Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>
Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>

* funasr1.0 sanm scama

* funasr1.0 infer_after_finetune

* funasr1.0 fsmn-vad bug fix

* funasr1.0 fsmn-vad bug fix

* funasr1.0 fsmn-vad bug fix

* funasr1.0 finetune

* funasr1.0 finetune

* funasr1.0 finetune

---------

Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>
Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>
This commit is contained in:
zhifu gao 2024-01-19 17:05:08 +08:00 committed by GitHub
parent 12496e559f
commit 2cca8104d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 21 additions and 11 deletions

View File

@ -11,9 +11,9 @@ python funasr/bin/train.py \
+model_revision="v2.0.2" \ +model_revision="v2.0.2" \
+train_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \ +train_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \
+valid_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \ +valid_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \
++dataset_conf.batch_size=2 \ ++dataset_conf.batch_size=64 \
++dataset_conf.batch_type="example" \ ++dataset_conf.batch_type="example" \
++train_conf.max_epoch=2 \ ++train_conf.max_epoch=2 \
++dataset_conf.num_workers=4 \
+output_dir="outputs/debug/ckpt/funasr2/exp2" \ +output_dir="outputs/debug/ckpt/funasr2/exp2" \
+device="cpu" \
+debug="true" +debug="true"

View File

@ -132,7 +132,7 @@ class AutoModel:
self.punc_kwargs = punc_kwargs self.punc_kwargs = punc_kwargs
self.spk_model = spk_model self.spk_model = spk_model
self.spk_kwargs = spk_kwargs self.spk_kwargs = spk_kwargs
self.model_path = kwargs["model_path"] self.model_path = kwargs.get("model_path", "./")
def build_model(self, **kwargs): def build_model(self, **kwargs):

View File

@ -40,7 +40,7 @@ def main_hydra(kwargs: DictConfig):
def main(**kwargs): def main(**kwargs):
print(kwargs)
# set random seed # set random seed
tables.print() tables.print()
set_all_random_seed(kwargs.get("seed", 0)) set_all_random_seed(kwargs.get("seed", 0))

View File

@ -28,7 +28,7 @@ class BatchSampler(torch.utils.data.BatchSampler):
self.shuffle = shuffle and is_training self.shuffle = shuffle and is_training
def __len__(self): def __len__(self):
return self.total_samples return (self.total_samples-1) // self.batch_size + 1
def set_epoch(self, epoch): def set_epoch(self, epoch):
np.random.seed(epoch) np.random.seed(epoch)

View File

@ -255,7 +255,6 @@ class Stats(object):
self.waveform = None self.waveform = None
self.last_drop_frames = 0 self.last_drop_frames = 0
@tables.register("model_classes", "FsmnVADStreaming") @tables.register("model_classes", "FsmnVADStreaming")
class FsmnVADStreaming(nn.Module): class FsmnVADStreaming(nn.Module):
""" """
@ -500,7 +499,6 @@ class FsmnVADStreaming(nn.Module):
# # reset class variables and clear the dict for the next query # # reset class variables and clear the dict for the next query
# self.AllResetDetection() # self.AllResetDetection()
return segments return segments
def init_cache(self, cache: dict = {}, **kwargs): def init_cache(self, cache: dict = {}, **kwargs):

View File

@ -147,9 +147,17 @@ class Trainer:
for epoch in range(self.start_epoch, self.max_epoch + 1): for epoch in range(self.start_epoch, self.max_epoch + 1):
self._train_epoch(epoch) self._train_epoch(epoch)
if self.use_ddp or self.use_fsdp:
dist.barrier()
self._validate_epoch(epoch) self._validate_epoch(epoch)
if self.use_ddp or self.use_fsdp:
dist.barrier()
if self.rank == 0: if self.rank == 0:
self._save_checkpoint(epoch) self._save_checkpoint(epoch)
@ -164,7 +172,9 @@ class Trainer:
if self.use_ddp or self.use_fsdp: if self.use_ddp or self.use_fsdp:
dist.barrier() dist.barrier()
self.writer.close()
if self.writer:
self.writer.close()
def _train_epoch(self, epoch): def _train_epoch(self, epoch):
@ -230,6 +240,8 @@ class Trainer:
continue continue
# Execute an optimization step (update model parameters) # Execute an optimization step (update model parameters)
if self.use_ddp or self.use_fsdp:
dist.barrier()
self.optim.step() self.optim.step()
self.scheduler.step() self.scheduler.step()
# Clear gradients for the next accumulation stage # Clear gradients for the next accumulation stage
@ -244,7 +256,7 @@ class Trainer:
pbar.update(1) pbar.update(1)
if self.local_rank == 0: if self.local_rank == 0:
description = ( description = (
f"Epoch: {epoch}/{self.max_epoch}, " f"Train epoch: {epoch}/{self.max_epoch}, "
f"step {batch_idx}/{len(self.dataloader_train)}, " f"step {batch_idx}/{len(self.dataloader_train)}, "
f"{speed_stats}, " f"{speed_stats}, "
f"(loss: {loss.detach().cpu().item():.3f}), " f"(loss: {loss.detach().cpu().item():.3f}), "
@ -306,7 +318,7 @@ class Trainer:
pbar.update(1) pbar.update(1)
if self.local_rank == 0: if self.local_rank == 0:
description = ( description = (
f"validation: \nEpoch: {epoch}/{self.max_epoch}, " f"validation epoch: {epoch}/{self.max_epoch}, "
f"step {batch_idx}/{len(self.dataloader_train)}, " f"step {batch_idx}/{len(self.dataloader_train)}, "
f"{speed_stats}, " f"{speed_stats}, "
f"(loss: {loss.detach().cpu().item():.3f}), " f"(loss: {loss.detach().cpu().item():.3f}), "