Refactor server codebase for testing.

- Restructured server module to separate the application creation and configuration.
- Moved code dealing with SQLAlchemy database
  setup and teardown to the main application file.
- Moved functions related to configuration file loading to `argos.server.settings`.
- Fixed SQLAchemy expressions in `argos.server.queries`.
- Implemented a more granular system of setting checks' schedule on the server.
- Introduced frequency scheduling on per-website basis in the YAML config.
- Introduced Pytest fixtures for handling test database and authorized HTTP client in `tests/conftest.py`.
- Included a first test for the api
- Implemented changes to models to accommodate changes to task scheduling.
- Fixed errors concerning database concurrency arising from changes to the application setup.
This commit is contained in:
Alexis Métaireau 2023-10-11 23:52:33 +02:00
parent e540eee9b3
commit 43f8aabb2c
15 changed files with 322 additions and 139 deletions

View file

@ -4,14 +4,8 @@ from datetime import datetime
from OpenSSL import crypto from OpenSSL import crypto
from argos.checks.base import ( from argos.checks.base import (BaseCheck, ExpectedIntValue,
BaseCheck, ExpectedStringValue, Response, Severity, Status)
ExpectedIntValue,
ExpectedStringValue,
Response,
Status,
Severity,
)
from argos.logging import logger from argos.logging import logger

View file

@ -1,33 +1,29 @@
import os from typing import Dict, List, Literal, Optional, Tuple
from datetime import datetime
from enum import StrEnum
from typing import Dict, List, Literal, Optional, Tuple, Union
import yaml from pydantic import BaseModel, HttpUrl, field_validator
from pydantic import BaseModel, Field, HttpUrl, validator from pydantic.functional_validators import BeforeValidator
from yamlinclude import YamlIncludeConstructor from typing_extensions import Annotated
from argos.schemas.utils import string_to_duration from argos.schemas.utils import string_to_duration
# XXX Find a way to check without having cirular imports # This file contains the pydantic schemas.
# For the database models, check in argos.server.models.
# This file contains the pydantic schemas. For the database models, check in argos.model.
Severity = Literal["warning", "error", "critical"] Severity = Literal["warning", "error", "critical"]
class SSL(BaseModel): def parse_threshold(value):
thresholds: List[Tuple[int, Severity]]
@validator("thresholds", each_item=True, pre=True)
def parse_threshold(cls, value):
for duration_str, severity in value.items(): for duration_str, severity in value.items():
days = string_to_duration(duration_str, "days") days = string_to_duration(duration_str, "days")
# Return here because it's one-item dicts. # Return here because it's one-item dicts.
return (days, severity) return (days, severity)
class SSL(BaseModel):
thresholds: List[Annotated[Tuple[int, Severity], BeforeValidator(parse_threshold)]]
class WebsiteCheck(BaseModel): class WebsiteCheck(BaseModel):
key: str key: str
value: str | List[str] | Dict[str, str] value: str | List[str] | Dict[str, str]
@ -51,13 +47,9 @@ class WebsiteCheck(BaseModel):
raise ValueError("Invalid type") raise ValueError("Invalid type")
class WebsitePath(BaseModel): def parse_checks(value):
path: str # To avoid circular imports
checks: List[Tuple[str, str | dict | int]] from argos.checks import get_registered_checks
@validator("checks", each_item=True, pre=True)
def parse_checks(cls, value):
from argos.checks import get_registered_checks # To avoid circular imports
available_names = get_registered_checks().keys() available_names = get_registered_checks().keys()
@ -68,13 +60,28 @@ class WebsitePath(BaseModel):
return (name, expected) return (name, expected)
class WebsitePath(BaseModel):
path: str
checks: List[
Annotated[
Tuple[str, str | dict | int],
BeforeValidator(parse_checks),
]
]
class Website(BaseModel): class Website(BaseModel):
domain: HttpUrl domain: HttpUrl
frequency: Optional[int] = None
paths: List[WebsitePath] paths: List[WebsitePath]
@field_validator("frequency", mode="before")
def parse_frequency(cls, value):
if value:
return string_to_duration(value, "hours")
class Service(BaseModel): class Service(BaseModel):
port: int
secrets: List[str] secrets: List[str]
@ -85,12 +92,12 @@ class Alert(BaseModel):
class General(BaseModel): class General(BaseModel):
frequency: str frequency: int
alerts: Alert alerts: Alert
@validator("frequency", pre=True) @field_validator("frequency", mode="before")
def parse_frequency(cls, value): def parse_frequency(cls, value):
return string_to_duration(value, "hours") return string_to_duration(value, "minutes")
class Config(BaseModel): class Config(BaseModel):
@ -98,22 +105,3 @@ class Config(BaseModel):
service: Service service: Service
ssl: SSL ssl: SSL
websites: List[Website] websites: List[Website]
def validate_config(config: dict):
return Config(**config)
def from_yaml(filename):
parsed = load_yaml(filename)
return validate_config(parsed)
def load_yaml(filename):
base_dir = os.path.dirname(filename)
YamlIncludeConstructor.add_to_loader_class(
loader_class=yaml.FullLoader, base_dir=base_dir
)
with open(filename, "r") as stream:
return yaml.load(stream, Loader=yaml.FullLoader)

