This commit is contained in:
游雁 2024-02-22 23:52:22 +08:00
parent aaf61bbb64
commit 54b6ff5764
12 changed files with 1094 additions and 9 deletions

View File

View File

@ -0,0 +1,130 @@
import torch
import copy
from funasr.register import tables
from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
@tables.register("dataset_classes", "AudioLLMDataset")
class AudioLLMDataset(torch.utils.data.Dataset):
"""
AudioLLMDataset
"""
def __init__(self,
path,
index_ds: str = None,
frontend=None,
tokenizer=None,
int_pad_value: int = -1,
float_pad_value: float = 0.0,
**kwargs):
super().__init__()
index_ds_class = tables.index_ds_classes.get(index_ds)
self.index_ds = index_ds_class(path, **kwargs)
preprocessor_speech = kwargs.get("preprocessor_speech", None)
if preprocessor_speech:
preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
self.preprocessor_speech = preprocessor_speech
preprocessor_text = kwargs.get("preprocessor_text", None)
if preprocessor_text:
preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
self.preprocessor_text = preprocessor_text
self.frontend = frontend
self.fs = 16000 if frontend is None else frontend.fs
self.data_type = "sound"
self.tokenizer = tokenizer
self.int_pad_value = int_pad_value
self.float_pad_value = float_pad_value
self.prompt = kwargs.get("prompt", "Transcribe speech to text.")
self.prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(
self.prompt) # "USER: \nINSTRUCTION: {}\nnINPUT: {}\nASSISTANT: "
self.prompt_af = ""
def get_source_len(self, index):
item = self.index_ds[index]
return self.index_ds.get_source_len(item)
def get_target_len(self, index):
item = self.index_ds[index]
return self.index_ds.get_target_len(item)
def __len__(self):
return len(self.index_ds)
def __getitem__(self, index):
item = self.index_ds[index]
# import pdb;
# pdb.set_trace()
source = item["source"]
data_src = load_audio_text_image_video(source, fs=self.fs)
if self.preprocessor_speech:
data_src = self.preprocessor_speech(data_src, fs=self.fs)
speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend, is_final=True) # speech: [b, T, d]
speech = speech.sequeeze(0)
target = item["target"]
if self.preprocessor_text:
target = self.preprocessor_text(target)
prompt_ids_pre = self.tokenizer.encode(self.prompt_pre) # [bos,prompt]
prompt_pre_length = len(prompt_ids_pre)
prompt_input = "{}{}".format(self.prompt_pre, target)
prompt_input_ids = self.tokenizer.encode(prompt_input)
audio_length = len(prompt_input_ids) - prompt_pre_length
input_ids = prompt_input_ids + [self.tokenizer.pad_token_id]
input_ids = torch.tensor(input_ids, dtype=torch.int64) #[bos, prompt, input, pad]
input_ids[prompt_pre_length:] = -1 # [bos, prompt,-1,-1]
attention_mask = input_ids.ge(-1) # [true, true, true, true], length mask
prompt_answer = "{}{}".format(self.prompt_pre, target)
prompt_answer_ids = self.tokenizer.encode(prompt_answer)
answer_length = len(prompt_answer_ids) - prompt_pre_length
labels_ids = copy.deepcopy(prompt_input_ids) + [self.tokenizer.eos_token_id]
labels_ids = torch.tensor(labels_ids, dtype=torch.int64) # [bos, prompt, input, eos]
labels_ids[:prompt_pre_length] = -1 # [-1, -1, input, eos]
label_mask = labels_ids.ge(0) # [False,False,True,True]
labels_ids[~label_mask] = self.IGNORE_INDEX # [-100,-100,input,eos]
audio_mask = [0] * prompt_pre_length + [1] * audio_length
torch.tensor(audio_mask, dtype=torch.float32)
ids = self.tokenizer.encode(target)
text = torch.tensor(ids, dtype=torch.int64)
text_lengths = torch.tensor([len(ids)], dtype=torch.int32)
return {"speech": speech,
"speech_lengths": speech_lengths,
"text": text,
"text_lengths": text_lengths,
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels_ids": labels_ids,
"label_mask": label_mask,
"audio_mask": audio_mask,
}
def collator(self, samples: list=None):
outputs = {}
for sample in samples:
for key in sample.keys():
if key not in outputs:
outputs[key] = []
outputs[key].append(sample[key])
for key, data_list in outputs.items():
if isinstance(data_list[0], torch.Tensor):
if data_list[0].dtype == torch.int64:
pad_value = self.int_pad_value
else:
pad_value = self.float_pad_value
outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
return outputs

