Refactor check config vlidation

- Refactored the `get_check_by_name` method to `get_registered_check` in the
  BaseCheck class and added a `get_registered_checks` method to get all the
  registered checks.
- Added a validation in the `WebsitePath` class to ensure that a check exists
  when parsing the configuration file.
- Updated the existing test for parsing SSL duration and added new tests to
  validate path parsing and check existence validation.
This commit is contained in:
Alexis Métaireau 2023-10-10 10:04:46 +02:00
parent 42ec15c6f4
commit 43e1767002
7 changed files with 58 additions and 28 deletions

View file

@ -1,2 +1,2 @@
from argos.checks.checks import HTTPStatus, HTTPBodyContains, SSLCertificateExpiration from argos.checks.checks import HTTPStatus, HTTPBodyContains, SSLCertificateExpiration
from argos.checks.base import get_check_by_name, CheckNotFound from argos.checks.base import get_registered_checks, get_registered_check, CheckNotFound

View file

@ -90,5 +90,9 @@ class BaseCheck:
return Response.new(status, **kwargs) return Response.new(status, **kwargs)
def get_check_by_name(name): def get_registered_check(name):
return BaseCheck.get_registered_check(name) return BaseCheck.get_registered_check(name)
def get_registered_checks():
return BaseCheck.get_registered_checks()

View file

@ -10,7 +10,6 @@ from pydantic import BaseModel, Field, HttpUrl, validator
from datetime import datetime from datetime import datetime
# from argos.checks import get_names as get_check_names
# XXX Find a way to check without having cirular imports # XXX Find a way to check without having cirular imports
# This file contains the pydantic schemas. For the database models, check in argos.model. # This file contains the pydantic schemas. For the database models, check in argos.model.
@ -63,7 +62,19 @@ class WebsiteCheck(BaseModel):
class WebsitePath(BaseModel): class WebsitePath(BaseModel):
path: str path: str
checks: List[Dict[str, str | dict | int]] checks: List[Tuple[str, str | dict | int]]
@validator("checks", each_item=True, pre=True)
def parse_checks(cls, value):
from argos.checks import get_registered_checks # To avoid circular imports
available_names = get_registered_checks().keys()
for name, expected in value.items():
if name not in available_names:
msg = f"Check should be one of f{available_names}. ({name} given)"
raise ValueError(msg)
return (name, expected)
class Website(BaseModel): class Website(BaseModel):

View file

@ -7,7 +7,7 @@ from argos.server import queries, models
from argos.schemas import AgentResult, Task from argos.schemas import AgentResult, Task
from argos.schemas.config import from_yaml as get_schemas_from_yaml from argos.schemas.config import from_yaml as get_schemas_from_yaml
from argos.server.database import SessionLocal, engine from argos.server.database import SessionLocal, engine
from argos.checks import get_check_by_name from argos.checks import get_registered_check
from argos.logging import logger from argos.logging import logger
from typing import List from typing import List

View file

@ -16,7 +16,7 @@ from sqlalchemy.orm import mapped_column, relationship
from datetime import datetime from datetime import datetime
from argos.schemas import WebsiteCheck from argos.schemas import WebsiteCheck
from argos.checks import get_check_by_name from argos.checks import get_registered_check
class Base(DeclarativeBase): class Base(DeclarativeBase):
@ -51,7 +51,7 @@ class Task(Base):
def get_check(self): def get_check(self):
"""Returns a check instance for this specific task""" """Returns a check instance for this specific task"""
return get_check_by_name(self.check) return get_registered_check(self.check)
class Result(Base): class Result(Base):

View file

@ -49,25 +49,24 @@ async def update_from_config(db: Session, config: schemas.Config):
domain = str(website.domain) domain = str(website.domain)
for p in website.paths: for p in website.paths:
url = urljoin(domain, str(p.path)) url = urljoin(domain, str(p.path))
for check in p.checks: for check_key, expected in p.checks:
for check_key, expected in check.items(): # Check the db for already existing tasks.
# Check the db for already existing tasks. existing_task = db.query(
existing_task = db.query( exists().where(
exists().where( Task.url == url
Task.url == url and Task.check == check_key
and Task.check == check_key and Task.expected == expected
and Task.expected == expected )
) ).scalar()
).scalar()
if not existing_task: if not existing_task:
task = Task( task = Task(
domain=domain, url=url, check=check_key, expected=expected domain=domain, url=url, check=check_key, expected=expected
) )
logger.debug(f"Adding a new task in the db: {task}") logger.debug(f"Adding a new task in the db: {task}")
db.add(task) db.add(task)
else: else:
logger.debug( logger.debug(
f"Skipping db task creation for {url=}, {check_key=}, {expected=}." f"Skipping db task creation for {url=}, {check_key=}, {expected=}."
) )
db.commit() db.commit()

View file

@ -1,5 +1,5 @@
import pytest import pytest
from argos.schemas.config import SSL from argos.schemas.config import SSL, WebsitePath
def test_ssl_duration_parsing(): def test_ssl_duration_parsing():
@ -14,3 +14,19 @@ def test_ssl_duration_parsing():
with pytest.raises(ValueError): with pytest.raises(ValueError):
erroneous_data = {"thresholds": [{"1d": "caution"}, {"1w": "danger"}]} erroneous_data = {"thresholds": [{"1d": "caution"}, {"1w": "danger"}]}
SSL(**erroneous_data) SSL(**erroneous_data)
def test_path_parsing():
data = {"path": "/", "checks": [{"body-contains": "youpi"}, {"status-is": 200}]}
path = WebsitePath(**data)
assert len(path.checks) == 2
assert path.checks == [("body-contains", "youpi"), ("status-is", 200)]
def test_path_ensures_check_exists():
with pytest.raises(ValueError):
erroneous_data = {
"path": "/",
"checks": [{"non-existing-key": "youpi"}, {"status-is": 200}],
}
WebsitePath(**erroneous_data)