diff --git a/CHANGELOG.md b/CHANGELOG.md index 02528f2..9b4debd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ - ♿️ — Fix not-OK domains display if javascript is disabled - ✨ — Retry check right after a httpx.ReadError - ✨ — The HTTP method used by checks is now configurable +- ♻️ — Refactor some agent code ## 0.5.0 diff --git a/argos/agent.py b/argos/agent.py index 580553d..d791641 100644 --- a/argos/agent.py +++ b/argos/agent.py @@ -6,6 +6,7 @@ import asyncio import json import logging import socket +from time import sleep from typing import List import httpx @@ -63,9 +64,24 @@ class ArgosAgent: async def _complete_task(self, _task: dict) -> AgentResult: try: task = Task(**_task) + + url = task.url + if task.check == "http-to-https": + url = str(httpx.URL(task.url).copy_with(scheme="http")) + + try: + response = await self._http_client.request( # type: ignore[attr-defined] + method=task.method, url=url, timeout=60 + ) + except httpx.ReadError: + sleep(1) + response = await self._http_client.request( # type: ignore[attr-defined] + method=task.method, url=url, timeout=60 + ) + check_class = get_registered_check(task.check) - check = check_class(self._http_client, task) - result = await check.run() + check = check_class(task) + result = await check.run(response) status = result.status context = result.context diff --git a/argos/checks/base.py b/argos/checks/base.py index 3221809..8206714 100644 --- a/argos/checks/base.py +++ b/argos/checks/base.py @@ -3,7 +3,6 @@ from dataclasses import dataclass from typing import Type -import httpx from pydantic import BaseModel from argos.schemas.models import Task @@ -92,8 +91,7 @@ class BaseCheck: raise CheckNotFound(name) return check - def __init__(self, http_client: httpx.AsyncClient, task: Task): - self.http_client = http_client + def __init__(self, task: Task): self.task = task @property diff --git a/argos/checks/checks.py b/argos/checks/checks.py index ad1ff5c..b30977a 100644 --- a/argos/checks/checks.py +++ b/argos/checks/checks.py @@ -3,9 +3,8 @@ import json import re from datetime import datetime -from time import sleep -from httpx import URL, ReadError +from httpx import Response from jsonpointer import resolve_pointer, JsonPointerException from argos.checks.base import ( @@ -23,18 +22,7 @@ class HTTPStatus(BaseCheck): config = "status-is" expected_cls = ExpectedIntValue - async def run(self) -> dict: - task = self.task - try: - response = await self.http_client.request( - method=task.method, url=task.url, timeout=60 - ) - except ReadError: - sleep(1) - response = await self.http_client.request( - method=task.method, url=task.url, timeout=60 - ) - + async def run(self, response: Response) -> dict: return self.response( status=response.status_code == self.expected, expected=self.expected, @@ -48,18 +36,7 @@ class HTTPStatusIn(BaseCheck): config = "status-in" expected_cls = ExpectedStringValue - async def run(self) -> dict: - task = self.task - try: - response = await self.http_client.request( - method=task.method, url=task.url, timeout=60 - ) - except ReadError: - sleep(1) - response = await self.http_client.request( - method=task.method, url=task.url, timeout=60 - ) - + async def run(self, response: Response) -> dict: return self.response( status=response.status_code in json.loads(self.expected), expected=self.expected, @@ -73,19 +50,7 @@ class HTTPToHTTPS(BaseCheck): config = "http-to-https" expected_cls = ExpectedStringValue - async def run(self) -> dict: - task = self.task - url = URL(task.url).copy_with(scheme="http") - try: - response = await self.http_client.request( - method=task.method, url=url, timeout=60 - ) - except ReadError: - sleep(1) - response = await self.http_client.request( - method=task.method, url=url, timeout=60 - ) - + async def run(self, response: Response) -> dict: expected_dict = json.loads(self.expected) expected = range(300, 400) if "range" in expected_dict: @@ -109,18 +74,7 @@ class HTTPHeadersContain(BaseCheck): config = "headers-contain" expected_cls = ExpectedStringValue - async def run(self) -> dict: - task = self.task - try: - response = await self.http_client.request( - method=task.method, url=task.url, timeout=60 - ) - except ReadError: - sleep(1) - response = await self.http_client.request( - method=task.method, url=task.url, timeout=60 - ) - + async def run(self, response: Response) -> dict: status = True for header in json.loads(self.expected): if header not in response.headers: @@ -140,18 +94,7 @@ class HTTPHeadersHave(BaseCheck): config = "headers-have" expected_cls = ExpectedStringValue - async def run(self) -> dict: - task = self.task - try: - response = await self.http_client.request( - method=task.method, url=task.url, timeout=60 - ) - except ReadError: - sleep(1) - response = await self.http_client.request( - method=task.method, url=task.url, timeout=60 - ) - + async def run(self, response: Response) -> dict: status = True for header, value in json.loads(self.expected).items(): if header not in response.headers: @@ -175,18 +118,7 @@ class HTTPHeadersLike(BaseCheck): config = "headers-like" expected_cls = ExpectedStringValue - async def run(self) -> dict: - task = self.task - try: - response = await self.http_client.request( - method=task.method, url=task.url, timeout=60 - ) - except ReadError: - sleep(1) - response = await self.http_client.request( - method=task.method, url=task.url, timeout=60 - ) - + async def run(self, response: Response) -> dict: status = True for header, value in json.loads(self.expected).items(): if header not in response.headers: @@ -209,16 +141,7 @@ class HTTPBodyContains(BaseCheck): config = "body-contains" expected_cls = ExpectedStringValue - async def run(self) -> dict: - try: - response = await self.http_client.request( - method=self.task.method, url=self.task.url, timeout=60 - ) - except ReadError: - sleep(1) - response = await self.http_client.request( - method=self.task.method, url=self.task.url, timeout=60 - ) + async def run(self, response: Response) -> dict: return self.response(status=self.expected in response.text) @@ -228,16 +151,7 @@ class HTTPBodyLike(BaseCheck): config = "body-like" expected_cls = ExpectedStringValue - async def run(self) -> dict: - try: - response = await self.http_client.request( - method=self.task.method, url=self.task.url, timeout=60 - ) - except ReadError: - sleep(1) - response = await self.http_client.request( - method=self.task.method, url=self.task.url, timeout=60 - ) + async def run(self, response: Response) -> dict: if re.search(rf"{self.expected}", response.text): return self.response(status=True) @@ -251,18 +165,7 @@ class HTTPJsonContains(BaseCheck): config = "json-contains" expected_cls = ExpectedStringValue - async def run(self) -> dict: - task = self.task - try: - response = await self.http_client.request( - method=task.method, url=task.url, timeout=60 - ) - except ReadError: - sleep(1) - response = await self.http_client.request( - method=task.method, url=task.url, timeout=60 - ) - + async def run(self, response: Response) -> dict: obj = response.json() status = True @@ -286,18 +189,7 @@ class HTTPJsonHas(BaseCheck): config = "json-has" expected_cls = ExpectedStringValue - async def run(self) -> dict: - task = self.task - try: - response = await self.http_client.request( - method=task.method, url=task.url, timeout=60 - ) - except ReadError: - sleep(1) - response = await self.http_client.request( - method=task.method, url=task.url, timeout=60 - ) - + async def run(self, response: Response) -> dict: obj = response.json() status = True @@ -325,18 +217,7 @@ class HTTPJsonLike(BaseCheck): config = "json-like" expected_cls = ExpectedStringValue - async def run(self) -> dict: - task = self.task - try: - response = await self.http_client.request( - method=task.method, url=task.url, timeout=60 - ) - except ReadError: - sleep(1) - response = await self.http_client.request( - method=task.method, url=task.url, timeout=60 - ) - + async def run(self, response: Response) -> dict: obj = response.json() status = True @@ -363,18 +244,7 @@ class HTTPJsonIs(BaseCheck): config = "json-is" expected_cls = ExpectedStringValue - async def run(self) -> dict: - task = self.task - try: - response = await self.http_client.request( - method=task.method, url=task.url, timeout=60 - ) - except ReadError: - sleep(1) - response = await self.http_client.request( - method=task.method, url=task.url, timeout=60 - ) - + async def run(self, response: Response) -> dict: obj = response.json() status = response.json() == json.loads(self.expected) @@ -392,18 +262,8 @@ class SSLCertificateExpiration(BaseCheck): config = "ssl-certificate-expiration" expected_cls = ExpectedStringValue - async def run(self): + async def run(self, response: Response) -> dict: """Returns the number of days in which the certificate will expire.""" - try: - response = await self.http_client.request( - method=self.task.method, url=self.task.url, timeout=60 - ) - except ReadError: - sleep(1) - response = await self.http_client.request( - method=self.task.method, url=self.task.url, timeout=60 - ) - network_stream = response.extensions["network_stream"] ssl_obj = network_stream.get_extra_info("ssl_object") cert = ssl_obj.getpeercert() diff --git a/tests/test_checks.py b/tests/test_checks.py index df58b04..28fe2fc 100644 --- a/tests/test_checks.py +++ b/tests/test_checks.py @@ -52,6 +52,9 @@ async def test_ssl_check_accepts_statuts( return_value=httpx.Response(http_status, extensions=httpx_extensions_ssl), ) async with httpx.AsyncClient() as client: - check = SSLCertificateExpiration(client, ssl_task) - check_response = await check.run() + check = SSLCertificateExpiration(ssl_task) + response = await client.request( + method=ssl_task.method, url=ssl_task.url, timeout=60 + ) + check_response = await check.run(response) assert check_response.status == "on-check"