analyzer.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. import re
  2. import uuid
  3. from typing import Any
  4. import docker
  5. from fastapi import HTTPException, Request
  6. from fastapi.responses import JSONResponse
  7. from opendevin.core.logger import opendevin_logger as logger
  8. from opendevin.events.action.action import (
  9. Action,
  10. ActionSecurityRisk,
  11. )
  12. from opendevin.events.event import Event, EventSource
  13. from opendevin.events.observation import Observation
  14. from opendevin.events.serialization.action import action_from_dict
  15. from opendevin.events.stream import EventStream
  16. from opendevin.runtime.utils import find_available_tcp_port
  17. from opendevin.security.analyzer import SecurityAnalyzer
  18. from opendevin.security.invariant.client import InvariantClient
  19. from opendevin.security.invariant.parser import TraceElement, parse_element
  20. class InvariantAnalyzer(SecurityAnalyzer):
  21. """Security analyzer based on Invariant."""
  22. trace: list[TraceElement]
  23. input: list[dict]
  24. container_name: str = 'opendevin-invariant-server'
  25. image_name: str = 'ghcr.io/invariantlabs-ai/server:opendevin'
  26. api_host: str = 'http://localhost'
  27. timeout: int = 180
  28. settings: dict = {}
  29. def __init__(
  30. self,
  31. event_stream: EventStream,
  32. policy: str | None = None,
  33. sid: str | None = None,
  34. ):
  35. """Initializes a new instance of the InvariantAnalzyer class."""
  36. super().__init__(event_stream)
  37. self.trace = []
  38. self.input = []
  39. self.settings = {}
  40. if sid is None:
  41. self.sid = str(uuid.uuid4())
  42. try:
  43. self.docker_client = docker.from_env()
  44. except Exception as ex:
  45. logger.exception(
  46. 'Error creating Invariant Security Analyzer container. Please check that Docker is running or disable the Security Analyzer in settings.',
  47. exc_info=False,
  48. )
  49. raise ex
  50. running_containers = self.docker_client.containers.list(
  51. filters={'name': self.container_name}
  52. )
  53. if not running_containers:
  54. all_containers = self.docker_client.containers.list(
  55. all=True, filters={'name': self.container_name}
  56. )
  57. if all_containers:
  58. self.container = all_containers[0]
  59. all_containers[0].start()
  60. else:
  61. self.api_port = find_available_tcp_port()
  62. self.container = self.docker_client.containers.run(
  63. self.image_name,
  64. name=self.container_name,
  65. platform='linux/amd64',
  66. ports={'8000/tcp': self.api_port},
  67. detach=True,
  68. )
  69. else:
  70. self.container = running_containers[0]
  71. elapsed = 0
  72. while self.container.status != 'running':
  73. self.container = self.docker_client.containers.get(self.container_name)
  74. elapsed += 1
  75. logger.info(
  76. f'waiting for container to start: {elapsed}, container status: {self.container.status}'
  77. )
  78. if elapsed > self.timeout:
  79. break
  80. self.api_port = int(
  81. self.container.attrs['NetworkSettings']['Ports']['8000/tcp'][0]['HostPort']
  82. )
  83. self.api_server = f'{self.api_host}:{self.api_port}'
  84. self.client = InvariantClient(self.api_server, self.sid)
  85. if policy is None:
  86. policy, _ = self.client.Policy.get_template()
  87. if policy is None:
  88. policy = ''
  89. self.monitor = self.client.Monitor.from_string(policy)
  90. async def close(self):
  91. self.container.stop()
  92. async def log_event(self, event: Event) -> None:
  93. if isinstance(event, Observation):
  94. element = parse_element(self.trace, event)
  95. self.trace.extend(element)
  96. self.input.extend([e.model_dump(exclude_none=True) for e in element]) # type: ignore [call-overload]
  97. else:
  98. logger.info('Invariant skipping element: event')
  99. def get_risk(self, results: list[str]) -> ActionSecurityRisk:
  100. mapping = {
  101. 'high': ActionSecurityRisk.HIGH,
  102. 'medium': ActionSecurityRisk.MEDIUM,
  103. 'low': ActionSecurityRisk.LOW,
  104. }
  105. regex = r'(?<=risk=)\w+'
  106. risks = []
  107. for result in results:
  108. m = re.search(regex, result)
  109. if m and m.group() in mapping:
  110. risks.append(mapping[m.group()])
  111. if risks:
  112. return max(risks)
  113. return ActionSecurityRisk.LOW
  114. async def act(self, event: Event) -> None:
  115. if await self.should_confirm(event):
  116. await self.confirm(event)
  117. async def should_confirm(self, event: Event) -> bool:
  118. risk = event.security_risk # type: ignore [attr-defined]
  119. return (
  120. risk is not None
  121. and risk < self.settings.get('RISK_SEVERITY', ActionSecurityRisk.MEDIUM)
  122. and hasattr(event, 'is_confirmed')
  123. and event.is_confirmed == 'awaiting_confirmation'
  124. )
  125. async def confirm(self, event: Event) -> None:
  126. new_event = action_from_dict(
  127. {'action': 'change_agent_state', 'args': {'agent_state': 'user_confirmed'}}
  128. )
  129. if event.source:
  130. self.event_stream.add_event(new_event, event.source)
  131. else:
  132. self.event_stream.add_event(new_event, EventSource.AGENT)
  133. async def security_risk(self, event: Action) -> ActionSecurityRisk:
  134. logger.info('Calling security_risk on InvariantAnalyzer')
  135. new_elements = parse_element(self.trace, event)
  136. input = [e.model_dump(exclude_none=True) for e in new_elements] # type: ignore [call-overload]
  137. self.trace.extend(new_elements)
  138. result, err = self.monitor.check(self.input, input)
  139. self.input.extend(input)
  140. risk = ActionSecurityRisk.UNKNOWN
  141. if err:
  142. logger.warning(f'Error checking policy: {err}')
  143. return risk
  144. risk = self.get_risk(result)
  145. return risk
  146. ### Handle API requests
  147. async def handle_api_request(self, request: Request) -> Any:
  148. path_parts = request.url.path.strip('/').split('/')
  149. endpoint = path_parts[-1] # Get the last part of the path
  150. if request.method == 'GET':
  151. if endpoint == 'export-trace':
  152. return await self.export_trace(request)
  153. elif endpoint == 'policy':
  154. return await self.get_policy(request)
  155. elif endpoint == 'settings':
  156. return await self.get_settings(request)
  157. elif request.method == 'POST':
  158. if endpoint == 'policy':
  159. return await self.update_policy(request)
  160. elif endpoint == 'settings':
  161. return await self.update_settings(request)
  162. raise HTTPException(status_code=405, detail='Method Not Allowed')
  163. async def export_trace(self, request: Request) -> Any:
  164. return JSONResponse(content=self.input)
  165. async def get_policy(self, request: Request) -> Any:
  166. return JSONResponse(content={'policy': self.monitor.policy})
  167. async def update_policy(self, request: Request) -> Any:
  168. data = await request.json()
  169. policy = data.get('policy')
  170. new_monitor = self.client.Monitor.from_string(policy)
  171. self.monitor = new_monitor
  172. return JSONResponse(content={'policy': policy})
  173. async def get_settings(self, request: Request) -> Any:
  174. return JSONResponse(content=self.settings)
  175. async def update_settings(self, request: Request) -> Any:
  176. settings = await request.json()
  177. self.settings = settings
  178. return JSONResponse(content=self.settings)