mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Funasr1.0 (#1362)
* funasr1.0.5 * funasr1.0.5 audio samples input * batch_type token * batch_type token * huggingface model zoo * dataloader * dataloader * fbank input * vad is_final=True bugfix
This commit is contained in:
parent
410a85402d
commit
d92cd5ae03
@ -171,7 +171,7 @@ class AutoModel:
|
||||
# build model
|
||||
model_class = tables.model_classes.get(kwargs["model"])
|
||||
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
|
||||
model.eval()
|
||||
|
||||
model.to(device)
|
||||
|
||||
# init_param
|
||||
@ -206,6 +206,7 @@ class AutoModel:
|
||||
kwargs = self.kwargs if kwargs is None else kwargs
|
||||
kwargs.update(cfg)
|
||||
model = self.model if model is None else model
|
||||
model.eval()
|
||||
|
||||
batch_size = kwargs.get("batch_size", 1)
|
||||
# if kwargs.get("device", "cpu") == "cpu":
|
||||
|
||||
@ -6,8 +6,8 @@ import torch.distributed as dist
|
||||
from funasr.register import tables
|
||||
|
||||
|
||||
@tables.register("index_ds_classes", "IndexDSJsonl")
|
||||
class IndexDSJsonl(torch.utils.data.Dataset):
|
||||
@tables.register("index_ds_classes", "IndexDSJsonlRankSplit")
|
||||
class IndexDSJsonlRankSplit(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(self, path):
|
||||
super().__init__()
|
||||
@ -66,3 +66,53 @@ class IndexDSJsonl(torch.utils.data.Dataset):
|
||||
def get_target_len(self, data_dict):
|
||||
|
||||
return data_dict["target_len"] if "target_len" in data_dict else 0
|
||||
|
||||
@tables.register("index_ds_classes", "IndexDSJsonl")
|
||||
@tables.register("index_ds_classes", "IndexDSJsonlRankFull")
|
||||
class IndexDSJsonlRankFull(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(self, path):
|
||||
super().__init__()
|
||||
|
||||
contents = []
|
||||
with open(path, encoding='utf-8') as fin:
|
||||
for line in fin:
|
||||
data = json.loads(line.strip())
|
||||
if "text" in data: # for sft
|
||||
self.contents.append(data['text'])
|
||||
if "source" in data: # for speech lab pretrain
|
||||
prompt = data.get("prompt", "<ASR>")
|
||||
source = data["source"]
|
||||
target = data["target"]
|
||||
source_len = data.get("source_len", 1)
|
||||
target_len = data.get("target_len", 0)
|
||||
|
||||
contents.append({"source": source,
|
||||
"prompt": prompt,
|
||||
"target": target,
|
||||
"source_len": source_len,
|
||||
"target_len": target_len,
|
||||
}
|
||||
)
|
||||
|
||||
self.contents = contents
|
||||
|
||||
logging.info(
|
||||
"total_num of samplers across ranks: {}".format(len(self.contents)))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.contents)
|
||||
|
||||
def __getitem__(self, index):
|
||||
try:
|
||||
data = self.contents[index]
|
||||
except:
|
||||
print(index)
|
||||
return data
|
||||
|
||||
def get_source_len(self, data_dict):
|
||||
return data_dict.get("source_len", 1)
|
||||
|
||||
def get_target_len(self, data_dict):
|
||||
|
||||
return data_dict.get("target_len", 0)
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import logging
|
||||
import torch.distributed as dist
|
||||
|
||||
from funasr.register import tables
|
||||
|
||||
@ -82,3 +84,194 @@ class BatchSampler(torch.utils.data.BatchSampler):
|
||||
max_token = sample_len_cur_raw
|
||||
num_sample = 1
|
||||
|
||||
|
||||
@tables.register("batch_sampler_classes", "BatchSampler")
|
||||
@tables.register("batch_sampler_classes", "RankFullLocalShuffleBatchSampler")
|
||||
class RankFullLocalShuffleBatchSampler(torch.utils.data.BatchSampler):
|
||||
|
||||
def __init__(self, dataset,
|
||||
batch_type: str = "example",
|
||||
batch_size: int = 100,
|
||||
buffer_size: int = 30,
|
||||
drop_last: bool = True,
|
||||
shuffle: bool = True,
|
||||
is_training: bool = True,
|
||||
**kwargs):
|
||||
|
||||
self.drop_last = drop_last
|
||||
self.pre_idx = -1
|
||||
self.dataset = dataset
|
||||
self.total_samples = len(dataset)
|
||||
self.batch_type = batch_type
|
||||
self.batch_size = int(batch_size)
|
||||
self.buffer_size = buffer_size
|
||||
self.max_token_length = kwargs.get("max_token_length", 1500)
|
||||
self.shuffle_idx = np.arange(self.total_samples)
|
||||
self.shuffle = shuffle and is_training
|
||||
self.length_scale_source = kwargs.get("length_scale_source", 1.0)
|
||||
|
||||
try:
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
except:
|
||||
rank = 0
|
||||
world_size = 1
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
|
||||
def __len__(self):
|
||||
return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
np.random.seed(epoch)
|
||||
|
||||
def __iter__(self):
|
||||
|
||||
batch_size_total = self.batch_size * self.world_size
|
||||
|
||||
if self.shuffle:
|
||||
np.random.shuffle(self.shuffle_idx)
|
||||
|
||||
batch = []
|
||||
max_token = 0
|
||||
num_sample = 0
|
||||
|
||||
iter_num = (self.total_samples - 1) // self.buffer_size + 1
|
||||
# print("iter_num: ", iter_num)
|
||||
for iter in range(self.pre_idx + 1, iter_num):
|
||||
# if iter == iter_num -1 and self.drop_last:
|
||||
# continue
|
||||
datalen_with_index = []
|
||||
for i in range(self.buffer_size):
|
||||
idx = iter * self.buffer_size + i
|
||||
if idx >= self.total_samples:
|
||||
continue
|
||||
|
||||
idx_map = self.shuffle_idx[idx]
|
||||
# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
|
||||
|
||||
source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
|
||||
target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
|
||||
sample_len_cur = source_len + target_len
|
||||
|
||||
datalen_with_index.append([idx, sample_len_cur])
|
||||
|
||||
datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
|
||||
for item in datalen_with_index_sort:
|
||||
idx, sample_len_cur_raw = item
|
||||
if sample_len_cur_raw > self.max_token_length:
|
||||
continue
|
||||
|
||||
max_token_cur = max(max_token, sample_len_cur_raw)
|
||||
max_token_padding = 1 + num_sample
|
||||
# if self.batch_type != 'example':
|
||||
# max_token_padding *= max_token_cur
|
||||
if max_token_padding <= batch_size_total:
|
||||
batch.append(idx)
|
||||
max_token = max_token_cur
|
||||
num_sample += 1
|
||||
else:
|
||||
batch_rank = batch[self.rank*self.batch_size: (self.rank+1)*self.batch_size]
|
||||
yield batch_rank
|
||||
batch = [idx]
|
||||
max_token = sample_len_cur_raw
|
||||
num_sample = 1
|
||||
|
||||
|
||||
@tables.register("batch_sampler_classes", "RankFullLocalShuffleDynamicBatchSampler")
|
||||
class RankFullLocalShuffleDynamicBatchSampler(torch.utils.data.BatchSampler):
|
||||
|
||||
def __init__(self, dataset,
|
||||
batch_type: str = "example",
|
||||
batch_size: int = 100,
|
||||
buffer_size: int = 30,
|
||||
drop_last: bool = True,
|
||||
shuffle: bool = True,
|
||||
is_training: bool = True,
|
||||
**kwargs):
|
||||
|
||||
self.drop_last = drop_last
|
||||
self.pre_idx = -1
|
||||
self.dataset = dataset
|
||||
self.total_samples = len(dataset)
|
||||
self.batch_type = batch_type
|
||||
self.batch_size = int(batch_size)
|
||||
self.buffer_size = buffer_size
|
||||
self.max_token_length = kwargs.get("max_token_length", 1500)
|
||||
self.shuffle_idx = np.arange(self.total_samples)
|
||||
self.shuffle = shuffle and is_training
|
||||
self.length_scale_source = kwargs.get("length_scale_source", 1.0)
|
||||
|
||||
try:
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
except:
|
||||
rank = 0
|
||||
world_size = 1
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
|
||||
def __len__(self):
|
||||
return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
np.random.seed(epoch)
|
||||
|
||||
def __iter__(self):
|
||||
|
||||
batch_size_total = self.batch_size * self.world_size
|
||||
if self.shuffle:
|
||||
np.random.shuffle(self.shuffle_idx)
|
||||
|
||||
batch_list_all_rank = []
|
||||
batch_list_cur = []
|
||||
max_token = 0
|
||||
num_sample = 0
|
||||
|
||||
iter_num = (self.total_samples - 1) // self.buffer_size + 1
|
||||
# print("iter_num: ", iter_num)
|
||||
for iter in range(self.pre_idx + 1, iter_num):
|
||||
# if iter == iter_num - 1 and self.drop_last:
|
||||
# continue
|
||||
datalen_with_index = []
|
||||
for i in range(self.buffer_size):
|
||||
idx = iter * self.buffer_size + i
|
||||
if idx >= self.total_samples:
|
||||
continue
|
||||
|
||||
idx_map = self.shuffle_idx[idx]
|
||||
# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
|
||||
|
||||
source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
|
||||
target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
|
||||
sample_len_cur = source_len + target_len
|
||||
|
||||
datalen_with_index.append([idx, sample_len_cur])
|
||||
|
||||
datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
|
||||
for ii, item in enumerate(datalen_with_index_sort):
|
||||
is_last_batch = iter == iter_num - 1 and ii == len(datalen_with_index_sort)
|
||||
idx, sample_len_cur_raw = item
|
||||
if sample_len_cur_raw > self.max_token_length:
|
||||
continue
|
||||
|
||||
max_token_cur = max(max_token, sample_len_cur_raw)
|
||||
max_token_padding = 1 + num_sample
|
||||
|
||||
if self.batch_type != 'example':
|
||||
max_token_padding *= max_token_cur
|
||||
if len(batch_list_all_rank) < self.world_size:
|
||||
|
||||
if max_token_padding <= self.batch_size:
|
||||
batch_list_cur.append(idx)
|
||||
max_token = max_token_cur
|
||||
num_sample += 1
|
||||
else:
|
||||
batch_list_all_rank.append(batch_list_cur)
|
||||
batch_list_cur = []
|
||||
else:
|
||||
batch_rank = batch_list_all_rank[self.rank]
|
||||
yield batch_rank
|
||||
batch_list_all_rank = [idx]
|
||||
max_token = sample_len_cur_raw
|
||||
num_sample = 1
|
||||
|
||||
@ -575,7 +575,8 @@ class FsmnVADStreaming(nn.Module):
|
||||
|
||||
time1 = time.perf_counter()
|
||||
is_streaming_input = kwargs.get("is_streaming_input", False) if chunk_size >= 15000 else kwargs.get("is_streaming_input", True)
|
||||
cfg = {"is_final": kwargs.get("is_final", False), "is_streaming_input": is_streaming_input}
|
||||
is_final = kwargs.get("is_final", False) if is_streaming_input else kwargs.get("is_final", True)
|
||||
cfg = {"is_final": is_final, "is_streaming_input": is_streaming_input}
|
||||
audio_sample_list = load_audio_text_image_video(data_in,
|
||||
fs=frontend.fs,
|
||||
audio_fs=kwargs.get("fs", 16000),
|
||||
|
||||
@ -186,7 +186,7 @@ class CifPredictorV2(torch.nn.Module):
|
||||
alphas = alphas.squeeze(-1)
|
||||
mask = mask.squeeze(-1)
|
||||
if target_label_length is not None:
|
||||
target_length = target_label_length
|
||||
target_length = target_label_length.squeeze(-1)
|
||||
elif target_label is not None:
|
||||
target_length = (target_label != ignore_id).float().sum(-1)
|
||||
else:
|
||||
|
||||
@ -491,6 +491,8 @@ class Paraformer(torch.nn.Module):
|
||||
b, n, d = decoder_out.size()
|
||||
if isinstance(key[0], (list, tuple)):
|
||||
key = key[0]
|
||||
if len(key) < b:
|
||||
key = key*b
|
||||
for i in range(b):
|
||||
x = encoder_out[i, :encoder_out_lens[i], :]
|
||||
am_scores = decoder_out[i, :pre_token_length[i], :]
|
||||
|
||||
@ -204,7 +204,25 @@ class Trainer:
|
||||
my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
|
||||
with my_context():
|
||||
time2 = time.perf_counter()
|
||||
print("before, GPU, memory: {:.1} MB, "
|
||||
"{:.1} MB, "
|
||||
"{:.1} MB, "
|
||||
"{:.1} MB".format(torch.cuda.memory_allocated()/1024/1024/1024,
|
||||
torch.cuda.max_memory_allocated()/1024/1024/1024,
|
||||
torch.cuda.memory_reserved()/1024/1024/1024,
|
||||
torch.cuda.max_memory_reserved()/1024/1024/1024,
|
||||
))
|
||||
|
||||
retval = self.model(**batch)
|
||||
torch.cuda.empty_cache()
|
||||
print("after, GPU, memory: {:.1} MB, "
|
||||
"{:.1} MB, "
|
||||
"{:.1} MB, "
|
||||
"{:.1} MB".format(torch.cuda.memory_allocated()/1024/1024/1024,
|
||||
torch.cuda.max_memory_allocated()/1024/1024/1024,
|
||||
torch.cuda.memory_reserved()/1024/1024/1024,
|
||||
torch.cuda.max_memory_reserved()/1024/1024/1024,
|
||||
))
|
||||
time3 = time.perf_counter()
|
||||
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
|
||||
loss, stats, weight = retval
|
||||
|
||||
Loading…
Reference in New Issue
Block a user