— Faster websites configuration reloading (fix #85)

This commit is contained in:
Luc Didry 2025-03-20 16:50:09 +01:00
parent 32d11c5598
commit ee703a505f
6 changed files with 351 additions and 98 deletions

1
.gitignore vendored
View file

@ -7,4 +7,5 @@ public
*.swp
argos-config.yaml
config.yaml
websites.yaml
dist

View file

@ -9,6 +9,7 @@
- 📝 — Improve OpenAPI doc
- 🤕 — Fix order of tasks sent to agent
- ✨ — Add application API (fix #86)
- ⚡️ — Faster websites configuration reloading (#85)
## 0.9.0

View file

@ -182,8 +182,9 @@ async def reload_config(config, enqueue):
else:
changed = await queries.update_from_config(db, _config)
click.echo(f"{changed['deleted']} task(s) deleted")
click.echo(f"{changed['added']} task(s) added")
click.echo(f"{changed['vanished']} task(s) deleted")
click.echo(f"{changed['updated']} task(s) updated")
@server.command()

View file

@ -0,0 +1,54 @@
"""Add table for configuration comparison
Revision ID: 655eefd69858
Revises: 1d0aaa07743c
Create Date: 2025-03-20 14:13:33.006662
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "655eefd69858"
down_revision: Union[str, None] = "1d0aaa07743c"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
ip_version_enum = sa.Enum("4", "6", name="ip_version_enum", create_type=False)
method_enum = sa.Enum(
"GET",
"HEAD",
"POST",
"OPTIONS",
"CONNECT",
"TRACE",
"PUT",
"PATCH",
"DELETE",
name="method",
create_type=False,
)
op.create_table(
"tasks_tmp",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("url", sa.String(), nullable=False),
sa.Column("domain", sa.String(), nullable=False),
sa.Column("check", sa.String(), nullable=False),
sa.Column("expected", sa.String(), nullable=False),
sa.Column("frequency", sa.Float(), nullable=False),
sa.Column("recheck_delay", sa.Float(), nullable=True),
sa.Column("retry_before_notification", sa.Integer(), nullable=False),
sa.Column("request_data", sa.String(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.add_column("tasks_tmp", sa.Column("ip_version", ip_version_enum, nullable=False))
op.add_column("tasks_tmp", sa.Column("method", method_enum, nullable=False))
def downgrade() -> None:
op.drop_table("tasks_tmp")

View file

@ -151,6 +151,44 @@ class Task(Base):
Index("similar_tasks", Task.task_group)
class TaskTmp(Base):
"""Table with temporary data, only used for websites
configuration refreshing"""
__tablename__ = "tasks_tmp"
id: Mapped[int] = mapped_column(primary_key=True)
url: Mapped[str] = mapped_column()
domain: Mapped[str] = mapped_column()
ip_version: Mapped[IPVersion] = mapped_column(
Enum("4", "6", name="ip_version_enum"),
)
check: Mapped[str] = mapped_column()
expected: Mapped[str] = mapped_column()
frequency: Mapped[float] = mapped_column()
recheck_delay: Mapped[float] = mapped_column(nullable=True)
retry_before_notification: Mapped[int] = mapped_column(insert_default=0)
method: Mapped[Method] = mapped_column(
Enum(
"GET",
"HEAD",
"POST",
"OPTIONS",
"CONNECT",
"TRACE",
"PUT",
"PATCH",
"DELETE",
name="method",
),
insert_default="GET",
)
request_data: Mapped[str] = mapped_column(nullable=True)
def __str__(self) -> str:
return f"DB TaskTmp {self.url} (IPv{self.ip_version}) - {self.check} - {self.expected}"
class Result(Base):
"""There are multiple results per task.

View file

@ -1,18 +1,26 @@
"""Functions to ease SQL queries management"""
from datetime import datetime, timedelta
from hashlib import sha256
from typing import List
from urllib.parse import urljoin
import jwt
from fastapi import Request
from sqlalchemy import asc, func, Select
from sqlalchemy.orm import Session
from sqlalchemy import and_, asc, func, or_, Select
from sqlalchemy.orm import aliased, Session
from sqlalchemy.sql import text as sa_text
from argos import schemas
from argos.logging import logger
from argos.server.models import BlockedToken, ConfigCache, Job, Result, Task, User
from argos.server.models import (
BlockedToken,
ConfigCache,
Job,
Result,
Task,
TaskTmp,
User,
)
from argos.server.settings import read_config
@ -287,8 +295,9 @@ async def process_jobs(db: Session) -> int:
logger.info("Processing job %i: %s %s", job.id, job.todo, job.args)
_config = read_config(job.args)
changed = await update_from_config(db, _config)
logger.info("%i task(s) deleted", changed["deleted"])
logger.info("%i task(s) added", changed["added"])
logger.info("%i task(s) deleted", changed["vanished"])
logger.info("%i task(s) updated", changed["updated"])
db.delete(job)
db.commit()
@ -297,15 +306,204 @@ async def process_jobs(db: Session) -> int:
return 0
async def update_from_config(db: Session, config: schemas.Config): # pylint: disable-msg=too-many-branches
"""Update tasks from config file"""
max_task_id = (
db.query(func.max(Task.id).label("max_id")).all() # pylint: disable-msg=not-callable
)[0].max_id
async def delete_duplicate_tasks(db: Session) -> int:
"""Find duplicate tasks in DB and delete one of them"""
f_task = aliased(Task)
s_task = aliased(Task)
duplicate_tasks = (
db.query(f_task, s_task)
.join(
s_task,
and_(
f_task.url == s_task.url,
f_task.method == s_task.method,
or_(
and_(
f_task.request_data == None,
s_task.request_data == None,
),
f_task.request_data == s_task.request_data,
),
f_task.check == s_task.check,
f_task.expected == s_task.expected,
f_task.ip_version == s_task.ip_version,
),
)
.filter(f_task.id != s_task.id)
.all()
)
deleted_duplicate_tasks = len(duplicate_tasks)
primary_duplicate_ids: list[int] = []
for i in duplicate_tasks:
primary = i[0]
secondary = i[1]
if primary.id in primary_duplicate_ids:
db.delete(secondary)
elif secondary not in primary_duplicate_ids:
primary_duplicate_ids.append(primary.id)
db.delete(secondary)
else:
db.delete(primary)
if deleted_duplicate_tasks:
db.commit()
return deleted_duplicate_tasks
async def delete_vanished_tasks(db: Session) -> int:
"""Delete tasks not in temporary config tasks table"""
tasks_to_delete = (
db.query(Task)
.outerjoin(
TaskTmp,
and_(
Task.url == TaskTmp.url,
Task.method == TaskTmp.method,
or_(
and_(
Task.request_data == None,
TaskTmp.request_data == None,
),
Task.request_data == TaskTmp.request_data,
),
Task.check == TaskTmp.check,
Task.expected == TaskTmp.expected,
Task.ip_version == TaskTmp.ip_version,
),
)
.filter(
TaskTmp.url == None,
)
.all()
)
vanished_tasks = len(tasks_to_delete)
for task in tasks_to_delete:
logger.debug("Deleting a task from the db: %s", task)
db.delete(task)
if vanished_tasks:
db.commit()
return vanished_tasks
async def add_tasks_from_config_table(db: Session) -> int:
"""Add tasks from temporary config tasks table"""
tasks_to_add = (
db.query(TaskTmp)
.outerjoin(
Task,
and_(
TaskTmp.url == Task.url,
TaskTmp.method == Task.method,
or_(
and_(
TaskTmp.request_data == None,
Task.request_data == None,
),
TaskTmp.request_data == Task.request_data,
),
TaskTmp.check == Task.check,
TaskTmp.expected == Task.expected,
TaskTmp.ip_version == Task.ip_version,
),
)
.filter(
Task.url == None,
)
.all()
)
tasks = []
for task_tmp in tasks_to_add:
task = Task(
domain=task_tmp.domain,
url=task_tmp.url,
ip_version=task_tmp.ip_version,
method=task_tmp.method,
request_data=task_tmp.request_data,
check=task_tmp.check,
expected=task_tmp.expected,
frequency=task_tmp.frequency,
recheck_delay=task_tmp.recheck_delay,
retry_before_notification=task_tmp.retry_before_notification,
already_retried=False,
)
logger.debug("Adding a new task in the db: %s", task)
tasks.append(task)
if tasks:
db.add_all(tasks)
db.commit()
return len(tasks)
async def update_tasks(db: Session) -> int:
"""Update tasks from temporary config tasks table"""
tasks_to_update = (
db.query(Task, TaskTmp)
.join(
TaskTmp,
and_(
Task.url == TaskTmp.url,
Task.method == TaskTmp.method,
or_(
and_(
Task.request_data == None,
TaskTmp.request_data == None,
),
Task.request_data == TaskTmp.request_data,
),
Task.check == TaskTmp.check,
Task.expected == TaskTmp.expected,
Task.ip_version == TaskTmp.ip_version,
),
)
.filter(
or_(
Task.frequency != TaskTmp.frequency,
Task.recheck_delay != TaskTmp.recheck_delay,
Task.retry_before_notification != TaskTmp.retry_before_notification,
)
)
.all()
)
updated_tasks = len(tasks_to_update)
for tasks in tasks_to_update:
task = tasks[0]
task_tmp = tasks[1]
logger.debug("Updating task: %s", task)
task.frequency = task_tmp.frequency
task.recheck_delay = task_tmp.recheck_delay
task.retry_before_notification = task_tmp.retry_before_notification
if updated_tasks:
db.commit()
return updated_tasks
async def update_from_config(db: Session, config: schemas.Config):
"""Update tasks from config file"""
deleted_duplicate_tasks = await delete_duplicate_tasks(db)
unique_properties = []
seen_tasks: List[int] = []
for website in config.websites: # pylint: disable-msg=too-many-nested-blocks
tmp_tasks: list[TaskTmp] = []
# Fill the tasks_tmp table
for website in config.websites:
domain = str(website.domain)
frequency = website.frequency or config.general.frequency
recheck_delay = website.recheck_delay or config.general.recheck_delay
@ -324,101 +522,61 @@ async def update_from_config(db: Session, config: schemas.Config): # pylint: di
for p in website.paths:
url = urljoin(domain, str(p.path))
for check_key, expected in p.checks:
# Check the db for already existing tasks.
existing_tasks = (
db.query(Task)
.filter(
Task.url == url,
Task.method == p.method,
Task.request_data == p.request_data,
Task.check == check_key,
Task.expected == expected,
Task.ip_version == ip_version,
)
.all()
)
if (ip_version == "4" and ipv4 is False) or (
ip_version == "6" and ipv6 is False
):
continue
if existing_tasks:
existing_task = existing_tasks[0]
seen_tasks.append(existing_task.id)
if frequency != existing_task.frequency:
existing_task.frequency = frequency
if recheck_delay != existing_task.recheck_delay:
existing_task.recheck_delay = recheck_delay # type: ignore[assignment]
if (
retry_before_notification
!= existing_task.retry_before_notification
):
existing_task.retry_before_notification = (
retry_before_notification
)
logger.debug(
"Skipping db task creation for url=%s, "
"method=%s, check_key=%s, expected=%s, "
"frequency=%s, recheck_delay=%s, "
"retry_before_notification=%s, ip_version=%s.",
url,
p.method,
check_key,
expected,
frequency,
recheck_delay,
retry_before_notification,
ip_version,
)
else:
properties = (
url,
p.method,
p.request_data,
check_key,
expected,
ip_version,
p.request_data,
)
if properties not in unique_properties:
unique_properties.append(properties)
task = Task(
config_task = TaskTmp(
domain=domain,
url=url,
ip_version=ip_version,
method=p.method,
request_data=p.request_data,
check=check_key,
expected=expected,
ip_version=ip_version,
frequency=frequency,
recheck_delay=recheck_delay,
retry_before_notification=retry_before_notification,
already_retried=False,
)
logger.debug("Adding a new task in the db: %s", task)
tasks.append(task)
tmp_tasks.append(config_task)
db.add_all(tasks)
db.add_all(tmp_tasks)
db.commit()
# Delete vanished tasks
if max_task_id:
vanished_tasks = (
db.query(Task)
.filter(Task.id <= max_task_id, Task.id.not_in(seen_tasks))
.delete()
)
db.commit()
logger.info(
"%i task(s) has been removed since not in config file anymore",
vanished_tasks,
)
return {"added": len(tasks), "vanished": vanished_tasks}
vanished_tasks = await delete_vanished_tasks(db)
return {"added": len(tasks), "vanished": 0}
added_tasks = await add_tasks_from_config_table(db)
updated_tasks = await update_tasks(db)
if str(config.general.db.url).startswith("sqlite"):
# SQLite has no TRUNCATE instruction
# See https://www.techonthenet.com/sqlite/truncate.php
logger.debug("Truncating tasks_tmp table (sqlite)")
db.query(TaskTmp).delete()
db.commit()
else:
logger.debug("Truncating tasks_tmp table")
db.execute(
sa_text("TRUNCATE TABLE tasks_tmp;").execution_options(autocommit=True)
)
return {
"added": added_tasks,
"deleted": vanished_tasks + deleted_duplicate_tasks,
"updated": updated_tasks,
}
async def get_severity_counts(db: Session) -> dict: