test_async_utils.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. import asyncio
  2. import pytest
  3. from openhands.utils.async_utils import (
  4. AsyncException,
  5. call_async_from_sync,
  6. call_sync_from_async,
  7. wait_all,
  8. )
  9. @pytest.mark.asyncio
  10. async def test_await_all():
  11. # Mock function demonstrating some calculation - always takes a minimum of 0.1 seconds
  12. async def dummy(value: int):
  13. await asyncio.sleep(0.1)
  14. return value * 2
  15. # wait for 10 calculations - serially this would take 1 second
  16. coro = wait_all(dummy(i) for i in range(10))
  17. # give the task only 0.3 seconds to complete (This verifies they occur in parallel)
  18. task = asyncio.create_task(coro)
  19. await asyncio.wait([task], timeout=0.3)
  20. # validate the results (We need to sort because they can return in any order)
  21. results = list(await task)
  22. expected = [i * 2 for i in range(10)]
  23. assert expected == results
  24. @pytest.mark.asyncio
  25. async def test_await_all_single_exception():
  26. # Mock function demonstrating some calculation - always takes a minimum of 0.1 seconds
  27. async def dummy(value: int):
  28. await asyncio.sleep(0.1)
  29. if value == 1:
  30. raise ValueError('Invalid value 1') # Throw an exception on every odd value
  31. return value * 2
  32. # expect an exception to be raised.
  33. with pytest.raises(ValueError, match='Invalid value 1'):
  34. await wait_all(dummy(i) for i in range(10))
  35. @pytest.mark.asyncio
  36. async def test_await_all_multi_exception():
  37. # Mock function demonstrating some calculation - always takes a minimum of 0.1 seconds
  38. async def dummy(value: int):
  39. await asyncio.sleep(0.1)
  40. if value & 1:
  41. raise ValueError(
  42. f'Invalid value {value}'
  43. ) # Throw an exception on every odd value
  44. return value * 2
  45. # expect an exception to be raised.
  46. with pytest.raises(AsyncException):
  47. await wait_all(dummy(i) for i in range(10))
  48. @pytest.mark.asyncio
  49. async def test_await_all_timeout():
  50. result = 0
  51. # Mock function updates a nonlocal variable after a delay
  52. async def dummy(value: int):
  53. nonlocal result
  54. await asyncio.sleep(0.2)
  55. result += value
  56. # expect an exception to be raised.
  57. with pytest.raises(asyncio.TimeoutError):
  58. await wait_all((dummy(i) for i in range(10)), 0.1)
  59. # Wait and then check the shared result - this makes sure that pending tasks were cancelled.
  60. asyncio.sleep(0.2)
  61. assert result == 0
  62. @pytest.mark.asyncio
  63. async def test_call_sync_from_async():
  64. def dummy(value: int = 2):
  65. return value * 2
  66. result = await call_sync_from_async(dummy)
  67. assert result == 4
  68. result = await call_sync_from_async(dummy, 3)
  69. assert result == 6
  70. result = await call_sync_from_async(dummy, value=5)
  71. assert result == 10
  72. @pytest.mark.asyncio
  73. async def test_call_sync_from_async_error():
  74. def dummy():
  75. raise ValueError()
  76. with pytest.raises(ValueError):
  77. await call_sync_from_async(dummy)
  78. def test_call_async_from_sync():
  79. async def dummy(value: int):
  80. return value * 2
  81. result = call_async_from_sync(dummy, 0, 3)
  82. assert result == 6
  83. def test_call_async_from_sync_error():
  84. async def dummy(value: int):
  85. raise ValueError()
  86. with pytest.raises(ValueError):
  87. call_async_from_sync(dummy, 0, 3)
  88. def test_call_async_from_sync_background_tasks():
  89. events = []
  90. async def bg_task():
  91. # This background task should finish after the dummy task
  92. events.append('bg_started')
  93. asyncio.sleep(0.2)
  94. events.append('bg_finished')
  95. async def dummy(value: int):
  96. events.append('dummy_started')
  97. # This coroutine kicks off a background task
  98. asyncio.create_task(bg_task())
  99. events.append('dummy_started')
  100. call_async_from_sync(dummy, 0, 3)
  101. # We check that the function did not return until all coroutines completed
  102. # (Even though some of these were started as background tasks)
  103. expected = ['dummy_started', 'dummy_started', 'bg_started', 'bg_finished']
  104. assert expected == events