mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
code update
This commit is contained in:
parent
3fcb5dcfed
commit
1233c0d3ff
@ -310,9 +310,6 @@ class AutoModel:
|
||||
logging.info("decoding, utt: {}, empty speech".format(key))
|
||||
continue
|
||||
|
||||
|
||||
# if kwargs["device"] == "cpu":
|
||||
# batch_size = 0
|
||||
if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
|
||||
batch_size = max(batch_size, sorted_data[0][0][1] - sorted_data[0][0][0])
|
||||
|
||||
|
||||
@ -1,27 +1,29 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from io import BytesIO
|
||||
from collections.abc import Sequence
|
||||
import torch
|
||||
import hydra
|
||||
import logging
|
||||
import argparse
|
||||
from io import BytesIO
|
||||
import torch.distributed as dist
|
||||
from collections.abc import Sequence
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from funasr.train_utils.set_all_random_seed import set_all_random_seed
|
||||
from funasr.models.lora.utils import mark_only_lora_as_trainable
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
|
||||
from funasr.register import tables
|
||||
from funasr.optimizers import optim_classes
|
||||
from funasr.train_utils.trainer import Trainer
|
||||
from funasr.schedulers import scheduler_classes
|
||||
from funasr.train_utils.load_pretrained_model import load_pretrained_model
|
||||
from funasr.train_utils.initialize import initialize
|
||||
from funasr.download.download_from_hub import download_model
|
||||
from funasr.models.lora.utils import mark_only_lora_as_trainable
|
||||
from funasr.train_utils.set_all_random_seed import set_all_random_seed
|
||||
from funasr.train_utils.load_pretrained_model import load_pretrained_model
|
||||
# from funasr.tokenizer.build_tokenizer import build_tokenizer
|
||||
# from funasr.tokenizer.token_id_converter import TokenIDConverter
|
||||
# from funasr.tokenizer.funtoken import build_tokenizer
|
||||
from funasr.train_utils.trainer import Trainer
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from funasr.download.download_from_hub import download_model
|
||||
from funasr.register import tables
|
||||
|
||||
|
||||
@hydra.main(config_name=None, version_base=None)
|
||||
def main_hydra(kwargs: DictConfig):
|
||||
|
||||
@ -1,15 +1,8 @@
|
||||
import torch
|
||||
import json
|
||||
import torch.distributed as dist
|
||||
import numpy as np
|
||||
import kaldiio
|
||||
import librosa
|
||||
import torchaudio
|
||||
import time
|
||||
import logging
|
||||
|
||||
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
|
||||
from funasr.register import tables
|
||||
from funasr.utils.load_utils import extract_fbank
|
||||
|
||||
|
||||
@tables.register("dataset_classes", "AudioDataset")
|
||||
class AudioDataset(torch.utils.data.Dataset):
|
||||
@ -82,8 +75,6 @@ class AudioDataset(torch.utils.data.Dataset):
|
||||
|
||||
|
||||
def collator(self, samples: list=None):
|
||||
|
||||
|
||||
outputs = {}
|
||||
for sample in samples:
|
||||
for key in sample.keys():
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
import torch
|
||||
import json
|
||||
import torch.distributed as dist
|
||||
import time
|
||||
import torch
|
||||
import logging
|
||||
import torch.distributed as dist
|
||||
|
||||
from funasr.register import tables
|
||||
|
||||
|
||||
@tables.register("index_ds_classes", "IndexDSJsonl")
|
||||
class IndexDSJsonl(torch.utils.data.Dataset):
|
||||
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import torch
|
||||
|
||||
import numpy as np
|
||||
|
||||
from funasr.register import tables
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
import json
|
||||
import os
|
||||
import json
|
||||
from omegaconf import OmegaConf
|
||||
import torch
|
||||
|
||||
from funasr.download.name_maps_from_hub import name_maps_ms, name_maps_hf
|
||||
|
||||
|
||||
def download_model(**kwargs):
|
||||
model_hub = kwargs.get("model_hub", "ms")
|
||||
if model_hub == "ms":
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
from pathlib import Path
|
||||
import os
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from funasr.utils.types import str2bool
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model-name', type=str, required=True)
|
||||
|
||||
@ -1,12 +1,10 @@
|
||||
|
||||
import math
|
||||
import torch
|
||||
import numpy as np
|
||||
import math
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
import logging
|
||||
import torch.nn.functional as F
|
||||
from funasr.models.scama.utils import sequence_mask
|
||||
|
||||
from funasr.models.scama.utils import sequence_mask
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
|
||||
|
||||
class overlap_chunk():
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
import os
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
import yaml
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
|
||||
if maxlen is None:
|
||||
|
||||
@ -1,15 +1,9 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Iterable
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from typing import Iterable
|
||||
from typing import List
|
||||
from typing import Union
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
from abc import abstractmethod
|
||||
from typing import Union, Iterable, List, Dict
|
||||
|
||||
|
||||
class AbsTokenizer(ABC):
|
||||
|
||||
@ -1,13 +1,15 @@
|
||||
import torch
|
||||
import os
|
||||
from funasr.train_utils.device_funcs import to_device
|
||||
import logging
|
||||
import time
|
||||
import torch
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
from contextlib import nullcontext
|
||||
import torch.distributed as dist
|
||||
from contextlib import nullcontext
|
||||
|
||||
from funasr.train_utils.device_funcs import to_device
|
||||
from funasr.train_utils.recursive_op import recursive_average
|
||||
|
||||
|
||||
class Trainer:
|
||||
"""
|
||||
A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def slice_padding_fbank(speech, speech_lengths, vad_segments):
|
||||
speech_list = []
|
||||
speech_lengths_list = []
|
||||
@ -16,7 +17,6 @@ def slice_padding_fbank(speech, speech_lengths, vad_segments):
|
||||
speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
|
||||
return feats_pad, speech_lengths_pad
|
||||
|
||||
|
||||
def slice_padding_audio_samples(speech, speech_lengths, vad_segments):
|
||||
speech_list = []
|
||||
speech_lengths_list = []
|
||||
|
||||
Loading…
Reference in New Issue
Block a user