class_choices.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. from typing import Mapping
  2. from typing import Optional
  3. from typing import Tuple
  4. from funasr.utils.nested_dict_action import NestedDictAction
  5. from funasr.utils.types import str_or_none
  6. class ClassChoices:
  7. """Helper class to manage the options for variable objects and its configuration.
  8. Example:
  9. >>> class A:
  10. ... def __init__(self, foo=3): pass
  11. >>> class B:
  12. ... def __init__(self, bar="aaaa"): pass
  13. >>> choices = ClassChoices("var", dict(a=A, b=B), default="a")
  14. >>> import argparse
  15. >>> parser = argparse.ArgumentParser()
  16. >>> choices.add_arguments(parser)
  17. >>> args = parser.parse_args(["--var", "a", "--var_conf", "foo=4")
  18. >>> args.var
  19. a
  20. >>> args.var_conf
  21. {"foo": 4}
  22. >>> class_obj = choices.get_class(args.var)
  23. >>> a_object = class_obj(**args.var_conf)
  24. """
  25. def __init__(
  26. self,
  27. name: str,
  28. classes: Mapping[str, type],
  29. type_check: type = None,
  30. default: str = None,
  31. optional: bool = False,
  32. ):
  33. self.name = name
  34. self.base_type = type_check
  35. self.classes = {k.lower(): v for k, v in classes.items()}
  36. if "none" in self.classes or "nil" in self.classes or "null" in self.classes:
  37. raise ValueError('"none", "nil", and "null" are reserved.')
  38. if type_check is not None:
  39. for v in self.classes.values():
  40. if not issubclass(v, type_check):
  41. raise ValueError(f"must be {type_check.__name__}, but got {v}")
  42. self.optional = optional
  43. self.default = default
  44. if default is None:
  45. self.optional = True
  46. def choices(self) -> Tuple[Optional[str], ...]:
  47. retval = tuple(self.classes)
  48. if self.optional:
  49. return retval + (None,)
  50. else:
  51. return retval
  52. def get_class(self, name: Optional[str]) -> Optional[type]:
  53. if name is None or (self.optional and name.lower() == ("none", "null", "nil")):
  54. retval = None
  55. elif name.lower() in self.classes:
  56. class_obj = self.classes[name]
  57. retval = class_obj
  58. else:
  59. raise ValueError(
  60. f"--{self.name} must be one of {self.choices()}: "
  61. f"--{self.name} {name.lower()}"
  62. )
  63. return retval
  64. def add_arguments(self, parser):
  65. parser.add_argument(
  66. f"--{self.name}",
  67. type=lambda x: str_or_none(x.lower()),
  68. default=self.default,
  69. choices=self.choices(),
  70. help=f"The {self.name} type",
  71. )
  72. parser.add_argument(
  73. f"--{self.name}_conf",
  74. action=NestedDictAction,
  75. default=dict(),
  76. help=f"The keyword arguments for {self.name}",
  77. )