|
|
@@ -51,24 +51,48 @@ def add_events(event_stream: EventStream, data: list[tuple[Event, EventSource]])
|
|
|
|
|
|
|
|
|
def test_msg(temp_dir: str):
|
|
|
- file_store = get_file_store('local', temp_dir)
|
|
|
- event_stream = EventStream('main', file_store)
|
|
|
- policy = """
|
|
|
- raise "Disallow ABC [risk=medium]" if:
|
|
|
- (msg: Message)
|
|
|
- "ABC" in msg.content
|
|
|
- """
|
|
|
- InvariantAnalyzer(event_stream, policy)
|
|
|
- data = [
|
|
|
- (MessageAction('Hello world!'), EventSource.USER),
|
|
|
- (MessageAction('AB!'), EventSource.AGENT),
|
|
|
- (MessageAction('Hello world!'), EventSource.USER),
|
|
|
- (MessageAction('ABC!'), EventSource.AGENT),
|
|
|
+ mock_container = MagicMock()
|
|
|
+ mock_container.status = 'running'
|
|
|
+ mock_container.attrs = {
|
|
|
+ 'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': 34567}]}}
|
|
|
+ }
|
|
|
+ mock_docker = MagicMock()
|
|
|
+ mock_docker.from_env().containers.list.return_value = [mock_container]
|
|
|
+
|
|
|
+ mock_requests = MagicMock()
|
|
|
+ mock_requests.get().json.return_value = {'id': 'mock-session-id'}
|
|
|
+ mock_requests.post().json.side_effect = [
|
|
|
+ {'monitor_id': 'mock-monitor-id'},
|
|
|
+ [], # First check
|
|
|
+ [], # Second check
|
|
|
+ [], # Third check
|
|
|
+ [
|
|
|
+ 'PolicyViolation(Disallow ABC [risk=medium], ranges=[<2 ranges>])'
|
|
|
+ ], # Fourth check
|
|
|
]
|
|
|
- add_events(event_stream, data)
|
|
|
- for i in range(3):
|
|
|
- assert data[i][0].security_risk == ActionSecurityRisk.LOW
|
|
|
- assert data[3][0].security_risk == ActionSecurityRisk.MEDIUM
|
|
|
+
|
|
|
+ with (
|
|
|
+ patch(f'{InvariantAnalyzer.__module__}.docker', mock_docker),
|
|
|
+ patch(f'{InvariantClient.__module__}.requests', mock_requests),
|
|
|
+ ):
|
|
|
+ file_store = get_file_store('local', temp_dir)
|
|
|
+ event_stream = EventStream('main', file_store)
|
|
|
+ policy = """
|
|
|
+ raise "Disallow ABC [risk=medium]" if:
|
|
|
+ (msg: Message)
|
|
|
+ "ABC" in msg.content
|
|
|
+ """
|
|
|
+ InvariantAnalyzer(event_stream, policy)
|
|
|
+ data = [
|
|
|
+ (MessageAction('Hello world!'), EventSource.USER),
|
|
|
+ (MessageAction('AB!'), EventSource.AGENT),
|
|
|
+ (MessageAction('Hello world!'), EventSource.USER),
|
|
|
+ (MessageAction('ABC!'), EventSource.AGENT),
|
|
|
+ ]
|
|
|
+ add_events(event_stream, data)
|
|
|
+ for i in range(3):
|
|
|
+ assert data[i][0].security_risk == ActionSecurityRisk.LOW
|
|
|
+ assert data[3][0].security_risk == ActionSecurityRisk.MEDIUM
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
@@ -76,22 +100,44 @@ def test_msg(temp_dir: str):
|
|
|
[('rm -rf root_dir', ActionSecurityRisk.MEDIUM), ['ls', ActionSecurityRisk.LOW]],
|
|
|
)
|
|
|
def test_cmd(cmd, expected_risk, temp_dir: str):
|
|
|
- file_store = get_file_store('local', temp_dir)
|
|
|
- event_stream = EventStream('main', file_store)
|
|
|
- policy = """
|
|
|
- raise "Disallow rm -rf [risk=medium]" if:
|
|
|
- (call: ToolCall)
|
|
|
- call is tool:run
|
|
|
- match("rm -rf", call.function.arguments.command)
|
|
|
- """
|
|
|
- InvariantAnalyzer(event_stream, policy)
|
|
|
- data = [
|
|
|
- (MessageAction('Hello world!'), EventSource.USER),
|
|
|
- (CmdRunAction(cmd), EventSource.USER),
|
|
|
+ mock_container = MagicMock()
|
|
|
+ mock_container.status = 'running'
|
|
|
+ mock_container.attrs = {
|
|
|
+ 'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': 34567}]}}
|
|
|
+ }
|
|
|
+ mock_docker = MagicMock()
|
|
|
+ mock_docker.from_env().containers.list.return_value = [mock_container]
|
|
|
+
|
|
|
+ mock_requests = MagicMock()
|
|
|
+ mock_requests.get().json.return_value = {'id': 'mock-session-id'}
|
|
|
+ mock_requests.post().json.side_effect = [
|
|
|
+ {'monitor_id': 'mock-monitor-id'},
|
|
|
+ [], # First check
|
|
|
+ ['PolicyViolation(Disallow rm -rf [risk=medium], ranges=[<2 ranges>])']
|
|
|
+ if expected_risk == ActionSecurityRisk.MEDIUM
|
|
|
+ else [], # Second check
|
|
|
]
|
|
|
- add_events(event_stream, data)
|
|
|
- assert data[0][0].security_risk == ActionSecurityRisk.LOW
|
|
|
- assert data[1][0].security_risk == expected_risk
|
|
|
+
|
|
|
+ with (
|
|
|
+ patch(f'{InvariantAnalyzer.__module__}.docker', mock_docker),
|
|
|
+ patch(f'{InvariantClient.__module__}.requests', mock_requests),
|
|
|
+ ):
|
|
|
+ file_store = get_file_store('local', temp_dir)
|
|
|
+ event_stream = EventStream('main', file_store)
|
|
|
+ policy = """
|
|
|
+ raise "Disallow rm -rf [risk=medium]" if:
|
|
|
+ (call: ToolCall)
|
|
|
+ call is tool:run
|
|
|
+ match("rm -rf", call.function.arguments.command)
|
|
|
+ """
|
|
|
+ InvariantAnalyzer(event_stream, policy)
|
|
|
+ data = [
|
|
|
+ (MessageAction('Hello world!'), EventSource.USER),
|
|
|
+ (CmdRunAction(cmd), EventSource.USER),
|
|
|
+ ]
|
|
|
+ add_events(event_stream, data)
|
|
|
+ assert data[0][0].security_risk == ActionSecurityRisk.LOW
|
|
|
+ assert data[1][0].security_risk == expected_risk
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
@@ -102,26 +148,49 @@ def test_cmd(cmd, expected_risk, temp_dir: str):
|
|
|
],
|
|
|
)
|
|
|
def test_leak_secrets(code, expected_risk, temp_dir: str):
|
|
|
- file_store = get_file_store('local', temp_dir)
|
|
|
- event_stream = EventStream('main', file_store)
|
|
|
- policy = """
|
|
|
- from invariant.detectors import secrets
|
|
|
-
|
|
|
- raise "Disallow writing secrets [risk=medium]" if:
|
|
|
- (call: ToolCall)
|
|
|
- call is tool:run_ipython
|
|
|
- any(secrets(call.function.arguments.code))
|
|
|
- """
|
|
|
- InvariantAnalyzer(event_stream, policy)
|
|
|
- data = [
|
|
|
- (MessageAction('Hello world!'), EventSource.USER),
|
|
|
- (IPythonRunCellAction(code), EventSource.AGENT),
|
|
|
- (IPythonRunCellAction('hello'), EventSource.AGENT),
|
|
|
+ mock_container = MagicMock()
|
|
|
+ mock_container.status = 'running'
|
|
|
+ mock_container.attrs = {
|
|
|
+ 'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': 34567}]}}
|
|
|
+ }
|
|
|
+ mock_docker = MagicMock()
|
|
|
+ mock_docker.from_env().containers.list.return_value = [mock_container]
|
|
|
+
|
|
|
+ mock_requests = MagicMock()
|
|
|
+ mock_requests.get().json.return_value = {'id': 'mock-session-id'}
|
|
|
+ mock_requests.post().json.side_effect = [
|
|
|
+ {'monitor_id': 'mock-monitor-id'},
|
|
|
+ [], # First check
|
|
|
+ ['PolicyViolation(Disallow writing secrets [risk=medium], ranges=[<2 ranges>])']
|
|
|
+ if expected_risk == ActionSecurityRisk.MEDIUM
|
|
|
+ else [], # Second check
|
|
|
+ [], # Third check
|
|
|
]
|
|
|
- add_events(event_stream, data)
|
|
|
- assert data[0][0].security_risk == ActionSecurityRisk.LOW
|
|
|
- assert data[1][0].security_risk == expected_risk
|
|
|
- assert data[2][0].security_risk == ActionSecurityRisk.LOW
|
|
|
+
|
|
|
+ with (
|
|
|
+ patch(f'{InvariantAnalyzer.__module__}.docker', mock_docker),
|
|
|
+ patch(f'{InvariantClient.__module__}.requests', mock_requests),
|
|
|
+ ):
|
|
|
+ file_store = get_file_store('local', temp_dir)
|
|
|
+ event_stream = EventStream('main', file_store)
|
|
|
+ policy = """
|
|
|
+ from invariant.detectors import secrets
|
|
|
+
|
|
|
+ raise "Disallow writing secrets [risk=medium]" if:
|
|
|
+ (call: ToolCall)
|
|
|
+ call is tool:run_ipython
|
|
|
+ any(secrets(call.function.arguments.code))
|
|
|
+ """
|
|
|
+ InvariantAnalyzer(event_stream, policy)
|
|
|
+ data = [
|
|
|
+ (MessageAction('Hello world!'), EventSource.USER),
|
|
|
+ (IPythonRunCellAction(code), EventSource.AGENT),
|
|
|
+ (IPythonRunCellAction('hello'), EventSource.AGENT),
|
|
|
+ ]
|
|
|
+ add_events(event_stream, data)
|
|
|
+ assert data[0][0].security_risk == ActionSecurityRisk.LOW
|
|
|
+ assert data[1][0].security_risk == expected_risk
|
|
|
+ assert data[2][0].security_risk == ActionSecurityRisk.LOW
|
|
|
|
|
|
|
|
|
def test_unsafe_python_code(temp_dir: str):
|
|
|
@@ -458,26 +527,48 @@ def default_config():
|
|
|
def test_check_usertask(
|
|
|
mock_litellm_completion, usertask, is_appropriate, default_config, temp_dir: str
|
|
|
):
|
|
|
- file_store = get_file_store('local', temp_dir)
|
|
|
- event_stream = EventStream('main', file_store)
|
|
|
- analyzer = InvariantAnalyzer(event_stream)
|
|
|
- mock_response = {'choices': [{'message': {'content': is_appropriate}}]}
|
|
|
- mock_litellm_completion.return_value = mock_response
|
|
|
- analyzer.guardrail_llm = LLM(config=default_config)
|
|
|
- analyzer.check_browsing_alignment = True
|
|
|
- data = [
|
|
|
- (MessageAction(usertask), EventSource.USER),
|
|
|
+ mock_container = MagicMock()
|
|
|
+ mock_container.status = 'running'
|
|
|
+ mock_container.attrs = {
|
|
|
+ 'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': 34567}]}}
|
|
|
+ }
|
|
|
+ mock_docker = MagicMock()
|
|
|
+ mock_docker.from_env().containers.list.return_value = [mock_container]
|
|
|
+
|
|
|
+ mock_requests = MagicMock()
|
|
|
+ mock_requests.get().json.return_value = {'id': 'mock-session-id'}
|
|
|
+ mock_requests.post().json.side_effect = [
|
|
|
+ {'monitor_id': 'mock-monitor-id'},
|
|
|
+ [],
|
|
|
+ [
|
|
|
+ 'PolicyViolation(Vulnerability in python code [risk=medium], ranges=[<2 ranges>])'
|
|
|
+ ],
|
|
|
]
|
|
|
- add_events(event_stream, data)
|
|
|
- event_list = list(event_stream.get_events())
|
|
|
|
|
|
- if is_appropriate == 'No':
|
|
|
- assert len(event_list) == 2
|
|
|
- assert type(event_list[0]) == MessageAction
|
|
|
- assert type(event_list[1]) == ChangeAgentStateAction
|
|
|
- elif is_appropriate == 'Yes':
|
|
|
- assert len(event_list) == 1
|
|
|
- assert type(event_list[0]) == MessageAction
|
|
|
+ with (
|
|
|
+ patch(f'{InvariantAnalyzer.__module__}.docker', mock_docker),
|
|
|
+ patch(f'{InvariantClient.__module__}.requests', mock_requests),
|
|
|
+ ):
|
|
|
+ file_store = get_file_store('local', temp_dir)
|
|
|
+ event_stream = EventStream('main', file_store)
|
|
|
+ analyzer = InvariantAnalyzer(event_stream)
|
|
|
+ mock_response = {'choices': [{'message': {'content': is_appropriate}}]}
|
|
|
+ mock_litellm_completion.return_value = mock_response
|
|
|
+ analyzer.guardrail_llm = LLM(config=default_config)
|
|
|
+ analyzer.check_browsing_alignment = True
|
|
|
+ data = [
|
|
|
+ (MessageAction(usertask), EventSource.USER),
|
|
|
+ ]
|
|
|
+ add_events(event_stream, data)
|
|
|
+ event_list = list(event_stream.get_events())
|
|
|
+
|
|
|
+ if is_appropriate == 'No':
|
|
|
+ assert len(event_list) == 2
|
|
|
+ assert type(event_list[0]) == MessageAction
|
|
|
+ assert type(event_list[1]) == ChangeAgentStateAction
|
|
|
+ elif is_appropriate == 'Yes':
|
|
|
+ assert len(event_list) == 1
|
|
|
+ assert type(event_list[0]) == MessageAction
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
@@ -491,23 +582,45 @@ def test_check_usertask(
|
|
|
def test_check_fillaction(
|
|
|
mock_litellm_completion, fillaction, is_harmful, default_config, temp_dir: str
|
|
|
):
|
|
|
- file_store = get_file_store('local', temp_dir)
|
|
|
- event_stream = EventStream('main', file_store)
|
|
|
- analyzer = InvariantAnalyzer(event_stream)
|
|
|
- mock_response = {'choices': [{'message': {'content': is_harmful}}]}
|
|
|
- mock_litellm_completion.return_value = mock_response
|
|
|
- analyzer.guardrail_llm = LLM(config=default_config)
|
|
|
- analyzer.check_browsing_alignment = True
|
|
|
- data = [
|
|
|
- (BrowseInteractiveAction(browser_actions=fillaction), EventSource.AGENT),
|
|
|
+ mock_container = MagicMock()
|
|
|
+ mock_container.status = 'running'
|
|
|
+ mock_container.attrs = {
|
|
|
+ 'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': 34567}]}}
|
|
|
+ }
|
|
|
+ mock_docker = MagicMock()
|
|
|
+ mock_docker.from_env().containers.list.return_value = [mock_container]
|
|
|
+
|
|
|
+ mock_requests = MagicMock()
|
|
|
+ mock_requests.get().json.return_value = {'id': 'mock-session-id'}
|
|
|
+ mock_requests.post().json.side_effect = [
|
|
|
+ {'monitor_id': 'mock-monitor-id'},
|
|
|
+ [],
|
|
|
+ [
|
|
|
+ 'PolicyViolation(Vulnerability in python code [risk=medium], ranges=[<2 ranges>])'
|
|
|
+ ],
|
|
|
]
|
|
|
- add_events(event_stream, data)
|
|
|
- event_list = list(event_stream.get_events())
|
|
|
-
|
|
|
- if is_harmful == 'Yes':
|
|
|
- assert len(event_list) == 2
|
|
|
- assert type(event_list[0]) == BrowseInteractiveAction
|
|
|
- assert type(event_list[1]) == ChangeAgentStateAction
|
|
|
- elif is_harmful == 'No':
|
|
|
- assert len(event_list) == 1
|
|
|
- assert type(event_list[0]) == BrowseInteractiveAction
|
|
|
+
|
|
|
+ with (
|
|
|
+ patch(f'{InvariantAnalyzer.__module__}.docker', mock_docker),
|
|
|
+ patch(f'{InvariantClient.__module__}.requests', mock_requests),
|
|
|
+ ):
|
|
|
+ file_store = get_file_store('local', temp_dir)
|
|
|
+ event_stream = EventStream('main', file_store)
|
|
|
+ analyzer = InvariantAnalyzer(event_stream)
|
|
|
+ mock_response = {'choices': [{'message': {'content': is_harmful}}]}
|
|
|
+ mock_litellm_completion.return_value = mock_response
|
|
|
+ analyzer.guardrail_llm = LLM(config=default_config)
|
|
|
+ analyzer.check_browsing_alignment = True
|
|
|
+ data = [
|
|
|
+ (BrowseInteractiveAction(browser_actions=fillaction), EventSource.AGENT),
|
|
|
+ ]
|
|
|
+ add_events(event_stream, data)
|
|
|
+ event_list = list(event_stream.get_events())
|
|
|
+
|
|
|
+ if is_harmful == 'Yes':
|
|
|
+ assert len(event_list) == 2
|
|
|
+ assert type(event_list[0]) == BrowseInteractiveAction
|
|
|
+ assert type(event_list[1]) == ChangeAgentStateAction
|
|
|
+ elif is_harmful == 'No':
|
|
|
+ assert len(event_list) == 1
|
|
|
+ assert type(event_list[0]) == BrowseInteractiveAction
|