mirror of
https://framagit.org/framasoft/framaspace/argos.git
synced 2025-04-28 18:02:41 +02:00
Update API endpoints and test cases
- Adjusted test cases in test_api.py - Updated test_db fixture to delete the database between each test. - Consolidated test cases for tasks retrieval and results
This commit is contained in:
parent
43f8aabb2c
commit
e2d8066746
3 changed files with 32 additions and 13 deletions
|
@ -5,7 +5,7 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|||
from sqlalchemy.orm import Session
|
||||
|
||||
from argos.logging import logger
|
||||
from argos.schemas import AgentResult, Task
|
||||
from argos.schemas import AgentResult, Config, Task
|
||||
from argos.server import queries
|
||||
from argos.server.alerting import handle_alert
|
||||
|
||||
|
@ -21,6 +21,10 @@ def get_db(request: Request):
|
|||
db.close()
|
||||
|
||||
|
||||
def get_config(request: Request):
|
||||
return request.app.state.config
|
||||
|
||||
|
||||
async def verify_token(
|
||||
request: Request, token: HTTPAuthorizationCredentials = Depends(auth_scheme)
|
||||
):
|
||||
|
@ -38,7 +42,11 @@ async def read_tasks(request: Request, db: Session = Depends(get_db), limit: int
|
|||
|
||||
|
||||
@api.post("/results", status_code=201, dependencies=[Depends(verify_token)])
|
||||
async def create_result(results: List[AgentResult], db: Session = Depends(get_db)):
|
||||
async def create_result(
|
||||
results: List[AgentResult],
|
||||
db: Session = Depends(get_db),
|
||||
config: Config = Depends(get_config),
|
||||
):
|
||||
"""Get the results from the agents and store them locally.
|
||||
|
||||
- Finalize the checks (some checks need the server to do some part of the validation,
|
||||
|
@ -57,13 +65,11 @@ async def create_result(results: List[AgentResult], db: Session = Depends(get_db
|
|||
logger.error(f"Unable to find task {agent_result.task_id}")
|
||||
else:
|
||||
check = task.get_check()
|
||||
status, severity = await check.finalize(
|
||||
api.config, result, **result.context
|
||||
)
|
||||
status, severity = await check.finalize(config, result, **result.context)
|
||||
result.set_status(status, severity)
|
||||
task.set_times_and_deselect()
|
||||
|
||||
handle_alert(api.config, result, task, severity)
|
||||
handle_alert(config, result, task, severity)
|
||||
|
||||
db_results.append(result)
|
||||
db.commit()
|
||||
|
|
|
@ -29,7 +29,7 @@ async def list_tasks(db: Session, agent_id: str, limit: int = 100):
|
|||
|
||||
|
||||
async def get_task(db: Session, id):
|
||||
return db.query(Task).get(id)
|
||||
return db.get(Task, id)
|
||||
|
||||
|
||||
async def create_result(db: Session, agent_result: schemas.AgentResult):
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from argos.schemas import AgentResult
|
||||
from argos.server import app, models
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_db():
|
||||
models.Base.metadata.create_all(bind=app.engine)
|
||||
yield
|
||||
models.Base.metadata.drop_all(bind=app.engine)
|
||||
models.Base.metadata.create_all(bind=app.state.engine)
|
||||
yield app.state.db
|
||||
models.Base.metadata.drop_all(bind=app.state.engine)
|
||||
|
||||
|
||||
def test_read_tasks_requires_auth():
|
||||
|
@ -17,12 +18,24 @@ def test_read_tasks_requires_auth():
|
|||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_read_tasks_returns_tasks():
|
||||
def test_tasks_retrieval_and_results(test_db):
|
||||
with TestClient(app) as client:
|
||||
token = app.state.config.service.secrets[0]
|
||||
client.headers = {"Authorization": f"Bearer {token}"}
|
||||
response = client.get("/tasks")
|
||||
assert response.status_code == 200
|
||||
|
||||
# We should have only two tasks
|
||||
assert len(response.json()) == 2
|
||||
tasks = response.json()
|
||||
assert len(tasks) == 2
|
||||
|
||||
results = []
|
||||
for task in tasks:
|
||||
results.append(
|
||||
AgentResult(task_id=task["id"], status="success", context={})
|
||||
)
|
||||
|
||||
data = [r.model_dump() for r in results]
|
||||
response = client.post("/results", json=data)
|
||||
|
||||
assert response.status_code == 201
|
||||
assert test_db.query(models.Result).count() == 2
|
||||
|
|
Loading…
Reference in a new issue