diff --git a/argos/checks/__init__.py b/argos/checks/__init__.py index bb0ebd9..21c44ff 100644 --- a/argos/checks/__init__.py +++ b/argos/checks/__init__.py @@ -1,2 +1,2 @@ from argos.checks.checks import HTTPStatus, HTTPBodyContains, SSLCertificateExpiration -from argos.checks.base import get_check_by_name, CheckNotFound +from argos.checks.base import get_registered_checks, get_registered_check, CheckNotFound diff --git a/argos/checks/base.py b/argos/checks/base.py index 073094a..ff9908d 100644 --- a/argos/checks/base.py +++ b/argos/checks/base.py @@ -90,5 +90,9 @@ class BaseCheck: return Response.new(status, **kwargs) -def get_check_by_name(name): +def get_registered_check(name): return BaseCheck.get_registered_check(name) + + +def get_registered_checks(): + return BaseCheck.get_registered_checks() diff --git a/argos/schemas/config.py b/argos/schemas/config.py index 680f6e5..712cc02 100644 --- a/argos/schemas/config.py +++ b/argos/schemas/config.py @@ -10,7 +10,6 @@ from pydantic import BaseModel, Field, HttpUrl, validator from datetime import datetime -# from argos.checks import get_names as get_check_names # XXX Find a way to check without having cirular imports # This file contains the pydantic schemas. For the database models, check in argos.model. @@ -63,7 +62,19 @@ class WebsiteCheck(BaseModel): class WebsitePath(BaseModel): path: str - checks: List[Dict[str, str | dict | int]] + checks: List[Tuple[str, str | dict | int]] + + @validator("checks", each_item=True, pre=True) + def parse_checks(cls, value): + from argos.checks import get_registered_checks # To avoid circular imports + + available_names = get_registered_checks().keys() + + for name, expected in value.items(): + if name not in available_names: + msg = f"Check should be one of f{available_names}. ({name} given)" + raise ValueError(msg) + return (name, expected) class Website(BaseModel): diff --git a/argos/server/api.py b/argos/server/api.py index 9e4489d..2d9ec50 100644 --- a/argos/server/api.py +++ b/argos/server/api.py @@ -7,7 +7,7 @@ from argos.server import queries, models from argos.schemas import AgentResult, Task from argos.schemas.config import from_yaml as get_schemas_from_yaml from argos.server.database import SessionLocal, engine -from argos.checks import get_check_by_name +from argos.checks import get_registered_check from argos.logging import logger from typing import List diff --git a/argos/server/models.py b/argos/server/models.py index 717f54f..bed3ef2 100644 --- a/argos/server/models.py +++ b/argos/server/models.py @@ -16,7 +16,7 @@ from sqlalchemy.orm import mapped_column, relationship from datetime import datetime from argos.schemas import WebsiteCheck -from argos.checks import get_check_by_name +from argos.checks import get_registered_check class Base(DeclarativeBase): @@ -51,7 +51,7 @@ class Task(Base): def get_check(self): """Returns a check instance for this specific task""" - return get_check_by_name(self.check) + return get_registered_check(self.check) class Result(Base): diff --git a/argos/server/queries.py b/argos/server/queries.py index 9cc6f1c..4600439 100644 --- a/argos/server/queries.py +++ b/argos/server/queries.py @@ -49,25 +49,24 @@ async def update_from_config(db: Session, config: schemas.Config): domain = str(website.domain) for p in website.paths: url = urljoin(domain, str(p.path)) - for check in p.checks: - for check_key, expected in check.items(): - # Check the db for already existing tasks. - existing_task = db.query( - exists().where( - Task.url == url - and Task.check == check_key - and Task.expected == expected - ) - ).scalar() + for check_key, expected in p.checks: + # Check the db for already existing tasks. + existing_task = db.query( + exists().where( + Task.url == url + and Task.check == check_key + and Task.expected == expected + ) + ).scalar() - if not existing_task: - task = Task( - domain=domain, url=url, check=check_key, expected=expected - ) - logger.debug(f"Adding a new task in the db: {task}") - db.add(task) - else: - logger.debug( - f"Skipping db task creation for {url=}, {check_key=}, {expected=}." - ) + if not existing_task: + task = Task( + domain=domain, url=url, check=check_key, expected=expected + ) + logger.debug(f"Adding a new task in the db: {task}") + db.add(task) + else: + logger.debug( + f"Skipping db task creation for {url=}, {check_key=}, {expected=}." + ) db.commit() diff --git a/tests/test_schemas_config.py b/tests/test_schemas_config.py index 65c62c1..f7c6a43 100644 --- a/tests/test_schemas_config.py +++ b/tests/test_schemas_config.py @@ -1,5 +1,5 @@ import pytest -from argos.schemas.config import SSL +from argos.schemas.config import SSL, WebsitePath def test_ssl_duration_parsing(): @@ -14,3 +14,19 @@ def test_ssl_duration_parsing(): with pytest.raises(ValueError): erroneous_data = {"thresholds": [{"1d": "caution"}, {"1w": "danger"}]} SSL(**erroneous_data) + + +def test_path_parsing(): + data = {"path": "/", "checks": [{"body-contains": "youpi"}, {"status-is": 200}]} + path = WebsitePath(**data) + assert len(path.checks) == 2 + assert path.checks == [("body-contains", "youpi"), ("status-is", 200)] + + +def test_path_ensures_check_exists(): + with pytest.raises(ValueError): + erroneous_data = { + "path": "/", + "checks": [{"non-existing-key": "youpi"}, {"status-is": 200}], + } + WebsitePath(**erroneous_data)