Dev gzf exp (#1647)

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* bugfix

* update with main (#1631)

* update seaco finetune

* v1.0.24

---------

Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>

* sensevoice

* sensevoice

* sensevoice

* update with main (#1638)

* update seaco finetune

* v1.0.24

* update rwkv template

---------

Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

---------

Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>
This commit is contained in:
zhifu gao 2024-04-23 18:08:57 +08:00 committed by GitHub
parent 0a4a1d5257
commit 2ac38adbe5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 194 additions and 135 deletions

View File

@ -176,15 +176,12 @@ def main(**kwargs):
except: except:
writer = None writer = None
# if use_ddp or use_fsdp:
# context = Join([model])
# else:
# context = nullcontext()
context = nullcontext()
for epoch in range(trainer.start_epoch, trainer.max_epoch + 1): for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
time1 = time.perf_counter() time1 = time.perf_counter()
with context:
dataloader_tr, dataloader_val = dataloader.build_iter(epoch) for data_split_i in range(dataloader.data_split_num):
dataloader_tr, dataloader_val = dataloader.build_iter(epoch, data_split_i=data_split_i)
trainer.train_epoch( trainer.train_epoch(
model=model, model=model,
optim=optim, optim=optim,
@ -193,15 +190,17 @@ def main(**kwargs):
dataloader_train=dataloader_tr, dataloader_train=dataloader_tr,
dataloader_val=dataloader_val, dataloader_val=dataloader_val,
epoch=epoch, epoch=epoch,
writer=writer writer=writer,
data_split_i=data_split_i,
data_split_num=dataloader.data_split_num,
) )
with context:
trainer.validate_epoch( trainer.validate_epoch(
model=model, model=model,
dataloader_val=dataloader_val, dataloader_val=dataloader_val,
epoch=epoch, epoch=epoch,
writer=writer writer=writer
) )
scheduler.step() scheduler.step()

View File

@ -2,8 +2,9 @@ import os
import json import json
import torch import torch
import logging import logging
import concurrent.futures
import librosa import librosa
import random
import torch.distributed as dist import torch.distributed as dist
from funasr.register import tables from funasr.register import tables
@ -44,7 +45,7 @@ from funasr.register import tables
# except: # except:
# rank = 0 # rank = 0
# world_size = 1 # world_size = 1
# logging.warning("distributed is not initialized, only single shard") # logging.info("distributed is not initialized, only single shard")
# num_per_rank = total_num // world_size # num_per_rank = total_num // world_size
# #
# # rank = 0 # # rank = 0
@ -72,6 +73,7 @@ from funasr.register import tables
@tables.register("index_ds_classes", "IndexDSJsonl") @tables.register("index_ds_classes", "IndexDSJsonl")
@tables.register("index_ds_classes", "IndexDSJsonlRankFull") @tables.register("index_ds_classes", "IndexDSJsonlRankFull")
@tables.register("index_ds_classes", "IndexDSJsonlRankSplit")
class IndexDSJsonlRankFull(torch.utils.data.Dataset): class IndexDSJsonlRankFull(torch.utils.data.Dataset):
def __init__(self, path: str, **kwargs): def __init__(self, path: str, **kwargs):
@ -80,83 +82,27 @@ class IndexDSJsonlRankFull(torch.utils.data.Dataset):
self.min_source_length = kwargs.get("min_source_length", 0) self.min_source_length = kwargs.get("min_source_length", 0)
self.max_target_length = kwargs.get("max_target_length", 2048) self.max_target_length = kwargs.get("max_target_length", 2048)
self.min_target_length = kwargs.get("min_target_length", 0) self.min_target_length = kwargs.get("min_target_length", 0)
if isinstance(path, (list, tuple)): # wav.scp, text.txt/text.trans
from funasr.datasets.audio_datasets.scp2jsonl import gen_jsonl_from_wav_text_list
jsonl_outdir = os.path.dirname(path[0])
jsonl_name = "datalist_train.jsonl" if kwargs.get("is_training", True) else "datalist_val.jsonl"
jsonl_file_out = os.path.join(jsonl_outdir, jsonl_name)
if not os.path.exists(jsonl_file_out):
print(f"datalist is: {path}, generate jsonl from it")
gen_jsonl_from_wav_text_list(path, jsonl_file_out=jsonl_file_out, **kwargs)
path = jsonl_file_out
contents = [] is_training = kwargs.get("is_training", True)
with open(path, encoding='utf-8') as fin: if not (path.endswith(".jsonl") or path.endswith(".json")):
for line in fin: # jsonl list file
data = json.loads(line.strip()) data_split_num = kwargs.get("data_split_num", 1)
if "text" in data: # for sft data_split_i = kwargs.get("data_split_i", 0)
contents.append(data['text'])
if "source" in data: # for speech lab pretrain
prompt = data.get("prompt", "<ASR>")
source = data["source"]
target = data["target"]
source_len = data.get("source_len", 1)
target_len = data.get("target_len", 0)
if "aishell" in source:
target = target.replace(" ", "")
if source_len < self.min_source_length or source_len > self.max_source_length:
continue
if target_len < self.min_target_length or target_len > self.max_target_length:
continue
contents_i = {"source": source,
"prompt": prompt,
"target": target,
"source_len": source_len,
"target_len": target_len,
}
text_language = data.get("text_language", None)
if text_language is not None:
contents_i["text_language"] = text_language
audio_language = data.get("audio_language", None)
if audio_language is not None:
contents_i["audio_language"] = audio_language
contents.append(contents_i)
self.contents = contents if not is_training:
data_split_num = 1
data_split_i = 0
with open(path, encoding='utf-8') as fin:
file_list_all = fin.readlines()
logging.info( num_per_slice = len(file_list_all) // data_split_num
"total_num of samplers across ranks: {}".format(len(self.contents))) file_list = file_list_all[data_split_i * num_per_slice:(data_split_i + 1) * num_per_slice]
logging.info(
f"is_training: {is_training}, data_split_num: {data_split_num}, data_split_i: {data_split_i}, \nfile_list: {file_list}, \nfile_list_all: {file_list_all}")
def __len__(self): else:
return len(self.contents) file_list = [path]
def __getitem__(self, index):
try:
data = self.contents[index]
except:
print(index)
return data
def get_source_len(self, data_dict):
return data_dict.get("source_len", 1)
def get_target_len(self, data_dict):
return data_dict.get("target_len", 0)
@tables.register("index_ds_classes", "IndexDSJsonlRankSplit")
class IndexDSJsonlRankSplit(torch.utils.data.Dataset):
def __init__(self, path: str, **kwargs):
super().__init__()
self.max_source_length = kwargs.get("max_source_length", 2048)
self.min_source_length = kwargs.get("min_source_length", 0)
self.max_target_length = kwargs.get("max_target_length", 2048)
self.min_target_length = kwargs.get("min_target_length", 0)
with open(path, encoding='utf-8') as fin:
file_list = fin.readlines()
total_num = len(file_list) total_num = len(file_list)
try: try:
@ -165,16 +111,30 @@ class IndexDSJsonlRankSplit(torch.utils.data.Dataset):
except: except:
rank = 0 rank = 0
world_size = 1 world_size = 1
logging.warning("distributed is not initialized, only single shard") logging.info("distributed is not initialized, only single shard")
if not kwargs.get("rank_split", False):
logging.info(f"Warning, rank_split disenabled, batch and shuffle data in global")
rank = 0
world_size = 1
num_per_rank = total_num // world_size num_per_rank = total_num // world_size
if num_per_rank * world_size < total_num: if num_per_rank * world_size < total_num:
logging.warning(f"Warning, jsonl file:{total_num} could not be divided by world_size: {world_size}, {path}") logging.info(f"Warning, jsonl file:{total_num} could not be divided by world_size: {world_size}, {path}")
total_num_needed = num_per_rank * world_size
extra_num = total_num_needed - total_num
file_list_tmp = random.choices(file_list, k=extra_num)
file_list += file_list_tmp
logging.info(f"Warning, after random choices: {file_list}")
file_list_rank = file_list[rank * num_per_rank:(rank + 1) * num_per_rank] file_list_rank = file_list[rank * num_per_rank:(rank + 1) * num_per_rank]
logging.info(
f"is_training: {is_training}, file_list_rank: {file_list_rank}")
contents = [] contents = []
for file_json in file_list_rank: for file_json in file_list_rank:
with open(file_json.strip(), encoding='utf-8') as fin: with open(file_json.strip(), encoding='utf-8') as fin:
for line in fin: for line in fin:
data = json.loads(line.strip()) data = json.loads(line.strip())
@ -182,41 +142,42 @@ class IndexDSJsonlRankSplit(torch.utils.data.Dataset):
contents.append(data['text']) contents.append(data['text'])
if "source" in data: # for speech lab pretrain if "source" in data: # for speech lab pretrain
prompt = data.get("prompt", "<ASR>") prompt = data.get("prompt", "<ASR>")
source = data["source"].replace("/cpfs01", "/cpfs_speech/data") source = data["source"].replace("/cpfs01", "/cpfs_speech/data") # only use in alibaba gpu group: .replace("/cpfs01", "/cpfs_speech/data")
target = data["target"] target = data["target"]
source_len = data.get("source_len", 1) source_len = data.get("source_len", 1)
target_len = data.get("target_len", 0) target_len = data.get("target_len", 0)
if "aishell" in source:
target = target.replace(" ", "")
if source_len < self.min_source_length or source_len > self.max_source_length: if source_len < self.min_source_length or source_len > self.max_source_length:
continue continue
if target_len < self.min_target_length or target_len > self.max_target_length: if target_len < self.min_target_length or target_len > self.max_target_length:
continue continue
contents_i = {"source": source, contents_i = {"source": source,
"prompt": prompt, "prompt": prompt,
"target": target, "target": target,
"source_len": source_len, "source_len": source_len,
"target_len": target_len, "target_len": target_len,
} }
text_language = data.get("text_language", None) text_language = data.get("text_language", None)
if text_language is not None: if text_language is not None:
contents_i["text_language"] = text_language contents_i["text_language"] = text_language
audio_language = data.get("audio_language", None) # audio_language = data.get("audio_language", None)
if audio_language is not None: # if audio_language is not None:
contents_i["audio_language"] = audio_language # contents_i["audio_language"] = audio_language
contents.append(contents_i) contents.append(contents_i)
self.contents = contents self.contents = contents
logging.info(f"total_num: {len(self.contents)} of samplers in ranks: {rank}") logging.info(
"total_num of samplers: {}, {}".format(len(self.contents), path))
def __len__(self): def __len__(self):
return len(self.contents) return len(self.contents)
def __getitem__(self, index): def __getitem__(self, index):
try:
data = self.contents[index] data = self.contents[index]
except:
print(index)
return data return data
def get_source_len(self, data_dict): def get_source_len(self, data_dict):
@ -225,3 +186,96 @@ class IndexDSJsonlRankSplit(torch.utils.data.Dataset):
def get_target_len(self, data_dict): def get_target_len(self, data_dict):
return data_dict.get("target_len", 0) return data_dict.get("target_len", 0)
#
# @tables.register("index_ds_classes", "IndexDSJsonlRankSplit")
# class IndexDSJsonlRankSplit(torch.utils.data.Dataset):
#
# def __init__(self, path: str, **kwargs):
# super().__init__()
# logging.info("building IndexDS")
# self.max_source_length = kwargs.get("max_source_length", 2048)
# self.min_source_length = kwargs.get("min_source_length", 0)
# self.max_target_length = kwargs.get("max_target_length", 2048)
# self.min_target_length = kwargs.get("min_target_length", 0)
#
# data_split_num = kwargs.get("data_split_num", 1)
# data_split_i = kwargs.get("data_split_i", 0)
# if not kwargs.get("is_training", True):
# data_split_num = 1
# data_split_i = 0
# with open(path, encoding='utf-8') as fin:
# file_list_all = fin.readlines()
#
# num_per_slice = len(file_list_all) // data_split_num
# file_list = file_list_all[data_split_i * num_per_slice:(data_split_i + 1) * num_per_slice]
# logging.info(f"data_split_num: {data_split_num}, data_split_i: {data_split_i}, file_list: {file_list}, file_list_all: {file_list_all}")
#
#
# total_num = len(file_list)
# try:
# rank = dist.get_rank()
# world_size = dist.get_world_size()
# except:
# rank = 0
# world_size = 1
# logging.info("distributed is not initialized, only single shard")
# num_per_rank = total_num // world_size
# if num_per_rank * world_size < total_num:
# logging.info(f"Warning, jsonl file:{total_num} could not be divided by world_size: {world_size}, {path}")
#
# file_list_rank = file_list[rank * num_per_rank:(rank + 1) * num_per_rank]
#
# contents = []
# for file_json in file_list_rank:
#
# with open(file_json.strip(), encoding='utf-8') as fin:
# for line in fin:
# data = json.loads(line.strip())
# if "text" in data: # for sft
# contents.append(data['text'])
# if "source" in data: # for speech lab pretrain
# prompt = data.get("prompt", "<ASR>")
# source = data["source"].replace("/cpfs01", "/cpfs_speech/data")
# target = data["target"]
# source_len = data.get("source_len", 1)
# target_len = data.get("target_len", 0)
#
# if source_len < self.min_source_length or source_len > self.max_source_length:
# continue
# if target_len < self.min_target_length or target_len > self.max_target_length:
# continue
# contents_i = {"source": source,
# "prompt": prompt,
# "target": target,
# "source_len": source_len,
# "target_len": target_len,
# }
# text_language = data.get("text_language", None)
# if text_language is not None:
# contents_i["text_language"] = text_language
# # audio_language = data.get("audio_language", None)
# # if audio_language is not None:
# # contents_i["audio_language"] = audio_language
# contents.append(contents_i)
#
# self.contents = contents
#
# logging.info(f"total_num: {len(self.contents)} of samplers in ranks: {rank}, file_list_rank: {file_list_rank}")
#
# def __len__(self):
# return len(self.contents)
#
# def __getitem__(self, index):
# try:
# data = self.contents[index]
# except:
# print(index)
# return data
#
# def get_source_len(self, data_dict):
# return data_dict.get("source_len", 1)
#
# def get_target_len(self, data_dict):
#
# return data_dict.get("target_len", 0)

View File

@ -301,6 +301,7 @@ class CustomDistributedBufferDynamicBatchSampler(DistributedSampler):
batch_type="token", batch_type="token",
num_replicas=None, num_replicas=None,
rank=None, rank=None,
rank_split=False,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
is_training: bool = True, is_training: bool = True,
@ -314,6 +315,12 @@ class CustomDistributedBufferDynamicBatchSampler(DistributedSampler):
except: except:
rank = 0 rank = 0
num_replicas = 1 num_replicas = 1
if rank_split:
logging.info(f"Warning, rank_split: {rank_split}, batch and shuffle data in local rank")
rank = 0
num_replicas = 1
self.rank = rank self.rank = rank
self.num_replicas = num_replicas self.num_replicas = num_replicas
self.dataset = dataset self.dataset = dataset

View File

@ -40,7 +40,21 @@ class DataloaderMapStyle:
self.dataset_val = dataset_val self.dataset_val = dataset_val
self.kwargs = kwargs self.kwargs = kwargs
def build_iter(self, epoch=0): # split dataset
self.data_split_num = kwargs["dataset_conf"].get("data_split_num", 1)
self.dataset_class = dataset_class
self.frontend = frontend
self.tokenizer = tokenizer
self.kwargs = kwargs
def build_iter(self, epoch=0, data_split_i=0, **kwargs):
# reload dataset slice
if self.data_split_num > 1:
del self.dataset_tr
self.dataset_tr = self.dataset_class(self.kwargs.get("train_data_set_list"), frontend=self.frontend, tokenizer=self.tokenizer,
is_training=True, **self.kwargs.get("dataset_conf"), data_split_i=data_split_i)
# dataloader # dataloader
batch_sampler = self.kwargs["dataset_conf"].get("batch_sampler", "BatchSampler") batch_sampler = self.kwargs["dataset_conf"].get("batch_sampler", "BatchSampler")
batch_sampler_val = None batch_sampler_val = None

View File

@ -245,29 +245,7 @@ class SenseVoiceDecoder(nn.Module):
self.register_buffer("mask", mask, persistent=False) self.register_buffer("mask", mask, persistent=False)
self.use_padmask = kwargs.get("use_padmask", True) self.use_padmask = kwargs.get("use_padmask", True)
# def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
# """
# x : torch.LongTensor, shape = (batch_size, <= n_ctx)
# the text tokens
# xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
# the encoded audio features to be attended on
# """
# offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
# x = (
# self.token_embedding(x)
# + self.positional_embedding[offset: offset + x.shape[-1]]
# )
# x = x.to(xa.dtype)
#
# for block in self.blocks:
# x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
#
# x = self.ln(x)
# logits = (
# x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
# ).float()
#
# return logits
def forward( def forward(

View File

@ -252,6 +252,7 @@ class Trainer:
dataloader_val=None, dataloader_val=None,
epoch=None, epoch=None,
writer=None, writer=None,
**kwargs,
): ):
""" """
Defines the training process for a single epoch with gradient accumulation. Defines the training process for a single epoch with gradient accumulation.
@ -374,6 +375,8 @@ class Trainer:
stats=stats, stats=stats,
writer=writer, writer=writer,
tag="train", tag="train",
data_split_i=kwargs.get("data_split_i", 0),
data_split_num=kwargs.get("data_split_num", 1),
) )
if (batch_idx + 1) % self.validate_interval == 0: if (batch_idx + 1) % self.validate_interval == 0:
@ -507,6 +510,9 @@ class Trainer:
stats=None, stats=None,
writer=None, writer=None,
tag="train", tag="train",
data_split_i=0,
data_split_num=1,
**kwargs,
): ):
if (batch_idx + 1) % self.log_interval == 0: if (batch_idx + 1) % self.log_interval == 0:
@ -526,6 +532,7 @@ class Trainer:
f"{tag}, " f"{tag}, "
f"rank: {self.local_rank}, " f"rank: {self.local_rank}, "
f"epoch: {epoch}/{self.max_epoch}, " f"epoch: {epoch}/{self.max_epoch}, "
f"data_slice: {data_split_i}/{data_split_num}, "
f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, " f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, "
f"(loss_avg_rank: {loss:.3f}), " f"(loss_avg_rank: {loss:.3f}), "
f"(loss_avg_epoch: {loss_avg_epoch:.3f}), " f"(loss_avg_epoch: {loss_avg_epoch:.3f}), "