From e0edb50e12346e843e8012955548cc264c5db643 Mon Sep 17 00:00:00 2001 From: Luc Didry Date: Wed, 4 Dec 2024 15:04:06 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=20=E2=80=94=20Mutualize=20check=20req?= =?UTF-8?q?uests=20(fix=20#68)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 1 + argos/agent.py | 83 ++++++++++++------- argos/schemas/models.py | 1 + .../versions/8b58ced14d6e_add_task_index.py | 35 ++++++++ argos/server/models.py | 15 +++- argos/server/queries.py | 9 +- tests/test_checks.py | 1 + 7 files changed, 111 insertions(+), 34 deletions(-) create mode 100644 argos/server/migrations/versions/8b58ced14d6e_add_task_index.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 86cf372..4ada528 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## [Unreleased] - ✨ — IPv4/IPv6 choice for checks, and choice for a dual-stack check (#69) +- ⚡ — Mutualize check requests (#68) ## 0.6.1 diff --git a/argos/agent.py b/argos/agent.py index 4e17597..ac26c93 100644 --- a/argos/agent.py +++ b/argos/agent.py @@ -41,9 +41,10 @@ class ArgosAgent: # pylint: disable-msg=too-many-instance-attributes self.max_tasks = max_tasks self.wait_time = wait_time self.auth = auth - self._http_client = None - self._http_client_v4 = None - self._http_client_v6 = None + self._http_client: httpx.AsyncClient | None = None + self._http_client_v4: httpx.AsyncClient | None = None + self._http_client_v6: httpx.AsyncClient | None = None + self._res_cache: dict[str, httpx.Response] = {} self.agent_id = socket.gethostname() @@ -51,6 +52,7 @@ class ArgosAgent: # pylint: disable-msg=too-many-instance-attributes async def run(self): auth_header = { "Authorization": f"Bearer {self.auth}", + "User-Agent": f"Argos Panoptes agent {VERSION}", } self._http_client = httpx.AsyncClient(headers=auth_header) @@ -74,37 +76,36 @@ class ArgosAgent: # pylint: disable-msg=too-many-instance-attributes logger.info("Waiting %i seconds before next retry", self.wait_time) await asyncio.sleep(self.wait_time) + async def _do_request(self, group: str, details: dict): + try: + if details["ip_version"] == "4": + response = await self._http_client_v4.request( # type: ignore[union-attr] + method=details["method"], url=details["url"], timeout=60 + ) + else: + response = await self._http_client_v6.request( # type: ignore[union-attr] + method=details["method"], url=details["url"], timeout=60 + ) + except httpx.ReadError: + sleep(1) + if details["ip_version"] == "4": + response = await self._http_client_v4.request( # type: ignore[union-attr] + method=details["method"], url=details["url"], timeout=60 + ) + else: + response = await self._http_client_v6.request( # type: ignore[union-attr] + method=details["method"], url=details["url"], timeout=60 + ) + + self._res_cache[group] = response + async def _complete_task(self, _task: dict) -> AgentResult: try: task = Task(**_task) - url = task.url - if task.check == "http-to-https": - url = str(httpx.URL(task.url).copy_with(scheme="http")) - - try: - if task.ip_version == "4": - response = await self._http_client_v4.request( # type: ignore[attr-defined] - method=task.method, url=url, timeout=60 - ) - else: - response = await self._http_client_v6.request( # type: ignore[attr-defined] - method=task.method, url=url, timeout=60 - ) - except httpx.ReadError: - sleep(1) - if task.ip_version == "4": - response = await self._http_client_v4.request( # type: ignore[attr-defined] - method=task.method, url=url, timeout=60 - ) - else: - response = await self._http_client_v6.request( # type: ignore[attr-defined] - method=task.method, url=url, timeout=60 - ) - check_class = get_registered_check(task.check) check = check_class(task) - result = await check.run(response) + result = await check.run(self._res_cache[task.task_group]) status = result.status context = result.context @@ -123,10 +124,34 @@ class ArgosAgent: # pylint: disable-msg=too-many-instance-attributes ) if response.status_code == httpx.codes.OK: - # XXX Maybe we want to group the tests by URL ? (to issue one request per URL) data = response.json() logger.info("Received %i tasks from the server", len(data)) + req_groups = {} + for _task in data: + task = Task(**_task) + + url = task.url + group = task.task_group + + if task.check == "http-to-https": + url = str(httpx.URL(task.url).copy_with(scheme="http")) + group = f"{task.method}-{task.ip_version}-{url}" + _task["task_group"] = group + + req_groups[group] = { + "url": url, + "ip_version": task.ip_version, + "method": task.method, + } + + requests = [] + for group, details in req_groups.items(): + requests.append(self._do_request(group, details)) + + if requests: + await asyncio.gather(*requests) + tasks = [] for task in data: tasks.append(self._complete_task(task)) diff --git a/argos/schemas/models.py b/argos/schemas/models.py index 8f9daeb..36c5fe8 100644 --- a/argos/schemas/models.py +++ b/argos/schemas/models.py @@ -23,6 +23,7 @@ class Task(BaseModel): check: str method: Method expected: str + task_group: str selected_at: datetime | None selected_by: str | None diff --git a/argos/server/migrations/versions/8b58ced14d6e_add_task_index.py b/argos/server/migrations/versions/8b58ced14d6e_add_task_index.py new file mode 100644 index 0000000..8bbf313 --- /dev/null +++ b/argos/server/migrations/versions/8b58ced14d6e_add_task_index.py @@ -0,0 +1,35 @@ +"""Add task index + +Revision ID: 8b58ced14d6e +Revises: 64f73a79b7d8 +Create Date: 2024-12-03 16:41:44.842213 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "8b58ced14d6e" +down_revision: Union[str, None] = "64f73a79b7d8" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + with op.batch_alter_table("tasks", schema=None) as batch_op: + batch_op.add_column(sa.Column("task_group", sa.String(), nullable=True)) + with op.batch_alter_table("tasks", schema=None) as batch_op: + batch_op.execute( + "UPDATE tasks SET task_group = method || '-' || ip_version || '-' || url" + ) + batch_op.alter_column("task_group", nullable=False) + batch_op.create_index("similar_tasks", ["task_group"], unique=False) + + +def downgrade() -> None: + with op.batch_alter_table("tasks", schema=None) as batch_op: + batch_op.drop_index("similar_tasks") + batch_op.drop_column("task_group") diff --git a/argos/server/models.py b/argos/server/models.py index e4503e4..45b811e 100644 --- a/argos/server/models.py +++ b/argos/server/models.py @@ -9,12 +9,21 @@ from sqlalchemy import ( ForeignKey, ) from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship +from sqlalchemy.schema import Index from argos.checks import BaseCheck, get_registered_check from argos.schemas import WebsiteCheck from argos.schemas.utils import IPVersion, Method +def compute_task_group(context) -> str: + return ( + f"{context.current_parameters['method']}-" + f"{context.current_parameters['ip_version']}-" + f"{context.current_parameters['url']}" + ) + + class Base(DeclarativeBase): type_annotation_map = {List[WebsiteCheck]: JSON, dict: JSON} @@ -62,6 +71,7 @@ class Task(Base): selected_at: Mapped[datetime] = mapped_column(nullable=True) completed_at: Mapped[datetime] = mapped_column(nullable=True) next_run: Mapped[datetime] = mapped_column(nullable=True) + task_group: Mapped[str] = mapped_column(insert_default=compute_task_group) severity: Mapped[Literal["ok", "warning", "critical", "unknown"]] = mapped_column( Enum("ok", "warning", "critical", "unknown", name="severity"), @@ -75,7 +85,7 @@ class Task(Base): passive_deletes=True, ) - def __str__(self): + def __str__(self) -> str: return f"DB Task {self.url} (IPv{self.ip_version}) - {self.check} - {self.expected}" def get_check(self) -> BaseCheck: @@ -117,6 +127,9 @@ class Task(Base): return self.last_result.status +Index("similar_tasks", Task.task_group) + + class Result(Base): """There are multiple results per task. diff --git a/argos/server/queries.py b/argos/server/queries.py index 60deaa8..6489dfe 100644 --- a/argos/server/queries.py +++ b/argos/server/queries.py @@ -4,7 +4,7 @@ from hashlib import sha256 from typing import List from urllib.parse import urljoin -from sqlalchemy import asc, desc, func +from sqlalchemy import asc, desc, func, Select from sqlalchemy.orm import Session from argos import schemas @@ -14,15 +14,16 @@ from argos.server.models import Result, Task, ConfigCache, User async def list_tasks(db: Session, agent_id: str, limit: int = 100): """List tasks and mark them as selected""" - tasks = ( - db.query(Task) + subquery = ( + db.query(func.distinct(Task.task_group)) .filter( Task.selected_by == None, # noqa: E711 ((Task.next_run <= datetime.now()) | (Task.next_run == None)), # noqa: E711 ) .limit(limit) - .all() + .subquery() ) + tasks = db.query(Task).filter(Task.task_group.in_(Select(subquery))).all() now = datetime.now() for task in tasks: diff --git a/tests/test_checks.py b/tests/test_checks.py index 7102cf3..460d5bf 100644 --- a/tests/test_checks.py +++ b/tests/test_checks.py @@ -37,6 +37,7 @@ def ssl_task(now): domain="https://example.org", ip_version="6", method="GET", + task_group="GET-6-https://example.org", check="ssl-certificate-expiration", expected="on-check", selected_at=now,