Dev gzf exp (#1647)

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* bugfix

* update with main (#1631)

* update seaco finetune

* v1.0.24

---------

Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>

* sensevoice

* sensevoice

* sensevoice

* update with main (#1638)

* update seaco finetune

* v1.0.24

* update rwkv template

---------

Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

---------

Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>
This commit is contained in:
zhifu gao 2024-04-23 18:08:57 +08:00 committed by GitHub
parent 0a4a1d5257
commit 2ac38adbe5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 194 additions and 135 deletions

View File

@ -176,15 +176,12 @@ def main(**kwargs):
except: except:
writer = None writer = None
# if use_ddp or use_fsdp:
# context = Join([model])
# else:
# context = nullcontext()
context = nullcontext()
for epoch in range(trainer.start_epoch, trainer.max_epoch + 1): for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
time1 = time.perf_counter() time1 = time.perf_counter()
with context:
dataloader_tr, dataloader_val = dataloader.build_iter(epoch) for data_split_i in range(dataloader.data_split_num):
dataloader_tr, dataloader_val = dataloader.build_iter(epoch, data_split_i=data_split_i)
trainer.train_epoch( trainer.train_epoch(
model=model, model=model,
optim=optim, optim=optim,
@ -193,9 +190,11 @@ def main(**kwargs):
dataloader_train=dataloader_tr, dataloader_train=dataloader_tr,
dataloader_val=dataloader_val, dataloader_val=dataloader_val,
epoch=epoch, epoch=epoch,
writer=writer writer=writer,
data_split_i=data_split_i,
data_split_num=dataloader.data_split_num,
) )
with context:
trainer.validate_epoch( trainer.validate_epoch(
model=model, model=model,
dataloader_val=dataloader_val, dataloader_val=dataloader_val,

View File

@ -2,8 +2,9 @@ import os
import json import json
import torch import torch
import logging import logging
import concurrent.futures
import librosa import librosa
import random
import torch.distributed as dist import torch.distributed as dist
from funasr.register import tables from funasr.register import tables
@ -44,7 +45,7 @@ from funasr.register import tables
# except: # except:
# rank = 0 # rank = 0
# world_size = 1 # world_size = 1
# logging.warning("distributed is not initialized, only single shard") # logging.info("distributed is not initialized, only single shard")
# num_per_rank = total_num // world_size # num_per_rank = total_num // world_size
# #
# # rank = 0 # # rank = 0
@ -72,6 +73,7 @@ from funasr.register import tables
@tables.register("index_ds_classes", "IndexDSJsonl") @tables.register("index_ds_classes", "IndexDSJsonl")
@tables.register("index_ds_classes", "IndexDSJsonlRankFull") @tables.register("index_ds_classes", "IndexDSJsonlRankFull")
@tables.register("index_ds_classes", "IndexDSJsonlRankSplit")
class IndexDSJsonlRankFull(torch.utils.data.Dataset): class IndexDSJsonlRankFull(torch.utils.data.Dataset):
def __init__(self, path: str, **kwargs): def __init__(self, path: str, **kwargs):
@ -80,25 +82,67 @@ class IndexDSJsonlRankFull(torch.utils.data.Dataset):
self.min_source_length = kwargs.get("min_source_length", 0) self.min_source_length = kwargs.get("min_source_length", 0)
self.max_target_length = kwargs.get("max_target_length", 2048) self.max_target_length = kwargs.get("max_target_length", 2048)
self.min_target_length = kwargs.get("min_target_length", 0) self.min_target_length = kwargs.get("min_target_length", 0)
if isinstance(path, (list, tuple)): # wav.scp, text.txt/text.trans
from funasr.datasets.audio_datasets.scp2jsonl import gen_jsonl_from_wav_text_list is_training = kwargs.get("is_training", True)
jsonl_outdir = os.path.dirname(path[0]) if not (path.endswith(".jsonl") or path.endswith(".json")):
jsonl_name = "datalist_train.jsonl" if kwargs.get("is_training", True) else "datalist_val.jsonl" # jsonl list file
jsonl_file_out = os.path.join(jsonl_outdir, jsonl_name) data_split_num = kwargs.get("data_split_num", 1)
if not os.path.exists(jsonl_file_out): data_split_i = kwargs.get("data_split_i", 0)
print(f"datalist is: {path}, generate jsonl from it")
gen_jsonl_from_wav_text_list(path, jsonl_file_out=jsonl_file_out, **kwargs) if not is_training:
path = jsonl_file_out data_split_num = 1
data_split_i = 0
with open(path, encoding='utf-8') as fin:
file_list_all = fin.readlines()
num_per_slice = len(file_list_all) // data_split_num
file_list = file_list_all[data_split_i * num_per_slice:(data_split_i + 1) * num_per_slice]
logging.info(
f"is_training: {is_training}, data_split_num: {data_split_num}, data_split_i: {data_split_i}, \nfile_list: {file_list}, \nfile_list_all: {file_list_all}")
else:
file_list = [path]
total_num = len(file_list)
try:
rank = dist.get_rank()
world_size = dist.get_world_size()
except:
rank = 0
world_size = 1
logging.info("distributed is not initialized, only single shard")
if not kwargs.get("rank_split", False):
logging.info(f"Warning, rank_split disenabled, batch and shuffle data in global")
rank = 0
world_size = 1
num_per_rank = total_num // world_size
if num_per_rank * world_size < total_num:
logging.info(f"Warning, jsonl file:{total_num} could not be divided by world_size: {world_size}, {path}")
total_num_needed = num_per_rank * world_size
extra_num = total_num_needed - total_num
file_list_tmp = random.choices(file_list, k=extra_num)
file_list += file_list_tmp
logging.info(f"Warning, after random choices: {file_list}")
file_list_rank = file_list[rank * num_per_rank:(rank + 1) * num_per_rank]
logging.info(
f"is_training: {is_training}, file_list_rank: {file_list_rank}")
contents = [] contents = []
with open(path, encoding='utf-8') as fin: for file_json in file_list_rank:
with open(file_json.strip(), encoding='utf-8') as fin:
for line in fin: for line in fin:
data = json.loads(line.strip()) data = json.loads(line.strip())
if "text" in data: # for sft if "text" in data: # for sft
contents.append(data['text']) contents.append(data['text'])
if "source" in data: # for speech lab pretrain if "source" in data: # for speech lab pretrain
prompt = data.get("prompt", "<ASR>") prompt = data.get("prompt", "<ASR>")
source = data["source"] source = data["source"].replace("/cpfs01", "/cpfs_speech/data") # only use in alibaba gpu group: .replace("/cpfs01", "/cpfs_speech/data")
target = data["target"] target = data["target"]
source_len = data.get("source_len", 1) source_len = data.get("source_len", 1)
target_len = data.get("target_len", 0) target_len = data.get("target_len", 0)
@ -117,24 +161,23 @@ class IndexDSJsonlRankFull(torch.utils.data.Dataset):
text_language = data.get("text_language", None) text_language = data.get("text_language", None)
if text_language is not None: if text_language is not None:
contents_i["text_language"] = text_language contents_i["text_language"] = text_language
audio_language = data.get("audio_language", None) # audio_language = data.get("audio_language", None)
if audio_language is not None: # if audio_language is not None:
contents_i["audio_language"] = audio_language # contents_i["audio_language"] = audio_language
contents.append(contents_i) contents.append(contents_i)
self.contents = contents self.contents = contents
logging.info( logging.info(
"total_num of samplers across ranks: {}".format(len(self.contents))) "total_num of samplers: {}, {}".format(len(self.contents), path))
def __len__(self): def __len__(self):
return len(self.contents) return len(self.contents)
def __getitem__(self, index): def __getitem__(self, index):
try:
data = self.contents[index] data = self.contents[index]
except:
print(index)
return data return data
def get_source_len(self, data_dict): def get_source_len(self, data_dict):
@ -144,84 +187,95 @@ class IndexDSJsonlRankFull(torch.utils.data.Dataset):
return data_dict.get("target_len", 0) return data_dict.get("target_len", 0)
#
@tables.register("index_ds_classes", "IndexDSJsonlRankSplit") # @tables.register("index_ds_classes", "IndexDSJsonlRankSplit")
class IndexDSJsonlRankSplit(torch.utils.data.Dataset): # class IndexDSJsonlRankSplit(torch.utils.data.Dataset):
#
def __init__(self, path: str, **kwargs): # def __init__(self, path: str, **kwargs):
super().__init__() # super().__init__()
self.max_source_length = kwargs.get("max_source_length", 2048) # logging.info("building IndexDS")
self.min_source_length = kwargs.get("min_source_length", 0) # self.max_source_length = kwargs.get("max_source_length", 2048)
self.max_target_length = kwargs.get("max_target_length", 2048) # self.min_source_length = kwargs.get("min_source_length", 0)
self.min_target_length = kwargs.get("min_target_length", 0) # self.max_target_length = kwargs.get("max_target_length", 2048)
# self.min_target_length = kwargs.get("min_target_length", 0)
with open(path, encoding='utf-8') as fin: #
file_list = fin.readlines() # data_split_num = kwargs.get("data_split_num", 1)
# data_split_i = kwargs.get("data_split_i", 0)
total_num = len(file_list) # if not kwargs.get("is_training", True):
try: # data_split_num = 1
rank = dist.get_rank() # data_split_i = 0
world_size = dist.get_world_size() # with open(path, encoding='utf-8') as fin:
except: # file_list_all = fin.readlines()
rank = 0 #
world_size = 1 # num_per_slice = len(file_list_all) // data_split_num
logging.warning("distributed is not initialized, only single shard") # file_list = file_list_all[data_split_i * num_per_slice:(data_split_i + 1) * num_per_slice]
num_per_rank = total_num // world_size # logging.info(f"data_split_num: {data_split_num}, data_split_i: {data_split_i}, file_list: {file_list}, file_list_all: {file_list_all}")
if num_per_rank * world_size < total_num: #
logging.warning(f"Warning, jsonl file:{total_num} could not be divided by world_size: {world_size}, {path}") #
# total_num = len(file_list)
file_list_rank = file_list[rank * num_per_rank:(rank + 1) * num_per_rank] # try:
# rank = dist.get_rank()
contents = [] # world_size = dist.get_world_size()
for file_json in file_list_rank: # except:
# rank = 0
with open(file_json.strip(), encoding='utf-8') as fin: # world_size = 1
for line in fin: # logging.info("distributed is not initialized, only single shard")
data = json.loads(line.strip()) # num_per_rank = total_num // world_size
if "text" in data: # for sft # if num_per_rank * world_size < total_num:
contents.append(data['text']) # logging.info(f"Warning, jsonl file:{total_num} could not be divided by world_size: {world_size}, {path}")
if "source" in data: # for speech lab pretrain #
prompt = data.get("prompt", "<ASR>") # file_list_rank = file_list[rank * num_per_rank:(rank + 1) * num_per_rank]
source = data["source"].replace("/cpfs01", "/cpfs_speech/data") #
target = data["target"] # contents = []
source_len = data.get("source_len", 1) # for file_json in file_list_rank:
target_len = data.get("target_len", 0) #
# with open(file_json.strip(), encoding='utf-8') as fin:
if source_len < self.min_source_length or source_len > self.max_source_length: # for line in fin:
continue # data = json.loads(line.strip())
if target_len < self.min_target_length or target_len > self.max_target_length: # if "text" in data: # for sft
continue # contents.append(data['text'])
contents_i = {"source": source, # if "source" in data: # for speech lab pretrain
"prompt": prompt, # prompt = data.get("prompt", "<ASR>")
"target": target, # source = data["source"].replace("/cpfs01", "/cpfs_speech/data")
"source_len": source_len, # target = data["target"]
"target_len": target_len, # source_len = data.get("source_len", 1)
} # target_len = data.get("target_len", 0)
text_language = data.get("text_language", None) #
if text_language is not None: # if source_len < self.min_source_length or source_len > self.max_source_length:
contents_i["text_language"] = text_language # continue
audio_language = data.get("audio_language", None) # if target_len < self.min_target_length or target_len > self.max_target_length:
if audio_language is not None: # continue
contents_i["audio_language"] = audio_language # contents_i = {"source": source,
contents.append(contents_i) # "prompt": prompt,
# "target": target,
self.contents = contents # "source_len": source_len,
# "target_len": target_len,
logging.info(f"total_num: {len(self.contents)} of samplers in ranks: {rank}") # }
# text_language = data.get("text_language", None)
def __len__(self): # if text_language is not None:
return len(self.contents) # contents_i["text_language"] = text_language
# # audio_language = data.get("audio_language", None)
def __getitem__(self, index): # # if audio_language is not None:
try: # # contents_i["audio_language"] = audio_language
data = self.contents[index] # contents.append(contents_i)
except: #
print(index) # self.contents = contents
return data #
# logging.info(f"total_num: {len(self.contents)} of samplers in ranks: {rank}, file_list_rank: {file_list_rank}")
def get_source_len(self, data_dict): #
return data_dict.get("source_len", 1) # def __len__(self):
# return len(self.contents)
def get_target_len(self, data_dict): #
# def __getitem__(self, index):
return data_dict.get("target_len", 0) # try:
# data = self.contents[index]
# except:
# print(index)
# return data
#
# def get_source_len(self, data_dict):
# return data_dict.get("source_len", 1)
#
# def get_target_len(self, data_dict):
#
# return data_dict.get("target_len", 0)

View File

@ -301,6 +301,7 @@ class CustomDistributedBufferDynamicBatchSampler(DistributedSampler):
batch_type="token", batch_type="token",
num_replicas=None, num_replicas=None,
rank=None, rank=None,
rank_split=False,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
is_training: bool = True, is_training: bool = True,
@ -314,6 +315,12 @@ class CustomDistributedBufferDynamicBatchSampler(DistributedSampler):
except: except:
rank = 0 rank = 0
num_replicas = 1 num_replicas = 1
if rank_split:
logging.info(f"Warning, rank_split: {rank_split}, batch and shuffle data in local rank")
rank = 0
num_replicas = 1
self.rank = rank self.rank = rank
self.num_replicas = num_replicas self.num_replicas = num_replicas
self.dataset = dataset self.dataset = dataset

View File

@ -40,7 +40,21 @@ class DataloaderMapStyle:
self.dataset_val = dataset_val self.dataset_val = dataset_val
self.kwargs = kwargs self.kwargs = kwargs
def build_iter(self, epoch=0): # split dataset
self.data_split_num = kwargs["dataset_conf"].get("data_split_num", 1)
self.dataset_class = dataset_class
self.frontend = frontend
self.tokenizer = tokenizer
self.kwargs = kwargs
def build_iter(self, epoch=0, data_split_i=0, **kwargs):
# reload dataset slice
if self.data_split_num > 1:
del self.dataset_tr
self.dataset_tr = self.dataset_class(self.kwargs.get("train_data_set_list"), frontend=self.frontend, tokenizer=self.tokenizer,
is_training=True, **self.kwargs.get("dataset_conf"), data_split_i=data_split_i)
# dataloader # dataloader
batch_sampler = self.kwargs["dataset_conf"].get("batch_sampler", "BatchSampler") batch_sampler = self.kwargs["dataset_conf"].get("batch_sampler", "BatchSampler")
batch_sampler_val = None batch_sampler_val = None

View File

@ -245,29 +245,7 @@ class SenseVoiceDecoder(nn.Module):
self.register_buffer("mask", mask, persistent=False) self.register_buffer("mask", mask, persistent=False)
self.use_padmask = kwargs.get("use_padmask", True) self.use_padmask = kwargs.get("use_padmask", True)
# def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
# """
# x : torch.LongTensor, shape = (batch_size, <= n_ctx)
# the text tokens
# xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
# the encoded audio features to be attended on
# """
# offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
# x = (
# self.token_embedding(x)
# + self.positional_embedding[offset: offset + x.shape[-1]]
# )
# x = x.to(xa.dtype)
#
# for block in self.blocks:
# x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
#
# x = self.ln(x)
# logits = (
# x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
# ).float()
#
# return logits
def forward( def forward(

View File

@ -252,6 +252,7 @@ class Trainer:
dataloader_val=None, dataloader_val=None,
epoch=None, epoch=None,
writer=None, writer=None,
**kwargs,
): ):
""" """
Defines the training process for a single epoch with gradient accumulation. Defines the training process for a single epoch with gradient accumulation.
@ -374,6 +375,8 @@ class Trainer:
stats=stats, stats=stats,
writer=writer, writer=writer,
tag="train", tag="train",
data_split_i=kwargs.get("data_split_i", 0),
data_split_num=kwargs.get("data_split_num", 1),
) )
if (batch_idx + 1) % self.validate_interval == 0: if (batch_idx + 1) % self.validate_interval == 0:
@ -507,6 +510,9 @@ class Trainer:
stats=None, stats=None,
writer=None, writer=None,
tag="train", tag="train",
data_split_i=0,
data_split_num=1,
**kwargs,
): ):
if (batch_idx + 1) % self.log_interval == 0: if (batch_idx + 1) % self.log_interval == 0:
@ -526,6 +532,7 @@ class Trainer:
f"{tag}, " f"{tag}, "
f"rank: {self.local_rank}, " f"rank: {self.local_rank}, "
f"epoch: {epoch}/{self.max_epoch}, " f"epoch: {epoch}/{self.max_epoch}, "
f"data_slice: {data_split_i}/{data_split_num}, "
f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, " f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, "
f"(loss_avg_rank: {loss:.3f}), " f"(loss_avg_rank: {loss:.3f}), "
f"(loss_avg_epoch: {loss_avg_epoch:.3f}), " f"(loss_avg_epoch: {loss_avg_epoch:.3f}), "