mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
funasr1.0
This commit is contained in:
parent
831c48a886
commit
2a0b2c795b
@ -175,7 +175,7 @@ class AutoModel:
|
|||||||
# build tokenizer
|
# build tokenizer
|
||||||
tokenizer = kwargs.get("tokenizer", None)
|
tokenizer = kwargs.get("tokenizer", None)
|
||||||
if tokenizer is not None:
|
if tokenizer is not None:
|
||||||
tokenizer_class = tables.tokenizer_classes.get(tokenizer.lower())
|
tokenizer_class = tables.tokenizer_classes.get(tokenizer)
|
||||||
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
|
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
|
||||||
kwargs["tokenizer"] = tokenizer
|
kwargs["tokenizer"] = tokenizer
|
||||||
kwargs["token_list"] = tokenizer.token_list
|
kwargs["token_list"] = tokenizer.token_list
|
||||||
@ -186,13 +186,13 @@ class AutoModel:
|
|||||||
# build frontend
|
# build frontend
|
||||||
frontend = kwargs.get("frontend", None)
|
frontend = kwargs.get("frontend", None)
|
||||||
if frontend is not None:
|
if frontend is not None:
|
||||||
frontend_class = tables.frontend_classes.get(frontend.lower())
|
frontend_class = tables.frontend_classes.get(frontend)
|
||||||
frontend = frontend_class(**kwargs["frontend_conf"])
|
frontend = frontend_class(**kwargs["frontend_conf"])
|
||||||
kwargs["frontend"] = frontend
|
kwargs["frontend"] = frontend
|
||||||
kwargs["input_size"] = frontend.output_size()
|
kwargs["input_size"] = frontend.output_size()
|
||||||
|
|
||||||
# build model
|
# build model
|
||||||
model_class = tables.model_classes.get(kwargs["model"].lower())
|
model_class = tables.model_classes.get(kwargs["model"])
|
||||||
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
|
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
|
||||||
model.eval()
|
model.eval()
|
||||||
model.to(device)
|
model.to(device)
|
||||||
@ -443,7 +443,7 @@ class AutoFrontend:
|
|||||||
# build frontend
|
# build frontend
|
||||||
frontend = kwargs.get("frontend", None)
|
frontend = kwargs.get("frontend", None)
|
||||||
if frontend is not None:
|
if frontend is not None:
|
||||||
frontend_class = tables.frontend_classes.get(frontend.lower())
|
frontend_class = tables.frontend_classes.get(frontend)
|
||||||
frontend = frontend_class(**kwargs["frontend_conf"])
|
frontend = frontend_class(**kwargs["frontend_conf"])
|
||||||
|
|
||||||
self.frontend = frontend
|
self.frontend = frontend
|
||||||
|
|||||||
@ -64,14 +64,14 @@ def main(**kwargs):
|
|||||||
|
|
||||||
tokenizer = kwargs.get("tokenizer", None)
|
tokenizer = kwargs.get("tokenizer", None)
|
||||||
if tokenizer is not None:
|
if tokenizer is not None:
|
||||||
tokenizer_class = tables.tokenizer_classes.get(tokenizer.lower())
|
tokenizer_class = tables.tokenizer_classes.get(tokenizer)
|
||||||
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
|
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
|
||||||
kwargs["tokenizer"] = tokenizer
|
kwargs["tokenizer"] = tokenizer
|
||||||
|
|
||||||
# build frontend if frontend is none None
|
# build frontend if frontend is none None
|
||||||
frontend = kwargs.get("frontend", None)
|
frontend = kwargs.get("frontend", None)
|
||||||
if frontend is not None:
|
if frontend is not None:
|
||||||
frontend_class = tables.frontend_classes.get(frontend.lower())
|
frontend_class = tables.frontend_classes.get(frontend)
|
||||||
frontend = frontend_class(**kwargs["frontend_conf"])
|
frontend = frontend_class(**kwargs["frontend_conf"])
|
||||||
kwargs["frontend"] = frontend
|
kwargs["frontend"] = frontend
|
||||||
kwargs["input_size"] = frontend.output_size()
|
kwargs["input_size"] = frontend.output_size()
|
||||||
@ -79,7 +79,7 @@ def main(**kwargs):
|
|||||||
# import pdb;
|
# import pdb;
|
||||||
# pdb.set_trace()
|
# pdb.set_trace()
|
||||||
# build model
|
# build model
|
||||||
model_class = tables.model_classes.get(kwargs["model"].lower())
|
model_class = tables.model_classes.get(kwargs["model"])
|
||||||
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
|
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
|
||||||
|
|
||||||
|
|
||||||
@ -141,12 +141,12 @@ def main(**kwargs):
|
|||||||
# import pdb;
|
# import pdb;
|
||||||
# pdb.set_trace()
|
# pdb.set_trace()
|
||||||
# dataset
|
# dataset
|
||||||
dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset").lower())
|
dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
|
||||||
dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
|
dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
|
||||||
|
|
||||||
# dataloader
|
# dataloader
|
||||||
batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
|
batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
|
||||||
batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler.lower())
|
batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
|
||||||
if batch_sampler is not None:
|
if batch_sampler is not None:
|
||||||
batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
|
batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
|
||||||
dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
|
dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
|
||||||
|
|||||||
@ -13,6 +13,9 @@ from funasr.register import tables
|
|||||||
|
|
||||||
@tables.register("dataset_classes", "AudioDataset")
|
@tables.register("dataset_classes", "AudioDataset")
|
||||||
class AudioDataset(torch.utils.data.Dataset):
|
class AudioDataset(torch.utils.data.Dataset):
|
||||||
|
"""
|
||||||
|
AudioDataset
|
||||||
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
path,
|
path,
|
||||||
index_ds: str = None,
|
index_ds: str = None,
|
||||||
@ -22,16 +25,16 @@ class AudioDataset(torch.utils.data.Dataset):
|
|||||||
float_pad_value: float = 0.0,
|
float_pad_value: float = 0.0,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
index_ds_class = tables.index_ds_classes.get(index_ds.lower())
|
index_ds_class = tables.index_ds_classes.get(index_ds)
|
||||||
self.index_ds = index_ds_class(path)
|
self.index_ds = index_ds_class(path)
|
||||||
preprocessor_speech = kwargs.get("preprocessor_speech", None)
|
preprocessor_speech = kwargs.get("preprocessor_speech", None)
|
||||||
if preprocessor_speech:
|
if preprocessor_speech:
|
||||||
preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech.lower())
|
preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech)
|
||||||
preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
|
preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
|
||||||
self.preprocessor_speech = preprocessor_speech
|
self.preprocessor_speech = preprocessor_speech
|
||||||
preprocessor_text = kwargs.get("preprocessor_text", None)
|
preprocessor_text = kwargs.get("preprocessor_text", None)
|
||||||
if preprocessor_text:
|
if preprocessor_text:
|
||||||
preprocessor_text_class = tables.preprocessor_text_classes.get(preprocessor_text.lower())
|
preprocessor_text_class = tables.preprocessor_text_classes.get(preprocessor_text)
|
||||||
preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
|
preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
|
||||||
self.preprocessor_text = preprocessor_text
|
self.preprocessor_text = preprocessor_text
|
||||||
|
|
||||||
|
|||||||
@ -46,7 +46,7 @@ class CTTransformer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
self.embed = nn.Embedding(vocab_size, embed_unit)
|
self.embed = nn.Embedding(vocab_size, embed_unit)
|
||||||
encoder_class = tables.encoder_classes.get(encoder.lower())
|
encoder_class = tables.encoder_classes.get(encoder)
|
||||||
encoder = encoder_class(**encoder_conf)
|
encoder = encoder_class(**encoder_conf)
|
||||||
|
|
||||||
self.decoder = nn.Linear(att_unit, punc_size)
|
self.decoder = nn.Linear(att_unit, punc_size)
|
||||||
|
|||||||
@ -268,7 +268,7 @@ class FsmnVADStreaming(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.vad_opts = VADXOptions(**kwargs)
|
self.vad_opts = VADXOptions(**kwargs)
|
||||||
|
|
||||||
encoder_class = tables.encoder_classes.get(encoder.lower())
|
encoder_class = tables.encoder_classes.get(encoder)
|
||||||
encoder = encoder_class(**encoder_conf)
|
encoder = encoder_class(**encoder_conf)
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
|
|
||||||
|
|||||||
@ -41,15 +41,15 @@ class MonotonicAligner(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if specaug is not None:
|
if specaug is not None:
|
||||||
specaug_class = tables.specaug_classes.get(specaug.lower())
|
specaug_class = tables.specaug_classes.get(specaug)
|
||||||
specaug = specaug_class(**specaug_conf)
|
specaug = specaug_class(**specaug_conf)
|
||||||
if normalize is not None:
|
if normalize is not None:
|
||||||
normalize_class = tables.normalize_classes.get(normalize.lower())
|
normalize_class = tables.normalize_classes.get(normalize)
|
||||||
normalize = normalize_class(**normalize_conf)
|
normalize = normalize_class(**normalize_conf)
|
||||||
encoder_class = tables.encoder_classes.get(encoder.lower())
|
encoder_class = tables.encoder_classes.get(encoder)
|
||||||
encoder = encoder_class(input_size=input_size, **encoder_conf)
|
encoder = encoder_class(input_size=input_size, **encoder_conf)
|
||||||
encoder_output_size = encoder.output_size()
|
encoder_output_size = encoder.output_size()
|
||||||
predictor_class = tables.predictor_classes.get(predictor.lower())
|
predictor_class = tables.predictor_classes.get(predictor)
|
||||||
predictor = predictor_class(**predictor_conf)
|
predictor = predictor_class(**predictor_conf)
|
||||||
self.specaug = specaug
|
self.specaug = specaug
|
||||||
self.normalize = normalize
|
self.normalize = normalize
|
||||||
|
|||||||
@ -79,17 +79,17 @@ class Paraformer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if specaug is not None:
|
if specaug is not None:
|
||||||
specaug_class = tables.specaug_classes.get(specaug.lower())
|
specaug_class = tables.specaug_classes.get(specaug)
|
||||||
specaug = specaug_class(**specaug_conf)
|
specaug = specaug_class(**specaug_conf)
|
||||||
if normalize is not None:
|
if normalize is not None:
|
||||||
normalize_class = tables.normalize_classes.get(normalize.lower())
|
normalize_class = tables.normalize_classes.get(normalize)
|
||||||
normalize = normalize_class(**normalize_conf)
|
normalize = normalize_class(**normalize_conf)
|
||||||
encoder_class = tables.encoder_classes.get(encoder.lower())
|
encoder_class = tables.encoder_classes.get(encoder)
|
||||||
encoder = encoder_class(input_size=input_size, **encoder_conf)
|
encoder = encoder_class(input_size=input_size, **encoder_conf)
|
||||||
encoder_output_size = encoder.output_size()
|
encoder_output_size = encoder.output_size()
|
||||||
|
|
||||||
if decoder is not None:
|
if decoder is not None:
|
||||||
decoder_class = tables.decoder_classes.get(decoder.lower())
|
decoder_class = tables.decoder_classes.get(decoder)
|
||||||
decoder = decoder_class(
|
decoder = decoder_class(
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
encoder_output_size=encoder_output_size,
|
encoder_output_size=encoder_output_size,
|
||||||
@ -104,7 +104,7 @@ class Paraformer(nn.Module):
|
|||||||
odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
|
odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
|
||||||
)
|
)
|
||||||
if predictor is not None:
|
if predictor is not None:
|
||||||
predictor_class = tables.predictor_classes.get(predictor.lower())
|
predictor_class = tables.predictor_classes.get(predictor)
|
||||||
predictor = predictor_class(**predictor_conf)
|
predictor = predictor_class(**predictor_conf)
|
||||||
|
|
||||||
# note that eos is the same as sos (equivalent ID)
|
# note that eos is the same as sos (equivalent ID)
|
||||||
|
|||||||
@ -90,7 +90,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
|||||||
seaco_decoder = kwargs.get("seaco_decoder", None)
|
seaco_decoder = kwargs.get("seaco_decoder", None)
|
||||||
if seaco_decoder is not None:
|
if seaco_decoder is not None:
|
||||||
seaco_decoder_conf = kwargs.get("seaco_decoder_conf")
|
seaco_decoder_conf = kwargs.get("seaco_decoder_conf")
|
||||||
seaco_decoder_class = tables.decoder_classes.get(seaco_decoder.lower())
|
seaco_decoder_class = tables.decoder_classes.get(seaco_decoder)
|
||||||
self.seaco_decoder = seaco_decoder_class(
|
self.seaco_decoder = seaco_decoder_class(
|
||||||
vocab_size=self.vocab_size,
|
vocab_size=self.vocab_size,
|
||||||
encoder_output_size=self.inner_dim,
|
encoder_output_size=self.inner_dim,
|
||||||
|
|||||||
@ -60,19 +60,19 @@ class Transformer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if frontend is not None:
|
if frontend is not None:
|
||||||
frontend_class = tables.frontend_classes.get_class(frontend.lower())
|
frontend_class = tables.frontend_classes.get_class(frontend)
|
||||||
frontend = frontend_class(**frontend_conf)
|
frontend = frontend_class(**frontend_conf)
|
||||||
if specaug is not None:
|
if specaug is not None:
|
||||||
specaug_class = tables.specaug_classes.get_class(specaug.lower())
|
specaug_class = tables.specaug_classes.get_class(specaug)
|
||||||
specaug = specaug_class(**specaug_conf)
|
specaug = specaug_class(**specaug_conf)
|
||||||
if normalize is not None:
|
if normalize is not None:
|
||||||
normalize_class = tables.normalize_classes.get_class(normalize.lower())
|
normalize_class = tables.normalize_classes.get_class(normalize)
|
||||||
normalize = normalize_class(**normalize_conf)
|
normalize = normalize_class(**normalize_conf)
|
||||||
encoder_class = tables.encoder_classes.get_class(encoder.lower())
|
encoder_class = tables.encoder_classes.get_class(encoder)
|
||||||
encoder = encoder_class(input_size=input_size, **encoder_conf)
|
encoder = encoder_class(input_size=input_size, **encoder_conf)
|
||||||
encoder_output_size = encoder.output_size()
|
encoder_output_size = encoder.output_size()
|
||||||
if decoder is not None:
|
if decoder is not None:
|
||||||
decoder_class = tables.decoder_classes.get_class(decoder.lower())
|
decoder_class = tables.decoder_classes.get_class(decoder)
|
||||||
decoder = decoder_class(
|
decoder = decoder_class(
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
encoder_output_size=encoder_output_size,
|
encoder_output_size=encoder_output_size,
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import inspect
|
import inspect
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
import re
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RegisterTables:
|
class RegisterTables:
|
||||||
@ -29,7 +29,7 @@ class RegisterTables:
|
|||||||
flag = key in classes_key
|
flag = key in classes_key
|
||||||
if classes_key.endswith("_meta") and flag:
|
if classes_key.endswith("_meta") and flag:
|
||||||
print(f"----------- ** {classes_key.replace('_meta', '')} ** --------------")
|
print(f"----------- ** {classes_key.replace('_meta', '')} ** --------------")
|
||||||
headers = ["class name", "register name", "class location"]
|
headers = ["class name", "class location"]
|
||||||
metas = []
|
metas = []
|
||||||
for register_key, meta in classes_dict.items():
|
for register_key, meta in classes_dict.items():
|
||||||
metas.append(meta)
|
metas.append(meta)
|
||||||
@ -51,8 +51,7 @@ class RegisterTables:
|
|||||||
|
|
||||||
registry = getattr(self, register_tables_key)
|
registry = getattr(self, register_tables_key)
|
||||||
registry_key = key if key is not None else target_class.__name__
|
registry_key = key if key is not None else target_class.__name__
|
||||||
registry_key = registry_key.lower()
|
|
||||||
# import pdb; pdb.set_trace()
|
|
||||||
assert not registry_key in registry, "(key: {} / class: {}) has been registered already,in {}".format(
|
assert not registry_key in registry, "(key: {} / class: {}) has been registered already,in {}".format(
|
||||||
registry_key, target_class, register_tables_key)
|
registry_key, target_class, register_tables_key)
|
||||||
|
|
||||||
@ -63,9 +62,13 @@ class RegisterTables:
|
|||||||
if not hasattr(self, register_tables_key_meta):
|
if not hasattr(self, register_tables_key_meta):
|
||||||
setattr(self, register_tables_key_meta, {})
|
setattr(self, register_tables_key_meta, {})
|
||||||
registry_meta = getattr(self, register_tables_key_meta)
|
registry_meta = getattr(self, register_tables_key_meta)
|
||||||
|
# doc = target_class.__doc__
|
||||||
class_file = inspect.getfile(target_class)
|
class_file = inspect.getfile(target_class)
|
||||||
class_line = inspect.getsourcelines(target_class)[1]
|
class_line = inspect.getsourcelines(target_class)[1]
|
||||||
meata_data = [f"{target_class.__name__}", f"{registry_key}", f"{class_file}:{class_line}"]
|
pattern = r'^.+/funasr/'
|
||||||
|
class_file = re.sub(pattern, 'funasr/', class_file)
|
||||||
|
meata_data = [f"{target_class.__name__}", f"{class_file}:{class_line}"]
|
||||||
|
# meata_data = [f"{target_class.__name__}", f"{registry_key}", f"{class_file}:{class_line}"]
|
||||||
registry_meta[registry_key] = meata_data
|
registry_meta[registry_key] = meata_data
|
||||||
# print(f"Registering class: {class_file}:{class_line} - {target_class.__name__} as {registry_key}")
|
# print(f"Registering class: {class_file}:{class_line} - {target_class.__name__} as {registry_key}")
|
||||||
return target_class
|
return target_class
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user