From 9885a5809a224dd72d392029b650b88e702a1405 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexis=20M=C3=A9taireau?= Date: Wed, 18 Oct 2023 21:30:03 +0200 Subject: [PATCH] Add an 'argos server clean' command that needs to be run periodically to clean the db --- README.md | 16 ++++-- argos/commands.py | 38 +++++++++++++- argos/server/main.py | 17 +++--- argos/server/queries.py | 38 ++++++++++++-- argos/server/settings.py | 4 +- pyproject.toml | 1 + tests/conftest.py | 41 ++++++++++----- tests/test_api.py | 16 ++---- tests/test_queries.py | 109 +++++++++++++++++++++++++++++++++++++++ 9 files changed, 236 insertions(+), 44 deletions(-) create mode 100644 tests/test_queries.py diff --git a/README.md b/README.md index d7dcaf6..ed5229e 100644 --- a/README.md +++ b/README.md @@ -4,13 +4,9 @@ Argos is an HTTP monitoring service. It allows you to define a list of websites Todo: -- [x] Use Postgresql as a database -- [x] Expose a simple read-only website. -- [x] Agents should wait and retry on timeout - [ ] Last seen agents - [ ] Use background tasks for alerting - [ ] Add a command to generate new authentication tokens -- [ ] Task for database cleanup (to run periodically) - [ ] Handles multiple alerting backends (email, sms, gotify) - [ ] Add a way to specify the severity of the alerts in the config - [ ] Do not send "expected" and "got" values in case check-status and body-contains suceeded @@ -50,7 +46,7 @@ cp config-example.yaml config.yaml Then, you can run the server: ```bash -argos server +argos server run ``` You can specify the environment variables to configure the server, or you can put them in an `.env` file: @@ -68,6 +64,16 @@ And here is how to run the agent: argos agent http://localhost:8000 "" ``` +You also need to run cleaning tasks periodically. `argos server clean --help` will give you more information on how to do that. + +Here is a crontab example: + +```bash +# Run the cleaning tasks every hour (at minute 7) +7 * * * * argos server clean --max-results 100000 --max-lock-seconds 3600 +``` + + ## Configuration Here is a simple configuration file: diff --git a/argos/commands.py b/argos/commands.py index c1ae8e9..7f8c25e 100644 --- a/argos/commands.py +++ b/argos/commands.py @@ -5,6 +5,8 @@ import click from argos import logging from argos.agent import ArgosAgent +from argos.server import queries +from argos.server.main import connect_to_db, get_application, setup_database @click.group() @@ -12,6 +14,11 @@ def cli(): pass +@cli.group() +def server(): + pass + + @cli.command() @click.argument("server") @click.argument("auth") @@ -43,12 +50,12 @@ def agent(server, auth, max_tasks, wait_time, log_level): asyncio.run(agent.run()) -@cli.command() +@server.command() @click.option("--host", default="127.0.0.1", help="Host to bind") @click.option("--port", default=8000, type=int, help="Port to bind") @click.option("--reload", is_flag=True, help="Enable hot reloading") @click.option("--log-config", help="Path to the logging configuration file") -def server(host, port, reload, log_config): +def start(host, port, reload, log_config): """Starts the server.""" command = ["uvicorn", "argos.server:app", "--host", host, "--port", str(port)] if reload: @@ -58,5 +65,32 @@ def server(host, port, reload, log_config): subprocess.run(command) +@server.command() +@click.option("--max-results", default=100, help="Maximum number of results to keep") +@click.option( + "--max-lock-seconds", + default=100, + help="The number of seconds after which a lock is considered stale", +) +def clean(max_results, max_lock_seconds): + """Clean the database (to run routinely) + + - Removes old results from the database. + - Removes locks from tasks that have been locked for too long. + """ + + async def clean_old_results(): + app = get_application() + setup_database(app) + db = await connect_to_db(app) + removed = await queries.remove_old_results(db, max_results) + updated = await queries.release_old_locks(db, max_lock_seconds) + + click.echo(f"{removed} results removed") + click.echo(f"{updated} locks released") + + asyncio.run(clean_old_results()) + + if __name__ == "__main__": cli() diff --git a/argos/server/main.py b/argos/server/main.py index b9e78b2..dd2fa9c 100644 --- a/argos/server/main.py +++ b/argos/server/main.py @@ -16,11 +16,15 @@ def get_application() -> FastAPI: app = FastAPI() config = read_config(app, settings) + + # Settings is the pydantic settings object + # Config is the argos config object (built from yaml) app.state.config = config + app.state.settings = settings app.add_event_handler( "startup", - create_start_app_handler(app, settings), + create_start_app_handler(app), ) app.add_event_handler( "shutdown", @@ -32,17 +36,17 @@ def get_application() -> FastAPI: return app -def create_start_app_handler(app, settings): +def create_start_app_handler(app): async def read_config_and_populate_db(): - setup_database(app, settings) + setup_database(app) - db = await connect_to_db(app, settings) + db = await connect_to_db(app) await queries.update_from_config(db, app.state.config) return read_config_and_populate_db -async def connect_to_db(app, settings): +async def connect_to_db(app): app.state.db = app.state.SessionLocal() return app.state.db @@ -66,7 +70,8 @@ def read_config(app, settings): sys.exit(1) -def setup_database(app, settings): +def setup_database(app): + settings = app.state.settings # For sqlite, we need to add connect_args={"check_same_thread": False} logger.debug(f"Using database URL {settings.database_url}") if settings.database_url.startswith("sqlite:////tmp"): diff --git a/argos/server/queries.py b/argos/server/queries.py index 68e6d17..e309c86 100644 --- a/argos/server/queries.py +++ b/argos/server/queries.py @@ -1,7 +1,7 @@ -from datetime import datetime +from datetime import datetime, timedelta from urllib.parse import urljoin -from sqlalchemy import func +from sqlalchemy import Select, desc, func from sqlalchemy.orm import Session from argos import schemas @@ -33,12 +33,13 @@ async def get_task(db: Session, id: int) -> Task: return db.get(Task, id) -async def create_result(db: Session, agent_result: schemas.AgentResult): +async def create_result(db: Session, agent_result: schemas.AgentResult, agent_id: str): result = Result( submitted_at=datetime.now(), status=agent_result.status, context=agent_result.context, task_id=agent_result.task_id, + agent_id=agent_id, ) db.add(result) return result @@ -113,3 +114,34 @@ async def get_severity_counts(db: Session): # Execute the query and fetch the results task_counts_by_severity = query.all() return task_counts_by_severity + + +async def remove_old_results(db: Session, max_results: int): + # Get the id of the oldest result to keep, then delete all results older than that + subquery = ( + db.query(Result.id).order_by(desc(Result.id)).limit(max_results).subquery() + ) + min_id = db.query(func.min(subquery.c.id)).scalar() + if min_id: + deleted = db.query(Result).where(Result.id < min_id).delete() + db.commit() + else: + deleted = 0 + + return deleted + + +async def release_old_locks(db: Session, max_lock_seconds: int): + # Get all the jobs that have been selected_at for more than max_lock_time + max_acceptable_time = datetime.now() - timedelta(seconds=max_lock_seconds) + subquery = ( + db.query(Task.id).filter(Task.selected_at < max_acceptable_time).subquery() + ) + # Release the locks on these jobs + updated = ( + db.query(Task) + .filter(Task.id.in_(Select(subquery))) + .update({Task.selected_at: None, Task.selected_by: None}) + ) + db.commit() + return updated diff --git a/argos/server/settings.py b/argos/server/settings.py index a4a95a1..3bdd7b5 100644 --- a/argos/server/settings.py +++ b/argos/server/settings.py @@ -23,9 +23,7 @@ class DevSettings(Settings): class TestSettings(Settings): - app_env: str = "test" - database_url: str = "sqlite:////tmp/test-argos.db" - yaml_file: str = "tests/config.yaml" + pass class ProdSettings(Settings): diff --git a/pyproject.toml b/pyproject.toml index a27cfbb..0ab9c05 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dev = [ "black==23.3.0", "isort==5.11.5", "pytest>=6.2.5", + "pytest-asyncio>=0.21,<1", "ipython>=8.16,<9", "ipdb>=0.13,<0.14", ] diff --git a/tests/conftest.py b/tests/conftest.py index 42e7425..2fa7061 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,25 +1,40 @@ +from datetime import datetime from os import environ import pytest from fastapi import FastAPI -from httpx import AsyncClient +from sqlalchemy.orm import Session + +from argos.server import models environ["ARGOS_APP_ENV"] = "test" @pytest.fixture -def app() -> FastAPI: - from argos.server.main import get_application # local import for testing purpose - - return get_application() +def db() -> Session: + app = _create_app() + models.Base.metadata.create_all(bind=app.state.engine) + yield app.state.SessionLocal() + models.Base.metadata.drop_all(bind=app.state.engine) @pytest.fixture -def authorized_client( - client: AsyncClient, token: str, authorization_prefix: str -) -> AsyncClient: - client.headers = { - "Authorization": f"Bearer {token}", - **client.headers, - } - return client +def app() -> FastAPI: + app = _create_app() + models.Base.metadata.create_all(bind=app.state.engine) + yield app + models.Base.metadata.drop_all(bind=app.state.engine) + + +def _create_app() -> FastAPI: + from argos.server.main import ( # local import for testing purpose + get_application, + setup_database, + ) + + app = get_application() + app.state.settings.database_url = "sqlite:////tmp/test-argos.db" + app.state.settings.yaml_file = "tests/config.yaml" + + setup_database(app) + return app diff --git a/tests/test_api.py b/tests/test_api.py index d7323cf..09512f8 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,24 +1,16 @@ -import pytest from fastapi.testclient import TestClient from argos.schemas import AgentResult -from argos.server import app, models +from argos.server import models -@pytest.fixture() -def test_db(): - models.Base.metadata.create_all(bind=app.state.engine) - yield app.state.db - models.Base.metadata.drop_all(bind=app.state.engine) - - -def test_read_tasks_requires_auth(): +def test_read_tasks_requires_auth(app): with TestClient(app) as client: response = client.get("/api/tasks") assert response.status_code == 403 -def test_tasks_retrieval_and_results(test_db): +def test_tasks_retrieval_and_results(app): with TestClient(app) as client: token = app.state.config.service.secrets[0] client.headers = {"Authorization": f"Bearer {token}"} @@ -38,7 +30,7 @@ def test_tasks_retrieval_and_results(test_db): response = client.post("/api/results", json=data) assert response.status_code == 201 - assert test_db.query(models.Result).count() == 2 + assert app.state.db.query(models.Result).count() == 2 # The list of tasks should be empty now response = client.get("/api/tasks") diff --git a/tests/test_queries.py b/tests/test_queries.py new file mode 100644 index 0000000..b52434c --- /dev/null +++ b/tests/test_queries.py @@ -0,0 +1,109 @@ +from datetime import datetime, timedelta + +import pytest + +from argos.server import queries +from argos.server.models import Result, Task + + +@pytest.mark.asyncio +async def test_remove_old_results(db, ten_results): + assert db.query(Result).count() == 10 + deleted = await queries.remove_old_results(db, 2) + assert deleted == 8 + assert db.query(Result).count() == 2 + # We should keep the last two results + assert db.query(Result).all() == ten_results[-2:] + + +@pytest.mark.asyncio +async def test_remove_old_results_with_empty_db(db): + assert db.query(Result).count() == 0 + deleted = await queries.remove_old_results(db, 2) + assert deleted == 0 + + +@pytest.mark.asyncio +async def test_release_old_locks(db, ten_locked_tasks, ten_tasks): + assert db.query(Task).count() == 20 + released = await queries.release_old_locks(db, 10) + assert released == 10 + + +@pytest.mark.asyncio +async def test_release_old_locks_with_empty_db(db): + assert db.query(Task).count() == 0 + released = await queries.release_old_locks(db, 10) + assert released == 0 + + +@pytest.fixture +def task(db): + task = Task( + url="https://www.example.com", + domain="example.com", + check="body-contains", + expected="foo", + frequency=1, + ) + db.add(task) + db.commit() + return task + + +@pytest.fixture +def ten_results(db, task): + results = [] + for i in range(10): + result = Result( + submitted_at=datetime.now(), + status="success", + context={"foo": "bar"}, + task=task, + agent_id="test", + severity="ok", + ) + db.add(result) + results.append(result) + db.commit() + return results + + +@pytest.fixture +def ten_locked_tasks(db): + a_minute_ago = datetime.now() - timedelta(minutes=1) + tasks = [] + for i in range(10): + task = Task( + url="https://www.example.com", + domain="example.com", + check="body-contains", + expected="foo", + frequency=1, + selected_by="test", + selected_at=a_minute_ago, + ) + db.add(task) + tasks.append(task) + db.commit() + return tasks + + +@pytest.fixture +def ten_tasks(db): + now = datetime.now() + tasks = [] + for i in range(10): + task = Task( + url="https://www.example.com", + domain="example.com", + check="body-contains", + expected="foo", + frequency=1, + selected_by="test", + selected_at=now, + ) + db.add(task) + tasks.append(task) + db.commit() + return tasks