auto frontend

This commit is contained in:
游雁 2024-06-06 15:45:32 +08:00
parent 6a7d34392c
commit 27256ed429
5 changed files with 630 additions and 0 deletions

View File

@ -0,0 +1,216 @@
import logging
import re
import torch
import random
import traceback
from funasr.register import tables
from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
@tables.register("dataset_classes", "OpenAIDataset")
class OpenAIDataset(torch.utils.data.Dataset):
"""
SenseVoiceDataset
"""
def __init__(
self,
path,
index_ds: str = None,
frontend=None,
tokenizer=None,
int_pad_value: int = -1,
float_pad_value: float = 0.0,
**kwargs,
):
super().__init__()
index_ds_class = tables.index_ds_classes.get(index_ds)
self.index_ds = index_ds_class(path, **kwargs)
preprocessor_speech = kwargs.get("preprocessor_speech", None)
if preprocessor_speech:
preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
preprocessor_speech = preprocessor_speech_class(
**kwargs.get("preprocessor_speech_conf")
)
self.preprocessor_speech = preprocessor_speech
preprocessor_text = kwargs.get("preprocessor_text", None)
if preprocessor_text:
preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
self.preprocessor_text = preprocessor_text
self.frontend = frontend
self.fs = 16000 if frontend is None else frontend.fs
self.data_type = "sound"
self.tokenizer = tokenizer
self.int_pad_value = int_pad_value
self.float_pad_value = float_pad_value
self.sos = kwargs.get("sos", "<|startoftranscript|>")
self.eos = kwargs.get("eos", "<|endoftext|>")
self.batch_size = kwargs.get("batch_size")
self.batch_type = kwargs.get("batch_type")
self.prompt_ids_len = 0
self.retry = kwargs.get("retry", 5)
self.permute = False
from funasr.frontends.whisper_frontend import WhisperFrontend
if isinstance(self.frontend, WhisperFrontend):
self.permute = True
self.pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)")
def get_source_len(self, index):
item = self.index_ds[index]
return self.index_ds.get_source_len(item)
def get_target_len(self, index):
item = self.index_ds[index]
return self.index_ds.get_target_len(item)
def __len__(self):
return len(self.index_ds)
def __getitem__(self, index):
# import pdb;
# pdb.set_trace()
output = None
for idx in range(self.retry):
if idx == 0:
index_cur = index
else:
index_cur = torch.randint(0, len(self.index_ds), ()).item()
item = self.index_ds[index_cur]
system = item["system"]
user = item["user"]
assistant = item["assistant"]
input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg = [], [], [], [], [], []
for i, (system_prompt, user_prompt, target_out) in enumerate(
zip(system, user, assistant)
):
source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
splits = self.pattern.split(source_input)
source_ids = []
fbank_mask_i = []
fbank_beg_i = []
fbank_lens_i = []
for k, sub_str in enumerate(splits):
if not sub_str.startswith("<|startofspeech|>"):
sub_token = self.tokenizer.encode(sub_str)
source_ids += sub_token
fbank_mask_i += [0] * len(sub_token)
else:
sub_str = sub_str.replace("<|startofspeech|>", "").replace(
"<|endofspeech|>", ""
)
if sub_str.startswith("!"):
data_src = load_audio_text_image_video(sub_str[1:], fs=self.fs)
speech, speech_lengths = extract_fbank(
data_src,
data_type=self.data_type,
frontend=self.frontend,
is_final=True,
) # speech: [b, T, d]
if self.permute:
speech = speech.permute(0, 2, 1)
if speech_lengths > self.batch_size:
continue
fbank_lens = speech_lengths[0].item()
olens = 1 + (fbanks_len - 3 + 2 * 1) // 2
olens = 1 + (olens - 3 + 2 * 1) // 2
sub_token_len = (olens - 1) // 2 + 1
sub_token = [0] * sub_token_len[0]
fbank_beg_i = [len(source_ids)]
source_ids += sub_token
fbank_mask_i += [1] * len(sub_token)
source_mask = [-100] * len(source_ids)
target_out = f"{target_out}<|im_end|>"
target_ids = tokenizer.encode(target_out)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
fbank_mask += fbank_mask_i
fbank_beg.append(fbank_beg_i)
input_ids = torch.tensor(input_ids, dtype=torch.int64)
attention_mask = torch.tensor([len(input_ids)], dtype=torch.int32)
labels = torch.tensor(labels, dtype=torch.int64)
fbank = speech[0, :, :]
fbank_lens = speech_lengths
fbank_mask = torch.tensor(fbank_mask, dtype=torch.float32)
fbank_beg = torch.tensor(fbank_beg, dtype=torch.int32)
output = {
"speech": fbank,
"speech_lengths": fbank_lens,
"fbank_mask": fbank_mask,
"fbank_beg": fbank_beg,
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels_ids": labels,
}
break
return output
def collator(self, samples: list = None):
outputs = {}
for sample in samples:
if sample is None:
continue
for key in sample.keys():
if key not in outputs:
outputs[key] = []
outputs[key].append(sample[key])
for key, data_list in outputs.items():
if isinstance(data_list[0], torch.Tensor):
if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32:
pad_value = self.int_pad_value
else:
pad_value = self.float_pad_value
outputs[key] = torch.nn.utils.rnn.pad_sequence(
data_list, batch_first=True, padding_value=pad_value
)
if self.batch_type != "example":
for i in range(10):
outputs = self._filter_badcase(outputs, i=i)
return outputs
def _filter_badcase(self, outputs, i=0):
b, t, _ = outputs["speech"].shape
if b * t > self.batch_size * 1.25:
beg = torch.randint(0, 2, ()).item()
if b < 2:
beg = 0
logging.info(
f"Warning, b * t: {b * t} > {self.batch_size}, drop half data {i}th, beg:{beg}"
)
for key, data_list in outputs.items():
outputs[key] = outputs[key][beg : beg + b : 2]
speech_lengths_max = outputs["speech_lengths"].max().item()
outputs["speech"] = outputs["speech"][:, :speech_lengths_max, :]
text_lengths_max = outputs["text_lengths"].max().item()
outputs["text"] = outputs["text"][:, :text_lengths_max]
target_mask_lengths_max = outputs["target_mask_lengths"].max().item()
outputs["target_mask"] = outputs["target_mask"][:, :target_mask_lengths_max]
return outputs

