get_default_kwargs.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import inspect
  2. class Invalid:
  3. """Marker object for not serializable-object"""
  4. def get_default_kwargs(func):
  5. """Get the default values of the input function.
  6. Examples:
  7. >>> def func(a, b=3): pass
  8. >>> get_default_kwargs(func)
  9. {'b': 3}
  10. """
  11. def yaml_serializable(value):
  12. # isinstance(x, tuple) includes namedtuple, so type is used here
  13. if type(value) is tuple:
  14. return yaml_serializable(list(value))
  15. elif isinstance(value, set):
  16. return yaml_serializable(list(value))
  17. elif isinstance(value, dict):
  18. if not all(isinstance(k, str) for k in value):
  19. return Invalid
  20. retval = {}
  21. for k, v in value.items():
  22. v2 = yaml_serializable(v)
  23. # Register only valid object
  24. if v2 not in (Invalid, inspect.Parameter.empty):
  25. retval[k] = v2
  26. return retval
  27. elif isinstance(value, list):
  28. retval = []
  29. for v in value:
  30. v2 = yaml_serializable(v)
  31. # If any elements in the list are invalid,
  32. # the list also becomes invalid
  33. if v2 is Invalid:
  34. return Invalid
  35. else:
  36. retval.append(v2)
  37. return retval
  38. elif value in (inspect.Parameter.empty, None):
  39. return value
  40. elif isinstance(value, (float, int, complex, bool, str, bytes)):
  41. return value
  42. else:
  43. return Invalid
  44. # params: An ordered mapping of inspect.Parameter
  45. params = inspect.signature(func).parameters
  46. data = {p.name: p.default for p in params.values()}
  47. # Remove not yaml-serializable object
  48. data = yaml_serializable(data)
  49. return data