|
|
@@ -6,7 +6,7 @@ import pathlib
|
|
|
import subprocess
|
|
|
import time
|
|
|
import traceback
|
|
|
-from concurrent.futures import Future, ProcessPoolExecutor
|
|
|
+from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
|
from typing import Any, Awaitable, Callable, TextIO
|
|
|
|
|
|
import pandas as pd
|
|
|
@@ -78,12 +78,6 @@ class EvalOutput(BaseModel):
|
|
|
return json.dumps(dumped_dict)
|
|
|
|
|
|
|
|
|
-class EvalError(BaseModel):
|
|
|
- instance_id: str
|
|
|
- error: str
|
|
|
- stacktrace: str
|
|
|
-
|
|
|
-
|
|
|
def codeact_user_response(
|
|
|
state: State,
|
|
|
encapsulate_solution: bool = False,
|
|
|
@@ -235,65 +229,58 @@ def prepare_dataset(
|
|
|
|
|
|
|
|
|
def update_progress(
|
|
|
- result_or_future: Future | EvalOutput | EvalError,
|
|
|
- instance: pd.Series,
|
|
|
+ result: EvalOutput,
|
|
|
pbar: tqdm,
|
|
|
output_fp: TextIO,
|
|
|
- instance_queue: mp.Queue,
|
|
|
):
|
|
|
"""Update the progress bar and write the result to the output file."""
|
|
|
- try:
|
|
|
- if isinstance(result_or_future, Future):
|
|
|
- result = result_or_future.result()
|
|
|
- else:
|
|
|
- result = result_or_future
|
|
|
- except Exception as e:
|
|
|
- # Handle the error
|
|
|
- # Exception may be raised in the process_instance_func and will
|
|
|
- # be raised here when we try to access the .result() of the future
|
|
|
- handle_error(
|
|
|
- EvalError(
|
|
|
- instance_id=instance.instance_id,
|
|
|
- error=str(e),
|
|
|
- stacktrace=traceback.format_exc(),
|
|
|
- ),
|
|
|
- instance,
|
|
|
- pbar,
|
|
|
- instance_queue,
|
|
|
- )
|
|
|
- return
|
|
|
+ pbar.update(1)
|
|
|
+ pbar.set_description(f'Instance {result.instance_id}')
|
|
|
+ pbar.set_postfix_str(f'Test Result: {result.test_result}')
|
|
|
+ logger.info(
|
|
|
+ f'Finished evaluation for instance {result.instance_id}: {str(result.test_result)[:300]}...\n'
|
|
|
+ )
|
|
|
+ output_fp.write(json.dumps(result.model_dump()) + '\n')
|
|
|
+ output_fp.flush()
|
|
|
|
|
|
- # Update the progress bar and write the result to the output file
|
|
|
- if isinstance(result, EvalOutput):
|
|
|
- pbar.update(1)
|
|
|
- pbar.set_description(f'Instance {result.instance_id}')
|
|
|
- pbar.set_postfix_str(f'Test Result: {result.test_result}')
|
|
|
- logger.info(
|
|
|
- f'Finished evaluation for instance {result.instance_id}: {str(result.test_result)[:300]}...\n'
|
|
|
- )
|
|
|
- output_fp.write(json.dumps(result.model_dump()) + '\n')
|
|
|
- output_fp.flush()
|
|
|
- elif isinstance(result, EvalError):
|
|
|
- handle_error(result, instance, pbar, instance_queue)
|
|
|
- else:
|
|
|
- raise ValueError(f'Unexpected result type: {type(result)}')
|
|
|
|
|
|
+def _process_instance_wrapper(
|
|
|
+ process_instance_func: Callable[[pd.Series, EvalMetadata, bool], EvalOutput],
|
|
|
+ instance: pd.Series,
|
|
|
+ metadata: EvalMetadata,
|
|
|
+ use_mp: bool,
|
|
|
+ max_retries: int = 5,
|
|
|
+) -> EvalOutput:
|
|
|
+ """Wrap the process_instance_func to handle retries and errors.
|
|
|
|
|
|
-def handle_error(
|
|
|
- error: EvalError, instance: pd.Series, pbar: tqdm, instance_queue: mp.Queue
|
|
|
-):
|
|
|
- """Handle an error that occurred during evaluation."""
|
|
|
- logger.error(
|
|
|
- f'Retrying instance [{instance.instance_id}] due to error: {error.error}. Stacktrace:\n{error.stacktrace}'
|
|
|
- + '\n'
|
|
|
- + '-' * 10
|
|
|
- + '[You may ignore this error if it is a transient issue - the instance will be automatically retried.]'
|
|
|
- + '-' * 10
|
|
|
- + '\n'
|
|
|
- )
|
|
|
- instance_queue.put(instance)
|
|
|
- pbar.total += 1
|
|
|
- pbar.refresh()
|
|
|
+ Retry an instance up to max_retries times if it fails (e.g., due to transient network/runtime issues).
|
|
|
+ """
|
|
|
+ for attempt in range(max_retries + 1):
|
|
|
+ try:
|
|
|
+ result = process_instance_func(instance, metadata, use_mp)
|
|
|
+ return result
|
|
|
+ except Exception as e:
|
|
|
+ if attempt == max_retries:
|
|
|
+ # Raise an error after all retries & stop the evaluation
|
|
|
+ raise RuntimeError(
|
|
|
+ f'Maximum error retries reached for instance {instance.instance_id}'
|
|
|
+ ) from e
|
|
|
+ error = str(e)
|
|
|
+ stacktrace = traceback.format_exc()
|
|
|
+ msg = (
|
|
|
+ '-' * 10
|
|
|
+ + '\n'
|
|
|
+ + f'Error in instance [{instance.instance_id}]: {error}. Stacktrace:\n{stacktrace}'
|
|
|
+ + '\n'
|
|
|
+ + '-' * 10
|
|
|
+ + '[This error occurred after maximum retries]'
|
|
|
+ + '-' * 10
|
|
|
+ + '\n'
|
|
|
+ )
|
|
|
+ logger.error(msg)
|
|
|
+ if use_mp:
|
|
|
+ print(msg) # use print to directly print to console
|
|
|
+ time.sleep(1) # Add a small delay before retrying
|
|
|
|
|
|
|
|
|
def run_evaluation(
|
|
|
@@ -304,6 +291,7 @@ def run_evaluation(
|
|
|
process_instance_func: Callable[
|
|
|
[pd.Series, EvalMetadata, bool], Awaitable[EvalOutput]
|
|
|
],
|
|
|
+ max_retries: int = 5, # number of retries for each instance
|
|
|
):
|
|
|
use_multiprocessing = num_workers > 1
|
|
|
logger.info(
|
|
|
@@ -311,10 +299,6 @@ def run_evaluation(
|
|
|
f'model {metadata.llm_config.model}, max iterations {metadata.max_iterations}.\n'
|
|
|
)
|
|
|
|
|
|
- instance_queue = mp.Queue()
|
|
|
- for _, instance in dataset.iterrows():
|
|
|
- instance_queue.put(instance)
|
|
|
-
|
|
|
total_instances = len(dataset)
|
|
|
pbar = tqdm(total=total_instances, desc='Instances processed')
|
|
|
output_fp = open(output_file, 'a')
|
|
|
@@ -322,53 +306,30 @@ def run_evaluation(
|
|
|
try:
|
|
|
if use_multiprocessing:
|
|
|
with ProcessPoolExecutor(num_workers) as executor:
|
|
|
- batch_futures = []
|
|
|
-
|
|
|
- # Loop until there are *no more instances to be processed* and *all (in-progress) futures are done*
|
|
|
- # since a running future may add new instances to the queue when error occurs
|
|
|
- while not instance_queue.empty() or batch_futures:
|
|
|
- # Submit new tasks if there are instances to be processed and available workers
|
|
|
- while (
|
|
|
- not instance_queue.empty() and len(batch_futures) < num_workers
|
|
|
- ):
|
|
|
- try:
|
|
|
- instance = instance_queue.get(block=False)
|
|
|
- future = executor.submit(
|
|
|
- process_instance_func, instance, metadata, True
|
|
|
- )
|
|
|
- future.instance = (
|
|
|
- instance # Attach the instance to the future
|
|
|
- )
|
|
|
- batch_futures.append(future)
|
|
|
- except mp.queues.Empty:
|
|
|
- logger.warning(
|
|
|
- 'Queue is empty - This should not happen. This is a bug.'
|
|
|
- )
|
|
|
- break # Queue is empty, stop submitting new tasks
|
|
|
-
|
|
|
- # Continue to wait for the futures to be done & remove completed futures
|
|
|
- new_batch_futures = []
|
|
|
- for future in batch_futures:
|
|
|
- if future.done():
|
|
|
- update_progress(
|
|
|
- future, future.instance, pbar, output_fp, instance_queue
|
|
|
- )
|
|
|
- else:
|
|
|
- new_batch_futures.append(future)
|
|
|
- batch_futures = new_batch_futures
|
|
|
-
|
|
|
- # Short sleep to prevent busy-waiting
|
|
|
- time.sleep(1)
|
|
|
-
|
|
|
- assert instance_queue.empty(), 'instance_queue should be empty after all futures are done. This is a bug.'
|
|
|
- assert (
|
|
|
- len(batch_futures) == 0
|
|
|
- ), 'batch_futures should be empty after all futures are done. This is a bug.'
|
|
|
+ futures = [
|
|
|
+ executor.submit(
|
|
|
+ _process_instance_wrapper,
|
|
|
+ process_instance_func=process_instance_func,
|
|
|
+ instance=instance,
|
|
|
+ metadata=metadata,
|
|
|
+ use_mp=True,
|
|
|
+ max_retries=max_retries,
|
|
|
+ )
|
|
|
+ for _, instance in dataset.iterrows()
|
|
|
+ ]
|
|
|
+ for future in as_completed(futures):
|
|
|
+ result = future.result()
|
|
|
+ update_progress(result, pbar, output_fp)
|
|
|
else:
|
|
|
- while not instance_queue.empty():
|
|
|
- instance = instance_queue.get()
|
|
|
- result = process_instance_func(instance, metadata, False)
|
|
|
- update_progress(result, instance, pbar, output_fp, instance_queue)
|
|
|
+ for _, instance in dataset.iterrows():
|
|
|
+ result = _process_instance_wrapper(
|
|
|
+ process_instance_func=process_instance_func,
|
|
|
+ instance=instance,
|
|
|
+ metadata=metadata,
|
|
|
+ use_mp=False,
|
|
|
+ max_retries=max_retries,
|
|
|
+ )
|
|
|
+ update_progress(result, pbar, output_fp)
|
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
print('\nKeyboardInterrupt received. Cleaning up...\n')
|