From 27256ed429c95ed8868a01f8555610393dd7b3a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Thu, 6 Jun 2024 15:45:32 +0800 Subject: [PATCH] auto frontend --- funasr/datasets/openai_datasets/__init__.py | 0 funasr/datasets/openai_datasets/datasets.py | 216 ++++++++++++ funasr/datasets/openai_datasets/index_ds.py | 95 ++++++ .../datasets/sense_voice_datasets/datasets.py | 1 + funasr/models/llm_asr/model.py | 318 ++++++++++++++++++ 5 files changed, 630 insertions(+) create mode 100644 funasr/datasets/openai_datasets/__init__.py create mode 100644 funasr/datasets/openai_datasets/datasets.py create mode 100644 funasr/datasets/openai_datasets/index_ds.py diff --git a/funasr/datasets/openai_datasets/__init__.py b/funasr/datasets/openai_datasets/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/funasr/datasets/openai_datasets/datasets.py b/funasr/datasets/openai_datasets/datasets.py new file mode 100644 index 000000000..9a542adb6 --- /dev/null +++ b/funasr/datasets/openai_datasets/datasets.py @@ -0,0 +1,216 @@ +import logging +import re +import torch +import random +import traceback +from funasr.register import tables +from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video + + +@tables.register("dataset_classes", "OpenAIDataset") +class OpenAIDataset(torch.utils.data.Dataset): + """ + SenseVoiceDataset + """ + + def __init__( + self, + path, + index_ds: str = None, + frontend=None, + tokenizer=None, + int_pad_value: int = -1, + float_pad_value: float = 0.0, + **kwargs, + ): + super().__init__() + index_ds_class = tables.index_ds_classes.get(index_ds) + self.index_ds = index_ds_class(path, **kwargs) + preprocessor_speech = kwargs.get("preprocessor_speech", None) + if preprocessor_speech: + preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech) + preprocessor_speech = preprocessor_speech_class( + **kwargs.get("preprocessor_speech_conf") + ) + self.preprocessor_speech = preprocessor_speech + preprocessor_text = kwargs.get("preprocessor_text", None) + if preprocessor_text: + preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text) + preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf")) + self.preprocessor_text = preprocessor_text + + self.frontend = frontend + self.fs = 16000 if frontend is None else frontend.fs + self.data_type = "sound" + self.tokenizer = tokenizer + + self.int_pad_value = int_pad_value + self.float_pad_value = float_pad_value + self.sos = kwargs.get("sos", "<|startoftranscript|>") + self.eos = kwargs.get("eos", "<|endoftext|>") + self.batch_size = kwargs.get("batch_size") + self.batch_type = kwargs.get("batch_type") + self.prompt_ids_len = 0 + self.retry = kwargs.get("retry", 5) + + self.permute = False + from funasr.frontends.whisper_frontend import WhisperFrontend + + if isinstance(self.frontend, WhisperFrontend): + self.permute = True + + self.pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)") + + def get_source_len(self, index): + item = self.index_ds[index] + return self.index_ds.get_source_len(item) + + def get_target_len(self, index): + item = self.index_ds[index] + return self.index_ds.get_target_len(item) + + def __len__(self): + return len(self.index_ds) + + def __getitem__(self, index): + # import pdb; + # pdb.set_trace() + + output = None + for idx in range(self.retry): + if idx == 0: + index_cur = index + else: + index_cur = torch.randint(0, len(self.index_ds), ()).item() + + item = self.index_ds[index_cur] + + system = item["system"] + user = item["user"] + assistant = item["assistant"] + + input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg = [], [], [], [], [], [] + + for i, (system_prompt, user_prompt, target_out) in enumerate( + zip(system, user, assistant) + ): + + source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" + + splits = self.pattern.split(source_input) + source_ids = [] + fbank_mask_i = [] + fbank_beg_i = [] + fbank_lens_i = [] + for k, sub_str in enumerate(splits): + if not sub_str.startswith("<|startofspeech|>"): + sub_token = self.tokenizer.encode(sub_str) + source_ids += sub_token + fbank_mask_i += [0] * len(sub_token) + else: + sub_str = sub_str.replace("<|startofspeech|>", "").replace( + "<|endofspeech|>", "" + ) + if sub_str.startswith("!"): + + data_src = load_audio_text_image_video(sub_str[1:], fs=self.fs) + + speech, speech_lengths = extract_fbank( + data_src, + data_type=self.data_type, + frontend=self.frontend, + is_final=True, + ) # speech: [b, T, d] + if self.permute: + speech = speech.permute(0, 2, 1) + if speech_lengths > self.batch_size: + continue + + fbank_lens = speech_lengths[0].item() + olens = 1 + (fbanks_len - 3 + 2 * 1) // 2 + olens = 1 + (olens - 3 + 2 * 1) // 2 + sub_token_len = (olens - 1) // 2 + 1 + sub_token = [0] * sub_token_len[0] + fbank_beg_i = [len(source_ids)] + source_ids += sub_token + fbank_mask_i += [1] * len(sub_token) + + source_mask = [-100] * len(source_ids) + target_out = f"{target_out}<|im_end|>" + target_ids = tokenizer.encode(target_out) + input_ids += source_ids + target_ids + labels += source_mask + target_ids + fbank_mask += fbank_mask_i + fbank_beg.append(fbank_beg_i) + + input_ids = torch.tensor(input_ids, dtype=torch.int64) + attention_mask = torch.tensor([len(input_ids)], dtype=torch.int32) + labels = torch.tensor(labels, dtype=torch.int64) + + fbank = speech[0, :, :] + fbank_lens = speech_lengths + fbank_mask = torch.tensor(fbank_mask, dtype=torch.float32) + fbank_beg = torch.tensor(fbank_beg, dtype=torch.int32) + + output = { + "speech": fbank, + "speech_lengths": fbank_lens, + "fbank_mask": fbank_mask, + "fbank_beg": fbank_beg, + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels_ids": labels, + } + break + + return output + + def collator(self, samples: list = None): + outputs = {} + for sample in samples: + if sample is None: + continue + for key in sample.keys(): + if key not in outputs: + outputs[key] = [] + outputs[key].append(sample[key]) + + for key, data_list in outputs.items(): + if isinstance(data_list[0], torch.Tensor): + if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32: + + pad_value = self.int_pad_value + else: + pad_value = self.float_pad_value + + outputs[key] = torch.nn.utils.rnn.pad_sequence( + data_list, batch_first=True, padding_value=pad_value + ) + + if self.batch_type != "example": + for i in range(10): + outputs = self._filter_badcase(outputs, i=i) + + return outputs + + def _filter_badcase(self, outputs, i=0): + b, t, _ = outputs["speech"].shape + + if b * t > self.batch_size * 1.25: + beg = torch.randint(0, 2, ()).item() + if b < 2: + beg = 0 + logging.info( + f"Warning, b * t: {b * t} > {self.batch_size}, drop half data {i}th, beg:{beg}" + ) + for key, data_list in outputs.items(): + outputs[key] = outputs[key][beg : beg + b : 2] + + speech_lengths_max = outputs["speech_lengths"].max().item() + outputs["speech"] = outputs["speech"][:, :speech_lengths_max, :] + text_lengths_max = outputs["text_lengths"].max().item() + outputs["text"] = outputs["text"][:, :text_lengths_max] + target_mask_lengths_max = outputs["target_mask_lengths"].max().item() + outputs["target_mask"] = outputs["target_mask"][:, :target_mask_lengths_max] + + return outputs diff --git a/funasr/datasets/openai_datasets/index_ds.py b/funasr/datasets/openai_datasets/index_ds.py new file mode 100644 index 000000000..1c48cd241 --- /dev/null +++ b/funasr/datasets/openai_datasets/index_ds.py @@ -0,0 +1,95 @@ +import os +import json +import torch +import logging + +import librosa +import random +import torch.distributed as dist + +from funasr.register import tables + + +@tables.register("index_ds_classes", "OpenAIIndexDSJsonl") +class OpenAIIndexDSJsonl(torch.utils.data.Dataset): # torch.utils.data.Dataset + + def __init__(self, path: str, **kwargs): + super().__init__() + self.max_source_length = kwargs.get("max_source_length", 2048) + self.min_source_length = kwargs.get("min_source_length", 0) + self.max_target_length = kwargs.get("max_target_length", 2048) + self.min_target_length = kwargs.get("min_target_length", 0) + self.max_token_length = kwargs.get("max_token_length", 2200) + + is_training = kwargs.get("is_training", True) + if not (path.endswith(".jsonl") or path.endswith(".json")): + # jsonl list file + data_split_num = kwargs.get("data_split_num", 1) + data_split_i = kwargs.get("data_split_i", 0) + + if not is_training: + data_split_num = 1 + data_split_i = 0 + with open(path, encoding="utf-8") as fin: + file_list_all = fin.readlines() + + num_per_slice = (len(file_list_all) - 1) // data_split_num + 1 # 16 + file_list = file_list_all[ + data_split_i * num_per_slice : (data_split_i + 1) * num_per_slice + ] + logging.info( + f"is_training: {is_training}, data_split_num: {data_split_num}, data_split_i: {data_split_i}, \nfile_list: {file_list}, \nfile_list_all: {file_list_all}" + ) + + else: + file_list = [path] + + contents = [] + for file_json in file_list: + with open(file_json.strip(), encoding="utf-8") as fin: + for line in fin: + data = json.loads(line.strip())["messages"] + + system, user, assistant = [], [], [] + for i, item in enumerate(data): + role = item["role"] + content = item["content"] + if role == "system": + system.append(content) + elif role == "user": + user.append(content) + elif role == "assistant": + assistant.append(content) + + system = system * len(user) + + contents_i = {"system": system, "user": user, "assistant": assistant} + contents.append(contents_i) + + self.contents = contents + + logging.info("total_num of samplers: {}, {}".format(len(self.contents), path)) + + def __len__(self): + return len(self.contents) + + def __getitem__(self, index): + + data = self.contents[index] + + return data + + def get_source_len(self, data_dict): + return len(data_dict["system"]) + len(data_dict["user"]) + + def get_target_len(self, data_dict): + + return len(data_dict["assistant"]) + + +if __name__ == "__main__": + index_ds = OpenAIIndexDSJsonl( + path="/Users/zhifu/funasr1.0/test_local/data_tmp/tmp_wav_10.jsonl" + ) + print(index_ds.contents) + pass diff --git a/funasr/datasets/sense_voice_datasets/datasets.py b/funasr/datasets/sense_voice_datasets/datasets.py index 690a1c56c..c0beda102 100644 --- a/funasr/datasets/sense_voice_datasets/datasets.py +++ b/funasr/datasets/sense_voice_datasets/datasets.py @@ -1,5 +1,6 @@ import logging +import re import torch import random import traceback diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index 4345f696d..11db0096d 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -341,3 +341,321 @@ class LLMASR(nn.Module): ibest_writer["text"][key[0]] = text return results, meta_data + + +@tables.register("model_classes", "LLMASR2") +class LLMASR2(nn.Module): + """ """ + + def __init__( + self, + specaug: str = None, + specaug_conf: dict = None, + normalize: str = None, + normalize_conf: dict = None, + audio_encoder: str = None, + audio_encoder_conf: dict = None, + audio_adaptor: str = None, + audio_adaptor_conf: dict = None, + decoder: str = None, + decoder_conf: dict = None, + ctc: str = None, + ctc_conf: dict = None, + ctc_weight: float = 0.5, + llm: str = None, + llm_conf: dict = None, + input_size: int = 80, + vocab_size: int = -1, + ignore_id: int = -1, + blank_id: int = 0, + sos: int = 1, + eos: int = 2, + lsm_weight: float = 0.0, + length_normalized_loss: bool = False, + report_cer: bool = True, + report_wer: bool = True, + sym_space: str = "", + sym_blank: str = "", + # extract_feats_in_collect_stats: bool = True, + share_embedding: bool = False, + # preencoder: Optional[AbsPreEncoder] = None, + # postencoder: Optional[AbsPostEncoder] = None, + **kwargs, + ): + + super().__init__() + + if specaug is not None: + specaug_class = tables.specaug_classes.get(specaug) + specaug = specaug_class(**specaug_conf) + if normalize is not None: + normalize_class = tables.normalize_classes.get(normalize) + normalize = normalize_class(**normalize_conf) + + # audio encoder + hub = audio_encoder_conf.get("hub", None) + if hub == "ms": + from funasr import AutoModel + + model = AutoModel(model=audio_encoder, model_revision="master") + # frontend = model.kwargs.get("frontend") + audio_encoder_output_size = model.model.encoder_output_size + + audio_encoder = model.model.model.encoder + + # self.frontend = frontend + + elif hub == "hf": + pass + else: + encoder_class = tables.encoder_classes.get(audio_encoder) + audio_encoder = encoder_class(input_size=input_size, **audio_encoder_conf) + audio_encoder_output_size = audio_encoder.output_size() + freeze = audio_encoder_conf.get("freeze", True) + if freeze: + for name, param in audio_encoder.named_parameters(): + param.requires_grad = False + audio_encoder.eval() + + self.audio_encoder = audio_encoder + + # llm + hub = llm_conf.get("hub", "hf") + self.llm = None + # if hub == "hf": + # from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig + # + # init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5") + # + # model = AutoModelForCausalLM.from_pretrained( + # init_param_path, + # load_in_8bit=None, + # device_map=None, + # use_cache=None, + # ) + # freeze = llm_conf.get("freeze", True) + # if freeze: + # for name, param in model.named_parameters(): + # param.requires_grad = False + # model.eval() + # self.llm = model + + # adaptor + adaptor_class = tables.adaptor_classes.get(audio_adaptor) + audio_adaptor_conf["encoder_dim"] = audio_encoder_output_size + audio_adaptor = adaptor_class(**audio_adaptor_conf) + + self.audio_adaptor = audio_adaptor + + self.blank_id = blank_id + self.sos = sos if sos is not None else vocab_size - 1 + self.eos = eos if eos is not None else vocab_size - 1 + self.vocab_size = vocab_size + self.ignore_id = ignore_id + self.specaug = specaug + self.normalize = normalize + + self.criterion_att = LabelSmoothingLoss( + size=vocab_size, + padding_idx=ignore_id, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + + self.error_calculator = None + + self.length_normalized_loss = length_normalized_loss + self.beam_search = None + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + labels_ids: torch.Tensor, + fbank_beg: torch.Tensor, + fbank_mask: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Encoder + Decoder + Calc loss + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + """ + # import pdb; + # pdb.set_trace() + if len(speech_lengths.size()) > 1: + speech_lengths = speech_lengths[:, 0] + + batch_size = speech.shape[0] + + # audio encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + + # audio_adaptor + encoder_out = self.audio_adaptor(encoder_out) + + input_ids[input_ids == -1] = 0 + input_ids[input_ids == -100] = 0 + if hasattr(self.llm.model, "embed_tokens"): + inputs_embeds = self.llm.model.embed_tokens(input_ids) + elif hasattr(self.llm.model.model, "embed_tokens"): + inputs_embeds = self.llm.model.model.embed_tokens(input_ids) + else: + inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids) + + batch_size, token_num, dims = inputs_embeds.shape + _, l, _ = encoder_out.shape + for batch_idx in range(batch_size): + fbank_beg_idx = fbank_beg[batch_idx, 0].item() + inputs_embeds[batch_idx, fbank_beg_idx : fbank_beg_idx + l, :] = encoder_out[ + batch_idx, :l, : + ] + + model_outputs = self.llm( + inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids + ) + loss = model_outputs.loss + + stats = {} + with torch.no_grad(): + preds = torch.argmax(model_outputs.logits, -1) + acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100) + stats["acc"] = acc_att + + stats["loss"] = torch.clone(loss.detach()) + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + if self.length_normalized_loss: + batch_size = int((text_lengths + 1).sum()) + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + def encode( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + **kwargs, + ): + speech = speech.permute(0, 2, 1) + res = self.audio_encoder(speech) + if isinstance(res, (list, tuple)): + encoder_out, encoder_out_lens = res[0], res[1] + else: + encoder_out, encoder_out_lens = res, speech_lengths + return encoder_out, encoder_out_lens + + def inference( + self, + data_in, + data_lengths=None, + key: list = None, + tokenizer=None, + frontend=None, + **kwargs, + ): + + prompt = kwargs.get("prompt", "Transcribe speech to text.") + + if kwargs.get("batch_size", 1) > 1: + raise NotImplementedError("batch decoding is not implemented") + + meta_data = {} + if ( + isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank" + ): # fbank + speech, speech_lengths = data_in, data_lengths + if len(speech.shape) < 3: + speech = speech[None, :, :] + if speech_lengths is None: + speech_lengths = speech.shape[1] + else: + # extract fbank feats + time1 = time.perf_counter() + audio_sample_list = load_audio_text_image_video( + data_in, + fs=frontend.fs, + audio_fs=kwargs.get("fs", 16000), + data_type=kwargs.get("data_type", "sound"), + tokenizer=tokenizer, + ) + time2 = time.perf_counter() + meta_data["load_data"] = f"{time2 - time1:0.3f}" + speech, speech_lengths = extract_fbank( + audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend + ) + time3 = time.perf_counter() + meta_data["extract_feat"] = f"{time3 - time2:0.3f}" + meta_data["batch_data_time"] = ( + speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 + ) + + speech = speech.to(device=kwargs["device"]) + speech_lengths = speech_lengths.to(device=kwargs["device"]) + + # Encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + + # adaptor + encoder_out = self.audio_adaptor(encoder_out) + + prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(prompt) + prompt_ids = tokenizer.encode(prompt_pre) + prompt_length = len(prompt_ids) + prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(kwargs["device"]) + + if hasattr(self.llm.model, "embed_tokens"): + inputs_embeds = self.llm.model.embed_tokens(prompt_ids) + elif hasattr(self.llm.model.model, "embed_tokens"): + inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids) + else: + inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids) + + inputs_embeds = torch.cat( + (inputs_embeds[None, :, :], encoder_out), dim=1 + ) # [prompt, audio] + attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to( + kwargs["device"] + ) + + preds = self.llm.generate( + inputs_embeds=inputs_embeds, + max_length=kwargs.get("max_length", 200), + max_new_tokens=kwargs.get("max_new_tokens", 200), + num_beams=kwargs.get("num_beams", 4), + do_sample=kwargs.get("do_sample", False), + min_length=kwargs.get("min_length", 1), + top_p=kwargs.get("top_p", 1.0), + repetition_penalty=kwargs.get("repetition_penalty", 1.0), + length_penalty=kwargs.get("length_penalty", 1.0), + temperature=kwargs.get("temperature", 1.0), + attention_mask=attention_mask, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + ) + + text = tokenizer.batch_decode(preds, add_special_tokens=False, skip_special_tokens=True) + + text = text[0].split(": ")[-1] + text = text.strip() + + # preds = torch.argmax(model_outputs.logits, -1) + + ibest_writer = None + if kwargs.get("output_dir") is not None: + if not hasattr(self, "writer"): + self.writer = DatadirWriter(kwargs.get("output_dir")) + ibest_writer = self.writer[f"{0 + 1}best_recog"] + + results = [] + result_i = {"key": key[0], "text": text} + results.append(result_i) + + if ibest_writer is not None: + ibest_writer["text"][key[0]] = text + + return results, meta_data