From 11cf10e433c173efd892766b669e0bba57253fed Mon Sep 17 00:00:00 2001 From: zhifu gao Date: Mon, 29 Apr 2024 14:52:20 +0800 Subject: [PATCH] Dev gzf exp (#1678) * resume from step * batch * batch * batch --- funasr/datasets/audio_datasets/scp2jsonl.py | 13 +++++-- funasr/models/sense_voice/model.py | 2 + funasr/schedulers/lambdalr_cus.py | 42 ++++++++++++--------- funasr/tokenizer/abs_tokenizer.py | 2 +- 4 files changed, 38 insertions(+), 21 deletions(-) diff --git a/funasr/datasets/audio_datasets/scp2jsonl.py b/funasr/datasets/audio_datasets/scp2jsonl.py index f6ceb6977..f16717301 100644 --- a/funasr/datasets/audio_datasets/scp2jsonl.py +++ b/funasr/datasets/audio_datasets/scp2jsonl.py @@ -7,6 +7,7 @@ from omegaconf import DictConfig, OmegaConf import concurrent.futures import librosa import torch.distributed as dist +from tqdm import tqdm def gen_jsonl_from_wav_text_list( @@ -28,6 +29,7 @@ def gen_jsonl_from_wav_text_list( with open(data_file, "r") as f: data_file_lists = f.readlines() + print("") lines_for_each_th = (len(data_file_lists) - 1) // cpu_cores + 1 task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1 # import pdb;pdb.set_trace() @@ -41,6 +43,7 @@ def gen_jsonl_from_wav_text_list( i * lines_for_each_th : (i + 1) * lines_for_each_th ], data_type, + i, ) for i in range(task_num) ] @@ -69,11 +72,15 @@ def gen_jsonl_from_wav_text_list( dist.barrier() -def parse_context_length(data_list: list, data_type: str): - +def parse_context_length(data_list: list, data_type: str, id=0): + pbar = tqdm(total=len(data_list), dynamic_ncols=True) res = {} for i, line in enumerate(data_list): - key, line = line.strip().split(maxsplit=1) + pbar.update(1) + pbar.set_description(f"cpu: {id}") + lines = line.strip().split(maxsplit=1) + key = lines[0] + line = lines[1] if len(lines) > 1 else "" line = line.strip() if os.path.exists(line): waveform, _ = librosa.load(line, sr=16000) diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py index 07fb4eb58..ae20902cf 100644 --- a/funasr/models/sense_voice/model.py +++ b/funasr/models/sense_voice/model.py @@ -329,6 +329,8 @@ class SenseVoiceRWKV(nn.Module): stats["loss"] = torch.clone(loss.detach()) stats["batch_size"] = batch_size stats["batch_size_x_frames"] = frames * batch_size + stats["batch_size_real_frames"] = speech_lengths.sum().item() + stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"] # force_gatherable: to-device and to-tensor if scalar for DataParallel if self.length_normalized_loss: diff --git a/funasr/schedulers/lambdalr_cus.py b/funasr/schedulers/lambdalr_cus.py index 19ad7a8a7..e3bb1fb47 100644 --- a/funasr/schedulers/lambdalr_cus.py +++ b/funasr/schedulers/lambdalr_cus.py @@ -2,28 +2,36 @@ import torch from torch.optim.lr_scheduler import _LRScheduler +# class CustomLambdaLR(_LRScheduler): +# def __init__(self, optimizer, warmup_steps, last_epoch=-1): +# self.warmup_steps = warmup_steps +# super().__init__(optimizer, last_epoch) +# +# def get_lr(self): +# if self.last_epoch < self.warmup_steps: +# return [ +# base_lr * min(self.last_epoch / self.warmup_steps, 1) for base_lr in self.base_lrs +# ] +# else: +# return [base_lr for base_lr in self.base_lrs] + + class CustomLambdaLR(_LRScheduler): - def __init__(self, optimizer, warmup_steps, last_epoch=-1): + def __init__( + self, + optimizer, + warmup_steps: int = 25000, + total_steps: int = 500000, + last_epoch=-1, + verbose=False, + ): self.warmup_steps = warmup_steps - super().__init__(optimizer, last_epoch) + self.total_steps = total_steps + super().__init__(optimizer, last_epoch, verbose) def get_lr(self): - if self.last_epoch < self.warmup_steps: - return [ - base_lr * min(self.last_epoch / self.warmup_steps, 1) for base_lr in self.base_lrs - ] - else: - return [base_lr for base_lr in self.base_lrs] - -class CustomLambdaLR(_LRScheduler): - def __init__(self, optimizer, train_config, last_epoch=-1, verbose=False): - self.warmup_steps = train_config.warmup_steps - self.total_steps = train_config.total_steps - super(CustomLambdaLR, self).__init__(optimizer, last_epoch, verbose) - - def get_lr(self): - step = self._step_count + step = self.last_epoch + 1 if step < self.warmup_steps: lr_scale = step / self.warmup_steps else: diff --git a/funasr/tokenizer/abs_tokenizer.py b/funasr/tokenizer/abs_tokenizer.py index a629e94ff..e125d292b 100644 --- a/funasr/tokenizer/abs_tokenizer.py +++ b/funasr/tokenizer/abs_tokenizer.py @@ -62,7 +62,7 @@ class BaseTokenizer(ABC): raise RuntimeError(f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list") self.unk_id = self.token2id[self.unk_symbol] - def encode(self, text): + def encode(self, text, **kwargs): tokens = self.text2tokens(text) text_ints = self.tokens2ids(tokens)