View File

@ -0,0 +1,95 @@
import os
import json
import torch
import logging
import librosa
import random
import torch.distributed as dist
from funasr.register import tables
@tables.register("index_ds_classes", "OpenAIIndexDSJsonl")
class OpenAIIndexDSJsonl(torch.utils.data.Dataset): # torch.utils.data.Dataset
def __init__(self, path: str, **kwargs):
super().__init__()
self.max_source_length = kwargs.get("max_source_length", 2048)
self.min_source_length = kwargs.get("min_source_length", 0)
self.max_target_length = kwargs.get("max_target_length", 2048)
self.min_target_length = kwargs.get("min_target_length", 0)
self.max_token_length = kwargs.get("max_token_length", 2200)
is_training = kwargs.get("is_training", True)
if not (path.endswith(".jsonl") or path.endswith(".json")):
# jsonl list file
data_split_num = kwargs.get("data_split_num", 1)
data_split_i = kwargs.get("data_split_i", 0)
if not is_training:
data_split_num = 1
data_split_i = 0
with open(path, encoding="utf-8") as fin:
file_list_all = fin.readlines()
num_per_slice = (len(file_list_all) - 1) // data_split_num + 1 # 16
file_list = file_list_all[
data_split_i * num_per_slice : (data_split_i + 1) * num_per_slice
]
logging.info(
f"is_training: {is_training}, data_split_num: {data_split_num}, data_split_i: {data_split_i}, \nfile_list: {file_list}, \nfile_list_all: {file_list_all}"
)
else:
file_list = [path]
contents = []
for file_json in file_list:
with open(file_json.strip(), encoding="utf-8") as fin:
for line in fin:
data = json.loads(line.strip())["messages"]
system, user, assistant = [], [], []
for i, item in enumerate(data):
role = item["role"]
content = item["content"]
if role == "system":
system.append(content)
elif role == "user":
user.append(content)
elif role == "assistant":
assistant.append(content)
system = system * len(user)
contents_i = {"system": system, "user": user, "assistant": assistant}
contents.append(contents_i)
self.contents = contents
logging.info("total_num of samplers: {}, {}".format(len(self.contents), path))
def __len__(self):
return len(self.contents)
def __getitem__(self, index):
data = self.contents[index]
return data
def get_source_len(self, data_dict):
return len(data_dict["system"]) + len(data_dict["user"])
def get_target_len(self, data_dict):
return len(data_dict["assistant"])
if __name__ == "__main__":
index_ds = OpenAIIndexDSJsonl(
path="/Users/zhifu/funasr1.0/test_local/data_tmp/tmp_wav_10.jsonl"
)
print(index_ds.contents)
pass

