register.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import logging
  2. import inspect
  3. from dataclasses import dataclass
  4. import re
  5. @dataclass
  6. class RegisterTables:
  7. model_classes = {}
  8. frontend_classes = {}
  9. specaug_classes = {}
  10. normalize_classes = {}
  11. encoder_classes = {}
  12. decoder_classes = {}
  13. joint_network_classes = {}
  14. predictor_classes = {}
  15. stride_conv_classes = {}
  16. tokenizer_classes = {}
  17. batch_sampler_classes = {}
  18. dataset_classes = {}
  19. index_ds_classes = {}
  20. def print(self, key=None):
  21. print("\ntables: \n")
  22. fields = vars(self)
  23. for classes_key, classes_dict in fields.items():
  24. flag = True
  25. if key is not None:
  26. flag = key in classes_key
  27. if classes_key.endswith("_meta") and flag:
  28. print(f"----------- ** {classes_key.replace('_meta', '')} ** --------------")
  29. headers = ["class name", "class location"]
  30. metas = []
  31. for register_key, meta in classes_dict.items():
  32. metas.append(meta)
  33. metas.sort(key=lambda x: x[0])
  34. data = [headers] + metas
  35. col_widths = [max(len(str(item)) for item in col) for col in zip(*data)]
  36. for row in data:
  37. print("| " + " | ".join(str(item).ljust(width) for item, width in zip(row, col_widths)) + " |")
  38. print("\n")
  39. def register(self, register_tables_key: str, key=None):
  40. def decorator(target_class):
  41. if not hasattr(self, register_tables_key):
  42. setattr(self, register_tables_key, {})
  43. logging.info("new registry table has been added: {}".format(register_tables_key))
  44. registry = getattr(self, register_tables_key)
  45. registry_key = key if key is not None else target_class.__name__
  46. assert not registry_key in registry, "(key: {} / class: {}) has been registered already,in {}".format(
  47. registry_key, target_class, register_tables_key)
  48. registry[registry_key] = target_class
  49. # meta, headers = ["class name", "register name", "class location"]
  50. register_tables_key_meta = register_tables_key + "_meta"
  51. if not hasattr(self, register_tables_key_meta):
  52. setattr(self, register_tables_key_meta, {})
  53. registry_meta = getattr(self, register_tables_key_meta)
  54. # doc = target_class.__doc__
  55. class_file = inspect.getfile(target_class)
  56. class_line = inspect.getsourcelines(target_class)[1]
  57. pattern = r'^.+/funasr/'
  58. class_file = re.sub(pattern, 'funasr/', class_file)
  59. meata_data = [f"{target_class.__name__}", f"{class_file}:{class_line}"]
  60. # meata_data = [f"{target_class.__name__}", f"{registry_key}", f"{class_file}:{class_line}"]
  61. registry_meta[registry_key] = meata_data
  62. # print(f"Registering class: {class_file}:{class_line} - {target_class.__name__} as {registry_key}")
  63. return target_class
  64. return decorator
  65. tables = RegisterTables()
  66. import funasr