mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
finetune
This commit is contained in:
parent
add1ac00f7
commit
1cbd2015f0
@ -171,7 +171,8 @@ class AutoModel:
|
||||
self.spk_kwargs = spk_kwargs
|
||||
self.model_path = kwargs.get("model_path")
|
||||
|
||||
def build_model(self, **kwargs):
|
||||
@staticmethod
|
||||
def build_model(**kwargs):
|
||||
assert "model" in kwargs
|
||||
if "model_conf" not in kwargs:
|
||||
logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms")))
|
||||
|
||||
@ -1,9 +0,0 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Iterator
|
||||
|
||||
|
||||
class AbsIterFactory(ABC):
|
||||
@abstractmethod
|
||||
def build_iter(self, epoch: int, shuffle: bool = None) -> Iterator:
|
||||
raise NotImplementedError
|
||||
@ -1,109 +0,0 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
from typing import List
|
||||
from typing import Union
|
||||
|
||||
import sentencepiece as spm
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from funasr.datasets.large_datasets.dataset import Dataset
|
||||
from funasr.datasets.large_datasets.abs_iter_factory import AbsIterFactory
|
||||
from funasr.tokenizer.abs_tokenizer import AbsTokenizer
|
||||
|
||||
from funasr.register import tables
|
||||
|
||||
|
||||
def read_symbol_table(symbol_table_file):
|
||||
if isinstance(symbol_table_file, str):
|
||||
symbol_table = {}
|
||||
with open(symbol_table_file, "r", encoding="utf8") as fin:
|
||||
for i, line in enumerate(fin):
|
||||
char = line.strip()
|
||||
symbol_table[char] = i
|
||||
else:
|
||||
assert isinstance(symbol_table_file, list)
|
||||
symbol_table = {}
|
||||
for i, char in enumerate(symbol_table_file):
|
||||
symbol_table[char] = i
|
||||
return symbol_table
|
||||
|
||||
|
||||
def load_seg_dict(seg_dict_file):
|
||||
seg_dict = {}
|
||||
assert isinstance(seg_dict_file, str)
|
||||
with open(seg_dict_file, "r", encoding="utf8") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
s = line.strip().split()
|
||||
key = s[0]
|
||||
value = s[1:]
|
||||
seg_dict[key] = " ".join(value)
|
||||
return seg_dict
|
||||
|
||||
|
||||
class SentencepiecesTokenizer(AbsTokenizer):
|
||||
def __init__(self, model: Union[Path, str]):
|
||||
self.model = str(model)
|
||||
self.sp = None
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}(model="{self.model}")'
|
||||
|
||||
def _build_sentence_piece_processor(self):
|
||||
if self.sp is None:
|
||||
self.sp = spm.SentencePieceProcessor()
|
||||
self.sp.load(self.model)
|
||||
|
||||
def text2tokens(self, line: str) -> List[str]:
|
||||
self._build_sentence_piece_processor()
|
||||
return self.sp.EncodeAsPieces(line)
|
||||
|
||||
def tokens2text(self, tokens: Iterable[str]) -> str:
|
||||
self._build_sentence_piece_processor()
|
||||
return self.sp.DecodePieces(list(tokens))
|
||||
|
||||
|
||||
@tables.register("dataset_classes", "LargeDataset")
|
||||
class LargeDataLoader(AbsIterFactory):
|
||||
def __init__(self, args, mode="train"):
|
||||
symbol_table, seg_dict, punc_dict, bpe_tokenizer = None, None, None, None
|
||||
if hasattr(args, "token_list") and args.token_list is not None:
|
||||
symbol_table = read_symbol_table(args.token_list)
|
||||
if hasattr(args, "seg_dict_file") and args.seg_dict_file is not None:
|
||||
seg_dict = load_seg_dict(args.seg_dict_file)
|
||||
if hasattr(args, "punc_list") and args.punc_list is not None:
|
||||
punc_dict = read_symbol_table(args.punc_list)
|
||||
if hasattr(args, "bpemodel") and args.bpemodel is not None:
|
||||
bpe_tokenizer = SentencepiecesTokenizer(args.bpemodel)
|
||||
self.dataset_conf = args.dataset_conf
|
||||
if "frontend_conf" not in args:
|
||||
self.frontend_conf = None
|
||||
else:
|
||||
self.frontend_conf = args.frontend_conf
|
||||
self.speed_perturb = args.speed_perturb if hasattr(args, "speed_perturb") else None
|
||||
logging.info("dataloader config: {}".format(self.dataset_conf))
|
||||
batch_mode = self.dataset_conf.get("batch_mode", "padding")
|
||||
data_list = args.train_data_file if mode == "train" else args.valid_data_file
|
||||
self.dataset = Dataset(
|
||||
data_list,
|
||||
symbol_table,
|
||||
seg_dict,
|
||||
punc_dict,
|
||||
bpe_tokenizer,
|
||||
self.dataset_conf,
|
||||
self.frontend_conf,
|
||||
speed_perturb=self.speed_perturb if mode == "train" else None,
|
||||
mode=mode,
|
||||
batch_mode=batch_mode,
|
||||
)
|
||||
|
||||
def build_iter(self, epoch, shuffle=True):
|
||||
self.dataset.set_epoch(epoch)
|
||||
data_loader = DataLoader(
|
||||
self.dataset,
|
||||
batch_size=None,
|
||||
pin_memory=True,
|
||||
num_workers=self.dataset_conf.get("num_workers", 8),
|
||||
)
|
||||
return data_loader
|
||||
@ -1,194 +0,0 @@
|
||||
from typing import Collection
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from funasr.models.transformer.utils.nets_utils import pad_list, pad_list_all_dim
|
||||
|
||||
|
||||
class CommonCollateFn:
|
||||
"""Functor class of common_collate_fn()"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
float_pad_value: Union[float, int] = 0.0,
|
||||
int_pad_value: int = -32768,
|
||||
not_sequence: Collection[str] = (),
|
||||
max_sample_size=None,
|
||||
):
|
||||
self.float_pad_value = float_pad_value
|
||||
self.int_pad_value = int_pad_value
|
||||
self.not_sequence = set(not_sequence)
|
||||
self.max_sample_size = max_sample_size
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{self.__class__}(float_pad_value={self.float_pad_value}, "
|
||||
f"int_pad_value={self.float_pad_value})"
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
|
||||
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
|
||||
return common_collate_fn(
|
||||
data,
|
||||
float_pad_value=self.float_pad_value,
|
||||
int_pad_value=self.int_pad_value,
|
||||
not_sequence=self.not_sequence,
|
||||
)
|
||||
|
||||
|
||||
def common_collate_fn(
|
||||
data: Collection[Tuple[str, Dict[str, np.ndarray]]],
|
||||
float_pad_value: Union[float, int] = 0.0,
|
||||
int_pad_value: int = -32768,
|
||||
not_sequence: Collection[str] = (),
|
||||
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
|
||||
"""Concatenate ndarray-list to an array and convert to torch.Tensor."""
|
||||
uttids = [u for u, _ in data]
|
||||
data = [d for _, d in data]
|
||||
|
||||
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
|
||||
assert all(
|
||||
not k.endswith("_lengths") for k in data[0]
|
||||
), f"*_lengths is reserved: {list(data[0])}"
|
||||
|
||||
output = {}
|
||||
for key in data[0]:
|
||||
if data[0][key].dtype.kind == "i":
|
||||
pad_value = int_pad_value
|
||||
else:
|
||||
pad_value = float_pad_value
|
||||
|
||||
array_list = [d[key] for d in data]
|
||||
tensor_list = [torch.from_numpy(a) for a in array_list]
|
||||
tensor = pad_list(tensor_list, pad_value)
|
||||
output[key] = tensor
|
||||
|
||||
if key not in not_sequence:
|
||||
lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long)
|
||||
output[key + "_lengths"] = lens
|
||||
|
||||
output = (uttids, output)
|
||||
return output
|
||||
|
||||
|
||||
class DiarCollateFn:
|
||||
"""Functor class of common_collate_fn()"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
float_pad_value: Union[float, int] = 0.0,
|
||||
int_pad_value: int = -32768,
|
||||
not_sequence: Collection[str] = (),
|
||||
max_sample_size=None,
|
||||
):
|
||||
self.float_pad_value = float_pad_value
|
||||
self.int_pad_value = int_pad_value
|
||||
self.not_sequence = set(not_sequence)
|
||||
self.max_sample_size = max_sample_size
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{self.__class__}(float_pad_value={self.float_pad_value}, "
|
||||
f"int_pad_value={self.float_pad_value})"
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
|
||||
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
|
||||
return diar_collate_fn(
|
||||
data,
|
||||
float_pad_value=self.float_pad_value,
|
||||
int_pad_value=self.int_pad_value,
|
||||
not_sequence=self.not_sequence,
|
||||
)
|
||||
|
||||
|
||||
def diar_collate_fn(
|
||||
data: Collection[Tuple[str, Dict[str, np.ndarray]]],
|
||||
float_pad_value: Union[float, int] = 0.0,
|
||||
int_pad_value: int = -32768,
|
||||
not_sequence: Collection[str] = (),
|
||||
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
|
||||
"""Concatenate ndarray-list to an array and convert to torch.Tensor."""
|
||||
uttids = [u for u, _ in data]
|
||||
data = [d for _, d in data]
|
||||
|
||||
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
|
||||
assert all(
|
||||
not k.endswith("_lengths") for k in data[0]
|
||||
), f"*_lengths is reserved: {list(data[0])}"
|
||||
|
||||
output = {}
|
||||
for key in data[0]:
|
||||
if data[0][key].dtype.kind == "i":
|
||||
pad_value = int_pad_value
|
||||
else:
|
||||
pad_value = float_pad_value
|
||||
|
||||
array_list = [d[key] for d in data]
|
||||
tensor_list = [torch.from_numpy(a) for a in array_list]
|
||||
tensor = pad_list_all_dim(tensor_list, pad_value)
|
||||
output[key] = tensor
|
||||
|
||||
if key not in not_sequence:
|
||||
lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long)
|
||||
output[key + "_lengths"] = lens
|
||||
|
||||
output = (uttids, output)
|
||||
return output
|
||||
|
||||
|
||||
def crop_to_max_size(feature, target_size):
|
||||
size = len(feature)
|
||||
diff = size - target_size
|
||||
if diff <= 0:
|
||||
return feature
|
||||
|
||||
start = np.random.randint(0, diff + 1)
|
||||
end = size - diff + start
|
||||
return feature[start:end]
|
||||
|
||||
|
||||
def clipping_collate_fn(
|
||||
data: Collection[Tuple[str, Dict[str, np.ndarray]]],
|
||||
max_sample_size=None,
|
||||
not_sequence: Collection[str] = (),
|
||||
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
|
||||
# mainly for pre-training
|
||||
uttids = [u for u, _ in data]
|
||||
data = [d for _, d in data]
|
||||
|
||||
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
|
||||
assert all(
|
||||
not k.endswith("_lengths") for k in data[0]
|
||||
), f"*_lengths is reserved: {list(data[0])}"
|
||||
|
||||
output = {}
|
||||
for key in data[0]:
|
||||
array_list = [d[key] for d in data]
|
||||
tensor_list = [torch.from_numpy(a) for a in array_list]
|
||||
sizes = [len(s) for s in tensor_list]
|
||||
if max_sample_size is None:
|
||||
target_size = min(sizes)
|
||||
else:
|
||||
target_size = min(min(sizes), max_sample_size)
|
||||
tensor = tensor_list[0].new_zeros(len(tensor_list), target_size, tensor_list[0].shape[1])
|
||||
for i, (source, size) in enumerate(zip(tensor_list, sizes)):
|
||||
diff = size - target_size
|
||||
if diff == 0:
|
||||
tensor[i] = source
|
||||
else:
|
||||
tensor[i] = crop_to_max_size(source, target_size)
|
||||
output[key] = tensor
|
||||
|
||||
if key not in not_sequence:
|
||||
lens = torch.tensor([source.shape[0] for source in tensor], dtype=torch.long)
|
||||
output[key + "_lengths"] = lens
|
||||
|
||||
output = (uttids, output)
|
||||
return output
|
||||
@ -1,213 +0,0 @@
|
||||
import random
|
||||
|
||||
from itertools import count
|
||||
from functools import partial
|
||||
from torch.utils.data import IterableDataset
|
||||
from funasr.datasets.large_datasets.datapipes.map import MapperIterDataPipe
|
||||
|
||||
tiebreaker = count()
|
||||
|
||||
|
||||
def _default_len_fn(token):
|
||||
return len(token), next(tiebreaker)
|
||||
|
||||
|
||||
def _token_len_fn(token, len_fn):
|
||||
return len_fn(token), next(tiebreaker), token
|
||||
|
||||
|
||||
class MaxTokenBucketizerIterDataPipe(IterableDataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
datapipe,
|
||||
batch_size=8000,
|
||||
len_fn=_default_len_fn,
|
||||
buffer_size=10240,
|
||||
sort_size=500,
|
||||
batch_mode="padding",
|
||||
):
|
||||
assert batch_size > 0, "Batch size is required to be larger than 0!"
|
||||
assert buffer_size >= -1, "Buffer size is required to be larger than -1!"
|
||||
assert sort_size > 0, "Sort size is required to be larger than 0!"
|
||||
|
||||
datapipe = MapperIterDataPipe(datapipe, fn=partial(_token_len_fn, len_fn=len_fn))
|
||||
self.datapipe = datapipe
|
||||
self.batch_size = batch_size
|
||||
self.buffer_size = buffer_size
|
||||
self.sort_size = sort_size
|
||||
self.batch_mode = batch_mode
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.datapipe.set_epoch(epoch)
|
||||
|
||||
def __iter__(self):
|
||||
buffer = []
|
||||
batch = []
|
||||
bucket = []
|
||||
max_lengths = 0
|
||||
min_lengths = 999999
|
||||
batch_lengths = 0
|
||||
|
||||
if self.batch_mode == "clipping":
|
||||
assert self.buffer_size > 0, "for clipping batch_mode, buffer_size must be > 1"
|
||||
for d in self.datapipe:
|
||||
if d[0] > self.batch_size:
|
||||
continue
|
||||
buffer.append(d)
|
||||
if len(buffer) == self.buffer_size:
|
||||
random.shuffle(buffer)
|
||||
for sample in buffer:
|
||||
bucket.append(sample)
|
||||
if len(bucket) == self.sort_size:
|
||||
bucket.sort()
|
||||
for x in bucket:
|
||||
length, _, token = x
|
||||
if length < min_lengths:
|
||||
min_lengths = length
|
||||
batch_lengths = min_lengths * (len(batch) + 1)
|
||||
if batch_lengths > self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
min_lengths = length
|
||||
batch.append(token)
|
||||
bucket = []
|
||||
buffer = []
|
||||
|
||||
if buffer:
|
||||
random.shuffle(buffer)
|
||||
for sample in buffer:
|
||||
bucket.append(sample)
|
||||
if len(bucket) == self.sort_size:
|
||||
bucket.sort()
|
||||
for x in bucket:
|
||||
length, _, token = x
|
||||
if length < min_lengths:
|
||||
min_lengths = length
|
||||
batch_lengths = min_lengths * (len(batch) + 1)
|
||||
if batch_lengths > self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
min_lengths = length
|
||||
batch.append(token)
|
||||
bucket = []
|
||||
buffer = []
|
||||
|
||||
if bucket:
|
||||
bucket.sort()
|
||||
for x in bucket:
|
||||
length, _, token = x
|
||||
if length < min_lengths:
|
||||
min_lengths = length
|
||||
batch_lengths = min_lengths * (len(batch) + 1)
|
||||
if batch_lengths > self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
min_lengths = length
|
||||
batch.append(token)
|
||||
bucket = []
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
else:
|
||||
if self.buffer_size == -1:
|
||||
for d in self.datapipe:
|
||||
if d[0] > self.batch_size:
|
||||
continue
|
||||
buffer.append(d)
|
||||
buffer.sort()
|
||||
for sample in buffer:
|
||||
length, _, token = sample
|
||||
if length > max_lengths:
|
||||
max_lengths = length
|
||||
batch_lengths = max_lengths * (len(batch) + 1)
|
||||
if batch_lengths > self.batch_size:
|
||||
bucket.append(batch)
|
||||
batch = []
|
||||
max_lengths = length
|
||||
batch.append(token)
|
||||
random.shuffle(bucket)
|
||||
if bucket:
|
||||
for batch_sample in bucket:
|
||||
yield batch_sample
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
elif self.buffer_size == 0:
|
||||
for d in self.datapipe:
|
||||
if d[0] > self.batch_size:
|
||||
continue
|
||||
length, _, token = d
|
||||
if length > self.batch_size:
|
||||
continue
|
||||
if length > max_lengths:
|
||||
max_lengths = length
|
||||
batch_lengths = max_lengths * (len(batch) + 1)
|
||||
if batch_lengths > self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
max_lengths = length
|
||||
batch.append(token)
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
else:
|
||||
for d in self.datapipe:
|
||||
if d[0] > self.batch_size:
|
||||
continue
|
||||
buffer.append(d)
|
||||
if len(buffer) == self.buffer_size:
|
||||
random.shuffle(buffer)
|
||||
for sample in buffer:
|
||||
bucket.append(sample)
|
||||
if len(bucket) == self.sort_size:
|
||||
bucket.sort()
|
||||
for x in bucket:
|
||||
length, _, token = x
|
||||
if length > max_lengths:
|
||||
max_lengths = length
|
||||
batch_lengths = max_lengths * (len(batch) + 1)
|
||||
if batch_lengths > self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
max_lengths = length
|
||||
batch.append(token)
|
||||
bucket = []
|
||||
buffer = []
|
||||
|
||||
if buffer:
|
||||
random.shuffle(buffer)
|
||||
for sample in buffer:
|
||||
bucket.append(sample)
|
||||
if len(bucket) == self.sort_size:
|
||||
bucket.sort()
|
||||
for x in bucket:
|
||||
length, _, token = x
|
||||
if length > max_lengths:
|
||||
max_lengths = length
|
||||
batch_lengths = max_lengths * (len(batch) + 1)
|
||||
if batch_lengths > self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
max_lengths = length
|
||||
batch.append(token)
|
||||
bucket = []
|
||||
buffer = []
|
||||
|
||||
if bucket:
|
||||
bucket.sort()
|
||||
for x in bucket:
|
||||
length, _, token = x
|
||||
if length > max_lengths:
|
||||
max_lengths = length
|
||||
batch_lengths = max_lengths * (len(batch) + 1)
|
||||
if batch_lengths > self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
max_lengths = length
|
||||
batch.append(token)
|
||||
bucket = []
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
@ -1,23 +0,0 @@
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
|
||||
def default_fn(data):
|
||||
return data
|
||||
|
||||
|
||||
class FilterIterDataPipe(IterableDataset):
|
||||
|
||||
def __init__(self, datapipe, fn=default_fn):
|
||||
self.datapipe = datapipe
|
||||
self.fn = fn
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.datapipe.set_epoch(epoch)
|
||||
|
||||
def __iter__(self):
|
||||
assert callable(self.fn)
|
||||
for data in self.datapipe:
|
||||
if self.fn(data):
|
||||
yield data
|
||||
else:
|
||||
continue
|
||||
@ -1,20 +0,0 @@
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
|
||||
def default_fn(data):
|
||||
return data
|
||||
|
||||
|
||||
class MapperIterDataPipe(IterableDataset):
|
||||
|
||||
def __init__(self, datapipe, fn=default_fn):
|
||||
self.datapipe = datapipe
|
||||
self.fn = fn
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.datapipe.set_epoch(epoch)
|
||||
|
||||
def __iter__(self):
|
||||
assert callable(self.fn)
|
||||
for data in self.datapipe:
|
||||
yield self.fn(data)
|
||||
@ -1,299 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torchaudio
|
||||
import numpy as np
|
||||
|
||||
# import librosa
|
||||
import librosa
|
||||
from kaldiio import ReadHelper
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
from funasr.datasets.large_datasets.datapipes.batch import MaxTokenBucketizerIterDataPipe
|
||||
from funasr.datasets.large_datasets.datapipes.filter import FilterIterDataPipe
|
||||
from funasr.datasets.large_datasets.datapipes.map import MapperIterDataPipe
|
||||
from funasr.datasets.large_datasets.utils.clipping import clipping
|
||||
from funasr.datasets.large_datasets.utils.filter import filter
|
||||
from funasr.datasets.large_datasets.utils.padding import padding
|
||||
from funasr.datasets.large_datasets.utils.tokenize import tokenize
|
||||
|
||||
|
||||
def read_lists(list_file):
|
||||
lists = []
|
||||
with open(list_file, "r", encoding="utf8") as fin:
|
||||
for line in fin:
|
||||
parts = line.strip()
|
||||
lists.append(parts)
|
||||
return lists
|
||||
|
||||
|
||||
class AudioDataset(IterableDataset):
|
||||
def __init__(
|
||||
self,
|
||||
scp_lists,
|
||||
data_names,
|
||||
data_types,
|
||||
frontend_conf=None,
|
||||
shuffle=True,
|
||||
speed_perturb=None,
|
||||
mode="train",
|
||||
):
|
||||
self.scp_lists = scp_lists
|
||||
self.data_names = data_names
|
||||
self.data_types = data_types
|
||||
self.frontend_conf = frontend_conf
|
||||
self.shuffle = shuffle
|
||||
self.mode = mode
|
||||
self.epoch = -1
|
||||
self.rank = 0
|
||||
self.world_size = 1
|
||||
self.worker_id = 0
|
||||
self.num_workers = 1
|
||||
self.speed_perturb = speed_perturb
|
||||
if self.speed_perturb is not None:
|
||||
logging.info("Using speed_perturb: {}".format(speed_perturb))
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
|
||||
def get_rank_data_list(self, data_index):
|
||||
assert dist.is_available()
|
||||
if dist.is_initialized():
|
||||
self.rank = dist.get_rank()
|
||||
self.world_size = dist.get_world_size()
|
||||
else:
|
||||
self.rank = 0
|
||||
self.world_size = 1
|
||||
|
||||
if self.mode == "train":
|
||||
if self.shuffle:
|
||||
random.seed(self.epoch)
|
||||
random.shuffle(data_index)
|
||||
return data_index[self.rank :: self.world_size]
|
||||
|
||||
return data_index
|
||||
|
||||
def get_worker_data_list(self, rank_data_index):
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
if worker_info is None:
|
||||
self.worker_id = 0
|
||||
self.num_workers = 1
|
||||
else:
|
||||
self.worker_id = worker_info.id
|
||||
self.num_workers = worker_info.num_workers
|
||||
|
||||
return rank_data_index[self.worker_id :: self.num_workers]
|
||||
|
||||
def close_reader(self, reader_list):
|
||||
for reader in reader_list:
|
||||
reader.close()
|
||||
|
||||
def __iter__(self):
|
||||
data_index = list(range(len(self.scp_lists)))
|
||||
rank_data_index = self.get_rank_data_list(data_index)
|
||||
worker_data_index = self.get_worker_data_list(rank_data_index)
|
||||
|
||||
for index in worker_data_index:
|
||||
data = dict(scp=self.scp_lists[index])
|
||||
|
||||
assert "scp" in data
|
||||
scp = data["scp"]
|
||||
data_file_list = scp.strip().split()
|
||||
data_name_list = self.data_names.split(",")
|
||||
data_type_list = self.data_types.split(",")
|
||||
|
||||
for file in data_file_list:
|
||||
assert os.path.exists(file), "{} not exists".format(file)
|
||||
|
||||
assert (
|
||||
len(data_file_list) == len(data_name_list) == len(data_type_list)
|
||||
), "The item number of data, data_names, data_types must be the same "
|
||||
|
||||
reader_list = []
|
||||
for data_file, data_type in zip(data_file_list, data_type_list):
|
||||
if data_type == "kaldi_ark":
|
||||
ark_reader = ReadHelper("ark:{}".format(data_file))
|
||||
reader_list.append(ark_reader)
|
||||
elif data_type == "text" or data_type == "sound" or data_type == "text_hotword":
|
||||
text_reader = open(data_file, "r", encoding="utf-8")
|
||||
reader_list.append(text_reader)
|
||||
elif data_type == "none":
|
||||
continue
|
||||
else:
|
||||
raise TypeError("Data type {} is not supported".format(data_type))
|
||||
|
||||
for items in zip(*reader_list):
|
||||
sample_dict = {}
|
||||
for item, (data_name, data_type) in zip(items, zip(data_name_list, data_type_list)):
|
||||
if data_type == "kaldi_ark":
|
||||
key, mat = item
|
||||
sample_dict[data_name] = mat
|
||||
if data_name == "speech":
|
||||
sample_dict["key"] = key
|
||||
elif data_type == "sound":
|
||||
key, path = item.strip().split()
|
||||
try:
|
||||
waveform, sampling_rate = torchaudio.load(path)
|
||||
except:
|
||||
# waveform, sampling_rate = librosa.load(path, dtype='float32')
|
||||
waveform, sampling_rate = librosa.load(path, dtype="float32")
|
||||
if waveform.ndim == 2:
|
||||
waveform = waveform[:, 0]
|
||||
waveform = np.expand_dims(waveform, axis=0)
|
||||
waveform = torch.tensor(waveform)
|
||||
if self.frontend_conf is not None:
|
||||
if sampling_rate != self.frontend_conf["fs"]:
|
||||
waveform = torchaudio.transforms.Resample(
|
||||
orig_freq=sampling_rate, new_freq=self.frontend_conf["fs"]
|
||||
)(waveform)
|
||||
sampling_rate = self.frontend_conf["fs"]
|
||||
waveform = waveform.numpy()
|
||||
mat = waveform[0]
|
||||
if self.speed_perturb is not None:
|
||||
speed = random.choice(self.speed_perturb)
|
||||
if speed != 1.0:
|
||||
mat, _ = torchaudio.sox_effects.apply_effects_tensor(
|
||||
torch.tensor(mat).view(1, -1),
|
||||
sampling_rate,
|
||||
[["speed", str(speed)], ["rate", str(sampling_rate)]],
|
||||
)
|
||||
mat = mat.view(-1).numpy()
|
||||
sample_dict[data_name] = mat
|
||||
sample_dict["sampling_rate"] = sampling_rate
|
||||
if data_name == "speech":
|
||||
sample_dict["key"] = key
|
||||
elif data_type == "text_hotword":
|
||||
text = item
|
||||
segs = text.strip().split()
|
||||
sample_dict[data_name] = segs[1:]
|
||||
if "key" not in sample_dict:
|
||||
sample_dict["key"] = segs[0]
|
||||
sample_dict["hw_tag"] = 1
|
||||
elif data_type == "text_nospace":
|
||||
text = item
|
||||
segs = text.strip().split(maxsplit=1)
|
||||
sample_dict[data_name] = [x for x in segs[1]]
|
||||
if "key" not in sample_dict:
|
||||
sample_dict["key"] = segs[0]
|
||||
else:
|
||||
text = item
|
||||
segs = text.strip().split()
|
||||
sample_dict[data_name] = segs[1:]
|
||||
if "key" not in sample_dict:
|
||||
sample_dict["key"] = segs[0]
|
||||
yield sample_dict
|
||||
|
||||
self.close_reader(reader_list)
|
||||
|
||||
|
||||
def len_fn_example(data):
|
||||
return 1
|
||||
|
||||
|
||||
def len_fn_token(data):
|
||||
assert "speech" in data
|
||||
if "sampling_rate" in data:
|
||||
return (data["speech"].shape[0] / data["sampling_rate"]) * 1000.0
|
||||
else:
|
||||
return data["speech"].shape[0]
|
||||
|
||||
|
||||
def Dataset(
|
||||
data_list_file,
|
||||
dict,
|
||||
seg_dict,
|
||||
punc_dict,
|
||||
bpe_tokenizer,
|
||||
conf,
|
||||
frontend_conf,
|
||||
speed_perturb=None,
|
||||
mode="train",
|
||||
batch_mode="padding",
|
||||
):
|
||||
scp_lists = read_lists(data_list_file)
|
||||
shuffle = conf.get("shuffle", True)
|
||||
data_names = conf.get("data_names", "speech,text")
|
||||
data_types = conf.get("data_types", "kaldi_ark,text")
|
||||
|
||||
pre_hwfile = conf.get("pre_hwlist", None)
|
||||
# pre_prob = conf.get("pre_prob", 0) # unused yet
|
||||
if pre_hwfile is not None:
|
||||
pre_hwlist = []
|
||||
with open(pre_hwfile, "r", encoding="utf-8") as fin:
|
||||
for line in fin.readlines():
|
||||
pre_hwlist.append(line.strip())
|
||||
else:
|
||||
pre_hwlist = None
|
||||
|
||||
hw_config = {
|
||||
"sample_rate": conf.get("sample_rate", 0.6),
|
||||
"double_rate": conf.get("double_rate", 0.1),
|
||||
"hotword_min_length": conf.get("hotword_min_length", 2),
|
||||
"hotword_max_length": conf.get("hotword_max_length", 8),
|
||||
"pre_prob": conf.get("pre_prob", 0.0),
|
||||
"pre_hwlist": pre_hwlist,
|
||||
}
|
||||
|
||||
dataset = AudioDataset(
|
||||
scp_lists,
|
||||
data_names,
|
||||
data_types,
|
||||
frontend_conf=frontend_conf,
|
||||
shuffle=shuffle,
|
||||
speed_perturb=speed_perturb,
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
if "text" in data_names:
|
||||
vocab = {
|
||||
"vocab": dict,
|
||||
"seg_dict": seg_dict,
|
||||
"punc_dict": punc_dict,
|
||||
"bpe_tokenizer": bpe_tokenizer,
|
||||
"hw_config": hw_config,
|
||||
}
|
||||
tokenize_fn = partial(tokenize, **vocab)
|
||||
dataset = MapperIterDataPipe(dataset, fn=tokenize_fn)
|
||||
|
||||
filter_conf = conf.get("filter_conf", {})
|
||||
filter_fn = partial(filter, **filter_conf)
|
||||
dataset = FilterIterDataPipe(dataset, fn=filter_fn)
|
||||
|
||||
if shuffle:
|
||||
buffer_conf = conf.get("shuffle_conf", {})
|
||||
buffer_size = buffer_conf["shuffle_size"]
|
||||
sort_size = buffer_conf["sort_size"]
|
||||
else:
|
||||
buffer_size = 0
|
||||
sort_size = 1
|
||||
|
||||
batch_conf = conf.get("batch_conf", {})
|
||||
batch_size = batch_conf["batch_size"]
|
||||
batch_type = batch_conf["batch_type"]
|
||||
|
||||
assert batch_type in ["example", "token"]
|
||||
if batch_type == "example":
|
||||
len_fn = len_fn_example
|
||||
else:
|
||||
len_fn = len_fn_token
|
||||
|
||||
dataset = MaxTokenBucketizerIterDataPipe(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
len_fn=len_fn,
|
||||
buffer_size=buffer_size,
|
||||
sort_size=sort_size,
|
||||
batch_mode=batch_mode,
|
||||
)
|
||||
|
||||
int_pad_value = conf.get("int_pad_value", -1)
|
||||
float_pad_value = conf.get("float_pad_value", 0.0)
|
||||
padding_conf = {"int_pad_value": int_pad_value, "float_pad_value": float_pad_value}
|
||||
padding_fn = partial(padding, **padding_conf)
|
||||
dataset = MapperIterDataPipe(dataset, fn=padding_fn if batch_mode == "padding" else clipping)
|
||||
|
||||
return dataset
|
||||
@ -1,44 +0,0 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from funasr.datasets.large_datasets.collate_fn import crop_to_max_size
|
||||
|
||||
|
||||
def clipping(data):
|
||||
assert isinstance(data, list)
|
||||
assert "key" in data[0]
|
||||
|
||||
keys = [x["key"] for x in data]
|
||||
|
||||
batch = {}
|
||||
data_names = data[0].keys()
|
||||
for data_name in data_names:
|
||||
if data_name == "key":
|
||||
continue
|
||||
else:
|
||||
if data[0][data_name].dtype.kind == "i":
|
||||
tensor_type = torch.int64
|
||||
else:
|
||||
tensor_type = torch.float32
|
||||
|
||||
tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
|
||||
tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
|
||||
|
||||
length_clip = min(tensor_lengths)
|
||||
tensor_clip = tensor_list[0].new_zeros(
|
||||
len(tensor_list), length_clip, tensor_list[0].shape[1]
|
||||
)
|
||||
for i, (tensor, length) in enumerate(zip(tensor_list, tensor_lengths)):
|
||||
diff = length - length_clip
|
||||
assert diff >= 0
|
||||
if diff == 0:
|
||||
tensor_clip[i] = tensor
|
||||
else:
|
||||
tensor_clip[i] = crop_to_max_size(tensor, length_clip)
|
||||
|
||||
batch[data_name] = tensor_clip
|
||||
batch[data_name + "_lengths"] = torch.tensor(
|
||||
[tensor.shape[0] for tensor in tensor_clip], dtype=torch.long
|
||||
)
|
||||
|
||||
return keys, batch
|
||||
@ -1,27 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
|
||||
def filter(
|
||||
data, speech_length_min=100, speech_length_max=15000, token_length_min=0, token_length_max=200
|
||||
):
|
||||
assert "speech" in data or "text" in data
|
||||
|
||||
if "speech" in data and "text" in data:
|
||||
if "sampling_rate" in data:
|
||||
speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000.0
|
||||
else:
|
||||
speech_length = data["speech"].shape[0]
|
||||
num_tokens = len(data["text"])
|
||||
return (
|
||||
speech_length_min < speech_length < speech_length_max
|
||||
and token_length_min < num_tokens < token_length_max
|
||||
)
|
||||
elif "speech" in data:
|
||||
if "sampling_rate" in data:
|
||||
speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000.0
|
||||
else:
|
||||
speech_length = data["speech"].shape[0]
|
||||
return speech_length_min < speech_length < speech_length_max
|
||||
else:
|
||||
num_tokens = len(data["text"])
|
||||
return token_length_min < num_tokens < token_length_max
|
||||
@ -1,42 +0,0 @@
|
||||
import random
|
||||
|
||||
|
||||
def sample_hotword(
|
||||
length,
|
||||
hotword_min_length,
|
||||
hotword_max_length,
|
||||
sample_rate,
|
||||
double_rate,
|
||||
pre_prob,
|
||||
pre_index=None,
|
||||
pre_hwlist=None,
|
||||
):
|
||||
if length < hotword_min_length:
|
||||
return [-1]
|
||||
if random.random() < sample_rate:
|
||||
if pre_prob > 0 and random.random() < pre_prob and pre_index is not None:
|
||||
return pre_index
|
||||
if length == hotword_min_length:
|
||||
return [0, length - 1]
|
||||
elif random.random() < double_rate and length > hotword_max_length + hotword_min_length + 2:
|
||||
# sample two hotwords in a sentence
|
||||
_max_hw_length = min(hotword_max_length, length // 2)
|
||||
# first hotword
|
||||
start1 = random.randint(0, length // 3)
|
||||
end1 = random.randint(start1 + hotword_min_length - 1, start1 + _max_hw_length - 1)
|
||||
# second hotword
|
||||
start2 = random.randint(end1 + 1, length - hotword_min_length)
|
||||
end2 = random.randint(
|
||||
min(length - 1, start2 + hotword_min_length - 1),
|
||||
min(length - 1, start2 + hotword_max_length - 1),
|
||||
)
|
||||
return [start1, end1, start2, end2]
|
||||
else: # single hotword
|
||||
start = random.randint(0, length - hotword_min_length)
|
||||
end = random.randint(
|
||||
min(length - 1, start + hotword_min_length - 1),
|
||||
min(length - 1, start + hotword_max_length - 1),
|
||||
)
|
||||
return [start, end]
|
||||
else:
|
||||
return [-1]
|
||||
@ -1,30 +0,0 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def build_LFR_features(data, m, n):
|
||||
"""
|
||||
Actually, this implements stacking frames and skipping frames.
|
||||
if m = 1 and n = 1, just return the origin features.
|
||||
if m = 1 and n > 1, it works like skipping.
|
||||
if m > 1 and n = 1, it works like stacking but only support right frames.
|
||||
if m > 1 and n > 1, it works like LFR.
|
||||
|
||||
Args:
|
||||
inputs_batch: inputs is T x D np.ndarray
|
||||
m: number of frames to stack
|
||||
n: number of frames to skip
|
||||
"""
|
||||
|
||||
LFR_inputs = []
|
||||
T = data.shape[0]
|
||||
T_lfr = int(np.ceil(T / n))
|
||||
for i in range(T_lfr):
|
||||
if m <= T - i * n:
|
||||
LFR_inputs.append(np.hstack(data[i * n : i * n + m]))
|
||||
else:
|
||||
num_padding = m - (T - i * n)
|
||||
frame = np.hstack(data[i * n :])
|
||||
for _ in range(num_padding):
|
||||
frame = np.hstack((frame, data[-1]))
|
||||
LFR_inputs.append(frame)
|
||||
return np.vstack(LFR_inputs)
|
||||
@ -1,72 +0,0 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def padding(data, float_pad_value=0.0, int_pad_value=-1):
|
||||
assert isinstance(data, list)
|
||||
assert "key" in data[0]
|
||||
assert "speech" in data[0] or "text" in data[0]
|
||||
|
||||
keys = [x["key"] for x in data]
|
||||
|
||||
batch = {}
|
||||
data_names = data[0].keys()
|
||||
for data_name in data_names:
|
||||
if data_name == "key" or data_name == "sampling_rate":
|
||||
continue
|
||||
else:
|
||||
if data_name != "hotword_indxs":
|
||||
if data[0][data_name].dtype.kind == "i":
|
||||
pad_value = int_pad_value
|
||||
tensor_type = torch.int64
|
||||
else:
|
||||
pad_value = float_pad_value
|
||||
tensor_type = torch.float32
|
||||
|
||||
tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
|
||||
tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
|
||||
tensor_pad = pad_sequence(tensor_list, batch_first=True, padding_value=pad_value)
|
||||
batch[data_name] = tensor_pad
|
||||
batch[data_name + "_lengths"] = tensor_lengths
|
||||
|
||||
# SAC LABEL INCLUDE
|
||||
if "hotword_indxs" in batch:
|
||||
# if hotword indxs in batch
|
||||
# use it to slice hotwords out
|
||||
hotword_list = []
|
||||
hotword_lengths = []
|
||||
text = batch["text"]
|
||||
text_lengths = batch["text_lengths"]
|
||||
hotword_indxs = batch["hotword_indxs"]
|
||||
dha_pad = torch.ones_like(text) * -1
|
||||
_, t1 = text.shape
|
||||
t1 += 1 # TODO: as parameter which is same as predictor_bias
|
||||
nth_hw = 0
|
||||
for b, (hotword_indx, one_text, length) in enumerate(
|
||||
zip(hotword_indxs, text, text_lengths)
|
||||
):
|
||||
dha_pad[b][:length] = 8405
|
||||
if hotword_indx[0] != -1:
|
||||
start, end = int(hotword_indx[0]), int(hotword_indx[1])
|
||||
hotword = one_text[start : end + 1]
|
||||
hotword_list.append(hotword)
|
||||
hotword_lengths.append(end - start + 1)
|
||||
dha_pad[b][start : end + 1] = one_text[start : end + 1]
|
||||
nth_hw += 1
|
||||
if len(hotword_indx) == 4 and hotword_indx[2] != -1:
|
||||
# the second hotword if exist
|
||||
start, end = int(hotword_indx[2]), int(hotword_indx[3])
|
||||
hotword_list.append(one_text[start : end + 1])
|
||||
hotword_lengths.append(end - start + 1)
|
||||
dha_pad[b][start : end + 1] = one_text[start : end + 1]
|
||||
nth_hw += 1
|
||||
hotword_list.append(torch.tensor([1]))
|
||||
hotword_lengths.append(1)
|
||||
hotword_pad = pad_sequence(hotword_list, batch_first=True, padding_value=0)
|
||||
batch["hotword_pad"] = hotword_pad
|
||||
batch["hotword_lengths"] = torch.tensor(hotword_lengths, dtype=torch.int32)
|
||||
batch["dha_pad"] = dha_pad
|
||||
del batch["hotword_indxs"]
|
||||
del batch["hotword_indxs_lengths"]
|
||||
return keys, batch
|
||||
@ -1,93 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
import re
|
||||
import numpy as np
|
||||
from funasr.datasets.large_datasets.utils.hotword_utils import sample_hotword
|
||||
|
||||
|
||||
def forward_segment(text, seg_dict):
|
||||
word_list = []
|
||||
i = 0
|
||||
while i < len(text):
|
||||
longest_word = text[i]
|
||||
for j in range(i + 1, len(text) + 1):
|
||||
word = text[i:j]
|
||||
if word in seg_dict:
|
||||
if len(word) > len(longest_word):
|
||||
longest_word = word
|
||||
word_list.append(longest_word)
|
||||
i += len(longest_word)
|
||||
return word_list
|
||||
|
||||
|
||||
def seg_tokenize(txt, seg_dict):
|
||||
pattern = re.compile(r"^[\u4E00-\u9FA50-9]+$")
|
||||
out_txt = ""
|
||||
for word in txt:
|
||||
word = word.lower()
|
||||
if word in seg_dict:
|
||||
out_txt += seg_dict[word] + " "
|
||||
else:
|
||||
if pattern.match(word):
|
||||
for char in word:
|
||||
if char in seg_dict:
|
||||
out_txt += seg_dict[char] + " "
|
||||
else:
|
||||
out_txt += "<unk>" + " "
|
||||
else:
|
||||
out_txt += "<unk>" + " "
|
||||
return out_txt.strip().split()
|
||||
|
||||
|
||||
def tokenize(data, vocab=None, seg_dict=None, punc_dict=None, bpe_tokenizer=None, hw_config=None):
|
||||
assert "text" in data
|
||||
assert isinstance(vocab, dict)
|
||||
text = data["text"]
|
||||
token = []
|
||||
vad = -2
|
||||
if bpe_tokenizer is not None:
|
||||
text = bpe_tokenizer.text2tokens(" ".join(text))
|
||||
if seg_dict is not None:
|
||||
assert isinstance(seg_dict, dict)
|
||||
text = seg_tokenize(text, seg_dict)
|
||||
|
||||
length = len(text)
|
||||
if "hw_tag" in data:
|
||||
pre_index = None
|
||||
if hw_config["pre_hwlist"] is not None and hw_config["pre_prob"] > 0:
|
||||
# enable preset hotword detect in sampling
|
||||
for hw in hw_config["pre_hwlist"]:
|
||||
hw = " ".join(seg_tokenize(hw, seg_dict))
|
||||
_find = " ".join(text).find(hw)
|
||||
if _find != -1:
|
||||
# _find = text[:_find].count(" ") # bpe sometimes
|
||||
pre_index = [_find, _find + max(hw.count(" "), 1)]
|
||||
break
|
||||
hotword_indxs = sample_hotword(length, **hw_config, pre_index=pre_index)
|
||||
data["hotword_indxs"] = hotword_indxs
|
||||
del data["hw_tag"]
|
||||
for i in range(length):
|
||||
x = text[i]
|
||||
if i == length - 1 and "punc" in data and x.startswith("vad:"):
|
||||
vad = x[4:]
|
||||
if len(vad) == 0:
|
||||
vad = -1
|
||||
else:
|
||||
vad = int(vad)
|
||||
elif x in vocab:
|
||||
token.append(vocab[x])
|
||||
else:
|
||||
token.append(vocab["<unk>"])
|
||||
|
||||
if "punc" in data and punc_dict is not None:
|
||||
punc_token = []
|
||||
for punc in data["punc"]:
|
||||
if punc in punc_dict:
|
||||
punc_token.append(punc_dict[punc])
|
||||
else:
|
||||
punc_token.append(punc_dict["_"])
|
||||
data["punc"] = np.array(punc_token)
|
||||
|
||||
data["text"] = np.array(token)
|
||||
if vad is not -2:
|
||||
data["vad_indexes"] = np.array([vad], dtype=np.int64)
|
||||
return data
|
||||
@ -85,8 +85,10 @@ def download_from_ms(**kwargs):
|
||||
|
||||
install_requirements(requirements)
|
||||
if kwargs.get("trust_remote_code", False):
|
||||
from funasr.utils.dynamic_import import import_module_from_path
|
||||
|
||||
import model
|
||||
model_code = kwargs.get("remote_code", "model")
|
||||
import_module_from_path(model_code)
|
||||
|
||||
# from funasr.register import tables
|
||||
# tables.print("model")
|
||||
|
||||
@ -2,6 +2,7 @@ import importlib.util
|
||||
|
||||
import importlib.util
|
||||
import inspect
|
||||
import os.path
|
||||
|
||||
|
||||
def load_module_from_path(file_path):
|
||||
@ -18,6 +19,21 @@ def load_module_from_path(file_path):
|
||||
return module
|
||||
|
||||
|
||||
def import_module_from_path(file_path):
|
||||
|
||||
current_working_directory = os.getcwd()
|
||||
|
||||
# 获取当前文件所在的目录
|
||||
file_dir = os.path.dirname(file_path)
|
||||
file_name = os.path.basename(file_path)
|
||||
module_name = file_path.split("/")[-1].replace(".py", "")
|
||||
|
||||
if len(file_dir) > 0:
|
||||
os.chdir(file_dir)
|
||||
importlib.import_module(module_name)
|
||||
os.chdir(current_working_directory)
|
||||
|
||||
|
||||
#
|
||||
# def load_module_from_path(module_name, file_path):
|
||||
# """
|
||||
|
||||
Loading…
Reference in New Issue
Block a user