class_choices.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. from typing import Mapping
  2. from typing import Optional
  3. from typing import Tuple
  4. from typeguard import check_argument_types
  5. from typeguard import check_return_type
  6. from funasr.utils.nested_dict_action import NestedDictAction
  7. from funasr.utils.types import str_or_none
  8. class ClassChoices:
  9. """Helper class to manage the options for variable objects and its configuration.
  10. Example:
  11. >>> class A:
  12. ... def __init__(self, foo=3): pass
  13. >>> class B:
  14. ... def __init__(self, bar="aaaa"): pass
  15. >>> choices = ClassChoices("var", dict(a=A, b=B), default="a")
  16. >>> import argparse
  17. >>> parser = argparse.ArgumentParser()
  18. >>> choices.add_arguments(parser)
  19. >>> args = parser.parse_args(["--var", "a", "--var_conf", "foo=4")
  20. >>> args.var
  21. a
  22. >>> args.var_conf
  23. {"foo": 4}
  24. >>> class_obj = choices.get_class(args.var)
  25. >>> a_object = class_obj(**args.var_conf)
  26. """
  27. def __init__(
  28. self,
  29. name: str,
  30. classes: Mapping[str, type],
  31. type_check: type = None,
  32. default: str = None,
  33. optional: bool = False,
  34. ):
  35. assert check_argument_types()
  36. self.name = name
  37. self.base_type = type_check
  38. self.classes = {k.lower(): v for k, v in classes.items()}
  39. if "none" in self.classes or "nil" in self.classes or "null" in self.classes:
  40. raise ValueError('"none", "nil", and "null" are reserved.')
  41. if type_check is not None:
  42. for v in self.classes.values():
  43. if not issubclass(v, type_check):
  44. raise ValueError(f"must be {type_check.__name__}, but got {v}")
  45. self.optional = optional
  46. self.default = default
  47. if default is None:
  48. self.optional = True
  49. def choices(self) -> Tuple[Optional[str], ...]:
  50. retval = tuple(self.classes)
  51. if self.optional:
  52. return retval + (None,)
  53. else:
  54. return retval
  55. def get_class(self, name: Optional[str]) -> Optional[type]:
  56. assert check_argument_types()
  57. if name is None or (self.optional and name.lower() == ("none", "null", "nil")):
  58. retval = None
  59. elif name.lower() in self.classes:
  60. class_obj = self.classes[name]
  61. assert check_return_type(class_obj)
  62. retval = class_obj
  63. else:
  64. raise ValueError(
  65. f"--{self.name} must be one of {self.choices()}: "
  66. f"--{self.name} {name.lower()}"
  67. )
  68. return retval
  69. def add_arguments(self, parser):
  70. parser.add_argument(
  71. f"--{self.name}",
  72. type=lambda x: str_or_none(x.lower()),
  73. default=self.default,
  74. choices=self.choices(),
  75. help=f"The {self.name} type",
  76. )
  77. parser.add_argument(
  78. f"--{self.name}_conf",
  79. action=NestedDictAction,
  80. default=dict(),
  81. help=f"The keyword arguments for {self.name}",
  82. )