diff --git a/argos/server/api.py b/argos/server/api.py index 5959ce0..57bc592 100644 --- a/argos/server/api.py +++ b/argos/server/api.py @@ -5,7 +5,7 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from sqlalchemy.orm import Session from argos.logging import logger -from argos.schemas import AgentResult, Task +from argos.schemas import AgentResult, Config, Task from argos.server import queries from argos.server.alerting import handle_alert @@ -21,6 +21,10 @@ def get_db(request: Request): db.close() +def get_config(request: Request): + return request.app.state.config + + async def verify_token( request: Request, token: HTTPAuthorizationCredentials = Depends(auth_scheme) ): @@ -38,7 +42,11 @@ async def read_tasks(request: Request, db: Session = Depends(get_db), limit: int @api.post("/results", status_code=201, dependencies=[Depends(verify_token)]) -async def create_result(results: List[AgentResult], db: Session = Depends(get_db)): +async def create_result( + results: List[AgentResult], + db: Session = Depends(get_db), + config: Config = Depends(get_config), +): """Get the results from the agents and store them locally. - Finalize the checks (some checks need the server to do some part of the validation, @@ -57,13 +65,11 @@ async def create_result(results: List[AgentResult], db: Session = Depends(get_db logger.error(f"Unable to find task {agent_result.task_id}") else: check = task.get_check() - status, severity = await check.finalize( - api.config, result, **result.context - ) + status, severity = await check.finalize(config, result, **result.context) result.set_status(status, severity) task.set_times_and_deselect() - handle_alert(api.config, result, task, severity) + handle_alert(config, result, task, severity) db_results.append(result) db.commit() diff --git a/argos/server/queries.py b/argos/server/queries.py index 39e17c0..eea4bf5 100644 --- a/argos/server/queries.py +++ b/argos/server/queries.py @@ -29,7 +29,7 @@ async def list_tasks(db: Session, agent_id: str, limit: int = 100): async def get_task(db: Session, id): - return db.query(Task).get(id) + return db.get(Task, id) async def create_result(db: Session, agent_result: schemas.AgentResult): diff --git a/tests/test_api.py b/tests/test_api.py index e26c624..d90ca67 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,14 +1,15 @@ import pytest from fastapi.testclient import TestClient +from argos.schemas import AgentResult from argos.server import app, models @pytest.fixture() def test_db(): - models.Base.metadata.create_all(bind=app.engine) - yield - models.Base.metadata.drop_all(bind=app.engine) + models.Base.metadata.create_all(bind=app.state.engine) + yield app.state.db + models.Base.metadata.drop_all(bind=app.state.engine) def test_read_tasks_requires_auth(): @@ -17,12 +18,24 @@ def test_read_tasks_requires_auth(): assert response.status_code == 403 -def test_read_tasks_returns_tasks(): +def test_tasks_retrieval_and_results(test_db): with TestClient(app) as client: token = app.state.config.service.secrets[0] client.headers = {"Authorization": f"Bearer {token}"} response = client.get("/tasks") assert response.status_code == 200 - # We should have only two tasks - assert len(response.json()) == 2 + tasks = response.json() + assert len(tasks) == 2 + + results = [] + for task in tasks: + results.append( + AgentResult(task_id=task["id"], status="success", context={}) + ) + + data = [r.model_dump() for r in results] + response = client.post("/results", json=data) + + assert response.status_code == 201 + assert test_db.query(models.Result).count() == 2