♻ — Refactor some agent code

This commit is contained in:
Luc Didry 2024-11-26 16:44:19 +01:00
parent 8ac2519398
commit 4117f9f628
No known key found for this signature in database
GPG key ID: EA868E12D0257E3C
5 changed files with 39 additions and 161 deletions

View file

@ -6,6 +6,7 @@
- ♿️ — Fix not-OK domains display if javascript is disabled - ♿️ — Fix not-OK domains display if javascript is disabled
- ✨ — Retry check right after a httpx.ReadError - ✨ — Retry check right after a httpx.ReadError
- ✨ — The HTTP method used by checks is now configurable - ✨ — The HTTP method used by checks is now configurable
- ♻️ — Refactor some agent code
## 0.5.0 ## 0.5.0

View file

@ -6,6 +6,7 @@ import asyncio
import json import json
import logging import logging
import socket import socket
from time import sleep
from typing import List from typing import List
import httpx import httpx
@ -63,9 +64,24 @@ class ArgosAgent:
async def _complete_task(self, _task: dict) -> AgentResult: async def _complete_task(self, _task: dict) -> AgentResult:
try: try:
task = Task(**_task) task = Task(**_task)
url = task.url
if task.check == "http-to-https":
url = str(httpx.URL(task.url).copy_with(scheme="http"))
try:
response = await self._http_client.request( # type: ignore[attr-defined]
method=task.method, url=url, timeout=60
)
except httpx.ReadError:
sleep(1)
response = await self._http_client.request( # type: ignore[attr-defined]
method=task.method, url=url, timeout=60
)
check_class = get_registered_check(task.check) check_class = get_registered_check(task.check)
check = check_class(self._http_client, task) check = check_class(task)
result = await check.run() result = await check.run(response)
status = result.status status = result.status
context = result.context context = result.context

View file

@ -3,7 +3,6 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Type from typing import Type
import httpx
from pydantic import BaseModel from pydantic import BaseModel
from argos.schemas.models import Task from argos.schemas.models import Task
@ -92,8 +91,7 @@ class BaseCheck:
raise CheckNotFound(name) raise CheckNotFound(name)
return check return check
def __init__(self, http_client: httpx.AsyncClient, task: Task): def __init__(self, task: Task):
self.http_client = http_client
self.task = task self.task = task
@property @property

View file

