mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
code update
This commit is contained in:
parent
3fcb5dcfed
commit
1233c0d3ff
@ -310,9 +310,6 @@ class AutoModel:
|
|||||||
logging.info("decoding, utt: {}, empty speech".format(key))
|
logging.info("decoding, utt: {}, empty speech".format(key))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
||||||
# if kwargs["device"] == "cpu":
|
|
||||||
# batch_size = 0
|
|
||||||
if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
|
if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
|
||||||
batch_size = max(batch_size, sorted_data[0][0][1] - sorted_data[0][0][0])
|
batch_size = max(batch_size, sorted_data[0][0][1] - sorted_data[0][0][0])
|
||||||
|
|
||||||
|
|||||||
@ -1,27 +1,29 @@
|
|||||||
import argparse
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from io import BytesIO
|
|
||||||
from collections.abc import Sequence
|
|
||||||
import torch
|
import torch
|
||||||
import hydra
|
import hydra
|
||||||
|
import logging
|
||||||
|
import argparse
|
||||||
|
from io import BytesIO
|
||||||
|
import torch.distributed as dist
|
||||||
|
from collections.abc import Sequence
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from funasr.train_utils.set_all_random_seed import set_all_random_seed
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from funasr.models.lora.utils import mark_only_lora_as_trainable
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
|
|
||||||
|
from funasr.register import tables
|
||||||
from funasr.optimizers import optim_classes
|
from funasr.optimizers import optim_classes
|
||||||
|
from funasr.train_utils.trainer import Trainer
|
||||||
from funasr.schedulers import scheduler_classes
|
from funasr.schedulers import scheduler_classes
|
||||||
from funasr.train_utils.load_pretrained_model import load_pretrained_model
|
|
||||||
from funasr.train_utils.initialize import initialize
|
from funasr.train_utils.initialize import initialize
|
||||||
|
from funasr.download.download_from_hub import download_model
|
||||||
|
from funasr.models.lora.utils import mark_only_lora_as_trainable
|
||||||
|
from funasr.train_utils.set_all_random_seed import set_all_random_seed
|
||||||
|
from funasr.train_utils.load_pretrained_model import load_pretrained_model
|
||||||
# from funasr.tokenizer.build_tokenizer import build_tokenizer
|
# from funasr.tokenizer.build_tokenizer import build_tokenizer
|
||||||
# from funasr.tokenizer.token_id_converter import TokenIDConverter
|
# from funasr.tokenizer.token_id_converter import TokenIDConverter
|
||||||
# from funasr.tokenizer.funtoken import build_tokenizer
|
# from funasr.tokenizer.funtoken import build_tokenizer
|
||||||
from funasr.train_utils.trainer import Trainer
|
|
||||||
import torch.distributed as dist
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
||||||
from funasr.download.download_from_hub import download_model
|
|
||||||
from funasr.register import tables
|
|
||||||
|
|
||||||
@hydra.main(config_name=None, version_base=None)
|
@hydra.main(config_name=None, version_base=None)
|
||||||
def main_hydra(kwargs: DictConfig):
|
def main_hydra(kwargs: DictConfig):
|
||||||
|
|||||||
@ -1,15 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
import json
|
|
||||||
import torch.distributed as dist
|
|
||||||
import numpy as np
|
|
||||||
import kaldiio
|
|
||||||
import librosa
|
|
||||||
import torchaudio
|
|
||||||
import time
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
|
|
||||||
from funasr.register import tables
|
from funasr.register import tables
|
||||||
|
from funasr.utils.load_utils import extract_fbank
|
||||||
|
|
||||||
|
|
||||||
@tables.register("dataset_classes", "AudioDataset")
|
@tables.register("dataset_classes", "AudioDataset")
|
||||||
class AudioDataset(torch.utils.data.Dataset):
|
class AudioDataset(torch.utils.data.Dataset):
|
||||||
@ -82,8 +75,6 @@ class AudioDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
|
|
||||||
def collator(self, samples: list=None):
|
def collator(self, samples: list=None):
|
||||||
|
|
||||||
|
|
||||||
outputs = {}
|
outputs = {}
|
||||||
for sample in samples:
|
for sample in samples:
|
||||||
for key in sample.keys():
|
for key in sample.keys():
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
import torch
|
|
||||||
import json
|
import json
|
||||||
import torch.distributed as dist
|
import torch
|
||||||
import time
|
|
||||||
import logging
|
import logging
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
from funasr.register import tables
|
from funasr.register import tables
|
||||||
|
|
||||||
|
|
||||||
@tables.register("index_ds_classes", "IndexDSJsonl")
|
@tables.register("index_ds_classes", "IndexDSJsonl")
|
||||||
class IndexDSJsonl(torch.utils.data.Dataset):
|
class IndexDSJsonl(torch.utils.data.Dataset):
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from funasr.register import tables
|
from funasr.register import tables
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
import json
|
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
import torch
|
|
||||||
from funasr.download.name_maps_from_hub import name_maps_ms, name_maps_hf
|
from funasr.download.name_maps_from_hub import name_maps_ms, name_maps_hf
|
||||||
|
|
||||||
|
|
||||||
def download_model(**kwargs):
|
def download_model(**kwargs):
|
||||||
model_hub = kwargs.get("model_hub", "ms")
|
model_hub = kwargs.get("model_hub", "ms")
|
||||||
if model_hub == "ms":
|
if model_hub == "ms":
|
||||||
|
|||||||
@ -1,8 +1,10 @@
|
|||||||
from pathlib import Path
|
|
||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from funasr.utils.types import str2bool
|
from funasr.utils.types import str2bool
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--model-name', type=str, required=True)
|
parser.add_argument('--model-name', type=str, required=True)
|
||||||
|
|||||||
@ -1,12 +1,10 @@
|
|||||||
|
import math
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
|
||||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
|
||||||
import logging
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from funasr.models.scama.utils import sequence_mask
|
|
||||||
|
|
||||||
|
from funasr.models.scama.utils import sequence_mask
|
||||||
|
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
class overlap_chunk():
|
class overlap_chunk():
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
|
||||||
from torch.nn import functional as F
|
|
||||||
import yaml
|
import yaml
|
||||||
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
|
def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
|
||||||
if maxlen is None:
|
if maxlen is None:
|
||||||
|
|||||||
@ -1,15 +1,9 @@
|
|||||||
from abc import ABC
|
|
||||||
from abc import abstractmethod
|
|
||||||
from typing import Iterable
|
|
||||||
from typing import List
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict
|
|
||||||
from typing import Iterable
|
|
||||||
from typing import List
|
|
||||||
from typing import Union
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from abc import ABC
|
||||||
|
from pathlib import Path
|
||||||
|
from abc import abstractmethod
|
||||||
|
from typing import Union, Iterable, List, Dict
|
||||||
|
|
||||||
|
|
||||||
class AbsTokenizer(ABC):
|
class AbsTokenizer(ABC):
|
||||||
|
|||||||
@ -1,13 +1,15 @@
|
|||||||
import torch
|
|
||||||
import os
|
import os
|
||||||
from funasr.train_utils.device_funcs import to_device
|
|
||||||
import logging
|
|
||||||
import time
|
import time
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from contextlib import nullcontext
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
|
from funasr.train_utils.device_funcs import to_device
|
||||||
from funasr.train_utils.recursive_op import recursive_average
|
from funasr.train_utils.recursive_op import recursive_average
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
"""
|
"""
|
||||||
A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
|
A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
|
|
||||||
def slice_padding_fbank(speech, speech_lengths, vad_segments):
|
def slice_padding_fbank(speech, speech_lengths, vad_segments):
|
||||||
speech_list = []
|
speech_list = []
|
||||||
speech_lengths_list = []
|
speech_lengths_list = []
|
||||||
@ -16,7 +17,6 @@ def slice_padding_fbank(speech, speech_lengths, vad_segments):
|
|||||||
speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
|
speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
|
||||||
return feats_pad, speech_lengths_pad
|
return feats_pad, speech_lengths_pad
|
||||||
|
|
||||||
|
|
||||||
def slice_padding_audio_samples(speech, speech_lengths, vad_segments):
|
def slice_padding_audio_samples(speech, speech_lengths, vad_segments):
|
||||||
speech_list = []
|
speech_list = []
|
||||||
speech_lengths_list = []
|
speech_lengths_list = []
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user