| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283 |
- import logging
- import inspect
- from dataclasses import dataclass
- import re
- @dataclass
- class RegisterTables:
- model_classes = {}
- frontend_classes = {}
- specaug_classes = {}
- normalize_classes = {}
- encoder_classes = {}
- decoder_classes = {}
- joint_network_classes = {}
- predictor_classes = {}
- stride_conv_classes = {}
- tokenizer_classes = {}
- batch_sampler_classes = {}
- dataset_classes = {}
- index_ds_classes = {}
- def print(self, key=None):
- print("\ntables: \n")
- fields = vars(self)
- for classes_key, classes_dict in fields.items():
-
- flag = True
- if key is not None:
- flag = key in classes_key
- if classes_key.endswith("_meta") and flag:
- print(f"----------- ** {classes_key.replace('_meta', '')} ** --------------")
- headers = ["class name", "class location"]
- metas = []
- for register_key, meta in classes_dict.items():
- metas.append(meta)
- metas.sort(key=lambda x: x[0])
- data = [headers] + metas
- col_widths = [max(len(str(item)) for item in col) for col in zip(*data)]
-
- for row in data:
- print("| " + " | ".join(str(item).ljust(width) for item, width in zip(row, col_widths)) + " |")
- print("\n")
- def register(self, register_tables_key: str, key=None):
- def decorator(target_class):
-
- if not hasattr(self, register_tables_key):
- setattr(self, register_tables_key, {})
- logging.info("new registry table has been added: {}".format(register_tables_key))
-
- registry = getattr(self, register_tables_key)
- registry_key = key if key is not None else target_class.__name__
-
- assert not registry_key in registry, "(key: {} / class: {}) has been registered already,in {}".format(
- registry_key, target_class, register_tables_key)
-
- registry[registry_key] = target_class
-
- # meta, headers = ["class name", "register name", "class location"]
- register_tables_key_meta = register_tables_key + "_meta"
- if not hasattr(self, register_tables_key_meta):
- setattr(self, register_tables_key_meta, {})
- registry_meta = getattr(self, register_tables_key_meta)
- # doc = target_class.__doc__
- class_file = inspect.getfile(target_class)
- class_line = inspect.getsourcelines(target_class)[1]
- pattern = r'^.+/funasr/'
- class_file = re.sub(pattern, 'funasr/', class_file)
- meata_data = [f"{target_class.__name__}", f"{class_file}:{class_line}"]
- # meata_data = [f"{target_class.__name__}", f"{registry_key}", f"{class_file}:{class_line}"]
- registry_meta[registry_key] = meata_data
- # print(f"Registering class: {class_file}:{class_line} - {target_class.__name__} as {registry_key}")
- return target_class
-
- return decorator
- tables = RegisterTables()
- import funasr
|