View File

@ -1,5 +1,6 @@
import logging
import re
import torch
import random
import traceback

View File

@ -341,3 +341,321 @@ class LLMASR(nn.Module):
ibest_writer["text"][key[0]] = text
return results, meta_data
@tables.register("model_classes", "LLMASR2")
class LLMASR2(nn.Module):
""" """
def __init__(
self,
specaug: str = None,
specaug_conf: dict = None,
normalize: str = None,
normalize_conf: dict = None,
audio_encoder: str = None,
audio_encoder_conf: dict = None,
audio_adaptor: str = None,
audio_adaptor_conf: dict = None,
decoder: str = None,
decoder_conf: dict = None,
ctc: str = None,
ctc_conf: dict = None,
ctc_weight: float = 0.5,
llm: str = None,
llm_conf: dict = None,
input_size: int = 80,
vocab_size: int = -1,
ignore_id: int = -1,
blank_id: int = 0,
sos: int = 1,
eos: int = 2,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
report_cer: bool = True,
report_wer: bool = True,
sym_space: str = "<space>",
sym_blank: str = "<blank>",
# extract_feats_in_collect_stats: bool = True,
share_embedding: bool = False,
# preencoder: Optional[AbsPreEncoder] = None,
# postencoder: Optional[AbsPostEncoder] = None,
**kwargs,
):
super().__init__()
if specaug is not None:
specaug_class = tables.specaug_classes.get(specaug)
specaug = specaug_class(**specaug_conf)
if normalize is not None:
normalize_class = tables.normalize_classes.get(normalize)
normalize = normalize_class(**normalize_conf)
# audio encoder
hub = audio_encoder_conf.get("hub", None)
if hub == "ms":
from funasr import AutoModel
model = AutoModel(model=audio_encoder, model_revision="master")
# frontend = model.kwargs.get("frontend")
audio_encoder_output_size = model.model.encoder_output_size
audio_encoder = model.model.model.encoder
# self.frontend = frontend
elif hub == "hf":
pass
else:
encoder_class = tables.encoder_classes.get(audio_encoder)
audio_encoder = encoder_class(input_size=input_size, **audio_encoder_conf)
audio_encoder_output_size = audio_encoder.output_size()
freeze = audio_encoder_conf.get("freeze", True)
if freeze:
for name, param in audio_encoder.named_parameters():
param.requires_grad = False
audio_encoder.eval()
self.audio_encoder = audio_encoder
# llm
hub = llm_conf.get("hub", "hf")
self.llm = None
# if hub == "hf":
# from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
#
# init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
#
# model = AutoModelForCausalLM.from_pretrained(
# init_param_path,
# load_in_8bit=None,
# device_map=None,
# use_cache=None,
# )
# freeze = llm_conf.get("freeze", True)
# if freeze:
# for name, param in model.named_parameters():
# param.requires_grad = False
# model.eval()
# self.llm = model
# adaptor
adaptor_class = tables.adaptor_classes.get(audio_adaptor)
audio_adaptor_conf["encoder_dim"] = audio_encoder_output_size
audio_adaptor = adaptor_class(**audio_adaptor_conf)
self.audio_adaptor = audio_adaptor
self.blank_id = blank_id
self.sos = sos if sos is not None else vocab_size - 1
self.eos = eos if eos is not None else vocab_size - 1
self.vocab_size = vocab_size
self.ignore_id = ignore_id
self.specaug = specaug
self.normalize = normalize
self.criterion_att = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
self.error_calculator = None
self.length_normalized_loss = length_normalized_loss
self.beam_search = None
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels_ids: torch.Tensor,
fbank_beg: torch.Tensor,
fbank_mask: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
# import pdb;
# pdb.set_trace()
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# audio encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
# audio_adaptor
encoder_out = self.audio_adaptor(encoder_out)
input_ids[input_ids == -1] = 0
input_ids[input_ids == -100] = 0
if hasattr(self.llm.model, "embed_tokens"):
inputs_embeds = self.llm.model.embed_tokens(input_ids)
elif hasattr(self.llm.model.model, "embed_tokens"):
inputs_embeds = self.llm.model.model.embed_tokens(input_ids)
else:
inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids)
batch_size, token_num, dims = inputs_embeds.shape
_, l, _ = encoder_out.shape
for batch_idx in range(batch_size):
fbank_beg_idx = fbank_beg[batch_idx, 0].item()
inputs_embeds[batch_idx, fbank_beg_idx : fbank_beg_idx + l, :] = encoder_out[
batch_idx, :l, :
]
model_outputs = self.llm(
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
)
loss = model_outputs.loss
stats = {}
with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1)
acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
stats["acc"] = acc_att
stats["loss"] = torch.clone(loss.detach())
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = int((text_lengths + 1).sum())
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
):
speech = speech.permute(0, 2, 1)
res = self.audio_encoder(speech)
if isinstance(res, (list, tuple)):
encoder_out, encoder_out_lens = res[0], res[1]
else:
encoder_out, encoder_out_lens = res, speech_lengths
return encoder_out, encoder_out_lens
def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
prompt = kwargs.get("prompt", "Transcribe speech to text.")
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
meta_data = {}
if (
isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
): # fbank
speech, speech_lengths = data_in, data_lengths
if len(speech.shape) < 3:
speech = speech[None, :, :]
if speech_lengths is None:
speech_lengths = speech.shape[1]
else:
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio_text_image_video(
data_in,
fs=frontend.fs,
audio_fs=kwargs.get("fs", 16000),
data_type=kwargs.get("data_type", "sound"),
tokenizer=tokenizer,
)
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(
audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = (
speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
)
speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])
# Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
# adaptor
encoder_out = self.audio_adaptor(encoder_out)
prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(prompt)
prompt_ids = tokenizer.encode(prompt_pre)
prompt_length = len(prompt_ids)
prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(kwargs["device"])
if hasattr(self.llm.model, "embed_tokens"):
inputs_embeds = self.llm.model.embed_tokens(prompt_ids)
elif hasattr(self.llm.model.model, "embed_tokens"):
inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids)
else:
inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids)
inputs_embeds = torch.cat(
(inputs_embeds[None, :, :], encoder_out), dim=1
) # [prompt, audio]
attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(
kwargs["device"]
)
preds = self.llm.generate(
inputs_embeds=inputs_embeds,
max_length=kwargs.get("max_length", 200),
max_new_tokens=kwargs.get("max_new_tokens", 200),
num_beams=kwargs.get("num_beams", 4),
do_sample=kwargs.get("do_sample", False),
min_length=kwargs.get("min_length", 1),
top_p=kwargs.get("top_p", 1.0),
repetition_penalty=kwargs.get("repetition_penalty", 1.0),
length_penalty=kwargs.get("length_penalty", 1.0),
temperature=kwargs.get("temperature", 1.0),
attention_mask=attention_mask,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
text = tokenizer.batch_decode(preds, add_special_tokens=False, skip_special_tokens=True)
text = text[0].split(": ")[-1]
text = text.strip()
# preds = torch.argmax(model_outputs.logits, -1)
ibest_writer = None
if kwargs.get("output_dir") is not None:
if not hasattr(self, "writer"):
self.writer = DatadirWriter(kwargs.get("output_dir"))
ibest_writer = self.writer[f"{0 + 1}best_recog"]
results = []
result_i = {"key": key[0], "text": text}
results.append(result_i)
if ibest_writer is not None:
ibest_writer["text"][key[0]] = text
return results, meta_data