View File

@ -0,0 +1,51 @@
import os
import json
import torch
import logging
import concurrent.futures
import librosa
import torch.distributed as dist
from typing import Collection
import torch
import torchaudio
from torch import nn
import random
import re
from funasr.tokenizer.cleaner import TextCleaner
from funasr.register import tables
@tables.register("preprocessor_classes", "SpeechPreprocessSpeedPerturb")
class SpeechPreprocessSpeedPerturb(nn.Module):
def __init__(self, speed_perturb: list=None, **kwargs):
super().__init__()
self.speed_perturb = speed_perturb
def forward(self, waveform, fs, **kwargs):
if self.speed_perturb is None:
return waveform
speed = random.choice(self.speed_perturb)
if speed != 1.0:
if not isinstance(waveform, torch.Tensor):
waveform = torch.tensor(waveform)
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
waveform.view(1, -1), fs, [['speed', str(speed)], ['rate', str(fs)]])
waveform = waveform.view(-1)
return waveform
@tables.register("preprocessor_classes", "TextPreprocessSegDict")
class TextPreprocessSegDict(nn.Module):
def __init__(self, seg_dict: str = None,
text_cleaner: Collection[str] = None,
split_with_space: bool = False,
**kwargs):
super().__init__()
self.text_cleaner = TextCleaner(text_cleaner)
def forward(self, text, **kwargs):
text = self.text_cleaner(text)
return text

View File

