diff --git a/.gitignore b/.gitignore index 309295c..8b227ad 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ public *.swp argos-config.yaml config.yaml +websites.yaml dist diff --git a/CHANGELOG.md b/CHANGELOG.md index aea0778..f3b3152 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ - 📝 — Improve OpenAPI doc - 🤕 — Fix order of tasks sent to agent - ✨ — Add application API (fix #86) +- ⚡️ — Faster websites configuration reloading (#85) ## 0.9.0 diff --git a/argos/commands.py b/argos/commands.py index 7bd5690..7a33150 100644 --- a/argos/commands.py +++ b/argos/commands.py @@ -182,8 +182,9 @@ async def reload_config(config, enqueue): else: changed = await queries.update_from_config(db, _config) + click.echo(f"{changed['deleted']} task(s) deleted") click.echo(f"{changed['added']} task(s) added") - click.echo(f"{changed['vanished']} task(s) deleted") + click.echo(f"{changed['updated']} task(s) updated") @server.command() diff --git a/argos/server/migrations/versions/655eefd69858_add_table_for_configuration_comparison.py b/argos/server/migrations/versions/655eefd69858_add_table_for_configuration_comparison.py new file mode 100644 index 0000000..75ad5da --- /dev/null +++ b/argos/server/migrations/versions/655eefd69858_add_table_for_configuration_comparison.py @@ -0,0 +1,54 @@ +"""Add table for configuration comparison + +Revision ID: 655eefd69858 +Revises: 1d0aaa07743c +Create Date: 2025-03-20 14:13:33.006662 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "655eefd69858" +down_revision: Union[str, None] = "1d0aaa07743c" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + ip_version_enum = sa.Enum("4", "6", name="ip_version_enum", create_type=False) + method_enum = sa.Enum( + "GET", + "HEAD", + "POST", + "OPTIONS", + "CONNECT", + "TRACE", + "PUT", + "PATCH", + "DELETE", + name="method", + create_type=False, + ) + op.create_table( + "tasks_tmp", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("url", sa.String(), nullable=False), + sa.Column("domain", sa.String(), nullable=False), + sa.Column("check", sa.String(), nullable=False), + sa.Column("expected", sa.String(), nullable=False), + sa.Column("frequency", sa.Float(), nullable=False), + sa.Column("recheck_delay", sa.Float(), nullable=True), + sa.Column("retry_before_notification", sa.Integer(), nullable=False), + sa.Column("request_data", sa.String(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.add_column("tasks_tmp", sa.Column("ip_version", ip_version_enum, nullable=False)) + op.add_column("tasks_tmp", sa.Column("method", method_enum, nullable=False)) + + +def downgrade() -> None: + op.drop_table("tasks_tmp") diff --git a/argos/server/models.py b/argos/server/models.py index 7d6997d..380289f 100644 --- a/argos/server/models.py +++ b/argos/server/models.py @@ -151,6 +151,44 @@ class Task(Base): Index("similar_tasks", Task.task_group) +class TaskTmp(Base): + """Table with temporary data, only used for websites + configuration refreshing""" + + __tablename__ = "tasks_tmp" + id: Mapped[int] = mapped_column(primary_key=True) + + url: Mapped[str] = mapped_column() + domain: Mapped[str] = mapped_column() + ip_version: Mapped[IPVersion] = mapped_column( + Enum("4", "6", name="ip_version_enum"), + ) + check: Mapped[str] = mapped_column() + expected: Mapped[str] = mapped_column() + frequency: Mapped[float] = mapped_column() + recheck_delay: Mapped[float] = mapped_column(nullable=True) + retry_before_notification: Mapped[int] = mapped_column(insert_default=0) + method: Mapped[Method] = mapped_column( + Enum( + "GET", + "HEAD", + "POST", + "OPTIONS", + "CONNECT", + "TRACE", + "PUT", + "PATCH", + "DELETE", + name="method", + ), + insert_default="GET", + ) + request_data: Mapped[str] = mapped_column(nullable=True) + + def __str__(self) -> str: + return f"DB TaskTmp {self.url} (IPv{self.ip_version}) - {self.check} - {self.expected}" + + class Result(Base): """There are multiple results per task. diff --git a/argos/server/queries.py b/argos/server/queries.py index faab67e..cdf80fe 100644 --- a/argos/server/queries.py +++ b/argos/server/queries.py @@ -1,18 +1,26 @@ """Functions to ease SQL queries management""" from datetime import datetime, timedelta from hashlib import sha256 -from typing import List from urllib.parse import urljoin import jwt from fastapi import Request -from sqlalchemy import asc, func, Select -from sqlalchemy.orm import Session +from sqlalchemy import and_, asc, func, or_, Select +from sqlalchemy.orm import aliased, Session +from sqlalchemy.sql import text as sa_text from argos import schemas from argos.logging import logger -from argos.server.models import BlockedToken, ConfigCache, Job, Result, Task, User +from argos.server.models import ( + BlockedToken, + ConfigCache, + Job, + Result, + Task, + TaskTmp, + User, +) from argos.server.settings import read_config @@ -287,8 +295,9 @@ async def process_jobs(db: Session) -> int: logger.info("Processing job %i: %s %s", job.id, job.todo, job.args) _config = read_config(job.args) changed = await update_from_config(db, _config) + logger.info("%i task(s) deleted", changed["deleted"]) logger.info("%i task(s) added", changed["added"]) - logger.info("%i task(s) deleted", changed["vanished"]) + logger.info("%i task(s) updated", changed["updated"]) db.delete(job) db.commit() @@ -297,15 +306,204 @@ async def process_jobs(db: Session) -> int: return 0 -async def update_from_config(db: Session, config: schemas.Config): # pylint: disable-msg=too-many-branches - """Update tasks from config file""" - max_task_id = ( - db.query(func.max(Task.id).label("max_id")).all() # pylint: disable-msg=not-callable - )[0].max_id +async def delete_duplicate_tasks(db: Session) -> int: + """Find duplicate tasks in DB and delete one of them""" + + f_task = aliased(Task) + s_task = aliased(Task) + duplicate_tasks = ( + db.query(f_task, s_task) + .join( + s_task, + and_( + f_task.url == s_task.url, + f_task.method == s_task.method, + or_( + and_( + f_task.request_data == None, + s_task.request_data == None, + ), + f_task.request_data == s_task.request_data, + ), + f_task.check == s_task.check, + f_task.expected == s_task.expected, + f_task.ip_version == s_task.ip_version, + ), + ) + .filter(f_task.id != s_task.id) + .all() + ) + deleted_duplicate_tasks = len(duplicate_tasks) + primary_duplicate_ids: list[int] = [] + for i in duplicate_tasks: + primary = i[0] + secondary = i[1] + if primary.id in primary_duplicate_ids: + db.delete(secondary) + elif secondary not in primary_duplicate_ids: + primary_duplicate_ids.append(primary.id) + db.delete(secondary) + else: + db.delete(primary) + + if deleted_duplicate_tasks: + db.commit() + + return deleted_duplicate_tasks + + +async def delete_vanished_tasks(db: Session) -> int: + """Delete tasks not in temporary config tasks table""" + + tasks_to_delete = ( + db.query(Task) + .outerjoin( + TaskTmp, + and_( + Task.url == TaskTmp.url, + Task.method == TaskTmp.method, + or_( + and_( + Task.request_data == None, + TaskTmp.request_data == None, + ), + Task.request_data == TaskTmp.request_data, + ), + Task.check == TaskTmp.check, + Task.expected == TaskTmp.expected, + Task.ip_version == TaskTmp.ip_version, + ), + ) + .filter( + TaskTmp.url == None, + ) + .all() + ) + + vanished_tasks = len(tasks_to_delete) + for task in tasks_to_delete: + logger.debug("Deleting a task from the db: %s", task) + db.delete(task) + + if vanished_tasks: + db.commit() + + return vanished_tasks + + +async def add_tasks_from_config_table(db: Session) -> int: + """Add tasks from temporary config tasks table""" + + tasks_to_add = ( + db.query(TaskTmp) + .outerjoin( + Task, + and_( + TaskTmp.url == Task.url, + TaskTmp.method == Task.method, + or_( + and_( + TaskTmp.request_data == None, + Task.request_data == None, + ), + TaskTmp.request_data == Task.request_data, + ), + TaskTmp.check == Task.check, + TaskTmp.expected == Task.expected, + TaskTmp.ip_version == Task.ip_version, + ), + ) + .filter( + Task.url == None, + ) + .all() + ) + tasks = [] + for task_tmp in tasks_to_add: + task = Task( + domain=task_tmp.domain, + url=task_tmp.url, + ip_version=task_tmp.ip_version, + method=task_tmp.method, + request_data=task_tmp.request_data, + check=task_tmp.check, + expected=task_tmp.expected, + frequency=task_tmp.frequency, + recheck_delay=task_tmp.recheck_delay, + retry_before_notification=task_tmp.retry_before_notification, + already_retried=False, + ) + + logger.debug("Adding a new task in the db: %s", task) + + tasks.append(task) + + if tasks: + db.add_all(tasks) + db.commit() + + return len(tasks) + + +async def update_tasks(db: Session) -> int: + """Update tasks from temporary config tasks table""" + + tasks_to_update = ( + db.query(Task, TaskTmp) + .join( + TaskTmp, + and_( + Task.url == TaskTmp.url, + Task.method == TaskTmp.method, + or_( + and_( + Task.request_data == None, + TaskTmp.request_data == None, + ), + Task.request_data == TaskTmp.request_data, + ), + Task.check == TaskTmp.check, + Task.expected == TaskTmp.expected, + Task.ip_version == TaskTmp.ip_version, + ), + ) + .filter( + or_( + Task.frequency != TaskTmp.frequency, + Task.recheck_delay != TaskTmp.recheck_delay, + Task.retry_before_notification != TaskTmp.retry_before_notification, + ) + ) + .all() + ) + updated_tasks = len(tasks_to_update) + for tasks in tasks_to_update: + task = tasks[0] + task_tmp = tasks[1] + + logger.debug("Updating task: %s", task) + + task.frequency = task_tmp.frequency + task.recheck_delay = task_tmp.recheck_delay + task.retry_before_notification = task_tmp.retry_before_notification + + if updated_tasks: + db.commit() + + return updated_tasks + + +async def update_from_config(db: Session, config: schemas.Config): + """Update tasks from config file""" + + deleted_duplicate_tasks = await delete_duplicate_tasks(db) + unique_properties = [] - seen_tasks: List[int] = [] - for website in config.websites: # pylint: disable-msg=too-many-nested-blocks + tmp_tasks: list[TaskTmp] = [] + + # Fill the tasks_tmp table + for website in config.websites: domain = str(website.domain) frequency = website.frequency or config.general.frequency recheck_delay = website.recheck_delay or config.general.recheck_delay @@ -324,101 +522,61 @@ async def update_from_config(db: Session, config: schemas.Config): # pylint: di for p in website.paths: url = urljoin(domain, str(p.path)) for check_key, expected in p.checks: - # Check the db for already existing tasks. - existing_tasks = ( - db.query(Task) - .filter( - Task.url == url, - Task.method == p.method, - Task.request_data == p.request_data, - Task.check == check_key, - Task.expected == expected, - Task.ip_version == ip_version, - ) - .all() - ) - if (ip_version == "4" and ipv4 is False) or ( ip_version == "6" and ipv6 is False ): continue - if existing_tasks: - existing_task = existing_tasks[0] - - seen_tasks.append(existing_task.id) - - if frequency != existing_task.frequency: - existing_task.frequency = frequency - if recheck_delay != existing_task.recheck_delay: - existing_task.recheck_delay = recheck_delay # type: ignore[assignment] - if ( - retry_before_notification - != existing_task.retry_before_notification - ): - existing_task.retry_before_notification = ( - retry_before_notification - ) - logger.debug( - "Skipping db task creation for url=%s, " - "method=%s, check_key=%s, expected=%s, " - "frequency=%s, recheck_delay=%s, " - "retry_before_notification=%s, ip_version=%s.", - url, - p.method, - check_key, - expected, - frequency, - recheck_delay, - retry_before_notification, - ip_version, + properties = ( + url, + p.method, + p.request_data, + check_key, + expected, + ip_version, + ) + if properties not in unique_properties: + unique_properties.append(properties) + config_task = TaskTmp( + domain=domain, + url=url, + method=p.method, + request_data=p.request_data, + check=check_key, + expected=expected, + ip_version=ip_version, + frequency=frequency, + recheck_delay=recheck_delay, + retry_before_notification=retry_before_notification, ) + tmp_tasks.append(config_task) - else: - properties = ( - url, - p.method, - check_key, - expected, - ip_version, - p.request_data, - ) - if properties not in unique_properties: - unique_properties.append(properties) - task = Task( - domain=domain, - url=url, - ip_version=ip_version, - method=p.method, - request_data=p.request_data, - check=check_key, - expected=expected, - frequency=frequency, - recheck_delay=recheck_delay, - retry_before_notification=retry_before_notification, - already_retried=False, - ) - logger.debug("Adding a new task in the db: %s", task) - tasks.append(task) - - db.add_all(tasks) + db.add_all(tmp_tasks) db.commit() - # Delete vanished tasks - if max_task_id: - vanished_tasks = ( - db.query(Task) - .filter(Task.id <= max_task_id, Task.id.not_in(seen_tasks)) - .delete() - ) - db.commit() - logger.info( - "%i task(s) has been removed since not in config file anymore", - vanished_tasks, - ) - return {"added": len(tasks), "vanished": vanished_tasks} + vanished_tasks = await delete_vanished_tasks(db) - return {"added": len(tasks), "vanished": 0} + added_tasks = await add_tasks_from_config_table(db) + + updated_tasks = await update_tasks(db) + + if str(config.general.db.url).startswith("sqlite"): + # SQLite has no TRUNCATE instruction + # See https://www.techonthenet.com/sqlite/truncate.php + logger.debug("Truncating tasks_tmp table (sqlite)") + db.query(TaskTmp).delete() + db.commit() + else: + logger.debug("Truncating tasks_tmp table") + db.execute( + sa_text("TRUNCATE TABLE tasks_tmp;").execution_options(autocommit=True) + ) + + return { + "added": added_tasks, + "deleted": vanished_tasks + deleted_duplicate_tasks, + "updated": updated_tasks, + } async def get_severity_counts(db: Session) -> dict: