Sfoglia il codice sorgente

[eval] refactor process instance logic into `update_progress` (#3875)

Xingyao Wang 1 anno fa
parent
commit
2b3925278d
1 ha cambiato i file con 74 aggiunte e 50 eliminazioni
  1. 74 50
      evaluation/utils/shared.py

+ 74 - 50
evaluation/utils/shared.py

@@ -6,8 +6,8 @@ import pathlib
 import subprocess
 import time
 import traceback
-from concurrent.futures import ProcessPoolExecutor
-from typing import Any, Awaitable, Callable
+from concurrent.futures import Future, ProcessPoolExecutor
+from typing import Any, Awaitable, Callable, TextIO
 
 import pandas as pd
 from pydantic import BaseModel
@@ -234,18 +234,66 @@ def prepare_dataset(
     return pd.DataFrame(new_dataset)
 
 
-def process_instance(
-    instance, metadata, use_multiprocessing, process_instance_func
-) -> EvalOutput | EvalError:
+def update_progress(
+    result_or_future: Future | EvalOutput | EvalError,
+    instance: pd.Series,
+    pbar: tqdm,
+    output_fp: TextIO,
+    instance_queue: mp.Queue,
+):
+    """Update the progress bar and write the result to the output file."""
     try:
-        return process_instance_func(instance, metadata, use_multiprocessing)
+        if isinstance(result_or_future, Future):
+            result = result_or_future.result()
+        else:
+            result = result_or_future
     except Exception as e:
-        logger.error(f'Error processing instance [{instance.instance_id}]: {e}')
-        return EvalError(
-            instance_id=instance.instance_id,
-            error=str(e),
-            stacktrace=traceback.format_exc(),
+        # Handle the error
+        # Exception may be raised in the process_instance_func and will
+        # be raised here when we try to access the .result() of the future
+        handle_error(
+            EvalError(
+                instance_id=instance.instance_id,
+                error=str(e),
+                stacktrace=traceback.format_exc(),
+            ),
+            instance,
+            pbar,
+            instance_queue,
         )
+        return
+
+    # Update the progress bar and write the result to the output file
+    if isinstance(result, EvalOutput):
+        pbar.update(1)
+        pbar.set_description(f'Instance {result.instance_id}')
+        pbar.set_postfix_str(f'Test Result: {result.test_result}')
+        logger.info(
+            f'Finished evaluation for instance {result.instance_id}: {str(result.test_result)[:300]}...\n'
+        )
+        output_fp.write(json.dumps(result.model_dump()) + '\n')
+        output_fp.flush()
+    elif isinstance(result, EvalError):
+        handle_error(result, instance, pbar, instance_queue)
+    else:
+        raise ValueError(f'Unexpected result type: {type(result)}')
+
+
+def handle_error(
+    error: EvalError, instance: pd.Series, pbar: tqdm, instance_queue: mp.Queue
+):
+    """Handle an error that occurred during evaluation."""
+    logger.error(
+        f'Retrying instance [{instance.instance_id}] due to error: {error.error}. Stacktrace:\n{error.stacktrace}'
+        + '\n'
+        + '-' * 10
+        + '[You may ignore this error if it is a transient issue - the instance will be automatically retried.]'
+        + '-' * 10
+        + '\n'
+    )
+    instance_queue.put(instance)
+    pbar.total += 1
+    pbar.refresh()
 
 
 def run_evaluation(
@@ -271,29 +319,6 @@ def run_evaluation(
     pbar = tqdm(total=total_instances, desc='Instances processed')
     output_fp = open(output_file, 'a')
 
-    def update_progress(result: EvalOutput | EvalError, instance: pd.Series):
-        if isinstance(result, EvalOutput):
-            pbar.update(1)
-            pbar.set_description(f'Instance {result.instance_id}')
-            pbar.set_postfix_str(f'Test Result: {result.test_result}')
-            logger.info(
-                f'Finished evaluation for instance {result.instance_id}: {str(result.test_result)[:300]}...\n'
-            )
-            output_fp.write(json.dumps(result.model_dump()) + '\n')
-            output_fp.flush()
-        else:
-            logger.error(
-                f'Retrying instance [{instance.instance_id}] due to error: {result.error}. Stacktrace:\n{result.stacktrace}'
-                + '\n'
-                + '-' * 10
-                + '[You may ignore this error if it is a transient issue - the instance will be automatically retried.]'
-                + '-' * 10
-                + '\n'
-            )
-            instance_queue.put(instance)
-            pbar.total += 1
-            pbar.refresh()
-
     try:
         if use_multiprocessing:
             with ProcessPoolExecutor(num_workers) as executor:
@@ -302,23 +327,17 @@ def run_evaluation(
                 # Loop until there are *no more instances to be processed* and *all (in-progress) futures are done*
                 # since a running future may add new instances to the queue when error occurs
                 while not instance_queue.empty() or batch_futures:
-                    # Submit new tasks if **there are instances to be processed** and **available workers**
+                    # Submit new tasks if there are instances to be processed and available workers
                     while (
                         not instance_queue.empty() and len(batch_futures) < num_workers
                     ):
                         try:
                             instance = instance_queue.get(block=False)
                             future = executor.submit(
-                                process_instance,
-                                instance,
-                                metadata,
-                                True,
-                                process_instance_func,
+                                process_instance_func, instance, metadata, True
                             )
-                            future.add_done_callback(
-                                lambda f, inst=instance: update_progress(
-                                    f.result(), inst
-                                )
+                            future.instance = (
+                                instance  # Attach the instance to the future
                             )
                             batch_futures.append(future)
                         except mp.queues.Empty:
@@ -328,12 +347,19 @@ def run_evaluation(
                             break  # Queue is empty, stop submitting new tasks
 
                     # Continue to wait for the futures to be done & remove completed futures
-                    batch_futures = [f for f in batch_futures if not f.done()]
+                    new_batch_futures = []
+                    for future in batch_futures:
+                        if future.done():
+                            update_progress(
+                                future, future.instance, pbar, output_fp, instance_queue
+                            )
+                        else:
+                            new_batch_futures.append(future)
+                    batch_futures = new_batch_futures
 
                     # Short sleep to prevent busy-waiting
                     time.sleep(1)
 
-                # Ensure all futures are done
                 assert instance_queue.empty(), 'instance_queue should be empty after all futures are done. This is a bug.'
                 assert (
                     len(batch_futures) == 0
@@ -341,10 +367,8 @@ def run_evaluation(
         else:
             while not instance_queue.empty():
                 instance = instance_queue.get()
-                result = process_instance(
-                    instance, metadata, False, process_instance_func
-                )
-                update_progress(result, instance)
+                result = process_instance_func(instance, metadata, False)
+                update_progress(result, instance, pbar, output_fp, instance_queue)
 
     except KeyboardInterrupt:
         print('\nKeyboardInterrupt received. Cleaning up...\n')