@ -0,0 +1,277 @@
import torch
import numpy as np
import logging
import torch.distributed as dist
from funasr.register import tables
@tables.register("batch_sampler_classes", "DynamicBatchLocalShuffleSampler")
class BatchSampler(torch.utils.data.BatchSampler):
def __init__(self, dataset,
batch_type: str = "example",
batch_size: int = 100,
buffer_size: int = 30,
drop_last: bool = False,
shuffle: bool = True,
is_training: bool = True,
**kwargs):
self.drop_last = drop_last
self.pre_idx = -1
self.dataset = dataset
self.total_samples = len(dataset)
self.batch_type = batch_type
self.batch_size = int(batch_size)
self.buffer_size = buffer_size
self.max_token_length = kwargs.get("max_token_length", 5000)
self.shuffle_idx = np.arange(self.total_samples)
self.shuffle = shuffle and is_training
self.length_scale_source = kwargs.get("length_scale_source", 1.0)
def __len__(self):
return (self.total_samples-1) // self.batch_size + 1
def set_epoch(self, epoch):
np.random.seed(epoch)
def __iter__(self):
if self.shuffle:
np.random.shuffle(self.shuffle_idx)
batch = []
max_token = 0
num_sample = 0
iter_num = (self.total_samples - 1) // self.buffer_size + 1
# print("iter_num: ", iter_num)
for iter in range(self.pre_idx + 1, iter_num):
datalen_with_index = []
for i in range(self.buffer_size):
idx = iter * self.buffer_size + i
if idx >= self.total_samples:
continue
idx_map = self.shuffle_idx[idx]
# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
sample_len_cur = source_len + target_len
datalen_with_index.append([idx, sample_len_cur])
datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
for item in datalen_with_index_sort:
idx, sample_len_cur_raw = item
if sample_len_cur_raw > self.max_token_length:
continue
max_token_cur = max(max_token, sample_len_cur_raw)
max_token_padding = 1 + num_sample
if self.batch_type != 'example':
max_token_padding *= max_token_cur
if max_token_padding <= self.batch_size:
batch.append(idx)
max_token = max_token_cur
num_sample += 1
else:
yield batch
batch = [idx]
max_token = sample_len_cur_raw
num_sample = 1
@tables.register("batch_sampler_classes", "BatchSampler")
@tables.register("batch_sampler_classes", "RankFullLocalShuffleBatchSampler")
class RankFullLocalShuffleBatchSampler(torch.utils.data.BatchSampler):
def __init__(self, dataset,
batch_type: str = "example",
batch_size: int = 100,
buffer_size: int = 30,
drop_last: bool = True,
shuffle: bool = True,
is_training: bool = True,
**kwargs):
self.drop_last = drop_last
self.pre_idx = -1
self.dataset = dataset
self.total_samples = len(dataset)
self.batch_type = batch_type
self.batch_size = int(batch_size)
self.buffer_size = buffer_size
self.max_token_length = kwargs.get("max_token_length", 1500)
self.shuffle_idx = np.arange(self.total_samples)
self.shuffle = shuffle and is_training
self.length_scale_source = kwargs.get("length_scale_source", 1.0)
try:
rank = dist.get_rank()
world_size = dist.get_world_size()
except:
rank = 0
world_size = 1
self.rank = rank
self.world_size = world_size
def __len__(self):
return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
def set_epoch(self, epoch):
np.random.seed(epoch)
def __iter__(self):
batch_size_total = self.batch_size * self.world_size
if self.shuffle:
np.random.shuffle(self.shuffle_idx)
batch = []
max_token = 0
num_sample = 0
iter_num = (self.total_samples - 1) // self.buffer_size + 1
# print("iter_num: ", iter_num)
for iter in range(self.pre_idx + 1, iter_num):
# if iter == iter_num -1 and self.drop_last:
# continue
datalen_with_index = []
for i in range(self.buffer_size):
idx = iter * self.buffer_size + i
if idx >= self.total_samples:
continue
idx_map = self.shuffle_idx[idx]
# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
sample_len_cur = source_len + target_len
datalen_with_index.append([idx, sample_len_cur])
datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
for item in datalen_with_index_sort:
idx, sample_len_cur_raw = item
if sample_len_cur_raw > self.max_token_length:
continue
max_token_cur = max(max_token, sample_len_cur_raw)
max_token_padding = 1 + num_sample
# if self.batch_type != 'example':
# max_token_padding *= max_token_cur
if max_token_padding <= batch_size_total:
batch.append(idx)
max_token = max_token_cur
num_sample += 1
else:
batch_rank = batch[self.rank*self.batch_size: (self.rank+1)*self.batch_size]
yield batch_rank
batch = [idx]
max_token = sample_len_cur_raw
num_sample = 1
@tables.register("batch_sampler_classes", "RankFullLocalShuffleDynamicBatchSampler")
class RankFullLocalShuffleDynamicBatchSampler(torch.utils.data.BatchSampler):
def __init__(self, dataset,
batch_type: str = "example",
batch_size: int = 100,
buffer_size: int = 30,
drop_last: bool = True,
shuffle: bool = True,
is_training: bool = True,
**kwargs):
self.drop_last = drop_last
self.pre_idx = -1
self.dataset = dataset
self.total_samples = len(dataset)
self.batch_type = batch_type
self.batch_size = int(batch_size)
self.buffer_size = buffer_size
self.max_token_length = kwargs.get("max_token_length", 1500)
self.shuffle_idx = np.arange(self.total_samples)
self.shuffle = shuffle and is_training
self.length_scale_source = kwargs.get("length_scale_source", 1.0)
try:
rank = dist.get_rank()
world_size = dist.get_world_size()
except:
rank = 0
world_size = 1
self.rank = rank
self.world_size = world_size
def __len__(self):
return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
def set_epoch(self, epoch):
np.random.seed(epoch)
def __iter__(self):
batch_size_total = self.batch_size * self.world_size
if self.shuffle:
np.random.shuffle(self.shuffle_idx)
batch_list_all_rank = []
batch_list_cur = []
max_token = 0
num_sample = 0
iter_num = (self.total_samples - 1) // self.buffer_size + 1
# print("iter_num: ", iter_num)
for iter in range(self.pre_idx + 1, iter_num):
# if iter == iter_num - 1 and self.drop_last:
# continue
datalen_with_index = []
for i in range(self.buffer_size):
idx = iter * self.buffer_size + i
if idx >= self.total_samples:
continue
idx_map = self.shuffle_idx[idx]
# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
sample_len_cur = source_len + target_len
datalen_with_index.append([idx, sample_len_cur])
datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
for ii, item in enumerate(datalen_with_index_sort):
is_last_batch = iter == iter_num - 1 and ii == len(datalen_with_index_sort)
idx, sample_len_cur_raw = item
if sample_len_cur_raw > self.max_token_length:
continue
max_token_cur = max(max_token, sample_len_cur_raw)
max_token_padding = 1 + num_sample
if self.batch_type != 'example':
max_token_padding *= max_token_cur
if len(batch_list_all_rank) < self.world_size:
if max_token_padding <= self.batch_size:
batch_list_cur.append(idx)
max_token = max_token_cur
num_sample += 1
else:
batch_list_all_rank.append(batch_list_cur)
batch_list_cur = []
else:
batch_rank = batch_list_all_rank[self.rank]
yield batch_rank
batch_list_all_rank = [idx]
max_token = sample_len_cur_raw
num_sample = 1

