🔀 Merge branch 'fix-19' into 'main'

🐛 — Delete tasks which are not in config file anymore (fix #19)

Closes #19

See merge request framasoft/framaspace/argos!25
This commit is contained in:
Luc Didry 2024-03-25 10:41:19 +00:00
commit 6f856afe4a
14 changed files with 244 additions and 14 deletions

View file

@ -24,6 +24,8 @@ djlint: venv ## Format the templates
venv/bin/djlint --ignore=H030,H031,H006 --profile jinja --lint argos/server/templates/*html
pylint: venv ## Runs pylint on the code
venv/bin/pylint argos
pylint-alembic: venv ## Runs pylint on alembic migration files
venv/bin/pylint --disable invalid-name,no-member alembic/versions/*.py
lint: djlint pylint
help:
@python3 -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST)

View file

@ -28,6 +28,7 @@ def run_migrations_offline() -> None:
context.configure(
url=url,
target_metadata=target_metadata,
render_as_batch=True,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
@ -50,7 +51,10 @@ def run_migrations_online() -> None:
)
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
context.configure(connection=connection,
target_metadata=target_metadata,
render_as_batch=True,
)
with context.begin_transaction():
context.run_migrations()

View file

@ -0,0 +1,31 @@
"""Adding ConfigCache model
Revision ID: 1a3497f9f71b
Revises: e99bc35702c9
Create Date: 2024-03-13 15:28:09.185377
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '1a3497f9f71b'
down_revision: Union[str, None] = 'e99bc35702c9'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table('config_cache',
sa.Column('name', sa.String(), nullable=False),
sa.Column('val', sa.String(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint('name')
)
def downgrade() -> None:
op.drop_table('config_cache')

View file

@ -53,6 +53,7 @@ def upgrade() -> None:
sa.ForeignKeyConstraint(
["task_id"],
["tasks.id"],
name="results_task_id_fkey",
),
sa.PrimaryKeyConstraint("id"),
)

View file

@ -0,0 +1,33 @@
"""Add ON DELETE CASCADE to results task_id
Revision ID: defda3f2952d
Revises: 1a3497f9f71b
Create Date: 2024-03-18 15:09:34.544573
"""
from typing import Sequence, Union
from alembic import op
# revision identifiers, used by Alembic.
revision: str = 'defda3f2952d'
down_revision: Union[str, None] = '1a3497f9f71b'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
with op.batch_alter_table('results', schema=None) as batch_op:
batch_op.drop_constraint('results_task_id_fkey', type_='foreignkey')
batch_op.create_foreign_key('results_task_id_fkey',
'tasks',
['task_id'],
['id'],
ondelete='CASCADE')
def downgrade() -> None:
with op.batch_alter_table('results', schema=None) as batch_op:
batch_op.drop_constraint('results_task_id_fkey', type_='foreignkey')
batch_op.create_foreign_key('results_task_id_fkey', 'tasks', ['task_id'], ['id'])

View file

@ -4,6 +4,8 @@ from functools import wraps
import click
import uvicorn
from alembic import command
from alembic.config import Config
from argos import logging
from argos.agent import ArgosAgent
@ -131,5 +133,41 @@ async def cleandb(max_results, max_lock_seconds):
click.echo(f"{updated} locks released")
@server.command()
@coroutine
async def reload_config():
"""Read tasks config and add/delete tasks in database if needed
"""
# The imports are made here otherwise the agent will need server configuration files.
from argos.server import queries
from argos.server.main import get_application, read_config
from argos.server.settings import get_app_settings
appli = get_application()
settings = get_app_settings()
config = read_config(appli, settings)
db = await get_db()
changed = await queries.update_from_config(db, config)
click.echo(f"{changed['added']} tasks added")
click.echo(f"{changed['vanished']} tasks deleted")
@server.command()
@coroutine
async def migrate():
"""Run database migrations
"""
# The imports are made here otherwise the agent will need server configuration files.
from argos.server.settings import get_app_settings
settings = get_app_settings()
alembic_cfg = Config("alembic.ini")
alembic_cfg.set_main_option("sqlalchemy.url", settings.database_url)
command.upgrade(alembic_cfg, "head")
if __name__ == "__main__":
cli()

View file

@ -3,11 +3,11 @@ import sys
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from pydantic import ValidationError
from sqlalchemy import create_engine
from sqlalchemy import create_engine, event
from sqlalchemy.orm import sessionmaker
from argos.logging import logger
from argos.server import models, queries, routes
from argos.server import models, routes
from argos.server.settings import get_app_settings, read_yaml_config
@ -39,15 +39,14 @@ def get_application() -> FastAPI:
def create_start_app_handler(appli):
"""Warmup the server:
setup database connection and update the tasks in it before making it available
setup database connection
"""
async def read_config_and_populate_db():
async def _get_db():
setup_database(appli)
db = await connect_to_db(appli)
await queries.update_from_config(db, appli.state.config)
return await connect_to_db(appli)
return read_config_and_populate_db
return _get_db
async def connect_to_db(appli):
@ -95,6 +94,13 @@ def setup_database(appli):
settings.database_url,
**extra_settings
)
def _fk_pragma_on_connect(dbapi_con, con_record):
dbapi_con.execute('pragma foreign_keys=ON')
if settings.database_url.startswith("sqlite:////"):
event.listen(engine, 'connect', _fk_pragma_on_connect)
appli.state.SessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine
)

View file

@ -47,7 +47,9 @@ class Task(Base):
)
last_severity_update: Mapped[datetime] = mapped_column(nullable=True)
results: Mapped[List["Result"]] = relationship(back_populates="task")
results: Mapped[List["Result"]] = relationship(back_populates="task",
cascade="all, delete",
passive_deletes=True,)
def __str__(self):
return f"DB Task {self.url} - {self.check} - {self.expected}"
@ -92,7 +94,7 @@ class Result(Base):
"""
__tablename__ = "results"
id: Mapped[int] = mapped_column(primary_key=True)
task_id: Mapped[int] = mapped_column(ForeignKey("tasks.id"))
task_id: Mapped[int] = mapped_column(ForeignKey("tasks.id", ondelete="CASCADE"))
task: Mapped["Task"] = relationship(back_populates="results")
agent_id: Mapped[str] = mapped_column(nullable=True)
@ -112,3 +114,19 @@ class Result(Base):
def __str__(self):
return f"DB Result {self.id} - {self.status} - {self.context}"
class ConfigCache(Base):
"""Contains some informations on the previous config state
Used to quickly determine if we need to update the tasks.
There is currently two cached settings:
- general_frequency: the content of general.frequency setting, in minutes
ex: 5
- websites_hash: the sha256sum of websites setting, to allow a quick
comparison without looping through all websites
ex: 8b886e7db7b553fe99f6d5437f31745987e243c77b2109b84cf9a7f8bf7d75b1
"""
__tablename__ = "config_cache"
name: Mapped[str] = mapped_column(primary_key=True)
val: Mapped[str] = mapped_column()
updated_at: Mapped[datetime] = mapped_column()

View file

@ -1,5 +1,7 @@
"""Functions to ease SQL queries management"""
from datetime import datetime, timedelta
from hashlib import sha256
from typing import List
from urllib.parse import urljoin
from sqlalchemy import desc, func
@ -7,7 +9,7 @@ from sqlalchemy.orm import Session
from argos import schemas
from argos.logging import logger
from argos.server.models import Result, Task
from argos.server.models import Result, Task, ConfigCache
async def list_tasks(db: Session, agent_id: str, limit: int = 100):
@ -60,10 +62,64 @@ async def count_results(db: Session):
return db.query(Result).count()
async def has_config_changed(db: Session, config: schemas.Config) -> bool:
"""Check if websites config has changed by using a hashsum and a config cache"""
websites_hash = sha256(str(config.websites).encode()).hexdigest()
conf_caches = (
db.query(ConfigCache)
.all()
)
same_config = True
if conf_caches:
for conf in conf_caches:
match (conf.name):
case 'websites_hash':
if conf.val != websites_hash:
same_config = False
conf.val = websites_hash
conf.updated_at = datetime.now()
case 'general_frequency':
if conf.val != str(config.general.frequency):
same_config = False
conf.val = config.general.frequency
conf.updated_at = datetime.now()
db.commit()
if same_config:
return False
else: # no config cache found
web_hash = ConfigCache(
name='websites_hash',
val=websites_hash,
updated_at=datetime.now()
)
gen_freq = ConfigCache(
name='general_frequency',
val=str(config.general.frequency),
updated_at=datetime.now()
)
db.add(web_hash)
db.add(gen_freq)
db.commit()
return True
async def update_from_config(db: Session, config: schemas.Config):
"""Update tasks from config file"""
config_changed = await has_config_changed(db, config)
if not config_changed:
return {'added': 0, 'vanished': 0}
max_task_id = (
db.query(func.max(Task.id).label('max_id')) # pylint: disable-msg=not-callable
.all()
)[0].max_id
tasks = []
unique_properties = []
seen_tasks: List[int] = []
for website in config.websites:
domain = str(website.domain)
frequency = website.frequency or config.general.frequency
@ -83,6 +139,7 @@ async def update_from_config(db: Session, config: schemas.Config):
)
if existing_tasks:
existing_task = existing_tasks[0]
seen_tasks.append(existing_task.id)
if frequency != existing_task.frequency:
existing_task.frequency = frequency
@ -107,6 +164,21 @@ async def update_from_config(db: Session, config: schemas.Config):
db.add_all(tasks)
db.commit()
# Delete vanished tasks
if max_task_id:
vanished_tasks = (
db.query(Task)
.filter(
Task.id <= max_task_id,
Task.id.not_in(seen_tasks)
).delete()
)
db.commit()
logger.info("%i tasks has been removed since not in config file anymore", vanished_tasks)
return {'added': len(tasks), 'vanished': vanished_tasks}
return {'added': len(tasks), 'vanished': 0}
async def get_severity_counts(db: Session) -> dict:
"""Get the severities (ok, warning, critical…) and their count"""

View file

@ -62,7 +62,7 @@ async def create_results(
# XXX Use a job queue or make it async
handle_alert(config, result, task, severity, last_severity, request)
db_results.append(result)
db_results.append(result)
db.commit()
return {"result_ids": [r.id for r in db_results]}

View file

@ -8,6 +8,8 @@ PartOf=postgresql.service
[Service]
User=www-data
WorkingDirectory=/var/www/argos/
ExecStartPre=/var/www/argos/venv/bin/argos server migrate
ExecStartPre=/var/www/argos/venv/bin/argos server reload-config
ExecStart=/var/www/argos/venv/bin/argos server start
ExecReload=/var/www/argos/venv/bin/argos server reload
SyslogIdentifier=argos-server

View file

@ -1,3 +1,4 @@
import asyncio
import os
import pytest
@ -40,6 +41,7 @@ def _create_app() -> FastAPI:
from argos.server.main import ( # local import for testing purpose
get_application,
setup_database,
connect_to_db,
)
app = get_application()
@ -49,4 +51,5 @@ def _create_app() -> FastAPI:
app.state.settings.yaml_file = "tests/config.yaml"
setup_database(app)
asyncio.run(connect_to_db(app))
return app

View file

@ -1,8 +1,11 @@
import asyncio
import pytest
from fastapi.testclient import TestClient
from argos.schemas import AgentResult, SerializableException
from argos.server import models
from argos.server.queries import update_from_config
def test_read_tasks_requires_auth(app):
@ -12,6 +15,7 @@ def test_read_tasks_requires_auth(app):
def test_tasks_retrieval_and_results(authorized_client, app):
asyncio.run(update_from_config(app.state.db, app.state.config))
with authorized_client as client:
response = client.get("/api/tasks")
assert response.status_code == 200

View file

@ -78,7 +78,7 @@ async def test_update_from_config_with_duplicate_tasks(db, empty_config):
@pytest.mark.asyncio
async def test_update_from_config_db_can_handle_already_present_duplicates(
async def test_update_from_config_db_can_remove_duplicates_and_old_tasks(
db, empty_config, task
):
# Add a duplicate in the db
@ -99,12 +99,28 @@ async def test_update_from_config_db_can_handle_already_present_duplicates(
dict(
path="https://another-example.com", checks=[{task.check: task.expected}]
),
dict(
path=task.url, checks=[{task.check: task.expected}]
),
],
)
empty_config.websites = [website]
await queries.update_from_config(db, empty_config)
assert db.query(Task).count() == 3
assert db.query(Task).count() == 2
website = schemas.config.Website(
domain=task.domain,
paths=[
dict(
path="https://another-example.com", checks=[{task.check: task.expected}]
),
],
)
empty_config.websites = [website]
await queries.update_from_config(db, empty_config)
assert db.query(Task).count() == 1
@pytest.mark.asyncio