From 43f8aabb2c55a5ec658e5412cd088b8dddc75286 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexis=20M=C3=A9taireau?= Date: Wed, 11 Oct 2023 23:52:33 +0200 Subject: [PATCH] Refactor server codebase for testing. - Restructured server module to separate the application creation and configuration. - Moved code dealing with SQLAlchemy database setup and teardown to the main application file. - Moved functions related to configuration file loading to `argos.server.settings`. - Fixed SQLAchemy expressions in `argos.server.queries`. - Implemented a more granular system of setting checks' schedule on the server. - Introduced frequency scheduling on per-website basis in the YAML config. - Introduced Pytest fixtures for handling test database and authorized HTTP client in `tests/conftest.py`. - Included a first test for the api - Implemented changes to models to accommodate changes to task scheduling. - Fixed errors concerning database concurrency arising from changes to the application setup. --- argos/checks/checks.py | 10 +---- argos/schemas/config.py | 96 ++++++++++++++++++---------------------- argos/server/__init__.py | 2 +- argos/server/alerting.py | 2 +- argos/server/api.py | 61 ++++++++----------------- argos/server/database.py | 11 ----- argos/server/main.py | 78 ++++++++++++++++++++++++++++++++ argos/server/models.py | 22 ++++++--- argos/server/queries.py | 42 ++++++++++++------ argos/server/settings.py | 58 ++++++++++++++++++++++++ config.yaml | 2 +- tests/config.yaml | 18 ++++++++ tests/conftest.py | 25 +++++++++++ tests/test_api.py | 28 ++++++++++++ tests/websites.yaml | 6 +++ 15 files changed, 322 insertions(+), 139 deletions(-) delete mode 100644 argos/server/database.py create mode 100644 argos/server/main.py create mode 100644 argos/server/settings.py create mode 100644 tests/config.yaml create mode 100644 tests/conftest.py create mode 100644 tests/test_api.py create mode 100644 tests/websites.yaml diff --git a/argos/checks/checks.py b/argos/checks/checks.py index e8bfa84..3eeb6e0 100644 --- a/argos/checks/checks.py +++ b/argos/checks/checks.py @@ -4,14 +4,8 @@ from datetime import datetime from OpenSSL import crypto -from argos.checks.base import ( - BaseCheck, - ExpectedIntValue, - ExpectedStringValue, - Response, - Status, - Severity, -) +from argos.checks.base import (BaseCheck, ExpectedIntValue, + ExpectedStringValue, Response, Severity, Status) from argos.logging import logger diff --git a/argos/schemas/config.py b/argos/schemas/config.py index f039740..2feed4b 100644 --- a/argos/schemas/config.py +++ b/argos/schemas/config.py @@ -1,31 +1,27 @@ -import os -from datetime import datetime -from enum import StrEnum -from typing import Dict, List, Literal, Optional, Tuple, Union +from typing import Dict, List, Literal, Optional, Tuple -import yaml -from pydantic import BaseModel, Field, HttpUrl, validator -from yamlinclude import YamlIncludeConstructor +from pydantic import BaseModel, HttpUrl, field_validator +from pydantic.functional_validators import BeforeValidator +from typing_extensions import Annotated from argos.schemas.utils import string_to_duration -# XXX Find a way to check without having cirular imports - -# This file contains the pydantic schemas. For the database models, check in argos.model. +# This file contains the pydantic schemas. +# For the database models, check in argos.server.models. Severity = Literal["warning", "error", "critical"] -class SSL(BaseModel): - thresholds: List[Tuple[int, Severity]] +def parse_threshold(value): + for duration_str, severity in value.items(): + days = string_to_duration(duration_str, "days") + # Return here because it's one-item dicts. + return (days, severity) - @validator("thresholds", each_item=True, pre=True) - def parse_threshold(cls, value): - for duration_str, severity in value.items(): - days = string_to_duration(duration_str, "days") - # Return here because it's one-item dicts. - return (days, severity) + +class SSL(BaseModel): + thresholds: List[Annotated[Tuple[int, Severity], BeforeValidator(parse_threshold)]] class WebsiteCheck(BaseModel): @@ -51,30 +47,41 @@ class WebsiteCheck(BaseModel): raise ValueError("Invalid type") +def parse_checks(value): + # To avoid circular imports + from argos.checks import get_registered_checks + + available_names = get_registered_checks().keys() + + for name, expected in value.items(): + if name not in available_names: + msg = f"Check should be one of f{available_names}. ({name} given)" + raise ValueError(msg) + return (name, expected) + + class WebsitePath(BaseModel): path: str - checks: List[Tuple[str, str | dict | int]] - - @validator("checks", each_item=True, pre=True) - def parse_checks(cls, value): - from argos.checks import get_registered_checks # To avoid circular imports - - available_names = get_registered_checks().keys() - - for name, expected in value.items(): - if name not in available_names: - msg = f"Check should be one of f{available_names}. ({name} given)" - raise ValueError(msg) - return (name, expected) + checks: List[ + Annotated[ + Tuple[str, str | dict | int], + BeforeValidator(parse_checks), + ] + ] class Website(BaseModel): domain: HttpUrl + frequency: Optional[int] = None paths: List[WebsitePath] + @field_validator("frequency", mode="before") + def parse_frequency(cls, value): + if value: + return string_to_duration(value, "hours") + class Service(BaseModel): - port: int secrets: List[str] @@ -85,12 +92,12 @@ class Alert(BaseModel): class General(BaseModel): - frequency: str + frequency: int alerts: Alert - @validator("frequency", pre=True) + @field_validator("frequency", mode="before") def parse_frequency(cls, value): - return string_to_duration(value, "hours") + return string_to_duration(value, "minutes") class Config(BaseModel): @@ -98,22 +105,3 @@ class Config(BaseModel): service: Service ssl: SSL websites: List[Website] - - -def validate_config(config: dict): - return Config(**config) - - -def from_yaml(filename): - parsed = load_yaml(filename) - return validate_config(parsed) - - -def load_yaml(filename): - base_dir = os.path.dirname(filename) - YamlIncludeConstructor.add_to_loader_class( - loader_class=yaml.FullLoader, base_dir=base_dir - ) - - with open(filename, "r") as stream: - return yaml.load(stream, Loader=yaml.FullLoader) diff --git a/argos/server/__init__.py b/argos/server/__init__.py index ce71875..64c3804 100644 --- a/argos/server/__init__.py +++ b/argos/server/__init__.py @@ -1 +1 @@ -from argos.server.api import app +from argos.server.main import app # noqa: F401 diff --git a/argos/server/alerting.py b/argos/server/alerting.py index b2775c0..eef18a4 100644 --- a/argos/server/alerting.py +++ b/argos/server/alerting.py @@ -3,4 +3,4 @@ from argos.logging import logger def handle_alert(config, result, task, severity): msg = f"{result=}, {task=}, {severity=}" - logger.error(msg) + logger.error(f"Alerting stub: {msg}") diff --git a/argos/server/api.py b/argos/server/api.py index 4d62bcc..5959ce0 100644 --- a/argos/server/api.py +++ b/argos/server/api.py @@ -1,67 +1,43 @@ -import sys -from typing import Annotated, List, Optional +from typing import List -from fastapi import Depends, FastAPI, Header, HTTPException, Request +from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer -from pydantic import BaseModel, ValidationError from sqlalchemy.orm import Session -from argos.checks import get_registered_check from argos.logging import logger from argos.schemas import AgentResult, Task -from argos.schemas.config import from_yaml as get_schemas_from_yaml -from argos.server import models, queries -from argos.server.database import SessionLocal, engine +from argos.server import queries from argos.server.alerting import handle_alert -models.Base.metadata.create_all(bind=engine) - -app = FastAPI() +api = APIRouter() auth_scheme = HTTPBearer() -def get_db(): - db = SessionLocal() +def get_db(request: Request): + db = request.app.state.SessionLocal() try: yield db finally: db.close() -async def verify_token(token: HTTPAuthorizationCredentials = Depends(auth_scheme)): - if token.credentials not in app.config.service.secrets: +async def verify_token( + request: Request, token: HTTPAuthorizationCredentials = Depends(auth_scheme) +): + if token.credentials not in request.app.state.config.service.secrets: raise HTTPException(status_code=401, detail="Unauthorized") return token -@app.on_event("startup") -async def read_config_and_populate_db(): - # XXX Get filename from environment. - try: - config = get_schemas_from_yaml("config.yaml") - app.config = config - except ValidationError as e: - logger.error(f"Errors where found while reading configuration:") - for error in e.errors(): - logger.error(f"{error['loc']} is {error['type']}") - sys.exit(1) - - db = SessionLocal() - try: - await queries.update_from_config(db, config) - finally: - db.close() - - # XXX Get the default limit from the config -@app.get("/tasks", response_model=list[Task], dependencies=[Depends(verify_token)]) +@api.get("/tasks", response_model=list[Task], dependencies=[Depends(verify_token)]) async def read_tasks(request: Request, db: Session = Depends(get_db), limit: int = 20): # XXX Let the agents specifify their names (and use hostnames) tasks = await queries.list_tasks(db, agent_id=request.client.host, limit=limit) return tasks -@app.post("/results", status_code=201, dependencies=[Depends(verify_token)]) +@api.post("/results", status_code=201, dependencies=[Depends(verify_token)]) async def create_result(results: List[AgentResult], db: Session = Depends(get_db)): """Get the results from the agents and store them locally. @@ -82,20 +58,19 @@ async def create_result(results: List[AgentResult], db: Session = Depends(get_db else: check = task.get_check() status, severity = await check.finalize( - app.config, result, **result.context + api.config, result, **result.context ) - result.severity = severity - result.status = status - # Set the selection status to None - task.selected_by = None - handle_alert(app.config, result, task, severity) + result.set_status(status, severity) + task.set_times_and_deselect() + + handle_alert(api.config, result, task, severity) db_results.append(result) db.commit() return {"result_ids": [r.id for r in db_results]} -@app.get("/stats", dependencies=[Depends(verify_token)]) +@api.get("/stats", dependencies=[Depends(verify_token)]) async def get_stats(db: Session = Depends(get_db)): return { "upcoming_tasks_count": await queries.count_tasks(db, selected=False), diff --git a/argos/server/database.py b/argos/server/database.py deleted file mode 100644 index 715c93e..0000000 --- a/argos/server/database.py +++ /dev/null @@ -1,11 +0,0 @@ -from sqlalchemy import create_engine -from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine -from sqlalchemy.orm import DeclarativeBase, sessionmaker - -SQLALCHEMY_DATABASE_URL = "sqlite:////tmp/argos.db" -# SQLALCHEMY_DATABASE_URL = "postgresql://user:password@postgresserver/db" - -engine = create_engine( - SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} -) -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) diff --git a/argos/server/main.py b/argos/server/main.py new file mode 100644 index 0000000..7a57f4e --- /dev/null +++ b/argos/server/main.py @@ -0,0 +1,78 @@ +import sys + +from fastapi import FastAPI +from pydantic import ValidationError +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from argos.logging import logger +from argos.server import models, queries +from argos.server.api import api as api_router +from argos.server.settings import get_app_settings, read_yaml_config + + +def get_application() -> FastAPI: + settings = get_app_settings() + app = FastAPI() + + config = read_config(app, settings) + app.state.config = config + + app.add_event_handler( + "startup", + create_start_app_handler(app, settings), + ) + app.add_event_handler( + "shutdown", + create_stop_app_handler(app), + ) + app.include_router(api_router) + return app + + +def create_start_app_handler(app, settings): + async def read_config_and_populate_db(): + setup_database(app, settings) + + db = await connect_to_db(app, settings) + await queries.update_from_config(db, app.state.config) + + return read_config_and_populate_db + + +async def connect_to_db(app, settings): + app.state.db = app.state.SessionLocal() + return app.state.db + + +def create_stop_app_handler(app): + async def stop_app(): + app.state.db.close() + + return stop_app + + +def read_config(app, settings): + try: + config = read_yaml_config(settings.yaml_file) + app.state.config = config + return config + except ValidationError as e: + logger.error("Errors where found while reading configuration:") + for error in e.errors(): + logger.error(f"{error['loc']} is {error['type']}") + sys.exit(1) + + +def setup_database(app, settings): + engine = create_engine( + settings.database_url, connect_args={"check_same_thread": False} + ) + app.state.SessionLocal = sessionmaker( + autocommit=False, autoflush=False, bind=engine + ) + app.state.engine = engine + models.Base.metadata.create_all(bind=engine) + + +app = get_application() diff --git a/argos/server/models.py b/argos/server/models.py index 24bd7ff..b83d0cc 100644 --- a/argos/server/models.py +++ b/argos/server/models.py @@ -1,18 +1,12 @@ -from datetime import datetime +from datetime import datetime, timedelta from typing import List, Literal from sqlalchemy import ( JSON, - Boolean, - Column, - DateTime, Enum, ForeignKey, - Integer, - String, ) from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship -from sqlalchemy_utils import ChoiceType from argos.checks import get_registered_check from argos.schemas import WebsiteCheck @@ -38,10 +32,13 @@ class Task(Base): domain: Mapped[str] = mapped_column() check: Mapped[str] = mapped_column() expected: Mapped[str] = mapped_column() + frequency: Mapped[int] = mapped_column() # Orchestration-related selected_by: Mapped[str] = mapped_column(nullable=True) selected_at: Mapped[datetime] = mapped_column(nullable=True) + completed_at: Mapped[datetime] = mapped_column(nullable=True) + next_run: Mapped[datetime] = mapped_column(nullable=True) results: Mapped[List["Result"]] = relationship(back_populates="task") @@ -52,6 +49,13 @@ class Task(Base): """Returns a check instance for this specific task""" return get_registered_check(self.check) + def set_times_and_deselect(self): + self.selected_by = None + + now = datetime.now() + self.completed_at = now + self.next_run = now + timedelta(hours=self.frequency) + class Result(Base): __tablename__ = "results" @@ -68,5 +72,9 @@ class Result(Base): ) context: Mapped[dict] = mapped_column() + def set_status(self, status, severity): + self.severity = severity + self.status = status + def __str__(self): return f"DB Result {self.id} - {self.status} - {self.context}" diff --git a/argos/server/queries.py b/argos/server/queries.py index 0033cbf..39e17c0 100644 --- a/argos/server/queries.py +++ b/argos/server/queries.py @@ -1,7 +1,6 @@ from datetime import datetime from urllib.parse import urljoin -from sqlalchemy import exists from sqlalchemy.orm import Session from argos import schemas @@ -11,9 +10,17 @@ from argos.server.models import Result, Task async def list_tasks(db: Session, agent_id: str, limit: int = 100): """List tasks and mark them as selected""" - tasks = db.query(Task).where(Task.selected_by == None).limit(limit).all() - now = datetime.now() + tasks = ( + db.query(Task) + .filter( + Task.selected_by == None, # noqa: E711 + ((Task.next_run >= datetime.now()) | (Task.next_run == None)), # noqa: E711 + ) + .limit(limit) + .all() + ) + now = datetime.now() for task in tasks: task.selected_at = now task.selected_by = agent_id @@ -53,26 +60,35 @@ async def count_results(db: Session): async def update_from_config(db: Session, config: schemas.Config): for website in config.websites: domain = str(website.domain) + frequency = website.frequency or config.general.frequency + for p in website.paths: url = urljoin(domain, str(p.path)) for check_key, expected in p.checks: # Check the db for already existing tasks. - existing_task = db.query( - exists().where( - Task.url == url - and Task.check == check_key - and Task.expected == expected + existing_task = ( + db.query(Task) + .filter( + Task.url == url, + Task.check == check_key, + Task.expected == expected, ) - ).scalar() + .scalar() + ) + if existing_task and frequency != existing_task.frequency: + existing_task.frequency = frequency if not existing_task: task = Task( - domain=domain, url=url, check=check_key, expected=expected + domain=domain, + url=url, + check=check_key, + expected=expected, + frequency=frequency, ) logger.debug(f"Adding a new task in the db: {task}") db.add(task) else: - logger.debug( - f"Skipping db task creation for {url=}, {check_key=}, {expected=}." - ) + msg = f"Skipping db task creation for {url=}, {check_key=}, {expected=}, {frequency=}." + logger.debug(msg) db.commit() diff --git a/argos/server/settings.py b/argos/server/settings.py new file mode 100644 index 0000000..f4e64b7 --- /dev/null +++ b/argos/server/settings.py @@ -0,0 +1,58 @@ +import os +from functools import lru_cache +from os import environ + +import yaml +from pydantic_settings import BaseSettings +from yamlinclude import YamlIncludeConstructor + +from argos.schemas.config import Config + + +class DefaultSettings(BaseSettings): + app_env: str = "prod" + database_url: str = "" + yaml_file: str = "" + + +class DevSettings(DefaultSettings): + database_url: str = "sqlite:////tmp/argos.db" + yaml_file: str = "config.yaml" + + +class TestSettings(DefaultSettings): + database_url: str = "sqlite:////tmp/test-argos.db" + yaml_file: str = "tests/config.yaml" + + +class ProdSettings(DefaultSettings): + pass + + +environments = { + "dev": DevSettings, + "prod": ProdSettings, + "test": TestSettings, +} + + +@lru_cache() +def get_app_settings() -> DefaultSettings: + app_env = environ.get("APP_ENV", "dev") + settings = environments.get(app_env) + return settings() + + +def read_yaml_config(filename): + parsed = _load_yaml(filename) + return Config(**parsed) + + +def _load_yaml(filename): + base_dir = os.path.dirname(filename) + YamlIncludeConstructor.add_to_loader_class( + loader_class=yaml.FullLoader, base_dir=base_dir + ) + + with open(filename, "r") as stream: + return yaml.load(stream, Loader=yaml.FullLoader) diff --git a/config.yaml b/config.yaml index 4118503..d4a9a28 100644 --- a/config.yaml +++ b/config.yaml @@ -1,5 +1,5 @@ general: - frequency: 4h # Run checks every 4 hours. + frequency: "1h" # Run checks every 4 hours. alerts: error: - local diff --git a/tests/config.yaml b/tests/config.yaml new file mode 100644 index 0000000..ab6c3c6 --- /dev/null +++ b/tests/config.yaml @@ -0,0 +1,18 @@ +general: + frequency: "1m" + alerts: + error: + - local + warning: + - local + alert: + - local +service: + secrets: + - "O4kt8Max9/k0EmHaEJ0CGGYbBNFmK8kOZNIoUk3Kjwc" + - "x1T1VZR51pxrv5pQUyzooMG4pMUvHNMhA5y/3cUsYVs=" +ssl: + thresholds: + - "1d": critical + "5d": warning +websites: !include websites.yaml \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..7a99060 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,25 @@ +from os import environ + +import pytest +from fastapi import FastAPI +from httpx import AsyncClient + +environ["APP_ENV"] = "test" + + +@pytest.fixture +def app() -> FastAPI: + from argos.server.main import get_application # local import for testing purpose + + return get_application() + + +@pytest.fixture +def authorized_client( + client: AsyncClient, token: str, authorization_prefix: str +) -> AsyncClient: + client.headers = { + "Authorization": f"Bearer {token}", + **client.headers, + } + return client diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..e26c624 --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,28 @@ +import pytest +from fastapi.testclient import TestClient + +from argos.server import app, models + + +@pytest.fixture() +def test_db(): + models.Base.metadata.create_all(bind=app.engine) + yield + models.Base.metadata.drop_all(bind=app.engine) + + +def test_read_tasks_requires_auth(): + with TestClient(app) as client: + response = client.get("/tasks") + assert response.status_code == 403 + + +def test_read_tasks_returns_tasks(): + with TestClient(app) as client: + token = app.state.config.service.secrets[0] + client.headers = {"Authorization": f"Bearer {token}"} + response = client.get("/tasks") + assert response.status_code == 200 + + # We should have only two tasks + assert len(response.json()) == 2 diff --git a/tests/websites.yaml b/tests/websites.yaml new file mode 100644 index 0000000..f2d50dc --- /dev/null +++ b/tests/websites.yaml @@ -0,0 +1,6 @@ +- domain: "https://mypads.framapad.org" + paths: + - path: "/mypads/" + checks: + - status-is: 200 + - body-contains: '
' \ No newline at end of file