View File

@ -0,0 +1,96 @@
import os
import json
import torch
import logging
import hydra
from omegaconf import DictConfig, OmegaConf
import concurrent.futures
import librosa
import torch.distributed as dist
def gen_jsonl_from_wav_text_list(path, data_type_list=("source", "target"), jsonl_file_out:str=None, **kwargs):
try:
rank = dist.get_rank()
world_size = dist.get_world_size()
except:
rank = 0
world_size = 1
cpu_cores = os.cpu_count() or 1
print(f"convert wav.scp text to jsonl, ncpu: {cpu_cores}")
if rank == 0:
json_dict = {}
for data_type, data_file in zip(data_type_list, path):
json_dict[data_type] = {}
with open(data_file, "r") as f:
data_file_lists = f.readlines()
lines_for_each_th = (len(data_file_lists)-1)//cpu_cores + 1
task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1
with concurrent.futures.ThreadPoolExecutor(max_workers=cpu_cores) as executor:
futures = [executor.submit(parse_context_length, data_file_lists[i*lines_for_each_th:(i+1)*lines_for_each_th], data_type) for i in range(task_num)]
for future in concurrent.futures.as_completed(futures):
json_dict[data_type].update(future.result())
# print(json_dict)
with open(jsonl_file_out, "w") as f:
for key in json_dict[data_type_list[0]].keys():
jsonl_line = {"key": key}
for data_file in data_type_list:
jsonl_line.update(json_dict[data_file][key])
jsonl_line = json.dumps(jsonl_line, ensure_ascii=False)
f.write(jsonl_line+"\n")
f.flush()
else:
pass
if world_size > 1:
dist.barrier()
def parse_context_length(data_list: list, data_type: str):
res = {}
for i, line in enumerate(data_list):
key, line = line.strip().split(maxsplit=1)
line = line.strip()
if os.path.exists(line):
waveform, _ = librosa.load(line, sr=16000)
sample_num = len(waveform)
context_len = int(sample_num//16000*1000/10)
else:
context_len = len(line.split()) if " " in line else len(line)
res[key] = {data_type: line, f"{data_type}_len": context_len}
return res
@hydra.main(config_name=None, version_base=None)
def main_hydra(cfg: DictConfig):
kwargs = OmegaConf.to_container(cfg, resolve=True)
scp_file_list = kwargs.get("scp_file_list", ("/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"))
if isinstance(scp_file_list, str):
scp_file_list = eval(scp_file_list)
data_type_list = kwargs.get("data_type_list", ("source", "target"))
jsonl_file_out = kwargs.get("jsonl_file_out", "/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl")
gen_jsonl_from_wav_text_list(scp_file_list, data_type_list=data_type_list, jsonl_file_out=jsonl_file_out)
"""
python -m funasr.datasets.audio_datasets.scp2jsonl \
++scp_file_list='["/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"]' \
++data_type_list='["source", "target"]' \
++jsonl_file_out=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl
"""
if __name__ == "__main__":
main_hydra()

View File

@ -21,3 +21,22 @@ def th_accuracy(pad_outputs, pad_targets, ignore_label):
)
denominator = torch.sum(mask)
return float(numerator) / float(denominator)
def compute_accuracy(pad_outputs, pad_targets, ignore_label):
"""Calculate accuracy.
Args:
pad_outputs (LongTensor): Prediction tensors (B, Lmax).
pad_targets (LongTensor): Target label tensors (B, Lmax).
ignore_label (int): Ignore label id.
Returns:
float: Accuracy value (0.0 - 1.0).
"""
mask = pad_targets != ignore_label
numerator = torch.sum(
pad_outputs.masked_select(mask) == pad_targets.masked_select(mask)
)
denominator = torch.sum(mask)
return numerator.float() / denominator.float() #(FIX:MZY):return torch.Tensor type

