mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
fp16
This commit is contained in:
parent
aaf61bbb64
commit
54b6ff5764
0
funasr/datasets/llm_datasets/__init__.py
Normal file
0
funasr/datasets/llm_datasets/__init__.py
Normal file
130
funasr/datasets/llm_datasets/datasets.py
Normal file
130
funasr/datasets/llm_datasets/datasets.py
Normal file
@ -0,0 +1,130 @@
|
||||
import torch
|
||||
import copy
|
||||
|
||||
from funasr.register import tables
|
||||
from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
|
||||
|
||||
|
||||
@tables.register("dataset_classes", "AudioLLMDataset")
|
||||
class AudioLLMDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
AudioLLMDataset
|
||||
"""
|
||||
def __init__(self,
|
||||
path,
|
||||
index_ds: str = None,
|
||||
frontend=None,
|
||||
tokenizer=None,
|
||||
int_pad_value: int = -1,
|
||||
float_pad_value: float = 0.0,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
index_ds_class = tables.index_ds_classes.get(index_ds)
|
||||
self.index_ds = index_ds_class(path, **kwargs)
|
||||
preprocessor_speech = kwargs.get("preprocessor_speech", None)
|
||||
if preprocessor_speech:
|
||||
preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
|
||||
preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
|
||||
self.preprocessor_speech = preprocessor_speech
|
||||
preprocessor_text = kwargs.get("preprocessor_text", None)
|
||||
if preprocessor_text:
|
||||
preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
|
||||
preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
|
||||
self.preprocessor_text = preprocessor_text
|
||||
|
||||
self.frontend = frontend
|
||||
self.fs = 16000 if frontend is None else frontend.fs
|
||||
self.data_type = "sound"
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.int_pad_value = int_pad_value
|
||||
self.float_pad_value = float_pad_value
|
||||
self.prompt = kwargs.get("prompt", "Transcribe speech to text.")
|
||||
self.prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(
|
||||
self.prompt) # "USER: \nINSTRUCTION: {}\nnINPUT: {}\nASSISTANT: "
|
||||
self.prompt_af = ""
|
||||
|
||||
def get_source_len(self, index):
|
||||
item = self.index_ds[index]
|
||||
return self.index_ds.get_source_len(item)
|
||||
|
||||
def get_target_len(self, index):
|
||||
item = self.index_ds[index]
|
||||
return self.index_ds.get_target_len(item)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.index_ds)
|
||||
|
||||
def __getitem__(self, index):
|
||||
item = self.index_ds[index]
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
source = item["source"]
|
||||
data_src = load_audio_text_image_video(source, fs=self.fs)
|
||||
if self.preprocessor_speech:
|
||||
data_src = self.preprocessor_speech(data_src, fs=self.fs)
|
||||
speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend, is_final=True) # speech: [b, T, d]
|
||||
speech = speech.sequeeze(0)
|
||||
|
||||
target = item["target"]
|
||||
if self.preprocessor_text:
|
||||
target = self.preprocessor_text(target)
|
||||
|
||||
|
||||
prompt_ids_pre = self.tokenizer.encode(self.prompt_pre) # [bos,prompt]
|
||||
prompt_pre_length = len(prompt_ids_pre)
|
||||
|
||||
prompt_input = "{}{}".format(self.prompt_pre, target)
|
||||
prompt_input_ids = self.tokenizer.encode(prompt_input)
|
||||
audio_length = len(prompt_input_ids) - prompt_pre_length
|
||||
input_ids = prompt_input_ids + [self.tokenizer.pad_token_id]
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64) #[bos, prompt, input, pad]
|
||||
input_ids[prompt_pre_length:] = -1 # [bos, prompt,-1,-1]
|
||||
attention_mask = input_ids.ge(-1) # [true, true, true, true], length mask
|
||||
|
||||
prompt_answer = "{}{}".format(self.prompt_pre, target)
|
||||
prompt_answer_ids = self.tokenizer.encode(prompt_answer)
|
||||
answer_length = len(prompt_answer_ids) - prompt_pre_length
|
||||
labels_ids = copy.deepcopy(prompt_input_ids) + [self.tokenizer.eos_token_id]
|
||||
labels_ids = torch.tensor(labels_ids, dtype=torch.int64) # [bos, prompt, input, eos]
|
||||
labels_ids[:prompt_pre_length] = -1 # [-1, -1, input, eos]
|
||||
label_mask = labels_ids.ge(0) # [False,False,True,True]
|
||||
labels_ids[~label_mask] = self.IGNORE_INDEX # [-100,-100,input,eos]
|
||||
|
||||
audio_mask = [0] * prompt_pre_length + [1] * audio_length
|
||||
torch.tensor(audio_mask, dtype=torch.float32)
|
||||
|
||||
ids = self.tokenizer.encode(target)
|
||||
text = torch.tensor(ids, dtype=torch.int64)
|
||||
text_lengths = torch.tensor([len(ids)], dtype=torch.int32)
|
||||
|
||||
return {"speech": speech,
|
||||
"speech_lengths": speech_lengths,
|
||||
"text": text,
|
||||
"text_lengths": text_lengths,
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels_ids": labels_ids,
|
||||
"label_mask": label_mask,
|
||||
"audio_mask": audio_mask,
|
||||
}
|
||||
|
||||
|
||||
def collator(self, samples: list=None):
|
||||
outputs = {}
|
||||
for sample in samples:
|
||||
for key in sample.keys():
|
||||
if key not in outputs:
|
||||
outputs[key] = []
|
||||
outputs[key].append(sample[key])
|
||||
|
||||
for key, data_list in outputs.items():
|
||||
if isinstance(data_list[0], torch.Tensor):
|
||||
if data_list[0].dtype == torch.int64:
|
||||
|
||||
pad_value = self.int_pad_value
|
||||
else:
|
||||
pad_value = self.float_pad_value
|
||||
|
||||
outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
|
||||
return outputs
|
||||
51
funasr/datasets/llm_datasets/preprocessor.py
Normal file
51
funasr/datasets/llm_datasets/preprocessor.py
Normal file
@ -0,0 +1,51 @@
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
import logging
|
||||
import concurrent.futures
|
||||
import librosa
|
||||
import torch.distributed as dist
|
||||
from typing import Collection
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch import nn
|
||||
import random
|
||||
import re
|
||||
from funasr.tokenizer.cleaner import TextCleaner
|
||||
from funasr.register import tables
|
||||
|
||||
|
||||
@tables.register("preprocessor_classes", "SpeechPreprocessSpeedPerturb")
|
||||
class SpeechPreprocessSpeedPerturb(nn.Module):
|
||||
def __init__(self, speed_perturb: list=None, **kwargs):
|
||||
super().__init__()
|
||||
self.speed_perturb = speed_perturb
|
||||
|
||||
def forward(self, waveform, fs, **kwargs):
|
||||
if self.speed_perturb is None:
|
||||
return waveform
|
||||
speed = random.choice(self.speed_perturb)
|
||||
if speed != 1.0:
|
||||
if not isinstance(waveform, torch.Tensor):
|
||||
waveform = torch.tensor(waveform)
|
||||
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
|
||||
waveform.view(1, -1), fs, [['speed', str(speed)], ['rate', str(fs)]])
|
||||
waveform = waveform.view(-1)
|
||||
|
||||
return waveform
|
||||
|
||||
|
||||
@tables.register("preprocessor_classes", "TextPreprocessSegDict")
|
||||
class TextPreprocessSegDict(nn.Module):
|
||||
def __init__(self, seg_dict: str = None,
|
||||
text_cleaner: Collection[str] = None,
|
||||
split_with_space: bool = False,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.text_cleaner = TextCleaner(text_cleaner)
|
||||
|
||||
def forward(self, text, **kwargs):
|
||||
text = self.text_cleaner(text)
|
||||
|
||||
return text
|
||||
277
funasr/datasets/llm_datasets/samplers.py
Normal file
277
funasr/datasets/llm_datasets/samplers.py
Normal file
@ -0,0 +1,277 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import logging
|
||||
import torch.distributed as dist
|
||||
|
||||
from funasr.register import tables
|
||||
|
||||
|
||||
@tables.register("batch_sampler_classes", "DynamicBatchLocalShuffleSampler")
|
||||
class BatchSampler(torch.utils.data.BatchSampler):
|
||||
|
||||
def __init__(self, dataset,
|
||||
batch_type: str = "example",
|
||||
batch_size: int = 100,
|
||||
buffer_size: int = 30,
|
||||
drop_last: bool = False,
|
||||
shuffle: bool = True,
|
||||
is_training: bool = True,
|
||||
**kwargs):
|
||||
|
||||
self.drop_last = drop_last
|
||||
self.pre_idx = -1
|
||||
self.dataset = dataset
|
||||
self.total_samples = len(dataset)
|
||||
self.batch_type = batch_type
|
||||
self.batch_size = int(batch_size)
|
||||
self.buffer_size = buffer_size
|
||||
self.max_token_length = kwargs.get("max_token_length", 5000)
|
||||
self.shuffle_idx = np.arange(self.total_samples)
|
||||
self.shuffle = shuffle and is_training
|
||||
self.length_scale_source = kwargs.get("length_scale_source", 1.0)
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return (self.total_samples-1) // self.batch_size + 1
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
np.random.seed(epoch)
|
||||
|
||||
def __iter__(self):
|
||||
|
||||
if self.shuffle:
|
||||
np.random.shuffle(self.shuffle_idx)
|
||||
|
||||
batch = []
|
||||
max_token = 0
|
||||
num_sample = 0
|
||||
|
||||
iter_num = (self.total_samples - 1) // self.buffer_size + 1
|
||||
# print("iter_num: ", iter_num)
|
||||
for iter in range(self.pre_idx + 1, iter_num):
|
||||
datalen_with_index = []
|
||||
for i in range(self.buffer_size):
|
||||
idx = iter * self.buffer_size + i
|
||||
if idx >= self.total_samples:
|
||||
continue
|
||||
|
||||
idx_map = self.shuffle_idx[idx]
|
||||
# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
|
||||
target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
|
||||
source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
|
||||
sample_len_cur = source_len + target_len
|
||||
|
||||
|
||||
datalen_with_index.append([idx, sample_len_cur])
|
||||
|
||||
datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
|
||||
for item in datalen_with_index_sort:
|
||||
idx, sample_len_cur_raw = item
|
||||
if sample_len_cur_raw > self.max_token_length:
|
||||
continue
|
||||
|
||||
max_token_cur = max(max_token, sample_len_cur_raw)
|
||||
max_token_padding = 1 + num_sample
|
||||
if self.batch_type != 'example':
|
||||
max_token_padding *= max_token_cur
|
||||
if max_token_padding <= self.batch_size:
|
||||
batch.append(idx)
|
||||
max_token = max_token_cur
|
||||
num_sample += 1
|
||||
else:
|
||||
yield batch
|
||||
batch = [idx]
|
||||
max_token = sample_len_cur_raw
|
||||
num_sample = 1
|
||||
|
||||
|
||||
@tables.register("batch_sampler_classes", "BatchSampler")
|
||||
@tables.register("batch_sampler_classes", "RankFullLocalShuffleBatchSampler")
|
||||
class RankFullLocalShuffleBatchSampler(torch.utils.data.BatchSampler):
|
||||
|
||||
def __init__(self, dataset,
|
||||
batch_type: str = "example",
|
||||
batch_size: int = 100,
|
||||
buffer_size: int = 30,
|
||||
drop_last: bool = True,
|
||||
shuffle: bool = True,
|
||||
is_training: bool = True,
|
||||
**kwargs):
|
||||
|
||||
self.drop_last = drop_last
|
||||
self.pre_idx = -1
|
||||
self.dataset = dataset
|
||||
self.total_samples = len(dataset)
|
||||
self.batch_type = batch_type
|
||||
self.batch_size = int(batch_size)
|
||||
self.buffer_size = buffer_size
|
||||
self.max_token_length = kwargs.get("max_token_length", 1500)
|
||||
self.shuffle_idx = np.arange(self.total_samples)
|
||||
self.shuffle = shuffle and is_training
|
||||
self.length_scale_source = kwargs.get("length_scale_source", 1.0)
|
||||
|
||||
try:
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
except:
|
||||
rank = 0
|
||||
world_size = 1
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
|
||||
def __len__(self):
|
||||
return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
np.random.seed(epoch)
|
||||
|
||||
def __iter__(self):
|
||||
|
||||
batch_size_total = self.batch_size * self.world_size
|
||||
|
||||
if self.shuffle:
|
||||
np.random.shuffle(self.shuffle_idx)
|
||||
|
||||
batch = []
|
||||
max_token = 0
|
||||
num_sample = 0
|
||||
|
||||
iter_num = (self.total_samples - 1) // self.buffer_size + 1
|
||||
# print("iter_num: ", iter_num)
|
||||
for iter in range(self.pre_idx + 1, iter_num):
|
||||
# if iter == iter_num -1 and self.drop_last:
|
||||
# continue
|
||||
datalen_with_index = []
|
||||
for i in range(self.buffer_size):
|
||||
idx = iter * self.buffer_size + i
|
||||
if idx >= self.total_samples:
|
||||
continue
|
||||
|
||||
idx_map = self.shuffle_idx[idx]
|
||||
# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
|
||||
|
||||
source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
|
||||
target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
|
||||
sample_len_cur = source_len + target_len
|
||||
|
||||
datalen_with_index.append([idx, sample_len_cur])
|
||||
|
||||
datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
|
||||
for item in datalen_with_index_sort:
|
||||
idx, sample_len_cur_raw = item
|
||||
if sample_len_cur_raw > self.max_token_length:
|
||||
continue
|
||||
|
||||
max_token_cur = max(max_token, sample_len_cur_raw)
|
||||
max_token_padding = 1 + num_sample
|
||||
# if self.batch_type != 'example':
|
||||
# max_token_padding *= max_token_cur
|
||||
if max_token_padding <= batch_size_total:
|
||||
batch.append(idx)
|
||||
max_token = max_token_cur
|
||||
num_sample += 1
|
||||
else:
|
||||
batch_rank = batch[self.rank*self.batch_size: (self.rank+1)*self.batch_size]
|
||||
yield batch_rank
|
||||
batch = [idx]
|
||||
max_token = sample_len_cur_raw
|
||||
num_sample = 1
|
||||
|
||||
|
||||
@tables.register("batch_sampler_classes", "RankFullLocalShuffleDynamicBatchSampler")
|
||||
class RankFullLocalShuffleDynamicBatchSampler(torch.utils.data.BatchSampler):
|
||||
|
||||
def __init__(self, dataset,
|
||||
batch_type: str = "example",
|
||||
batch_size: int = 100,
|
||||
buffer_size: int = 30,
|
||||
drop_last: bool = True,
|
||||
shuffle: bool = True,
|
||||
is_training: bool = True,
|
||||
**kwargs):
|
||||
|
||||
self.drop_last = drop_last
|
||||
self.pre_idx = -1
|
||||
self.dataset = dataset
|
||||
self.total_samples = len(dataset)
|
||||
self.batch_type = batch_type
|
||||
self.batch_size = int(batch_size)
|
||||
self.buffer_size = buffer_size
|
||||
self.max_token_length = kwargs.get("max_token_length", 1500)
|
||||
self.shuffle_idx = np.arange(self.total_samples)
|
||||
self.shuffle = shuffle and is_training
|
||||
self.length_scale_source = kwargs.get("length_scale_source", 1.0)
|
||||
|
||||
try:
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
except:
|
||||
rank = 0
|
||||
world_size = 1
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
|
||||
def __len__(self):
|
||||
return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
np.random.seed(epoch)
|
||||
|
||||
def __iter__(self):
|
||||
|
||||
batch_size_total = self.batch_size * self.world_size
|
||||
if self.shuffle:
|
||||
np.random.shuffle(self.shuffle_idx)
|
||||
|
||||
batch_list_all_rank = []
|
||||
batch_list_cur = []
|
||||
max_token = 0
|
||||
num_sample = 0
|
||||
|
||||
iter_num = (self.total_samples - 1) // self.buffer_size + 1
|
||||
# print("iter_num: ", iter_num)
|
||||
for iter in range(self.pre_idx + 1, iter_num):
|
||||
# if iter == iter_num - 1 and self.drop_last:
|
||||
# continue
|
||||
datalen_with_index = []
|
||||
for i in range(self.buffer_size):
|
||||
idx = iter * self.buffer_size + i
|
||||
if idx >= self.total_samples:
|
||||
continue
|
||||
|
||||
idx_map = self.shuffle_idx[idx]
|
||||
# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
|
||||
|
||||
source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
|
||||
target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
|
||||
sample_len_cur = source_len + target_len
|
||||
|
||||
datalen_with_index.append([idx, sample_len_cur])
|
||||
|
||||
datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
|
||||
for ii, item in enumerate(datalen_with_index_sort):
|
||||
is_last_batch = iter == iter_num - 1 and ii == len(datalen_with_index_sort)
|
||||
idx, sample_len_cur_raw = item
|
||||
if sample_len_cur_raw > self.max_token_length:
|
||||
continue
|
||||
|
||||
max_token_cur = max(max_token, sample_len_cur_raw)
|
||||
max_token_padding = 1 + num_sample
|
||||
|
||||
if self.batch_type != 'example':
|
||||
max_token_padding *= max_token_cur
|
||||
if len(batch_list_all_rank) < self.world_size:
|
||||
|
||||
if max_token_padding <= self.batch_size:
|
||||
batch_list_cur.append(idx)
|
||||
max_token = max_token_cur
|
||||
num_sample += 1
|
||||
else:
|
||||
batch_list_all_rank.append(batch_list_cur)
|
||||
batch_list_cur = []
|
||||
else:
|
||||
batch_rank = batch_list_all_rank[self.rank]
|
||||
yield batch_rank
|
||||
batch_list_all_rank = [idx]
|
||||
max_token = sample_len_cur_raw
|
||||
num_sample = 1
|
||||
96
funasr/datasets/llm_datasets/scp2jsonl.py
Normal file
96
funasr/datasets/llm_datasets/scp2jsonl.py
Normal file
@ -0,0 +1,96 @@
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
import logging
|
||||
import hydra
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
import concurrent.futures
|
||||
import librosa
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
|
||||
def gen_jsonl_from_wav_text_list(path, data_type_list=("source", "target"), jsonl_file_out:str=None, **kwargs):
|
||||
try:
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
except:
|
||||
rank = 0
|
||||
world_size = 1
|
||||
|
||||
cpu_cores = os.cpu_count() or 1
|
||||
print(f"convert wav.scp text to jsonl, ncpu: {cpu_cores}")
|
||||
if rank == 0:
|
||||
json_dict = {}
|
||||
for data_type, data_file in zip(data_type_list, path):
|
||||
json_dict[data_type] = {}
|
||||
with open(data_file, "r") as f:
|
||||
|
||||
data_file_lists = f.readlines()
|
||||
lines_for_each_th = (len(data_file_lists)-1)//cpu_cores + 1
|
||||
task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=cpu_cores) as executor:
|
||||
|
||||
futures = [executor.submit(parse_context_length, data_file_lists[i*lines_for_each_th:(i+1)*lines_for_each_th], data_type) for i in range(task_num)]
|
||||
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
|
||||
json_dict[data_type].update(future.result())
|
||||
# print(json_dict)
|
||||
|
||||
with open(jsonl_file_out, "w") as f:
|
||||
for key in json_dict[data_type_list[0]].keys():
|
||||
jsonl_line = {"key": key}
|
||||
for data_file in data_type_list:
|
||||
jsonl_line.update(json_dict[data_file][key])
|
||||
jsonl_line = json.dumps(jsonl_line, ensure_ascii=False)
|
||||
f.write(jsonl_line+"\n")
|
||||
f.flush()
|
||||
|
||||
else:
|
||||
pass
|
||||
|
||||
if world_size > 1:
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def parse_context_length(data_list: list, data_type: str):
|
||||
|
||||
res = {}
|
||||
for i, line in enumerate(data_list):
|
||||
key, line = line.strip().split(maxsplit=1)
|
||||
line = line.strip()
|
||||
if os.path.exists(line):
|
||||
waveform, _ = librosa.load(line, sr=16000)
|
||||
sample_num = len(waveform)
|
||||
context_len = int(sample_num//16000*1000/10)
|
||||
else:
|
||||
context_len = len(line.split()) if " " in line else len(line)
|
||||
res[key] = {data_type: line, f"{data_type}_len": context_len}
|
||||
return res
|
||||
|
||||
|
||||
@hydra.main(config_name=None, version_base=None)
|
||||
def main_hydra(cfg: DictConfig):
|
||||
|
||||
kwargs = OmegaConf.to_container(cfg, resolve=True)
|
||||
|
||||
scp_file_list = kwargs.get("scp_file_list", ("/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"))
|
||||
if isinstance(scp_file_list, str):
|
||||
scp_file_list = eval(scp_file_list)
|
||||
data_type_list = kwargs.get("data_type_list", ("source", "target"))
|
||||
jsonl_file_out = kwargs.get("jsonl_file_out", "/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl")
|
||||
gen_jsonl_from_wav_text_list(scp_file_list, data_type_list=data_type_list, jsonl_file_out=jsonl_file_out)
|
||||
|
||||
|
||||
"""
|
||||
python -m funasr.datasets.audio_datasets.scp2jsonl \
|
||||
++scp_file_list='["/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"]' \
|
||||
++data_type_list='["source", "target"]' \
|
||||
++jsonl_file_out=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl
|
||||
"""
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_hydra()
|
||||
|
||||
|
||||
@ -21,3 +21,22 @@ def th_accuracy(pad_outputs, pad_targets, ignore_label):
|
||||
)
|
||||
denominator = torch.sum(mask)
|
||||
return float(numerator) / float(denominator)
|
||||
|
||||
def compute_accuracy(pad_outputs, pad_targets, ignore_label):
|
||||
"""Calculate accuracy.
|
||||
|
||||
Args:
|
||||
pad_outputs (LongTensor): Prediction tensors (B, Lmax).
|
||||
pad_targets (LongTensor): Target label tensors (B, Lmax).
|
||||
ignore_label (int): Ignore label id.
|
||||
|
||||
Returns:
|
||||
float: Accuracy value (0.0 - 1.0).
|
||||
|
||||
"""
|
||||
mask = pad_targets != ignore_label
|
||||
numerator = torch.sum(
|
||||
pad_outputs.masked_select(mask) == pad_targets.masked_select(mask)
|
||||
)
|
||||
denominator = torch.sum(mask)
|
||||
return numerator.float() / denominator.float() #(FIX:MZY):return torch.Tensor type
|
||||
0
funasr/models/llm_asr/__init__.py
Normal file
0
funasr/models/llm_asr/__init__.py
Normal file
29
funasr/models/llm_asr/adaptor.py
Normal file
29
funasr/models/llm_asr/adaptor.py
Normal file
@ -0,0 +1,29 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from funasr.register import tables
|
||||
|
||||
@tables.register("adaptor_classes", "Linear")
|
||||
class Linear(nn.Module):
|
||||
def __init__(self, downsample_rate, encoder_dim, llm_dim, ffn_dim: int = 2048, **kwargs):
|
||||
super().__init__()
|
||||
self.k = downsample_rate
|
||||
self.encoder_dim = encoder_dim
|
||||
self.llm_dim = llm_dim
|
||||
self.linear1 = nn.Linear(self.encoder_dim * self.k, ffn_dim)
|
||||
self.relu = nn.ReLU()
|
||||
self.linear2 = nn.Linear(ffn_dim, self.llm_dim)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, seq_len, dim = x.size()
|
||||
num_frames_to_discard = seq_len % self.k
|
||||
if num_frames_to_discard > 0:
|
||||
x = x[:, :-num_frames_to_discard, :]
|
||||
seq_len = x.size(1)
|
||||
|
||||
x = x.contiguous()
|
||||
x = x.view(batch_size, seq_len // self.k, dim * self.k)
|
||||
x = self.linear1(x)
|
||||
x = self.relu(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
353
funasr/models/llm_asr/model.py
Normal file
353
funasr/models/llm_asr/model.py
Normal file
@ -0,0 +1,353 @@
|
||||
import logging
|
||||
from typing import Union, Dict, List, Tuple, Optional
|
||||
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
|
||||
from funasr.models.ctc.ctc import CTC
|
||||
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
|
||||
from funasr.metrics.compute_acc import th_accuracy, compute_accuracy
|
||||
# from funasr.models.e2e_asr_common import ErrorCalculator
|
||||
from funasr.train_utils.device_funcs import force_gatherable
|
||||
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
|
||||
from funasr.utils import postprocess_utils
|
||||
from funasr.utils.datadir_writer import DatadirWriter
|
||||
from funasr.register import tables
|
||||
|
||||
|
||||
@tables.register("model_classes", "LLMASR")
|
||||
class LLMASR(nn.Module):
|
||||
""" """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
specaug: str = None,
|
||||
specaug_conf: dict = None,
|
||||
normalize: str = None,
|
||||
normalize_conf: dict = None,
|
||||
encoder: str = None,
|
||||
encoder_conf: dict = None,
|
||||
decoder: str = None,
|
||||
decoder_conf: dict = None,
|
||||
ctc: str = None,
|
||||
ctc_conf: dict = None,
|
||||
ctc_weight: float = 0.5,
|
||||
llm: str = None,
|
||||
llm_conf: dict = None,
|
||||
adaptor: str = None,
|
||||
adaptor_conf: dict = None,
|
||||
input_size: int = 80,
|
||||
vocab_size: int = -1,
|
||||
ignore_id: int = -1,
|
||||
blank_id: int = 0,
|
||||
sos: int = 1,
|
||||
eos: int = 2,
|
||||
lsm_weight: float = 0.0,
|
||||
length_normalized_loss: bool = False,
|
||||
report_cer: bool = True,
|
||||
report_wer: bool = True,
|
||||
sym_space: str = "<space>",
|
||||
sym_blank: str = "<blank>",
|
||||
# extract_feats_in_collect_stats: bool = True,
|
||||
share_embedding: bool = False,
|
||||
# preencoder: Optional[AbsPreEncoder] = None,
|
||||
# postencoder: Optional[AbsPostEncoder] = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
if specaug is not None:
|
||||
specaug_class = tables.specaug_classes.get(specaug)
|
||||
specaug = specaug_class(**specaug_conf)
|
||||
if normalize is not None:
|
||||
normalize_class = tables.normalize_classes.get(normalize)
|
||||
normalize = normalize_class(**normalize_conf)
|
||||
|
||||
# audio encoder
|
||||
hub = encoder_conf.get("hub", None)
|
||||
if hub == "funasr":
|
||||
from funasr import AutoModel
|
||||
from funasr.models.scama.utils import sequence_mask
|
||||
init_param_path = encoder_conf.get("hub", "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
|
||||
model = AutoModel(model=init_param_path, model_revision="v2.0.4")
|
||||
frontend = model.kwargs.get("frontend")
|
||||
model.model.decoder = None
|
||||
|
||||
self.model = model.model
|
||||
self.frontend = frontend
|
||||
self.mask_fn = sequence_mask
|
||||
|
||||
elif hub == "hf":
|
||||
pass
|
||||
else:
|
||||
encoder_class = tables.encoder_classes.get(encoder)
|
||||
encoder = encoder_class(input_size=input_size, **encoder_conf)
|
||||
encoder_output_size = encoder.output_size()
|
||||
|
||||
# llm
|
||||
hub = llm_conf.get("hub", "hf")
|
||||
self.llm = None
|
||||
if hub == "hf":
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
|
||||
|
||||
init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
init_param_path,
|
||||
load_in_8bit=None,
|
||||
device_map=None,
|
||||
use_cache=None,
|
||||
)
|
||||
freeze_llm = llm_conf.get("freeze_llm", True)
|
||||
if freeze_llm:
|
||||
for name, param in model.named_parameters():
|
||||
param.requires_grad = False
|
||||
model.eval()
|
||||
self.llm = model
|
||||
|
||||
# adaptor
|
||||
adaptor_class = tables.adaptor_classes.get(adaptor)
|
||||
adaptor = adaptor_class(**adaptor_conf)
|
||||
|
||||
self.adaptor = adaptor
|
||||
|
||||
|
||||
self.blank_id = blank_id
|
||||
self.sos = sos if sos is not None else vocab_size - 1
|
||||
self.eos = eos if eos is not None else vocab_size - 1
|
||||
self.vocab_size = vocab_size
|
||||
self.ignore_id = ignore_id
|
||||
self.specaug = specaug
|
||||
self.normalize = normalize
|
||||
self.encoder = encoder
|
||||
|
||||
|
||||
self.criterion_att = LabelSmoothingLoss(
|
||||
size=vocab_size,
|
||||
padding_idx=ignore_id,
|
||||
smoothing=lsm_weight,
|
||||
normalize_length=length_normalized_loss,
|
||||
)
|
||||
#
|
||||
# if report_cer or report_wer:
|
||||
# self.error_calculator = ErrorCalculator(
|
||||
# token_list, sym_space, sym_blank, report_cer, report_wer
|
||||
# )
|
||||
#
|
||||
self.error_calculator = None
|
||||
|
||||
self.length_normalized_loss = length_normalized_loss
|
||||
self.beam_search = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask:torch.Tensor,
|
||||
labels_ids:torch.Tensor,
|
||||
label_mask: torch.Tensor,
|
||||
audio_mask:torch.Tensor,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Encoder + Decoder + Calc loss
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
text: (Batch, Length)
|
||||
text_lengths: (Batch,)
|
||||
"""
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
if len(text_lengths.size()) > 1:
|
||||
text_lengths = text_lengths[:, 0]
|
||||
if len(speech_lengths.size()) > 1:
|
||||
speech_lengths = speech_lengths[:, 0]
|
||||
|
||||
batch_size = speech.shape[0]
|
||||
|
||||
# audio encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, audio_mask)
|
||||
|
||||
# adaptor
|
||||
encoder_out = self.adaptor(encoder_out)
|
||||
|
||||
if input_ids is not None:
|
||||
input_ids[input_ids == -1] = 0
|
||||
if hasattr(self.llm.model, "embed_tokens"):
|
||||
inputs_embeds = self.llm.model.embed_tokens(input_ids)
|
||||
elif hasattr(self.llm.model.model, "embed_tokens"):
|
||||
inputs_embeds = self.llm.model.model.embed_tokens(input_ids)
|
||||
else:
|
||||
inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids)
|
||||
|
||||
if audio_mask is not None:
|
||||
batch_size, token_num, dims = inputs_embeds.shape
|
||||
_, l, _ = encoder_out.shape
|
||||
encoder_outs_pad = F.pad(encoder_out, (0, 0, token_num-l-1, 1, 0, 0), value=0.0)
|
||||
inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (~audio_mask[:, :, None])
|
||||
inputs_embeds = F.pad(inputs_embeds[:, 1:, :], (0, 0, 0, 1, 0, 0), value=0.0)
|
||||
|
||||
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
|
||||
loss = model_outputs.loss
|
||||
|
||||
acc_att = -1
|
||||
if self.metric:
|
||||
with torch.no_grad():
|
||||
preds = torch.argmax(model_outputs.logits, -1)
|
||||
acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
|
||||
|
||||
stats = {}
|
||||
# Collect Attn branch stats
|
||||
stats["acc"] = acc_att.detach()
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
if self.length_normalized_loss:
|
||||
batch_size = int((text_lengths + 1).sum())
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
def encode(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
audio_mask = kwargs.get("audio_mask")
|
||||
audio_token_lengths = audio_mask.sum(-1)
|
||||
|
||||
batch = {"speech": speech, "speech_lengths": speech_lengths}
|
||||
enc, enc_lens = self.model.encode(**batch)
|
||||
enc_mask = self.mask_fn(enc_lens, enc.size(1), device=enc.device)[:, None, :]
|
||||
pre_acoustic_embeds, pre_token_length, _, _ = self.model.predictor(enc,
|
||||
mask=enc_mask,
|
||||
target_label_length=audio_token_lengths,
|
||||
)
|
||||
|
||||
return pre_acoustic_embeds, pre_token_length
|
||||
|
||||
def _calc_att_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
):
|
||||
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
|
||||
ys_in_lens = ys_pad_lens + 1
|
||||
|
||||
# 1. Forward decoder
|
||||
decoder_out, _ = self.decoder(
|
||||
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
|
||||
)
|
||||
|
||||
# 2. Compute attention loss
|
||||
loss_att = self.criterion_att(decoder_out, ys_out_pad)
|
||||
acc_att = th_accuracy(
|
||||
decoder_out.view(-1, self.vocab_size),
|
||||
ys_out_pad,
|
||||
ignore_label=self.ignore_id,
|
||||
)
|
||||
|
||||
# Compute cer/wer using attention-decoder
|
||||
if self.training or self.error_calculator is None:
|
||||
cer_att, wer_att = None, None
|
||||
else:
|
||||
ys_hat = decoder_out.argmax(dim=-1)
|
||||
cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
|
||||
|
||||
return loss_att, acc_att, cer_att, wer_att
|
||||
|
||||
|
||||
def inference(self,
|
||||
data_in,
|
||||
data_lengths=None,
|
||||
key: list = None,
|
||||
tokenizer=None,
|
||||
frontend=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if kwargs.get("batch_size", 1) > 1:
|
||||
raise NotImplementedError("batch decoding is not implemented")
|
||||
|
||||
# init beamsearch
|
||||
if self.beam_search is None:
|
||||
logging.info("enable beam_search")
|
||||
self.init_beam_search(**kwargs)
|
||||
self.nbest = kwargs.get("nbest", 1)
|
||||
|
||||
meta_data = {}
|
||||
if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
|
||||
speech, speech_lengths = data_in, data_lengths
|
||||
if len(speech.shape) < 3:
|
||||
speech = speech[None, :, :]
|
||||
if speech_lengths is None:
|
||||
speech_lengths = speech.shape[1]
|
||||
else:
|
||||
# extract fbank feats
|
||||
time1 = time.perf_counter()
|
||||
audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
|
||||
data_type=kwargs.get("data_type", "sound"),
|
||||
tokenizer=tokenizer)
|
||||
time2 = time.perf_counter()
|
||||
meta_data["load_data"] = f"{time2 - time1:0.3f}"
|
||||
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
|
||||
frontend=frontend)
|
||||
time3 = time.perf_counter()
|
||||
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
|
||||
meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
|
||||
|
||||
speech = speech.to(device=kwargs["device"])
|
||||
speech_lengths = speech_lengths.to(device=kwargs["device"])
|
||||
# Encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
if isinstance(encoder_out, tuple):
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
# c. Passed the encoder result and the beam search
|
||||
nbest_hyps = self.beam_search(
|
||||
x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
|
||||
)
|
||||
|
||||
nbest_hyps = nbest_hyps[: self.nbest]
|
||||
|
||||
results = []
|
||||
b, n, d = encoder_out.size()
|
||||
for i in range(b):
|
||||
|
||||
for nbest_idx, hyp in enumerate(nbest_hyps):
|
||||
ibest_writer = None
|
||||
if kwargs.get("output_dir") is not None:
|
||||
if not hasattr(self, "writer"):
|
||||
self.writer = DatadirWriter(kwargs.get("output_dir"))
|
||||
ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
|
||||
|
||||
# remove sos/eos and get results
|
||||
last_pos = -1
|
||||
if isinstance(hyp.yseq, list):
|
||||
token_int = hyp.yseq[1:last_pos]
|
||||
else:
|
||||
token_int = hyp.yseq[1:last_pos].tolist()
|
||||
|
||||
# remove blank symbol id, which is assumed to be 0
|
||||
token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
|
||||
|
||||
# Change integer-ids to tokens
|
||||
token = tokenizer.ids2tokens(token_int)
|
||||
text = tokenizer.tokens2text(token)
|
||||
|
||||
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
|
||||
result_i = {"key": key[i], "token": token, "text": text_postprocessed}
|
||||
results.append(result_i)
|
||||
|
||||
if ibest_writer is not None:
|
||||
ibest_writer["token"][key[i]] = " ".join(token)
|
||||
ibest_writer["text"][key[i]] = text_postprocessed
|
||||
|
||||
return results, meta_data
|
||||
|
||||
90
funasr/models/llm_asr/template.yaml
Normal file
90
funasr/models/llm_asr/template.yaml
Normal file
@ -0,0 +1,90 @@
|
||||
# This is an example that demonstrates how to configure a model file.
|
||||
# You can modify the configuration according to your own requirements.
|
||||
|
||||
# to print the register_table:
|
||||
# from funasr.register import tables
|
||||
# tables.print()
|
||||
|
||||
# network architecture
|
||||
model: LLMASR
|
||||
model_conf:
|
||||
lsm_weight: 0.1 # label smoothing option
|
||||
length_normalized_loss: true
|
||||
|
||||
# encoder
|
||||
encoder: Paraformer
|
||||
encoder_conf:
|
||||
hub: funasr
|
||||
init_param_path: "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
||||
|
||||
llm: Vicuna
|
||||
llm_conf:
|
||||
hub: hf
|
||||
init_param_path: null
|
||||
freeze_llm: true
|
||||
|
||||
adaptor: linear
|
||||
adaptor_conf:
|
||||
downsample_rate: 1
|
||||
llm_dim: 4096
|
||||
encoder_dim: 2048
|
||||
|
||||
# frontend related
|
||||
frontend: WavFrontend
|
||||
frontend_conf:
|
||||
fs: 16000
|
||||
window: hamming
|
||||
n_mels: 80
|
||||
frame_length: 25
|
||||
frame_shift: 10
|
||||
dither: 0.0
|
||||
lfr_m: 1
|
||||
lfr_n: 1
|
||||
|
||||
specaug: SpecAug
|
||||
specaug_conf:
|
||||
apply_time_warp: true
|
||||
time_warp_window: 5
|
||||
time_warp_mode: bicubic
|
||||
apply_freq_mask: true
|
||||
freq_mask_width_range:
|
||||
- 0
|
||||
- 30
|
||||
num_freq_mask: 2
|
||||
apply_time_mask: true
|
||||
time_mask_width_range:
|
||||
- 0
|
||||
- 40
|
||||
num_time_mask: 2
|
||||
|
||||
train_conf:
|
||||
accum_grad: 1
|
||||
grad_clip: 5
|
||||
max_epoch: 150
|
||||
keep_nbest_models: 10
|
||||
log_interval: 50
|
||||
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.001
|
||||
weight_decay: 0.000001
|
||||
scheduler: warmuplr
|
||||
scheduler_conf:
|
||||
warmup_steps: 35000
|
||||
|
||||
dataset: AudioLLMDataset
|
||||
dataset_conf:
|
||||
index_ds: IndexDSJsonl
|
||||
batch_sampler: RankFullLocalShuffleBatchSampler
|
||||
batch_type: example # example or length
|
||||
batch_size: 4 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
|
||||
max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
|
||||
buffer_size: 500
|
||||
shuffle: True
|
||||
num_workers: 4
|
||||
|
||||
tokenizer: HuggingfaceTokenizer
|
||||
tokenizer_conf:
|
||||
unk_symbol: <unk>
|
||||
init_param_path: null
|
||||
|
||||
15
funasr/tokenizer/hf_tokenizer.py
Normal file
15
funasr/tokenizer/hf_tokenizer.py
Normal file
@ -0,0 +1,15 @@
|
||||
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
except:
|
||||
print("If you want to use hugging, please `pip install -U transformers`")
|
||||
|
||||
from funasr.register import tables
|
||||
|
||||
@tables.register("tokenizer_classes", "HuggingfaceTokenizer")
|
||||
def HuggingfaceTokenizer(init_param_path, **kwargs):
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(init_param_path)
|
||||
|
||||
return tokenizer
|
||||
|
||||
@ -5,7 +5,8 @@ import logging
|
||||
from tqdm import tqdm
|
||||
from datetime import datetime
|
||||
import torch.distributed as dist
|
||||
from contextlib import nullcontext
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
from contextlib import nullcontext, contextmanager
|
||||
# from torch.utils.tensorboard import SummaryWriter
|
||||
from tensorboardX import SummaryWriter
|
||||
from pathlib import Path
|
||||
@ -14,6 +15,14 @@ from funasr.train_utils.device_funcs import to_device
|
||||
from funasr.train_utils.recursive_op import recursive_average
|
||||
from funasr.train_utils.average_nbest_models import average_checkpoints
|
||||
|
||||
@contextmanager
|
||||
def maybe_autocast(enabled):
|
||||
if enabled:
|
||||
with autocast():
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
class Trainer:
|
||||
"""
|
||||
A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
|
||||
@ -36,8 +45,9 @@ class Trainer:
|
||||
dataloader_train,
|
||||
dataloader_val,
|
||||
local_rank,
|
||||
use_ddp=False,
|
||||
use_fsdp=False,
|
||||
use_ddp: bool = False,
|
||||
use_fsdp: bool = False,
|
||||
use_fp16: bool = False,
|
||||
output_dir: str="./",
|
||||
**kwargs):
|
||||
"""
|
||||
@ -72,6 +82,9 @@ class Trainer:
|
||||
self.kwargs = kwargs
|
||||
self.log_interval = kwargs.get("log_interval", 50)
|
||||
self.batch_total = 0
|
||||
self.use_fp16 = use_fp16
|
||||
self.disable_gpu_cache = kwargs.get("disable_gpu_cache", True)
|
||||
self.scaler = GradScaler(enabled=use_fp16) if use_fp16 else None
|
||||
|
||||
|
||||
try:
|
||||
@ -103,6 +116,8 @@ class Trainer:
|
||||
'optimizer': self.optim.state_dict(),
|
||||
'scheduler': self.scheduler.state_dict(),
|
||||
}
|
||||
if self.scaler:
|
||||
state["scaler_state"] = self.scaler.state_dict()
|
||||
# Create output directory if it does not exist
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
|
||||
@ -141,6 +156,8 @@ class Trainer:
|
||||
self.model.load_state_dict(dst_state)
|
||||
self.optim.load_state_dict(checkpoint['optimizer'])
|
||||
self.scheduler.load_state_dict(checkpoint['scheduler'])
|
||||
if self.scaler and 'scaler_state' in checkpoint:
|
||||
self.scaler.load_state_dict(checkpoint['scaler_state'])
|
||||
print(f"Checkpoint loaded successfully from '{ckpt}'")
|
||||
else:
|
||||
print(f"No checkpoint found at '{ckpt}', starting from scratch")
|
||||
@ -221,9 +238,10 @@ class Trainer:
|
||||
my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
|
||||
with my_context():
|
||||
time2 = time.perf_counter()
|
||||
|
||||
retval = self.model(**batch)
|
||||
torch.cuda.empty_cache()
|
||||
with maybe_autocast(self.use_fp16):
|
||||
retval = self.model(**batch)
|
||||
|
||||
if self.disable_gpu_cache: torch.cuda.empty_cache()
|
||||
|
||||
time3 = time.perf_counter()
|
||||
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
|
||||
@ -241,7 +259,10 @@ class Trainer:
|
||||
loss *= self.world_size
|
||||
# Scale the loss since we're not updating for every mini-batch
|
||||
loss = loss / accum_grad
|
||||
loss.backward()
|
||||
if self.use_fp16:
|
||||
self.scaler.scale(loss).backward()
|
||||
else:
|
||||
loss.backward()
|
||||
time4 = time.perf_counter()
|
||||
speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
|
||||
|
||||
@ -264,10 +285,14 @@ class Trainer:
|
||||
# Execute an optimization step (update model parameters)
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.barrier()
|
||||
self.optim.step()
|
||||
if self.use_fp16:
|
||||
self.scaler.step(self.optim)
|
||||
self.scaler.update()
|
||||
else:
|
||||
self.optim.step()
|
||||
self.scheduler.step()
|
||||
# Clear gradients for the next accumulation stage
|
||||
self.optim.zero_grad()
|
||||
self.optim.zero_grad(set_to_none=True)
|
||||
total_time = f"{time.perf_counter() - time5:0.3f}"
|
||||
time5 = time.perf_counter()
|
||||
speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user