diff --git a/argos/checks/__init__.py b/argos/checks/__init__.py index 93eb551..ec63d62 100644 --- a/argos/checks/__init__.py +++ b/argos/checks/__init__.py @@ -1,93 +1,2 @@ -import httpx -from argos.logging import logger - -from argos.schemas import Task -from pydantic import BaseModel, Field -from typing import Type - - -class BaseExpectedValue(BaseModel): - expected: str - - def get_converted(self): - return self.expected - - -class ExpectedIntValue(BaseExpectedValue): - def get_converted(self): - return int(self.expected) - - -class ExpectedStringValue(BaseExpectedValue): - pass - - -class BaseCheck: - config: str - expected_cls : Type[BaseExpectedValue] = None - - def response(self, passed, **kwargs): - status = "success" if passed else "failure" - return status, kwargs - - def __init__(self, client: httpx.AsyncClient, task: Task): - self.client = client - self.task = task - - @property - def expected(self): - return self.expected_cls(expected=self.task.expected).get_converted() - - -class HTTPStatusCheck(BaseCheck): - config = "status-is" - expected_cls = ExpectedIntValue - - async def run(self) -> dict: - # XXX Get the method from the task - task = self.task - response = await self.client.request(method="get", url=task.url) - logger.error(f"{response.status_code=}, {self.expected=}") - return self.response( - response.status_code == self.expected, - expected=self.expected, - retrieved=response.status_code - ) - - -class HTTPBodyContains(BaseCheck): - config = "body-contains" - expected_cls = ExpectedStringValue - - async def run(self) -> dict: - response = await self.client.request(method="get", url=self.task.url) - return self.response( - self.expected in response.text - ) - - -class SSLCertificateExpiration(BaseCheck): - config = "ssl-certificate-expiration" - expected_cls = ExpectedStringValue - - async def run(self): - return True - - -AVAILABLE_CHECKS = (HTTPStatusCheck, HTTPBodyContains, SSLCertificateExpiration) - - -class CheckNotFound(Exception): - pass - - -def get_names(checks=AVAILABLE_CHECKS): - return [c.config for c in checks] - - -def get_check_by_name(name, checks=AVAILABLE_CHECKS): - checks_dict = {c.config: c for c in checks} - check = checks_dict.get(name) - if not check: - raise CheckNotFound(name) - return check +from argos.checks.checks import HTTPStatus, HTTPBodyContains, SSLCertificateExpiration +from argos.checks.base import get_check_by_name, CheckNotFound \ No newline at end of file diff --git a/argos/checks/base.py b/argos/checks/base.py new file mode 100644 index 0000000..bc93f75 --- /dev/null +++ b/argos/checks/base.py @@ -0,0 +1,71 @@ +from pydantic import BaseModel, Field +from dataclasses import dataclass + +from typing import Type +import httpx + +from argos.schemas import Task + + +@dataclass +class Response: + status: str + context: dict + + +class BaseExpectedValue(BaseModel): + expected: str + + def get_converted(self): + return self.expected + + +class ExpectedIntValue(BaseExpectedValue): + def get_converted(self): + return int(self.expected) + + +class ExpectedStringValue(BaseExpectedValue): + pass + + +class CheckNotFound(Exception): + pass + + +class BaseCheck: + config: str + expected_cls: Type[BaseExpectedValue] = None + + _registry = [] + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + cls._registry.append(cls) + + @classmethod + def get_registered_checks(cls): + return {c.config: c for c in cls._registry} + + @classmethod + def get_registered_check(cls, name): + check = cls.get_registered_checks().get(name) + if not check: + raise CheckNotFound(name) + return check + + def response(self, passed, **kwargs) -> Response: + status = "success" if passed else "failure" + return Response(status, kwargs) + + @property + def expected(self): + return self.expected_cls(expected=self.task.expected).get_converted() + + def __init__(self, client: httpx.AsyncClient, task: Task): + self.client = client + self.task = task + + +def get_check_by_name(name): + return BaseCheck.get_registered_check(name) \ No newline at end of file diff --git a/argos/checks/checks.py b/argos/checks/checks.py new file mode 100644 index 0000000..c57aa3f --- /dev/null +++ b/argos/checks/checks.py @@ -0,0 +1,37 @@ +from argos.logging import logger +from argos.checks.base import BaseCheck, Response, ExpectedIntValue, ExpectedStringValue + + +class HTTPStatus(BaseCheck): + config = "status-is" + expected_cls = ExpectedIntValue + + async def run(self) -> dict: + # XXX Get the method from the task + task = self.task + response = await self.client.request(method="get", url=task.url) + logger.error(f"{response.status_code=}, {self.expected=}") + return self.response( + response.status_code == self.expected, + expected=self.expected, + retrieved=response.status_code + ) + + +class HTTPBodyContains(BaseCheck): + config = "body-contains" + expected_cls = ExpectedStringValue + + async def run(self) -> dict: + response = await self.client.request(method="get", url=self.task.url) + return self.response( + self.expected in response.text + ) + + +class SSLCertificateExpiration(BaseCheck): + config = "ssl-certificate-expiration" + expected_cls = ExpectedStringValue + + async def run(self): + return True diff --git a/argos/client/cli.py b/argos/client/cli.py index 06d2bac..aaccaca 100644 --- a/argos/client/cli.py +++ b/argos/client/cli.py @@ -16,7 +16,9 @@ async def complete_task(client: httpx.AsyncClient, task: dict) -> dict: task = Task(**task) check_class = get_check_by_name(task.check) check = check_class(client, task) - status, context = await check.run() + result = await check.run() + status = result.status + context = result.context except Exception as e: status = "error"