From 54b6ff57647e28bbe88d8df81f2b112f127660e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Thu, 22 Feb 2024 23:52:22 +0800 Subject: [PATCH] fp16 --- funasr/datasets/llm_datasets/__init__.py | 0 funasr/datasets/llm_datasets/datasets.py | 130 +++++++ funasr/datasets/llm_datasets/preprocessor.py | 51 +++ funasr/datasets/llm_datasets/samplers.py | 277 +++++++++++++++ funasr/datasets/llm_datasets/scp2jsonl.py | 96 +++++ funasr/metrics/compute_acc.py | 19 + funasr/models/llm_asr/__init__.py | 0 funasr/models/llm_asr/adaptor.py | 29 ++ funasr/models/llm_asr/model.py | 353 +++++++++++++++++++ funasr/models/llm_asr/template.yaml | 90 +++++ funasr/tokenizer/hf_tokenizer.py | 15 + funasr/train_utils/trainer.py | 43 ++- 12 files changed, 1094 insertions(+), 9 deletions(-) create mode 100644 funasr/datasets/llm_datasets/__init__.py create mode 100644 funasr/datasets/llm_datasets/datasets.py create mode 100644 funasr/datasets/llm_datasets/preprocessor.py create mode 100644 funasr/datasets/llm_datasets/samplers.py create mode 100644 funasr/datasets/llm_datasets/scp2jsonl.py create mode 100644 funasr/models/llm_asr/__init__.py create mode 100644 funasr/models/llm_asr/adaptor.py create mode 100644 funasr/models/llm_asr/model.py create mode 100644 funasr/models/llm_asr/template.yaml create mode 100644 funasr/tokenizer/hf_tokenizer.py diff --git a/funasr/datasets/llm_datasets/__init__.py b/funasr/datasets/llm_datasets/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/funasr/datasets/llm_datasets/datasets.py b/funasr/datasets/llm_datasets/datasets.py new file mode 100644 index 000000000..20eb8aa7c --- /dev/null +++ b/funasr/datasets/llm_datasets/datasets.py @@ -0,0 +1,130 @@ +import torch +import copy + +from funasr.register import tables +from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video + + +@tables.register("dataset_classes", "AudioLLMDataset") +class AudioLLMDataset(torch.utils.data.Dataset): + """ + AudioLLMDataset + """ + def __init__(self, + path, + index_ds: str = None, + frontend=None, + tokenizer=None, + int_pad_value: int = -1, + float_pad_value: float = 0.0, + **kwargs): + super().__init__() + index_ds_class = tables.index_ds_classes.get(index_ds) + self.index_ds = index_ds_class(path, **kwargs) + preprocessor_speech = kwargs.get("preprocessor_speech", None) + if preprocessor_speech: + preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech) + preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf")) + self.preprocessor_speech = preprocessor_speech + preprocessor_text = kwargs.get("preprocessor_text", None) + if preprocessor_text: + preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text) + preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf")) + self.preprocessor_text = preprocessor_text + + self.frontend = frontend + self.fs = 16000 if frontend is None else frontend.fs + self.data_type = "sound" + self.tokenizer = tokenizer + + self.int_pad_value = int_pad_value + self.float_pad_value = float_pad_value + self.prompt = kwargs.get("prompt", "Transcribe speech to text.") + self.prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format( + self.prompt) # "USER: \nINSTRUCTION: {}\nnINPUT: {}\nASSISTANT: " + self.prompt_af = "" + + def get_source_len(self, index): + item = self.index_ds[index] + return self.index_ds.get_source_len(item) + + def get_target_len(self, index): + item = self.index_ds[index] + return self.index_ds.get_target_len(item) + + def __len__(self): + return len(self.index_ds) + + def __getitem__(self, index): + item = self.index_ds[index] + # import pdb; + # pdb.set_trace() + source = item["source"] + data_src = load_audio_text_image_video(source, fs=self.fs) + if self.preprocessor_speech: + data_src = self.preprocessor_speech(data_src, fs=self.fs) + speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend, is_final=True) # speech: [b, T, d] + speech = speech.sequeeze(0) + + target = item["target"] + if self.preprocessor_text: + target = self.preprocessor_text(target) + + + prompt_ids_pre = self.tokenizer.encode(self.prompt_pre) # [bos,prompt] + prompt_pre_length = len(prompt_ids_pre) + + prompt_input = "{}{}".format(self.prompt_pre, target) + prompt_input_ids = self.tokenizer.encode(prompt_input) + audio_length = len(prompt_input_ids) - prompt_pre_length + input_ids = prompt_input_ids + [self.tokenizer.pad_token_id] + input_ids = torch.tensor(input_ids, dtype=torch.int64) #[bos, prompt, input, pad] + input_ids[prompt_pre_length:] = -1 # [bos, prompt,-1,-1] + attention_mask = input_ids.ge(-1) # [true, true, true, true], length mask + + prompt_answer = "{}{}".format(self.prompt_pre, target) + prompt_answer_ids = self.tokenizer.encode(prompt_answer) + answer_length = len(prompt_answer_ids) - prompt_pre_length + labels_ids = copy.deepcopy(prompt_input_ids) + [self.tokenizer.eos_token_id] + labels_ids = torch.tensor(labels_ids, dtype=torch.int64) # [bos, prompt, input, eos] + labels_ids[:prompt_pre_length] = -1 # [-1, -1, input, eos] + label_mask = labels_ids.ge(0) # [False,False,True,True] + labels_ids[~label_mask] = self.IGNORE_INDEX # [-100,-100,input,eos] + + audio_mask = [0] * prompt_pre_length + [1] * audio_length + torch.tensor(audio_mask, dtype=torch.float32) + + ids = self.tokenizer.encode(target) + text = torch.tensor(ids, dtype=torch.int64) + text_lengths = torch.tensor([len(ids)], dtype=torch.int32) + + return {"speech": speech, + "speech_lengths": speech_lengths, + "text": text, + "text_lengths": text_lengths, + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels_ids": labels_ids, + "label_mask": label_mask, + "audio_mask": audio_mask, + } + + + def collator(self, samples: list=None): + outputs = {} + for sample in samples: + for key in sample.keys(): + if key not in outputs: + outputs[key] = [] + outputs[key].append(sample[key]) + + for key, data_list in outputs.items(): + if isinstance(data_list[0], torch.Tensor): + if data_list[0].dtype == torch.int64: + + pad_value = self.int_pad_value + else: + pad_value = self.float_pad_value + + outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value) + return outputs diff --git a/funasr/datasets/llm_datasets/preprocessor.py b/funasr/datasets/llm_datasets/preprocessor.py new file mode 100644 index 000000000..ab751401b --- /dev/null +++ b/funasr/datasets/llm_datasets/preprocessor.py @@ -0,0 +1,51 @@ +import os +import json +import torch +import logging +import concurrent.futures +import librosa +import torch.distributed as dist +from typing import Collection +import torch +import torchaudio +from torch import nn +import random +import re +from funasr.tokenizer.cleaner import TextCleaner +from funasr.register import tables + + +@tables.register("preprocessor_classes", "SpeechPreprocessSpeedPerturb") +class SpeechPreprocessSpeedPerturb(nn.Module): + def __init__(self, speed_perturb: list=None, **kwargs): + super().__init__() + self.speed_perturb = speed_perturb + + def forward(self, waveform, fs, **kwargs): + if self.speed_perturb is None: + return waveform + speed = random.choice(self.speed_perturb) + if speed != 1.0: + if not isinstance(waveform, torch.Tensor): + waveform = torch.tensor(waveform) + waveform, _ = torchaudio.sox_effects.apply_effects_tensor( + waveform.view(1, -1), fs, [['speed', str(speed)], ['rate', str(fs)]]) + waveform = waveform.view(-1) + + return waveform + + +@tables.register("preprocessor_classes", "TextPreprocessSegDict") +class TextPreprocessSegDict(nn.Module): + def __init__(self, seg_dict: str = None, + text_cleaner: Collection[str] = None, + split_with_space: bool = False, + **kwargs): + super().__init__() + + self.text_cleaner = TextCleaner(text_cleaner) + + def forward(self, text, **kwargs): + text = self.text_cleaner(text) + + return text diff --git a/funasr/datasets/llm_datasets/samplers.py b/funasr/datasets/llm_datasets/samplers.py new file mode 100644 index 000000000..914e77692 --- /dev/null +++ b/funasr/datasets/llm_datasets/samplers.py @@ -0,0 +1,277 @@ +import torch +import numpy as np +import logging +import torch.distributed as dist + +from funasr.register import tables + + +@tables.register("batch_sampler_classes", "DynamicBatchLocalShuffleSampler") +class BatchSampler(torch.utils.data.BatchSampler): + + def __init__(self, dataset, + batch_type: str = "example", + batch_size: int = 100, + buffer_size: int = 30, + drop_last: bool = False, + shuffle: bool = True, + is_training: bool = True, + **kwargs): + + self.drop_last = drop_last + self.pre_idx = -1 + self.dataset = dataset + self.total_samples = len(dataset) + self.batch_type = batch_type + self.batch_size = int(batch_size) + self.buffer_size = buffer_size + self.max_token_length = kwargs.get("max_token_length", 5000) + self.shuffle_idx = np.arange(self.total_samples) + self.shuffle = shuffle and is_training + self.length_scale_source = kwargs.get("length_scale_source", 1.0) + + + def __len__(self): + return (self.total_samples-1) // self.batch_size + 1 + + def set_epoch(self, epoch): + np.random.seed(epoch) + + def __iter__(self): + + if self.shuffle: + np.random.shuffle(self.shuffle_idx) + + batch = [] + max_token = 0 + num_sample = 0 + + iter_num = (self.total_samples - 1) // self.buffer_size + 1 + # print("iter_num: ", iter_num) + for iter in range(self.pre_idx + 1, iter_num): + datalen_with_index = [] + for i in range(self.buffer_size): + idx = iter * self.buffer_size + i + if idx >= self.total_samples: + continue + + idx_map = self.shuffle_idx[idx] + # prompt = self.dataset.indexed_dataset[idx_map]["prompt"] + target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0 + source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source + sample_len_cur = source_len + target_len + + + datalen_with_index.append([idx, sample_len_cur]) + + datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1]) + for item in datalen_with_index_sort: + idx, sample_len_cur_raw = item + if sample_len_cur_raw > self.max_token_length: + continue + + max_token_cur = max(max_token, sample_len_cur_raw) + max_token_padding = 1 + num_sample + if self.batch_type != 'example': + max_token_padding *= max_token_cur + if max_token_padding <= self.batch_size: + batch.append(idx) + max_token = max_token_cur + num_sample += 1 + else: + yield batch + batch = [idx] + max_token = sample_len_cur_raw + num_sample = 1 + + +@tables.register("batch_sampler_classes", "BatchSampler") +@tables.register("batch_sampler_classes", "RankFullLocalShuffleBatchSampler") +class RankFullLocalShuffleBatchSampler(torch.utils.data.BatchSampler): + + def __init__(self, dataset, + batch_type: str = "example", + batch_size: int = 100, + buffer_size: int = 30, + drop_last: bool = True, + shuffle: bool = True, + is_training: bool = True, + **kwargs): + + self.drop_last = drop_last + self.pre_idx = -1 + self.dataset = dataset + self.total_samples = len(dataset) + self.batch_type = batch_type + self.batch_size = int(batch_size) + self.buffer_size = buffer_size + self.max_token_length = kwargs.get("max_token_length", 1500) + self.shuffle_idx = np.arange(self.total_samples) + self.shuffle = shuffle and is_training + self.length_scale_source = kwargs.get("length_scale_source", 1.0) + + try: + rank = dist.get_rank() + world_size = dist.get_world_size() + except: + rank = 0 + world_size = 1 + self.rank = rank + self.world_size = world_size + + def __len__(self): + return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1 + + def set_epoch(self, epoch): + np.random.seed(epoch) + + def __iter__(self): + + batch_size_total = self.batch_size * self.world_size + + if self.shuffle: + np.random.shuffle(self.shuffle_idx) + + batch = [] + max_token = 0 + num_sample = 0 + + iter_num = (self.total_samples - 1) // self.buffer_size + 1 + # print("iter_num: ", iter_num) + for iter in range(self.pre_idx + 1, iter_num): + # if iter == iter_num -1 and self.drop_last: + # continue + datalen_with_index = [] + for i in range(self.buffer_size): + idx = iter * self.buffer_size + i + if idx >= self.total_samples: + continue + + idx_map = self.shuffle_idx[idx] + # prompt = self.dataset.indexed_dataset[idx_map]["prompt"] + + source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source + target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0 + sample_len_cur = source_len + target_len + + datalen_with_index.append([idx, sample_len_cur]) + + datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1]) + for item in datalen_with_index_sort: + idx, sample_len_cur_raw = item + if sample_len_cur_raw > self.max_token_length: + continue + + max_token_cur = max(max_token, sample_len_cur_raw) + max_token_padding = 1 + num_sample + # if self.batch_type != 'example': + # max_token_padding *= max_token_cur + if max_token_padding <= batch_size_total: + batch.append(idx) + max_token = max_token_cur + num_sample += 1 + else: + batch_rank = batch[self.rank*self.batch_size: (self.rank+1)*self.batch_size] + yield batch_rank + batch = [idx] + max_token = sample_len_cur_raw + num_sample = 1 + + +@tables.register("batch_sampler_classes", "RankFullLocalShuffleDynamicBatchSampler") +class RankFullLocalShuffleDynamicBatchSampler(torch.utils.data.BatchSampler): + + def __init__(self, dataset, + batch_type: str = "example", + batch_size: int = 100, + buffer_size: int = 30, + drop_last: bool = True, + shuffle: bool = True, + is_training: bool = True, + **kwargs): + + self.drop_last = drop_last + self.pre_idx = -1 + self.dataset = dataset + self.total_samples = len(dataset) + self.batch_type = batch_type + self.batch_size = int(batch_size) + self.buffer_size = buffer_size + self.max_token_length = kwargs.get("max_token_length", 1500) + self.shuffle_idx = np.arange(self.total_samples) + self.shuffle = shuffle and is_training + self.length_scale_source = kwargs.get("length_scale_source", 1.0) + + try: + rank = dist.get_rank() + world_size = dist.get_world_size() + except: + rank = 0 + world_size = 1 + self.rank = rank + self.world_size = world_size + + def __len__(self): + return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1 + + def set_epoch(self, epoch): + np.random.seed(epoch) + + def __iter__(self): + + batch_size_total = self.batch_size * self.world_size + if self.shuffle: + np.random.shuffle(self.shuffle_idx) + + batch_list_all_rank = [] + batch_list_cur = [] + max_token = 0 + num_sample = 0 + + iter_num = (self.total_samples - 1) // self.buffer_size + 1 + # print("iter_num: ", iter_num) + for iter in range(self.pre_idx + 1, iter_num): + # if iter == iter_num - 1 and self.drop_last: + # continue + datalen_with_index = [] + for i in range(self.buffer_size): + idx = iter * self.buffer_size + i + if idx >= self.total_samples: + continue + + idx_map = self.shuffle_idx[idx] + # prompt = self.dataset.indexed_dataset[idx_map]["prompt"] + + source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source + target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0 + sample_len_cur = source_len + target_len + + datalen_with_index.append([idx, sample_len_cur]) + + datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1]) + for ii, item in enumerate(datalen_with_index_sort): + is_last_batch = iter == iter_num - 1 and ii == len(datalen_with_index_sort) + idx, sample_len_cur_raw = item + if sample_len_cur_raw > self.max_token_length: + continue + + max_token_cur = max(max_token, sample_len_cur_raw) + max_token_padding = 1 + num_sample + + if self.batch_type != 'example': + max_token_padding *= max_token_cur + if len(batch_list_all_rank) < self.world_size: + + if max_token_padding <= self.batch_size: + batch_list_cur.append(idx) + max_token = max_token_cur + num_sample += 1 + else: + batch_list_all_rank.append(batch_list_cur) + batch_list_cur = [] + else: + batch_rank = batch_list_all_rank[self.rank] + yield batch_rank + batch_list_all_rank = [idx] + max_token = sample_len_cur_raw + num_sample = 1 diff --git a/funasr/datasets/llm_datasets/scp2jsonl.py b/funasr/datasets/llm_datasets/scp2jsonl.py new file mode 100644 index 000000000..e09a84a61 --- /dev/null +++ b/funasr/datasets/llm_datasets/scp2jsonl.py @@ -0,0 +1,96 @@ +import os +import json +import torch +import logging +import hydra +from omegaconf import DictConfig, OmegaConf +import concurrent.futures +import librosa +import torch.distributed as dist + + + +def gen_jsonl_from_wav_text_list(path, data_type_list=("source", "target"), jsonl_file_out:str=None, **kwargs): + try: + rank = dist.get_rank() + world_size = dist.get_world_size() + except: + rank = 0 + world_size = 1 + + cpu_cores = os.cpu_count() or 1 + print(f"convert wav.scp text to jsonl, ncpu: {cpu_cores}") + if rank == 0: + json_dict = {} + for data_type, data_file in zip(data_type_list, path): + json_dict[data_type] = {} + with open(data_file, "r") as f: + + data_file_lists = f.readlines() + lines_for_each_th = (len(data_file_lists)-1)//cpu_cores + 1 + task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1 + with concurrent.futures.ThreadPoolExecutor(max_workers=cpu_cores) as executor: + + futures = [executor.submit(parse_context_length, data_file_lists[i*lines_for_each_th:(i+1)*lines_for_each_th], data_type) for i in range(task_num)] + + for future in concurrent.futures.as_completed(futures): + + json_dict[data_type].update(future.result()) + # print(json_dict) + + with open(jsonl_file_out, "w") as f: + for key in json_dict[data_type_list[0]].keys(): + jsonl_line = {"key": key} + for data_file in data_type_list: + jsonl_line.update(json_dict[data_file][key]) + jsonl_line = json.dumps(jsonl_line, ensure_ascii=False) + f.write(jsonl_line+"\n") + f.flush() + + else: + pass + + if world_size > 1: + dist.barrier() + + +def parse_context_length(data_list: list, data_type: str): + + res = {} + for i, line in enumerate(data_list): + key, line = line.strip().split(maxsplit=1) + line = line.strip() + if os.path.exists(line): + waveform, _ = librosa.load(line, sr=16000) + sample_num = len(waveform) + context_len = int(sample_num//16000*1000/10) + else: + context_len = len(line.split()) if " " in line else len(line) + res[key] = {data_type: line, f"{data_type}_len": context_len} + return res + + +@hydra.main(config_name=None, version_base=None) +def main_hydra(cfg: DictConfig): + + kwargs = OmegaConf.to_container(cfg, resolve=True) + + scp_file_list = kwargs.get("scp_file_list", ("/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt")) + if isinstance(scp_file_list, str): + scp_file_list = eval(scp_file_list) + data_type_list = kwargs.get("data_type_list", ("source", "target")) + jsonl_file_out = kwargs.get("jsonl_file_out", "/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl") + gen_jsonl_from_wav_text_list(scp_file_list, data_type_list=data_type_list, jsonl_file_out=jsonl_file_out) + + +""" +python -m funasr.datasets.audio_datasets.scp2jsonl \ +++scp_file_list='["/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"]' \ +++data_type_list='["source", "target"]' \ +++jsonl_file_out=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl +""" + +if __name__ == "__main__": + main_hydra() + + \ No newline at end of file diff --git a/funasr/metrics/compute_acc.py b/funasr/metrics/compute_acc.py index 9d16e1f3b..73545c0ee 100644 --- a/funasr/metrics/compute_acc.py +++ b/funasr/metrics/compute_acc.py @@ -21,3 +21,22 @@ def th_accuracy(pad_outputs, pad_targets, ignore_label): ) denominator = torch.sum(mask) return float(numerator) / float(denominator) + +def compute_accuracy(pad_outputs, pad_targets, ignore_label): + """Calculate accuracy. + + Args: + pad_outputs (LongTensor): Prediction tensors (B, Lmax). + pad_targets (LongTensor): Target label tensors (B, Lmax). + ignore_label (int): Ignore label id. + + Returns: + float: Accuracy value (0.0 - 1.0). + + """ + mask = pad_targets != ignore_label + numerator = torch.sum( + pad_outputs.masked_select(mask) == pad_targets.masked_select(mask) + ) + denominator = torch.sum(mask) + return numerator.float() / denominator.float() #(FIX:MZY):return torch.Tensor type \ No newline at end of file diff --git a/funasr/models/llm_asr/__init__.py b/funasr/models/llm_asr/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/funasr/models/llm_asr/adaptor.py b/funasr/models/llm_asr/adaptor.py new file mode 100644 index 000000000..0676e7deb --- /dev/null +++ b/funasr/models/llm_asr/adaptor.py @@ -0,0 +1,29 @@ +import torch +import torch.nn as nn + +from funasr.register import tables + +@tables.register("adaptor_classes", "Linear") +class Linear(nn.Module): + def __init__(self, downsample_rate, encoder_dim, llm_dim, ffn_dim: int = 2048, **kwargs): + super().__init__() + self.k = downsample_rate + self.encoder_dim = encoder_dim + self.llm_dim = llm_dim + self.linear1 = nn.Linear(self.encoder_dim * self.k, ffn_dim) + self.relu = nn.ReLU() + self.linear2 = nn.Linear(ffn_dim, self.llm_dim) + + def forward(self, x): + batch_size, seq_len, dim = x.size() + num_frames_to_discard = seq_len % self.k + if num_frames_to_discard > 0: + x = x[:, :-num_frames_to_discard, :] + seq_len = x.size(1) + + x = x.contiguous() + x = x.view(batch_size, seq_len // self.k, dim * self.k) + x = self.linear1(x) + x = self.relu(x) + x = self.linear2(x) + return x diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py new file mode 100644 index 000000000..fcb301d07 --- /dev/null +++ b/funasr/models/llm_asr/model.py @@ -0,0 +1,353 @@ +import logging +from typing import Union, Dict, List, Tuple, Optional + +import time +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.cuda.amp import autocast + +from funasr.losses.label_smoothing_loss import LabelSmoothingLoss +from funasr.models.ctc.ctc import CTC +from funasr.models.transformer.utils.add_sos_eos import add_sos_eos +from funasr.metrics.compute_acc import th_accuracy, compute_accuracy +# from funasr.models.e2e_asr_common import ErrorCalculator +from funasr.train_utils.device_funcs import force_gatherable +from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank +from funasr.utils import postprocess_utils +from funasr.utils.datadir_writer import DatadirWriter +from funasr.register import tables + + +@tables.register("model_classes", "LLMASR") +class LLMASR(nn.Module): + """ """ + + def __init__( + self, + specaug: str = None, + specaug_conf: dict = None, + normalize: str = None, + normalize_conf: dict = None, + encoder: str = None, + encoder_conf: dict = None, + decoder: str = None, + decoder_conf: dict = None, + ctc: str = None, + ctc_conf: dict = None, + ctc_weight: float = 0.5, + llm: str = None, + llm_conf: dict = None, + adaptor: str = None, + adaptor_conf: dict = None, + input_size: int = 80, + vocab_size: int = -1, + ignore_id: int = -1, + blank_id: int = 0, + sos: int = 1, + eos: int = 2, + lsm_weight: float = 0.0, + length_normalized_loss: bool = False, + report_cer: bool = True, + report_wer: bool = True, + sym_space: str = "", + sym_blank: str = "", + # extract_feats_in_collect_stats: bool = True, + share_embedding: bool = False, + # preencoder: Optional[AbsPreEncoder] = None, + # postencoder: Optional[AbsPostEncoder] = None, + **kwargs, + ): + + super().__init__() + + if specaug is not None: + specaug_class = tables.specaug_classes.get(specaug) + specaug = specaug_class(**specaug_conf) + if normalize is not None: + normalize_class = tables.normalize_classes.get(normalize) + normalize = normalize_class(**normalize_conf) + + # audio encoder + hub = encoder_conf.get("hub", None) + if hub == "funasr": + from funasr import AutoModel + from funasr.models.scama.utils import sequence_mask + init_param_path = encoder_conf.get("hub", "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch") + model = AutoModel(model=init_param_path, model_revision="v2.0.4") + frontend = model.kwargs.get("frontend") + model.model.decoder = None + + self.model = model.model + self.frontend = frontend + self.mask_fn = sequence_mask + + elif hub == "hf": + pass + else: + encoder_class = tables.encoder_classes.get(encoder) + encoder = encoder_class(input_size=input_size, **encoder_conf) + encoder_output_size = encoder.output_size() + + # llm + hub = llm_conf.get("hub", "hf") + self.llm = None + if hub == "hf": + from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig + + init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5") + model = AutoModelForCausalLM.from_pretrained( + init_param_path, + load_in_8bit=None, + device_map=None, + use_cache=None, + ) + freeze_llm = llm_conf.get("freeze_llm", True) + if freeze_llm: + for name, param in model.named_parameters(): + param.requires_grad = False + model.eval() + self.llm = model + + # adaptor + adaptor_class = tables.adaptor_classes.get(adaptor) + adaptor = adaptor_class(**adaptor_conf) + + self.adaptor = adaptor + + + self.blank_id = blank_id + self.sos = sos if sos is not None else vocab_size - 1 + self.eos = eos if eos is not None else vocab_size - 1 + self.vocab_size = vocab_size + self.ignore_id = ignore_id + self.specaug = specaug + self.normalize = normalize + self.encoder = encoder + + + self.criterion_att = LabelSmoothingLoss( + size=vocab_size, + padding_idx=ignore_id, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + # + # if report_cer or report_wer: + # self.error_calculator = ErrorCalculator( + # token_list, sym_space, sym_blank, report_cer, report_wer + # ) + # + self.error_calculator = None + + self.length_normalized_loss = length_normalized_loss + self.beam_search = None + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + input_ids: torch.Tensor, + attention_mask:torch.Tensor, + labels_ids:torch.Tensor, + label_mask: torch.Tensor, + audio_mask:torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Encoder + Decoder + Calc loss + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + """ + # import pdb; + # pdb.set_trace() + if len(text_lengths.size()) > 1: + text_lengths = text_lengths[:, 0] + if len(speech_lengths.size()) > 1: + speech_lengths = speech_lengths[:, 0] + + batch_size = speech.shape[0] + + # audio encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, audio_mask) + + # adaptor + encoder_out = self.adaptor(encoder_out) + + if input_ids is not None: + input_ids[input_ids == -1] = 0 + if hasattr(self.llm.model, "embed_tokens"): + inputs_embeds = self.llm.model.embed_tokens(input_ids) + elif hasattr(self.llm.model.model, "embed_tokens"): + inputs_embeds = self.llm.model.model.embed_tokens(input_ids) + else: + inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids) + + if audio_mask is not None: + batch_size, token_num, dims = inputs_embeds.shape + _, l, _ = encoder_out.shape + encoder_outs_pad = F.pad(encoder_out, (0, 0, token_num-l-1, 1, 0, 0), value=0.0) + inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (~audio_mask[:, :, None]) + inputs_embeds = F.pad(inputs_embeds[:, 1:, :], (0, 0, 0, 1, 0, 0), value=0.0) + + model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels) + loss = model_outputs.loss + + acc_att = -1 + if self.metric: + with torch.no_grad(): + preds = torch.argmax(model_outputs.logits, -1) + acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100) + + stats = {} + # Collect Attn branch stats + stats["acc"] = acc_att.detach() + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + if self.length_normalized_loss: + batch_size = int((text_lengths + 1).sum()) + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + def encode( + self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + audio_mask = kwargs.get("audio_mask") + audio_token_lengths = audio_mask.sum(-1) + + batch = {"speech": speech, "speech_lengths": speech_lengths} + enc, enc_lens = self.model.encode(**batch) + enc_mask = self.mask_fn(enc_lens, enc.size(1), device=enc.device)[:, None, :] + pre_acoustic_embeds, pre_token_length, _, _ = self.model.predictor(enc, + mask=enc_mask, + target_label_length=audio_token_lengths, + ) + + return pre_acoustic_embeds, pre_token_length + + def _calc_att_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + # 1. Forward decoder + decoder_out, _ = self.decoder( + encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens + ) + + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_out_pad) + acc_att = th_accuracy( + decoder_out.view(-1, self.vocab_size), + ys_out_pad, + ignore_label=self.ignore_id, + ) + + # Compute cer/wer using attention-decoder + if self.training or self.error_calculator is None: + cer_att, wer_att = None, None + else: + ys_hat = decoder_out.argmax(dim=-1) + cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) + + return loss_att, acc_att, cer_att, wer_att + + + def inference(self, + data_in, + data_lengths=None, + key: list = None, + tokenizer=None, + frontend=None, + **kwargs, + ): + + if kwargs.get("batch_size", 1) > 1: + raise NotImplementedError("batch decoding is not implemented") + + # init beamsearch + if self.beam_search is None: + logging.info("enable beam_search") + self.init_beam_search(**kwargs) + self.nbest = kwargs.get("nbest", 1) + + meta_data = {} + if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank + speech, speech_lengths = data_in, data_lengths + if len(speech.shape) < 3: + speech = speech[None, :, :] + if speech_lengths is None: + speech_lengths = speech.shape[1] + else: + # extract fbank feats + time1 = time.perf_counter() + audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000), + data_type=kwargs.get("data_type", "sound"), + tokenizer=tokenizer) + time2 = time.perf_counter() + meta_data["load_data"] = f"{time2 - time1:0.3f}" + speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), + frontend=frontend) + time3 = time.perf_counter() + meta_data["extract_feat"] = f"{time3 - time2:0.3f}" + meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 + + speech = speech.to(device=kwargs["device"]) + speech_lengths = speech_lengths.to(device=kwargs["device"]) + # Encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + if isinstance(encoder_out, tuple): + encoder_out = encoder_out[0] + + # c. Passed the encoder result and the beam search + nbest_hyps = self.beam_search( + x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0) + ) + + nbest_hyps = nbest_hyps[: self.nbest] + + results = [] + b, n, d = encoder_out.size() + for i in range(b): + + for nbest_idx, hyp in enumerate(nbest_hyps): + ibest_writer = None + if kwargs.get("output_dir") is not None: + if not hasattr(self, "writer"): + self.writer = DatadirWriter(kwargs.get("output_dir")) + ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"] + + # remove sos/eos and get results + last_pos = -1 + if isinstance(hyp.yseq, list): + token_int = hyp.yseq[1:last_pos] + else: + token_int = hyp.yseq[1:last_pos].tolist() + + # remove blank symbol id, which is assumed to be 0 + token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int)) + + # Change integer-ids to tokens + token = tokenizer.ids2tokens(token_int) + text = tokenizer.tokens2text(token) + + text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) + result_i = {"key": key[i], "token": token, "text": text_postprocessed} + results.append(result_i) + + if ibest_writer is not None: + ibest_writer["token"][key[i]] = " ".join(token) + ibest_writer["text"][key[i]] = text_postprocessed + + return results, meta_data + diff --git a/funasr/models/llm_asr/template.yaml b/funasr/models/llm_asr/template.yaml new file mode 100644 index 000000000..8b564cd40 --- /dev/null +++ b/funasr/models/llm_asr/template.yaml @@ -0,0 +1,90 @@ +# This is an example that demonstrates how to configure a model file. +# You can modify the configuration according to your own requirements. + +# to print the register_table: +# from funasr.register import tables +# tables.print() + +# network architecture +model: LLMASR +model_conf: + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: true + +# encoder +encoder: Paraformer +encoder_conf: + hub: funasr + init_param_path: "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" + +llm: Vicuna +llm_conf: + hub: hf + init_param_path: null + freeze_llm: true + +adaptor: linear +adaptor_conf: + downsample_rate: 1 + llm_dim: 4096 + encoder_dim: 2048 + +# frontend related +frontend: WavFrontend +frontend_conf: + fs: 16000 + window: hamming + n_mels: 80 + frame_length: 25 + frame_shift: 10 + dither: 0.0 + lfr_m: 1 + lfr_n: 1 + +specaug: SpecAug +specaug_conf: + apply_time_warp: true + time_warp_window: 5 + time_warp_mode: bicubic + apply_freq_mask: true + freq_mask_width_range: + - 0 + - 30 + num_freq_mask: 2 + apply_time_mask: true + time_mask_width_range: + - 0 + - 40 + num_time_mask: 2 + +train_conf: + accum_grad: 1 + grad_clip: 5 + max_epoch: 150 + keep_nbest_models: 10 + log_interval: 50 + +optim: adam +optim_conf: + lr: 0.001 + weight_decay: 0.000001 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 35000 + +dataset: AudioLLMDataset +dataset_conf: + index_ds: IndexDSJsonl + batch_sampler: RankFullLocalShuffleBatchSampler + batch_type: example # example or length + batch_size: 4 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len; + max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length, + buffer_size: 500 + shuffle: True + num_workers: 4 + +tokenizer: HuggingfaceTokenizer +tokenizer_conf: + unk_symbol: + init_param_path: null + diff --git a/funasr/tokenizer/hf_tokenizer.py b/funasr/tokenizer/hf_tokenizer.py new file mode 100644 index 000000000..c856b3d5d --- /dev/null +++ b/funasr/tokenizer/hf_tokenizer.py @@ -0,0 +1,15 @@ + +try: + from transformers import AutoTokenizer +except: + print("If you want to use hugging, please `pip install -U transformers`") + +from funasr.register import tables + +@tables.register("tokenizer_classes", "HuggingfaceTokenizer") +def HuggingfaceTokenizer(init_param_path, **kwargs): + + tokenizer = AutoTokenizer.from_pretrained(init_param_path) + + return tokenizer + diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py index d175fbeb8..5b280bf18 100644 --- a/funasr/train_utils/trainer.py +++ b/funasr/train_utils/trainer.py @@ -5,7 +5,8 @@ import logging from tqdm import tqdm from datetime import datetime import torch.distributed as dist -from contextlib import nullcontext +from torch.cuda.amp import autocast, GradScaler +from contextlib import nullcontext, contextmanager # from torch.utils.tensorboard import SummaryWriter from tensorboardX import SummaryWriter from pathlib import Path @@ -14,6 +15,14 @@ from funasr.train_utils.device_funcs import to_device from funasr.train_utils.recursive_op import recursive_average from funasr.train_utils.average_nbest_models import average_checkpoints +@contextmanager +def maybe_autocast(enabled): + if enabled: + with autocast(): + yield + else: + yield + class Trainer: """ A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch, @@ -36,8 +45,9 @@ class Trainer: dataloader_train, dataloader_val, local_rank, - use_ddp=False, - use_fsdp=False, + use_ddp: bool = False, + use_fsdp: bool = False, + use_fp16: bool = False, output_dir: str="./", **kwargs): """ @@ -72,6 +82,9 @@ class Trainer: self.kwargs = kwargs self.log_interval = kwargs.get("log_interval", 50) self.batch_total = 0 + self.use_fp16 = use_fp16 + self.disable_gpu_cache = kwargs.get("disable_gpu_cache", True) + self.scaler = GradScaler(enabled=use_fp16) if use_fp16 else None try: @@ -103,6 +116,8 @@ class Trainer: 'optimizer': self.optim.state_dict(), 'scheduler': self.scheduler.state_dict(), } + if self.scaler: + state["scaler_state"] = self.scaler.state_dict() # Create output directory if it does not exist os.makedirs(self.output_dir, exist_ok=True) filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}') @@ -141,6 +156,8 @@ class Trainer: self.model.load_state_dict(dst_state) self.optim.load_state_dict(checkpoint['optimizer']) self.scheduler.load_state_dict(checkpoint['scheduler']) + if self.scaler and 'scaler_state' in checkpoint: + self.scaler.load_state_dict(checkpoint['scaler_state']) print(f"Checkpoint loaded successfully from '{ckpt}'") else: print(f"No checkpoint found at '{ckpt}', starting from scratch") @@ -221,9 +238,10 @@ class Trainer: my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext with my_context(): time2 = time.perf_counter() - - retval = self.model(**batch) - torch.cuda.empty_cache() + with maybe_autocast(self.use_fp16): + retval = self.model(**batch) + + if self.disable_gpu_cache: torch.cuda.empty_cache() time3 = time.perf_counter() speed_stats["forward_time"] = f"{time3 - time2:0.3f}" @@ -241,7 +259,10 @@ class Trainer: loss *= self.world_size # Scale the loss since we're not updating for every mini-batch loss = loss / accum_grad - loss.backward() + if self.use_fp16: + self.scaler.scale(loss).backward() + else: + loss.backward() time4 = time.perf_counter() speed_stats["backward_time"] = f"{time4 - time3:0.3f}" @@ -264,10 +285,14 @@ class Trainer: # Execute an optimization step (update model parameters) if self.use_ddp or self.use_fsdp: dist.barrier() - self.optim.step() + if self.use_fp16: + self.scaler.step(self.optim) + self.scaler.update() + else: + self.optim.step() self.scheduler.step() # Clear gradients for the next accumulation stage - self.optim.zero_grad() + self.optim.zero_grad(set_to_none=True) total_time = f"{time.perf_counter() - time5:0.3f}" time5 = time.perf_counter() speed_stats["optim_time"] = f"{time5 - time4:0.3f}"