From 5e6fd09a4968ec8b74ac5d9db6ed594469b0e379 Mon Sep 17 00:00:00 2001 From: dcaaaa Date: Wed, 21 Aug 2024 18:00:08 +0800 Subject: [PATCH] add llm semantic vad model code --- funasr/datasets/openai_datasets/datasets.py | 275 ++++++ funasr/datasets/openai_datasets/index_ds.py | 114 +++ funasr/models/llm_asr/model.py | 912 ++++++++++++++++++++ 3 files changed, 1301 insertions(+) diff --git a/funasr/datasets/openai_datasets/datasets.py b/funasr/datasets/openai_datasets/datasets.py index 78612ae70..f2c59ce93 100644 --- a/funasr/datasets/openai_datasets/datasets.py +++ b/funasr/datasets/openai_datasets/datasets.py @@ -770,3 +770,278 @@ class OpenAIDatasetMultiTurnCodec(torch.utils.data.Dataset): break return outputs + +@tables.register("dataset_classes", "OpenAIDatasetMultiTurnForFullDuplexVAD") +class OpenAIDatasetMultiTurnForFullDuplexVAD(torch.utils.data.Dataset): + """ + SenseVoiceDataset + """ + + def __init__( + self, + path, + index_ds: str = None, + frontend=None, + tokenizer=None, + int_pad_value: int = -1, + float_pad_value: float = 0.0, + **kwargs, + ): + super().__init__() + index_ds_class = tables.index_ds_classes.get(index_ds) + self.index_ds = index_ds_class(path, **kwargs) + preprocessor_speech = kwargs.get("preprocessor_speech", None) + if preprocessor_speech: + preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech) + preprocessor_speech = preprocessor_speech_class( + **kwargs.get("preprocessor_speech_conf") + ) + self.preprocessor_speech = preprocessor_speech + preprocessor_text = kwargs.get("preprocessor_text", None) + if preprocessor_text: + preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text) + preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf")) + self.preprocessor_text = preprocessor_text + + self.frontend = frontend + self.fs = 16000 if frontend is None else frontend.fs + self.data_type = "sound" + self.tokenizer = tokenizer + + self.int_pad_value = int_pad_value + self.float_pad_value = float_pad_value + self.sos = kwargs.get("sos", "<|startoftranscript|>") + self.eos = kwargs.get("eos", "<|endoftext|>") + self.batch_size = kwargs.get("batch_size") + self.batch_type = kwargs.get("batch_type") + self.prompt_ids_len = 0 + self.retry = kwargs.get("retry", 100) + + self.permute = False + from funasr.frontends.whisper_frontend import WhisperFrontend + + if isinstance(self.frontend, WhisperFrontend): + self.permute = True + + self.pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)") + # self.kwargs = kwargs + self.max_token_length = kwargs.get("max_token_length", 1500) + self.batch_size_scale_ratio_max = kwargs.get("batch_size_scale_ratio_max", 1.5) + self.batch_size_token_max = kwargs.get("batch_size_token_max", 2500) + self.multiturn_num_max = kwargs.get("multiturn_num_max", 5) + self.max_source_length = kwargs.get("max_source_length", 3000) + + def get_source_len(self, index): + item = self.index_ds[index] + return self.index_ds.get_source_len(item) + + def get_target_len(self, index): + item = self.index_ds[index] + return self.index_ds.get_target_len(item) + + def __len__(self): + return len(self.index_ds) + + def __getitem__(self, index): + # import pdb + # + # pdb.set_trace() + + output = None + + for idx in range(self.retry): + badcase_flag = False + if idx == 0: + index_cur = index + else: + index_cur = torch.randint(0, len(self.index_ds), ()).item() + + item = self.index_ds[index_cur] + + system = item["system"] + user = item["user"] + assistant = item["assistant"] + task = item["task"] + true_time_span = item["true_time_span"] + last_time_span = item["last_total_time"] + + input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg, fake_token_len = ( + [], + [], + [], + [], + [], + [], + [], + ) + + for i, (system_prompt, user_prompt, target_out) in enumerate( + zip(system, user, assistant) + ): + if len(input_ids) > self.max_token_length: + logging.info( + f"input_ids > max_token_length: {len(input_ids)}>{self.max_token_length}, {item}" + ) + break + + if i == 0: + source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" + elif i == len(system)-1: + source_input = ( + f"<|im_start|>user\n{user_prompt}" + ) + else: + source_input = ( + f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" + ) + # self.pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)") + splits = self.pattern.split(source_input) + source_ids = [] + fbank_i = [] + fbank_mask_i = [] + fake_token_len_i = 0 + fbank_beg_i = -1 + fbank_lens_i = [] + for k, sub_str in enumerate(splits): + if not sub_str.startswith("<|startofspeech|>"): + sub_token = self.tokenizer.encode(sub_str) + source_ids += sub_token + fbank_mask_i += [0] * len(sub_token) + else: + sub_str = sub_str.replace("<|startofspeech|>", "").replace( + "<|endofspeech|>", "" + ) + if sub_str.startswith("!"): + try: + data_src = load_audio_text_image_video(sub_str[1:], fs=self.fs) + except Exception as e: + logging.error( + f"Loading wav failed! {str(e)}, {traceback.format_exc()}" + ) + badcase_flag = True + continue + speech, speech_lengths = extract_fbank( + data_src, + data_type=self.data_type, + frontend=self.frontend, + is_final=True, + ) # speech: [b, T, d] + if speech_lengths > self.max_source_length: + logging.info( + f"speech_lengths > max_source_length: {speech_lengths}>{self.max_source_length}, {item}" + ) + badcase_flag = True + if self.permute: + speech = speech.permute(0, 2, 1) + # if speech_lengths > self.batch_size: + # continue + + olens = 1 + (speech_lengths[0].item() - 3 + 2 * 1) // 2 + olens = 1 + (olens - 3 + 2 * 1) // 2 + fake_token_len_i = (olens - 1) // 2 + 1 + fake_token = [0] * fake_token_len_i + fbank_beg_i = len(source_ids) + source_ids += fake_token + fbank_mask_i += [1] * len(fake_token) + + if badcase_flag: + continue + + fbank_beg += [fbank_beg_i + len(input_ids)] + fake_token_len += [fake_token_len_i] + source_mask = [-100] * len(source_ids) + # target_out = f"{target_out}<|im_end|>" + # target_ids = self.tokenizer.encode(target_out) + target_ids = [] + input_ids += source_ids + target_ids + labels += source_mask + target_ids + fbank.append(speech[0, :, :]) + fbank_mask += fbank_mask_i + fbank_lens.append(speech_lengths) + + if badcase_flag: + continue + + turn_taking_labels = [-100] * len(labels) + barge_in_labels = [-100] * len(labels) + last_vad = [0] * fake_token_len[-1] + pos_vad = math.ceil(fake_token_len[-1] * (true_time_span/last_time_span)) + assert pos_vad <= fake_token_len[-1] + if pos_vad > 0: + last_vad[-pos_vad:] = [1] * pos_vad + + if task == "turn-taking": + turn_taking_labels[-fake_token_len[-1]:] = last_vad + elif task == "barge-in": + # print(f'barge-in: {last_vad}') + barge_in_labels[-fake_token_len[-1]:] = last_vad + + input_ids = torch.tensor(input_ids, dtype=torch.int64) # [: self.max_token_length] + attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32) + labels = torch.tensor(labels, dtype=torch.int64) # [: self.max_token_length] + turn_taking_labels = torch.tensor(turn_taking_labels, dtype=torch.int64) # [: self.max_token_length] + barge_in_labels = torch.tensor(barge_in_labels, dtype=torch.int64) # [: self.max_token_length] + + # fbank = speech[0, :, :] + # fbank_lens = torch.tensor(fbank_lens, dtype=torch.int32) + fbank_mask = torch.tensor(fbank_mask, dtype=torch.float32) + fbank_beg = torch.tensor(fbank_beg, dtype=torch.int32) + fake_token_len = torch.tensor(fake_token_len, dtype=torch.int32) + + output = { + "speech": fbank, + "speech_lengths": fbank_lens, + "fbank_mask": fbank_mask, + "fbank_beg": fbank_beg, + "fake_token_len": fake_token_len, + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels_ids": labels, + "turn_taking_labels": turn_taking_labels, + "barge_in_labels": barge_in_labels, + } + break + + return output + + def collator(self, samples: list = None): + + for idx in range(self.retry): + badcase_flag = False + + outputs = {} + for sample in samples: + if sample is None: + continue + for key in sample.keys(): + if key not in outputs: + outputs[key] = [] + if isinstance(sample[key], (list, tuple)): + outputs[key].extend(sample[key]) + else: + outputs[key].append(sample[key]) + + for key, data_list in outputs.items(): + if isinstance(data_list[0], torch.Tensor): + if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32: + + pad_value = self.int_pad_value + else: + pad_value = self.float_pad_value + + outputs[key] = torch.nn.utils.rnn.pad_sequence( + data_list, batch_first=True, padding_value=pad_value + ) + + if self.batch_type != "example": + b, t = outputs["input_ids"].shape + if b > 1 and b * t > self.batch_size_token_max: + logging.info( + f"Warning, {idx}th, b*t: {b}*{t}={b * t} > batch_size_sample_max: {self.batch_size_token_max}, drop last data" + ) + samples = samples[:-1] + continue + + break + + return outputs diff --git a/funasr/datasets/openai_datasets/index_ds.py b/funasr/datasets/openai_datasets/index_ds.py index eefb7f618..f9c4573c1 100644 --- a/funasr/datasets/openai_datasets/index_ds.py +++ b/funasr/datasets/openai_datasets/index_ds.py @@ -107,6 +107,120 @@ class OpenAIIndexDSJsonl(torch.utils.data.Dataset): # torch.utils.data.Dataset return 0 +@tables.register("index_ds_classes", "OpenAIIndexDSJsonlForFullDuplexVAD") +class OpenAIIndexDSJsonlForFullDuplexVAD(torch.utils.data.Dataset): # torch.utils.data.Dataset + + def __init__(self, path: str, **kwargs): + super().__init__() + + self.max_source_length = kwargs.get("max_source_length", 3000) + self.min_source_length = kwargs.get("min_source_length", 0) + self.max_target_length = kwargs.get("max_target_length", 2048) + self.min_target_length = kwargs.get("min_target_length", 0) + self.max_token_length = kwargs.get("max_token_length", 2200) + + is_training = kwargs.get("is_training", True) + if not (path.endswith(".jsonl") or path.endswith(".json")): + # jsonl list file + data_split_num = kwargs.get("data_split_num", 1) + data_split_i = kwargs.get("data_split_i", 0) + + if not is_training: + data_split_num = 1 + data_split_i = 0 + with open(path, encoding="utf-8") as fin: + file_list_all = fin.readlines() + + num_per_slice = (len(file_list_all) - 1) // data_split_num + 1 # 16 + file_list = file_list_all[ + data_split_i * num_per_slice : (data_split_i + 1) * num_per_slice + ] + logging.info( + f"is_training: {is_training}, data_split_num: {data_split_num}, data_split_i: {data_split_i}, \nfile_list: {file_list}, \nfile_list_all: {file_list_all}" + ) + + else: + file_list = [path] + + contents = [] + for file_json in file_list: + with open(file_json.strip(), encoding="utf-8") as fin: + for line in fin: + data_dict = json.loads(line.strip()) + data = data_dict["messages"] + for message in data: + if message['role'] == 'user': + message['content'] = message['content'].replace("/home/qinglin.zql/project/dataset/gpt-4o/vad", "/cpfs_speech/qinglin.zql/project/datasets/gpt-4o/vad") + message['content'] = message['content'].replace("/cpfs_speech/qinglin.zql/project/datasets/gpt-4o/vad/alimeeting/wav", "/cpfs_speech/qinglin.zql/project/datasets/gpt-4o/vad/alimeeting/alimeeting_vad/wav") + + speech_length = data_dict.get("speech_length", -1) // 8 + text_length = data_dict.get("text_length", 0) + task = data_dict['task'] + last_total_time = data[-1]['end_time'] - data[-1]['start_time'] + if task == 'turn-taking': + true_time_span = data[-1]['turn-taking-gap_time-added'] + elif task == "barge-in": + true_time_span = last_total_time - data[-1]['barge-in-0'] + if speech_length > self.max_source_length: + logging.info( + f"speech_length: {speech_length} > {self.max_source_length}, drop it" + ) + continue + if text_length > self.max_target_length: + continue + + self.max_target_length = kwargs.get("max_target_length", 2048) + + system, user, assistant = [], [], [] + for i, item in enumerate(data): + role = item["role"] + content = item["content"] + if role == "system": + system.append(content) + elif role == "user": + user.append(content) + elif role == "assistant": + assistant.append(content) + + system = system * len(user) + assert len(user) - 1 == len(assistant) + assistant.append("") + + contents_i = { + "system": system, + "user": user, + "assistant": assistant, + "source_len": speech_length + text_length, + "task": task, + "true_time_span": true_time_span, + "last_total_time": last_total_time + } + + contents.append(contents_i) + + self.contents = contents + + logging.info("total_num of samplers: {}, {}".format(len(self.contents), path)) + + def __len__(self): + return len(self.contents) + + def __getitem__(self, index): + + data = self.contents[index] + + return data + + def get_source_len(self, data_dict): + source_len = data_dict.get("source_len", -1) + if source_len < 0: + source_len = len(data_dict["system"]) + len(data_dict["user"]) + return source_len + + def get_target_len(self, data_dict): + + return 0 + if __name__ == "__main__": index_ds = OpenAIIndexDSJsonl( diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index e4a5e77d9..79678d5df 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -10,6 +10,9 @@ import torch.nn.functional as F from torch.cuda.amp import autocast import numpy as np import re +import math +from torch.nn import CrossEntropyLoss + from funasr.models.scama.utils import sequence_mask from funasr.losses.label_smoothing_loss import LabelSmoothingLoss from funasr.models.ctc.ctc import CTC @@ -2465,3 +2468,912 @@ class LLMASR5(nn.Module): def random_sampling(self, weighted_scores): top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True) return top_ids + +@tables.register("model_classes", "LLMVAD") +class LLMVAD(nn.Module): + """ """ + + def __init__( + self, + audio_encoder: str = None, + audio_encoder_conf: dict = None, + audio_adaptor: str = None, + audio_adaptor_conf: dict = None, + llm: str = None, + llm_conf: dict = None, + input_size: int = 80, + length_normalized_loss: bool = False, + **kwargs, + ): + + super().__init__() + + # audio encoder + hub = audio_encoder_conf.get("hub", None) + self.audio_encoder_activation_checkpoint = audio_encoder_conf.get( + "activation_checkpoint", False + ) + if hub == "ms": + from funasr import AutoModel + + model = AutoModel(model=audio_encoder, model_revision="master") + # frontend = model.kwargs.get("frontend") + audio_encoder_output_size = model.model.encoder_output_size + + audio_encoder = ( + model.model.model.encoder if hasattr(model.model, "model") else model.model.encoder + ) + + # self.frontend = frontend + + elif hub == "hf": + pass + else: + encoder_class = tables.encoder_classes.get(audio_encoder) + audio_encoder = encoder_class(input_size=input_size, **audio_encoder_conf) + audio_encoder_output_size = audio_encoder.output_size() + freeze = audio_encoder_conf.get("freeze", True) + freeze_layer_num = int(audio_encoder_conf.get("freeze_layer_num", -1)) + # if freeze_layer_num > 0: + # freeze_layer_num = range(freeze_layer_num) + + if freeze: + for name, param in audio_encoder.named_parameters(): + if freeze_layer_num > 0: + idx = re.search(r"\.\d+\.", name) + if idx is not None: + beg, end = idx.regs[0] + layer_id = int(name[beg + 1 : end - 1]) + if layer_id < freeze_layer_num: + param.requires_grad = False + elif "ln_post." not in name: + param.requires_grad = False + else: + param.requires_grad = False + + audio_encoder.eval() + + self.audio_encoder = audio_encoder + + # llm + self.llm = None + + from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig + + init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5") + logging.info(f"Loading llm ckpt: {init_param_path}") + model = AutoModelForCausalLM.from_pretrained( + init_param_path, + load_in_8bit=None, + device_map=None, + use_cache=None, + ) + logging.info(f"llm ckpt loaded: {init_param_path}") + + freeze = llm_conf.get("freeze", True) + if freeze: + for name, param in model.named_parameters(): + param.requires_grad = False + model.eval() + + logging.info(f"use_lora: {llm_conf.get('use_lora', False)}") + if llm_conf.get("use_lora", False): + from omegaconf import OmegaConf, DictConfig + + lora_conf = llm_conf.get("lora_conf", {}) + if isinstance(lora_conf, (OmegaConf, DictConfig)): + lora_conf = OmegaConf.to_container(lora_conf, resolve=True) + from peft import get_peft_model, LoraConfig, TaskType, PeftConfig, PeftModel + + lora_init_param_path = lora_conf.get("init_param_path", None) + if lora_init_param_path is not None: + model = PeftModel.from_pretrained(model, lora_init_param_path) + else: + peft_config = LoraConfig(**lora_conf) + model = get_peft_model(model, peft_config) + model.print_trainable_parameters() + + if llm_conf.get("activation_checkpoint", False): + model.gradient_checkpointing_enable() + + self.llm_dtype = llm_conf.get("llm_dtype", "fp32") + self.llm = model.to(dtype_map[self.llm_dtype]) + llm_dim = model.get_input_embeddings().weight.shape[-1] + + # adaptor + adaptor_class = tables.adaptor_classes.get(audio_adaptor) + audio_adaptor_conf["encoder_dim"] = audio_encoder_output_size + audio_adaptor_conf["llm_dim"] = llm_dim + audio_adaptor = adaptor_class(**audio_adaptor_conf) + init_param_path = audio_adaptor_conf.get("init_param_path", None) + if init_param_path is not None: + src_state = torch.load(init_param_path, map_location="cpu") + flag = audio_adaptor.load_state_dict(src_state, strict=False) + logging.info(f"Loading audio_adaptor ckpt: {init_param_path}, status: {flag}") + freeze = audio_adaptor_conf.get("freeze", False) + if freeze: + for name, param in audio_adaptor.named_parameters(): + param.requires_grad = False + audio_adaptor.eval() + + self.audio_adaptor = audio_adaptor + + self.error_calculator = None + + self.length_normalized_loss = length_normalized_loss + self.beam_search = None + + self.loss_fct = CrossEntropyLoss() + + print("self.llm.config:", self.llm.config) + from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer + from copy import deepcopy + self.task_decoder_layer_config = deepcopy(self.llm.config) + self.task_decoder_layer_config.hidden_size = self.llm.config.hidden_size // 4 + self.task_decoder_layer_config.intermediate_size = self.llm.config.intermediate_size // 4 + self.task_decoder_layer_config.num_attention_heads = self.llm.config.num_attention_heads // 4 + self.task_decoder_layer_config.num_key_value_heads = self.llm.config.num_key_value_heads // 4 + print("self.task_decoder_layer_config:", self.task_decoder_layer_config) + self.down_proj = nn.Linear(self.llm.config.hidden_size, self.task_decoder_layer_config.hidden_size, bias=False).to(dtype_map[self.llm_dtype]) + self.task_decoder_layer = Qwen2DecoderLayer(self.task_decoder_layer_config, self.llm.config.num_hidden_layers).to(dtype_map[self.llm_dtype]) + if getattr(self.llm.config, "classifier_dropout", None) is not None: + classifier_dropout = self.llm.config.classifier_dropout + elif getattr(self.llm.config, "hidden_dropout", None) is not None: + classifier_dropout = self.llm.config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.barge_in_num_labels = 2 + self.turn_taking_num_labels = 2 + self.barge_in_score = nn.Linear(self.task_decoder_layer_config.hidden_size, self.barge_in_num_labels).to(dtype_map[self.llm_dtype]) + self.turn_taking_score = nn.Linear(self.task_decoder_layer_config.hidden_size, self.turn_taking_num_labels).to(dtype_map[self.llm_dtype]) + + + def forward( + self, + speech: torch.Tensor = None, + speech_lengths: torch.Tensor = None, + input_ids: torch.Tensor = None, + attention_mask: torch.Tensor = None, + labels_ids: torch.Tensor = None, + fbank_beg: torch.Tensor = None, + fbank_mask: torch.Tensor = None, + turn_taking_labels: torch.Tensor = None, + barge_in_labels: torch.Tensor = None, + **kwargs, + ): + """Encoder + Decoder + Calc loss + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + """ + # import pdb + # + # pdb.set_trace() + input_ids[input_ids < 0] = 0 + inputs_embeds = self.llm.model.get_input_embeddings()(input_ids) + + if speech is not None: + if len(speech_lengths.size()) > 1: + speech_lengths = speech_lengths[:, 0] + + batch_size_speech, frames, _ = speech.shape + batch_size, token_num = input_ids.shape + + # with torch.cuda.amp.autocast(enabled=False): + # audio encoder + if self.audio_encoder_activation_checkpoint: + from torch.utils.checkpoint import checkpoint + + encoder_out, encoder_out_lens = checkpoint( + self.encode, speech, speech_lengths, use_reentrant=False + ) + else: + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + + # audio_adaptor + encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens) + + batch_size, token_num, dims = inputs_embeds.shape + fake_token_len = kwargs.get("fake_token_len") + fake_token_len[fake_token_len < 0] = 0 + fbank_beg[fbank_beg < 0] = 0 + + speech_idx = 0 + for batch_idx in range(batch_size): + + for turn_id in range(fbank_beg.shape[1]): + fbank_beg_idx = fbank_beg[batch_idx, turn_id].item() + if fbank_beg_idx > 0: + speech_token_len = fake_token_len[batch_idx, turn_id] + speech_token = encoder_out[speech_idx, :speech_token_len, :] + + try: + inputs_embeds[ + batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : + ] = speech_token + except Exception as e: + # + logging.error(f"{str(e)}, {traceback.format_exc()}") + logging.info( + f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}" + ) + # import pdb; + # pdb.set_trace() + speech_token_len = encoder_out_lens[speech_idx].item() + speech_token = encoder_out[speech_idx, :speech_token_len, :] + inputs_embeds[ + batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : + ] = speech_token + + speech_idx += 1 + + with torch.cuda.amp.autocast( + enabled=True if self.llm_dtype != "fp32" else False, dtype=dtype_map[self.llm_dtype] + ): + labels_ids[labels_ids == -1] = -100 + attention_mask[attention_mask < 0] = 0 + model_outputs = self.llm( + inputs_embeds=inputs_embeds.to(dtype_map[self.llm_dtype]), + attention_mask=attention_mask, + labels=labels_ids, + output_hidden_states=True, + ) + output_attentions = kwargs.get("output_attentions", None) + past_key_values = kwargs.get("past_key_values", None) + past_key_values_length = kwargs.get("past_key_values_length", 0) + position_ids = kwargs.get("position_ids", None) + use_cache = kwargs.get("use_cache", None) + seq_length = token_num + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, \ + _prepare_4d_causal_attention_mask_for_sdpa + + if self.llm.config._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.llm.config._attn_implementation == "sdpa" and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + causal_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.llm.config.sliding_window, + ) + else: + # 4d mask is passed through the layers + causal_attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.llm.config.sliding_window, + ) + + sequence_output = model_outputs.hidden_states[-1] + sequence_output = self.down_proj(sequence_output) + if self.llm.model.gradient_checkpointing and self.llm.model.training: + layer_outputs = self.llm._gradient_checkpointing_func( + self.task_decoder_layer.__call__, + sequence_output, + causal_attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = self.task_decoder_layer( + sequence_output, + attention_mask=causal_attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + sequence_output = layer_outputs[0] + + sequence_output = self.dropout(sequence_output) + turn_taking_logits = self.turn_taking_score(sequence_output) + barge_in_logits = self.barge_in_score(sequence_output) + + loss = None + if barge_in_labels is not None: + barge_in_labels[barge_in_labels == -1] = -100 + barge_in_loss = self.loss_fct(barge_in_logits.view(-1, self.barge_in_num_labels), barge_in_labels.view(-1)) + loss = barge_in_loss + if turn_taking_labels is not None: + turn_taking_labels[turn_taking_labels == -1] = -100 + turn_taking_loss = self.loss_fct(turn_taking_logits.view(-1, self.turn_taking_num_labels), turn_taking_labels.view(-1)) + loss = turn_taking_loss if loss is None else loss + turn_taking_loss + + stats = {} + # with torch.no_grad(): + # preds = torch.argmax(model_outputs.logits, -1) + # acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100) + # stats["acc"] = acc_att + if turn_taking_labels is not None: + stats["turn_taking_loss"] = torch.clone(turn_taking_loss.detach()) + with torch.no_grad(): + turn_taking_preds = torch.argmax(turn_taking_logits, -1) + turn_taking_acc = compute_accuracy(turn_taking_preds, turn_taking_labels, ignore_label=-100) + stats["turn_taking_acc"] = turn_taking_acc + if barge_in_labels is not None: + stats["barge_in_loss"] = torch.clone(barge_in_loss.detach()) + with torch.no_grad(): + barge_in_preds = torch.argmax(barge_in_logits, -1) + barge_in_acc = compute_accuracy(barge_in_preds, barge_in_labels, ignore_label=-100) + stats["barge_in_acc"] = barge_in_acc + stats["loss"] = torch.clone(loss.detach()) + stats["batch_size"] = batch_size + stats["batch_size_speech"] = batch_size_speech + stats["batch_size_x_frames"] = frames * batch_size_speech + stats["batch_size_real_frames"] = speech_lengths.sum().item() + stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"] + stats["batch_size_x_tokens"] = token_num * batch_size + stats["batch_size_real_tokens"] = attention_mask.sum().item() + stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"] + + dialog_turns = (fbank_beg > 0).sum(-1) + dialog_turns_max = torch.max(dialog_turns).int().item() + dialog_turns_avg = dialog_turns.sum().item() / batch_size + stats["dialog_turns_max"] = dialog_turns_max + stats["dialog_turns_avg"] = dialog_turns_avg + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + if self.length_normalized_loss: + batch_size = int((labels_ids > 0 + 1).sum()) + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + def vad_inference( + self, + data_in, + data_lengths=None, + key: list = None, + tokenizer=None, + frontend=None, + **kwargs, + ): + + inputs_embeds, contents, batch, source_ids, meta_data = self.vad_inference_prepare( + data_in, data_lengths, key, tokenizer, frontend, **kwargs + ) + task = contents.get("task", "vad") + fbank_beg = batch["fbank_beg"] + fake_token_len = batch["fake_token_len"] + fbank_mask = batch["fbank_mask"] + batch_size, token_num, dims = inputs_embeds.shape + fake_token_len[fake_token_len < 0] = 0 + fbank_beg[fbank_beg < 0] = 0 + + llm_dtype = kwargs.get("llm_dtype", "fp32") + if llm_dtype == "fp32": + llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype + llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype + + stats = {"turn_taking_preds": [], "barge_in_preds": [], "turn_taking_labels": [], "barge_in_labels": [], 'task': task} + with torch.cuda.amp.autocast( + enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype] + ): + self.llm = self.llm.to(dtype_map[llm_dtype]) + self.down_proj = self.down_proj.to(dtype_map[llm_dtype]) + self.task_decoder_layer = self.task_decoder_layer.to(dtype_map[llm_dtype]) + self.turn_taking_score = self.turn_taking_score.to(dtype_map[llm_dtype]) + self.barge_in_score = self.barge_in_score.to(dtype_map[llm_dtype]) + + inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype]) + llm_kwargs = kwargs.get("llm_kwargs", {}) + + attention_mask = batch.get("attention_mask", None) + # attention_mask = attention_mask.to(dtype_map[llm_dtype]) + model_outputs = self.llm( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + labels=None, + output_hidden_states=True, + **llm_kwargs, + ) + output_attentions = llm_kwargs.get("output_attentions", None) + past_key_values = llm_kwargs.get("past_key_values", None) + past_key_values_length = llm_kwargs.get("past_key_values_length", 0) + position_ids = llm_kwargs.get("position_ids", None) + use_cache = llm_kwargs.get("use_cache", None) + seq_length = token_num + if position_ids is None: + device = inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, \ + _prepare_4d_causal_attention_mask_for_sdpa + + if self.llm.config._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.llm.config._attn_implementation == "sdpa" and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.llm.config.sliding_window, + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.llm.config.sliding_window, + ) + + sequence_output = model_outputs.hidden_states[-1] + sequence_output = self.down_proj(sequence_output) + + layer_outputs = self.task_decoder_layer( + sequence_output, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + sequence_output = layer_outputs[0] + + sequence_output = self.dropout(sequence_output) + turn_taking_logits = self.turn_taking_score(sequence_output) + barge_in_logits = self.barge_in_score(sequence_output) + + turn_taking_labels = batch.get("turn_taking_labels", None) + barge_in_labels = batch.get("barge_in_labels", None) + # print(f'batch: {batch}') + # print(f"fake_token_len: {fake_token_len}") + # print(f"turn taking labels: {turn_taking_labels}") + # print(f"barge in labels: {barge_in_labels}") + turn_taking_preds_res = [] + barge_in_preds_res = [] + turn_taking_labels_res = [] + barge_in_labels_res = [] + with torch.no_grad(): + turn_taking_preds = torch.argmax(turn_taking_logits, -1) + barge_in_preds = torch.argmax(barge_in_logits, -1) + for batch_idx in range(batch_size): + fbank_begin_index = fbank_beg[batch_idx, -1].item() + fbank_end_index = fbank_begin_index + fake_token_len[batch_idx, -1].item() + turn_taking_preds_last = turn_taking_preds[batch_idx, fbank_begin_index:fbank_end_index].cpu().numpy().tolist() + turn_taking_preds_res.append(turn_taking_preds_last) + # print(f"turn_taking_labels: {turn_taking_labels}") + turn_taking_labels_last = turn_taking_labels[batch_idx, fbank_begin_index:fbank_end_index].cpu().numpy().tolist() + turn_taking_labels_res.append(turn_taking_labels_last) + # print(f"turn_taking_preds: {turn_taking_preds_last}") + barge_in_preds_last = barge_in_preds[batch_idx, fbank_begin_index:fbank_end_index].cpu().numpy().tolist() + barge_in_preds_res.append(barge_in_preds_last) + # print(f"barge_in_labels: {barge_in_labels}") + barge_in_labels_last = barge_in_labels[batch_idx, fbank_begin_index:fbank_end_index].cpu().numpy().tolist() + barge_in_labels_res.append(barge_in_labels_last) + + turn_taking_acc = compute_accuracy(turn_taking_preds, turn_taking_labels, ignore_label=-100) + stats["turn_taking_acc"] = turn_taking_acc.item() + + barge_in_acc = compute_accuracy(barge_in_preds, barge_in_labels, ignore_label=-100) + stats["barge_in_acc"] = barge_in_acc.item() + stats["turn_taking_preds"].append(turn_taking_preds_res) + stats["barge_in_preds"].append(barge_in_preds_res) + stats["turn_taking_labels"].append(turn_taking_labels_res) + stats["barge_in_labels"].append(barge_in_labels_res) + return turn_taking_logits, barge_in_logits, meta_data, stats + + + def encode(self, speech, speech_lengths): + # audio encoder + encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths) + + return encoder_out, encoder_out_lens + + def vad_data_template(self, sample): + data = sample["messages"] + system, user, assistant = [], [], [] + for i, item in enumerate(data): + role = item["role"] + content = item["content"] + if role == "system": + system.append(content) + elif role == "user": + if "audio" in item: + audio = item["audio"] + content = [content, audio] + user.append(content) + elif role == "assistant": + assistant.append(content) + + system = system * len(user) + assistant.append("") + contents = { + "system": system, + "user": user, + "assistant": assistant, + } + + if "task" in sample: + task = sample['task'] + last_total_time = data[-1]['end_time'] - data[-1]['start_time'] + if task == 'turn-taking': + true_time_span = data[-1]['turn-taking-gap_time-added'] + elif task == "barge-in": + true_time_span = last_total_time - data[-1]['barge-in-0'] + else: + raise ValueError("task must be turn-taking or barge-in") + contents["true_time_span"] = true_time_span + contents["last_total_time"] = last_total_time + contents['task'] = sample['task'] + return contents + + + def data_template(self, data): + system, user, assistant = [], [], [] + for i, item in enumerate(data): + role = item["role"] + content = item["content"] + if role == "system": + system.append(content) + elif role == "user": + if "audio" in item: + audio = item["audio"] + content = [content, audio] + user.append(content) + elif role == "assistant": + assistant.append(content) + + system = system * len(user) + + contents = { + "system": system, + "user": user, + "assistant": assistant, + } + + return contents + + + def vad_data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs): + + system = contents["system"] + user = contents["user"] + assistant = contents["assistant"] + pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)") + + input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg, fake_token_len = ( + [], + [], + [], + [], + [], + [], + [], + ) + input_source_ids = [] + for i, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)): + if isinstance(user_prompt, (list, tuple)): + user_prompt, audio = user_prompt + if i == 0: + source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" + elif i == len(system) - 1: + source_input = ( + f"<|im_start|>user\n{user_prompt}" + ) + else: + source_input = ( + f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" + ) + + splits = pattern.split(source_input) + source_ids = [] + fbank_i = [] + fbank_mask_i = [] + fake_token_len_i = 0 + fbank_beg_i = -1 + fbank_lens_i = [] + speech, speech_lengths = [], [] + for k, sub_str in enumerate(splits): + if not sub_str.startswith("<|startofspeech|>"): + sub_token = tokenizer.encode(sub_str) + source_ids += sub_token + fbank_mask_i += [0] * len(sub_token) + else: + sub_str = sub_str.replace("<|startofspeech|>", "").replace( + "<|endofspeech|>", "" + ) + if sub_str.startswith("!"): + sub_str = sub_str[1:] + if sub_str.startswith("!"): # !!: audio sample point + sub_str = audio + try: + time1 = time.perf_counter() + data_src = load_audio_text_image_video(sub_str, fs=frontend.fs) + time2 = time.perf_counter() + meta_data["load_data"] = f"{time2 - time1:0.3f}" + except Exception as e: + logging.error(f"Loading wav failed! {str(e)}, {traceback.format_exc()}") + + speech, speech_lengths = extract_fbank( + data_src, + data_type=kwargs.get("data_type", "sound"), + frontend=frontend, + is_final=True, + ) # speech: [b, T, d] + + time3 = time.perf_counter() + meta_data["extract_feat"] = f"{time3 - time2:0.3f}" + meta_data["batch_data_time"] = ( + speech_lengths.sum().item() + * frontend.frame_shift + * frontend.lfr_n + / 1000 + ) + + if kwargs.get("permute", True): + speech = speech.permute(0, 2, 1) + if speech_lengths > kwargs.get("max_source_length", 5500): + # logging.info( + # f"speech_lengths > max_source_length: {speech_lengths}>{self.max_source_length}, {item}" + # ) + badcase_flag = True + + olens = 1 + (speech_lengths[0].item() - 3 + 2 * 1) // 2 + olens = 1 + (olens - 3 + 2 * 1) // 2 + fake_token_len_i = (olens - 1) // 2 + 1 + fake_token = [0] * fake_token_len_i + fbank_beg_i = len(source_ids) + source_ids += fake_token + fbank_mask_i += [1] * len(fake_token) + + fbank_beg += [fbank_beg_i + len(input_ids)] + fake_token_len += [fake_token_len_i] + source_mask = [-100] * len(source_ids) + # target_out = f"{target_out}<|im_end|>" + # target_ids = tokenizer.encode(target_out) + target_ids = [] + input_source_ids = input_ids + source_ids + input_ids += source_ids + target_ids + labels += source_mask + target_ids + fbank_mask += fbank_mask_i + if len(speech) > 0: + fbank.append(speech[0, :, :]) + fbank_lens.append(speech_lengths) + + turn_taking_labels = [-100] * len(labels) + barge_in_labels = [-100] * len(labels) + last_vad = [0] * fake_token_len[-1] + if "true_time_span" in contents: + true_time_span = contents["true_time_span"] + last_time_span = contents["last_total_time"] + pos_vad = math.ceil(fake_token_len[-1] * (true_time_span/last_time_span)) + assert pos_vad <= fake_token_len[-1] + if pos_vad > 0: + last_vad[-pos_vad:] = [1] * pos_vad + turn_taking_labels[-fake_token_len[-1]:] = last_vad + barge_in_labels[-fake_token_len[-1]:] = last_vad + + input_ids = torch.tensor(input_ids, dtype=torch.int64) # [: self.max_token_length] + attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32) + labels = torch.tensor(labels, dtype=torch.int64) # [: self.max_token_length] + turn_taking_labels = torch.tensor([turn_taking_labels], dtype=torch.int64) # [: self.max_token_length] + barge_in_labels = torch.tensor([barge_in_labels], dtype=torch.int64) # [: self.max_token_length] + + # fbank = speech[0, :, :] + # fbank_lens = torch.tensor(fbank_lens, dtype=torch.int32) + fbank_mask = torch.tensor(fbank_mask, dtype=torch.float32) + fbank_beg = torch.tensor(fbank_beg, dtype=torch.int32) + fake_token_len = torch.tensor(fake_token_len, dtype=torch.int32) + source_ids = torch.tensor(input_source_ids, dtype=torch.int64) + target_ids = torch.tensor(target_ids, dtype=torch.int64) + + if len(fbank) > 0: + speech = torch.nn.utils.rnn.pad_sequence(fbank, batch_first=True, padding_value=0.0) + speech_lengths = torch.nn.utils.rnn.pad_sequence( + fbank_lens, batch_first=True, padding_value=-1 + ) + else: + speech = [] + speech_lengths = [] + output = { + "speech": speech, + "speech_lengths": speech_lengths, + "fbank_mask": fbank_mask[None, :], + "fbank_beg": fbank_beg[None,], + "fake_token_len": fake_token_len[None, :], + "input_ids": input_ids[None,], + "attention_mask": attention_mask[None,], + "labels_ids": labels, + "source_ids": source_ids[None, :], + "target_ids": target_ids[None, :], + "turn_taking_labels": turn_taking_labels, + "barge_in_labels": barge_in_labels, + } + + return output + + def vad_inference_prepare( + self, + data_in, + data_lengths=None, + key: list = None, + tokenizer=None, + frontend=None, + **kwargs, + ): + + meta_data = {} + prompt = kwargs.get("prompt", None) + + if kwargs.get("batch_size", 1) > 1: + raise NotImplementedError("batch decoding is not implemented") + + contents = self.vad_data_template(data_in[0]) + output = self.vad_data_load_speech(contents, tokenizer, frontend, meta_data=meta_data, **kwargs) + batch = to_device(output, kwargs["device"]) + + # audio encoder + speech = batch["speech"] + if len(speech) > 0: + speech_lengths = batch["speech_lengths"][:, 0] + # fp16 + if kwargs.get("fp16", False): + speech = speech.to(torch.float16) + elif kwargs.get("bf16", False): + speech = speech.to(torch.bfloat16) + # audio encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + + # audio_adaptor + encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens) + + input_ids = batch["input_ids"] + source_ids = batch["source_ids"] + fbank_beg = batch["fbank_beg"] + fake_token_len = batch["fake_token_len"] + + if not kwargs.get("tearchforing", False): + input_ids = source_ids + + input_ids[input_ids < 0] = 0 + inputs_embeds = self.llm.model.get_input_embeddings()(input_ids) + + batch_size, token_num, dims = inputs_embeds.shape + + fake_token_len[fake_token_len < 0] = 0 + fbank_beg[fbank_beg < 0] = 0 + + speech_idx = 0 + for batch_idx in range(batch_size): + + for turn_id in range(fbank_beg.shape[1]): + fbank_beg_idx = fbank_beg[batch_idx, turn_id].item() + if fbank_beg_idx > 0: + speech_token_len = fake_token_len[batch_idx, turn_id] + speech_token = encoder_out[speech_idx, :speech_token_len, :] + + try: + inputs_embeds[ + batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : + ] = speech_token + except Exception as e: + # + logging.error(f"{str(e)}, {traceback.format_exc()}") + logging.info( + f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}" + ) + # import pdb; + # pdb.set_trace() + speech_token_len = encoder_out_lens[speech_idx].item() + speech_token = encoder_out[speech_idx, :speech_token_len, :] + inputs_embeds[ + batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : + ] = speech_token + + speech_idx += 1 + return inputs_embeds, contents, batch, source_ids, meta_data + + def inference( + self, + data_in, + data_lengths=None, + key: list = None, + tokenizer=None, + frontend=None, + **kwargs, + ): + + inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare( + data_in, data_lengths, key, tokenizer, frontend, **kwargs + ) + + llm_dtype = kwargs.get("llm_dtype", "fp32") + if llm_dtype == "fp32": + llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype + llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype + + with torch.cuda.amp.autocast( + enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype] + ): + label = contents["assistant"][-1] + self.llm = self.llm.to(dtype_map[llm_dtype]) + inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype]) + llm_kwargs = kwargs.get("llm_kwargs", {}) + if not kwargs.get("tearchforing", False): + + generated_ids = self.llm.generate( + inputs_embeds=inputs_embeds, + max_new_tokens=kwargs.get("max_length", 512), + **llm_kwargs, + ) + # generated_ids = [ + # output_ids[len(input_id) :] + # for input_id, output_ids in zip(input_ids, generated_ids) + # ] + response = tokenizer.batch_decode( + generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True) + )[0] + + loss = None + else: + + labels_ids = batch["labels_ids"] + labels_ids[labels_ids == -1] = -100 + attention_mask = batch.get("attention_mask", None) + # attention_mask = attention_mask.to(dtype_map[llm_dtype]) + model_outputs = self.llm( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + labels=labels_ids, + **llm_kwargs, + ) + + preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :] + response = tokenizer.batch_decode( + preds, + add_special_tokens=False, + skip_special_tokens=kwargs.get("skip_special_tokens", True), + )[0] + loss = model_outputs.loss.item() + + ibest_writer = None + if kwargs.get("output_dir") is not None: + if not hasattr(self, "writer"): + self.writer = DatadirWriter(kwargs.get("output_dir")) + ibest_writer = self.writer[f"{0 + 1}best_recog"] + + results = [] + response_clean = re.sub(r"[^\w\s\u3000\u4e00-\u9fff]+", "", response) + result_i = {"key": key[0], "text": response, "text_tn": response_clean, "label": label} + if loss is not None: + result_i["loss"] = loss + results.append(result_i) + + if ibest_writer is not None: + ibest_writer["text"][key[0]] = response.replace("\n", " ") + ibest_writer["label"][key[0]] = label.replace("\n", " ") + ibest_writer["text_tn"][key[0]] = response_clean + + return results, meta_data