♻ — Move tasks config (re)loading into a command

This commit is contained in:
Luc Didry 2024-03-14 12:18:05 +01:00
parent e3b1b714b3
commit f976905433
No known key found for this signature in database
GPG key ID: EA868E12D0257E3C
8 changed files with 75 additions and 8 deletions

View file

@ -0,0 +1,35 @@
"""Adding ConfigCache model
Revision ID: 1a3497f9f71b
Revises: 7d480e6f1112
Create Date: 2024-03-13 15:28:09.185377
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '1a3497f9f71b'
down_revision: Union[str, None] = '7d480e6f1112'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('config_cache',
sa.Column('name', sa.String(), nullable=False),
sa.Column('val', sa.String(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint('name')
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('config_cache')
# ### end Alembic commands ###

View file

@ -131,5 +131,26 @@ async def cleandb(max_results, max_lock_seconds):
click.echo(f"{updated} locks released") click.echo(f"{updated} locks released")
@server.command()
@coroutine
async def reload_config():
"""Read tasks config and add/delete tasks in database if needed
"""
# The imports are made here otherwise the agent will need server configuration files.
from argos.server import queries
from argos.server.main import get_application, read_config
from argos.server.settings import get_app_settings
appli = get_application()
settings = get_app_settings()
config = read_config(appli, settings)
db = await get_db()
changed = await queries.update_from_config(db, config)
click.echo(f"{changed['added']} tasks added")
click.echo(f"{changed['vanished']} tasks deleted")
if __name__ == "__main__": if __name__ == "__main__":
cli() cli()

View file

@ -7,7 +7,7 @@ from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from argos.logging import logger from argos.logging import logger
from argos.server import models, queries, routes from argos.server import models, routes
from argos.server.settings import get_app_settings, read_yaml_config from argos.server.settings import get_app_settings, read_yaml_config
@ -39,15 +39,14 @@ def get_application() -> FastAPI:
def create_start_app_handler(appli): def create_start_app_handler(appli):
"""Warmup the server: """Warmup the server:
setup database connection and update the tasks in it before making it available setup database connection
""" """
async def read_config_and_populate_db(): async def connect_db_at_startup():
setup_database(appli) setup_database(appli)
db = await connect_to_db(appli) return await connect_to_db(appli)
await queries.update_from_config(db, appli.state.config)
return read_config_and_populate_db return connect_db_at_startup
async def connect_to_db(appli): async def connect_to_db(appli):

View file

@ -111,7 +111,7 @@ async def update_from_config(db: Session, config: schemas.Config):
"""Update tasks from config file""" """Update tasks from config file"""
config_unchanged = await is_config_unchanged(db, config) config_unchanged = await is_config_unchanged(db, config)
if config_unchanged: if config_unchanged:
return None return {'added': 0, 'vanished': 0}
max_task_id = ( max_task_id = (
db.query(func.max(Task.id).label('max_id')) # pylint: disable-msg=not-callable db.query(func.max(Task.id).label('max_id')) # pylint: disable-msg=not-callable
@ -175,6 +175,9 @@ async def update_from_config(db: Session, config: schemas.Config):
) )
db.commit() db.commit()
logger.info("%i tasks has been removed since not in config file anymore", vanished_tasks) logger.info("%i tasks has been removed since not in config file anymore", vanished_tasks)
return {'added': len(tasks), 'vanished': vanished_tasks}
return {'added': len(tasks), 'vanished': 0}
async def get_severity_counts(db: Session) -> dict: async def get_severity_counts(db: Session) -> dict:

View file

@ -8,6 +8,8 @@ PartOf=postgresql.service
[Service] [Service]
User=www-data User=www-data
WorkingDirectory=/var/www/argos/ WorkingDirectory=/var/www/argos/
ExecStartPre=/var/www/argos/venv/bin/alembic upgrade head
ExecStartPre=/var/www/argos/venv/bin/argos server reload-config
ExecStart=/var/www/argos/venv/bin/argos server start ExecStart=/var/www/argos/venv/bin/argos server start
ExecReload=/var/www/argos/venv/bin/argos server reload ExecReload=/var/www/argos/venv/bin/argos server reload
SyslogIdentifier=argos-server SyslogIdentifier=argos-server

View file

@ -1,3 +1,4 @@
import asyncio
import os import os
import pytest import pytest
@ -40,6 +41,7 @@ def _create_app() -> FastAPI:
from argos.server.main import ( # local import for testing purpose from argos.server.main import ( # local import for testing purpose
get_application, get_application,
setup_database, setup_database,
connect_to_db,
) )
app = get_application() app = get_application()
@ -49,4 +51,5 @@ def _create_app() -> FastAPI:
app.state.settings.yaml_file = "tests/config.yaml" app.state.settings.yaml_file = "tests/config.yaml"
setup_database(app) setup_database(app)
asyncio.run(connect_to_db(app))
return app return app

View file

@ -1,8 +1,11 @@
import asyncio
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from argos.schemas import AgentResult, SerializableException from argos.schemas import AgentResult, SerializableException
from argos.server import models from argos.server import models
from argos.server.queries import update_from_config
def test_read_tasks_requires_auth(app): def test_read_tasks_requires_auth(app):
@ -12,6 +15,7 @@ def test_read_tasks_requires_auth(app):
def test_tasks_retrieval_and_results(authorized_client, app): def test_tasks_retrieval_and_results(authorized_client, app):
asyncio.run(update_from_config(app.state.db, app.state.config))
with authorized_client as client: with authorized_client as client:
response = client.get("/api/tasks") response = client.get("/api/tasks")
assert response.status_code == 200 assert response.status_code == 200