mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
auto frontend
This commit is contained in:
parent
6a7d34392c
commit
27256ed429
0
funasr/datasets/openai_datasets/__init__.py
Normal file
0
funasr/datasets/openai_datasets/__init__.py
Normal file
216
funasr/datasets/openai_datasets/datasets.py
Normal file
216
funasr/datasets/openai_datasets/datasets.py
Normal file
@ -0,0 +1,216 @@
|
||||
import logging
|
||||
import re
|
||||
import torch
|
||||
import random
|
||||
import traceback
|
||||
from funasr.register import tables
|
||||
from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
|
||||
|
||||
|
||||
@tables.register("dataset_classes", "OpenAIDataset")
|
||||
class OpenAIDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
SenseVoiceDataset
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path,
|
||||
index_ds: str = None,
|
||||
frontend=None,
|
||||
tokenizer=None,
|
||||
int_pad_value: int = -1,
|
||||
float_pad_value: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
index_ds_class = tables.index_ds_classes.get(index_ds)
|
||||
self.index_ds = index_ds_class(path, **kwargs)
|
||||
preprocessor_speech = kwargs.get("preprocessor_speech", None)
|
||||
if preprocessor_speech:
|
||||
preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
|
||||
preprocessor_speech = preprocessor_speech_class(
|
||||
**kwargs.get("preprocessor_speech_conf")
|
||||
)
|
||||
self.preprocessor_speech = preprocessor_speech
|
||||
preprocessor_text = kwargs.get("preprocessor_text", None)
|
||||
if preprocessor_text:
|
||||
preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
|
||||
preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
|
||||
self.preprocessor_text = preprocessor_text
|
||||
|
||||
self.frontend = frontend
|
||||
self.fs = 16000 if frontend is None else frontend.fs
|
||||
self.data_type = "sound"
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.int_pad_value = int_pad_value
|
||||
self.float_pad_value = float_pad_value
|
||||
self.sos = kwargs.get("sos", "<|startoftranscript|>")
|
||||
self.eos = kwargs.get("eos", "<|endoftext|>")
|
||||
self.batch_size = kwargs.get("batch_size")
|
||||
self.batch_type = kwargs.get("batch_type")
|
||||
self.prompt_ids_len = 0
|
||||
self.retry = kwargs.get("retry", 5)
|
||||
|
||||
self.permute = False
|
||||
from funasr.frontends.whisper_frontend import WhisperFrontend
|
||||
|
||||
if isinstance(self.frontend, WhisperFrontend):
|
||||
self.permute = True
|
||||
|
||||
self.pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)")
|
||||
|
||||
def get_source_len(self, index):
|
||||
item = self.index_ds[index]
|
||||
return self.index_ds.get_source_len(item)
|
||||
|
||||
def get_target_len(self, index):
|
||||
item = self.index_ds[index]
|
||||
return self.index_ds.get_target_len(item)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.index_ds)
|
||||
|
||||
def __getitem__(self, index):
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
|
||||
output = None
|
||||
for idx in range(self.retry):
|
||||
if idx == 0:
|
||||
index_cur = index
|
||||
else:
|
||||
index_cur = torch.randint(0, len(self.index_ds), ()).item()
|
||||
|
||||
item = self.index_ds[index_cur]
|
||||
|
||||
system = item["system"]
|
||||
user = item["user"]
|
||||
assistant = item["assistant"]
|
||||
|
||||
input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg = [], [], [], [], [], []
|
||||
|
||||
for i, (system_prompt, user_prompt, target_out) in enumerate(
|
||||
zip(system, user, assistant)
|
||||
):
|
||||
|
||||
source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
splits = self.pattern.split(source_input)
|
||||
source_ids = []
|
||||
fbank_mask_i = []
|
||||
fbank_beg_i = []
|
||||
fbank_lens_i = []
|
||||
for k, sub_str in enumerate(splits):
|
||||
if not sub_str.startswith("<|startofspeech|>"):
|
||||
sub_token = self.tokenizer.encode(sub_str)
|
||||
source_ids += sub_token
|
||||
fbank_mask_i += [0] * len(sub_token)
|
||||
else:
|
||||
sub_str = sub_str.replace("<|startofspeech|>", "").replace(
|
||||
"<|endofspeech|>", ""
|
||||
)
|
||||
if sub_str.startswith("!"):
|
||||
|
||||
data_src = load_audio_text_image_video(sub_str[1:], fs=self.fs)
|
||||
|
||||
speech, speech_lengths = extract_fbank(
|
||||
data_src,
|
||||
data_type=self.data_type,
|
||||
frontend=self.frontend,
|
||||
is_final=True,
|
||||
) # speech: [b, T, d]
|
||||
if self.permute:
|
||||
speech = speech.permute(0, 2, 1)
|
||||
if speech_lengths > self.batch_size:
|
||||
continue
|
||||
|
||||
fbank_lens = speech_lengths[0].item()
|
||||
olens = 1 + (fbanks_len - 3 + 2 * 1) // 2
|
||||
olens = 1 + (olens - 3 + 2 * 1) // 2
|
||||
sub_token_len = (olens - 1) // 2 + 1
|
||||
sub_token = [0] * sub_token_len[0]
|
||||
fbank_beg_i = [len(source_ids)]
|
||||
source_ids += sub_token
|
||||
fbank_mask_i += [1] * len(sub_token)
|
||||
|
||||
source_mask = [-100] * len(source_ids)
|
||||
target_out = f"{target_out}<|im_end|>"
|
||||
target_ids = tokenizer.encode(target_out)
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_mask + target_ids
|
||||
fbank_mask += fbank_mask_i
|
||||
fbank_beg.append(fbank_beg_i)
|
||||
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64)
|
||||
attention_mask = torch.tensor([len(input_ids)], dtype=torch.int32)
|
||||
labels = torch.tensor(labels, dtype=torch.int64)
|
||||
|
||||
fbank = speech[0, :, :]
|
||||
fbank_lens = speech_lengths
|
||||
fbank_mask = torch.tensor(fbank_mask, dtype=torch.float32)
|
||||
fbank_beg = torch.tensor(fbank_beg, dtype=torch.int32)
|
||||
|
||||
output = {
|
||||
"speech": fbank,
|
||||
"speech_lengths": fbank_lens,
|
||||
"fbank_mask": fbank_mask,
|
||||
"fbank_beg": fbank_beg,
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels_ids": labels,
|
||||
}
|
||||
break
|
||||
|
||||
return output
|
||||
|
||||
def collator(self, samples: list = None):
|
||||
outputs = {}
|
||||
for sample in samples:
|
||||
if sample is None:
|
||||
continue
|
||||
for key in sample.keys():
|
||||
if key not in outputs:
|
||||
outputs[key] = []
|
||||
outputs[key].append(sample[key])
|
||||
|
||||
for key, data_list in outputs.items():
|
||||
if isinstance(data_list[0], torch.Tensor):
|
||||
if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32:
|
||||
|
||||
pad_value = self.int_pad_value
|
||||
else:
|
||||
pad_value = self.float_pad_value
|
||||
|
||||
outputs[key] = torch.nn.utils.rnn.pad_sequence(
|
||||
data_list, batch_first=True, padding_value=pad_value
|
||||
)
|
||||
|
||||
if self.batch_type != "example":
|
||||
for i in range(10):
|
||||
outputs = self._filter_badcase(outputs, i=i)
|
||||
|
||||
return outputs
|
||||
|
||||
def _filter_badcase(self, outputs, i=0):
|
||||
b, t, _ = outputs["speech"].shape
|
||||
|
||||
if b * t > self.batch_size * 1.25:
|
||||
beg = torch.randint(0, 2, ()).item()
|
||||
if b < 2:
|
||||
beg = 0
|
||||
logging.info(
|
||||
f"Warning, b * t: {b * t} > {self.batch_size}, drop half data {i}th, beg:{beg}"
|
||||
)
|
||||
for key, data_list in outputs.items():
|
||||
outputs[key] = outputs[key][beg : beg + b : 2]
|
||||
|
||||
speech_lengths_max = outputs["speech_lengths"].max().item()
|
||||
outputs["speech"] = outputs["speech"][:, :speech_lengths_max, :]
|
||||
text_lengths_max = outputs["text_lengths"].max().item()
|
||||
outputs["text"] = outputs["text"][:, :text_lengths_max]
|
||||
target_mask_lengths_max = outputs["target_mask_lengths"].max().item()
|
||||
outputs["target_mask"] = outputs["target_mask"][:, :target_mask_lengths_max]
|
||||
|
||||
return outputs
|
||||
95
funasr/datasets/openai_datasets/index_ds.py
Normal file
95
funasr/datasets/openai_datasets/index_ds.py
Normal file
@ -0,0 +1,95 @@
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
import logging
|
||||
|
||||
import librosa
|
||||
import random
|
||||
import torch.distributed as dist
|
||||
|
||||
from funasr.register import tables
|
||||
|
||||
|
||||
@tables.register("index_ds_classes", "OpenAIIndexDSJsonl")
|
||||
class OpenAIIndexDSJsonl(torch.utils.data.Dataset): # torch.utils.data.Dataset
|
||||
|
||||
def __init__(self, path: str, **kwargs):
|
||||
super().__init__()
|
||||
self.max_source_length = kwargs.get("max_source_length", 2048)
|
||||
self.min_source_length = kwargs.get("min_source_length", 0)
|
||||
self.max_target_length = kwargs.get("max_target_length", 2048)
|
||||
self.min_target_length = kwargs.get("min_target_length", 0)
|
||||
self.max_token_length = kwargs.get("max_token_length", 2200)
|
||||
|
||||
is_training = kwargs.get("is_training", True)
|
||||
if not (path.endswith(".jsonl") or path.endswith(".json")):
|
||||
# jsonl list file
|
||||
data_split_num = kwargs.get("data_split_num", 1)
|
||||
data_split_i = kwargs.get("data_split_i", 0)
|
||||
|
||||
if not is_training:
|
||||
data_split_num = 1
|
||||
data_split_i = 0
|
||||
with open(path, encoding="utf-8") as fin:
|
||||
file_list_all = fin.readlines()
|
||||
|
||||
num_per_slice = (len(file_list_all) - 1) // data_split_num + 1 # 16
|
||||
file_list = file_list_all[
|
||||
data_split_i * num_per_slice : (data_split_i + 1) * num_per_slice
|
||||
]
|
||||
logging.info(
|
||||
f"is_training: {is_training}, data_split_num: {data_split_num}, data_split_i: {data_split_i}, \nfile_list: {file_list}, \nfile_list_all: {file_list_all}"
|
||||
)
|
||||
|
||||
else:
|
||||
file_list = [path]
|
||||
|
||||
contents = []
|
||||
for file_json in file_list:
|
||||
with open(file_json.strip(), encoding="utf-8") as fin:
|
||||
for line in fin:
|
||||
data = json.loads(line.strip())["messages"]
|
||||
|
||||
system, user, assistant = [], [], []
|
||||
for i, item in enumerate(data):
|
||||
role = item["role"]
|
||||
content = item["content"]
|
||||
if role == "system":
|
||||
system.append(content)
|
||||
elif role == "user":
|
||||
user.append(content)
|
||||
elif role == "assistant":
|
||||
assistant.append(content)
|
||||
|
||||
system = system * len(user)
|
||||
|
||||
contents_i = {"system": system, "user": user, "assistant": assistant}
|
||||
contents.append(contents_i)
|
||||
|
||||
self.contents = contents
|
||||
|
||||
logging.info("total_num of samplers: {}, {}".format(len(self.contents), path))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.contents)
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
data = self.contents[index]
|
||||
|
||||
return data
|
||||
|
||||
def get_source_len(self, data_dict):
|
||||
return len(data_dict["system"]) + len(data_dict["user"])
|
||||
|
||||
def get_target_len(self, data_dict):
|
||||
|
||||
return len(data_dict["assistant"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
index_ds = OpenAIIndexDSJsonl(
|
||||
path="/Users/zhifu/funasr1.0/test_local/data_tmp/tmp_wav_10.jsonl"
|
||||
)
|
||||
print(index_ds.contents)
|
||||
pass
|
||||
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
|
||||
import re
|
||||
import torch
|
||||
import random
|
||||
import traceback
|
||||
|
||||
@ -341,3 +341,321 @@ class LLMASR(nn.Module):
|
||||
ibest_writer["text"][key[0]] = text
|
||||
|
||||
return results, meta_data
|
||||
|
||||
|
||||
@tables.register("model_classes", "LLMASR2")
|
||||
class LLMASR2(nn.Module):
|
||||
""" """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
specaug: str = None,
|
||||
specaug_conf: dict = None,
|
||||
normalize: str = None,
|
||||
normalize_conf: dict = None,
|
||||
audio_encoder: str = None,
|
||||
audio_encoder_conf: dict = None,
|
||||
audio_adaptor: str = None,
|
||||
audio_adaptor_conf: dict = None,
|
||||
decoder: str = None,
|
||||
decoder_conf: dict = None,
|
||||
ctc: str = None,
|
||||
ctc_conf: dict = None,
|
||||
ctc_weight: float = 0.5,
|
||||
llm: str = None,
|
||||
llm_conf: dict = None,
|
||||
input_size: int = 80,
|
||||
vocab_size: int = -1,
|
||||
ignore_id: int = -1,
|
||||
blank_id: int = 0,
|
||||
sos: int = 1,
|
||||
eos: int = 2,
|
||||
lsm_weight: float = 0.0,
|
||||
length_normalized_loss: bool = False,
|
||||
report_cer: bool = True,
|
||||
report_wer: bool = True,
|
||||
sym_space: str = "<space>",
|
||||
sym_blank: str = "<blank>",
|
||||
# extract_feats_in_collect_stats: bool = True,
|
||||
share_embedding: bool = False,
|
||||
# preencoder: Optional[AbsPreEncoder] = None,
|
||||
# postencoder: Optional[AbsPostEncoder] = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
if specaug is not None:
|
||||
specaug_class = tables.specaug_classes.get(specaug)
|
||||
specaug = specaug_class(**specaug_conf)
|
||||
if normalize is not None:
|
||||
normalize_class = tables.normalize_classes.get(normalize)
|
||||
normalize = normalize_class(**normalize_conf)
|
||||
|
||||
# audio encoder
|
||||
hub = audio_encoder_conf.get("hub", None)
|
||||
if hub == "ms":
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model=audio_encoder, model_revision="master")
|
||||
# frontend = model.kwargs.get("frontend")
|
||||
audio_encoder_output_size = model.model.encoder_output_size
|
||||
|
||||
audio_encoder = model.model.model.encoder
|
||||
|
||||
# self.frontend = frontend
|
||||
|
||||
elif hub == "hf":
|
||||
pass
|
||||
else:
|
||||
encoder_class = tables.encoder_classes.get(audio_encoder)
|
||||
audio_encoder = encoder_class(input_size=input_size, **audio_encoder_conf)
|
||||
audio_encoder_output_size = audio_encoder.output_size()
|
||||
freeze = audio_encoder_conf.get("freeze", True)
|
||||
if freeze:
|
||||
for name, param in audio_encoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
audio_encoder.eval()
|
||||
|
||||
self.audio_encoder = audio_encoder
|
||||
|
||||
# llm
|
||||
hub = llm_conf.get("hub", "hf")
|
||||
self.llm = None
|
||||
# if hub == "hf":
|
||||
# from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
|
||||
#
|
||||
# init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
|
||||
#
|
||||
# model = AutoModelForCausalLM.from_pretrained(
|
||||
# init_param_path,
|
||||
# load_in_8bit=None,
|
||||
# device_map=None,
|
||||
# use_cache=None,
|
||||
# )
|
||||
# freeze = llm_conf.get("freeze", True)
|
||||
# if freeze:
|
||||
# for name, param in model.named_parameters():
|
||||
# param.requires_grad = False
|
||||
# model.eval()
|
||||
# self.llm = model
|
||||
|
||||
# adaptor
|
||||
adaptor_class = tables.adaptor_classes.get(audio_adaptor)
|
||||
audio_adaptor_conf["encoder_dim"] = audio_encoder_output_size
|
||||
audio_adaptor = adaptor_class(**audio_adaptor_conf)
|
||||
|
||||
self.audio_adaptor = audio_adaptor
|
||||
|
||||
self.blank_id = blank_id
|
||||
self.sos = sos if sos is not None else vocab_size - 1
|
||||
self.eos = eos if eos is not None else vocab_size - 1
|
||||
self.vocab_size = vocab_size
|
||||
self.ignore_id = ignore_id
|
||||
self.specaug = specaug
|
||||
self.normalize = normalize
|
||||
|
||||
self.criterion_att = LabelSmoothingLoss(
|
||||
size=vocab_size,
|
||||
padding_idx=ignore_id,
|
||||
smoothing=lsm_weight,
|
||||
normalize_length=length_normalized_loss,
|
||||
)
|
||||
|
||||
self.error_calculator = None
|
||||
|
||||
self.length_normalized_loss = length_normalized_loss
|
||||
self.beam_search = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
labels_ids: torch.Tensor,
|
||||
fbank_beg: torch.Tensor,
|
||||
fbank_mask: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Encoder + Decoder + Calc loss
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
text: (Batch, Length)
|
||||
text_lengths: (Batch,)
|
||||
"""
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
if len(speech_lengths.size()) > 1:
|
||||
speech_lengths = speech_lengths[:, 0]
|
||||
|
||||
batch_size = speech.shape[0]
|
||||
|
||||
# audio encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
|
||||
# audio_adaptor
|
||||
encoder_out = self.audio_adaptor(encoder_out)
|
||||
|
||||
input_ids[input_ids == -1] = 0
|
||||
input_ids[input_ids == -100] = 0
|
||||
if hasattr(self.llm.model, "embed_tokens"):
|
||||
inputs_embeds = self.llm.model.embed_tokens(input_ids)
|
||||
elif hasattr(self.llm.model.model, "embed_tokens"):
|
||||
inputs_embeds = self.llm.model.model.embed_tokens(input_ids)
|
||||
else:
|
||||
inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids)
|
||||
|
||||
batch_size, token_num, dims = inputs_embeds.shape
|
||||
_, l, _ = encoder_out.shape
|
||||
for batch_idx in range(batch_size):
|
||||
fbank_beg_idx = fbank_beg[batch_idx, 0].item()
|
||||
inputs_embeds[batch_idx, fbank_beg_idx : fbank_beg_idx + l, :] = encoder_out[
|
||||
batch_idx, :l, :
|
||||
]
|
||||
|
||||
model_outputs = self.llm(
|
||||
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
|
||||
)
|
||||
loss = model_outputs.loss
|
||||
|
||||
stats = {}
|
||||
with torch.no_grad():
|
||||
preds = torch.argmax(model_outputs.logits, -1)
|
||||
acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
|
||||
stats["acc"] = acc_att
|
||||
|
||||
stats["loss"] = torch.clone(loss.detach())
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
if self.length_normalized_loss:
|
||||
batch_size = int((text_lengths + 1).sum())
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
def encode(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
**kwargs,
|
||||
):
|
||||
speech = speech.permute(0, 2, 1)
|
||||
res = self.audio_encoder(speech)
|
||||
if isinstance(res, (list, tuple)):
|
||||
encoder_out, encoder_out_lens = res[0], res[1]
|
||||
else:
|
||||
encoder_out, encoder_out_lens = res, speech_lengths
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def inference(
|
||||
self,
|
||||
data_in,
|
||||
data_lengths=None,
|
||||
key: list = None,
|
||||
tokenizer=None,
|
||||
frontend=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
prompt = kwargs.get("prompt", "Transcribe speech to text.")
|
||||
|
||||
if kwargs.get("batch_size", 1) > 1:
|
||||
raise NotImplementedError("batch decoding is not implemented")
|
||||
|
||||
meta_data = {}
|
||||
if (
|
||||
isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
|
||||
): # fbank
|
||||
speech, speech_lengths = data_in, data_lengths
|
||||
if len(speech.shape) < 3:
|
||||
speech = speech[None, :, :]
|
||||
if speech_lengths is None:
|
||||
speech_lengths = speech.shape[1]
|
||||
else:
|
||||
# extract fbank feats
|
||||
time1 = time.perf_counter()
|
||||
audio_sample_list = load_audio_text_image_video(
|
||||
data_in,
|
||||
fs=frontend.fs,
|
||||
audio_fs=kwargs.get("fs", 16000),
|
||||
data_type=kwargs.get("data_type", "sound"),
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
time2 = time.perf_counter()
|
||||
meta_data["load_data"] = f"{time2 - time1:0.3f}"
|
||||
speech, speech_lengths = extract_fbank(
|
||||
audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
|
||||
)
|
||||
time3 = time.perf_counter()
|
||||
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
|
||||
meta_data["batch_data_time"] = (
|
||||
speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
|
||||
)
|
||||
|
||||
speech = speech.to(device=kwargs["device"])
|
||||
speech_lengths = speech_lengths.to(device=kwargs["device"])
|
||||
|
||||
# Encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
|
||||
# adaptor
|
||||
encoder_out = self.audio_adaptor(encoder_out)
|
||||
|
||||
prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(prompt)
|
||||
prompt_ids = tokenizer.encode(prompt_pre)
|
||||
prompt_length = len(prompt_ids)
|
||||
prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(kwargs["device"])
|
||||
|
||||
if hasattr(self.llm.model, "embed_tokens"):
|
||||
inputs_embeds = self.llm.model.embed_tokens(prompt_ids)
|
||||
elif hasattr(self.llm.model.model, "embed_tokens"):
|
||||
inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids)
|
||||
else:
|
||||
inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids)
|
||||
|
||||
inputs_embeds = torch.cat(
|
||||
(inputs_embeds[None, :, :], encoder_out), dim=1
|
||||
) # [prompt, audio]
|
||||
attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(
|
||||
kwargs["device"]
|
||||
)
|
||||
|
||||
preds = self.llm.generate(
|
||||
inputs_embeds=inputs_embeds,
|
||||
max_length=kwargs.get("max_length", 200),
|
||||
max_new_tokens=kwargs.get("max_new_tokens", 200),
|
||||
num_beams=kwargs.get("num_beams", 4),
|
||||
do_sample=kwargs.get("do_sample", False),
|
||||
min_length=kwargs.get("min_length", 1),
|
||||
top_p=kwargs.get("top_p", 1.0),
|
||||
repetition_penalty=kwargs.get("repetition_penalty", 1.0),
|
||||
length_penalty=kwargs.get("length_penalty", 1.0),
|
||||
temperature=kwargs.get("temperature", 1.0),
|
||||
attention_mask=attention_mask,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
)
|
||||
|
||||
text = tokenizer.batch_decode(preds, add_special_tokens=False, skip_special_tokens=True)
|
||||
|
||||
text = text[0].split(": ")[-1]
|
||||
text = text.strip()
|
||||
|
||||
# preds = torch.argmax(model_outputs.logits, -1)
|
||||
|
||||
ibest_writer = None
|
||||
if kwargs.get("output_dir") is not None:
|
||||
if not hasattr(self, "writer"):
|
||||
self.writer = DatadirWriter(kwargs.get("output_dir"))
|
||||
ibest_writer = self.writer[f"{0 + 1}best_recog"]
|
||||
|
||||
results = []
|
||||
result_i = {"key": key[0], "text": text}
|
||||
results.append(result_i)
|
||||
|
||||
if ibest_writer is not None:
|
||||
ibest_writer["text"][key[0]] = text
|
||||
|
||||
return results, meta_data
|
||||
|
||||
Loading…
Reference in New Issue
Block a user