View file

@ -1 +1 @@
from argos.server.api import app from argos.server.main import app # noqa: F401

View file

@ -3,4 +3,4 @@ from argos.logging import logger
def handle_alert(config, result, task, severity): def handle_alert(config, result, task, severity):
msg = f"{result=}, {task=}, {severity=}" msg = f"{result=}, {task=}, {severity=}"
logger.error(msg) logger.error(f"Alerting stub: {msg}")

View file

@ -1,67 +1,43 @@
import sys from typing import List
from typing import Annotated, List, Optional
from fastapi import Depends, FastAPI, Header, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseModel, ValidationError
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from argos.checks import get_registered_check
from argos.logging import logger from argos.logging import logger
from argos.schemas import AgentResult, Task from argos.schemas import AgentResult, Task
from argos.schemas.config import from_yaml as get_schemas_from_yaml from argos.server import queries
from argos.server import models, queries
from argos.server.database import SessionLocal, engine
from argos.server.alerting import handle_alert from argos.server.alerting import handle_alert
models.Base.metadata.create_all(bind=engine) api = APIRouter()
app = FastAPI()
auth_scheme = HTTPBearer() auth_scheme = HTTPBearer()
def get_db(): def get_db(request: Request):
db = SessionLocal() db = request.app.state.SessionLocal()
try: try:
yield db yield db
finally: finally:
db.close() db.close()
async def verify_token(token: HTTPAuthorizationCredentials = Depends(auth_scheme)): async def verify_token(
if token.credentials not in app.config.service.secrets: request: Request, token: HTTPAuthorizationCredentials = Depends(auth_scheme)
):
if token.credentials not in request.app.state.config.service.secrets:
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
return token return token
@app.on_event("startup")
async def read_config_and_populate_db():
# XXX Get filename from environment.
try:
config = get_schemas_from_yaml("config.yaml")
app.config = config
except ValidationError as e:
logger.error(f"Errors where found while reading configuration:")
for error in e.errors():
logger.error(f"{error['loc']} is {error['type']}")
sys.exit(1)
db = SessionLocal()
try:
await queries.update_from_config(db, config)
finally:
db.close()
# XXX Get the default limit from the config # XXX Get the default limit from the config
@app.get("/tasks", response_model=list[Task], dependencies=[Depends(verify_token)]) @api.get("/tasks", response_model=list[Task], dependencies=[Depends(verify_token)])
async def read_tasks(request: Request, db: Session = Depends(get_db), limit: int = 20): async def read_tasks(request: Request, db: Session = Depends(get_db), limit: int = 20):
# XXX Let the agents specifify their names (and use hostnames) # XXX Let the agents specifify their names (and use hostnames)
tasks = await queries.list_tasks(db, agent_id=request.client.host, limit=limit) tasks = await queries.list_tasks(db, agent_id=request.client.host, limit=limit)
return tasks return tasks
@app.post("/results", status_code=201, dependencies=[Depends(verify_token)]) @api.post("/results", status_code=201, dependencies=[Depends(verify_token)])
async def create_result(results: List[AgentResult], db: Session = Depends(get_db)): async def create_result(results: List[AgentResult], db: Session = Depends(get_db)):
"""Get the results from the agents and store them locally. """Get the results from the agents and store them locally.
@ -82,20 +58,19 @@ async def create_result(results: List[AgentResult], db: Session = Depends(get_db
else: else:
check = task.get_check() check = task.get_check()
status, severity = await check.finalize( status, severity = await check.finalize(
app.config, result, **result.context api.config, result, **result.context
) )
result.severity = severity result.set_status(status, severity)
result.status = status task.set_times_and_deselect()
# Set the selection status to None
task.selected_by = None handle_alert(api.config, result, task, severity)
handle_alert(app.config, result, task, severity)
db_results.append(result) db_results.append(result)
db.commit() db.commit()
return {"result_ids": [r.id for r in db_results]} return {"result_ids": [r.id for r in db_results]}
@app.get("/stats", dependencies=[Depends(verify_token)]) @api.get("/stats", dependencies=[Depends(verify_token)])
async def get_stats(db: Session = Depends(get_db)): async def get_stats(db: Session = Depends(get_db)):
return { return {
"upcoming_tasks_count": await queries.count_tasks(db, selected=False), "upcoming_tasks_count": await queries.count_tasks(db, selected=False),

View file

@ -1,11 +0,0 @@
from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import DeclarativeBase, sessionmaker
SQLALCHEMY_DATABASE_URL = "sqlite:////tmp/argos.db"
# SQLALCHEMY_DATABASE_URL = "postgresql://user:password@postgresserver/db"
engine = create_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

78
argos/server/main.py Normal file
View file

@ -0,0 +1,78 @@
import sys
from fastapi import FastAPI
from pydantic import ValidationError
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from argos.logging import logger
from argos.server import models, queries
from argos.server.api import api as api_router
from argos.server.settings import get_app_settings, read_yaml_config
def get_application() -> FastAPI:
settings = get_app_settings()
app = FastAPI()
config = read_config(app, settings)
app.state.config = config
app.add_event_handler(
"startup",
create_start_app_handler(app, settings),
)
app.add_event_handler(
"shutdown",
create_stop_app_handler(app),
)
app.include_router(api_router)
return app
def create_start_app_handler(app, settings):
async def read_config_and_populate_db():
setup_database(app, settings)
db = await connect_to_db(app, settings)
await queries.update_from_config(db, app.state.config)
return read_config_and_populate_db
async def connect_to_db(app, settings):
app.state.db = app.state.SessionLocal()
return app.state.db
def create_stop_app_handler(app):
async def stop_app():
app.state.db.close()
return stop_app
def read_config(app, settings):
try:
config = read_yaml_config(settings.yaml_file)
app.state.config = config
return config
except ValidationError as e:
logger.error("Errors where found while reading configuration:")
for error in e.errors():
logger.error(f"{error['loc']} is {error['type']}")
sys.exit(1)
def setup_database(app, settings):
engine = create_engine(
settings.database_url, connect_args={"check_same_thread": False}
)
app.state.SessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine
)
app.state.engine = engine
models.Base.metadata.create_all(bind=engine)
app = get_application()

View file

@ -1,18 +1,12 @@
from datetime import datetime from datetime import datetime, timedelta
from typing import List, Literal from typing import List, Literal
from sqlalchemy import ( from sqlalchemy import (
JSON, JSON,
Boolean,
Column,
DateTime,
Enum, Enum,
ForeignKey, ForeignKey,
Integer,
String,
) )
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from sqlalchemy_utils import ChoiceType
from argos.checks import get_registered_check from argos.checks import get_registered_check
from argos.schemas import WebsiteCheck from argos.schemas import WebsiteCheck
@ -38,10 +32,13 @@ class Task(Base):
domain: Mapped[str] = mapped_column() domain: Mapped[str] = mapped_column()
check: Mapped[str] = mapped_column() check: Mapped[str] = mapped_column()
expected: Mapped[str] = mapped_column() expected: Mapped[str] = mapped_column()
frequency: Mapped[int] = mapped_column()
# Orchestration-related # Orchestration-related
selected_by: Mapped[str] = mapped_column(nullable=True) selected_by: Mapped[str] = mapped_column(nullable=True)
selected_at: Mapped[datetime] = mapped_column(nullable=True) selected_at: Mapped[datetime] = mapped_column(nullable=True)
completed_at: Mapped[datetime] = mapped_column(nullable=True)
next_run: Mapped[datetime] = mapped_column(nullable=True)
results: Mapped[List["Result"]] = relationship(back_populates="task") results: Mapped[List["Result"]] = relationship(back_populates="task")
@ -52,6 +49,13 @@ class Task(Base):
"""Returns a check instance for this specific task""" """Returns a check instance for this specific task"""
return get_registered_check(self.check) return get_registered_check(self.check)
def set_times_and_deselect(self):
self.selected_by = None
now = datetime.now()
self.completed_at = now
self.next_run = now + timedelta(hours=self.frequency)
class Result(Base): class Result(Base):
__tablename__ = "results" __tablename__ = "results"
@ -68,5 +72,9 @@ class Result(Base):
) )
context: Mapped[dict] = mapped_column() context: Mapped[dict] = mapped_column()
def set_status(self, status, severity):
self.severity = severity
self.status = status
def __str__(self): def __str__(self):
return f"DB Result {self.id} - {self.status} - {self.context}" return f"DB Result {self.id} - {self.status} - {self.context}"

View file

@ -1,7 +1,6 @@
from datetime import datetime from datetime import datetime
from urllib.parse import urljoin from urllib.parse import urljoin
from sqlalchemy import exists
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from argos import schemas from argos import schemas
@ -11,9 +10,17 @@ from argos.server.models import Result, Task
async def list_tasks(db: Session, agent_id: str, limit: int = 100): async def list_tasks(db: Session, agent_id: str, limit: int = 100):
"""List tasks and mark them as selected""" """List tasks and mark them as selected"""
tasks = db.query(Task).where(Task.selected_by == None).limit(limit).all() tasks = (
now = datetime.now() db.query(Task)
.filter(
Task.selected_by == None, # noqa: E711
((Task.next_run >= datetime.now()) | (Task.next_run == None)), # noqa: E711
)
.limit(limit)
.all()
)
now = datetime.now()
for task in tasks: for task in tasks:
task.selected_at = now task.selected_at = now
task.selected_by = agent_id task.selected_by = agent_id
@ -53,26 +60,35 @@ async def count_results(db: Session):
async def update_from_config(db: Session, config: schemas.Config): async def update_from_config(db: Session, config: schemas.Config):
for website in config.websites: for website in config.websites:
domain = str(website.domain) domain = str(website.domain)
frequency = website.frequency or config.general.frequency
for p in website.paths: for p in website.paths:
url = urljoin(domain, str(p.path)) url = urljoin(domain, str(p.path))
for check_key, expected in p.checks: for check_key, expected in p.checks:
# Check the db for already existing tasks. # Check the db for already existing tasks.
existing_task = db.query( existing_task = (
exists().where( db.query(Task)
Task.url == url .filter(
and Task.check == check_key Task.url == url,
and Task.expected == expected Task.check == check_key,
Task.expected == expected,
) )
).scalar() .scalar()
)
if existing_task and frequency != existing_task.frequency:
existing_task.frequency = frequency
if not existing_task: if not existing_task:
task = Task( task = Task(
domain=domain, url=url, check=check_key, expected=expected domain=domain,
url=url,
check=check_key,
expected=expected,
frequency=frequency,
) )
logger.debug(f"Adding a new task in the db: {task}") logger.debug(f"Adding a new task in the db: {task}")
db.add(task) db.add(task)
else: else:
logger.debug( msg = f"Skipping db task creation for {url=}, {check_key=}, {expected=}, {frequency=}."
f"Skipping db task creation for {url=}, {check_key=}, {expected=}." logger.debug(msg)
)
db.commit() db.commit()

58
argos/server/settings.py Normal file
View file

@ -0,0 +1,58 @@
import os
from functools import lru_cache
from os import environ
import yaml
from pydantic_settings import BaseSettings
from yamlinclude import YamlIncludeConstructor
from argos.schemas.config import Config
class DefaultSettings(BaseSettings):
app_env: str = "prod"
database_url: str = ""
yaml_file: str = ""
class DevSettings(DefaultSettings):
database_url: str = "sqlite:////tmp/argos.db"
yaml_file: str = "config.yaml"
class TestSettings(DefaultSettings):
database_url: str = "sqlite:////tmp/test-argos.db"
yaml_file: str = "tests/config.yaml"
class ProdSettings(DefaultSettings):
pass
environments = {
"dev": DevSettings,
"prod": ProdSettings,
"test": TestSettings,
}
@lru_cache()
def get_app_settings() -> DefaultSettings:
app_env = environ.get("APP_ENV", "dev")
settings = environments.get(app_env)
return settings()
def read_yaml_config(filename):
parsed = _load_yaml(filename)
return Config(**parsed)
def _load_yaml(filename):
base_dir = os.path.dirname(filename)
YamlIncludeConstructor.add_to_loader_class(
loader_class=yaml.FullLoader, base_dir=base_dir
)
with open(filename, "r") as stream:
return yaml.load(stream, Loader=yaml.FullLoader)

View file

@ -1,5 +1,5 @@
general: general:
frequency: 4h # Run checks every 4 hours. frequency: "1h" # Run checks every 4 hours.
alerts: alerts:
error: error:
- local - local

18
tests/config.yaml Normal file
View file

@ -0,0 +1,18 @@
general:
frequency: "1m"
alerts:
error:
- local
warning:
- local
alert:
- local
service:
secrets:
- "O4kt8Max9/k0EmHaEJ0CGGYbBNFmK8kOZNIoUk3Kjwc"
- "x1T1VZR51pxrv5pQUyzooMG4pMUvHNMhA5y/3cUsYVs="
ssl:
thresholds:
- "1d": critical
"5d": warning
websites: !include websites.yaml

25
tests/conftest.py Normal file
View file

@ -0,0 +1,25 @@
from os import environ
import pytest
from fastapi import FastAPI
from httpx import AsyncClient
environ["APP_ENV"] = "test"
@pytest.fixture
def app() -> FastAPI:
from argos.server.main import get_application # local import for testing purpose
return get_application()
@pytest.fixture
def authorized_client(
client: AsyncClient, token: str, authorization_prefix: str
) -> AsyncClient:
client.headers = {
"Authorization": f"Bearer {token}",
**client.headers,
}
return client

28
tests/test_api.py Normal file
View file

@ -0,0 +1,28 @@
import pytest
from fastapi.testclient import TestClient
from argos.server import app, models
@pytest.fixture()
def test_db():
models.Base.metadata.create_all(bind=app.engine)
yield
models.Base.metadata.drop_all(bind=app.engine)
def test_read_tasks_requires_auth():
with TestClient(app) as client:
response = client.get("/tasks")
assert response.status_code == 403
def test_read_tasks_returns_tasks():
with TestClient(app) as client:
token = app.state.config.service.secrets[0]
client.headers = {"Authorization": f"Bearer {token}"}
response = client.get("/tasks")
assert response.status_code == 200
# We should have only two tasks
assert len(response.json()) == 2

6
tests/websites.yaml Normal file
View file

@ -0,0 +1,6 @@
- domain: "https://mypads.framapad.org"
paths:
- path: "/mypads/"
checks:
- status-is: 200
- body-contains: '<div id= "mypads"></div>'