shared.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. import asyncio
  2. import json
  3. import logging
  4. import multiprocessing as mp
  5. import os
  6. import pathlib
  7. import subprocess
  8. import time
  9. from concurrent.futures import ProcessPoolExecutor
  10. from typing import Any, Awaitable, Callable
  11. import pandas as pd
  12. from pydantic import BaseModel
  13. from tqdm import tqdm
  14. from opendevin.controller.state.state import State
  15. from opendevin.core.config import LLMConfig
  16. from opendevin.core.logger import get_console_handler
  17. from opendevin.core.logger import opendevin_logger as logger
  18. from opendevin.events.action import Action
  19. from opendevin.events.action.message import MessageAction
  20. class EvalMetadata(BaseModel):
  21. agent_class: str
  22. llm_config: LLMConfig
  23. max_iterations: int
  24. eval_output_dir: str
  25. start_time: str
  26. git_commit: str
  27. dataset: str | None = None
  28. data_split: str | None = None
  29. details: dict[str, Any] | None = None
  30. def model_dump_json(self, *args, **kwargs):
  31. dumped = super().model_dump_json(*args, **kwargs)
  32. dumped_dict = json.loads(dumped)
  33. logger.debug(f'Dumped metadata: {dumped_dict}')
  34. # avoid leaking sensitive information
  35. dumped_dict['llm_config'] = self.llm_config.to_safe_dict()
  36. return json.dumps(dumped_dict)
  37. class EvalOutput(BaseModel):
  38. # NOTE: User-specified
  39. instance_id: str
  40. instruction: str
  41. # output of the evaluation
  42. # store anything that is needed for the score calculation
  43. test_result: dict[str, Any]
  44. # Interaction info
  45. metadata: EvalMetadata
  46. history: list[tuple[dict[str, Any], dict[str, Any]]]
  47. metrics: dict[str, Any]
  48. error: str | None = None
  49. # Optionally save the input test instance
  50. instance: dict[str, Any] | None = None
  51. def model_dump_json(self, *args, **kwargs):
  52. dumped = super().model_dump_json(*args, **kwargs)
  53. dumped_dict = json.loads(dumped)
  54. # Apply custom serialization for metadata (to avoid leaking sensitive information)
  55. dumped_dict['metadata'] = json.loads(self.metadata.model_dump_json())
  56. return json.dumps(dumped_dict)
  57. def codeact_user_response(
  58. state: State,
  59. encapsulate_solution: bool = False,
  60. try_parse: Callable[[Action], str] | None = None,
  61. ) -> str:
  62. encaps_str = (
  63. (
  64. 'Please encapsulate your final answer (answer ONLY) within <solution> and </solution>.\n'
  65. 'For example: The answer to the question is <solution> 42 </solution>.\n'
  66. )
  67. if encapsulate_solution
  68. else ''
  69. )
  70. msg = (
  71. 'Please continue working on the task on whatever approach you think is suitable.\n'
  72. 'If you think you have solved the task, please first send your answer to user through message and then <execute_bash> exit </execute_bash>.\n'
  73. f'{encaps_str}'
  74. 'IMPORTANT: YOU SHOULD NEVER ASK FOR HUMAN HELP.\n'
  75. )
  76. if state.history:
  77. # check if the last action has an answer, if so, early exit
  78. if try_parse is not None:
  79. last_action = state.history.get_last_action()
  80. ans = try_parse(last_action)
  81. if ans is not None:
  82. return '/exit'
  83. # check if the agent has tried to talk to the user 3 times, if so, let the agent know it can give up
  84. user_msgs = [
  85. event
  86. for event in state.history.get_events()
  87. if isinstance(event, MessageAction) and event.source == 'user'
  88. ]
  89. if len(user_msgs) >= 2:
  90. # let the agent know that it can give up when it has tried 3 times
  91. return (
  92. msg
  93. + 'If you want to give up, run: <execute_bash> exit </execute_bash>.\n'
  94. )
  95. return msg
  96. def cleanup():
  97. print('Cleaning up child processes...')
  98. for process in mp.active_children():
  99. print(f'Terminating child process: {process.name}')
  100. process.terminate()
  101. process.join()
  102. def make_metadata(
  103. llm_config: LLMConfig,
  104. dataset_name: str,
  105. agent_class: str,
  106. max_iterations: int,
  107. eval_note: str | None,
  108. eval_output_dir: str,
  109. data_split: str | None = None,
  110. details: dict[str, Any] | None = None,
  111. ) -> EvalMetadata:
  112. model_name = llm_config.model.split('/')[-1]
  113. eval_note = f'_N_{eval_note}' if eval_note else ''
  114. eval_output_path = os.path.join(
  115. eval_output_dir,
  116. dataset_name,
  117. agent_class,
  118. f'{model_name}_maxiter_{max_iterations}{eval_note}',
  119. )
  120. pathlib.Path(eval_output_path).mkdir(parents=True, exist_ok=True)
  121. pathlib.Path(os.path.join(eval_output_path, 'logs')).mkdir(
  122. parents=True, exist_ok=True
  123. )
  124. logger.info(f'Using evaluation output directory: {eval_output_path}')
  125. metadata = EvalMetadata(
  126. agent_class=agent_class,
  127. llm_config=llm_config,
  128. max_iterations=max_iterations,
  129. eval_output_dir=eval_output_path,
  130. start_time=time.strftime('%Y-%m-%d %H:%M:%S'),
  131. git_commit=subprocess.check_output(['git', 'rev-parse', 'HEAD'])
  132. .decode('utf-8')
  133. .strip(),
  134. dataset=dataset_name,
  135. data_split=data_split,
  136. details=details,
  137. )
  138. metadata_json = metadata.model_dump_json()
  139. logger.info(f'Metadata: {metadata_json}')
  140. with open(os.path.join(eval_output_path, 'metadata.json'), 'w') as f:
  141. f.write(metadata_json)
  142. return metadata
  143. def prepare_dataset(dataset: pd.DataFrame, output_file: str, eval_n_limit: int):
  144. assert (
  145. 'instance_id' in dataset.columns
  146. ), "Expected 'instance_id' column in the dataset. You should define your own unique identifier for each instance and use it as the 'instance_id' column."
  147. id_column = 'instance_id'
  148. logger.info(f'Writing evaluation output to {output_file}')
  149. finished_ids = set()
  150. if os.path.exists(output_file):
  151. with open(output_file, 'r') as f:
  152. for line in f:
  153. data = json.loads(line)
  154. finished_ids.add(data[id_column])
  155. logger.warning(
  156. f'Output file {output_file} already exists. Loaded {len(finished_ids)} finished instances.'
  157. )
  158. if eval_n_limit:
  159. dataset = dataset.head(eval_n_limit)
  160. logger.info(f'Limiting evaluation to first {eval_n_limit} instances.')
  161. new_dataset = [
  162. instance
  163. for _, instance in dataset.iterrows()
  164. if instance[id_column] not in finished_ids
  165. ]
  166. logger.info(
  167. f'Finished instances: {len(finished_ids)}, Remaining instances: {len(new_dataset)}'
  168. )
  169. return pd.DataFrame(new_dataset)
  170. async def run_evaluation(
  171. dataset: pd.DataFrame,
  172. metadata: EvalMetadata,
  173. output_file: str,
  174. num_workers: int,
  175. process_instance_func: Callable[
  176. [pd.Series, EvalMetadata, bool], Awaitable[EvalOutput]
  177. ],
  178. ):
  179. use_multiprocessing = num_workers > 1
  180. logger.info(
  181. f'Evaluation started with Agent {metadata.agent_class}, '
  182. f'model {metadata.llm_config.model}, max iterations {metadata.max_iterations}.'
  183. )
  184. pbar = tqdm(total=len(dataset))
  185. output_fp = open(output_file, 'a')
  186. async def update_progress(future):
  187. pbar.update(1)
  188. output: EvalOutput = await future if use_multiprocessing else future
  189. pbar.set_description(f'Instance {output.instance_id}')
  190. pbar.set_postfix_str(f'Test Result: {output.test_result}')
  191. logger.info(
  192. f'Finished evaluation for instance {output.instance_id}: {output.test_result}'
  193. )
  194. output_fp.write(json.dumps(output.model_dump()) + '\n')
  195. output_fp.flush()
  196. try:
  197. if use_multiprocessing:
  198. with ProcessPoolExecutor(num_workers) as executor:
  199. loop = asyncio.get_event_loop()
  200. futures = []
  201. for _, instance in dataset.iterrows():
  202. future = loop.run_in_executor(
  203. executor,
  204. process_instance_func,
  205. instance,
  206. metadata,
  207. bool(num_workers > 1),
  208. )
  209. futures.append(update_progress(future))
  210. await asyncio.gather(*futures)
  211. # Use plain for loop for single process for easier debugging
  212. else:
  213. assert num_workers == 1
  214. for _, instance in dataset.iterrows():
  215. output = await process_instance_func(instance, metadata, False)
  216. await update_progress(output)
  217. except KeyboardInterrupt:
  218. print('KeyboardInterrupt received. Cleaning up...')
  219. cleanup()
  220. output_fp.close()
  221. logger.info('Evaluation finished.')
  222. def reset_logger_for_multiprocessing(
  223. logger: logging.Logger, instance_id: str, log_dir: str
  224. ):
  225. """Reset the logger for multiprocessing.
  226. Save logs to a separate file for each process, instead of trying to write to the
  227. same file/console from multiple processes.
  228. """
  229. # Set up logger
  230. log_file = os.path.join(
  231. log_dir,
  232. f'instance_{instance_id}.log',
  233. )
  234. # Remove all existing handlers from logger
  235. for handler in logger.handlers[:]:
  236. logger.removeHandler(handler)
  237. # add back the console handler to print ONE line
  238. logger.addHandler(get_console_handler())
  239. logger.info(
  240. f'Starting evaluation for instance {instance_id}.\n'
  241. f'Hint: run "tail -f {log_file}" to see live logs in a separate shell'
  242. )
  243. # Remove all existing handlers from logger
  244. for handler in logger.handlers[:]:
  245. logger.removeHandler(handler)
  246. os.makedirs(os.path.dirname(log_file), exist_ok=True)
  247. file_handler = logging.FileHandler(log_file)
  248. file_handler.setFormatter(
  249. logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
  250. )
  251. logger.addHandler(file_handler)