| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208 |
- import re
- import uuid
- from typing import Any
- import docker
- from fastapi import HTTPException, Request
- from fastapi.responses import JSONResponse
- from openhands.core.logger import openhands_logger as logger
- from openhands.events.action.action import (
- Action,
- ActionConfirmationStatus,
- ActionSecurityRisk,
- )
- from openhands.events.event import Event, EventSource
- from openhands.events.observation import Observation
- from openhands.events.serialization.action import action_from_dict
- from openhands.events.stream import EventStream
- from openhands.runtime.utils import find_available_tcp_port
- from openhands.security.analyzer import SecurityAnalyzer
- from openhands.security.invariant.client import InvariantClient
- from openhands.security.invariant.parser import TraceElement, parse_element
- from openhands.utils.async_utils import call_sync_from_async
- class InvariantAnalyzer(SecurityAnalyzer):
- """Security analyzer based on Invariant."""
- trace: list[TraceElement]
- input: list[dict]
- container_name: str = 'openhands-invariant-server'
- image_name: str = 'ghcr.io/invariantlabs-ai/server:openhands'
- api_host: str = 'http://localhost'
- timeout: int = 180
- settings: dict = {}
- def __init__(
- self,
- event_stream: EventStream,
- policy: str | None = None,
- sid: str | None = None,
- ):
- """Initializes a new instance of the InvariantAnalzyer class."""
- super().__init__(event_stream)
- self.trace = []
- self.input = []
- self.settings = {}
- if sid is None:
- self.sid = str(uuid.uuid4())
- try:
- self.docker_client = docker.from_env()
- except Exception as ex:
- logger.exception(
- 'Error creating Invariant Security Analyzer container. Please check that Docker is running or disable the Security Analyzer in settings.',
- exc_info=False,
- )
- raise ex
- running_containers = self.docker_client.containers.list(
- filters={'name': self.container_name}
- )
- if not running_containers:
- all_containers = self.docker_client.containers.list(
- all=True, filters={'name': self.container_name}
- )
- if all_containers:
- self.container = all_containers[0]
- all_containers[0].start()
- else:
- self.api_port = find_available_tcp_port()
- self.container = self.docker_client.containers.run(
- self.image_name,
- name=self.container_name,
- platform='linux/amd64',
- ports={'8000/tcp': self.api_port},
- detach=True,
- )
- else:
- self.container = running_containers[0]
- elapsed = 0
- while self.container.status != 'running':
- self.container = self.docker_client.containers.get(self.container_name)
- elapsed += 1
- logger.debug(
- f'waiting for container to start: {elapsed}, container status: {self.container.status}'
- )
- if elapsed > self.timeout:
- break
- self.api_port = int(
- self.container.attrs['NetworkSettings']['Ports']['8000/tcp'][0]['HostPort']
- )
- self.api_server = f'{self.api_host}:{self.api_port}'
- self.client = InvariantClient(self.api_server, self.sid)
- if policy is None:
- policy, _ = self.client.Policy.get_template()
- if policy is None:
- policy = ''
- self.monitor = self.client.Monitor.from_string(policy)
- async def close(self):
- self.container.stop()
- async def log_event(self, event: Event) -> None:
- if isinstance(event, Observation):
- element = parse_element(self.trace, event)
- self.trace.extend(element)
- self.input.extend([e.model_dump(exclude_none=True) for e in element]) # type: ignore [call-overload]
- else:
- logger.debug('Invariant skipping element: event')
- def get_risk(self, results: list[str]) -> ActionSecurityRisk:
- mapping = {
- 'high': ActionSecurityRisk.HIGH,
- 'medium': ActionSecurityRisk.MEDIUM,
- 'low': ActionSecurityRisk.LOW,
- }
- regex = r'(?<=risk=)\w+'
- risks = []
- for result in results:
- m = re.search(regex, result)
- if m and m.group() in mapping:
- risks.append(mapping[m.group()])
- if risks:
- return max(risks)
- return ActionSecurityRisk.LOW
- async def act(self, event: Event) -> None:
- if await self.should_confirm(event):
- await self.confirm(event)
- async def should_confirm(self, event: Event) -> bool:
- risk = event.security_risk # type: ignore [attr-defined]
- return (
- risk is not None
- and risk < self.settings.get('RISK_SEVERITY', ActionSecurityRisk.MEDIUM)
- and hasattr(event, 'confirmation_state')
- and event.confirmation_state
- == ActionConfirmationStatus.AWAITING_CONFIRMATION
- )
- async def confirm(self, event: Event) -> None:
- new_event = action_from_dict(
- {'action': 'change_agent_state', 'args': {'agent_state': 'user_confirmed'}}
- )
- # we should confirm only on agent actions
- event_source = event.source if event.source else EventSource.AGENT
- await call_sync_from_async(self.event_stream.add_event, new_event, event_source)
- async def security_risk(self, event: Action) -> ActionSecurityRisk:
- logger.debug('Calling security_risk on InvariantAnalyzer')
- new_elements = parse_element(self.trace, event)
- input = [e.model_dump(exclude_none=True) for e in new_elements] # type: ignore [call-overload]
- self.trace.extend(new_elements)
- result, err = self.monitor.check(self.input, input)
- self.input.extend(input)
- risk = ActionSecurityRisk.UNKNOWN
- if err:
- logger.warning(f'Error checking policy: {err}')
- return risk
- risk = self.get_risk(result)
- return risk
- ### Handle API requests
- async def handle_api_request(self, request: Request) -> Any:
- path_parts = request.url.path.strip('/').split('/')
- endpoint = path_parts[-1] # Get the last part of the path
- if request.method == 'GET':
- if endpoint == 'export-trace':
- return await self.export_trace(request)
- elif endpoint == 'policy':
- return await self.get_policy(request)
- elif endpoint == 'settings':
- return await self.get_settings(request)
- elif request.method == 'POST':
- if endpoint == 'policy':
- return await self.update_policy(request)
- elif endpoint == 'settings':
- return await self.update_settings(request)
- raise HTTPException(status_code=405, detail='Method Not Allowed')
- async def export_trace(self, request: Request) -> Any:
- return JSONResponse(content=self.input)
- async def get_policy(self, request: Request) -> Any:
- return JSONResponse(content={'policy': self.monitor.policy})
- async def update_policy(self, request: Request) -> Any:
- data = await request.json()
- policy = data.get('policy')
- new_monitor = self.client.Monitor.from_string(policy)
- self.monitor = new_monitor
- return JSONResponse(content={'policy': policy})
- async def get_settings(self, request: Request) -> Any:
- return JSONResponse(content=self.settings)
- async def update_settings(self, request: Request) -> Any:
- settings = await request.json()
- self.settings = settings
- return JSONResponse(content=self.settings)
|