async_utils.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import asyncio
  2. from concurrent import futures
  3. from concurrent.futures import ThreadPoolExecutor
  4. from typing import Callable, Coroutine, Iterable, List
  5. GENERAL_TIMEOUT: int = 15
  6. EXECUTOR = ThreadPoolExecutor()
  7. async def call_sync_from_async(fn: Callable, *args, **kwargs):
  8. """
  9. Shorthand for running a function in the default background thread pool executor
  10. and awaiting the result. The nature of synchronous code is that the future
  11. returned by this function is not cancellable
  12. """
  13. loop = asyncio.get_event_loop()
  14. coro = loop.run_in_executor(None, lambda: fn(*args, **kwargs))
  15. result = await coro
  16. return result
  17. def call_async_from_sync(
  18. corofn: Callable, timeout: float = GENERAL_TIMEOUT, *args, **kwargs
  19. ):
  20. """
  21. Shorthand for running a coroutine in the default background thread pool executor
  22. and awaiting the result
  23. """
  24. if corofn is None:
  25. raise ValueError('corofn is None')
  26. if not asyncio.iscoroutinefunction(corofn):
  27. raise ValueError('corofn is not a coroutine function')
  28. async def arun():
  29. coro = corofn(*args, **kwargs)
  30. result = await coro
  31. return result
  32. def run():
  33. loop_for_thread = asyncio.new_event_loop()
  34. try:
  35. asyncio.set_event_loop(loop_for_thread)
  36. return asyncio.run(arun())
  37. finally:
  38. loop_for_thread.close()
  39. future = EXECUTOR.submit(run)
  40. futures.wait([future], timeout=timeout or None)
  41. result = future.result()
  42. return result
  43. async def call_coro_in_bg_thread(
  44. corofn: Callable, timeout: float = GENERAL_TIMEOUT, *args, **kwargs
  45. ):
  46. """Function for running a coroutine in a background thread."""
  47. await call_sync_from_async(call_async_from_sync, corofn, timeout, *args, **kwargs)
  48. async def wait_all(
  49. iterable: Iterable[Coroutine], timeout: int = GENERAL_TIMEOUT
  50. ) -> List:
  51. """
  52. Shorthand for waiting for all the coroutines in the iterable given in parallel. Creates
  53. a task for each coroutine.
  54. Returns a list of results in the original order. If any single task raised an exception, this is raised.
  55. If multiple tasks raised exceptions, an AsyncException is raised containing all exceptions.
  56. """
  57. tasks = [asyncio.create_task(c) for c in iterable]
  58. if not tasks:
  59. return []
  60. _, pending = await asyncio.wait(tasks, timeout=timeout)
  61. if pending:
  62. for task in pending:
  63. task.cancel()
  64. raise asyncio.TimeoutError()
  65. results = []
  66. errors = []
  67. for task in tasks:
  68. try:
  69. results.append(task.result())
  70. except Exception as e:
  71. errors.append(e)
  72. if errors:
  73. if len(errors) == 1:
  74. raise errors[0]
  75. raise AsyncException(errors)
  76. return [task.result() for task in tasks]
  77. class AsyncException(Exception):
  78. def __init__(self, exceptions):
  79. self.exceptions = exceptions
  80. def __str__(self):
  81. return '\n'.join(str(e) for e in self.exceptions)