mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
parent
b7ae3d5268
commit
11cf10e433
@ -7,6 +7,7 @@ from omegaconf import DictConfig, OmegaConf
|
||||
import concurrent.futures
|
||||
import librosa
|
||||
import torch.distributed as dist
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def gen_jsonl_from_wav_text_list(
|
||||
@ -28,6 +29,7 @@ def gen_jsonl_from_wav_text_list(
|
||||
with open(data_file, "r") as f:
|
||||
|
||||
data_file_lists = f.readlines()
|
||||
print("")
|
||||
lines_for_each_th = (len(data_file_lists) - 1) // cpu_cores + 1
|
||||
task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1
|
||||
# import pdb;pdb.set_trace()
|
||||
@ -41,6 +43,7 @@ def gen_jsonl_from_wav_text_list(
|
||||
i * lines_for_each_th : (i + 1) * lines_for_each_th
|
||||
],
|
||||
data_type,
|
||||
i,
|
||||
)
|
||||
for i in range(task_num)
|
||||
]
|
||||
@ -69,11 +72,15 @@ def gen_jsonl_from_wav_text_list(
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def parse_context_length(data_list: list, data_type: str):
|
||||
|
||||
def parse_context_length(data_list: list, data_type: str, id=0):
|
||||
pbar = tqdm(total=len(data_list), dynamic_ncols=True)
|
||||
res = {}
|
||||
for i, line in enumerate(data_list):
|
||||
key, line = line.strip().split(maxsplit=1)
|
||||
pbar.update(1)
|
||||
pbar.set_description(f"cpu: {id}")
|
||||
lines = line.strip().split(maxsplit=1)
|
||||
key = lines[0]
|
||||
line = lines[1] if len(lines) > 1 else ""
|
||||
line = line.strip()
|
||||
if os.path.exists(line):
|
||||
waveform, _ = librosa.load(line, sr=16000)
|
||||
|
||||
@ -329,6 +329,8 @@ class SenseVoiceRWKV(nn.Module):
|
||||
stats["loss"] = torch.clone(loss.detach())
|
||||
stats["batch_size"] = batch_size
|
||||
stats["batch_size_x_frames"] = frames * batch_size
|
||||
stats["batch_size_real_frames"] = speech_lengths.sum().item()
|
||||
stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
if self.length_normalized_loss:
|
||||
|
||||
@ -2,28 +2,36 @@ import torch
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
|
||||
# class CustomLambdaLR(_LRScheduler):
|
||||
# def __init__(self, optimizer, warmup_steps, last_epoch=-1):
|
||||
# self.warmup_steps = warmup_steps
|
||||
# super().__init__(optimizer, last_epoch)
|
||||
#
|
||||
# def get_lr(self):
|
||||
# if self.last_epoch < self.warmup_steps:
|
||||
# return [
|
||||
# base_lr * min(self.last_epoch / self.warmup_steps, 1) for base_lr in self.base_lrs
|
||||
# ]
|
||||
# else:
|
||||
# return [base_lr for base_lr in self.base_lrs]
|
||||
|
||||
|
||||
class CustomLambdaLR(_LRScheduler):
|
||||
def __init__(self, optimizer, warmup_steps, last_epoch=-1):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
warmup_steps: int = 25000,
|
||||
total_steps: int = 500000,
|
||||
last_epoch=-1,
|
||||
verbose=False,
|
||||
):
|
||||
self.warmup_steps = warmup_steps
|
||||
super().__init__(optimizer, last_epoch)
|
||||
self.total_steps = total_steps
|
||||
super().__init__(optimizer, last_epoch, verbose)
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch < self.warmup_steps:
|
||||
return [
|
||||
base_lr * min(self.last_epoch / self.warmup_steps, 1) for base_lr in self.base_lrs
|
||||
]
|
||||
else:
|
||||
return [base_lr for base_lr in self.base_lrs]
|
||||
|
||||
|
||||
class CustomLambdaLR(_LRScheduler):
|
||||
def __init__(self, optimizer, train_config, last_epoch=-1, verbose=False):
|
||||
self.warmup_steps = train_config.warmup_steps
|
||||
self.total_steps = train_config.total_steps
|
||||
super(CustomLambdaLR, self).__init__(optimizer, last_epoch, verbose)
|
||||
|
||||
def get_lr(self):
|
||||
step = self._step_count
|
||||
step = self.last_epoch + 1
|
||||
if step < self.warmup_steps:
|
||||
lr_scale = step / self.warmup_steps
|
||||
else:
|
||||
|
||||
@ -62,7 +62,7 @@ class BaseTokenizer(ABC):
|
||||
raise RuntimeError(f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list")
|
||||
self.unk_id = self.token2id[self.unk_symbol]
|
||||
|
||||
def encode(self, text):
|
||||
def encode(self, text, **kwargs):
|
||||
tokens = self.text2tokens(text)
|
||||
text_ints = self.tokens2ids(tokens)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user