|
|
@@ -6,8 +6,8 @@ import pathlib
|
|
|
import subprocess
|
|
|
import time
|
|
|
import traceback
|
|
|
-from concurrent.futures import ProcessPoolExecutor
|
|
|
-from typing import Any, Awaitable, Callable
|
|
|
+from concurrent.futures import Future, ProcessPoolExecutor
|
|
|
+from typing import Any, Awaitable, Callable, TextIO
|
|
|
|
|
|
import pandas as pd
|
|
|
from pydantic import BaseModel
|
|
|
@@ -234,18 +234,66 @@ def prepare_dataset(
|
|
|
return pd.DataFrame(new_dataset)
|
|
|
|
|
|
|
|
|
-def process_instance(
|
|
|
- instance, metadata, use_multiprocessing, process_instance_func
|
|
|
-) -> EvalOutput | EvalError:
|
|
|
+def update_progress(
|
|
|
+ result_or_future: Future | EvalOutput | EvalError,
|
|
|
+ instance: pd.Series,
|
|
|
+ pbar: tqdm,
|
|
|
+ output_fp: TextIO,
|
|
|
+ instance_queue: mp.Queue,
|
|
|
+):
|
|
|
+ """Update the progress bar and write the result to the output file."""
|
|
|
try:
|
|
|
- return process_instance_func(instance, metadata, use_multiprocessing)
|
|
|
+ if isinstance(result_or_future, Future):
|
|
|
+ result = result_or_future.result()
|
|
|
+ else:
|
|
|
+ result = result_or_future
|
|
|
except Exception as e:
|
|
|
- logger.error(f'Error processing instance [{instance.instance_id}]: {e}')
|
|
|
- return EvalError(
|
|
|
- instance_id=instance.instance_id,
|
|
|
- error=str(e),
|
|
|
- stacktrace=traceback.format_exc(),
|
|
|
+ # Handle the error
|
|
|
+ # Exception may be raised in the process_instance_func and will
|
|
|
+ # be raised here when we try to access the .result() of the future
|
|
|
+ handle_error(
|
|
|
+ EvalError(
|
|
|
+ instance_id=instance.instance_id,
|
|
|
+ error=str(e),
|
|
|
+ stacktrace=traceback.format_exc(),
|
|
|
+ ),
|
|
|
+ instance,
|
|
|
+ pbar,
|
|
|
+ instance_queue,
|
|
|
)
|
|
|
+ return
|
|
|
+
|
|
|
+ # Update the progress bar and write the result to the output file
|
|
|
+ if isinstance(result, EvalOutput):
|
|
|
+ pbar.update(1)
|
|
|
+ pbar.set_description(f'Instance {result.instance_id}')
|
|
|
+ pbar.set_postfix_str(f'Test Result: {result.test_result}')
|
|
|
+ logger.info(
|
|
|
+ f'Finished evaluation for instance {result.instance_id}: {str(result.test_result)[:300]}...\n'
|
|
|
+ )
|
|
|
+ output_fp.write(json.dumps(result.model_dump()) + '\n')
|
|
|
+ output_fp.flush()
|
|
|
+ elif isinstance(result, EvalError):
|
|
|
+ handle_error(result, instance, pbar, instance_queue)
|
|
|
+ else:
|
|
|
+ raise ValueError(f'Unexpected result type: {type(result)}')
|
|
|
+
|
|
|
+
|
|
|
+def handle_error(
|
|
|
+ error: EvalError, instance: pd.Series, pbar: tqdm, instance_queue: mp.Queue
|
|
|
+):
|
|
|
+ """Handle an error that occurred during evaluation."""
|
|
|
+ logger.error(
|
|
|
+ f'Retrying instance [{instance.instance_id}] due to error: {error.error}. Stacktrace:\n{error.stacktrace}'
|
|
|
+ + '\n'
|
|
|
+ + '-' * 10
|
|
|
+ + '[You may ignore this error if it is a transient issue - the instance will be automatically retried.]'
|
|
|
+ + '-' * 10
|
|
|
+ + '\n'
|
|
|
+ )
|
|
|
+ instance_queue.put(instance)
|
|
|
+ pbar.total += 1
|
|
|
+ pbar.refresh()
|
|
|
|
|
|
|
|
|
def run_evaluation(
|
|
|
@@ -271,29 +319,6 @@ def run_evaluation(
|
|
|
pbar = tqdm(total=total_instances, desc='Instances processed')
|
|
|
output_fp = open(output_file, 'a')
|
|
|
|
|
|
- def update_progress(result: EvalOutput | EvalError, instance: pd.Series):
|
|
|
- if isinstance(result, EvalOutput):
|
|
|
- pbar.update(1)
|
|
|
- pbar.set_description(f'Instance {result.instance_id}')
|
|
|
- pbar.set_postfix_str(f'Test Result: {result.test_result}')
|
|
|
- logger.info(
|
|
|
- f'Finished evaluation for instance {result.instance_id}: {str(result.test_result)[:300]}...\n'
|
|
|
- )
|
|
|
- output_fp.write(json.dumps(result.model_dump()) + '\n')
|
|
|
- output_fp.flush()
|
|
|
- else:
|
|
|
- logger.error(
|
|
|
- f'Retrying instance [{instance.instance_id}] due to error: {result.error}. Stacktrace:\n{result.stacktrace}'
|
|
|
- + '\n'
|
|
|
- + '-' * 10
|
|
|
- + '[You may ignore this error if it is a transient issue - the instance will be automatically retried.]'
|
|
|
- + '-' * 10
|
|
|
- + '\n'
|
|
|
- )
|
|
|
- instance_queue.put(instance)
|
|
|
- pbar.total += 1
|
|
|
- pbar.refresh()
|
|
|
-
|
|
|
try:
|
|
|
if use_multiprocessing:
|
|
|
with ProcessPoolExecutor(num_workers) as executor:
|
|
|
@@ -302,23 +327,17 @@ def run_evaluation(
|
|
|
# Loop until there are *no more instances to be processed* and *all (in-progress) futures are done*
|
|
|
# since a running future may add new instances to the queue when error occurs
|
|
|
while not instance_queue.empty() or batch_futures:
|
|
|
- # Submit new tasks if **there are instances to be processed** and **available workers**
|
|
|
+ # Submit new tasks if there are instances to be processed and available workers
|
|
|
while (
|
|
|
not instance_queue.empty() and len(batch_futures) < num_workers
|
|
|
):
|
|
|
try:
|
|
|
instance = instance_queue.get(block=False)
|
|
|
future = executor.submit(
|
|
|
- process_instance,
|
|
|
- instance,
|
|
|
- metadata,
|
|
|
- True,
|
|
|
- process_instance_func,
|
|
|
+ process_instance_func, instance, metadata, True
|
|
|
)
|
|
|
- future.add_done_callback(
|
|
|
- lambda f, inst=instance: update_progress(
|
|
|
- f.result(), inst
|
|
|
- )
|
|
|
+ future.instance = (
|
|
|
+ instance # Attach the instance to the future
|
|
|
)
|
|
|
batch_futures.append(future)
|
|
|
except mp.queues.Empty:
|
|
|
@@ -328,12 +347,19 @@ def run_evaluation(
|
|
|
break # Queue is empty, stop submitting new tasks
|
|
|
|
|
|
# Continue to wait for the futures to be done & remove completed futures
|
|
|
- batch_futures = [f for f in batch_futures if not f.done()]
|
|
|
+ new_batch_futures = []
|
|
|
+ for future in batch_futures:
|
|
|
+ if future.done():
|
|
|
+ update_progress(
|
|
|
+ future, future.instance, pbar, output_fp, instance_queue
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ new_batch_futures.append(future)
|
|
|
+ batch_futures = new_batch_futures
|
|
|
|
|
|
# Short sleep to prevent busy-waiting
|
|
|
time.sleep(1)
|
|
|
|
|
|
- # Ensure all futures are done
|
|
|
assert instance_queue.empty(), 'instance_queue should be empty after all futures are done. This is a bug.'
|
|
|
assert (
|
|
|
len(batch_futures) == 0
|
|
|
@@ -341,10 +367,8 @@ def run_evaluation(
|
|
|
else:
|
|
|
while not instance_queue.empty():
|
|
|
instance = instance_queue.get()
|
|
|
- result = process_instance(
|
|
|
- instance, metadata, False, process_instance_func
|
|
|
- )
|
|
|
- update_progress(result, instance)
|
|
|
+ result = process_instance_func(instance, metadata, False)
|
|
|
+ update_progress(result, instance, pbar, output_fp, instance_queue)
|
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
print('\nKeyboardInterrupt received. Cleaning up...\n')
|