client.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. import time
  2. from typing import Any, Union
  3. import requests
  4. from requests.exceptions import ConnectionError, HTTPError, Timeout
  5. class InvariantClient:
  6. timeout: int = 120
  7. def __init__(self, server_url: str, session_id: str | None = None):
  8. self.server = server_url
  9. self.session_id, err = self._create_session(session_id)
  10. if err:
  11. raise RuntimeError(f'Failed to create session: {err}')
  12. self.Policy = self._Policy(self)
  13. self.Monitor = self._Monitor(self)
  14. def _create_session(
  15. self, session_id: str | None = None
  16. ) -> tuple[str | None, Exception | None]:
  17. elapsed = 0
  18. while elapsed < self.timeout:
  19. try:
  20. if session_id:
  21. response = requests.get(
  22. f'{self.server}/session/new?session_id={session_id}', timeout=60
  23. )
  24. else:
  25. response = requests.get(f'{self.server}/session/new', timeout=60)
  26. response.raise_for_status()
  27. return response.json().get('id'), None
  28. except (ConnectionError, Timeout):
  29. elapsed += 1
  30. time.sleep(1)
  31. except HTTPError as http_err:
  32. return None, http_err
  33. except Exception as err:
  34. return None, err
  35. return None, ConnectionError('Connection timed out')
  36. def close_session(self) -> Union[None, Exception]:
  37. try:
  38. response = requests.delete(
  39. f'{self.server}/session/?session_id={self.session_id}', timeout=60
  40. )
  41. response.raise_for_status()
  42. except (ConnectionError, Timeout, HTTPError) as err:
  43. return err
  44. return None
  45. class _Policy:
  46. def __init__(self, invariant):
  47. self.server = invariant.server
  48. self.session_id = invariant.session_id
  49. def _create_policy(self, rule: str) -> tuple[str | None, Exception | None]:
  50. try:
  51. response = requests.post(
  52. f'{self.server}/policy/new?session_id={self.session_id}',
  53. json={'rule': rule},
  54. timeout=60,
  55. )
  56. response.raise_for_status()
  57. return response.json().get('policy_id'), None
  58. except (ConnectionError, Timeout, HTTPError) as err:
  59. return None, err
  60. def get_template(self) -> tuple[str | None, Exception | None]:
  61. try:
  62. response = requests.get(
  63. f'{self.server}/policy/template',
  64. timeout=60,
  65. )
  66. response.raise_for_status()
  67. return response.json(), None
  68. except (ConnectionError, Timeout, HTTPError) as err:
  69. return None, err
  70. def from_string(self, rule: str):
  71. policy_id, err = self._create_policy(rule)
  72. if err:
  73. raise err
  74. self.policy_id = policy_id
  75. return self
  76. def analyze(self, trace: list[dict]) -> Union[Any, Exception]:
  77. try:
  78. response = requests.post(
  79. f'{self.server}/policy/{self.policy_id}/analyze?session_id={self.session_id}',
  80. json={'trace': trace},
  81. timeout=60,
  82. )
  83. response.raise_for_status()
  84. return response.json(), None
  85. except (ConnectionError, Timeout, HTTPError) as err:
  86. return None, err
  87. class _Monitor:
  88. def __init__(self, invariant):
  89. self.server = invariant.server
  90. self.session_id = invariant.session_id
  91. self.policy = ''
  92. def _create_monitor(self, rule: str) -> tuple[str | None, Exception | None]:
  93. try:
  94. response = requests.post(
  95. f'{self.server}/monitor/new?session_id={self.session_id}',
  96. json={'rule': rule},
  97. timeout=60,
  98. )
  99. response.raise_for_status()
  100. return response.json().get('monitor_id'), None
  101. except (ConnectionError, Timeout, HTTPError) as err:
  102. return None, err
  103. def from_string(self, rule: str):
  104. monitor_id, err = self._create_monitor(rule)
  105. if err:
  106. raise err
  107. self.monitor_id = monitor_id
  108. self.policy = rule
  109. return self
  110. def check(
  111. self, past_events: list[dict], pending_events: list[dict]
  112. ) -> Union[Any, Exception]:
  113. try:
  114. response = requests.post(
  115. f'{self.server}/monitor/{self.monitor_id}/check?session_id={self.session_id}',
  116. json={'past_events': past_events, 'pending_events': pending_events},
  117. timeout=60,
  118. )
  119. response.raise_for_status()
  120. return response.json(), None
  121. except (ConnectionError, Timeout, HTTPError) as err:
  122. return None, err