code update

This commit is contained in:
shixian.shi 2024-01-15 20:34:47 +08:00
parent 3fcb5dcfed
commit 1233c0d3ff
24 changed files with 1391 additions and 1404 deletions

View File

@ -310,9 +310,6 @@ class AutoModel:
logging.info("decoding, utt: {}, empty speech".format(key)) logging.info("decoding, utt: {}, empty speech".format(key))
continue continue
# if kwargs["device"] == "cpu":
# batch_size = 0
if len(sorted_data) > 0 and len(sorted_data[0]) > 0: if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
batch_size = max(batch_size, sorted_data[0][0][1] - sorted_data[0][0][0]) batch_size = max(batch_size, sorted_data[0][0][1] - sorted_data[0][0][0])

View File

@ -1,27 +1,29 @@
import argparse
import logging
import os import os
import sys import sys
from io import BytesIO
from collections.abc import Sequence
import torch import torch
import hydra import hydra
import logging
import argparse
from io import BytesIO
import torch.distributed as dist
from collections.abc import Sequence
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from funasr.train_utils.set_all_random_seed import set_all_random_seed from torch.nn.parallel import DistributedDataParallel as DDP
from funasr.models.lora.utils import mark_only_lora_as_trainable from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from funasr.register import tables
from funasr.optimizers import optim_classes from funasr.optimizers import optim_classes
from funasr.train_utils.trainer import Trainer
from funasr.schedulers import scheduler_classes from funasr.schedulers import scheduler_classes
from funasr.train_utils.load_pretrained_model import load_pretrained_model
from funasr.train_utils.initialize import initialize from funasr.train_utils.initialize import initialize
from funasr.download.download_from_hub import download_model
from funasr.models.lora.utils import mark_only_lora_as_trainable
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.train_utils.load_pretrained_model import load_pretrained_model
# from funasr.tokenizer.build_tokenizer import build_tokenizer # from funasr.tokenizer.build_tokenizer import build_tokenizer
# from funasr.tokenizer.token_id_converter import TokenIDConverter # from funasr.tokenizer.token_id_converter import TokenIDConverter
# from funasr.tokenizer.funtoken import build_tokenizer # from funasr.tokenizer.funtoken import build_tokenizer
from funasr.train_utils.trainer import Trainer
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from funasr.download.download_from_hub import download_model
from funasr.register import tables
@hydra.main(config_name=None, version_base=None) @hydra.main(config_name=None, version_base=None)
def main_hydra(kwargs: DictConfig): def main_hydra(kwargs: DictConfig):

View File

@ -1,15 +1,8 @@
import torch import torch
import json
import torch.distributed as dist
import numpy as np
import kaldiio
import librosa
import torchaudio
import time
import logging
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.register import tables from funasr.register import tables
from funasr.utils.load_utils import extract_fbank
@tables.register("dataset_classes", "AudioDataset") @tables.register("dataset_classes", "AudioDataset")
class AudioDataset(torch.utils.data.Dataset): class AudioDataset(torch.utils.data.Dataset):
@ -82,8 +75,6 @@ class AudioDataset(torch.utils.data.Dataset):
def collator(self, samples: list=None): def collator(self, samples: list=None):
outputs = {} outputs = {}
for sample in samples: for sample in samples:
for key in sample.keys(): for key in sample.keys():

View File

@ -1,11 +1,11 @@
import torch
import json import json
import torch.distributed as dist import torch
import time
import logging import logging
import torch.distributed as dist
from funasr.register import tables from funasr.register import tables
@tables.register("index_ds_classes", "IndexDSJsonl") @tables.register("index_ds_classes", "IndexDSJsonl")
class IndexDSJsonl(torch.utils.data.Dataset): class IndexDSJsonl(torch.utils.data.Dataset):

View File

@ -1,5 +1,4 @@
import torch import torch
import numpy as np import numpy as np
from funasr.register import tables from funasr.register import tables

View File

@ -1,9 +1,10 @@
import json
import os import os
import json
from omegaconf import OmegaConf from omegaconf import OmegaConf
import torch
from funasr.download.name_maps_from_hub import name_maps_ms, name_maps_hf from funasr.download.name_maps_from_hub import name_maps_ms, name_maps_hf
def download_model(**kwargs): def download_model(**kwargs):
model_hub = kwargs.get("model_hub", "ms") model_hub = kwargs.get("model_hub", "ms")
if model_hub == "ms": if model_hub == "ms":

View File

@ -1,8 +1,10 @@
from pathlib import Path
import os import os
import argparse import argparse
from pathlib import Path
from funasr.utils.types import str2bool from funasr.utils.types import str2bool
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model-name', type=str, required=True) parser.add_argument('--model-name', type=str, required=True)

View File

@ -1,12 +1,10 @@
import math
import torch import torch
import numpy as np import numpy as np
import math
from funasr.models.transformer.utils.nets_utils import make_pad_mask
import logging
import torch.nn.functional as F import torch.nn.functional as F
from funasr.models.scama.utils import sequence_mask
from funasr.models.scama.utils import sequence_mask
from funasr.models.transformer.utils.nets_utils import make_pad_mask
class overlap_chunk(): class overlap_chunk():

View File

@ -1,8 +1,9 @@
import os import os
import torch
from torch.nn import functional as F
import yaml import yaml
import torch
import numpy as np import numpy as np
from torch.nn import functional as F
def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None): def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
if maxlen is None: if maxlen is None:

View File

@ -1,15 +1,9 @@
from abc import ABC
from abc import abstractmethod
from typing import Iterable
from typing import List
from pathlib import Path
from typing import Dict
from typing import Iterable
from typing import List
from typing import Union
import json import json
import numpy as np import numpy as np
from abc import ABC
from pathlib import Path
from abc import abstractmethod
from typing import Union, Iterable, List, Dict
class AbsTokenizer(ABC): class AbsTokenizer(ABC):

View File

@ -1,13 +1,15 @@
import torch
import os import os
from funasr.train_utils.device_funcs import to_device
import logging
import time import time
import torch
import logging
from tqdm import tqdm from tqdm import tqdm
from contextlib import nullcontext
import torch.distributed as dist import torch.distributed as dist
from contextlib import nullcontext
from funasr.train_utils.device_funcs import to_device
from funasr.train_utils.recursive_op import recursive_average from funasr.train_utils.recursive_op import recursive_average
class Trainer: class Trainer:
""" """
A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch, A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,

View File

@ -1,6 +1,7 @@
import torch import torch
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
def slice_padding_fbank(speech, speech_lengths, vad_segments): def slice_padding_fbank(speech, speech_lengths, vad_segments):
speech_list = [] speech_list = []
speech_lengths_list = [] speech_lengths_list = []
@ -16,7 +17,6 @@ def slice_padding_fbank(speech, speech_lengths, vad_segments):
speech_lengths_pad = torch.Tensor(speech_lengths_list).int() speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
return feats_pad, speech_lengths_pad return feats_pad, speech_lengths_pad
def slice_padding_audio_samples(speech, speech_lengths, vad_segments): def slice_padding_audio_samples(speech, speech_lengths, vad_segments):
speech_list = [] speech_list = []
speech_lengths_list = [] speech_lengths_list = []