— Mutualize check requests (fix #68)

This commit is contained in:
Luc Didry 2024-12-04 15:04:06 +01:00
parent ea23ea7c1f
commit e0edb50e12
No known key found for this signature in database
GPG key ID: EA868E12D0257E3C
7 changed files with 111 additions and 34 deletions

View file

@ -3,6 +3,7 @@
## [Unreleased]
- ✨ — IPv4/IPv6 choice for checks, and choice for a dual-stack check (#69)
- ⚡ — Mutualize check requests (#68)
## 0.6.1

View file

@ -41,9 +41,10 @@ class ArgosAgent: # pylint: disable-msg=too-many-instance-attributes
self.max_tasks = max_tasks
self.wait_time = wait_time
self.auth = auth
self._http_client = None
self._http_client_v4 = None
self._http_client_v6 = None
self._http_client: httpx.AsyncClient | None = None
self._http_client_v4: httpx.AsyncClient | None = None
self._http_client_v6: httpx.AsyncClient | None = None
self._res_cache: dict[str, httpx.Response] = {}
self.agent_id = socket.gethostname()
@ -51,6 +52,7 @@ class ArgosAgent: # pylint: disable-msg=too-many-instance-attributes
async def run(self):
auth_header = {
"Authorization": f"Bearer {self.auth}",
"User-Agent": f"Argos Panoptes agent {VERSION}",
}
self._http_client = httpx.AsyncClient(headers=auth_header)
@ -74,37 +76,36 @@ class ArgosAgent: # pylint: disable-msg=too-many-instance-attributes
logger.info("Waiting %i seconds before next retry", self.wait_time)
await asyncio.sleep(self.wait_time)
async def _do_request(self, group: str, details: dict):
try:
if details["ip_version"] == "4":
response = await self._http_client_v4.request( # type: ignore[union-attr]
method=details["method"], url=details["url"], timeout=60
)
else:
response = await self._http_client_v6.request( # type: ignore[union-attr]
method=details["method"], url=details["url"], timeout=60
)
except httpx.ReadError:
sleep(1)
if details["ip_version"] == "4":
response = await self._http_client_v4.request( # type: ignore[union-attr]
method=details["method"], url=details["url"], timeout=60
)
else:
response = await self._http_client_v6.request( # type: ignore[union-attr]
method=details["method"], url=details["url"], timeout=60
)
self._res_cache[group] = response
async def _complete_task(self, _task: dict) -> AgentResult:
try:
task = Task(**_task)
url = task.url
if task.check == "http-to-https":
url = str(httpx.URL(task.url).copy_with(scheme="http"))
try:
if task.ip_version == "4":
response = await self._http_client_v4.request( # type: ignore[attr-defined]
method=task.method, url=url, timeout=60
)
else:
response = await self._http_client_v6.request( # type: ignore[attr-defined]
method=task.method, url=url, timeout=60
)
except httpx.ReadError:
sleep(1)
if task.ip_version == "4":
response = await self._http_client_v4.request( # type: ignore[attr-defined]
method=task.method, url=url, timeout=60
)
else:
response = await self._http_client_v6.request( # type: ignore[attr-defined]
method=task.method, url=url, timeout=60
)
check_class = get_registered_check(task.check)
check = check_class(task)
result = await check.run(response)
result = await check.run(self._res_cache[task.task_group])
status = result.status
context = result.context
@ -123,10 +124,34 @@ class ArgosAgent: # pylint: disable-msg=too-many-instance-attributes
)
if response.status_code == httpx.codes.OK:
# XXX Maybe we want to group the tests by URL ? (to issue one request per URL)
data = response.json()
logger.info("Received %i tasks from the server", len(data))
req_groups = {}
for _task in data:
task = Task(**_task)
url = task.url
group = task.task_group
if task.check == "http-to-https":
url = str(httpx.URL(task.url).copy_with(scheme="http"))
group = f"{task.method}-{task.ip_version}-{url}"
_task["task_group"] = group
req_groups[group] = {
"url": url,
"ip_version": task.ip_version,
"method": task.method,
}
requests = []
for group, details in req_groups.items():
requests.append(self._do_request(group, details))
if requests:
await asyncio.gather(*requests)
tasks = []
for task in data:
tasks.append(self._complete_task(task))

View file

@ -23,6 +23,7 @@ class Task(BaseModel):
check: str
method: Method
expected: str
task_group: str
selected_at: datetime | None
selected_by: str | None

View file

@ -0,0 +1,35 @@
"""Add task index
Revision ID: 8b58ced14d6e
Revises: 64f73a79b7d8
Create Date: 2024-12-03 16:41:44.842213
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "8b58ced14d6e"
down_revision: Union[str, None] = "64f73a79b7d8"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
with op.batch_alter_table("tasks", schema=None) as batch_op:
batch_op.add_column(sa.Column("task_group", sa.String(), nullable=True))
with op.batch_alter_table("tasks", schema=None) as batch_op:
batch_op.execute(
"UPDATE tasks SET task_group = method || '-' || ip_version || '-' || url"
)
batch_op.alter_column("task_group", nullable=False)
batch_op.create_index("similar_tasks", ["task_group"], unique=False)
def downgrade() -> None:
with op.batch_alter_table("tasks", schema=None) as batch_op:
batch_op.drop_index("similar_tasks")
batch_op.drop_column("task_group")

View file

@ -9,12 +9,21 @@ from sqlalchemy import (
ForeignKey,
)
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from sqlalchemy.schema import Index
from argos.checks import BaseCheck, get_registered_check
from argos.schemas import WebsiteCheck
from argos.schemas.utils import IPVersion, Method
def compute_task_group(context) -> str:
return (
f"{context.current_parameters['method']}-"
f"{context.current_parameters['ip_version']}-"
f"{context.current_parameters['url']}"
)
class Base(DeclarativeBase):
type_annotation_map = {List[WebsiteCheck]: JSON, dict: JSON}
@ -62,6 +71,7 @@ class Task(Base):
selected_at: Mapped[datetime] = mapped_column(nullable=True)
completed_at: Mapped[datetime] = mapped_column(nullable=True)
next_run: Mapped[datetime] = mapped_column(nullable=True)
task_group: Mapped[str] = mapped_column(insert_default=compute_task_group)
severity: Mapped[Literal["ok", "warning", "critical", "unknown"]] = mapped_column(
Enum("ok", "warning", "critical", "unknown", name="severity"),
@ -75,7 +85,7 @@ class Task(Base):
passive_deletes=True,
)
def __str__(self):
def __str__(self) -> str:
return f"DB Task {self.url} (IPv{self.ip_version}) - {self.check} - {self.expected}"
def get_check(self) -> BaseCheck:
@ -117,6 +127,9 @@ class Task(Base):
return self.last_result.status
Index("similar_tasks", Task.task_group)
class Result(Base):
"""There are multiple results per task.

View file

@ -4,7 +4,7 @@ from hashlib import sha256
from typing import List
from urllib.parse import urljoin
from sqlalchemy import asc, desc, func
from sqlalchemy import asc, desc, func, Select
from sqlalchemy.orm import Session
from argos import schemas
@ -14,15 +14,16 @@ from argos.server.models import Result, Task, ConfigCache, User
async def list_tasks(db: Session, agent_id: str, limit: int = 100):
"""List tasks and mark them as selected"""
tasks = (
db.query(Task)
subquery = (
db.query(func.distinct(Task.task_group))
.filter(
Task.selected_by == None, # noqa: E711
((Task.next_run <= datetime.now()) | (Task.next_run == None)), # noqa: E711
)
.limit(limit)
.all()
.subquery()
)
tasks = db.query(Task).filter(Task.task_group.in_(Select(subquery))).all()
now = datetime.now()
for task in tasks:

View file

@ -37,6 +37,7 @@ def ssl_task(now):
domain="https://example.org",
ip_version="6",
method="GET",
task_group="GET-6-https://example.org",
check="ssl-certificate-expiration",
expected="on-check",
selected_at=now,