diff --git a/Makefile b/Makefile index f42e098..9f61d37 100644 --- a/Makefile +++ b/Makefile @@ -24,6 +24,8 @@ djlint: venv ## Format the templates venv/bin/djlint --ignore=H030,H031,H006 --profile jinja --lint argos/server/templates/*html pylint: venv ## Runs pylint on the code venv/bin/pylint argos +pylint-alembic: venv ## Runs pylint on alembic migration files + venv/bin/pylint --disable invalid-name,no-member alembic/versions/*.py lint: djlint pylint help: @python3 -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST) diff --git a/alembic/env.py b/alembic/env.py index 347ff3f..4f34ce8 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -28,6 +28,7 @@ def run_migrations_offline() -> None: context.configure( url=url, target_metadata=target_metadata, + render_as_batch=True, literal_binds=True, dialect_opts={"paramstyle": "named"}, ) @@ -50,7 +51,10 @@ def run_migrations_online() -> None: ) with connectable.connect() as connection: - context.configure(connection=connection, target_metadata=target_metadata) + context.configure(connection=connection, + target_metadata=target_metadata, + render_as_batch=True, + ) with context.begin_transaction(): context.run_migrations() diff --git a/alembic/versions/1a3497f9f71b_adding_configcache_model.py b/alembic/versions/1a3497f9f71b_adding_configcache_model.py new file mode 100644 index 0000000..66befb1 --- /dev/null +++ b/alembic/versions/1a3497f9f71b_adding_configcache_model.py @@ -0,0 +1,31 @@ +"""Adding ConfigCache model + +Revision ID: 1a3497f9f71b +Revises: e99bc35702c9 +Create Date: 2024-03-13 15:28:09.185377 + +""" +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + + +# revision identifiers, used by Alembic. +revision: str = '1a3497f9f71b' +down_revision: Union[str, None] = 'e99bc35702c9' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table('config_cache', + sa.Column('name', sa.String(), nullable=False), + sa.Column('val', sa.String(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint('name') + ) + + +def downgrade() -> None: + op.drop_table('config_cache') diff --git a/alembic/versions/7d480e6f1112_initial_migrations.py b/alembic/versions/7d480e6f1112_initial_migrations.py index 87bb415..3382fe1 100644 --- a/alembic/versions/7d480e6f1112_initial_migrations.py +++ b/alembic/versions/7d480e6f1112_initial_migrations.py @@ -53,6 +53,7 @@ def upgrade() -> None: sa.ForeignKeyConstraint( ["task_id"], ["tasks.id"], + name="results_task_id_fkey", ), sa.PrimaryKeyConstraint("id"), ) diff --git a/alembic/versions/defda3f2952d_add_on_delete_cascade_to_results_task_id.py b/alembic/versions/defda3f2952d_add_on_delete_cascade_to_results_task_id.py new file mode 100644 index 0000000..777cfa2 --- /dev/null +++ b/alembic/versions/defda3f2952d_add_on_delete_cascade_to_results_task_id.py @@ -0,0 +1,33 @@ +"""Add ON DELETE CASCADE to results’ task_id + +Revision ID: defda3f2952d +Revises: 1a3497f9f71b +Create Date: 2024-03-18 15:09:34.544573 + +""" +from typing import Sequence, Union + +from alembic import op + + +# revision identifiers, used by Alembic. +revision: str = 'defda3f2952d' +down_revision: Union[str, None] = '1a3497f9f71b' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + with op.batch_alter_table('results', schema=None) as batch_op: + batch_op.drop_constraint('results_task_id_fkey', type_='foreignkey') + batch_op.create_foreign_key('results_task_id_fkey', + 'tasks', + ['task_id'], + ['id'], + ondelete='CASCADE') + + +def downgrade() -> None: + with op.batch_alter_table('results', schema=None) as batch_op: + batch_op.drop_constraint('results_task_id_fkey', type_='foreignkey') + batch_op.create_foreign_key('results_task_id_fkey', 'tasks', ['task_id'], ['id']) diff --git a/argos/commands.py b/argos/commands.py index 9c3ccd6..b90057a 100644 --- a/argos/commands.py +++ b/argos/commands.py @@ -4,6 +4,8 @@ from functools import wraps import click import uvicorn +from alembic import command +from alembic.config import Config from argos import logging from argos.agent import ArgosAgent @@ -131,5 +133,41 @@ async def cleandb(max_results, max_lock_seconds): click.echo(f"{updated} locks released") +@server.command() +@coroutine +async def reload_config(): + """Read tasks config and add/delete tasks in database if needed + """ + # The imports are made here otherwise the agent will need server configuration files. + from argos.server import queries + from argos.server.main import get_application, read_config + from argos.server.settings import get_app_settings + + appli = get_application() + settings = get_app_settings() + config = read_config(appli, settings) + + db = await get_db() + changed = await queries.update_from_config(db, config) + + click.echo(f"{changed['added']} tasks added") + click.echo(f"{changed['vanished']} tasks deleted") + + +@server.command() +@coroutine +async def migrate(): + """Run database migrations + """ + # The imports are made here otherwise the agent will need server configuration files. + from argos.server.settings import get_app_settings + + settings = get_app_settings() + + alembic_cfg = Config("alembic.ini") + alembic_cfg.set_main_option("sqlalchemy.url", settings.database_url) + command.upgrade(alembic_cfg, "head") + + if __name__ == "__main__": cli() diff --git a/argos/server/main.py b/argos/server/main.py index ce3f534..45c2c5f 100644 --- a/argos/server/main.py +++ b/argos/server/main.py @@ -3,11 +3,11 @@ import sys from fastapi import FastAPI from fastapi.staticfiles import StaticFiles from pydantic import ValidationError -from sqlalchemy import create_engine +from sqlalchemy import create_engine, event from sqlalchemy.orm import sessionmaker from argos.logging import logger -from argos.server import models, queries, routes +from argos.server import models, routes from argos.server.settings import get_app_settings, read_yaml_config @@ -39,15 +39,14 @@ def get_application() -> FastAPI: def create_start_app_handler(appli): """Warmup the server: - setup database connection and update the tasks in it before making it available + setup database connection """ - async def read_config_and_populate_db(): + async def _get_db(): setup_database(appli) - db = await connect_to_db(appli) - await queries.update_from_config(db, appli.state.config) + return await connect_to_db(appli) - return read_config_and_populate_db + return _get_db async def connect_to_db(appli): @@ -95,6 +94,13 @@ def setup_database(appli): settings.database_url, **extra_settings ) + + def _fk_pragma_on_connect(dbapi_con, con_record): + dbapi_con.execute('pragma foreign_keys=ON') + + if settings.database_url.startswith("sqlite:////"): + event.listen(engine, 'connect', _fk_pragma_on_connect) + appli.state.SessionLocal = sessionmaker( autocommit=False, autoflush=False, bind=engine ) diff --git a/argos/server/models.py b/argos/server/models.py index 451eb6b..866ccd6 100644 --- a/argos/server/models.py +++ b/argos/server/models.py @@ -47,7 +47,9 @@ class Task(Base): ) last_severity_update: Mapped[datetime] = mapped_column(nullable=True) - results: Mapped[List["Result"]] = relationship(back_populates="task") + results: Mapped[List["Result"]] = relationship(back_populates="task", + cascade="all, delete", + passive_deletes=True,) def __str__(self): return f"DB Task {self.url} - {self.check} - {self.expected}" @@ -92,7 +94,7 @@ class Result(Base): """ __tablename__ = "results" id: Mapped[int] = mapped_column(primary_key=True) - task_id: Mapped[int] = mapped_column(ForeignKey("tasks.id")) + task_id: Mapped[int] = mapped_column(ForeignKey("tasks.id", ondelete="CASCADE")) task: Mapped["Task"] = relationship(back_populates="results") agent_id: Mapped[str] = mapped_column(nullable=True) @@ -112,3 +114,19 @@ class Result(Base): def __str__(self): return f"DB Result {self.id} - {self.status} - {self.context}" + +class ConfigCache(Base): + """Contains some informations on the previous config state + + Used to quickly determine if we need to update the tasks. + There is currently two cached settings: + - general_frequency: the content of general.frequency setting, in minutes + ex: 5 + - websites_hash: the sha256sum of websites setting, to allow a quick + comparison without looping through all websites + ex: 8b886e7db7b553fe99f6d5437f31745987e243c77b2109b84cf9a7f8bf7d75b1 + """ + __tablename__ = "config_cache" + name: Mapped[str] = mapped_column(primary_key=True) + val: Mapped[str] = mapped_column() + updated_at: Mapped[datetime] = mapped_column() diff --git a/argos/server/queries.py b/argos/server/queries.py index 271b2c9..742c516 100644 --- a/argos/server/queries.py +++ b/argos/server/queries.py @@ -1,5 +1,7 @@ """Functions to ease SQL queries management""" from datetime import datetime, timedelta +from hashlib import sha256 +from typing import List from urllib.parse import urljoin from sqlalchemy import desc, func @@ -7,7 +9,7 @@ from sqlalchemy.orm import Session from argos import schemas from argos.logging import logger -from argos.server.models import Result, Task +from argos.server.models import Result, Task, ConfigCache async def list_tasks(db: Session, agent_id: str, limit: int = 100): @@ -60,10 +62,64 @@ async def count_results(db: Session): return db.query(Result).count() +async def has_config_changed(db: Session, config: schemas.Config) -> bool: + """Check if websites config has changed by using a hashsum and a config cache""" + websites_hash = sha256(str(config.websites).encode()).hexdigest() + conf_caches = ( + db.query(ConfigCache) + .all() + ) + same_config = True + if conf_caches: + for conf in conf_caches: + match (conf.name): + case 'websites_hash': + if conf.val != websites_hash: + same_config = False + conf.val = websites_hash + conf.updated_at = datetime.now() + case 'general_frequency': + if conf.val != str(config.general.frequency): + same_config = False + conf.val = config.general.frequency + conf.updated_at = datetime.now() + + db.commit() + + if same_config: + return False + + else: # no config cache found + web_hash = ConfigCache( + name='websites_hash', + val=websites_hash, + updated_at=datetime.now() + ) + gen_freq = ConfigCache( + name='general_frequency', + val=str(config.general.frequency), + updated_at=datetime.now() + ) + db.add(web_hash) + db.add(gen_freq) + db.commit() + + return True + + async def update_from_config(db: Session, config: schemas.Config): """Update tasks from config file""" + config_changed = await has_config_changed(db, config) + if not config_changed: + return {'added': 0, 'vanished': 0} + + max_task_id = ( + db.query(func.max(Task.id).label('max_id')) # pylint: disable-msg=not-callable + .all() + )[0].max_id tasks = [] unique_properties = [] + seen_tasks: List[int] = [] for website in config.websites: domain = str(website.domain) frequency = website.frequency or config.general.frequency @@ -83,6 +139,7 @@ async def update_from_config(db: Session, config: schemas.Config): ) if existing_tasks: existing_task = existing_tasks[0] + seen_tasks.append(existing_task.id) if frequency != existing_task.frequency: existing_task.frequency = frequency @@ -107,6 +164,21 @@ async def update_from_config(db: Session, config: schemas.Config): db.add_all(tasks) db.commit() + # Delete vanished tasks + if max_task_id: + vanished_tasks = ( + db.query(Task) + .filter( + Task.id <= max_task_id, + Task.id.not_in(seen_tasks) + ).delete() + ) + db.commit() + logger.info("%i tasks has been removed since not in config file anymore", vanished_tasks) + return {'added': len(tasks), 'vanished': vanished_tasks} + + return {'added': len(tasks), 'vanished': 0} + async def get_severity_counts(db: Session) -> dict: """Get the severities (ok, warning, critical…) and their count""" diff --git a/argos/server/routes/api.py b/argos/server/routes/api.py index 432660b..4a316b7 100644 --- a/argos/server/routes/api.py +++ b/argos/server/routes/api.py @@ -62,7 +62,7 @@ async def create_results( # XXX Use a job queue or make it async handle_alert(config, result, task, severity, last_severity, request) - db_results.append(result) + db_results.append(result) db.commit() return {"result_ids": [r.id for r in db_results]} diff --git a/conf/systemd-server.service b/conf/systemd-server.service index 698e637..f5116a9 100644 --- a/conf/systemd-server.service +++ b/conf/systemd-server.service @@ -8,6 +8,8 @@ PartOf=postgresql.service [Service] User=www-data WorkingDirectory=/var/www/argos/ +ExecStartPre=/var/www/argos/venv/bin/argos server migrate +ExecStartPre=/var/www/argos/venv/bin/argos server reload-config ExecStart=/var/www/argos/venv/bin/argos server start ExecReload=/var/www/argos/venv/bin/argos server reload SyslogIdentifier=argos-server diff --git a/tests/conftest.py b/tests/conftest.py index 84e1a44..d07ea1a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import asyncio import os import pytest @@ -40,6 +41,7 @@ def _create_app() -> FastAPI: from argos.server.main import ( # local import for testing purpose get_application, setup_database, + connect_to_db, ) app = get_application() @@ -49,4 +51,5 @@ def _create_app() -> FastAPI: app.state.settings.yaml_file = "tests/config.yaml" setup_database(app) + asyncio.run(connect_to_db(app)) return app diff --git a/tests/test_api.py b/tests/test_api.py index 53ca7a7..21c87d1 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,8 +1,11 @@ +import asyncio + import pytest from fastapi.testclient import TestClient from argos.schemas import AgentResult, SerializableException from argos.server import models +from argos.server.queries import update_from_config def test_read_tasks_requires_auth(app): @@ -12,6 +15,7 @@ def test_read_tasks_requires_auth(app): def test_tasks_retrieval_and_results(authorized_client, app): + asyncio.run(update_from_config(app.state.db, app.state.config)) with authorized_client as client: response = client.get("/api/tasks") assert response.status_code == 200 diff --git a/tests/test_queries.py b/tests/test_queries.py index 8770d98..d92c750 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -78,7 +78,7 @@ async def test_update_from_config_with_duplicate_tasks(db, empty_config): @pytest.mark.asyncio -async def test_update_from_config_db_can_handle_already_present_duplicates( +async def test_update_from_config_db_can_remove_duplicates_and_old_tasks( db, empty_config, task ): # Add a duplicate in the db @@ -99,12 +99,28 @@ async def test_update_from_config_db_can_handle_already_present_duplicates( dict( path="https://another-example.com", checks=[{task.check: task.expected}] ), + dict( + path=task.url, checks=[{task.check: task.expected}] + ), ], ) empty_config.websites = [website] await queries.update_from_config(db, empty_config) - assert db.query(Task).count() == 3 + assert db.query(Task).count() == 2 + + website = schemas.config.Website( + domain=task.domain, + paths=[ + dict( + path="https://another-example.com", checks=[{task.check: task.expected}] + ), + ], + ) + empty_config.websites = [website] + + await queries.update_from_config(db, empty_config) + assert db.query(Task).count() == 1 @pytest.mark.asyncio