View File

View File

@ -0,0 +1,29 @@
import torch
import torch.nn as nn
from funasr.register import tables
@tables.register("adaptor_classes", "Linear")
class Linear(nn.Module):
def __init__(self, downsample_rate, encoder_dim, llm_dim, ffn_dim: int = 2048, **kwargs):
super().__init__()
self.k = downsample_rate
self.encoder_dim = encoder_dim
self.llm_dim = llm_dim
self.linear1 = nn.Linear(self.encoder_dim * self.k, ffn_dim)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(ffn_dim, self.llm_dim)
def forward(self, x):
batch_size, seq_len, dim = x.size()
num_frames_to_discard = seq_len % self.k
if num_frames_to_discard > 0:
x = x[:, :-num_frames_to_discard, :]
seq_len = x.size(1)
x = x.contiguous()
x = x.view(batch_size, seq_len // self.k, dim * self.k)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x

View File

@ -0,0 +1,353 @@
import logging
from typing import Union, Dict, List, Tuple, Optional
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
from funasr.models.ctc.ctc import CTC
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.metrics.compute_acc import th_accuracy, compute_accuracy
# from funasr.models.e2e_asr_common import ErrorCalculator
from funasr.train_utils.device_funcs import force_gatherable
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
from funasr.register import tables
@tables.register("model_classes", "LLMASR")
class LLMASR(nn.Module):
""" """
def __init__(
self,
specaug: str = None,
specaug_conf: dict = None,
normalize: str = None,
normalize_conf: dict = None,
encoder: str = None,
encoder_conf: dict = None,
decoder: str = None,
decoder_conf: dict = None,
ctc: str = None,
ctc_conf: dict = None,
ctc_weight: float = 0.5,
llm: str = None,
llm_conf: dict = None,
adaptor: str = None,
adaptor_conf: dict = None,
input_size: int = 80,
vocab_size: int = -1,
ignore_id: int = -1,
blank_id: int = 0,
sos: int = 1,
eos: int = 2,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
report_cer: bool = True,
report_wer: bool = True,
sym_space: str = "<space>",
sym_blank: str = "<blank>",
# extract_feats_in_collect_stats: bool = True,
share_embedding: bool = False,
# preencoder: Optional[AbsPreEncoder] = None,
# postencoder: Optional[AbsPostEncoder] = None,
**kwargs,
):
super().__init__()
if specaug is not None:
specaug_class = tables.specaug_classes.get(specaug)
specaug = specaug_class(**specaug_conf)
if normalize is not None:
normalize_class = tables.normalize_classes.get(normalize)
normalize = normalize_class(**normalize_conf)
# audio encoder
hub = encoder_conf.get("hub", None)
if hub == "funasr":
from funasr import AutoModel
from funasr.models.scama.utils import sequence_mask
init_param_path = encoder_conf.get("hub", "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
model = AutoModel(model=init_param_path, model_revision="v2.0.4")
frontend = model.kwargs.get("frontend")
model.model.decoder = None
self.model = model.model
self.frontend = frontend
self.mask_fn = sequence_mask
elif hub == "hf":
pass
else:
encoder_class = tables.encoder_classes.get(encoder)
encoder = encoder_class(input_size=input_size, **encoder_conf)
encoder_output_size = encoder.output_size()
# llm
hub = llm_conf.get("hub", "hf")
self.llm = None
if hub == "hf":
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
model = AutoModelForCausalLM.from_pretrained(
init_param_path,
load_in_8bit=None,
device_map=None,
use_cache=None,
)
freeze_llm = llm_conf.get("freeze_llm", True)
if freeze_llm:
for name, param in model.named_parameters():
param.requires_grad = False
model.eval()
self.llm = model
# adaptor
adaptor_class = tables.adaptor_classes.get(adaptor)
adaptor = adaptor_class(**adaptor_conf)
self.adaptor = adaptor
self.blank_id = blank_id
self.sos = sos if sos is not None else vocab_size - 1
self.eos = eos if eos is not None else vocab_size - 1
self.vocab_size = vocab_size
self.ignore_id = ignore_id
self.specaug = specaug
self.normalize = normalize
self.encoder = encoder
self.criterion_att = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
#
# if report_cer or report_wer:
# self.error_calculator = ErrorCalculator(
# token_list, sym_space, sym_blank, report_cer, report_wer
# )
#
self.error_calculator = None
self.length_normalized_loss = length_normalized_loss
self.beam_search = None
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
input_ids: torch.Tensor,
attention_mask:torch.Tensor,
labels_ids:torch.Tensor,
label_mask: torch.Tensor,
audio_mask:torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
# import pdb;
# pdb.set_trace()
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# audio encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, audio_mask)
# adaptor
encoder_out = self.adaptor(encoder_out)
if input_ids is not None:
input_ids[input_ids == -1] = 0
if hasattr(self.llm.model, "embed_tokens"):
inputs_embeds = self.llm.model.embed_tokens(input_ids)
elif hasattr(self.llm.model.model, "embed_tokens"):
inputs_embeds = self.llm.model.model.embed_tokens(input_ids)
else:
inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids)
if audio_mask is not None:
batch_size, token_num, dims = inputs_embeds.shape
_, l, _ = encoder_out.shape
encoder_outs_pad = F.pad(encoder_out, (0, 0, token_num-l-1, 1, 0, 0), value=0.0)
inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (~audio_mask[:, :, None])
inputs_embeds = F.pad(inputs_embeds[:, 1:, :], (0, 0, 0, 1, 0, 0), value=0.0)
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
loss = model_outputs.loss
acc_att = -1
if self.metric:
with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1)
acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
stats = {}
# Collect Attn branch stats
stats["acc"] = acc_att.detach()
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = int((text_lengths + 1).sum())
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def encode(
self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
audio_mask = kwargs.get("audio_mask")
audio_token_lengths = audio_mask.sum(-1)
batch = {"speech": speech, "speech_lengths": speech_lengths}
enc, enc_lens = self.model.encode(**batch)
enc_mask = self.mask_fn(enc_lens, enc.size(1), device=enc.device)[:, None, :]
pre_acoustic_embeds, pre_token_length, _, _ = self.model.predictor(enc,
mask=enc_mask,
target_label_length=audio_token_lengths,
)
return pre_acoustic_embeds, pre_token_length
def _calc_att_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder
decoder_out, _ = self.decoder(
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
)
# 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_out_pad)
acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size),
ys_out_pad,
ignore_label=self.ignore_id,
)
# Compute cer/wer using attention-decoder
if self.training or self.error_calculator is None:
cer_att, wer_att = None, None
else:
ys_hat = decoder_out.argmax(dim=-1)
cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
return loss_att, acc_att, cer_att, wer_att
def inference(self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
# init beamsearch
if self.beam_search is None:
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
meta_data = {}
if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
speech, speech_lengths = data_in, data_lengths
if len(speech.shape) < 3:
speech = speech[None, :, :]
if speech_lengths is None:
speech_lengths = speech.shape[1]
else:
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
data_type=kwargs.get("data_type", "sound"),
tokenizer=tokenizer)
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
frontend=frontend)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])
# Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
# c. Passed the encoder result and the beam search
nbest_hyps = self.beam_search(
x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
)
nbest_hyps = nbest_hyps[: self.nbest]
results = []
b, n, d = encoder_out.size()
for i in range(b):
for nbest_idx, hyp in enumerate(nbest_hyps):
ibest_writer = None
if kwargs.get("output_dir") is not None:
if not hasattr(self, "writer"):
self.writer = DatadirWriter(kwargs.get("output_dir"))
ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1:last_pos]
else:
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
# Change integer-ids to tokens
token = tokenizer.ids2tokens(token_int)
text = tokenizer.tokens2text(token)
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
result_i = {"key": key[i], "token": token, "text": text_postprocessed}
results.append(result_i)
if ibest_writer is not None:
ibest_writer["token"][key[i]] = " ".join(token)
ibest_writer["text"][key[i]] = text_postprocessed
return results, meta_data

