|
|
@@ -1,11 +1,13 @@
|
|
|
import asyncio
|
|
|
+import time
|
|
|
|
|
|
import pytest
|
|
|
|
|
|
from openhands.utils.async_utils import (
|
|
|
AsyncException,
|
|
|
- async_from_sync,
|
|
|
- sync_from_async,
|
|
|
+ call_async_from_sync,
|
|
|
+ call_coro_in_bg_thread,
|
|
|
+ call_sync_from_async,
|
|
|
wait_all,
|
|
|
)
|
|
|
|
|
|
@@ -80,44 +82,44 @@ async def test_await_all_timeout():
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
-async def test_sync_from_async():
|
|
|
+async def test_call_sync_from_async():
|
|
|
def dummy(value: int = 2):
|
|
|
return value * 2
|
|
|
|
|
|
- result = await sync_from_async(dummy)
|
|
|
+ result = await call_sync_from_async(dummy)
|
|
|
assert result == 4
|
|
|
- result = await sync_from_async(dummy, 3)
|
|
|
+ result = await call_sync_from_async(dummy, 3)
|
|
|
assert result == 6
|
|
|
- result = await sync_from_async(dummy, value=5)
|
|
|
+ result = await call_sync_from_async(dummy, value=5)
|
|
|
assert result == 10
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
-async def test_sync_from_async_error():
|
|
|
+async def test_call_sync_from_async_error():
|
|
|
def dummy():
|
|
|
raise ValueError()
|
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
- await sync_from_async(dummy)
|
|
|
+ await call_sync_from_async(dummy)
|
|
|
|
|
|
|
|
|
-def test_async_from_sync():
|
|
|
+def test_call_async_from_sync():
|
|
|
async def dummy(value: int):
|
|
|
return value * 2
|
|
|
|
|
|
- result = async_from_sync(dummy, 0, 3)
|
|
|
+ result = call_async_from_sync(dummy, 0, 3)
|
|
|
assert result == 6
|
|
|
|
|
|
|
|
|
-def test_async_from_sync_error():
|
|
|
+def test_call_async_from_sync_error():
|
|
|
async def dummy(value: int):
|
|
|
raise ValueError()
|
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
- async_from_sync(dummy, 0, 3)
|
|
|
+ call_async_from_sync(dummy, 0, 3)
|
|
|
|
|
|
|
|
|
-def test_async_from_sync_background_tasks():
|
|
|
+def test_call_async_from_sync_background_tasks():
|
|
|
events = []
|
|
|
|
|
|
async def bg_task():
|
|
|
@@ -132,9 +134,33 @@ def test_async_from_sync_background_tasks():
|
|
|
asyncio.create_task(bg_task())
|
|
|
events.append('dummy_started')
|
|
|
|
|
|
- async_from_sync(dummy, 0, 3)
|
|
|
+ call_async_from_sync(dummy, 0, 3)
|
|
|
|
|
|
# We check that the function did not return until all coroutines completed
|
|
|
# (Even though some of these were started as background tasks)
|
|
|
expected = ['dummy_started', 'dummy_started', 'bg_started', 'bg_finished']
|
|
|
assert expected == events
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_call_coro_in_bg_thread():
|
|
|
+ times = {}
|
|
|
+
|
|
|
+ async def bad_async(id_):
|
|
|
+ # Dummy demonstrating some bad async function that does not cede control
|
|
|
+ time.sleep(0.1)
|
|
|
+ times[id_] = time.time()
|
|
|
+
|
|
|
+ async def curve_ball():
|
|
|
+ # A curve ball - an async function that wants to run while the bad async functions are in progress
|
|
|
+ await asyncio.sleep(0.05)
|
|
|
+ times['curve_ball'] = time.time()
|
|
|
+
|
|
|
+ start = time.time()
|
|
|
+ asyncio.create_task(curve_ball())
|
|
|
+ await wait_all(
|
|
|
+ call_coro_in_bg_thread(bad_async, id_=f'bad_async_{id_}') for id_ in range(5)
|
|
|
+ )
|
|
|
+ assert (times['curve_ball'] - start) == pytest.approx(0.05, abs=0.1)
|
|
|
+ for id_ in range(5):
|
|
|
+ assert (times[f'bad_async_{id_}'] - start) == pytest.approx(0.1, abs=0.1)
|