FunASR/funasr/register.py
2024-02-19 14:59:26 +08:00

84 lines
3.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import inspect
from dataclasses import dataclass
import re
@dataclass
class RegisterTables:
model_classes = {}
frontend_classes = {}
specaug_classes = {}
normalize_classes = {}
encoder_classes = {}
decoder_classes = {}
joint_network_classes = {}
predictor_classes = {}
stride_conv_classes = {}
tokenizer_classes = {}
batch_sampler_classes = {}
dataset_classes = {}
index_ds_classes = {}
def print(self, key=None):
print("\ntables: \n")
fields = vars(self)
for classes_key, classes_dict in fields.items():
flag = True
if key is not None:
flag = key in classes_key
if classes_key.endswith("_meta") and flag:
print(f"----------- ** {classes_key.replace('_meta', '')} ** --------------")
headers = ["class name", "class location"]
metas = []
for register_key, meta in classes_dict.items():
metas.append(meta)
metas.sort(key=lambda x: x[0])
data = [headers] + metas
col_widths = [max(len(str(item)) for item in col) for col in zip(*data)]
for row in data:
print("| " + " | ".join(str(item).ljust(width) for item, width in zip(row, col_widths)) + " |")
print("\n")
def register(self, register_tables_key: str, key=None):
def decorator(target_class):
if not hasattr(self, register_tables_key):
setattr(self, register_tables_key, {})
logging.info("new registry table has been added: {}".format(register_tables_key))
registry = getattr(self, register_tables_key)
registry_key = key if key is not None else target_class.__name__
# assert not registry_key in registry, "(key: {} / class: {}) has been registered alreadyin {}".format(
# registry_key, target_class, register_tables_key)
registry[registry_key] = target_class
# meta headers = ["class name", "register name", "class location"]
register_tables_key_meta = register_tables_key + "_meta"
if not hasattr(self, register_tables_key_meta):
setattr(self, register_tables_key_meta, {})
registry_meta = getattr(self, register_tables_key_meta)
# doc = target_class.__doc__
class_file = inspect.getfile(target_class)
class_line = inspect.getsourcelines(target_class)[1]
pattern = r'^.+/funasr/'
class_file = re.sub(pattern, 'funasr/', class_file)
meata_data = [f"{target_class.__name__}", f"{class_file}:{class_line}"]
# meata_data = [f"{target_class.__name__}", f"{registry_key}", f"{class_file}:{class_line}"]
registry_meta[registry_key] = meata_data
# print(f"Registering class: {class_file}:{class_line} - {target_class.__name__} as {registry_key}")
return target_class
return decorator
tables = RegisterTables()
import funasr