View File

@ -0,0 +1,90 @@
# This is an example that demonstrates how to configure a model file.
# You can modify the configuration according to your own requirements.
# to print the register_table:
# from funasr.register import tables
# tables.print()
# network architecture
model: LLMASR
model_conf:
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: true
# encoder
encoder: Paraformer
encoder_conf:
hub: funasr
init_param_path: "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
llm: Vicuna
llm_conf:
hub: hf
init_param_path: null
freeze_llm: true
adaptor: linear
adaptor_conf:
downsample_rate: 1
llm_dim: 4096
encoder_dim: 2048
# frontend related
frontend: WavFrontend
frontend_conf:
fs: 16000
window: hamming
n_mels: 80
frame_length: 25
frame_shift: 10
dither: 0.0
lfr_m: 1
lfr_n: 1
specaug: SpecAug
specaug_conf:
apply_time_warp: true
time_warp_window: 5
time_warp_mode: bicubic
apply_freq_mask: true
freq_mask_width_range:
- 0
- 30
num_freq_mask: 2
apply_time_mask: true
time_mask_width_range:
- 0
- 40
num_time_mask: 2
train_conf:
accum_grad: 1
grad_clip: 5
max_epoch: 150
keep_nbest_models: 10
log_interval: 50
optim: adam
optim_conf:
lr: 0.001
weight_decay: 0.000001
scheduler: warmuplr
scheduler_conf:
warmup_steps: 35000
dataset: AudioLLMDataset
dataset_conf:
index_ds: IndexDSJsonl
batch_sampler: RankFullLocalShuffleBatchSampler
batch_type: example # example or length
batch_size: 4 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
buffer_size: 500
shuffle: True
num_workers: 4
tokenizer: HuggingfaceTokenizer
tokenizer_conf:
unk_symbol: <unk>
init_param_path: null

