analyzer.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. import re
  2. import uuid
  3. from typing import Any
  4. import docker
  5. from fastapi import HTTPException, Request
  6. from fastapi.responses import JSONResponse
  7. from openhands.core.logger import openhands_logger as logger
  8. from openhands.events.action.action import (
  9. Action,
  10. ActionConfirmationStatus,
  11. ActionSecurityRisk,
  12. )
  13. from openhands.events.event import Event, EventSource
  14. from openhands.events.observation import Observation
  15. from openhands.events.serialization.action import action_from_dict
  16. from openhands.events.stream import EventStream
  17. from openhands.runtime.utils import find_available_tcp_port
  18. from openhands.security.analyzer import SecurityAnalyzer
  19. from openhands.security.invariant.client import InvariantClient
  20. from openhands.security.invariant.parser import TraceElement, parse_element
  21. from openhands.utils.async_utils import call_sync_from_async
  22. class InvariantAnalyzer(SecurityAnalyzer):
  23. """Security analyzer based on Invariant."""
  24. trace: list[TraceElement]
  25. input: list[dict]
  26. container_name: str = 'openhands-invariant-server'
  27. image_name: str = 'ghcr.io/invariantlabs-ai/server:openhands'
  28. api_host: str = 'http://localhost'
  29. timeout: int = 180
  30. settings: dict = {}
  31. def __init__(
  32. self,
  33. event_stream: EventStream,
  34. policy: str | None = None,
  35. sid: str | None = None,
  36. ):
  37. """Initializes a new instance of the InvariantAnalzyer class."""
  38. super().__init__(event_stream)
  39. self.trace = []
  40. self.input = []
  41. self.settings = {}
  42. if sid is None:
  43. self.sid = str(uuid.uuid4())
  44. try:
  45. self.docker_client = docker.from_env()
  46. except Exception as ex:
  47. logger.exception(
  48. 'Error creating Invariant Security Analyzer container. Please check that Docker is running or disable the Security Analyzer in settings.',
  49. exc_info=False,
  50. )
  51. raise ex
  52. running_containers = self.docker_client.containers.list(
  53. filters={'name': self.container_name}
  54. )
  55. if not running_containers:
  56. all_containers = self.docker_client.containers.list(
  57. all=True, filters={'name': self.container_name}
  58. )
  59. if all_containers:
  60. self.container = all_containers[0]
  61. all_containers[0].start()
  62. else:
  63. self.api_port = find_available_tcp_port()
  64. self.container = self.docker_client.containers.run(
  65. self.image_name,
  66. name=self.container_name,
  67. platform='linux/amd64',
  68. ports={'8000/tcp': self.api_port},
  69. detach=True,
  70. )
  71. else:
  72. self.container = running_containers[0]
  73. elapsed = 0
  74. while self.container.status != 'running':
  75. self.container = self.docker_client.containers.get(self.container_name)
  76. elapsed += 1
  77. logger.debug(
  78. f'waiting for container to start: {elapsed}, container status: {self.container.status}'
  79. )
  80. if elapsed > self.timeout:
  81. break
  82. self.api_port = int(
  83. self.container.attrs['NetworkSettings']['Ports']['8000/tcp'][0]['HostPort']
  84. )
  85. self.api_server = f'{self.api_host}:{self.api_port}'
  86. self.client = InvariantClient(self.api_server, self.sid)
  87. if policy is None:
  88. policy, _ = self.client.Policy.get_template()
  89. if policy is None:
  90. policy = ''
  91. self.monitor = self.client.Monitor.from_string(policy)
  92. async def close(self):
  93. self.container.stop()
  94. async def log_event(self, event: Event) -> None:
  95. if isinstance(event, Observation):
  96. element = parse_element(self.trace, event)
  97. self.trace.extend(element)
  98. self.input.extend([e.model_dump(exclude_none=True) for e in element]) # type: ignore [call-overload]
  99. else:
  100. logger.debug('Invariant skipping element: event')
  101. def get_risk(self, results: list[str]) -> ActionSecurityRisk:
  102. mapping = {
  103. 'high': ActionSecurityRisk.HIGH,
  104. 'medium': ActionSecurityRisk.MEDIUM,
  105. 'low': ActionSecurityRisk.LOW,
  106. }
  107. regex = r'(?<=risk=)\w+'
  108. risks = []
  109. for result in results:
  110. m = re.search(regex, result)
  111. if m and m.group() in mapping:
  112. risks.append(mapping[m.group()])
  113. if risks:
  114. return max(risks)
  115. return ActionSecurityRisk.LOW
  116. async def act(self, event: Event) -> None:
  117. if await self.should_confirm(event):
  118. await self.confirm(event)
  119. async def should_confirm(self, event: Event) -> bool:
  120. risk = event.security_risk # type: ignore [attr-defined]
  121. return (
  122. risk is not None
  123. and risk < self.settings.get('RISK_SEVERITY', ActionSecurityRisk.MEDIUM)
  124. and hasattr(event, 'confirmation_state')
  125. and event.confirmation_state
  126. == ActionConfirmationStatus.AWAITING_CONFIRMATION
  127. )
  128. async def confirm(self, event: Event) -> None:
  129. new_event = action_from_dict(
  130. {'action': 'change_agent_state', 'args': {'agent_state': 'user_confirmed'}}
  131. )
  132. # we should confirm only on agent actions
  133. event_source = event.source if event.source else EventSource.AGENT
  134. await call_sync_from_async(self.event_stream.add_event, new_event, event_source)
  135. async def security_risk(self, event: Action) -> ActionSecurityRisk:
  136. logger.debug('Calling security_risk on InvariantAnalyzer')
  137. new_elements = parse_element(self.trace, event)
  138. input = [e.model_dump(exclude_none=True) for e in new_elements] # type: ignore [call-overload]
  139. self.trace.extend(new_elements)
  140. result, err = self.monitor.check(self.input, input)
  141. self.input.extend(input)
  142. risk = ActionSecurityRisk.UNKNOWN
  143. if err:
  144. logger.warning(f'Error checking policy: {err}')
  145. return risk
  146. risk = self.get_risk(result)
  147. return risk
  148. ### Handle API requests
  149. async def handle_api_request(self, request: Request) -> Any:
  150. path_parts = request.url.path.strip('/').split('/')
  151. endpoint = path_parts[-1] # Get the last part of the path
  152. if request.method == 'GET':
  153. if endpoint == 'export-trace':
  154. return await self.export_trace(request)
  155. elif endpoint == 'policy':
  156. return await self.get_policy(request)
  157. elif endpoint == 'settings':
  158. return await self.get_settings(request)
  159. elif request.method == 'POST':
  160. if endpoint == 'policy':
  161. return await self.update_policy(request)
  162. elif endpoint == 'settings':
  163. return await self.update_settings(request)
  164. raise HTTPException(status_code=405, detail='Method Not Allowed')
  165. async def export_trace(self, request: Request) -> Any:
  166. return JSONResponse(content=self.input)
  167. async def get_policy(self, request: Request) -> Any:
  168. return JSONResponse(content={'policy': self.monitor.policy})
  169. async def update_policy(self, request: Request) -> Any:
  170. data = await request.json()
  171. policy = data.get('policy')
  172. new_monitor = self.client.Monitor.from_string(policy)
  173. self.monitor = new_monitor
  174. return JSONResponse(content={'policy': policy})
  175. async def get_settings(self, request: Request) -> Any:
  176. return JSONResponse(content=self.settings)
  177. async def update_settings(self, request: Request) -> Any:
  178. settings = await request.json()
  179. self.settings = settings
  180. return JSONResponse(content=self.settings)