Refactor argos check and cli modules

- Added a simple way to have a registry for the checks.
- Use a Result dataclass to send back results from the workers
This commit is contained in:
Alexis Métaireau 2023-10-05 11:36:36 +02:00
parent 835ee50c1f
commit ff4588bc39
4 changed files with 113 additions and 94 deletions

View file

@ -1,93 +1,2 @@
import httpx
from argos.logging import logger
from argos.schemas import Task
from pydantic import BaseModel, Field
from typing import Type
class BaseExpectedValue(BaseModel):
expected: str
def get_converted(self):
return self.expected
class ExpectedIntValue(BaseExpectedValue):
def get_converted(self):
return int(self.expected)
class ExpectedStringValue(BaseExpectedValue):
pass
class BaseCheck:
config: str
expected_cls : Type[BaseExpectedValue] = None
def response(self, passed, **kwargs):
status = "success" if passed else "failure"
return status, kwargs
def __init__(self, client: httpx.AsyncClient, task: Task):
self.client = client
self.task = task
@property
def expected(self):
return self.expected_cls(expected=self.task.expected).get_converted()
class HTTPStatusCheck(BaseCheck):
config = "status-is"
expected_cls = ExpectedIntValue
async def run(self) -> dict:
# XXX Get the method from the task
task = self.task
response = await self.client.request(method="get", url=task.url)
logger.error(f"{response.status_code=}, {self.expected=}")
return self.response(
response.status_code == self.expected,
expected=self.expected,
retrieved=response.status_code
)
class HTTPBodyContains(BaseCheck):
config = "body-contains"
expected_cls = ExpectedStringValue
async def run(self) -> dict:
response = await self.client.request(method="get", url=self.task.url)
return self.response(
self.expected in response.text
)
class SSLCertificateExpiration(BaseCheck):
config = "ssl-certificate-expiration"
expected_cls = ExpectedStringValue
async def run(self):
return True
AVAILABLE_CHECKS = (HTTPStatusCheck, HTTPBodyContains, SSLCertificateExpiration)
class CheckNotFound(Exception):
pass
def get_names(checks=AVAILABLE_CHECKS):
return [c.config for c in checks]
def get_check_by_name(name, checks=AVAILABLE_CHECKS):
checks_dict = {c.config: c for c in checks}
check = checks_dict.get(name)
if not check:
raise CheckNotFound(name)
return check
from argos.checks.checks import HTTPStatus, HTTPBodyContains, SSLCertificateExpiration
from argos.checks.base import get_check_by_name, CheckNotFound

71
argos/checks/base.py Normal file
View file

@ -0,0 +1,71 @@
from pydantic import BaseModel, Field
from dataclasses import dataclass
from typing import Type
import httpx
from argos.schemas import Task
@dataclass
class Response:
status: str
context: dict
class BaseExpectedValue(BaseModel):
expected: str
def get_converted(self):
return self.expected
class ExpectedIntValue(BaseExpectedValue):
def get_converted(self):
return int(self.expected)
class ExpectedStringValue(BaseExpectedValue):
pass
class CheckNotFound(Exception):
pass
class BaseCheck:
config: str
expected_cls: Type[BaseExpectedValue] = None
_registry = []
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls._registry.append(cls)
@classmethod
def get_registered_checks(cls):
return {c.config: c for c in cls._registry}
@classmethod
def get_registered_check(cls, name):
check = cls.get_registered_checks().get(name)
if not check:
raise CheckNotFound(name)
return check
def response(self, passed, **kwargs) -> Response:
status = "success" if passed else "failure"
return Response(status, kwargs)
@property
def expected(self):
return self.expected_cls(expected=self.task.expected).get_converted()
def __init__(self, client: httpx.AsyncClient, task: Task):
self.client = client
self.task = task
def get_check_by_name(name):
return BaseCheck.get_registered_check(name)

37
argos/checks/checks.py Normal file
View file

@ -0,0 +1,37 @@
from argos.logging import logger
from argos.checks.base import BaseCheck, Response, ExpectedIntValue, ExpectedStringValue
class HTTPStatus(BaseCheck):
config = "status-is"
expected_cls = ExpectedIntValue
async def run(self) -> dict:
# XXX Get the method from the task
task = self.task
response = await self.client.request(method="get", url=task.url)
logger.error(f"{response.status_code=}, {self.expected=}")
return self.response(
response.status_code == self.expected,
expected=self.expected,
retrieved=response.status_code
)
class HTTPBodyContains(BaseCheck):
config = "body-contains"
expected_cls = ExpectedStringValue
async def run(self) -> dict:
response = await self.client.request(method="get", url=self.task.url)
return self.response(
self.expected in response.text
)
class SSLCertificateExpiration(BaseCheck):
config = "ssl-certificate-expiration"
expected_cls = ExpectedStringValue
async def run(self):
return True

View file

@ -16,7 +16,9 @@ async def complete_task(client: httpx.AsyncClient, task: dict) -> dict:
task = Task(**task)
check_class = get_check_by_name(task.check)
check = check_class(client, task)
status, context = await check.run()
result = await check.run()
status = result.status
context = result.context
except Exception as e:
status = "error"