From 35d04ba3570a9d455495f0da73e2a3b950a9f286 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Tue, 2 Jul 2024 11:37:22 +0800 Subject: [PATCH 01/24] update --- funasr/models/llm_asr/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index 3d74632ce..545493310 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -1636,7 +1636,7 @@ class LLMASR5(nn.Module): target_ids_len_i = end_i - beg_i target_ids_len.append(target_ids_len_i) target_ids.append(target_ids_i) - hidden_states_i = hidden_states[batch_idx, beg_i:end_i, :] + hidden_states_i = hidden_states[batch_idx, beg_i - 1 : end_i - 1, :] hidden_states_select.append(hidden_states_i) beg_i = end_i continue From b6aad84db69c6c3ad57c76abdf1777a8f117b3f2 Mon Sep 17 00:00:00 2001 From: zhifu gao Date: Tue, 2 Jul 2024 13:52:03 +0800 Subject: [PATCH 02/24] Dev dzh deepspeed (#1867) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add audio generator * update ar model --------- Co-authored-by: 志浩 --- funasr/models/llm_asr/model.py | 712 ++++++++++++++++++++++++ funasr/models/llm_asr/transformer_lm.py | 345 ++++++++++++ 2 files changed, 1057 insertions(+) create mode 100644 funasr/models/llm_asr/transformer_lm.py diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index 545493310..cf554384a 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -19,6 +19,7 @@ from funasr.utils import postprocess_utils from funasr.utils.datadir_writer import DatadirWriter from funasr.register import tables from funasr.train_utils.device_funcs import to_device +from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list import traceback dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32} @@ -2003,3 +2004,714 @@ class LLMASR5(nn.Module): ibest_writer["text_tn"][key[0]] = response_clean return results, meta_data + + +@tables.register("model_classes", "LLMASR5") +class LLMASR5(nn.Module): + """ """ + + def __init__( + self, + specaug: str = None, + specaug_conf: dict = None, + normalize: str = None, + normalize_conf: dict = None, + audio_encoder: str = None, + audio_encoder_conf: dict = None, + audio_adaptor: str = None, + audio_adaptor_conf: dict = None, + decoder: str = None, + decoder_conf: dict = None, + ctc: str = None, + ctc_conf: dict = None, + ctc_weight: float = 0.5, + llm: str = None, + llm_conf: dict = None, + input_size: int = 80, + vocab_size: int = -1, + ignore_id: int = -1, + blank_id: int = 0, + sos: int = 1, + eos: int = 2, + lsm_weight: float = 0.0, + length_normalized_loss: bool = False, + report_cer: bool = True, + report_wer: bool = True, + sym_space: str = "", + sym_blank: str = "", + # extract_feats_in_collect_stats: bool = True, + share_embedding: bool = False, + # preencoder: Optional[AbsPreEncoder] = None, + # postencoder: Optional[AbsPostEncoder] = None, + audio_decoder: str = None, + audio_decoder_conf: dict = None, + **kwargs, + ): + + super().__init__() + + # audio encoder + hub = audio_encoder_conf.get("hub", None) + if hub == "ms": + from funasr import AutoModel + + model = AutoModel(model=audio_encoder, model_revision="master") + # frontend = model.kwargs.get("frontend") + audio_encoder_output_size = model.model.encoder_output_size + + audio_encoder = ( + model.model.model.encoder if hasattr(model.model, "model") else model.model.encoder + ) + + # self.frontend = frontend + + elif hub == "hf": + pass + else: + encoder_class = tables.encoder_classes.get(audio_encoder) + audio_encoder = encoder_class(input_size=input_size, **audio_encoder_conf) + audio_encoder_output_size = audio_encoder.output_size() + freeze = audio_encoder_conf.get("freeze", True) + freeze_layer_num = int(audio_encoder_conf.get("freeze_layer_num", -1)) + # if freeze_layer_num > 0: + # freeze_layer_num = range(freeze_layer_num) + + if freeze: + for name, param in audio_encoder.named_parameters(): + if freeze_layer_num > 0: + idx = re.search(r"\.\d+\.", name) + if idx is not None: + beg, end = idx.regs[0] + layer_id = int(name[beg + 1: end - 1]) + if layer_id < freeze_layer_num: + param.requires_grad = False + elif "ln_post." not in name: + param.requires_grad = False + else: + param.requires_grad = False + + audio_encoder.eval() + + self.audio_encoder = audio_encoder + + # llm + self.llm = None + + from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig + + init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5") + + model = AutoModelForCausalLM.from_pretrained( + init_param_path, + load_in_8bit=None, + device_map=None, + use_cache=None, + ) + freeze = llm_conf.get("freeze", True) + if freeze: + for name, param in model.named_parameters(): + param.requires_grad = False + model.eval() + self.llm_dtype = llm_conf.get("llm_dtype", "fp32") + self.llm = model.to(dtype_map[self.llm_dtype]) + llm_dim = model.get_input_embeddings().weight.shape[-1] + + # adaptor + adaptor_class = tables.adaptor_classes.get(audio_adaptor) + audio_adaptor_conf["encoder_dim"] = audio_encoder_output_size + audio_adaptor_conf["llm_dim"] = llm_dim + audio_adaptor = adaptor_class(**audio_adaptor_conf) + init_param_path = audio_adaptor_conf.get("init_param_path", None) + if init_param_path is not None: + src_state = torch.load(init_param_path, map_location="cpu") + flag = audio_adaptor.load_state_dict(src_state, strict=False) + logging.info(f"Loading audio_adaptor ckpt: {init_param_path}, status: {flag}") + + self.audio_adaptor = audio_adaptor + + self.error_calculator = None + + self.length_normalized_loss = length_normalized_loss + self.beam_search = None + + # audio decoder related + self.audio_decoder = self.build_audio_decoder(name=audio_decoder, conf=audio_decoder_conf) + self.audio_decoder_in_proj = torch.nn.Linear(llm_dim, self.audio_decoder.embed_unit) + self.codebook_dim = audio_decoder_conf.pop("codebook_dim", 1024) + self.codebook_size = audio_decoder_conf.pop("codebook_size", 4096) + self.codec_embedder = torch.nn.Embedding(self.codebook_size, self.codebook_dim) + self.audio_decoder_embedding = torch.nn.Embedding(2, self.audio_decoder.embed_unit) + self.ad_sos_eos = 0 + self.ad_task_id = 1 + self.ad_ignore_id = -1 + + def build_audio_decoder(self, name, conf): + if name == "transformer": + from funasr.models.llm_asr.transformer_lm import TransformerEmbedLM + if "text_vocab_size" in conf: + lm_model = TransformerEmbedLM( + vocab_size=self.lm_out_voc_size, + **conf + ) + else: + lm_model = TransformerEmbedLM( + vocab_size=self.lm_out_voc_size, + text_vocab_size=self.lm_out_voc_size, + **conf + ) + else: + raise TypeError(f"Unknown codec decoder type {name}") + + conf["name"] = name + return lm_model + + def calc_dense_vector(self, codec, codec_lengths): + """ + Args: + codec: (B, T, Nq) + codec_lengths: (B, ) + """ + mask = codec != self.ad_ignore_id + return self.codec_embedder(codec * mask).sum(dim=-2) * mask + + def prepare_audio_decoder_io( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + codec: Optional[torch.Tensor] = None, + codec_lengths: Optional[torch.Tensor] = None, + need_targets: bool = True, + ): + """build inputs and targets for language model + + Normally, this function is called in batchify_nll. + Args: + text: (Batch, Length, Dim) + text_lengths: (Batch,) + codec: (Batch, Length) + codec_lengths: (Batch,) + need_targets: bool, whether provide targets + """ + + if need_targets: + assert codec is not None and codec_lengths is not None, \ + "need_target=True, but codec or codec_length is None" + + sos_eos_emb = self.audio_decoder_embedding(torch.tensor([self.ad_sos_eos], dtype=torch.int64, device=text.device)) + task_id_emb = self.audio_decoder_embedding(torch.tensor([self.ad_task_id], dtype=torch.int64, device=text.device)) + codec_emb = None + if codec is not None and codec_lengths is not None: + codec_emb = self.calc_dense_vector(codec, codec_lengths) + inputs_list = [] + for i, text_len in enumerate(text_lengths): + one_input = [sos_eos_emb, text[i, :text_len], task_id_emb] + if codec_emb is not None: + one_input.append(codec_emb[i, :codec_lengths[i]]) + inputs_list.append(torch.cat(one_input, dim=0)) + llm_inputs = pad_list(inputs_list, 0.0) + llm_lengths = text_lengths + 2 + if codec_emb is not None: + llm_lengths = llm_lengths + codec_lengths + + if not need_targets: + return llm_inputs, llm_lengths + + bb, tt = text.shape[0], codec_lengths.max() + 1 + llm_targets = -1 * torch.ones([bb, tt, self.predict_nq], dtype=torch.int64, device=text.device) + for i, codec_len in enumerate(codec_lengths): + llm_targets[i, :codec_len] = codec[i, :codec_len] + llm_targets[i, codec_len] = self.codebook_size + self.sos_eos + + return (llm_inputs, llm_targets), (llm_lengths, codec_lengths + 1) + + def nll( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + codec: Optional[torch.Tensor] = None, + codec_lengths: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute negative log likelihood(nll) + + Normally, this function is called in batchify_nll. + Args: + text: (Batch, Length, Dim) + text_lengths: (Batch,) + codec: (Batch, Length) + codec_lengths: (Batch,) + """ + batch_size = text.size(0) + # For data parallel + text = text[:, :text_lengths.max()] + codec = codec[:, :codec_lengths.max()] + text = self.audio_decoder_in_proj(text) + + # build inputs and targets for language model + with autocast(False): + (sequence, target), (x_lengths, y_lengths) = self.prepare_audio_decoder_io( + text, text_lengths, + codec, codec_lengths, + need_targets=True + ) + + # 2a. Forward Language model + # x: (Batch, Length) -> y: (Batch, Length, NVocab) + sequence = sequence[:, :x_lengths.max()] + target = target[:, :y_lengths.max()] + y, _ = self.audio_decoder(sequence, x_lengths, text_lengths+1) + bb, tt = y.shape[0], y.shape[1] + y = y.reshape(bb, tt, self.predict_nq, -1) + # 2b. Extract real logits + logits_list = [] + for i, (text_len, codec_len) in enumerate(zip(text_lengths, codec_lengths)): + logits_list.append(y[i, text_len + 1:text_len + 2 + codec_len]) + logits = pad_list(logits_list, 0.0) + + # 3. Calc negative log likelihood + tt = logits.shape[1] + nll = self.criterion_ce( + logits.reshape(bb, tt * self.predict_nq, -1), + target.reshape(bb, tt * self.predict_nq) + ) + nll = nll.sum(-1) + # nll: (BxL,) -> (BxL,) + nll.masked_fill_(make_pad_mask(y_lengths * self.predict_nq).to(nll.device).view(-1), 0.0) + # nll: (BxL,) -> (B, L) + nll = nll.reshape(batch_size, -1).reshape(batch_size, tt, self.predict_nq) + + return nll, logits, target, codec_lengths+1 + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + labels_ids: torch.Tensor, + fbank_beg: torch.Tensor, + fbank_mask: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Encoder + Decoder + Calc loss + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + """ + # import pdb + # + # pdb.set_trace() + if len(speech_lengths.size()) > 1: + speech_lengths = speech_lengths[:, 0] + + batch_size_speech, frames, _ = speech.shape + batch_size, token_num = input_ids.shape + + with torch.cuda.amp.autocast(enabled=False): + # audio encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + + # audio_adaptor + encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens) + + input_ids[input_ids < 0] = 0 + inputs_embeds = self.llm.model.get_input_embeddings()(input_ids) + + batch_size, token_num, dims = inputs_embeds.shape + fake_token_len = kwargs.get("fake_token_len") + fake_token_len[fake_token_len < 0] = 0 + fbank_beg[fbank_beg < 0] = 0 + + speech_idx = 0 + for batch_idx in range(batch_size): + + for turn_id in range(fbank_beg.shape[1]): + fbank_beg_idx = fbank_beg[batch_idx, turn_id].item() + if fbank_beg_idx > 0: + speech_token_len = fake_token_len[batch_idx, turn_id] + speech_token = encoder_out[speech_idx, :speech_token_len, :] + + try: + inputs_embeds[ + batch_idx, fbank_beg_idx: fbank_beg_idx + speech_token_len, : + ] = speech_token + except Exception as e: + # + logging.error(f"{str(e)}, {traceback.format_exc()}") + logging.info( + f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}" + ) + # import pdb; + # pdb.set_trace() + speech_token_len = encoder_out_lens[speech_idx].item() + speech_token = encoder_out[speech_idx, :speech_token_len, :] + inputs_embeds[ + batch_idx, fbank_beg_idx: fbank_beg_idx + speech_token_len, : + ] = speech_token + + speech_idx += 1 + + with torch.cuda.amp.autocast( + enabled=True if self.llm_dtype != "fp32" else False, dtype=dtype_map[self.llm_dtype] + ): + labels_ids[labels_ids == -1] = -100 + attention_mask[attention_mask < 0] = 0 + model_outputs = self.llm( + inputs_embeds=inputs_embeds.to(dtype_map[self.llm_dtype]), + attention_mask=attention_mask, + labels=labels_ids, + ) + loss = model_outputs.loss + + stats = {} + with torch.no_grad(): + preds = torch.argmax(model_outputs.logits, -1) + acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100) + stats["acc"] = acc_att + + stats["loss"] = torch.clone(loss.detach()) + stats["batch_size"] = batch_size + stats["batch_size_speech"] = batch_size_speech + stats["batch_size_x_frames"] = frames * batch_size_speech + stats["batch_size_real_frames"] = speech_lengths.sum().item() + stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"] + stats["batch_size_x_tokens"] = token_num * batch_size + stats["batch_size_real_tokens"] = attention_mask.sum().item() + stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"] + + dialog_turns = (fbank_beg > 0).sum(-1) + dialog_turns_max = torch.max(dialog_turns).int().item() + dialog_turns_avg = dialog_turns.sum().item() / batch_size + stats["dialog_turns_max"] = dialog_turns_max + stats["dialog_turns_avg"] = dialog_turns_avg + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + if self.length_normalized_loss: + batch_size = int((labels_ids > 0 + 1).sum()) + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + def encode(self, speech, speech_lengths): + # audio encoder + encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths) + + return encoder_out, encoder_out_lens + + def data_template(self, data): + system, user, assistant = [], [], [] + for i, item in enumerate(data): + role = item["role"] + content = item["content"] + if role == "system": + system.append(content) + elif role == "user": + user.append(content) + elif role == "assistant": + assistant.append(content) + + system = system * len(user) + + contents = { + "system": system, + "user": user, + "assistant": assistant, + } + + return contents + + def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs): + + system = contents["system"] + user = contents["user"] + assistant = contents["assistant"] + pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)") + + input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg, fake_token_len = ( + [], + [], + [], + [], + [], + [], + [], + ) + input_source_ids = [] + for i, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)): + if i >= kwargs.get("multiturn_num_max", 5): + break + if len(input_ids) > kwargs.get("max_token_length", 1500): + break + + if i == 0: + source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" + else: + source_input = f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" + + splits = pattern.split(source_input) + source_ids = [] + fbank_i = [] + fbank_mask_i = [] + fake_token_len_i = 0 + fbank_beg_i = -1 + fbank_lens_i = [] + speech, speech_lengths = [], [] + for k, sub_str in enumerate(splits): + if not sub_str.startswith("<|startofspeech|>"): + sub_token = tokenizer.encode(sub_str) + source_ids += sub_token + fbank_mask_i += [0] * len(sub_token) + else: + sub_str = sub_str.replace("<|startofspeech|>", "").replace( + "<|endofspeech|>", "" + ) + if sub_str.startswith("!"): + sub_str = sub_str[1:] + if sub_str.startswith("!"): # !!bytes + sub_str = eval(sub_str[1:]) + try: + time1 = time.perf_counter() + data_src = load_audio_text_image_video(sub_str, fs=frontend.fs) + time2 = time.perf_counter() + meta_data["load_data"] = f"{time2 - time1:0.3f}" + except Exception as e: + logging.error(f"Loading wav failed! {str(e)}, {traceback.format_exc()}") + + speech, speech_lengths = extract_fbank( + data_src, + data_type=kwargs.get("data_type", "sound"), + frontend=frontend, + is_final=True, + ) # speech: [b, T, d] + + time3 = time.perf_counter() + meta_data["extract_feat"] = f"{time3 - time2:0.3f}" + meta_data["batch_data_time"] = ( + speech_lengths.sum().item() + * frontend.frame_shift + * frontend.lfr_n + / 1000 + ) + + if kwargs.get("permute", True): + speech = speech.permute(0, 2, 1) + if speech_lengths > kwargs.get("max_source_length", 5500): + # logging.info( + # f"speech_lengths > max_source_length: {speech_lengths}>{self.max_source_length}, {item}" + # ) + badcase_flag = True + + olens = 1 + (speech_lengths[0].item() - 3 + 2 * 1) // 2 + olens = 1 + (olens - 3 + 2 * 1) // 2 + fake_token_len_i = (olens - 1) // 2 + 1 + fake_token = [0] * fake_token_len_i + fbank_beg_i = len(source_ids) + source_ids += fake_token + fbank_mask_i += [1] * len(fake_token) + + fbank_beg += [fbank_beg_i + len(input_ids)] + fake_token_len += [fake_token_len_i] + source_mask = [-100] * len(source_ids) + target_out = f"{target_out}<|im_end|>" + target_ids = tokenizer.encode(target_out) + input_source_ids = input_ids + source_ids + input_ids += source_ids + target_ids + labels += source_mask + target_ids + fbank_mask += fbank_mask_i + if len(speech) > 0: + fbank.append(speech[0, :, :]) + fbank_lens.append(speech_lengths) + + input_ids = torch.tensor(input_ids, dtype=torch.int64) # [: self.max_token_length] + attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32) + labels = torch.tensor(labels, dtype=torch.int64) # [: self.max_token_length] + + # fbank = speech[0, :, :] + # fbank_lens = torch.tensor(fbank_lens, dtype=torch.int32) + fbank_mask = torch.tensor(fbank_mask, dtype=torch.float32) + fbank_beg = torch.tensor(fbank_beg, dtype=torch.int32) + fake_token_len = torch.tensor(fake_token_len, dtype=torch.int32) + source_ids = torch.tensor(input_source_ids, dtype=torch.int64) + target_ids = torch.tensor(target_ids, dtype=torch.int64) + + if len(fbank) > 0: + speech = torch.nn.utils.rnn.pad_sequence(fbank, batch_first=True, padding_value=0.0) + speech_lengths = torch.nn.utils.rnn.pad_sequence( + fbank_lens, batch_first=True, padding_value=-1 + ) + else: + speech = [] + speech_lengths = [] + output = { + "speech": speech, + "speech_lengths": speech_lengths, + "fbank_mask": fbank_mask[None, :], + "fbank_beg": fbank_beg[None,], + "fake_token_len": fake_token_len[None, :], + "input_ids": input_ids[None,], + "attention_mask": attention_mask[None,], + "labels_ids": labels, + "source_ids": source_ids[None, :], + "target_ids": target_ids[None, :], + } + + return output + + def inference_prepare( + self, + data_in, + data_lengths=None, + key: list = None, + tokenizer=None, + frontend=None, + **kwargs, + ): + + meta_data = {} + prompt = kwargs.get("prompt", None) + + if kwargs.get("batch_size", 1) > 1: + raise NotImplementedError("batch decoding is not implemented") + + contents = self.data_template(data_in[0]) + output = self.data_load_speech(contents, tokenizer, frontend, meta_data=meta_data, **kwargs) + batch = to_device(output, kwargs["device"]) + + # audio encoder + speech = batch["speech"] + if len(speech) > 0: + speech_lengths = batch["speech_lengths"][:, 0] + # fp16 + if kwargs.get("fp16", False): + speech = speech.to(torch.float16) + elif kwargs.get("bf16", False): + speech = speech.to(torch.bfloat16) + # audio encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + + # audio_adaptor + encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens) + + input_ids = batch["input_ids"] + source_ids = batch["source_ids"] + fbank_beg = batch["fbank_beg"] + fake_token_len = batch["fake_token_len"] + + if not kwargs.get("tearchforing", False): + input_ids = source_ids + + input_ids[input_ids < 0] = 0 + inputs_embeds = self.llm.model.get_input_embeddings()(input_ids) + + batch_size, token_num, dims = inputs_embeds.shape + + fake_token_len[fake_token_len < 0] = 0 + fbank_beg[fbank_beg < 0] = 0 + + speech_idx = 0 + for batch_idx in range(batch_size): + + for turn_id in range(fbank_beg.shape[1]): + fbank_beg_idx = fbank_beg[batch_idx, turn_id].item() + if fbank_beg_idx > 0: + speech_token_len = fake_token_len[batch_idx, turn_id] + speech_token = encoder_out[speech_idx, :speech_token_len, :] + + try: + inputs_embeds[ + batch_idx, fbank_beg_idx: fbank_beg_idx + speech_token_len, : + ] = speech_token + except Exception as e: + # + logging.error(f"{str(e)}, {traceback.format_exc()}") + logging.info( + f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}" + ) + # import pdb; + # pdb.set_trace() + speech_token_len = encoder_out_lens[speech_idx].item() + speech_token = encoder_out[speech_idx, :speech_token_len, :] + inputs_embeds[ + batch_idx, fbank_beg_idx: fbank_beg_idx + speech_token_len, : + ] = speech_token + + speech_idx += 1 + return inputs_embeds, contents, batch, source_ids, meta_data + + def inference( + self, + data_in, + data_lengths=None, + key: list = None, + tokenizer=None, + frontend=None, + **kwargs, + ): + + inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare( + data_in, data_lengths, key, tokenizer, frontend, **kwargs + ) + + llm_dtype = kwargs.get("llm_dtype", "fp32") + if llm_dtype == "fp32": + llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype + llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype + + with torch.cuda.amp.autocast( + enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype] + ): + label = contents["assistant"][-1] + self.llm = self.llm.to(dtype_map[llm_dtype]) + inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype]) + + if not kwargs.get("tearchforing", False): + + generated_ids = self.llm.generate( + inputs_embeds=inputs_embeds, max_new_tokens=kwargs.get("max_length", 512) + ) + # generated_ids = [ + # output_ids[len(input_id) :] + # for input_id, output_ids in zip(input_ids, generated_ids) + # ] + response = tokenizer.batch_decode( + generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True) + )[0] + + loss = None + else: + + labels_ids = batch["labels_ids"] + labels_ids[labels_ids == -1] = -100 + attention_mask = batch.get("attention_mask", None) + # attention_mask = attention_mask.to(dtype_map[llm_dtype]) + model_outputs = self.llm( + inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids + ) + + preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1]:] + response = tokenizer.batch_decode( + preds, + add_special_tokens=False, + skip_special_tokens=kwargs.get("skip_special_tokens", True), + )[0] + loss = model_outputs.loss.item() + + ibest_writer = None + if kwargs.get("output_dir") is not None: + if not hasattr(self, "writer"): + self.writer = DatadirWriter(kwargs.get("output_dir")) + ibest_writer = self.writer[f"{0 + 1}best_recog"] + + results = [] + response_clean = re.sub("[^\w\s\u3000\u4e00-\u9fff]+", "", response) + result_i = {"key": key[0], "text": response, "text_tn": response_clean, "label": label} + if loss is not None: + result_i["loss"] = loss + results.append(result_i) + + if ibest_writer is not None: + ibest_writer["text"][key[0]] = response.replace("\n", " ") + ibest_writer["label"][key[0]] = label.replace("\n", " ") + ibest_writer["text_tn"][key[0]] = response_clean + + return results, meta_data diff --git a/funasr/models/llm_asr/transformer_lm.py b/funasr/models/llm_asr/transformer_lm.py new file mode 100644 index 000000000..e05568742 --- /dev/null +++ b/funasr/models/llm_asr/transformer_lm.py @@ -0,0 +1,345 @@ +from typing import Any +from typing import List +from typing import Tuple + +import torch +import torch.nn as nn + +from funasr.models.transformer.embedding import PositionalEncoding, ScaledPositionalEncoding, RelPositionalEncoding, LegacyRelPositionalEncoding +from funasr.models.transformer.encoder import TransformerEncoder as Encoder +from funasr.models.transformer.utils.mask import subsequent_mask +from funasr.models.transformer.utils.nets_utils import make_pad_mask +import logging +from distutils.version import LooseVersion +from contextlib import contextmanager +if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): + from torch.cuda.amp import autocast +else: + # Nothing to do if torch<1.6.0 + @contextmanager + def autocast(enabled=True): + yield + + +class TransformerEmbedLM(nn.Module): + def __init__( + self, + vocab_size: int, + pos_enc: str = None, + embed_unit: int = 128, + att_unit: int = 256, + head: int = 2, + unit: int = 1024, + layer: int = 4, + dropout_rate: float = 0.5, + attention_dropout_rate: float = 0.0, + pe_type: str = "split", + bidirectional_inputs: bool = False, + text_vocab_size: int = 4000, + input_aug_conf: dict = None, + output_aug_conf: dict = None, + codec_groups: int = 4, + selfattention_layer_type: str = "selfattn", + input_normalize: bool = False, + use_decoder: bool = True, + encoder_type: str = "transformer", + **kwargs + ): + super().__init__() + if pos_enc == "sinusoidal": + pos_enc_class = PositionalEncoding + elif pos_enc == "abs_pos": + pos_enc_class = PositionalEncoding + elif pos_enc == "scaled_abs_pos": + pos_enc_class = ScaledPositionalEncoding + elif pos_enc == "rel_pos": + assert selfattention_layer_type == "rel_selfattn" + pos_enc_class = RelPositionalEncoding + elif pos_enc == "legacy_rel_pos": + assert selfattention_layer_type == "legacy_rel_selfattn" + pos_enc_class = LegacyRelPositionalEncoding + logging.warning( + "Using legacy_rel_pos and it will be deprecated in the future." + ) + elif pos_enc is None: + + def pos_enc_class(*args, **kwargs): + return nn.Sequential() # indentity + + else: + raise ValueError(f"unknown pos-enc option: {pos_enc}") + + self.embed_unit = embed_unit + self.pe_type = pe_type + self.encoder_type = encoder_type + if encoder_type == "llama": + raise NotImplementedError("llama encoder has not been implemented") + # from cosyvoice.nets.encoder.llama_encoder import LlamaEncoder + # # set causal to false, using mask to control causal mode. + # self.encoder = LlamaEncoder( + # input_size=embed_unit, + # output_size=att_unit, + # attention_heads=head, + # num_blocks=layer, + # dropout_rate=dropout_rate, + # attention_dropout_rate=attention_dropout_rate, + # causal=False, + # linear_units=unit, + # ) + else: + self.encoder = Encoder( + idim=embed_unit, + attention_dim=att_unit, + attention_heads=head, + linear_units=unit, + num_blocks=layer, + dropout_rate=dropout_rate, + positional_dropout_rate=dropout_rate, + attention_dropout_rate=attention_dropout_rate, + input_layer="none" if pe_type == "split" else "linear", + pos_enc_class=pos_enc_class, + selfattention_layer_type=selfattention_layer_type, + ) + if use_decoder: + self.decoder = nn.Linear(att_unit, vocab_size) + else: + self.decoder = None + self.attn_unit = att_unit + self.pos_enc_func = None + if pe_type == "split": + assert pos_enc == "sinusoidal" or pos_enc == "abs_pos" or pos_enc == "scaled_abs_pos", \ + "Different positional embedding for inputs and outputs " \ + "only supports sinusoidal, abs_pos and scaled_abs_pos." + self.pos_enc_func = pos_enc_class(embed_unit, 0.1) + self.input_layer = torch.nn.Linear(embed_unit, att_unit) + self.bidirectional_inputs = bidirectional_inputs + self.text_vocab_size = text_vocab_size + self.codec_groups = codec_groups + self.input_aug = None + if input_aug_conf is not None: + from funasr.models.specaug.specaug import SpecAug + self.input_aug = SpecAug(**input_aug_conf) + + self.output_aug = None + if output_aug_conf is not None: + from funasr.models.specaug.specaug import SpecAug + self.output_aug = SpecAug(**output_aug_conf) + + self.normalize = None + if input_normalize: + from funasr.models.normalize.utterance_mvn import UtteranceMVN + self.normalize = UtteranceMVN() + + self.first_pack_mask_conf: dict = kwargs.get("first_pack_mask_conf", None) + + def output_size(self): + return self.attn_unit + + def _target_mask(self, lengths): + ys_mask = ~make_pad_mask(lengths) + m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0) + return ys_mask.unsqueeze(-2) & m + + def clac_first_package_mask(self, mask, input_lengths, cond_lengths): + device = mask.device + mask_type = self.first_pack_mask_conf.get("mask_type", "first_pack") + fp_token_len = self.first_pack_mask_conf["fp_token_len"] + fp_text_len = self.first_pack_mask_conf["fp_text_len"] + # NOTE: fp_text_len excluding sos, xvec, only including text + # NOTE: cond_lengths including sos, xvec and text + if mask_type == "streaming": + for i, (seq_len, cond_len) in enumerate(zip(input_lengths, cond_lengths)): + # 1 for task_id + token_len = seq_len - cond_len - 1 + if token_len > 0: + target_text_len = torch.ceil(torch.arange(1, token_len+1, device=device) / fp_token_len) * fp_text_len + # 2 for sos and xvec, M -> M x 1 + target_text_len = torch.minimum(target_text_len + 2, cond_len).unsqueeze(1) + # 1 x N + pos_range = torch.arange(0, cond_len, device=device).unsqueeze(0) + # M x N + text_mask = pos_range < target_text_len + # 1 for + mask[i, cond_len+1:seq_len, :cond_len] = mask[i, cond_len+1:seq_len, :cond_len] * text_mask + else: + for i, (seq_len, cond_len) in enumerate(zip(input_lengths, cond_lengths)): + mask_token_end = min(cond_len+1+fp_token_len, seq_len) + mask[i, cond_len+1:mask_token_end, fp_text_len+2:cond_len] = 0 + + return mask + + def forward( + self, + input: torch.Tensor, + input_lengths: torch.Tensor, + cond_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute LM loss value from buffer sequences. + + Args: + input (torch.Tensor): Input ids. (batch, len, dim) + input_lengths (torch.Tensor): length of input. (batch,) + cond_lengths (torch.Tensor): length of conditions (including sos, excluding taskid). (batch,) + + """ + mask = self._target_mask(input_lengths).to(input.device) + if self.first_pack_mask_conf is not None: + mask = self.clac_first_package_mask(mask, input_lengths, cond_lengths) + if self.bidirectional_inputs: + for i, length in enumerate(cond_lengths): + mask[i, :length, :length] = True + pos_emb = None + if self.pe_type == "split": + pos_emb = torch.zeros((input.shape[0], input.shape[1]*2-1, self.attn_unit)).to(input) + kk = self.codec_groups + # with torch.no_grad(): + with autocast(False): + for i, length in enumerate(cond_lengths): + # perform specaug for each frame including multi-group. + raw_feat = input[i:i + 1, 1:length].clone() + bb, tt, dd = raw_feat.shape + raw_feat = raw_feat.reshape(bb, tt // kk, kk, dd).reshape(bb, tt // kk, kk * dd) + + if self.input_aug is not None and self.training: + raw_feat = self.input_aug(raw_feat, (cond_lengths[i:i+1] - 1) // kk)[0] + + if self.normalize is not None: + raw_feat = self.normalize(raw_feat, None)[0] + + input[i:i + 1, 1:length] = raw_feat.reshape(bb, tt//kk, kk, dd).reshape(bb, tt, dd) + + if self.output_aug is not None and self.training: + raw_feat = input[i:i + 1, length+1:].clone() + aug_feat = self.output_aug(raw_feat, input_lengths[i:i+1] - length - 2)[0] + input[i:i + 1, length + 1:] = aug_feat + + # add positional encoding + if self.pe_type == "split" and self.pos_enc_func is not None: + posed_input = self.pos_enc_func(input[i:i + 1, :length].clone()) + if isinstance(posed_input, tuple): + pos_emb[i:i+1, :length*2-1] = posed_input[1] + posed_input = posed_input[0] + input[i:i + 1, :length] = posed_input + + posed_output = self.pos_enc_func(input[i:i + 1, length + 1:].clone()) + if isinstance(posed_output, tuple): + pos_emb[i:i+1, length*2: length*2+posed_output[1].shape[1]] = posed_output[1] + posed_output = posed_output[0] + input[i:i + 1, length + 1:] = posed_output + + if self.pe_type == "split": + input = self.input_layer(input) + if isinstance(self.pos_enc_func, (RelPositionalEncoding, LegacyRelPositionalEncoding)): + input = (input, pos_emb) + # logging.info(f"shapes {input.shape} {mask.shape} {input_lengths}") + h, _ = self.encoder(input, mask) + if self.decoder is None: + return h, h + + y = self.decoder(h) + return y, h + + def init_state(self, x: torch.Tensor): + return None + + def score( + self, y: torch.Tensor, state: Any, x: torch.Tensor + ) -> Tuple[torch.Tensor, Any]: + """Score new token. + + Args: + y (torch.Tensor): 2D torch.float prefix embeddings. + state: Scorer state for prefix tokens + x (torch.Tensor): encoder feature that generates ys. + + Returns: + tuple[torch.Tensor, Any]: Tuple of + torch.float32 scores for next token (vocab_size) + and next state for ys + + """ + # this implementation is much faster than the blow!! + mask = torch.tril(torch.ones((1, y.shape[0], y.shape[0]), device=y.device)).to(torch.bool) + y_emb = y.unsqueeze(0).to(x.device) + # lengths = y_emb.new_full([1], dtype=torch.long, fill_value=y_emb.size(1)) + # mask = self._target_mask(lengths).to(y_emb.device) + # x includes , feat, + input_length = x.shape[0] - 1 + if self.bidirectional_inputs: + mask[:1, :input_length, :input_length] = True + # if self.first_pack_mask_conf is not None: + # mask = self.clac_first_package_mask( + # mask, + # torch.tensor([y.shape[0]], device=y.device), + # torch.tensor([input_length], device=y.device), + # ) + if self.pe_type == "split" and self.pos_enc_func is not None: + pos_emb = torch.zeros((y_emb.shape[0], y_emb.shape[1], self.attn_unit)).to(y_emb) + + posed_input = self.pos_enc_func(y_emb[:1, :input_length]) + if isinstance(posed_input, tuple): + pos_emb[:1, :input_length] = posed_input[1] + posed_input = posed_input[0] + y_emb[:1, :input_length] = posed_input + + posed_output = self.pos_enc_func(y_emb[:1, input_length + 1:]) + if isinstance(posed_output, tuple): + pos_emb[:1, input_length + 1:] = posed_output[1] + posed_output = posed_output[0] + y_emb[:1, input_length + 1:] = posed_output + + if self.pe_type == "split": + y_emb = self.input_layer(y_emb) + if isinstance(self.pos_enc_func, (RelPositionalEncoding, LegacyRelPositionalEncoding)): + y_emb = (y_emb, pos_emb) + lm_hidden_states, _, cache = self.encoder.forward_one_step( + y_emb, mask, cache=state + ) + if self.decoder is None: + return lm_hidden_states[:, -1], cache + + h = self.decoder(lm_hidden_states[:, -1])[:, :self.text_vocab_size] + + logp = h.log_softmax(dim=-1).squeeze(0) + # return logp, cache + return logp, (cache, lm_hidden_states[:, -1]) + + def batch_score( + self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor + ) -> Tuple[torch.Tensor, List[Any]]: + """Score new token batch. + + Args: + ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). + states (List[Any]): Scorer states for prefix tokens. + xs (torch.Tensor): + The encoder feature that generates ys (n_batch, xlen, n_feat). + + Returns: + tuple[torch.Tensor, List[Any]]: Tuple of + batchfied scores for next token with shape of `(n_batch, vocab_size)` + and next state list for ys. + + """ + # merge states + n_batch = len(ys) + n_layers = len(self.encoder.encoders) + if states[0] is None: + batch_state = None + else: + # transpose state of [batch, layer] into [layer, batch] + batch_state = [ + torch.stack([states[b][i] for b in range(n_batch)]) + for i in range(n_layers) + ] + + # batch decoding + h, _, states = self.encoder.forward_one_step( + self.embed(ys), self._target_mask(ys), cache=batch_state + ) + h = self.decoder(h[:, -1]) + logp = h.log_softmax(dim=-1) + + # transpose state of [layer, batch] into [batch, layer] + state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] + return logp, state_list From 4784baf2afc6a11e39e521f33b850a2516b6e617 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Tue, 2 Jul 2024 14:10:57 +0800 Subject: [PATCH 03/24] update --- funasr/models/llm_asr/model.py | 241 ++++++++++++++++++--------------- 1 file changed, 131 insertions(+), 110 deletions(-) diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index cf554384a..24837f0dd 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -1408,7 +1408,7 @@ class LLMASR4(nn.Module): return results, meta_data -@tables.register("model_classes", "LLMASR5") +# @tables.register("model_classes", "LLMASR5") class LLMASR5(nn.Module): """ """ @@ -2011,41 +2011,19 @@ class LLMASR5(nn.Module): """ """ def __init__( - self, - specaug: str = None, - specaug_conf: dict = None, - normalize: str = None, - normalize_conf: dict = None, - audio_encoder: str = None, - audio_encoder_conf: dict = None, - audio_adaptor: str = None, - audio_adaptor_conf: dict = None, - decoder: str = None, - decoder_conf: dict = None, - ctc: str = None, - ctc_conf: dict = None, - ctc_weight: float = 0.5, - llm: str = None, - llm_conf: dict = None, - input_size: int = 80, - vocab_size: int = -1, - ignore_id: int = -1, - blank_id: int = 0, - sos: int = 1, - eos: int = 2, - lsm_weight: float = 0.0, - length_normalized_loss: bool = False, - report_cer: bool = True, - report_wer: bool = True, - sym_space: str = "", - sym_blank: str = "", - # extract_feats_in_collect_stats: bool = True, - share_embedding: bool = False, - # preencoder: Optional[AbsPreEncoder] = None, - # postencoder: Optional[AbsPostEncoder] = None, - audio_decoder: str = None, - audio_decoder_conf: dict = None, - **kwargs, + self, + audio_encoder: str = None, + audio_encoder_conf: dict = None, + audio_adaptor: str = None, + audio_adaptor_conf: dict = None, + llm: str = None, + llm_conf: dict = None, + input_size: int = 80, + lsm_weight: float = 0.0, + length_normalized_loss: bool = False, + audio_decoder: str = None, + audio_decoder_conf: dict = None, + **kwargs, ): super().__init__() @@ -2082,7 +2060,7 @@ class LLMASR5(nn.Module): idx = re.search(r"\.\d+\.", name) if idx is not None: beg, end = idx.regs[0] - layer_id = int(name[beg + 1: end - 1]) + layer_id = int(name[beg + 1 : end - 1]) if layer_id < freeze_layer_num: param.requires_grad = False elif "ln_post." not in name: @@ -2134,6 +2112,8 @@ class LLMASR5(nn.Module): self.length_normalized_loss = length_normalized_loss self.beam_search = None + self.eos = kwargs.get("eos", 151645) + # audio decoder related self.audio_decoder = self.build_audio_decoder(name=audio_decoder, conf=audio_decoder_conf) self.audio_decoder_in_proj = torch.nn.Linear(llm_dim, self.audio_decoder.embed_unit) @@ -2148,16 +2128,12 @@ class LLMASR5(nn.Module): def build_audio_decoder(self, name, conf): if name == "transformer": from funasr.models.llm_asr.transformer_lm import TransformerEmbedLM + if "text_vocab_size" in conf: - lm_model = TransformerEmbedLM( - vocab_size=self.lm_out_voc_size, - **conf - ) + lm_model = TransformerEmbedLM(vocab_size=self.lm_out_voc_size, **conf) else: lm_model = TransformerEmbedLM( - vocab_size=self.lm_out_voc_size, - text_vocab_size=self.lm_out_voc_size, - **conf + vocab_size=self.lm_out_voc_size, text_vocab_size=self.lm_out_voc_size, **conf ) else: raise TypeError(f"Unknown codec decoder type {name}") @@ -2175,30 +2151,35 @@ class LLMASR5(nn.Module): return self.codec_embedder(codec * mask).sum(dim=-2) * mask def prepare_audio_decoder_io( - self, - text: torch.Tensor, - text_lengths: torch.Tensor, - codec: Optional[torch.Tensor] = None, - codec_lengths: Optional[torch.Tensor] = None, - need_targets: bool = True, + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + codec: Optional[torch.Tensor] = None, + codec_lengths: Optional[torch.Tensor] = None, + need_targets: bool = True, ): """build inputs and targets for language model - Normally, this function is called in batchify_nll. - Args: - text: (Batch, Length, Dim) - text_lengths: (Batch,) - codec: (Batch, Length) - codec_lengths: (Batch,) - need_targets: bool, whether provide targets - """ + Normally, this function is called in batchify_nll. + Args: + text: (Batch, Length, Dim) + text_lengths: (Batch,) + codec: (Batch, Length) + codec_lengths: (Batch,) + need_targets: bool, whether provide targets + """ if need_targets: - assert codec is not None and codec_lengths is not None, \ - "need_target=True, but codec or codec_length is None" + assert ( + codec is not None and codec_lengths is not None + ), "need_target=True, but codec or codec_length is None" - sos_eos_emb = self.audio_decoder_embedding(torch.tensor([self.ad_sos_eos], dtype=torch.int64, device=text.device)) - task_id_emb = self.audio_decoder_embedding(torch.tensor([self.ad_task_id], dtype=torch.int64, device=text.device)) + sos_eos_emb = self.audio_decoder_embedding( + torch.tensor([self.ad_sos_eos], dtype=torch.int64, device=text.device) + ) + task_id_emb = self.audio_decoder_embedding( + torch.tensor([self.ad_task_id], dtype=torch.int64, device=text.device) + ) codec_emb = None if codec is not None and codec_lengths is not None: codec_emb = self.calc_dense_vector(codec, codec_lengths) @@ -2206,7 +2187,7 @@ class LLMASR5(nn.Module): for i, text_len in enumerate(text_lengths): one_input = [sos_eos_emb, text[i, :text_len], task_id_emb] if codec_emb is not None: - one_input.append(codec_emb[i, :codec_lengths[i]]) + one_input.append(codec_emb[i, : codec_lengths[i]]) inputs_list.append(torch.cat(one_input, dim=0)) llm_inputs = pad_list(inputs_list, 0.0) llm_lengths = text_lengths + 2 @@ -2217,7 +2198,9 @@ class LLMASR5(nn.Module): return llm_inputs, llm_lengths bb, tt = text.shape[0], codec_lengths.max() + 1 - llm_targets = -1 * torch.ones([bb, tt, self.predict_nq], dtype=torch.int64, device=text.device) + llm_targets = -1 * torch.ones( + [bb, tt, self.predict_nq], dtype=torch.int64, device=text.device + ) for i, codec_len in enumerate(codec_lengths): llm_targets[i, :codec_len] = codec[i, :codec_len] llm_targets[i, codec_len] = self.codebook_size + self.sos_eos @@ -2242,36 +2225,33 @@ class LLMASR5(nn.Module): """ batch_size = text.size(0) # For data parallel - text = text[:, :text_lengths.max()] - codec = codec[:, :codec_lengths.max()] + text = text[:, : text_lengths.max()] + codec = codec[:, : codec_lengths.max()] text = self.audio_decoder_in_proj(text) # build inputs and targets for language model with autocast(False): (sequence, target), (x_lengths, y_lengths) = self.prepare_audio_decoder_io( - text, text_lengths, - codec, codec_lengths, - need_targets=True + text, text_lengths, codec, codec_lengths, need_targets=True ) # 2a. Forward Language model # x: (Batch, Length) -> y: (Batch, Length, NVocab) - sequence = sequence[:, :x_lengths.max()] - target = target[:, :y_lengths.max()] - y, _ = self.audio_decoder(sequence, x_lengths, text_lengths+1) + sequence = sequence[:, : x_lengths.max()] + target = target[:, : y_lengths.max()] + y, _ = self.audio_decoder(sequence, x_lengths, text_lengths + 1) bb, tt = y.shape[0], y.shape[1] y = y.reshape(bb, tt, self.predict_nq, -1) # 2b. Extract real logits logits_list = [] for i, (text_len, codec_len) in enumerate(zip(text_lengths, codec_lengths)): - logits_list.append(y[i, text_len + 1:text_len + 2 + codec_len]) + logits_list.append(y[i, text_len + 1 : text_len + 2 + codec_len]) logits = pad_list(logits_list, 0.0) # 3. Calc negative log likelihood tt = logits.shape[1] nll = self.criterion_ce( - logits.reshape(bb, tt * self.predict_nq, -1), - target.reshape(bb, tt * self.predict_nq) + logits.reshape(bb, tt * self.predict_nq, -1), target.reshape(bb, tt * self.predict_nq) ) nll = nll.sum(-1) # nll: (BxL,) -> (BxL,) @@ -2279,18 +2259,18 @@ class LLMASR5(nn.Module): # nll: (BxL,) -> (B, L) nll = nll.reshape(batch_size, -1).reshape(batch_size, tt, self.predict_nq) - return nll, logits, target, codec_lengths+1 + return nll, logits, target, codec_lengths + 1 def forward( - self, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - labels_ids: torch.Tensor, - fbank_beg: torch.Tensor, - fbank_mask: torch.Tensor, - **kwargs, + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + labels_ids: torch.Tensor, + fbank_beg: torch.Tensor, + fbank_mask: torch.Tensor, + **kwargs, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Encoder + Decoder + Calc loss Args: @@ -2334,7 +2314,7 @@ class LLMASR5(nn.Module): try: inputs_embeds[ - batch_idx, fbank_beg_idx: fbank_beg_idx + speech_token_len, : + batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : ] = speech_token except Exception as e: # @@ -2347,13 +2327,13 @@ class LLMASR5(nn.Module): speech_token_len = encoder_out_lens[speech_idx].item() speech_token = encoder_out[speech_idx, :speech_token_len, :] inputs_embeds[ - batch_idx, fbank_beg_idx: fbank_beg_idx + speech_token_len, : + batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : ] = speech_token speech_idx += 1 with torch.cuda.amp.autocast( - enabled=True if self.llm_dtype != "fp32" else False, dtype=dtype_map[self.llm_dtype] + enabled=True if self.llm_dtype != "fp32" else False, dtype=dtype_map[self.llm_dtype] ): labels_ids[labels_ids == -1] = -100 attention_mask[attention_mask < 0] = 0 @@ -2364,6 +2344,47 @@ class LLMASR5(nn.Module): ) loss = model_outputs.loss + codec = kwargs.get("codec") + codec_len = kwargs.get("codec_len") + if len(codec_len.size()) > 1: + codec_len = codec_len[:, 0] + hidden_states = model_outputs.hidden_states[-1].float() + + target_ids = [] + target_ids_len = [] + hidden_states_select = [] + for batch_idx in range(labels_ids.shape[0]): + beg_i = 0 + end_i = 0 + for token_idx in range(labels_ids.shape[1]): + token_int = labels_ids[batch_idx, token_idx].item() + if token_int == self.eos: + target_ids_i = labels_ids[batch_idx, beg_i:end_i] + target_ids_len_i = end_i - beg_i + target_ids_len.append(target_ids_len_i) + target_ids.append(target_ids_i) + hidden_states_i = hidden_states[batch_idx, beg_i - 1 : end_i - 1, :] + hidden_states_select.append(hidden_states_i) + beg_i = end_i + continue + + end_i += 1 + if token_int <= 0: + beg_i += 1 + + target_ids = torch.nn.utils.rnn.pad_sequence( + target_ids, batch_first=True, padding_value=-100 + ) + hidden_states_select = torch.nn.utils.rnn.pad_sequence( + hidden_states_select, batch_first=True, padding_value=0.0 + ) + target_ids_len = torch.tensor(target_ids_len, dtype=torch.int32, device=input_ids.device) + target_ids = target_ids.to(device=input_ids.device) + hidden_states_select = hidden_states_select.to(device=input_ids.device) + loss, logits, target, codec_lengths = self.nll( + hidden_states_select, target_ids_len, codec, codec_len + ) + stats = {} with torch.no_grad(): preds = torch.argmax(model_outputs.logits, -1) @@ -2487,10 +2508,10 @@ class LLMASR5(nn.Module): time3 = time.perf_counter() meta_data["extract_feat"] = f"{time3 - time2:0.3f}" meta_data["batch_data_time"] = ( - speech_lengths.sum().item() - * frontend.frame_shift - * frontend.lfr_n - / 1000 + speech_lengths.sum().item() + * frontend.frame_shift + * frontend.lfr_n + / 1000 ) if kwargs.get("permute", True): @@ -2558,13 +2579,13 @@ class LLMASR5(nn.Module): return output def inference_prepare( - self, - data_in, - data_lengths=None, - key: list = None, - tokenizer=None, - frontend=None, - **kwargs, + self, + data_in, + data_lengths=None, + key: list = None, + tokenizer=None, + frontend=None, + **kwargs, ): meta_data = {} @@ -2619,7 +2640,7 @@ class LLMASR5(nn.Module): try: inputs_embeds[ - batch_idx, fbank_beg_idx: fbank_beg_idx + speech_token_len, : + batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : ] = speech_token except Exception as e: # @@ -2632,20 +2653,20 @@ class LLMASR5(nn.Module): speech_token_len = encoder_out_lens[speech_idx].item() speech_token = encoder_out[speech_idx, :speech_token_len, :] inputs_embeds[ - batch_idx, fbank_beg_idx: fbank_beg_idx + speech_token_len, : + batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : ] = speech_token speech_idx += 1 return inputs_embeds, contents, batch, source_ids, meta_data def inference( - self, - data_in, - data_lengths=None, - key: list = None, - tokenizer=None, - frontend=None, - **kwargs, + self, + data_in, + data_lengths=None, + key: list = None, + tokenizer=None, + frontend=None, + **kwargs, ): inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare( @@ -2658,7 +2679,7 @@ class LLMASR5(nn.Module): llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype with torch.cuda.amp.autocast( - enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype] + enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype] ): label = contents["assistant"][-1] self.llm = self.llm.to(dtype_map[llm_dtype]) @@ -2688,7 +2709,7 @@ class LLMASR5(nn.Module): inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids ) - preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1]:] + preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :] response = tokenizer.batch_decode( preds, add_special_tokens=False, From 9d8a55c66a6e5bf72aa8f765ecacaf96af57e249 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Tue, 2 Jul 2024 14:11:50 +0800 Subject: [PATCH 04/24] update --- funasr/models/llm_asr/model.py | 1190 ++++++++++++++++---------------- 1 file changed, 595 insertions(+), 595 deletions(-) diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index 24837f0dd..72a3d67b5 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -1409,601 +1409,601 @@ class LLMASR4(nn.Module): # @tables.register("model_classes", "LLMASR5") -class LLMASR5(nn.Module): - """ """ - - def __init__( - self, - specaug: str = None, - specaug_conf: dict = None, - normalize: str = None, - normalize_conf: dict = None, - audio_encoder: str = None, - audio_encoder_conf: dict = None, - audio_adaptor: str = None, - audio_adaptor_conf: dict = None, - decoder: str = None, - decoder_conf: dict = None, - ctc: str = None, - ctc_conf: dict = None, - ctc_weight: float = 0.5, - llm: str = None, - llm_conf: dict = None, - input_size: int = 80, - vocab_size: int = -1, - ignore_id: int = -1, - blank_id: int = 0, - lsm_weight: float = 0.0, - length_normalized_loss: bool = False, - report_cer: bool = True, - report_wer: bool = True, - sym_space: str = "", - sym_blank: str = "", - # extract_feats_in_collect_stats: bool = True, - share_embedding: bool = False, - # preencoder: Optional[AbsPreEncoder] = None, - # postencoder: Optional[AbsPostEncoder] = None, - **kwargs, - ): - - super().__init__() - - # audio encoder - hub = audio_encoder_conf.get("hub", None) - if hub == "ms": - from funasr import AutoModel - - model = AutoModel(model=audio_encoder, model_revision="master") - # frontend = model.kwargs.get("frontend") - audio_encoder_output_size = model.model.encoder_output_size - - audio_encoder = ( - model.model.model.encoder if hasattr(model.model, "model") else model.model.encoder - ) - - # self.frontend = frontend - - elif hub == "hf": - pass - else: - encoder_class = tables.encoder_classes.get(audio_encoder) - audio_encoder = encoder_class(input_size=input_size, **audio_encoder_conf) - audio_encoder_output_size = audio_encoder.output_size() - freeze = audio_encoder_conf.get("freeze", True) - freeze_layer_num = int(audio_encoder_conf.get("freeze_layer_num", -1)) - # if freeze_layer_num > 0: - # freeze_layer_num = range(freeze_layer_num) - - if freeze: - for name, param in audio_encoder.named_parameters(): - if freeze_layer_num > 0: - idx = re.search(r"\.\d+\.", name) - if idx is not None: - beg, end = idx.regs[0] - layer_id = int(name[beg + 1 : end - 1]) - if layer_id < freeze_layer_num: - param.requires_grad = False - elif "ln_post." not in name: - param.requires_grad = False - else: - param.requires_grad = False - - audio_encoder.eval() - - self.audio_encoder = audio_encoder - - # llm - self.llm = None - - from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig - - init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5") - - model = AutoModelForCausalLM.from_pretrained( - init_param_path, - load_in_8bit=None, - device_map=None, - use_cache=None, - ) - freeze = llm_conf.get("freeze", True) - if freeze: - for name, param in model.named_parameters(): - param.requires_grad = False - model.eval() - self.llm_dtype = llm_conf.get("llm_dtype", "fp32") - self.llm = model.to(dtype_map[self.llm_dtype]) - llm_dim = model.get_input_embeddings().weight.shape[-1] - - # adaptor - adaptor_class = tables.adaptor_classes.get(audio_adaptor) - audio_adaptor_conf["encoder_dim"] = audio_encoder_output_size - audio_adaptor_conf["llm_dim"] = llm_dim - audio_adaptor = adaptor_class(**audio_adaptor_conf) - init_param_path = audio_adaptor_conf.get("init_param_path", None) - if init_param_path is not None: - src_state = torch.load(init_param_path, map_location="cpu") - flag = audio_adaptor.load_state_dict(src_state, strict=False) - logging.info(f"Loading audio_adaptor ckpt: {init_param_path}, status: {flag}") - - self.audio_adaptor = audio_adaptor - - self.error_calculator = None - - self.length_normalized_loss = length_normalized_loss - self.beam_search = None - - self.eos = kwargs.get("eos", 151645) - - def forward( - self, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - labels_ids: torch.Tensor, - fbank_beg: torch.Tensor, - fbank_mask: torch.Tensor, - **kwargs, - ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: - """Encoder + Decoder + Calc loss - Args: - speech: (Batch, Length, ...) - speech_lengths: (Batch, ) - text: (Batch, Length) - text_lengths: (Batch,) - """ - import pdb - - pdb.set_trace() - - if len(speech_lengths.size()) > 1: - speech_lengths = speech_lengths[:, 0] - - batch_size_speech, frames, _ = speech.shape - batch_size, token_num = input_ids.shape - - with torch.cuda.amp.autocast(enabled=False): - # audio encoder - encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) - - # audio_adaptor - encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens) - - input_ids[input_ids < 0] = 0 - inputs_embeds = self.llm.model.get_input_embeddings()(input_ids) - - batch_size, token_num, dims = inputs_embeds.shape - fake_token_len = kwargs.get("fake_token_len") - fake_token_len[fake_token_len < 0] = 0 - fbank_beg[fbank_beg < 0] = 0 - - speech_idx = 0 - for batch_idx in range(batch_size): - - for turn_id in range(fbank_beg.shape[1]): - fbank_beg_idx = fbank_beg[batch_idx, turn_id].item() - if fbank_beg_idx > 0: - speech_token_len = fake_token_len[batch_idx, turn_id] - speech_token = encoder_out[speech_idx, :speech_token_len, :] - - try: - inputs_embeds[ - batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : - ] = speech_token - except Exception as e: - # - logging.error(f"{str(e)}, {traceback.format_exc()}") - logging.info( - f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}" - ) - # import pdb; - # pdb.set_trace() - speech_token_len = encoder_out_lens[speech_idx].item() - speech_token = encoder_out[speech_idx, :speech_token_len, :] - inputs_embeds[ - batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : - ] = speech_token - - speech_idx += 1 - - with torch.cuda.amp.autocast( - enabled=True if self.llm_dtype != "fp32" else False, dtype=dtype_map[self.llm_dtype] - ): - labels_ids[labels_ids == -1] = -100 - attention_mask[attention_mask < 0] = 0 - model_outputs = self.llm( - inputs_embeds=inputs_embeds.to(dtype_map[self.llm_dtype]), - attention_mask=attention_mask, - labels=labels_ids, - ) - loss = model_outputs.loss - - codec = kwargs.get("codec") - codec_len = kwargs.get("codec_len") - if len(codec_len.size()) > 1: - codec_len = codec_len[:, 0] - hidden_states = model_outputs.hidden_states[-1].float() - - target_ids = [] - target_ids_len = [] - hidden_states_select = [] - for batch_idx in range(labels_ids.shape[0]): - beg_i = 0 - end_i = 0 - for token_idx in range(labels_ids.shape[1]): - token_int = labels_ids[batch_idx, token_idx].item() - if token_int == self.eos: - target_ids_i = labels_ids[batch_idx, beg_i:end_i] - target_ids_len_i = end_i - beg_i - target_ids_len.append(target_ids_len_i) - target_ids.append(target_ids_i) - hidden_states_i = hidden_states[batch_idx, beg_i - 1 : end_i - 1, :] - hidden_states_select.append(hidden_states_i) - beg_i = end_i - continue - - end_i += 1 - if token_int <= 0: - beg_i += 1 - - target_ids = torch.nn.utils.rnn.pad_sequence( - target_ids, batch_first=True, padding_value=-100 - ) - hidden_states_select = torch.nn.utils.rnn.pad_sequence( - hidden_states_select, batch_first=True, padding_value=0.0 - ) - - stats = {} - with torch.no_grad(): - preds = torch.argmax(model_outputs.logits, -1) - acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100) - stats["acc"] = acc_att - - stats["loss"] = torch.clone(loss.detach()) - stats["batch_size"] = batch_size - stats["batch_size_speech"] = batch_size_speech - stats["batch_size_x_frames"] = frames * batch_size_speech - stats["batch_size_real_frames"] = speech_lengths.sum().item() - stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"] - stats["batch_size_x_tokens"] = token_num * batch_size - stats["batch_size_real_tokens"] = attention_mask.sum().item() - stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"] - - dialog_turns = (fbank_beg > 0).sum(-1) - dialog_turns_max = torch.max(dialog_turns).int().item() - dialog_turns_avg = dialog_turns.sum().item() / batch_size - stats["dialog_turns_max"] = dialog_turns_max - stats["dialog_turns_avg"] = dialog_turns_avg - - # force_gatherable: to-device and to-tensor if scalar for DataParallel - if self.length_normalized_loss: - batch_size = int((labels_ids > 0 + 1).sum()) - loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) - return loss, stats, weight - - def encode(self, speech, speech_lengths): - # audio encoder - encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths) - - return encoder_out, encoder_out_lens - - def data_template(self, data): - system, user, assistant = [], [], [] - for i, item in enumerate(data): - role = item["role"] - content = item["content"] - if role == "system": - system.append(content) - elif role == "user": - user.append(content) - elif role == "assistant": - assistant.append(content) - - system = system * len(user) - - contents = { - "system": system, - "user": user, - "assistant": assistant, - } - - return contents - - def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs): - - system = contents["system"] - user = contents["user"] - assistant = contents["assistant"] - pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)") - - input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg, fake_token_len = ( - [], - [], - [], - [], - [], - [], - [], - ) - input_source_ids = [] - for i, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)): - if i >= kwargs.get("multiturn_num_max", 5): - break - if len(input_ids) > kwargs.get("max_token_length", 1500): - break - - if i == 0: - source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" - else: - source_input = f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" - - splits = pattern.split(source_input) - source_ids = [] - fbank_i = [] - fbank_mask_i = [] - fake_token_len_i = 0 - fbank_beg_i = -1 - fbank_lens_i = [] - speech, speech_lengths = [], [] - for k, sub_str in enumerate(splits): - if not sub_str.startswith("<|startofspeech|>"): - sub_token = tokenizer.encode(sub_str) - source_ids += sub_token - fbank_mask_i += [0] * len(sub_token) - else: - sub_str = sub_str.replace("<|startofspeech|>", "").replace( - "<|endofspeech|>", "" - ) - if sub_str.startswith("!"): - sub_str = sub_str[1:] - if sub_str.startswith("!"): # !!bytes - sub_str = eval(sub_str[1:]) - try: - time1 = time.perf_counter() - data_src = load_audio_text_image_video(sub_str, fs=frontend.fs) - time2 = time.perf_counter() - meta_data["load_data"] = f"{time2 - time1:0.3f}" - except Exception as e: - logging.error(f"Loading wav failed! {str(e)}, {traceback.format_exc()}") - - speech, speech_lengths = extract_fbank( - data_src, - data_type=kwargs.get("data_type", "sound"), - frontend=frontend, - is_final=True, - ) # speech: [b, T, d] - - time3 = time.perf_counter() - meta_data["extract_feat"] = f"{time3 - time2:0.3f}" - meta_data["batch_data_time"] = ( - speech_lengths.sum().item() - * frontend.frame_shift - * frontend.lfr_n - / 1000 - ) - - if kwargs.get("permute", True): - speech = speech.permute(0, 2, 1) - if speech_lengths > kwargs.get("max_source_length", 5500): - # logging.info( - # f"speech_lengths > max_source_length: {speech_lengths}>{self.max_source_length}, {item}" - # ) - badcase_flag = True - - olens = 1 + (speech_lengths[0].item() - 3 + 2 * 1) // 2 - olens = 1 + (olens - 3 + 2 * 1) // 2 - fake_token_len_i = (olens - 1) // 2 + 1 - fake_token = [0] * fake_token_len_i - fbank_beg_i = len(source_ids) - source_ids += fake_token - fbank_mask_i += [1] * len(fake_token) - - fbank_beg += [fbank_beg_i + len(input_ids)] - fake_token_len += [fake_token_len_i] - source_mask = [-100] * len(source_ids) - target_out = f"{target_out}<|im_end|>" - target_ids = tokenizer.encode(target_out) - input_source_ids = input_ids + source_ids - input_ids += source_ids + target_ids - labels += source_mask + target_ids - fbank_mask += fbank_mask_i - if len(speech) > 0: - fbank.append(speech[0, :, :]) - fbank_lens.append(speech_lengths) - - input_ids = torch.tensor(input_ids, dtype=torch.int64) # [: self.max_token_length] - attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32) - labels = torch.tensor(labels, dtype=torch.int64) # [: self.max_token_length] - - # fbank = speech[0, :, :] - # fbank_lens = torch.tensor(fbank_lens, dtype=torch.int32) - fbank_mask = torch.tensor(fbank_mask, dtype=torch.float32) - fbank_beg = torch.tensor(fbank_beg, dtype=torch.int32) - fake_token_len = torch.tensor(fake_token_len, dtype=torch.int32) - source_ids = torch.tensor(input_source_ids, dtype=torch.int64) - target_ids = torch.tensor(target_ids, dtype=torch.int64) - - if len(fbank) > 0: - speech = torch.nn.utils.rnn.pad_sequence(fbank, batch_first=True, padding_value=0.0) - speech_lengths = torch.nn.utils.rnn.pad_sequence( - fbank_lens, batch_first=True, padding_value=-1 - ) - else: - speech = [] - speech_lengths = [] - output = { - "speech": speech, - "speech_lengths": speech_lengths, - "fbank_mask": fbank_mask[None, :], - "fbank_beg": fbank_beg[None,], - "fake_token_len": fake_token_len[None, :], - "input_ids": input_ids[None,], - "attention_mask": attention_mask[None,], - "labels_ids": labels, - "source_ids": source_ids[None, :], - "target_ids": target_ids[None, :], - } - - return output - - def inference_prepare( - self, - data_in, - data_lengths=None, - key: list = None, - tokenizer=None, - frontend=None, - **kwargs, - ): - - meta_data = {} - prompt = kwargs.get("prompt", None) - - if kwargs.get("batch_size", 1) > 1: - raise NotImplementedError("batch decoding is not implemented") - - contents = self.data_template(data_in[0]) - output = self.data_load_speech(contents, tokenizer, frontend, meta_data=meta_data, **kwargs) - batch = to_device(output, kwargs["device"]) - - # audio encoder - speech = batch["speech"] - if len(speech) > 0: - speech_lengths = batch["speech_lengths"][:, 0] - # fp16 - if kwargs.get("fp16", False): - speech = speech.to(torch.float16) - elif kwargs.get("bf16", False): - speech = speech.to(torch.bfloat16) - # audio encoder - encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) - - # audio_adaptor - encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens) - - input_ids = batch["input_ids"] - source_ids = batch["source_ids"] - fbank_beg = batch["fbank_beg"] - fake_token_len = batch["fake_token_len"] - - if not kwargs.get("tearchforing", False): - input_ids = source_ids - - input_ids[input_ids < 0] = 0 - inputs_embeds = self.llm.model.get_input_embeddings()(input_ids) - - batch_size, token_num, dims = inputs_embeds.shape - - fake_token_len[fake_token_len < 0] = 0 - fbank_beg[fbank_beg < 0] = 0 - - speech_idx = 0 - for batch_idx in range(batch_size): - - for turn_id in range(fbank_beg.shape[1]): - fbank_beg_idx = fbank_beg[batch_idx, turn_id].item() - if fbank_beg_idx > 0: - speech_token_len = fake_token_len[batch_idx, turn_id] - speech_token = encoder_out[speech_idx, :speech_token_len, :] - - try: - inputs_embeds[ - batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : - ] = speech_token - except Exception as e: - # - logging.error(f"{str(e)}, {traceback.format_exc()}") - logging.info( - f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}" - ) - # import pdb; - # pdb.set_trace() - speech_token_len = encoder_out_lens[speech_idx].item() - speech_token = encoder_out[speech_idx, :speech_token_len, :] - inputs_embeds[ - batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : - ] = speech_token - - speech_idx += 1 - return inputs_embeds, contents, batch, source_ids, meta_data - - def inference( - self, - data_in, - data_lengths=None, - key: list = None, - tokenizer=None, - frontend=None, - **kwargs, - ): - - inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare( - data_in, data_lengths, key, tokenizer, frontend, **kwargs - ) - - llm_dtype = kwargs.get("llm_dtype", "fp32") - if llm_dtype == "fp32": - llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype - llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype - - with torch.cuda.amp.autocast( - enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype] - ): - label = contents["assistant"][-1] - self.llm = self.llm.to(dtype_map[llm_dtype]) - inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype]) - - if not kwargs.get("tearchforing", False): - - generated_ids = self.llm.generate( - inputs_embeds=inputs_embeds, max_new_tokens=kwargs.get("max_length", 512) - ) - # generated_ids = [ - # output_ids[len(input_id) :] - # for input_id, output_ids in zip(input_ids, generated_ids) - # ] - response = tokenizer.batch_decode( - generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True) - )[0] - - loss = None - else: - - labels_ids = batch["labels_ids"] - labels_ids[labels_ids == -1] = -100 - attention_mask = batch.get("attention_mask", None) - # attention_mask = attention_mask.to(dtype_map[llm_dtype]) - model_outputs = self.llm( - inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids - ) - - preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :] - response = tokenizer.batch_decode( - preds, - add_special_tokens=False, - skip_special_tokens=kwargs.get("skip_special_tokens", True), - )[0] - loss = model_outputs.loss.item() - - ibest_writer = None - if kwargs.get("output_dir") is not None: - if not hasattr(self, "writer"): - self.writer = DatadirWriter(kwargs.get("output_dir")) - ibest_writer = self.writer[f"{0 + 1}best_recog"] - - results = [] - response_clean = re.sub("[^\w\s\u3000\u4e00-\u9fff]+", "", response) - result_i = {"key": key[0], "text": response, "text_tn": response_clean, "label": label} - if loss is not None: - result_i["loss"] = loss - results.append(result_i) - - if ibest_writer is not None: - ibest_writer["text"][key[0]] = response.replace("\n", " ") - ibest_writer["label"][key[0]] = label.replace("\n", " ") - ibest_writer["text_tn"][key[0]] = response_clean - - return results, meta_data +# class LLMASR5(nn.Module): +# """ """ +# +# def __init__( +# self, +# specaug: str = None, +# specaug_conf: dict = None, +# normalize: str = None, +# normalize_conf: dict = None, +# audio_encoder: str = None, +# audio_encoder_conf: dict = None, +# audio_adaptor: str = None, +# audio_adaptor_conf: dict = None, +# decoder: str = None, +# decoder_conf: dict = None, +# ctc: str = None, +# ctc_conf: dict = None, +# ctc_weight: float = 0.5, +# llm: str = None, +# llm_conf: dict = None, +# input_size: int = 80, +# vocab_size: int = -1, +# ignore_id: int = -1, +# blank_id: int = 0, +# lsm_weight: float = 0.0, +# length_normalized_loss: bool = False, +# report_cer: bool = True, +# report_wer: bool = True, +# sym_space: str = "", +# sym_blank: str = "", +# # extract_feats_in_collect_stats: bool = True, +# share_embedding: bool = False, +# # preencoder: Optional[AbsPreEncoder] = None, +# # postencoder: Optional[AbsPostEncoder] = None, +# **kwargs, +# ): +# +# super().__init__() +# +# # audio encoder +# hub = audio_encoder_conf.get("hub", None) +# if hub == "ms": +# from funasr import AutoModel +# +# model = AutoModel(model=audio_encoder, model_revision="master") +# # frontend = model.kwargs.get("frontend") +# audio_encoder_output_size = model.model.encoder_output_size +# +# audio_encoder = ( +# model.model.model.encoder if hasattr(model.model, "model") else model.model.encoder +# ) +# +# # self.frontend = frontend +# +# elif hub == "hf": +# pass +# else: +# encoder_class = tables.encoder_classes.get(audio_encoder) +# audio_encoder = encoder_class(input_size=input_size, **audio_encoder_conf) +# audio_encoder_output_size = audio_encoder.output_size() +# freeze = audio_encoder_conf.get("freeze", True) +# freeze_layer_num = int(audio_encoder_conf.get("freeze_layer_num", -1)) +# # if freeze_layer_num > 0: +# # freeze_layer_num = range(freeze_layer_num) +# +# if freeze: +# for name, param in audio_encoder.named_parameters(): +# if freeze_layer_num > 0: +# idx = re.search(r"\.\d+\.", name) +# if idx is not None: +# beg, end = idx.regs[0] +# layer_id = int(name[beg + 1 : end - 1]) +# if layer_id < freeze_layer_num: +# param.requires_grad = False +# elif "ln_post." not in name: +# param.requires_grad = False +# else: +# param.requires_grad = False +# +# audio_encoder.eval() +# +# self.audio_encoder = audio_encoder +# +# # llm +# self.llm = None +# +# from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig +# +# init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5") +# +# model = AutoModelForCausalLM.from_pretrained( +# init_param_path, +# load_in_8bit=None, +# device_map=None, +# use_cache=None, +# ) +# freeze = llm_conf.get("freeze", True) +# if freeze: +# for name, param in model.named_parameters(): +# param.requires_grad = False +# model.eval() +# self.llm_dtype = llm_conf.get("llm_dtype", "fp32") +# self.llm = model.to(dtype_map[self.llm_dtype]) +# llm_dim = model.get_input_embeddings().weight.shape[-1] +# +# # adaptor +# adaptor_class = tables.adaptor_classes.get(audio_adaptor) +# audio_adaptor_conf["encoder_dim"] = audio_encoder_output_size +# audio_adaptor_conf["llm_dim"] = llm_dim +# audio_adaptor = adaptor_class(**audio_adaptor_conf) +# init_param_path = audio_adaptor_conf.get("init_param_path", None) +# if init_param_path is not None: +# src_state = torch.load(init_param_path, map_location="cpu") +# flag = audio_adaptor.load_state_dict(src_state, strict=False) +# logging.info(f"Loading audio_adaptor ckpt: {init_param_path}, status: {flag}") +# +# self.audio_adaptor = audio_adaptor +# +# self.error_calculator = None +# +# self.length_normalized_loss = length_normalized_loss +# self.beam_search = None +# +# self.eos = kwargs.get("eos", 151645) +# +# def forward( +# self, +# speech: torch.Tensor, +# speech_lengths: torch.Tensor, +# input_ids: torch.Tensor, +# attention_mask: torch.Tensor, +# labels_ids: torch.Tensor, +# fbank_beg: torch.Tensor, +# fbank_mask: torch.Tensor, +# **kwargs, +# ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: +# """Encoder + Decoder + Calc loss +# Args: +# speech: (Batch, Length, ...) +# speech_lengths: (Batch, ) +# text: (Batch, Length) +# text_lengths: (Batch,) +# """ +# import pdb +# +# pdb.set_trace() +# +# if len(speech_lengths.size()) > 1: +# speech_lengths = speech_lengths[:, 0] +# +# batch_size_speech, frames, _ = speech.shape +# batch_size, token_num = input_ids.shape +# +# with torch.cuda.amp.autocast(enabled=False): +# # audio encoder +# encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) +# +# # audio_adaptor +# encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens) +# +# input_ids[input_ids < 0] = 0 +# inputs_embeds = self.llm.model.get_input_embeddings()(input_ids) +# +# batch_size, token_num, dims = inputs_embeds.shape +# fake_token_len = kwargs.get("fake_token_len") +# fake_token_len[fake_token_len < 0] = 0 +# fbank_beg[fbank_beg < 0] = 0 +# +# speech_idx = 0 +# for batch_idx in range(batch_size): +# +# for turn_id in range(fbank_beg.shape[1]): +# fbank_beg_idx = fbank_beg[batch_idx, turn_id].item() +# if fbank_beg_idx > 0: +# speech_token_len = fake_token_len[batch_idx, turn_id] +# speech_token = encoder_out[speech_idx, :speech_token_len, :] +# +# try: +# inputs_embeds[ +# batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : +# ] = speech_token +# except Exception as e: +# # +# logging.error(f"{str(e)}, {traceback.format_exc()}") +# logging.info( +# f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}" +# ) +# # import pdb; +# # pdb.set_trace() +# speech_token_len = encoder_out_lens[speech_idx].item() +# speech_token = encoder_out[speech_idx, :speech_token_len, :] +# inputs_embeds[ +# batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : +# ] = speech_token +# +# speech_idx += 1 +# +# with torch.cuda.amp.autocast( +# enabled=True if self.llm_dtype != "fp32" else False, dtype=dtype_map[self.llm_dtype] +# ): +# labels_ids[labels_ids == -1] = -100 +# attention_mask[attention_mask < 0] = 0 +# model_outputs = self.llm( +# inputs_embeds=inputs_embeds.to(dtype_map[self.llm_dtype]), +# attention_mask=attention_mask, +# labels=labels_ids, +# ) +# loss = model_outputs.loss +# +# codec = kwargs.get("codec") +# codec_len = kwargs.get("codec_len") +# if len(codec_len.size()) > 1: +# codec_len = codec_len[:, 0] +# hidden_states = model_outputs.hidden_states[-1].float() +# +# target_ids = [] +# target_ids_len = [] +# hidden_states_select = [] +# for batch_idx in range(labels_ids.shape[0]): +# beg_i = 0 +# end_i = 0 +# for token_idx in range(labels_ids.shape[1]): +# token_int = labels_ids[batch_idx, token_idx].item() +# if token_int == self.eos: +# target_ids_i = labels_ids[batch_idx, beg_i:end_i] +# target_ids_len_i = end_i - beg_i +# target_ids_len.append(target_ids_len_i) +# target_ids.append(target_ids_i) +# hidden_states_i = hidden_states[batch_idx, beg_i - 1 : end_i - 1, :] +# hidden_states_select.append(hidden_states_i) +# beg_i = end_i +# continue +# +# end_i += 1 +# if token_int <= 0: +# beg_i += 1 +# +# target_ids = torch.nn.utils.rnn.pad_sequence( +# target_ids, batch_first=True, padding_value=-100 +# ) +# hidden_states_select = torch.nn.utils.rnn.pad_sequence( +# hidden_states_select, batch_first=True, padding_value=0.0 +# ) +# +# stats = {} +# with torch.no_grad(): +# preds = torch.argmax(model_outputs.logits, -1) +# acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100) +# stats["acc"] = acc_att +# +# stats["loss"] = torch.clone(loss.detach()) +# stats["batch_size"] = batch_size +# stats["batch_size_speech"] = batch_size_speech +# stats["batch_size_x_frames"] = frames * batch_size_speech +# stats["batch_size_real_frames"] = speech_lengths.sum().item() +# stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"] +# stats["batch_size_x_tokens"] = token_num * batch_size +# stats["batch_size_real_tokens"] = attention_mask.sum().item() +# stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"] +# +# dialog_turns = (fbank_beg > 0).sum(-1) +# dialog_turns_max = torch.max(dialog_turns).int().item() +# dialog_turns_avg = dialog_turns.sum().item() / batch_size +# stats["dialog_turns_max"] = dialog_turns_max +# stats["dialog_turns_avg"] = dialog_turns_avg +# +# # force_gatherable: to-device and to-tensor if scalar for DataParallel +# if self.length_normalized_loss: +# batch_size = int((labels_ids > 0 + 1).sum()) +# loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) +# return loss, stats, weight +# +# def encode(self, speech, speech_lengths): +# # audio encoder +# encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths) +# +# return encoder_out, encoder_out_lens +# +# def data_template(self, data): +# system, user, assistant = [], [], [] +# for i, item in enumerate(data): +# role = item["role"] +# content = item["content"] +# if role == "system": +# system.append(content) +# elif role == "user": +# user.append(content) +# elif role == "assistant": +# assistant.append(content) +# +# system = system * len(user) +# +# contents = { +# "system": system, +# "user": user, +# "assistant": assistant, +# } +# +# return contents +# +# def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs): +# +# system = contents["system"] +# user = contents["user"] +# assistant = contents["assistant"] +# pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)") +# +# input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg, fake_token_len = ( +# [], +# [], +# [], +# [], +# [], +# [], +# [], +# ) +# input_source_ids = [] +# for i, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)): +# if i >= kwargs.get("multiturn_num_max", 5): +# break +# if len(input_ids) > kwargs.get("max_token_length", 1500): +# break +# +# if i == 0: +# source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" +# else: +# source_input = f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" +# +# splits = pattern.split(source_input) +# source_ids = [] +# fbank_i = [] +# fbank_mask_i = [] +# fake_token_len_i = 0 +# fbank_beg_i = -1 +# fbank_lens_i = [] +# speech, speech_lengths = [], [] +# for k, sub_str in enumerate(splits): +# if not sub_str.startswith("<|startofspeech|>"): +# sub_token = tokenizer.encode(sub_str) +# source_ids += sub_token +# fbank_mask_i += [0] * len(sub_token) +# else: +# sub_str = sub_str.replace("<|startofspeech|>", "").replace( +# "<|endofspeech|>", "" +# ) +# if sub_str.startswith("!"): +# sub_str = sub_str[1:] +# if sub_str.startswith("!"): # !!bytes +# sub_str = eval(sub_str[1:]) +# try: +# time1 = time.perf_counter() +# data_src = load_audio_text_image_video(sub_str, fs=frontend.fs) +# time2 = time.perf_counter() +# meta_data["load_data"] = f"{time2 - time1:0.3f}" +# except Exception as e: +# logging.error(f"Loading wav failed! {str(e)}, {traceback.format_exc()}") +# +# speech, speech_lengths = extract_fbank( +# data_src, +# data_type=kwargs.get("data_type", "sound"), +# frontend=frontend, +# is_final=True, +# ) # speech: [b, T, d] +# +# time3 = time.perf_counter() +# meta_data["extract_feat"] = f"{time3 - time2:0.3f}" +# meta_data["batch_data_time"] = ( +# speech_lengths.sum().item() +# * frontend.frame_shift +# * frontend.lfr_n +# / 1000 +# ) +# +# if kwargs.get("permute", True): +# speech = speech.permute(0, 2, 1) +# if speech_lengths > kwargs.get("max_source_length", 5500): +# # logging.info( +# # f"speech_lengths > max_source_length: {speech_lengths}>{self.max_source_length}, {item}" +# # ) +# badcase_flag = True +# +# olens = 1 + (speech_lengths[0].item() - 3 + 2 * 1) // 2 +# olens = 1 + (olens - 3 + 2 * 1) // 2 +# fake_token_len_i = (olens - 1) // 2 + 1 +# fake_token = [0] * fake_token_len_i +# fbank_beg_i = len(source_ids) +# source_ids += fake_token +# fbank_mask_i += [1] * len(fake_token) +# +# fbank_beg += [fbank_beg_i + len(input_ids)] +# fake_token_len += [fake_token_len_i] +# source_mask = [-100] * len(source_ids) +# target_out = f"{target_out}<|im_end|>" +# target_ids = tokenizer.encode(target_out) +# input_source_ids = input_ids + source_ids +# input_ids += source_ids + target_ids +# labels += source_mask + target_ids +# fbank_mask += fbank_mask_i +# if len(speech) > 0: +# fbank.append(speech[0, :, :]) +# fbank_lens.append(speech_lengths) +# +# input_ids = torch.tensor(input_ids, dtype=torch.int64) # [: self.max_token_length] +# attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32) +# labels = torch.tensor(labels, dtype=torch.int64) # [: self.max_token_length] +# +# # fbank = speech[0, :, :] +# # fbank_lens = torch.tensor(fbank_lens, dtype=torch.int32) +# fbank_mask = torch.tensor(fbank_mask, dtype=torch.float32) +# fbank_beg = torch.tensor(fbank_beg, dtype=torch.int32) +# fake_token_len = torch.tensor(fake_token_len, dtype=torch.int32) +# source_ids = torch.tensor(input_source_ids, dtype=torch.int64) +# target_ids = torch.tensor(target_ids, dtype=torch.int64) +# +# if len(fbank) > 0: +# speech = torch.nn.utils.rnn.pad_sequence(fbank, batch_first=True, padding_value=0.0) +# speech_lengths = torch.nn.utils.rnn.pad_sequence( +# fbank_lens, batch_first=True, padding_value=-1 +# ) +# else: +# speech = [] +# speech_lengths = [] +# output = { +# "speech": speech, +# "speech_lengths": speech_lengths, +# "fbank_mask": fbank_mask[None, :], +# "fbank_beg": fbank_beg[None,], +# "fake_token_len": fake_token_len[None, :], +# "input_ids": input_ids[None,], +# "attention_mask": attention_mask[None,], +# "labels_ids": labels, +# "source_ids": source_ids[None, :], +# "target_ids": target_ids[None, :], +# } +# +# return output +# +# def inference_prepare( +# self, +# data_in, +# data_lengths=None, +# key: list = None, +# tokenizer=None, +# frontend=None, +# **kwargs, +# ): +# +# meta_data = {} +# prompt = kwargs.get("prompt", None) +# +# if kwargs.get("batch_size", 1) > 1: +# raise NotImplementedError("batch decoding is not implemented") +# +# contents = self.data_template(data_in[0]) +# output = self.data_load_speech(contents, tokenizer, frontend, meta_data=meta_data, **kwargs) +# batch = to_device(output, kwargs["device"]) +# +# # audio encoder +# speech = batch["speech"] +# if len(speech) > 0: +# speech_lengths = batch["speech_lengths"][:, 0] +# # fp16 +# if kwargs.get("fp16", False): +# speech = speech.to(torch.float16) +# elif kwargs.get("bf16", False): +# speech = speech.to(torch.bfloat16) +# # audio encoder +# encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) +# +# # audio_adaptor +# encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens) +# +# input_ids = batch["input_ids"] +# source_ids = batch["source_ids"] +# fbank_beg = batch["fbank_beg"] +# fake_token_len = batch["fake_token_len"] +# +# if not kwargs.get("tearchforing", False): +# input_ids = source_ids +# +# input_ids[input_ids < 0] = 0 +# inputs_embeds = self.llm.model.get_input_embeddings()(input_ids) +# +# batch_size, token_num, dims = inputs_embeds.shape +# +# fake_token_len[fake_token_len < 0] = 0 +# fbank_beg[fbank_beg < 0] = 0 +# +# speech_idx = 0 +# for batch_idx in range(batch_size): +# +# for turn_id in range(fbank_beg.shape[1]): +# fbank_beg_idx = fbank_beg[batch_idx, turn_id].item() +# if fbank_beg_idx > 0: +# speech_token_len = fake_token_len[batch_idx, turn_id] +# speech_token = encoder_out[speech_idx, :speech_token_len, :] +# +# try: +# inputs_embeds[ +# batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : +# ] = speech_token +# except Exception as e: +# # +# logging.error(f"{str(e)}, {traceback.format_exc()}") +# logging.info( +# f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}" +# ) +# # import pdb; +# # pdb.set_trace() +# speech_token_len = encoder_out_lens[speech_idx].item() +# speech_token = encoder_out[speech_idx, :speech_token_len, :] +# inputs_embeds[ +# batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : +# ] = speech_token +# +# speech_idx += 1 +# return inputs_embeds, contents, batch, source_ids, meta_data +# +# def inference( +# self, +# data_in, +# data_lengths=None, +# key: list = None, +# tokenizer=None, +# frontend=None, +# **kwargs, +# ): +# +# inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare( +# data_in, data_lengths, key, tokenizer, frontend, **kwargs +# ) +# +# llm_dtype = kwargs.get("llm_dtype", "fp32") +# if llm_dtype == "fp32": +# llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype +# llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype +# +# with torch.cuda.amp.autocast( +# enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype] +# ): +# label = contents["assistant"][-1] +# self.llm = self.llm.to(dtype_map[llm_dtype]) +# inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype]) +# +# if not kwargs.get("tearchforing", False): +# +# generated_ids = self.llm.generate( +# inputs_embeds=inputs_embeds, max_new_tokens=kwargs.get("max_length", 512) +# ) +# # generated_ids = [ +# # output_ids[len(input_id) :] +# # for input_id, output_ids in zip(input_ids, generated_ids) +# # ] +# response = tokenizer.batch_decode( +# generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True) +# )[0] +# +# loss = None +# else: +# +# labels_ids = batch["labels_ids"] +# labels_ids[labels_ids == -1] = -100 +# attention_mask = batch.get("attention_mask", None) +# # attention_mask = attention_mask.to(dtype_map[llm_dtype]) +# model_outputs = self.llm( +# inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids +# ) +# +# preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :] +# response = tokenizer.batch_decode( +# preds, +# add_special_tokens=False, +# skip_special_tokens=kwargs.get("skip_special_tokens", True), +# )[0] +# loss = model_outputs.loss.item() +# +# ibest_writer = None +# if kwargs.get("output_dir") is not None: +# if not hasattr(self, "writer"): +# self.writer = DatadirWriter(kwargs.get("output_dir")) +# ibest_writer = self.writer[f"{0 + 1}best_recog"] +# +# results = [] +# response_clean = re.sub("[^\w\s\u3000\u4e00-\u9fff]+", "", response) +# result_i = {"key": key[0], "text": response, "text_tn": response_clean, "label": label} +# if loss is not None: +# result_i["loss"] = loss +# results.append(result_i) +# +# if ibest_writer is not None: +# ibest_writer["text"][key[0]] = response.replace("\n", " ") +# ibest_writer["label"][key[0]] = label.replace("\n", " ") +# ibest_writer["text_tn"][key[0]] = response_clean +# +# return results, meta_data @tables.register("model_classes", "LLMASR5") From c37cf737b57f781992c2d8328f8fa34a2d1a9a27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Tue, 2 Jul 2024 14:48:42 +0800 Subject: [PATCH 05/24] update --- funasr/models/llm_asr/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index 72a3d67b5..95939cd38 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -2124,6 +2124,7 @@ class LLMASR5(nn.Module): self.ad_sos_eos = 0 self.ad_task_id = 1 self.ad_ignore_id = -1 + self.lm_out_voc_size = self.codebook_size + 1 def build_audio_decoder(self, name, conf): if name == "transformer": From 99c732badbc7a2135560ce23c5ac50a12c574bc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Tue, 2 Jul 2024 15:04:39 +0800 Subject: [PATCH 06/24] update --- funasr/models/llm_asr/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index 95939cd38..b3d36cb7f 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -2115,16 +2115,16 @@ class LLMASR5(nn.Module): self.eos = kwargs.get("eos", 151645) # audio decoder related + self.codebook_dim = audio_decoder_conf.get("codebook_dim", 1024) + self.codebook_size = audio_decoder_conf.get("codebook_size", 4096) + self.lm_out_voc_size = self.codebook_size + 1 self.audio_decoder = self.build_audio_decoder(name=audio_decoder, conf=audio_decoder_conf) self.audio_decoder_in_proj = torch.nn.Linear(llm_dim, self.audio_decoder.embed_unit) - self.codebook_dim = audio_decoder_conf.pop("codebook_dim", 1024) - self.codebook_size = audio_decoder_conf.pop("codebook_size", 4096) self.codec_embedder = torch.nn.Embedding(self.codebook_size, self.codebook_dim) self.audio_decoder_embedding = torch.nn.Embedding(2, self.audio_decoder.embed_unit) self.ad_sos_eos = 0 self.ad_task_id = 1 self.ad_ignore_id = -1 - self.lm_out_voc_size = self.codebook_size + 1 def build_audio_decoder(self, name, conf): if name == "transformer": From 988db3a2ed80ef129f167353e448614e7bd30282 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BF=97=E6=B5=A9?= Date: Tue, 2 Jul 2024 15:24:58 +0800 Subject: [PATCH 07/24] add TransformerEncoder_s0 --- funasr/models/llm_asr/transformer_encoder.py | 751 +++++++++++++++++++ funasr/models/llm_asr/transformer_lm.py | 2 +- 2 files changed, 752 insertions(+), 1 deletion(-) create mode 100644 funasr/models/llm_asr/transformer_encoder.py diff --git a/funasr/models/llm_asr/transformer_encoder.py b/funasr/models/llm_asr/transformer_encoder.py new file mode 100644 index 000000000..fc74c7306 --- /dev/null +++ b/funasr/models/llm_asr/transformer_encoder.py @@ -0,0 +1,751 @@ +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Transformer encoder definition.""" +from typing import List +from typing import Optional +from typing import Tuple +import torch +from torch import nn +import logging +from funasr.models.transformer.attention import ( + MultiHeadedAttention, + RelPositionMultiHeadedAttention, # noqa: H301 + LegacyRelPositionMultiHeadedAttention, # noqa: H301 +) +from funasr.models.transformer.embedding import ( + PositionalEncoding, # noqa: H301 + ScaledPositionalEncoding, # noqa: H301 + RelPositionalEncoding, # noqa: H301 + LegacyRelPositionalEncoding, # noqa: H301 +) +from funasr.models.transformer.layer_norm import LayerNorm +from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear +from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d +from funasr.models.transformer.utils.nets_utils import make_pad_mask +from funasr.models.transformer.positionwise_feed_forward import ( + PositionwiseFeedForward, # noqa: H301 +) +from funasr.models.transformer.utils.repeat import repeat +from funasr.models.transformer.utils.nets_utils import rename_state_dict +from funasr.models.transformer.utils.dynamic_conv import DynamicConvolution +from funasr.models.transformer.utils.dynamic_conv2d import DynamicConvolution2D +from funasr.models.transformer.utils.lightconv import LightweightConvolution +from funasr.models.transformer.utils.lightconv2d import LightweightConvolution2D +from funasr.models.transformer.utils.subsampling import Conv2dSubsampling +from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2 +from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6 +from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8 +from funasr.models.transformer.utils.subsampling import TooShortUttError +from funasr.models.transformer.utils.subsampling import check_short_utt + + +class EncoderLayer(nn.Module): + """Encoder layer module. + + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance + can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance + can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): Whether to use layer_norm before the first block. + concat_after (bool): Whether to concat attention layer's input and output. + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + stochastic_depth_rate (float): Proability to skip this layer. + During training, the layer may skip residual computation and return input + as-is with given probability. + """ + + def __init__( + self, + size, + self_attn, + feed_forward, + dropout_rate, + normalize_before=True, + concat_after=False, + stochastic_depth_rate=0.0, + ): + """Construct an EncoderLayer object.""" + super(EncoderLayer, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.norm1 = LayerNorm(size) + self.norm2 = LayerNorm(size) + self.dropout = nn.Dropout(dropout_rate) + self.size = size + self.normalize_before = normalize_before + self.concat_after = concat_after + if self.concat_after: + self.concat_linear = nn.Linear(size + size, size) + self.stochastic_depth_rate = stochastic_depth_rate + + def forward(self, x, mask, cache=None): + """Compute encoded features. + + Args: + x_input (torch.Tensor): Input tensor (#batch, time, size). + mask (torch.Tensor): Mask tensor for the input (#batch, time). + cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). + + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time). + + """ + if isinstance(x, tuple): + x, pos_emb = x[0], x[1] + else: + x, pos_emb = x, None + + skip_layer = False + # with stochastic depth, residual connection `x + f(x)` becomes + # `x <- x + 1 / (1 - p) * f(x)` at training time. + stoch_layer_coeff = 1.0 + if self.training and self.stochastic_depth_rate > 0: + skip_layer = torch.rand(1).item() < self.stochastic_depth_rate + stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) + + if skip_layer: + if cache is not None: + x = torch.cat([cache, x], dim=1) + if pos_emb is not None: + return (x, pos_emb), mask + return x, mask + + residual = x + if self.normalize_before: + x = self.norm1(x) + + if cache is None: + x_q = x + else: + assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) + x_q = x[:, -1:, :] + residual = residual[:, -1:, :] + mask = None if mask is None else mask[:, -1:, :] + + if pos_emb is not None: + x_att = self.self_attn(x_q, x, x, pos_emb, mask) + else: + x_att = self.self_attn(x_q, x, x, mask) + + if self.concat_after: + x_concat = torch.cat((x, x_att), dim=-1) + x = residual + stoch_layer_coeff * self.concat_linear(x_concat) + else: + x = residual + stoch_layer_coeff * self.dropout(x_att) + if not self.normalize_before: + x = self.norm1(x) + + residual = x + if self.normalize_before: + x = self.norm2(x) + x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm2(x) + + if cache is not None: + x = torch.cat([cache, x], dim=1) + + if pos_emb is not None: + return (x, pos_emb), mask + + return x, mask + + +class TransformerEncoder(nn.Module): + """Transformer encoder module. + + Args: + input_size: input dim + output_size: dimension of attention + attention_heads: the number of heads of multi head attention + linear_units: the number of units of position-wise feed forward + num_blocks: the number of decoder blocks + dropout_rate: dropout rate + attention_dropout_rate: dropout rate in attention + positional_dropout_rate: dropout rate after adding positional encoding + input_layer: input layer type + pos_enc_class: PositionalEncoding or ScaledPositionalEncoding + normalize_before: whether to use layer_norm before the first block + concat_after: whether to concat attention layer's input and output + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. + i.e. x -> x + att(x) + positionwise_layer_type: linear of conv1d + positionwise_conv_kernel_size: kernel size of positionwise conv1d layer + padding_idx: padding_idx for input_layer=embed + """ + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: Optional[str] = "conv2d", + pos_enc_class=PositionalEncoding, + normalize_before: bool = True, + concat_after: bool = False, + positionwise_layer_type: str = "linear", + positionwise_conv_kernel_size: int = 1, + padding_idx: int = -1, + interctc_layer_idx: List[int] = [], + interctc_use_conditioning: bool = False, + causal_mode: str = "None", + ): + super().__init__() + self._output_size = output_size + self.causal_mode = causal_mode + + if input_layer == "linear": + self.embed = torch.nn.Sequential( + torch.nn.Linear(input_size, output_size), + torch.nn.LayerNorm(output_size), + torch.nn.Dropout(dropout_rate), + torch.nn.ReLU(), + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer == "conv2d": + self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate) + elif input_layer == "conv2d2": + self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate) + elif input_layer == "conv2d6": + self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate) + elif input_layer == "conv2d8": + self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate) + elif input_layer == "embed": + self.embed = torch.nn.Sequential( + torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx), + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer is None: + if input_size == output_size: + self.embed = None + else: + self.embed = torch.nn.Linear(input_size, output_size) + else: + raise ValueError("unknown input_layer: " + input_layer) + self.normalize_before = normalize_before + if positionwise_layer_type == "linear": + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + output_size, + linear_units, + dropout_rate, + ) + elif positionwise_layer_type == "conv1d": + positionwise_layer = MultiLayeredConv1d + positionwise_layer_args = ( + output_size, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + elif positionwise_layer_type == "conv1d-linear": + positionwise_layer = Conv1dLinear + positionwise_layer_args = ( + output_size, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + else: + raise NotImplementedError("Support only linear or conv1d.") + self.encoders = repeat( + num_blocks, + lambda lnum: EncoderLayer( + output_size, + MultiHeadedAttention( + attention_heads, output_size, attention_dropout_rate + ), + positionwise_layer(*positionwise_layer_args), + dropout_rate, + normalize_before, + concat_after, + ), + ) + if self.normalize_before: + self.after_norm = LayerNorm(output_size) + + self.interctc_layer_idx = interctc_layer_idx + if len(interctc_layer_idx) > 0: + assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks + self.interctc_use_conditioning = interctc_use_conditioning + self.conditioning_layer = None + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + prev_states: torch.Tensor = None, + ctc = None, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Embed positions in tensor. + + Args: + xs_pad: input tensor (B, L, D) + ilens: input length (B) + prev_states: Not to be used now. + Returns: + position embedded tensor and mask + """ + masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) + if self.causal_mode == "None": + pass + elif self.causal_mode == "causal": + tt = xs_pad.shape[1] + pos_idx = torch.arange(tt) + causal_mask = torch.less_equal(pos_idx.unsqueeze(0), pos_idx.unsqueeze(1)) + causal_mask = causal_mask.unsqueeze(0).to(xs_pad.device) + masks = masks * causal_mask + + if self.embed is None: + xs_pad = xs_pad + elif ( + isinstance(self.embed, Conv2dSubsampling) + or isinstance(self.embed, Conv2dSubsampling2) + or isinstance(self.embed, Conv2dSubsampling6) + or isinstance(self.embed, Conv2dSubsampling8) + ): + short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) + if short_status: + raise TooShortUttError( + f"has {xs_pad.size(1)} frames and is too short for subsampling " + + f"(it needs more than {limit_size} frames), return empty results", + xs_pad.size(1), + limit_size, + ) + xs_pad, masks = self.embed(xs_pad, masks) + else: + xs_pad = self.embed(xs_pad) + + intermediate_outs = [] + if len(self.interctc_layer_idx) == 0: + xs_pad, masks = self.encoders(xs_pad, masks) + else: + for layer_idx, encoder_layer in enumerate(self.encoders): + xs_pad, masks = encoder_layer(xs_pad, masks) + + if layer_idx + 1 in self.interctc_layer_idx: + encoder_out = xs_pad + + # intermediate outputs are also normalized + if self.normalize_before: + encoder_out = self.after_norm(encoder_out) + + intermediate_outs.append((layer_idx + 1, encoder_out)) + + if self.interctc_use_conditioning: + ctc_out = ctc.softmax(encoder_out) + xs_pad = xs_pad + self.conditioning_layer(ctc_out) + + if self.normalize_before: + xs_pad = self.after_norm(xs_pad) + + olens = masks.squeeze(1).sum(1) + if len(intermediate_outs) > 0: + return (xs_pad, intermediate_outs), olens, None + return xs_pad, olens, None + + +def _pre_hook( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, +): + # https://github.com/espnet/espnet/commit/21d70286c354c66c0350e65dc098d2ee236faccc#diff-bffb1396f038b317b2b64dd96e6d3563 + rename_state_dict(prefix + "input_layer.", prefix + "embed.", state_dict) + # https://github.com/espnet/espnet/commit/3d422f6de8d4f03673b89e1caef698745ec749ea#diff-bffb1396f038b317b2b64dd96e6d3563 + rename_state_dict(prefix + "norm.", prefix + "after_norm.", state_dict) + + +class TransformerEncoder_s0(nn.Module): + """Transformer encoder module. + + Args: + idim (int): Input dimension. + attention_dim (int): Dimension of attention. + attention_heads (int): The number of heads of multi head attention. + conv_wshare (int): The number of kernel of convolution. Only used in + selfattention_layer_type == "lightconv*" or "dynamiconv*". + conv_kernel_length (Union[int, str]): Kernel size str of convolution + (e.g. 71_71_71_71_71_71). Only used in selfattention_layer_type + == "lightconv*" or "dynamiconv*". + conv_usebias (bool): Whether to use bias in convolution. Only used in + selfattention_layer_type == "lightconv*" or "dynamiconv*". + linear_units (int): The number of units of position-wise feed forward. + num_blocks (int): The number of decoder blocks. + dropout_rate (float): Dropout rate. + positional_dropout_rate (float): Dropout rate after adding positional encoding. + attention_dropout_rate (float): Dropout rate in attention. + input_layer (Union[str, torch.nn.Module]): Input layer type. + pos_enc_class (torch.nn.Module): Positional encoding module class. + `PositionalEncoding `or `ScaledPositionalEncoding` + normalize_before (bool): Whether to use layer_norm before the first block. + concat_after (bool): Whether to concat attention layer's input and output. + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear". + positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer. + selfattention_layer_type (str): Encoder attention layer type. + padding_idx (int): Padding idx for input_layer=embed. + stochastic_depth_rate (float): Maximum probability to skip the encoder layer. + intermediate_layers (Union[List[int], None]): indices of intermediate CTC layer. + indices start from 1. + if not None, intermediate outputs are returned (which changes return type + signature.) + + """ + + def __init__( + self, + idim, + attention_dim=256, + attention_heads=4, + conv_wshare=4, + conv_kernel_length="11", + conv_usebias=False, + linear_units=2048, + num_blocks=6, + dropout_rate=0.1, + positional_dropout_rate=0.1, + attention_dropout_rate=0.0, + input_layer="conv2d", + pos_enc_class=PositionalEncoding, + normalize_before=True, + concat_after=False, + positionwise_layer_type="linear", + positionwise_conv_kernel_size=1, + selfattention_layer_type="selfattn", + padding_idx=-1, + stochastic_depth_rate=0.0, + intermediate_layers=None, + ctc_softmax=None, + conditioning_layer_dim=None, + zero_triu: bool = False, + ): + """Construct an Encoder object.""" + super(TransformerEncoder_s0, self).__init__() + self._register_load_state_dict_pre_hook(_pre_hook) + + self.conv_subsampling_factor = 1 + if input_layer == "linear": + self.embed = torch.nn.Sequential( + torch.nn.Linear(idim, attention_dim), + torch.nn.LayerNorm(attention_dim), + torch.nn.Dropout(dropout_rate), + torch.nn.ReLU(), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif input_layer == "conv2d": + self.embed = Conv2dSubsampling(idim, attention_dim, dropout_rate) + self.conv_subsampling_factor = 4 + elif input_layer == "conv2d-scaled-pos-enc": + self.embed = Conv2dSubsampling( + idim, + attention_dim, + dropout_rate, + pos_enc_class(attention_dim, positional_dropout_rate), + ) + self.conv_subsampling_factor = 4 + elif input_layer == "conv2d6": + self.embed = Conv2dSubsampling6(idim, attention_dim, dropout_rate) + self.conv_subsampling_factor = 6 + elif input_layer == "conv2d8": + self.embed = Conv2dSubsampling8(idim, attention_dim, dropout_rate) + self.conv_subsampling_factor = 8 + elif input_layer == "embed": + self.embed = torch.nn.Sequential( + torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif isinstance(input_layer, torch.nn.Module): + self.embed = torch.nn.Sequential( + input_layer, + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif input_layer is None: + self.embed = torch.nn.Sequential( + pos_enc_class(attention_dim, positional_dropout_rate) + ) + elif input_layer == "none": + self.embed = torch.nn.Identity() + else: + raise ValueError("unknown input_layer: " + input_layer) + self.normalize_before = normalize_before + positionwise_layer, positionwise_layer_args = self.get_positionwise_layer( + positionwise_layer_type, + attention_dim, + linear_units, + dropout_rate, + positionwise_conv_kernel_size, + ) + # if selfattention_layer_type in [ + # "selfattn", + # "rel_selfattn", + # "legacy_rel_selfattn", + # ]: + # logging.info("encoder self-attention layer type = self-attention") + # encoder_selfattn_layer = MultiHeadedAttention + # encoder_selfattn_layer_args = [ + # ( + # attention_heads, + # attention_dim, + # attention_dropout_rate, + # ) + # ] * num_blocks + if selfattention_layer_type == "selfattn": + logging.info("encoder self-attention layer type = self-attention") + encoder_selfattn_layer = MultiHeadedAttention + encoder_selfattn_layer_args = [( + attention_heads, + attention_dim, + attention_dropout_rate, + )] * num_blocks + elif selfattention_layer_type == "legacy_rel_selfattn": + logging.info("encoder self-attention layer type = legacy relative self-attention") + assert pos_enc_class == LegacyRelPositionalEncoding + encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention + encoder_selfattn_layer_args = [( + attention_heads, + attention_dim, + attention_dropout_rate, + )] * num_blocks + logging.warning( + "Using legacy_rel_selfattn and it will be deprecated in the future." + ) + elif selfattention_layer_type == "rel_selfattn": + logging.info("encoder self-attention layer type = relative self-attention") + assert pos_enc_class == RelPositionalEncoding + encoder_selfattn_layer = RelPositionMultiHeadedAttention + encoder_selfattn_layer_args = [( + attention_heads, + attention_dim, + attention_dropout_rate, + zero_triu, + )] * num_blocks + elif selfattention_layer_type == "lightconv": + logging.info("encoder self-attention layer type = lightweight convolution") + encoder_selfattn_layer = LightweightConvolution + encoder_selfattn_layer_args = [ + ( + conv_wshare, + attention_dim, + attention_dropout_rate, + int(conv_kernel_length.split("_")[lnum]), + False, + conv_usebias, + ) + for lnum in range(num_blocks) + ] + elif selfattention_layer_type == "lightconv2d": + logging.info( + "encoder self-attention layer " + "type = lightweight convolution 2-dimensional" + ) + encoder_selfattn_layer = LightweightConvolution2D + encoder_selfattn_layer_args = [ + ( + conv_wshare, + attention_dim, + attention_dropout_rate, + int(conv_kernel_length.split("_")[lnum]), + False, + conv_usebias, + ) + for lnum in range(num_blocks) + ] + elif selfattention_layer_type == "dynamicconv": + logging.info("encoder self-attention layer type = dynamic convolution") + encoder_selfattn_layer = DynamicConvolution + encoder_selfattn_layer_args = [ + ( + conv_wshare, + attention_dim, + attention_dropout_rate, + int(conv_kernel_length.split("_")[lnum]), + False, + conv_usebias, + ) + for lnum in range(num_blocks) + ] + elif selfattention_layer_type == "dynamicconv2d": + logging.info( + "encoder self-attention layer type = dynamic convolution 2-dimensional" + ) + encoder_selfattn_layer = DynamicConvolution2D + encoder_selfattn_layer_args = [ + ( + conv_wshare, + attention_dim, + attention_dropout_rate, + int(conv_kernel_length.split("_")[lnum]), + False, + conv_usebias, + ) + for lnum in range(num_blocks) + ] + else: + raise NotImplementedError(selfattention_layer_type) + + self.encoders = repeat( + num_blocks, + lambda lnum: EncoderLayer( + attention_dim, + encoder_selfattn_layer(*encoder_selfattn_layer_args[lnum]), + positionwise_layer(*positionwise_layer_args), + dropout_rate, + normalize_before, + concat_after, + stochastic_depth_rate * float(1 + lnum) / num_blocks, + ), + ) + if self.normalize_before: + self.after_norm = LayerNorm(attention_dim) + + self.intermediate_layers = intermediate_layers + self.use_conditioning = True if ctc_softmax is not None else False + if self.use_conditioning: + self.ctc_softmax = ctc_softmax + self.conditioning_layer = torch.nn.Linear( + conditioning_layer_dim, attention_dim + ) + + def get_positionwise_layer( + self, + positionwise_layer_type="linear", + attention_dim=256, + linear_units=2048, + dropout_rate=0.1, + positionwise_conv_kernel_size=1, + ): + """Define positionwise layer.""" + if positionwise_layer_type == "linear": + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = (attention_dim, linear_units, dropout_rate) + elif positionwise_layer_type == "conv1d": + positionwise_layer = MultiLayeredConv1d + positionwise_layer_args = ( + attention_dim, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + elif positionwise_layer_type == "conv1d-linear": + positionwise_layer = Conv1dLinear + positionwise_layer_args = ( + attention_dim, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + else: + raise NotImplementedError("Support only linear or conv1d.") + return positionwise_layer, positionwise_layer_args + + def forward(self, xs, masks): + """Encode input sequence. + + Args: + xs (torch.Tensor): Input tensor (#batch, time, idim). + masks (torch.Tensor): Mask tensor (#batch, time). + + Returns: + torch.Tensor: Output tensor (#batch, time, attention_dim). + torch.Tensor: Mask tensor (#batch, time). + + """ + if isinstance( + self.embed, + (Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8), + ): + xs, masks = self.embed(xs, masks) + else: + xs = self.embed(xs) + + if self.intermediate_layers is None: + xs, masks = self.encoders(xs, masks) + else: + intermediate_outputs = [] + for layer_idx, encoder_layer in enumerate(self.encoders): + xs, masks = encoder_layer(xs, masks) + + if ( + self.intermediate_layers is not None + and layer_idx + 1 in self.intermediate_layers + ): + if isinstance(xs, tuple): + encoder_output = xs[0] + else: + encoder_output = xs + # intermediate branches also require normalization. + if self.normalize_before: + encoder_output = self.after_norm(encoder_output) + intermediate_outputs.append(encoder_output) + + if self.use_conditioning: + intermediate_result = self.ctc_softmax(encoder_output) + xs = xs + self.conditioning_layer(intermediate_result) + + if isinstance(xs, tuple): + xs = xs[0] + if self.normalize_before: + xs = self.after_norm(xs) + + if self.intermediate_layers is not None: + return xs, masks, intermediate_outputs + return xs, masks + + def forward_one_step(self, xs, masks, cache=None): + """Encode input frame. + + Args: + xs (torch.Tensor): Input tensor. + masks (torch.Tensor): Mask tensor. + cache (List[torch.Tensor]): List of cache tensors. + + Returns: + torch.Tensor: Output tensor. + torch.Tensor: Mask tensor. + List[torch.Tensor]: List of new cache tensors. + + """ + if isinstance(self.embed, (Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8)): + xs, masks = self.embed(xs, masks) + else: + xs = self.embed(xs) + if cache is None: + cache = [None for _ in range(len(self.encoders))] + new_cache = [] + for c, e in zip(cache, self.encoders): + xs, masks = e(xs, masks, cache=c) + if isinstance(xs, tuple): + new_cache.append(xs[0]) + else: + new_cache.append(xs) + if isinstance(xs, tuple): + xs = xs[0] + if self.normalize_before: + xs = self.after_norm(xs) + return xs, masks, new_cache diff --git a/funasr/models/llm_asr/transformer_lm.py b/funasr/models/llm_asr/transformer_lm.py index e05568742..bcb18418b 100644 --- a/funasr/models/llm_asr/transformer_lm.py +++ b/funasr/models/llm_asr/transformer_lm.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn from funasr.models.transformer.embedding import PositionalEncoding, ScaledPositionalEncoding, RelPositionalEncoding, LegacyRelPositionalEncoding -from funasr.models.transformer.encoder import TransformerEncoder as Encoder +from funasr.models.llm_asr.transformer_encoder import TransformerEncoder_s0 as Encoder from funasr.models.transformer.utils.mask import subsequent_mask from funasr.models.transformer.utils.nets_utils import make_pad_mask import logging From 8839f1038cd2342a87fdcdc6f063bbf8093aea2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BF=97=E6=B5=A9?= Date: Tue, 2 Jul 2024 15:26:36 +0800 Subject: [PATCH 08/24] add TransformerEncoder_s0 --- funasr/models/llm_asr/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index b3d36cb7f..15c240772 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -2139,7 +2139,6 @@ class LLMASR5(nn.Module): else: raise TypeError(f"Unknown codec decoder type {name}") - conf["name"] = name return lm_model def calc_dense_vector(self, codec, codec_lengths): From 8bde4a2cfbb43dd6e93800556ae53f164adcbc41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Tue, 2 Jul 2024 15:30:32 +0800 Subject: [PATCH 09/24] update --- funasr/models/llm_asr/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index b3d36cb7f..5d3002657 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -2084,6 +2084,7 @@ class LLMASR5(nn.Module): load_in_8bit=None, device_map=None, use_cache=None, + output_hidden_states=True, ) freeze = llm_conf.get("freeze", True) if freeze: From 38f588479df74542abe26ed8989d8e732fc490ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Tue, 2 Jul 2024 16:04:34 +0800 Subject: [PATCH 10/24] update --- funasr/models/llm_asr/label_smoothing_loss.py | 82 +++++++++++++++++++ funasr/models/llm_asr/model.py | 13 ++- 2 files changed, 94 insertions(+), 1 deletion(-) create mode 100644 funasr/models/llm_asr/label_smoothing_loss.py diff --git a/funasr/models/llm_asr/label_smoothing_loss.py b/funasr/models/llm_asr/label_smoothing_loss.py new file mode 100644 index 000000000..25433431c --- /dev/null +++ b/funasr/models/llm_asr/label_smoothing_loss.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Label smoothing module.""" + +import torch +from torch import nn +from cosyvoice.modules.nets_utils import make_pad_mask + + +class LabelSmoothingLoss(nn.Module): + """Label-smoothing loss. + + :param int size: the number of class + :param int padding_idx: ignored class id + :param float smoothing: smoothing rate (0.0 means the conventional CE) + :param bool normalize_length: normalize loss by sequence length if True + :param torch.nn.Module criterion: loss function to be smoothed + """ + + def __init__( + self, + size, + padding_idx, + smoothing, + normalize_length=False, + criterion=nn.KLDivLoss(reduction="none"), + reduction=True, + ): + """Construct an LabelSmoothingLoss object.""" + super(LabelSmoothingLoss, self).__init__() + self.criterion = criterion + self.padding_idx = padding_idx + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + self.size = size + self.true_dist = None + self.normalize_length = normalize_length + self.reduction = reduction + + def forward(self, x, target): + """Compute loss between x and target. + + :param torch.Tensor x: prediction (batch, seqlen, class) + :param torch.Tensor target: + target signal masked with self.padding_id (batch, seqlen) + :return: scalar float value + :rtype torch.Tensor + """ + assert x.size(2) == self.size + batch_size = x.size(0) + x = x.reshape(-1, self.size) + target = target.reshape(-1) + with torch.no_grad(): + true_dist = x.clone() + true_dist.fill_(self.smoothing / (self.size - 1)) + ignore = target == self.padding_idx # (B,) + total = len(target) - ignore.sum().item() + target = target.masked_fill(ignore, 0) # avoid -1 index + true_dist.scatter_(1, target.unsqueeze(1), self.confidence) + kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) + if not self.reduction: + return kl + else: + denom = total if self.normalize_length else batch_size + return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom + + +class SequenceBinaryCrossEntropy(nn.Module): + def __init__(self, normalize_length=False, criterion=nn.BCEWithLogitsLoss(reduction="none")): + super().__init__() + self.normalize_length = normalize_length + self.criterion = criterion + + def forward(self, pred, label, lengths): + pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1]).to(pred.device) + loss = self.criterion(pred, label) + denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0] + return loss.masked_fill(pad_mask.unsqueeze(-1), 0).sum() / denom diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index 99b61f986..4cfb82ee4 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -2126,6 +2126,17 @@ class LLMASR5(nn.Module): self.ad_sos_eos = 0 self.ad_task_id = 1 self.ad_ignore_id = -1 + self.predict_nq = 1 + + from .label_smoothing_loss import LabelSmoothingLoss + + self.criterion_ce = LabelSmoothingLoss( + size=self.lm_out_voc_size // self.predict_nq, + padding_idx=self.ad_ignore_id, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + reduction=False, + ) def build_audio_decoder(self, name, conf): if name == "transformer": @@ -2204,7 +2215,7 @@ class LLMASR5(nn.Module): ) for i, codec_len in enumerate(codec_lengths): llm_targets[i, :codec_len] = codec[i, :codec_len] - llm_targets[i, codec_len] = self.codebook_size + self.sos_eos + llm_targets[i, codec_len] = self.codebook_size + self.ad_sos_eos return (llm_inputs, llm_targets), (llm_lengths, codec_lengths + 1) From 2095a607e4604a481d30caa63439fd76c6fe3ee9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Tue, 2 Jul 2024 17:41:09 +0800 Subject: [PATCH 11/24] update --- funasr/models/llm_asr/label_smoothing_loss.py | 2 +- funasr/models/llm_asr/model.py | 21 +++++++++++-- funasr/train_utils/load_pretrained_model.py | 31 +++++++++++++++++-- 3 files changed, 49 insertions(+), 5 deletions(-) diff --git a/funasr/models/llm_asr/label_smoothing_loss.py b/funasr/models/llm_asr/label_smoothing_loss.py index 25433431c..78709a045 100644 --- a/funasr/models/llm_asr/label_smoothing_loss.py +++ b/funasr/models/llm_asr/label_smoothing_loss.py @@ -8,7 +8,7 @@ import torch from torch import nn -from cosyvoice.modules.nets_utils import make_pad_mask +from funasr.models.transformer.utils.nets_utils import make_pad_mask class LabelSmoothingLoss(nn.Module): diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index 4cfb82ee4..e26195d1b 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -2377,6 +2377,7 @@ class LLMASR5(nn.Module): target_ids.append(target_ids_i) hidden_states_i = hidden_states[batch_idx, beg_i - 1 : end_i - 1, :] hidden_states_select.append(hidden_states_i) + end_i += 1 beg_i = end_i continue @@ -2393,16 +2394,32 @@ class LLMASR5(nn.Module): target_ids_len = torch.tensor(target_ids_len, dtype=torch.int32, device=input_ids.device) target_ids = target_ids.to(device=input_ids.device) hidden_states_select = hidden_states_select.to(device=input_ids.device) - loss, logits, target, codec_lengths = self.nll( - hidden_states_select, target_ids_len, codec, codec_len + nll, logits, target, target_lengths = self.nll( + hidden_states_select, target_ids_len, codec[:, :, None], codec_len ) + output_mask = ( + ~make_pad_mask(target_lengths, maxlen=target_lengths.max()) + .to(hidden_states_select.device) + .unsqueeze(-1) + ) + total, batch_size = output_mask.sum() * self.predict_nq, nll.shape[0] * self.predict_nq + denom = total if self.length_normalized_loss else batch_size + loss = (nll * output_mask).sum() / denom + stats = {} with torch.no_grad(): preds = torch.argmax(model_outputs.logits, -1) acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100) stats["acc"] = acc_att + cc = logits.shape[-1] + for i in range(self.predict_nq): + acc = th_accuracy( + logits[:, :, i, :].reshape(-1, cc), target[:, :, i], self.ad_ignore_id + ) + stats[f"codec_acc_{i + 1}"] = acc + stats["loss"] = torch.clone(loss.detach()) stats["batch_size"] = batch_size stats["batch_size_speech"] = batch_size_speech diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py index 8ed613caf..055daff3d 100644 --- a/funasr/train_utils/load_pretrained_model.py +++ b/funasr/train_utils/load_pretrained_model.py @@ -10,8 +10,8 @@ import torch.optim import pdb -def load_pretrained_model( - path: str, +def _load_pretrained_model( + path, model: torch.nn.Module, ignore_init_mismatch: bool = True, map_location: str = "cpu", @@ -100,3 +100,30 @@ def load_pretrained_model( flag = obj.load_state_dict(dst_state, strict=True) logging.info(f"Loading ckpt: {path}, status: {flag}") + + +def load_pretrained_model( + path, + model: torch.nn.Module, + ignore_init_mismatch: bool = True, + map_location: str = "cpu", + oss_bucket=None, + scope_map=[], + excludes=None, + **kwargs, +): + if isinstance(path, str): + path = path.split(",") + + for i, path_i in enumerate(path): + logging.info(f"Loading ckpt-{i}: {path_i}") + _load_pretrained_model( + path_i, + model=model, + ignore_init_mismatch=ignore_init_mismatch, + map_location=map_location, + oss_bucket=oss_bucket, + scope_map=scope_map, + excludes=excludes, + **kwargs, + ) From 1cf30eca3625206c457b158350a46d92ae857235 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Tue, 2 Jul 2024 17:56:37 +0800 Subject: [PATCH 12/24] update --- funasr/datasets/openai_datasets/index_ds.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/funasr/datasets/openai_datasets/index_ds.py b/funasr/datasets/openai_datasets/index_ds.py index 010d2d5d7..eefb7f618 100644 --- a/funasr/datasets/openai_datasets/index_ds.py +++ b/funasr/datasets/openai_datasets/index_ds.py @@ -55,7 +55,7 @@ class OpenAIIndexDSJsonl(torch.utils.data.Dataset): # torch.utils.data.Dataset text_length = data_dict.get("text_length", 0) if speech_length > self.max_source_length: logging.info( - "speech_length: {speech_length} > {self.max_source_length}, drop it" + f"speech_length: {speech_length} > {self.max_source_length}, drop it" ) continue if text_length > self.max_target_length: From c99ecd3687969b478c83df0672d6cce6a30bea89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Tue, 2 Jul 2024 23:25:28 +0800 Subject: [PATCH 13/24] update --- funasr/auto/auto_model.py | 27 ++++++----- funasr/train_utils/load_pretrained_model.py | 52 ++++++++++----------- 2 files changed, 41 insertions(+), 38 deletions(-) diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index e027e070e..8677c69e6 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -224,18 +224,21 @@ class AutoModel: # init_param init_param = kwargs.get("init_param", None) if init_param is not None: - if os.path.exists(init_param): - logging.info(f"Loading pretrained params from {init_param}") - load_pretrained_model( - model=model, - path=init_param, - ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True), - oss_bucket=kwargs.get("oss_bucket", None), - scope_map=kwargs.get("scope_map", []), - excludes=kwargs.get("excludes", None), - ) - else: - print(f"error, init_param does not exist!: {init_param}") + if isinstance(init_param, str): + init_param = [init_param] + for i, init_param_i in enumerate(init_param): + if os.path.exists(init_param_i): + logging.info(f"Loading pretrained params from ckpt-{i}: {init_param_i}") + load_pretrained_model( + model=model, + path=init_param_i, + ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True), + oss_bucket=kwargs.get("oss_bucket", None), + scope_map=kwargs.get("scope_map", []), + excludes=kwargs.get("excludes", None), + ) + else: + print(f"error, init_param from ckpt-{i} does not exist!: {init_param_i}") # fp16 if kwargs.get("fp16", False): diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py index 055daff3d..0cfe89630 100644 --- a/funasr/train_utils/load_pretrained_model.py +++ b/funasr/train_utils/load_pretrained_model.py @@ -10,7 +10,7 @@ import torch.optim import pdb -def _load_pretrained_model( +def load_pretrained_model( path, model: torch.nn.Module, ignore_init_mismatch: bool = True, @@ -102,28 +102,28 @@ def _load_pretrained_model( logging.info(f"Loading ckpt: {path}, status: {flag}") -def load_pretrained_model( - path, - model: torch.nn.Module, - ignore_init_mismatch: bool = True, - map_location: str = "cpu", - oss_bucket=None, - scope_map=[], - excludes=None, - **kwargs, -): - if isinstance(path, str): - path = path.split(",") - - for i, path_i in enumerate(path): - logging.info(f"Loading ckpt-{i}: {path_i}") - _load_pretrained_model( - path_i, - model=model, - ignore_init_mismatch=ignore_init_mismatch, - map_location=map_location, - oss_bucket=oss_bucket, - scope_map=scope_map, - excludes=excludes, - **kwargs, - ) +# def load_pretrained_model( +# path, +# model: torch.nn.Module, +# ignore_init_mismatch: bool = True, +# map_location: str = "cpu", +# oss_bucket=None, +# scope_map=[], +# excludes=None, +# **kwargs, +# ): +# if isinstance(path, str): +# path = path.split(",") +# +# for i, path_i in enumerate(path): +# logging.info(f"Loading ckpt-{i}: {path_i}") +# _load_pretrained_model( +# path_i, +# model=model, +# ignore_init_mismatch=ignore_init_mismatch, +# map_location=map_location, +# oss_bucket=oss_bucket, +# scope_map=scope_map, +# excludes=excludes, +# **kwargs, +# ) From f9bb49c7ffa0151b0ab92d96d0fcfc572ff47eb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Wed, 3 Jul 2024 11:08:53 +0800 Subject: [PATCH 14/24] update --- funasr/models/llm_asr/model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index e26195d1b..ea1ab9d42 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -2357,9 +2357,10 @@ class LLMASR5(nn.Module): loss = model_outputs.loss codec = kwargs.get("codec") - codec_len = kwargs.get("codec_len") - if len(codec_len.size()) > 1: - codec_len = codec_len[:, 0] + # codec_len = kwargs.get("codec_len") + # if len(codec_len.size()) > 1: + # codec_len = codec_len[:, 0] + codec_len = (codec > 0).sum(-1) hidden_states = model_outputs.hidden_states[-1].float() target_ids = [] From 1bf66d044ffb777d17592bb94869b485c90933a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Wed, 3 Jul 2024 17:53:23 +0800 Subject: [PATCH 15/24] update --- funasr/models/llm_asr/model.py | 45 +++++++++++++--------------------- 1 file changed, 17 insertions(+), 28 deletions(-) diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index ea1ab9d42..21c96559e 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -2715,37 +2715,26 @@ class LLMASR5(nn.Module): self.llm = self.llm.to(dtype_map[llm_dtype]) inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype]) - if not kwargs.get("tearchforing", False): + generated_ids = self.llm.generate( + inputs_embeds=inputs_embeds, + max_new_tokens=kwargs.get("max_length", 512), + output_hidden_states=True, + return_dict_in_generate=True, + output_scores=True, + ) + hidden_states = generated_ids["hidden_states"] - generated_ids = self.llm.generate( - inputs_embeds=inputs_embeds, max_new_tokens=kwargs.get("max_length", 512) - ) - # generated_ids = [ - # output_ids[len(input_id) :] - # for input_id, output_ids in zip(input_ids, generated_ids) - # ] - response = tokenizer.batch_decode( - generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True) - )[0] + # hidden_states: (t1, t2, ..., tn, ..., tN), tn=(l1, l2, ..., ln, ..., lN), ln: shape: 1x1x3584 - loss = None - else: + # generated_ids = [ + # output_ids[len(input_id) :] + # for input_id, output_ids in zip(input_ids, generated_ids) + # ] + # response = tokenizer.batch_decode( + # generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True) + # )[0] - labels_ids = batch["labels_ids"] - labels_ids[labels_ids == -1] = -100 - attention_mask = batch.get("attention_mask", None) - # attention_mask = attention_mask.to(dtype_map[llm_dtype]) - model_outputs = self.llm( - inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids - ) - - preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :] - response = tokenizer.batch_decode( - preds, - add_special_tokens=False, - skip_special_tokens=kwargs.get("skip_special_tokens", True), - )[0] - loss = model_outputs.loss.item() + loss = None ibest_writer = None if kwargs.get("output_dir") is not None: From e8fe5711a26eef610fe1057c0ef718153cf1ac71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Thu, 4 Jul 2024 09:43:29 +0800 Subject: [PATCH 16/24] update --- funasr/models/llm_asr/model.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index 21c96559e..937ad07fc 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -2722,9 +2722,17 @@ class LLMASR5(nn.Module): return_dict_in_generate=True, output_scores=True, ) - hidden_states = generated_ids["hidden_states"] + hidden_states = generated_ids[ + "hidden_states" + ] # hidden_states: (t1, t2, ..., tn, ..., tN), tn=(l1, l2, ..., ln, ..., lN), ln: shape: 1x1x3584 - # hidden_states: (t1, t2, ..., tn, ..., tN), tn=(l1, l2, ..., ln, ..., lN), ln: shape: 1x1x3584 + token_num = len(hidden_states) + hidden_states_out = torch.zeros((1, token_num, 3584), dtype=torch.float32).to( + inputs_embeds.device + ) + + for i in range(token_num): + hidden_states_out[0, i, :] = hidden_states[1, -1][0, 0, :].to(torch.float32) # generated_ids = [ # output_ids[len(input_id) :] From 2ab9f44113649b451c7f2ac27203090c7c4d669e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Thu, 4 Jul 2024 09:48:20 +0800 Subject: [PATCH 17/24] update --- funasr/models/llm_asr/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index 937ad07fc..05f52ef93 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -2732,7 +2732,7 @@ class LLMASR5(nn.Module): ) for i in range(token_num): - hidden_states_out[0, i, :] = hidden_states[1, -1][0, 0, :].to(torch.float32) + hidden_states_out[0, i, :] = hidden_states[1][-1][0, 0, :].to(torch.float32) # generated_ids = [ # output_ids[len(input_id) :] From 05acd675ec507b48b96ea560242e563636692ece Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BF=97=E6=B5=A9?= Date: Thu, 4 Jul 2024 10:37:36 +0800 Subject: [PATCH 18/24] add audio decoding --- funasr/models/llm_asr/model.py | 110 +++++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index 05f52ef93..6358a8ddc 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -6,6 +6,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch.cuda.amp import autocast +import numpy as np import re from funasr.models.scama.utils import sequence_mask from funasr.losses.label_smoothing_loss import LabelSmoothingLoss @@ -2734,6 +2735,8 @@ class LLMASR5(nn.Module): for i in range(token_num): hidden_states_out[0, i, :] = hidden_states[1][-1][0, 0, :].to(torch.float32) + speech_tokens = audio_decode(hidden_states) + # generated_ids = [ # output_ids[len(input_id) :] # for input_id, output_ids in zip(input_ids, generated_ids) @@ -2763,3 +2766,110 @@ class LLMASR5(nn.Module): ibest_writer["text_tn"][key[0]] = response_clean return results, meta_data + + def audio_decode( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + min_length=None, + max_length: int = 30 * 25, + infer_cfg_ratio=None, + decoding_length=None, + ): + # 1. encode text + text = self.audio_decoder_in_proj(text) + device = text.device + out_tokens = [] + sos_eos_emb = self.audio_decoder_embedding(torch.tensor([[self.ad_sos_eos]], dtype=torch.int64, device=device)) + task_id_emb = self.audio_decoder_embedding(torch.tensor([[self.ad_task_id]], dtype=torch.int64, device=device)) + prompt = torch.cat([sos_eos_emb, text, task_id_emb], dim=1) + state, cfg_state = None, None + for i in range(max_length): + if len(out_tokens) > 0: + codec_prompt = torch.tensor([out_tokens], dtype=torch.int64, device=device) + codec_lengths = torch.tensor([len(out_tokens)], dtype=torch.int64, device=device) + # if any quantizer output is eos + if torch.any(codec_prompt[:, -1] == (self.codebook_size+self.sos_eos)): + break + seq_input, _ = self.prepare_audio_decoder_io( + text, text_lengths, + codec_prompt, codec_lengths, + need_targets=False + ) + else: + seq_input, _ = self.prepare_audio_decoder_io( + text, text_lengths, None, None, + need_targets=False + ) + + # use state for speedup + pred, (state, _) = self.audio_decoder.score( + seq_input[0], + state, + prompt[0] + ) + if infer_cfg_ratio is not None: + cond_len = prompt[0].shape[0] + cfg_pred, (cfg_state, _) = self.audio_decoder.score( + seq_input[0][cond_len-1:], + cfg_state, + prompt[0][cond_len-1:] + ) + pred = (1 + infer_cfg_ratio) * pred - infer_cfg_ratio * cfg_pred + + # sampling all `nq` token ids + pred = pred.reshape(self.predict_nq, -1) + # normalize scores + pred = torch.log_softmax(pred, dim=-1) + if min_length is not None and i < min_length: + pred[:, self.codebook_size + self.ad_sos_eos] = float(np.finfo(np.float32).min) + top_ids = [] + for k in range(self.predict_nq): + top_ids.append(self.ras_sampling(pred[k], out_tokens)[0].item()) + out_tokens.append(top_ids) + + # remove eos token + hit_eos = False + if torch.any(torch.tensor(out_tokens[-1], dtype=torch.int64) == self.codebook_size+self.ad_sos_eos): + hit_eos = True + out_tokens = out_tokens[:-1] + + if decoding_length is None: + return torch.tensor([out_tokens], dtype=torch.int64, device=device) + else: + return torch.tensor([out_tokens], dtype=torch.int64, device=device), hit_eos + + # Repetition Aware Sampling in VALL-E 2 + def ras_sampling( + self, + weighted_scores, decoded_tokens, *, + top_p=0.8, top_k=25, win_size=10, tau_r=0.1 + ): + top_ids = self.nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k) + rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(top_ids) == top_ids).sum().item() + if rep_num >= win_size * tau_r: + top_ids = self.random_sampling(weighted_scores) + + return top_ids + + def nucleus_sampling(self, weighted_scores, top_p=0.8, top_k=25): + prob, indices = [], [] + cum_prob = 0.0 + sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True) + for i in range(len(sorted_idx)): + # sampling both top-p and numbers. + if cum_prob < top_p and len(prob) < top_k: + cum_prob += sorted_value[i] + prob.append(sorted_value[i]) + indices.append(sorted_idx[i]) + else: + break + prob = torch.tensor(prob).to(weighted_scores) + indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device) + sampling_ids = prob.multinomial(1, replacement=True) + top_ids = indices[sampling_ids] + return top_ids + + def random_sampling(self, weighted_scores): + top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True) + return top_ids From 63800cb852b063ea2001bbd716b8c0a15df3b3a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Thu, 4 Jul 2024 11:07:27 +0800 Subject: [PATCH 19/24] update --- funasr/models/llm_asr/model.py | 81 ++++++++++++++++++++-------------- 1 file changed, 48 insertions(+), 33 deletions(-) diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index 6358a8ddc..89ada01ab 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -2731,19 +2731,27 @@ class LLMASR5(nn.Module): hidden_states_out = torch.zeros((1, token_num, 3584), dtype=torch.float32).to( inputs_embeds.device ) - + hidden_states_out_len = torch.tensor( + [ + token_num, + ], + dtype=torch.int32, + ).to(inputs_embeds.device) for i in range(token_num): hidden_states_out[0, i, :] = hidden_states[1][-1][0, 0, :].to(torch.float32) - speech_tokens = audio_decode(hidden_states) + speech_tokens = self.audio_decode( + hidden_states_out, hidden_states_out_len + ) # 1xl: 2,10,1023 + sequences = generated_ids["sequences"] # generated_ids = [ # output_ids[len(input_id) :] # for input_id, output_ids in zip(input_ids, generated_ids) # ] - # response = tokenizer.batch_decode( - # generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True) - # )[0] + response = tokenizer.batch_decode( + sequences, skip_special_tokens=kwargs.get("skip_special_tokens", True) + )[0] loss = None @@ -2755,33 +2763,49 @@ class LLMASR5(nn.Module): results = [] response_clean = re.sub("[^\w\s\u3000\u4e00-\u9fff]+", "", response) - result_i = {"key": key[0], "text": response, "text_tn": response_clean, "label": label} + result_i = { + "key": key[0], + "text": response, + "text_tn": response_clean, + "label": label, + "speech_tokens": speech_tokens, + } if loss is not None: result_i["loss"] = loss results.append(result_i) + speech_tokens_out = "<|startofspeech|>" + for i in range(speech_tokens.shape[-1]): + tmp = speech_tokens[0, i].item() + speech_tokens_out += f"<|c{tmp}|>" + speech_tokens_out += "<|endofspeech|><|im_end|>" if ibest_writer is not None: ibest_writer["text"][key[0]] = response.replace("\n", " ") ibest_writer["label"][key[0]] = label.replace("\n", " ") ibest_writer["text_tn"][key[0]] = response_clean + ibest_writer["speech_tokens"][key[0]] = speech_tokens_out return results, meta_data def audio_decode( - self, - text: torch.Tensor, - text_lengths: torch.Tensor, - min_length=None, - max_length: int = 30 * 25, - infer_cfg_ratio=None, - decoding_length=None, + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + min_length=None, + max_length: int = 30 * 25, + infer_cfg_ratio=None, + decoding_length=None, ): # 1. encode text text = self.audio_decoder_in_proj(text) device = text.device out_tokens = [] - sos_eos_emb = self.audio_decoder_embedding(torch.tensor([[self.ad_sos_eos]], dtype=torch.int64, device=device)) - task_id_emb = self.audio_decoder_embedding(torch.tensor([[self.ad_task_id]], dtype=torch.int64, device=device)) + sos_eos_emb = self.audio_decoder_embedding( + torch.tensor([[self.ad_sos_eos]], dtype=torch.int64, device=device) + ) + task_id_emb = self.audio_decoder_embedding( + torch.tensor([[self.ad_task_id]], dtype=torch.int64, device=device) + ) prompt = torch.cat([sos_eos_emb, text, task_id_emb], dim=1) state, cfg_state = None, None for i in range(max_length): @@ -2789,31 +2813,22 @@ class LLMASR5(nn.Module): codec_prompt = torch.tensor([out_tokens], dtype=torch.int64, device=device) codec_lengths = torch.tensor([len(out_tokens)], dtype=torch.int64, device=device) # if any quantizer output is eos - if torch.any(codec_prompt[:, -1] == (self.codebook_size+self.sos_eos)): + if torch.any(codec_prompt[:, -1] == (self.codebook_size + self.ad_sos_eos)): break seq_input, _ = self.prepare_audio_decoder_io( - text, text_lengths, - codec_prompt, codec_lengths, - need_targets=False + text, text_lengths, codec_prompt, codec_lengths, need_targets=False ) else: seq_input, _ = self.prepare_audio_decoder_io( - text, text_lengths, None, None, - need_targets=False + text, text_lengths, None, None, need_targets=False ) # use state for speedup - pred, (state, _) = self.audio_decoder.score( - seq_input[0], - state, - prompt[0] - ) + pred, (state, _) = self.audio_decoder.score(seq_input[0], state, prompt[0]) if infer_cfg_ratio is not None: cond_len = prompt[0].shape[0] cfg_pred, (cfg_state, _) = self.audio_decoder.score( - seq_input[0][cond_len-1:], - cfg_state, - prompt[0][cond_len-1:] + seq_input[0][cond_len - 1 :], cfg_state, prompt[0][cond_len - 1 :] ) pred = (1 + infer_cfg_ratio) * pred - infer_cfg_ratio * cfg_pred @@ -2830,7 +2845,9 @@ class LLMASR5(nn.Module): # remove eos token hit_eos = False - if torch.any(torch.tensor(out_tokens[-1], dtype=torch.int64) == self.codebook_size+self.ad_sos_eos): + if torch.any( + torch.tensor(out_tokens[-1], dtype=torch.int64) == self.codebook_size + self.ad_sos_eos + ): hit_eos = True out_tokens = out_tokens[:-1] @@ -2841,9 +2858,7 @@ class LLMASR5(nn.Module): # Repetition Aware Sampling in VALL-E 2 def ras_sampling( - self, - weighted_scores, decoded_tokens, *, - top_p=0.8, top_k=25, win_size=10, tau_r=0.1 + self, weighted_scores, decoded_tokens, *, top_p=0.8, top_k=25, win_size=10, tau_r=0.1 ): top_ids = self.nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k) rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(top_ids) == top_ids).sum().item() From 256defef106f3bbb71d24027c6bc2316fa136162 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Thu, 4 Jul 2024 13:04:45 +0800 Subject: [PATCH 20/24] update --- funasr/auto/auto_model.py | 2 +- funasr/download/download_model_from_hub.py | 17 +++++++++++++---- funasr/models/llm_asr/model.py | 18 +++++++++++++----- 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index 8677c69e6..2f713929c 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -225,7 +225,7 @@ class AutoModel: init_param = kwargs.get("init_param", None) if init_param is not None: if isinstance(init_param, str): - init_param = [init_param] + init_param = init_param.split(",") for i, init_param_i in enumerate(init_param): if os.path.exists(init_param_i): logging.info(f"Loading pretrained params from ckpt-{i}: {init_param_i}") diff --git a/funasr/download/download_model_from_hub.py b/funasr/download/download_model_from_hub.py index df4f33daf..892752445 100644 --- a/funasr/download/download_model_from_hub.py +++ b/funasr/download/download_model_from_hub.py @@ -59,10 +59,19 @@ def download_from_ms(**kwargs): elif os.path.exists(os.path.join(model_or_path, "config.yaml")): config = OmegaConf.load(os.path.join(model_or_path, "config.yaml")) kwargs = OmegaConf.merge(config, kwargs) - init_param = os.path.join(model_or_path, "model.pt") - if "init_param" not in kwargs or not os.path.exists(kwargs["init_param"]): - kwargs["init_param"] = init_param - assert os.path.exists(kwargs["init_param"]), "init_param does not exist" + + init_param = kwargs.get("init_param", "") + if not os.path.exists(init_param): + init_param_new = init_param + if isinstance(init_param, str): + init_param = init_param.split(",") + for init_param_i in init_param: + if not os.path.exists(init_param_i): + print(f"init_param: {init_param_i}, does not exist") + init_param_i = os.path.join(model_or_path, "model.pt") + init_param_new = f"{init_param_new},{init_param_i}" + kwargs["init_param"] = init_param_new + # assert os.path.exists(kwargs["init_param"]), "init_param does not exist" if os.path.exists(os.path.join(model_or_path, "tokens.txt")): kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt") if os.path.exists(os.path.join(model_or_path, "tokens.json")): diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index 89ada01ab..88eb8c001 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -2564,8 +2564,16 @@ class LLMASR5(nn.Module): fbank_beg += [fbank_beg_i + len(input_ids)] fake_token_len += [fake_token_len_i] source_mask = [-100] * len(source_ids) - target_out = f"{target_out}<|im_end|>" - target_ids = tokenizer.encode(target_out) + splits = pattern.split(target_out) + for k, sub_str in enumerate(splits): + if len(sub_str) < 1: + continue + if not sub_str.startswith("<|startofspeech|>"): + sub_str = f"{sub_str}<|im_end|>" + sub_token = tokenizer.encode(sub_str) + target_ids = sub_token + # target_out = f"{target_out}<|im_end|>" + # target_ids = tokenizer.encode(target_out) input_source_ids = input_ids + source_ids input_ids += source_ids + target_ids labels += source_mask + target_ids @@ -2740,9 +2748,9 @@ class LLMASR5(nn.Module): for i in range(token_num): hidden_states_out[0, i, :] = hidden_states[1][-1][0, 0, :].to(torch.float32) - speech_tokens = self.audio_decode( - hidden_states_out, hidden_states_out_len - ) # 1xl: 2,10,1023 + speech_tokens = self.audio_decode(hidden_states_out, hidden_states_out_len)[ + :, :, 0 + ] # 1xlx1: 2,10,1023 sequences = generated_ids["sequences"] # generated_ids = [ From e969be589e6270d69906fca252609aec8530321c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Thu, 4 Jul 2024 23:30:30 +0800 Subject: [PATCH 21/24] update --- funasr/models/llm_asr/model.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index 88eb8c001..a6a05ca0b 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -2117,11 +2117,15 @@ class LLMASR5(nn.Module): self.eos = kwargs.get("eos", 151645) # audio decoder related + self.concat_emb_hidden = audio_decoder_conf.get("concat_emb_hidden", False) self.codebook_dim = audio_decoder_conf.get("codebook_dim", 1024) self.codebook_size = audio_decoder_conf.get("codebook_size", 4096) self.lm_out_voc_size = self.codebook_size + 1 self.audio_decoder = self.build_audio_decoder(name=audio_decoder, conf=audio_decoder_conf) - self.audio_decoder_in_proj = torch.nn.Linear(llm_dim, self.audio_decoder.embed_unit) + audio_decoder_in_proj_dim = llm_dim * 2 if self.concat_emb_hidden else llm_dim + self.audio_decoder_in_proj = torch.nn.Linear( + audio_decoder_in_proj_dim, self.audio_decoder.embed_unit + ) self.codec_embedder = torch.nn.Embedding(self.codebook_size, self.codebook_dim) self.audio_decoder_embedding = torch.nn.Embedding(2, self.audio_decoder.embed_unit) self.ad_sos_eos = 0 @@ -2395,7 +2399,11 @@ class LLMASR5(nn.Module): ) target_ids_len = torch.tensor(target_ids_len, dtype=torch.int32, device=input_ids.device) target_ids = target_ids.to(device=input_ids.device) + target_ids[target_ids < 0] = 0 + target_emb = self.llm.model.get_input_embeddings()(target_ids) hidden_states_select = hidden_states_select.to(device=input_ids.device) + if self.concat_emb_hidden: + hidden_states_select = torch.concat((hidden_states_select, target_emb), dim=-1) nll, logits, target, target_lengths = self.nll( hidden_states_select, target_ids_len, codec[:, :, None], codec_len ) From 496ca8eddb1da8f928d3a58932a7fb3c820e2313 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Thu, 4 Jul 2024 23:33:51 +0800 Subject: [PATCH 22/24] update --- funasr/models/llm_asr/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index a6a05ca0b..8779a6a86 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -2754,7 +2754,7 @@ class LLMASR5(nn.Module): dtype=torch.int32, ).to(inputs_embeds.device) for i in range(token_num): - hidden_states_out[0, i, :] = hidden_states[1][-1][0, 0, :].to(torch.float32) + hidden_states_out[0, i, :] = hidden_states[i][-1][0, 0, :].to(torch.float32) speech_tokens = self.audio_decode(hidden_states_out, hidden_states_out_len)[ :, :, 0 From 8f6d2787f02027bb70c9d034e63dc08221ddf079 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Fri, 5 Jul 2024 20:53:32 +0800 Subject: [PATCH 23/24] update --- funasr/models/llm_asr/model.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index 8779a6a86..b899f0e42 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -2744,7 +2744,7 @@ class LLMASR5(nn.Module): ] # hidden_states: (t1, t2, ..., tn, ..., tN), tn=(l1, l2, ..., ln, ..., lN), ln: shape: 1x1x3584 token_num = len(hidden_states) - hidden_states_out = torch.zeros((1, token_num, 3584), dtype=torch.float32).to( + hidden_states_select = torch.zeros((1, token_num, 3584), dtype=torch.float32).to( inputs_embeds.device ) hidden_states_out_len = torch.tensor( @@ -2754,19 +2754,23 @@ class LLMASR5(nn.Module): dtype=torch.int32, ).to(inputs_embeds.device) for i in range(token_num): - hidden_states_out[0, i, :] = hidden_states[i][-1][0, 0, :].to(torch.float32) + hidden_states_select[0, i, :] = hidden_states[i][-1][0, 0, :].to(torch.float32) - speech_tokens = self.audio_decode(hidden_states_out, hidden_states_out_len)[ + target_ids = generated_ids["sequences"] + target_emb = self.llm.model.get_input_embeddings()(target_ids) + if self.concat_emb_hidden: + hidden_states_select = torch.concat((hidden_states_select, target_emb), dim=-1) + + speech_tokens = self.audio_decode(hidden_states_select, hidden_states_out_len)[ :, :, 0 ] # 1xlx1: 2,10,1023 - sequences = generated_ids["sequences"] # generated_ids = [ # output_ids[len(input_id) :] # for input_id, output_ids in zip(input_ids, generated_ids) # ] response = tokenizer.batch_decode( - sequences, skip_special_tokens=kwargs.get("skip_special_tokens", True) + target_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True) )[0] loss = None From ef5ea9b05f4e034742dd75cd44be2936c7375f3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Mon, 8 Jul 2024 14:32:58 +0800 Subject: [PATCH 24/24] update --- funasr/datasets/openai_datasets/datasets.py | 14 ++- funasr/models/llm_asr/model.py | 96 ++++++++++----------- 2 files changed, 58 insertions(+), 52 deletions(-) diff --git a/funasr/datasets/openai_datasets/datasets.py b/funasr/datasets/openai_datasets/datasets.py index 33dbe31bb..ee7685ee7 100644 --- a/funasr/datasets/openai_datasets/datasets.py +++ b/funasr/datasets/openai_datasets/datasets.py @@ -610,6 +610,8 @@ class OpenAIDatasetMultiTurnCodec(torch.utils.data.Dataset): fake_token_len_i = 0 fbank_beg_i = -1 fbank_lens_i = [] + speech = [] + speech_lengths = [] for k, sub_str in enumerate(splits): if not sub_str.startswith("<|startofspeech|>"): sub_token = self.tokenizer.encode(sub_str) @@ -688,9 +690,11 @@ class OpenAIDatasetMultiTurnCodec(torch.utils.data.Dataset): input_ids += source_ids + target_ids labels += source_mask + target_ids - fbank.append(speech[0, :, :]) + fbank_mask += fbank_mask_i - fbank_lens.append(speech_lengths) + if len(speech) > 0: + fbank.append(speech[0, :, :]) + fbank_lens.append(speech_lengths) if badcase_flag: continue @@ -706,8 +710,6 @@ class OpenAIDatasetMultiTurnCodec(torch.utils.data.Dataset): fake_token_len = torch.tensor(fake_token_len, dtype=torch.int32) output = { - "speech": fbank, - "speech_lengths": fbank_lens, "fbank_mask": fbank_mask, "fbank_beg": fbank_beg, "fake_token_len": fake_token_len, @@ -719,6 +721,10 @@ class OpenAIDatasetMultiTurnCodec(torch.utils.data.Dataset): codec_len = torch.tensor(codec_len, dtype=torch.int32) output["codec"] = codec output["codec_len"] = codec_len + if len(fbank) > 0: + output["speech"] = fbank + output["speech_lengths"] = fbank_lens + break return output diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index b899f0e42..f8fedf25c 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -974,13 +974,13 @@ class LLMASR4(nn.Module): def forward( self, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - labels_ids: torch.Tensor, - fbank_beg: torch.Tensor, - fbank_mask: torch.Tensor, + speech: torch.Tensor = None, + speech_lengths: torch.Tensor = None, + input_ids: torch.Tensor = None, + attention_mask: torch.Tensor = None, + labels_ids: torch.Tensor = None, + fbank_beg: torch.Tensor = None, + fbank_mask: torch.Tensor = None, **kwargs, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Encoder + Decoder + Calc loss @@ -993,55 +993,55 @@ class LLMASR4(nn.Module): # import pdb # # pdb.set_trace() - if len(speech_lengths.size()) > 1: - speech_lengths = speech_lengths[:, 0] - - batch_size_speech, frames, _ = speech.shape - batch_size, token_num = input_ids.shape - - with torch.cuda.amp.autocast(enabled=False): - # audio encoder - encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) - - # audio_adaptor - encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens) - input_ids[input_ids < 0] = 0 inputs_embeds = self.llm.model.get_input_embeddings()(input_ids) + if speech is not None: + if len(speech_lengths.size()) > 1: + speech_lengths = speech_lengths[:, 0] - batch_size, token_num, dims = inputs_embeds.shape - fake_token_len = kwargs.get("fake_token_len") - fake_token_len[fake_token_len < 0] = 0 - fbank_beg[fbank_beg < 0] = 0 + batch_size_speech, frames, _ = speech.shape + batch_size, token_num = input_ids.shape - speech_idx = 0 - for batch_idx in range(batch_size): + with torch.cuda.amp.autocast(enabled=False): + # audio encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) - for turn_id in range(fbank_beg.shape[1]): - fbank_beg_idx = fbank_beg[batch_idx, turn_id].item() - if fbank_beg_idx > 0: - speech_token_len = fake_token_len[batch_idx, turn_id] - speech_token = encoder_out[speech_idx, :speech_token_len, :] + # audio_adaptor + encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens) - try: - inputs_embeds[ - batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : - ] = speech_token - except Exception as e: - # - logging.error(f"{str(e)}, {traceback.format_exc()}") - logging.info( - f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}" - ) - # import pdb; - # pdb.set_trace() - speech_token_len = encoder_out_lens[speech_idx].item() + batch_size, token_num, dims = inputs_embeds.shape + fake_token_len = kwargs.get("fake_token_len") + fake_token_len[fake_token_len < 0] = 0 + fbank_beg[fbank_beg < 0] = 0 + + speech_idx = 0 + for batch_idx in range(batch_size): + + for turn_id in range(fbank_beg.shape[1]): + fbank_beg_idx = fbank_beg[batch_idx, turn_id].item() + if fbank_beg_idx > 0: + speech_token_len = fake_token_len[batch_idx, turn_id] speech_token = encoder_out[speech_idx, :speech_token_len, :] - inputs_embeds[ - batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : - ] = speech_token - speech_idx += 1 + try: + inputs_embeds[ + batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : + ] = speech_token + except Exception as e: + # + logging.error(f"{str(e)}, {traceback.format_exc()}") + logging.info( + f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}" + ) + # import pdb; + # pdb.set_trace() + speech_token_len = encoder_out_lens[speech_idx].item() + speech_token = encoder_out[speech_idx, :speech_token_len, :] + inputs_embeds[ + batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : + ] = speech_token + + speech_idx += 1 with torch.cuda.amp.autocast( enabled=True if self.llm_dtype != "fp32" else False, dtype=dtype_map[self.llm_dtype]