diff --git a/alembic/versions/1a3497f9f71b_adding_configcache_model.py b/alembic/versions/1a3497f9f71b_adding_configcache_model.py new file mode 100644 index 0000000..2f92fd3 --- /dev/null +++ b/alembic/versions/1a3497f9f71b_adding_configcache_model.py @@ -0,0 +1,35 @@ +"""Adding ConfigCache model + +Revision ID: 1a3497f9f71b +Revises: 7d480e6f1112 +Create Date: 2024-03-13 15:28:09.185377 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '1a3497f9f71b' +down_revision: Union[str, None] = '7d480e6f1112' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('config_cache', + sa.Column('name', sa.String(), nullable=False), + sa.Column('val', sa.String(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint('name') + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('config_cache') + # ### end Alembic commands ### diff --git a/argos/commands.py b/argos/commands.py index 9c3ccd6..e53a93d 100644 --- a/argos/commands.py +++ b/argos/commands.py @@ -131,5 +131,26 @@ async def cleandb(max_results, max_lock_seconds): click.echo(f"{updated} locks released") +@server.command() +@coroutine +async def reload_config(): + """Read tasks config and add/delete tasks in database if needed + """ + # The imports are made here otherwise the agent will need server configuration files. + from argos.server import queries + from argos.server.main import get_application, read_config + from argos.server.settings import get_app_settings + + appli = get_application() + settings = get_app_settings() + config = read_config(appli, settings) + + db = await get_db() + changed = await queries.update_from_config(db, config) + + click.echo(f"{changed['added']} tasks added") + click.echo(f"{changed['vanished']} tasks deleted") + + if __name__ == "__main__": cli() diff --git a/argos/server/main.py b/argos/server/main.py index ce3f534..d3e175e 100644 --- a/argos/server/main.py +++ b/argos/server/main.py @@ -7,7 +7,7 @@ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from argos.logging import logger -from argos.server import models, queries, routes +from argos.server import models, routes from argos.server.settings import get_app_settings, read_yaml_config @@ -39,15 +39,14 @@ def get_application() -> FastAPI: def create_start_app_handler(appli): """Warmup the server: - setup database connection and update the tasks in it before making it available + setup database connection """ - async def read_config_and_populate_db(): + async def connect_db_at_startup(): setup_database(appli) - db = await connect_to_db(appli) - await queries.update_from_config(db, appli.state.config) + return await connect_to_db(appli) - return read_config_and_populate_db + return connect_db_at_startup async def connect_to_db(appli): diff --git a/argos/server/queries.py b/argos/server/queries.py index 345f480..6d4d790 100644 --- a/argos/server/queries.py +++ b/argos/server/queries.py @@ -111,7 +111,7 @@ async def update_from_config(db: Session, config: schemas.Config): """Update tasks from config file""" config_unchanged = await is_config_unchanged(db, config) if config_unchanged: - return None + return {'added': 0, 'vanished': 0} max_task_id = ( db.query(func.max(Task.id).label('max_id')) # pylint: disable-msg=not-callable @@ -175,6 +175,9 @@ async def update_from_config(db: Session, config: schemas.Config): ) db.commit() logger.info("%i tasks has been removed since not in config file anymore", vanished_tasks) + return {'added': len(tasks), 'vanished': vanished_tasks} + + return {'added': len(tasks), 'vanished': 0} async def get_severity_counts(db: Session) -> dict: diff --git a/argos/server/routes/api.py b/argos/server/routes/api.py index 432660b..4a316b7 100644 --- a/argos/server/routes/api.py +++ b/argos/server/routes/api.py @@ -62,7 +62,7 @@ async def create_results( # XXX Use a job queue or make it async handle_alert(config, result, task, severity, last_severity, request) - db_results.append(result) + db_results.append(result) db.commit() return {"result_ids": [r.id for r in db_results]} diff --git a/conf/systemd-server.service b/conf/systemd-server.service index 698e637..0c69536 100644 --- a/conf/systemd-server.service +++ b/conf/systemd-server.service @@ -8,6 +8,8 @@ PartOf=postgresql.service [Service] User=www-data WorkingDirectory=/var/www/argos/ +ExecStartPre=/var/www/argos/venv/bin/alembic upgrade head +ExecStartPre=/var/www/argos/venv/bin/argos server reload-config ExecStart=/var/www/argos/venv/bin/argos server start ExecReload=/var/www/argos/venv/bin/argos server reload SyslogIdentifier=argos-server diff --git a/tests/conftest.py b/tests/conftest.py index 84e1a44..d07ea1a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import asyncio import os import pytest @@ -40,6 +41,7 @@ def _create_app() -> FastAPI: from argos.server.main import ( # local import for testing purpose get_application, setup_database, + connect_to_db, ) app = get_application() @@ -49,4 +51,5 @@ def _create_app() -> FastAPI: app.state.settings.yaml_file = "tests/config.yaml" setup_database(app) + asyncio.run(connect_to_db(app)) return app diff --git a/tests/test_api.py b/tests/test_api.py index 53ca7a7..21c87d1 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,8 +1,11 @@ +import asyncio + import pytest from fastapi.testclient import TestClient from argos.schemas import AgentResult, SerializableException from argos.server import models +from argos.server.queries import update_from_config def test_read_tasks_requires_auth(app): @@ -12,6 +15,7 @@ def test_read_tasks_requires_auth(app): def test_tasks_retrieval_and_results(authorized_client, app): + asyncio.run(update_from_config(app.state.db, app.state.config)) with authorized_client as client: response = client.get("/api/tasks") assert response.status_code == 200