Dev gzf exp (#1678)

* resume from step

* batch

* batch

* batch
This commit is contained in:
zhifu gao 2024-04-29 14:52:20 +08:00 committed by GitHub
parent b7ae3d5268
commit 11cf10e433
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 38 additions and 21 deletions

View File

@ -7,6 +7,7 @@ from omegaconf import DictConfig, OmegaConf
import concurrent.futures
import librosa
import torch.distributed as dist
from tqdm import tqdm
def gen_jsonl_from_wav_text_list(
@ -28,6 +29,7 @@ def gen_jsonl_from_wav_text_list(
with open(data_file, "r") as f:
data_file_lists = f.readlines()
print("")
lines_for_each_th = (len(data_file_lists) - 1) // cpu_cores + 1
task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1
# import pdb;pdb.set_trace()
@ -41,6 +43,7 @@ def gen_jsonl_from_wav_text_list(
i * lines_for_each_th : (i + 1) * lines_for_each_th
],
data_type,
i,
)
for i in range(task_num)
]
@ -69,11 +72,15 @@ def gen_jsonl_from_wav_text_list(
dist.barrier()
def parse_context_length(data_list: list, data_type: str):
def parse_context_length(data_list: list, data_type: str, id=0):
pbar = tqdm(total=len(data_list), dynamic_ncols=True)
res = {}
for i, line in enumerate(data_list):
key, line = line.strip().split(maxsplit=1)
pbar.update(1)
pbar.set_description(f"cpu: {id}")
lines = line.strip().split(maxsplit=1)
key = lines[0]
line = lines[1] if len(lines) > 1 else ""
line = line.strip()
if os.path.exists(line):
waveform, _ = librosa.load(line, sr=16000)

View File

@ -329,6 +329,8 @@ class SenseVoiceRWKV(nn.Module):
stats["loss"] = torch.clone(loss.detach())
stats["batch_size"] = batch_size
stats["batch_size_x_frames"] = frames * batch_size
stats["batch_size_real_frames"] = speech_lengths.sum().item()
stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:

View File

@ -2,28 +2,36 @@ import torch
from torch.optim.lr_scheduler import _LRScheduler
# class CustomLambdaLR(_LRScheduler):
# def __init__(self, optimizer, warmup_steps, last_epoch=-1):
# self.warmup_steps = warmup_steps
# super().__init__(optimizer, last_epoch)
#
# def get_lr(self):
# if self.last_epoch < self.warmup_steps:
# return [
# base_lr * min(self.last_epoch / self.warmup_steps, 1) for base_lr in self.base_lrs
# ]
# else:
# return [base_lr for base_lr in self.base_lrs]
class CustomLambdaLR(_LRScheduler):
def __init__(self, optimizer, warmup_steps, last_epoch=-1):
def __init__(
self,
optimizer,
warmup_steps: int = 25000,
total_steps: int = 500000,
last_epoch=-1,
verbose=False,
):
self.warmup_steps = warmup_steps
super().__init__(optimizer, last_epoch)
self.total_steps = total_steps
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
if self.last_epoch < self.warmup_steps:
return [
base_lr * min(self.last_epoch / self.warmup_steps, 1) for base_lr in self.base_lrs
]
else:
return [base_lr for base_lr in self.base_lrs]
class CustomLambdaLR(_LRScheduler):
def __init__(self, optimizer, train_config, last_epoch=-1, verbose=False):
self.warmup_steps = train_config.warmup_steps
self.total_steps = train_config.total_steps
super(CustomLambdaLR, self).__init__(optimizer, last_epoch, verbose)
def get_lr(self):
step = self._step_count
step = self.last_epoch + 1
if step < self.warmup_steps:
lr_scale = step / self.warmup_steps
else:

View File

@ -62,7 +62,7 @@ class BaseTokenizer(ABC):
raise RuntimeError(f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list")
self.unk_id = self.token2id[self.unk_symbol]
def encode(self, text):
def encode(self, text, **kwargs):
tokens = self.text2tokens(text)
text_ints = self.tokens2ids(tokens)