Update API endpoints and test cases

- Adjusted test cases in test_api.py
  - Updated test_db fixture to delete the database between each test.
  - Consolidated test cases for tasks retrieval and results
This commit is contained in:
Alexis Métaireau 2023-10-12 00:16:06 +02:00
parent 43f8aabb2c
commit e2d8066746
3 changed files with 32 additions and 13 deletions

View file

@ -5,7 +5,7 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from argos.logging import logger from argos.logging import logger
from argos.schemas import AgentResult, Task from argos.schemas import AgentResult, Config, Task
from argos.server import queries from argos.server import queries
from argos.server.alerting import handle_alert from argos.server.alerting import handle_alert
@ -21,6 +21,10 @@ def get_db(request: Request):
db.close() db.close()
def get_config(request: Request):
return request.app.state.config
async def verify_token( async def verify_token(
request: Request, token: HTTPAuthorizationCredentials = Depends(auth_scheme) request: Request, token: HTTPAuthorizationCredentials = Depends(auth_scheme)
): ):
@ -38,7 +42,11 @@ async def read_tasks(request: Request, db: Session = Depends(get_db), limit: int
@api.post("/results", status_code=201, dependencies=[Depends(verify_token)]) @api.post("/results", status_code=201, dependencies=[Depends(verify_token)])
async def create_result(results: List[AgentResult], db: Session = Depends(get_db)): async def create_result(
results: List[AgentResult],
db: Session = Depends(get_db),
config: Config = Depends(get_config),
):
"""Get the results from the agents and store them locally. """Get the results from the agents and store them locally.
- Finalize the checks (some checks need the server to do some part of the validation, - Finalize the checks (some checks need the server to do some part of the validation,
@ -57,13 +65,11 @@ async def create_result(results: List[AgentResult], db: Session = Depends(get_db
logger.error(f"Unable to find task {agent_result.task_id}") logger.error(f"Unable to find task {agent_result.task_id}")
else: else:
check = task.get_check() check = task.get_check()
status, severity = await check.finalize( status, severity = await check.finalize(config, result, **result.context)
api.config, result, **result.context
)
result.set_status(status, severity) result.set_status(status, severity)
task.set_times_and_deselect() task.set_times_and_deselect()
handle_alert(api.config, result, task, severity) handle_alert(config, result, task, severity)
db_results.append(result) db_results.append(result)
db.commit() db.commit()

View file

@ -29,7 +29,7 @@ async def list_tasks(db: Session, agent_id: str, limit: int = 100):
async def get_task(db: Session, id): async def get_task(db: Session, id):
return db.query(Task).get(id) return db.get(Task, id)
async def create_result(db: Session, agent_result: schemas.AgentResult): async def create_result(db: Session, agent_result: schemas.AgentResult):

View file

@ -1,14 +1,15 @@
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from argos.schemas import AgentResult
from argos.server import app, models from argos.server import app, models
@pytest.fixture() @pytest.fixture()
def test_db(): def test_db():
models.Base.metadata.create_all(bind=app.engine) models.Base.metadata.create_all(bind=app.state.engine)
yield yield app.state.db
models.Base.metadata.drop_all(bind=app.engine) models.Base.metadata.drop_all(bind=app.state.engine)
def test_read_tasks_requires_auth(): def test_read_tasks_requires_auth():
@ -17,12 +18,24 @@ def test_read_tasks_requires_auth():
assert response.status_code == 403 assert response.status_code == 403
def test_read_tasks_returns_tasks(): def test_tasks_retrieval_and_results(test_db):
with TestClient(app) as client: with TestClient(app) as client:
token = app.state.config.service.secrets[0] token = app.state.config.service.secrets[0]
client.headers = {"Authorization": f"Bearer {token}"} client.headers = {"Authorization": f"Bearer {token}"}
response = client.get("/tasks") response = client.get("/tasks")
assert response.status_code == 200 assert response.status_code == 200
# We should have only two tasks tasks = response.json()
assert len(response.json()) == 2 assert len(tasks) == 2
results = []
for task in tasks:
results.append(
AgentResult(task_id=task["id"], status="success", context={})
)
data = [r.model_dump() for r in results]
response = client.post("/results", json=data)
assert response.status_code == 201
assert test_db.query(models.Result).count() == 2