View File

@ -0,0 +1,15 @@
try:
from transformers import AutoTokenizer
except:
print("If you want to use hugging, please `pip install -U transformers`")
from funasr.register import tables
@tables.register("tokenizer_classes", "HuggingfaceTokenizer")
def HuggingfaceTokenizer(init_param_path, **kwargs):
tokenizer = AutoTokenizer.from_pretrained(init_param_path)
return tokenizer

View File

@ -5,7 +5,8 @@ import logging
from tqdm import tqdm
from datetime import datetime
import torch.distributed as dist
from contextlib import nullcontext
from torch.cuda.amp import autocast, GradScaler
from contextlib import nullcontext, contextmanager
# from torch.utils.tensorboard import SummaryWriter
from tensorboardX import SummaryWriter
from pathlib import Path
@ -14,6 +15,14 @@ from funasr.train_utils.device_funcs import to_device
from funasr.train_utils.recursive_op import recursive_average
from funasr.train_utils.average_nbest_models import average_checkpoints
@contextmanager
def maybe_autocast(enabled):
if enabled:
with autocast():
yield
else:
yield
class Trainer:
"""
A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
@ -36,8 +45,9 @@ class Trainer:
dataloader_train,
dataloader_val,
local_rank,
use_ddp=False,
use_fsdp=False,
use_ddp: bool = False,
use_fsdp: bool = False,
use_fp16: bool = False,
output_dir: str="./",
**kwargs):
"""
@ -72,6 +82,9 @@ class Trainer:
self.kwargs = kwargs
self.log_interval = kwargs.get("log_interval", 50)
self.batch_total = 0
self.use_fp16 = use_fp16
self.disable_gpu_cache = kwargs.get("disable_gpu_cache", True)
self.scaler = GradScaler(enabled=use_fp16) if use_fp16 else None
try:
@ -103,6 +116,8 @@ class Trainer:
'optimizer': self.optim.state_dict(),
'scheduler': self.scheduler.state_dict(),
}
if self.scaler:
state["scaler_state"] = self.scaler.state_dict()
# Create output directory if it does not exist
os.makedirs(self.output_dir, exist_ok=True)
filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
@ -141,6 +156,8 @@ class Trainer:
self.model.load_state_dict(dst_state)
self.optim.load_state_dict(checkpoint['optimizer'])
self.scheduler.load_state_dict(checkpoint['scheduler'])
if self.scaler and 'scaler_state' in checkpoint:
self.scaler.load_state_dict(checkpoint['scaler_state'])
print(f"Checkpoint loaded successfully from '{ckpt}'")
else:
print(f"No checkpoint found at '{ckpt}', starting from scratch")
@ -221,9 +238,10 @@ class Trainer:
my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
with my_context():
time2 = time.perf_counter()
retval = self.model(**batch)
torch.cuda.empty_cache()
with maybe_autocast(self.use_fp16):
retval = self.model(**batch)
if self.disable_gpu_cache: torch.cuda.empty_cache()
time3 = time.perf_counter()
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
@ -241,7 +259,10 @@ class Trainer:
loss *= self.world_size
# Scale the loss since we're not updating for every mini-batch
loss = loss / accum_grad
loss.backward()
if self.use_fp16:
self.scaler.scale(loss).backward()
else:
loss.backward()
time4 = time.perf_counter()
speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
@ -264,10 +285,14 @@ class Trainer:
# Execute an optimization step (update model parameters)
if self.use_ddp or self.use_fsdp:
dist.barrier()
self.optim.step()
if self.use_fp16:
self.scaler.step(self.optim)
self.scaler.update()
else:
self.optim.step()
self.scheduler.step()
# Clear gradients for the next accumulation stage
self.optim.zero_grad()
self.optim.zero_grad(set_to_none=True)
total_time = f"{time.perf_counter() - time5:0.3f}"
time5 = time.perf_counter()
speed_stats["optim_time"] = f"{time5 - time4:0.3f}"