mirror of
https://framagit.org/framasoft/framaspace/argos.git
synced 2025-04-28 18:02:41 +02:00
Refactor argos check and cli modules
- Added a simple way to have a registry for the checks. - Use a Result dataclass to send back results from the workers
This commit is contained in:
parent
835ee50c1f
commit
ff4588bc39
4 changed files with 113 additions and 94 deletions
|
@ -1,93 +1,2 @@
|
||||||
import httpx
|
from argos.checks.checks import HTTPStatus, HTTPBodyContains, SSLCertificateExpiration
|
||||||
from argos.logging import logger
|
from argos.checks.base import get_check_by_name, CheckNotFound
|
||||||
|
|
||||||
from argos.schemas import Task
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from typing import Type
|
|
||||||
|
|
||||||
|
|
||||||
class BaseExpectedValue(BaseModel):
|
|
||||||
expected: str
|
|
||||||
|
|
||||||
def get_converted(self):
|
|
||||||
return self.expected
|
|
||||||
|
|
||||||
|
|
||||||
class ExpectedIntValue(BaseExpectedValue):
|
|
||||||
def get_converted(self):
|
|
||||||
return int(self.expected)
|
|
||||||
|
|
||||||
|
|
||||||
class ExpectedStringValue(BaseExpectedValue):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class BaseCheck:
|
|
||||||
config: str
|
|
||||||
expected_cls : Type[BaseExpectedValue] = None
|
|
||||||
|
|
||||||
def response(self, passed, **kwargs):
|
|
||||||
status = "success" if passed else "failure"
|
|
||||||
return status, kwargs
|
|
||||||
|
|
||||||
def __init__(self, client: httpx.AsyncClient, task: Task):
|
|
||||||
self.client = client
|
|
||||||
self.task = task
|
|
||||||
|
|
||||||
@property
|
|
||||||
def expected(self):
|
|
||||||
return self.expected_cls(expected=self.task.expected).get_converted()
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPStatusCheck(BaseCheck):
|
|
||||||
config = "status-is"
|
|
||||||
expected_cls = ExpectedIntValue
|
|
||||||
|
|
||||||
async def run(self) -> dict:
|
|
||||||
# XXX Get the method from the task
|
|
||||||
task = self.task
|
|
||||||
response = await self.client.request(method="get", url=task.url)
|
|
||||||
logger.error(f"{response.status_code=}, {self.expected=}")
|
|
||||||
return self.response(
|
|
||||||
response.status_code == self.expected,
|
|
||||||
expected=self.expected,
|
|
||||||
retrieved=response.status_code
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPBodyContains(BaseCheck):
|
|
||||||
config = "body-contains"
|
|
||||||
expected_cls = ExpectedStringValue
|
|
||||||
|
|
||||||
async def run(self) -> dict:
|
|
||||||
response = await self.client.request(method="get", url=self.task.url)
|
|
||||||
return self.response(
|
|
||||||
self.expected in response.text
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SSLCertificateExpiration(BaseCheck):
|
|
||||||
config = "ssl-certificate-expiration"
|
|
||||||
expected_cls = ExpectedStringValue
|
|
||||||
|
|
||||||
async def run(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
AVAILABLE_CHECKS = (HTTPStatusCheck, HTTPBodyContains, SSLCertificateExpiration)
|
|
||||||
|
|
||||||
|
|
||||||
class CheckNotFound(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def get_names(checks=AVAILABLE_CHECKS):
|
|
||||||
return [c.config for c in checks]
|
|
||||||
|
|
||||||
|
|
||||||
def get_check_by_name(name, checks=AVAILABLE_CHECKS):
|
|
||||||
checks_dict = {c.config: c for c in checks}
|
|
||||||
check = checks_dict.get(name)
|
|
||||||
if not check:
|
|
||||||
raise CheckNotFound(name)
|
|
||||||
return check
|
|
71
argos/checks/base.py
Normal file
71
argos/checks/base.py
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from typing import Type
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from argos.schemas import Task
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Response:
|
||||||
|
status: str
|
||||||
|
context: dict
|
||||||
|
|
||||||
|
|
||||||
|
class BaseExpectedValue(BaseModel):
|
||||||
|
expected: str
|
||||||
|
|
||||||
|
def get_converted(self):
|
||||||
|
return self.expected
|
||||||
|
|
||||||
|
|
||||||
|
class ExpectedIntValue(BaseExpectedValue):
|
||||||
|
def get_converted(self):
|
||||||
|
return int(self.expected)
|
||||||
|
|
||||||
|
|
||||||
|
class ExpectedStringValue(BaseExpectedValue):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CheckNotFound(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BaseCheck:
|
||||||
|
config: str
|
||||||
|
expected_cls: Type[BaseExpectedValue] = None
|
||||||
|
|
||||||
|
_registry = []
|
||||||
|
|
||||||
|
def __init_subclass__(cls, **kwargs):
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
cls._registry.append(cls)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_registered_checks(cls):
|
||||||
|
return {c.config: c for c in cls._registry}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_registered_check(cls, name):
|
||||||
|
check = cls.get_registered_checks().get(name)
|
||||||
|
if not check:
|
||||||
|
raise CheckNotFound(name)
|
||||||
|
return check
|
||||||
|
|
||||||
|
def response(self, passed, **kwargs) -> Response:
|
||||||
|
status = "success" if passed else "failure"
|
||||||
|
return Response(status, kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected(self):
|
||||||
|
return self.expected_cls(expected=self.task.expected).get_converted()
|
||||||
|
|
||||||
|
def __init__(self, client: httpx.AsyncClient, task: Task):
|
||||||
|
self.client = client
|
||||||
|
self.task = task
|
||||||
|
|
||||||
|
|
||||||
|
def get_check_by_name(name):
|
||||||
|
return BaseCheck.get_registered_check(name)
|
37
argos/checks/checks.py
Normal file
37
argos/checks/checks.py
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
from argos.logging import logger
|
||||||
|
from argos.checks.base import BaseCheck, Response, ExpectedIntValue, ExpectedStringValue
|
||||||
|
|
||||||
|
|
||||||
|
class HTTPStatus(BaseCheck):
|
||||||
|
config = "status-is"
|
||||||
|
expected_cls = ExpectedIntValue
|
||||||
|
|
||||||
|
async def run(self) -> dict:
|
||||||
|
# XXX Get the method from the task
|
||||||
|
task = self.task
|
||||||
|
response = await self.client.request(method="get", url=task.url)
|
||||||
|
logger.error(f"{response.status_code=}, {self.expected=}")
|
||||||
|
return self.response(
|
||||||
|
response.status_code == self.expected,
|
||||||
|
expected=self.expected,
|
||||||
|
retrieved=response.status_code
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HTTPBodyContains(BaseCheck):
|
||||||
|
config = "body-contains"
|
||||||
|
expected_cls = ExpectedStringValue
|
||||||
|
|
||||||
|
async def run(self) -> dict:
|
||||||
|
response = await self.client.request(method="get", url=self.task.url)
|
||||||
|
return self.response(
|
||||||
|
self.expected in response.text
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SSLCertificateExpiration(BaseCheck):
|
||||||
|
config = "ssl-certificate-expiration"
|
||||||
|
expected_cls = ExpectedStringValue
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
return True
|
|
@ -16,7 +16,9 @@ async def complete_task(client: httpx.AsyncClient, task: dict) -> dict:
|
||||||
task = Task(**task)
|
task = Task(**task)
|
||||||
check_class = get_check_by_name(task.check)
|
check_class = get_check_by_name(task.check)
|
||||||
check = check_class(client, task)
|
check = check_class(client, task)
|
||||||
status, context = await check.run()
|
result = await check.run()
|
||||||
|
status = result.status
|
||||||
|
context = result.context
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
status = "error"
|
status = "error"
|
||||||
|
|
Loading…
Reference in a new issue