argos/argos/agent.py

106 lines
3.6 KiB
Python

import asyncio
import logging
import socket
from typing import List
import httpx
from tenacity import retry, wait_random
from argos.checks import get_registered_check
from argos.logging import logger
from argos.schemas import AgentResult, SerializableException, Task
def log_failure(retry_state):
if retry_state.attempt_number < 1:
loglevel = logging.INFO
else:
loglevel = logging.WARNING
logger.log(
loglevel,
"Retrying: attempt %s ended with: %s %s",
retry_state.attempt_number,
retry_state.outcome,
retry_state.outcome.exception(),
)
class ArgosAgent:
"""The Argos agent is responsible for running the checks and reporting the results."""
def __init__(self, server: str, auth: str, max_tasks: int, wait_time: int):
self.server = server
self.max_tasks = max_tasks
self.wait_time = wait_time
self.auth = auth
self.agent_id = socket.gethostname()
@retry(after=log_failure, wait=wait_random(min=1, max=2))
async def run(self):
logger.info(f"Running agent against {self.server}")
headers = {
"Authorization": f"Bearer {self.auth}",
}
self._http_client = httpx.AsyncClient(headers=headers)
async with self._http_client:
while True:
retry_now = await self._get_and_complete_tasks()
if not retry_now:
logger.error(f"Waiting {self.wait_time} seconds before next retry")
await asyncio.sleep(self.wait_time)
async def _complete_task(self, task: dict) -> dict:
try:
task = Task(**task)
check_class = get_registered_check(task.check)
check = check_class(self._http_client, task)
result = await check.run()
status = result.status
context = result.context
except Exception as e:
status = "error"
context = SerializableException.from_exception(e)
msg = f"An exception occured when running {task}. {e.__class__.__name__} : {e}"
logger.error(msg)
return AgentResult(task_id=task.id, status=status, context=context)
async def _get_and_complete_tasks(self):
# Fetch the list of tasks
response = await self._http_client.get(
f"{self.server}/api/tasks",
params={"limit": self.max_tasks, "agent_id": self.agent_id},
)
if response.status_code == httpx.codes.OK:
# XXX Maybe we want to group the tests by URL ? (to issue one request per URL)
data = response.json()
logger.info(f"Received {len(data)} tasks from the server")
tasks = []
for task in data:
tasks.append(self._complete_task(task))
if tasks:
results = await asyncio.gather(*tasks)
await self._post_results(results)
return True
else:
logger.error("Got no tasks from the server.")
return False
else:
logger.error(f"Failed to fetch tasks: {response.read()}")
return False
async def _post_results(self, results: List[AgentResult]):
data = [r.model_dump() for r in results]
response = await self._http_client.post(
f"{self.server}/api/results", params={"agent_id": self.agent_id}, json=data
)
if response.status_code == httpx.codes.CREATED:
logger.error(f"Successfully posted results {response.json()}")
else:
logger.error(f"Failed to post results: {response.read()}")
return response