@ -3,9 +3,8 @@
import json import json
import re import re
from datetime import datetime from datetime import datetime
from time import sleep
from httpx import URL, ReadError from httpx import Response
from jsonpointer import resolve_pointer, JsonPointerException from jsonpointer import resolve_pointer, JsonPointerException
from argos.checks.base import ( from argos.checks.base import (
@ -23,18 +22,7 @@ class HTTPStatus(BaseCheck):
config = "status-is" config = "status-is"
expected_cls = ExpectedIntValue expected_cls = ExpectedIntValue
async def run(self) -> dict: async def run(self, response: Response) -> dict:
task = self.task
try:
response = await self.http_client.request(
method=task.method, url=task.url, timeout=60
)
except ReadError:
sleep(1)
response = await self.http_client.request(
method=task.method, url=task.url, timeout=60
)
return self.response( return self.response(
status=response.status_code == self.expected, status=response.status_code == self.expected,
expected=self.expected, expected=self.expected,
@ -48,18 +36,7 @@ class HTTPStatusIn(BaseCheck):
config = "status-in" config = "status-in"
expected_cls = ExpectedStringValue expected_cls = ExpectedStringValue
async def run(self) -> dict: async def run(self, response: Response) -> dict:
task = self.task
try:
response = await self.http_client.request(
method=task.method, url=task.url, timeout=60
)
except ReadError:
sleep(1)
response = await self.http_client.request(
method=task.method, url=task.url, timeout=60
)
return self.response( return self.response(
status=response.status_code in json.loads(self.expected), status=response.status_code in json.loads(self.expected),
expected=self.expected, expected=self.expected,
@ -73,19 +50,7 @@ class HTTPToHTTPS(BaseCheck):
config = "http-to-https" config = "http-to-https"
expected_cls = ExpectedStringValue expected_cls = ExpectedStringValue
async def run(self) -> dict: async def run(self, response: Response) -> dict:
task = self.task
url = URL(task.url).copy_with(scheme="http")
try:
response = await self.http_client.request(
method=task.method, url=url, timeout=60
)
except ReadError:
sleep(1)
response = await self.http_client.request(
method=task.method, url=url, timeout=60
)
expected_dict = json.loads(self.expected) expected_dict = json.loads(self.expected)
expected = range(300, 400) expected = range(300, 400)
if "range" in expected_dict: if "range" in expected_dict:
@ -109,18 +74,7 @@ class HTTPHeadersContain(BaseCheck):
config = "headers-contain" config = "headers-contain"
expected_cls = ExpectedStringValue expected_cls = ExpectedStringValue
async def run(self) -> dict: async def run(self, response: Response) -> dict:
task = self.task
try:
response = await self.http_client.request(
method=task.method, url=task.url, timeout=60
)
except ReadError:
sleep(1)
response = await self.http_client.request(
method=task.method, url=task.url, timeout=60
)
status = True status = True
for header in json.loads(self.expected): for header in json.loads(self.expected):
if header not in response.headers: if header not in response.headers:
@ -140,18 +94,7 @@ class HTTPHeadersHave(BaseCheck):
config = "headers-have" config = "headers-have"
expected_cls = ExpectedStringValue expected_cls = ExpectedStringValue
async def run(self) -> dict: async def run(self, response: Response) -> dict:
task = self.task
try:
response = await self.http_client.request(
method=task.method, url=task.url, timeout=60
)
except ReadError:
sleep(1)
response = await self.http_client.request(
method=task.method, url=task.url, timeout=60
)
status = True status = True
for header, value in json.loads(self.expected).items(): for header, value in json.loads(self.expected).items():
if header not in response.headers: if header not in response.headers:
@ -175,18 +118,7 @@ class HTTPHeadersLike(BaseCheck):
config = "headers-like" config = "headers-like"
expected_cls = ExpectedStringValue expected_cls = ExpectedStringValue
async def run(self) -> dict: async def run(self, response: Response) -> dict:
task = self.task
try:
response = await self.http_client.request(
method=task.method, url=task.url, timeout=60
)
except ReadError:
sleep(1)
response = await self.http_client.request(
method=task.method, url=task.url, timeout=60
)
status = True status = True
for header, value in json.loads(self.expected).items(): for header, value in json.loads(self.expected).items():
if header not in response.headers: if header not in response.headers:
@ -209,16 +141,7 @@ class HTTPBodyContains(BaseCheck):
config = "body-contains" config = "body-contains"
expected_cls = ExpectedStringValue expected_cls = ExpectedStringValue
async def run(self) -> dict: async def run(self, response: Response) -> dict:
try:
response = await self.http_client.request(
method=self.task.method, url=self.task.url, timeout=60
)
except ReadError:
sleep(1)
response = await self.http_client.request(
method=self.task.method, url=self.task.url, timeout=60
)
return self.response(status=self.expected in response.text) return self.response(status=self.expected in response.text)
@ -228,16 +151,7 @@ class HTTPBodyLike(BaseCheck):
config = "body-like" config = "body-like"
expected_cls = ExpectedStringValue expected_cls = ExpectedStringValue
async def run(self) -> dict: async def run(self, response: Response) -> dict:
try:
response = await self.http_client.request(
method=self.task.method, url=self.task.url, timeout=60
)
except ReadError:
sleep(1)
response = await self.http_client.request(
method=self.task.method, url=self.task.url, timeout=60
)
if re.search(rf"{self.expected}", response.text): if re.search(rf"{self.expected}", response.text):
return self.response(status=True) return self.response(status=True)
@ -251,18 +165,7 @@ class HTTPJsonContains(BaseCheck):
config = "json-contains" config = "json-contains"
expected_cls = ExpectedStringValue expected_cls = ExpectedStringValue
async def run(self) -> dict: async def run(self, response: Response) -> dict:
task = self.task
try:
response = await self.http_client.request(
method=task.method, url=task.url, timeout=60
)
except ReadError:
sleep(1)
response = await self.http_client.request(
method=task.method, url=task.url, timeout=60
)
obj = response.json() obj = response.json()
status = True status = True
@ -286,18 +189,7 @@ class HTTPJsonHas(BaseCheck):
config = "json-has" config = "json-has"
expected_cls = ExpectedStringValue expected_cls = ExpectedStringValue
async def run(self) -> dict: async def run(self, response: Response) -> dict:
task = self.task
try:
response = await self.http_client.request(
method=task.method, url=task.url, timeout=60
)
except ReadError:
sleep(1)
response = await self.http_client.request(
method=task.method, url=task.url, timeout=60
)
obj = response.json() obj = response.json()
status = True status = True
@ -325,18 +217,7 @@ class HTTPJsonLike(BaseCheck):
config = "json-like" config = "json-like"
expected_cls = ExpectedStringValue expected_cls = ExpectedStringValue
async def run(self) -> dict: async def run(self, response: Response) -> dict:
task = self.task
try:
response = await self.http_client.request(
method=task.method, url=task.url, timeout=60
)
except ReadError:
sleep(1)
response = await self.http_client.request(
method=task.method, url=task.url, timeout=60
)
obj = response.json() obj = response.json()
status = True status = True
@ -363,18 +244,7 @@ class HTTPJsonIs(BaseCheck):
config = "json-is" config = "json-is"
expected_cls = ExpectedStringValue expected_cls = ExpectedStringValue
async def run(self) -> dict: async def run(self, response: Response) -> dict:
task = self.task
try:
response = await self.http_client.request(
method=task.method, url=task.url, timeout=60
)
except ReadError:
sleep(1)
response = await self.http_client.request(
method=task.method, url=task.url, timeout=60
)
obj = response.json() obj = response.json()
status = response.json() == json.loads(self.expected) status = response.json() == json.loads(self.expected)
@ -392,18 +262,8 @@ class SSLCertificateExpiration(BaseCheck):
config = "ssl-certificate-expiration" config = "ssl-certificate-expiration"
expected_cls = ExpectedStringValue expected_cls = ExpectedStringValue
async def run(self): async def run(self, response: Response) -> dict:
"""Returns the number of days in which the certificate will expire.""" """Returns the number of days in which the certificate will expire."""
try:
response = await self.http_client.request(
method=self.task.method, url=self.task.url, timeout=60
)
except ReadError:
sleep(1)
response = await self.http_client.request(
method=self.task.method, url=self.task.url, timeout=60
)
network_stream = response.extensions["network_stream"] network_stream = response.extensions["network_stream"]
ssl_obj = network_stream.get_extra_info("ssl_object") ssl_obj = network_stream.get_extra_info("ssl_object")
cert = ssl_obj.getpeercert() cert = ssl_obj.getpeercert()

View file

@ -52,6 +52,9 @@ async def test_ssl_check_accepts_statuts(
return_value=httpx.Response(http_status, extensions=httpx_extensions_ssl), return_value=httpx.Response(http_status, extensions=httpx_extensions_ssl),
) )
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
check = SSLCertificateExpiration(client, ssl_task) check = SSLCertificateExpiration(ssl_task)
check_response = await check.run() response = await client.request(
method=ssl_task.method, url=ssl_task.url, timeout=60
)
check_response = await check.run(response)
assert check_response.status == "on-check" assert check_response.status == "on-check"