From 8ac251939810f7978d438795c936516a5e43bfc5 Mon Sep 17 00:00:00 2001 From: Luc Didry Date: Tue, 26 Nov 2024 15:59:19 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20=E2=80=94=20The=20HTTP=20method=20u?= =?UTF-8?q?sed=20by=20checks=20is=20now=20configurable?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 2 + argos/checks/checks.py | 69 +++++++++---------- argos/config-example.yaml | 5 ++ argos/schemas/config.py | 3 +- argos/schemas/models.py | 3 + argos/schemas/utils.py | 5 ++ .../dcf73fa19fce_specify_check_method.py | 45 ++++++++++++ argos/server/models.py | 16 +++++ argos/server/queries.py | 6 +- tests/test_checks.py | 1 + 10 files changed, 118 insertions(+), 37 deletions(-) create mode 100644 argos/server/migrations/versions/dcf73fa19fce_specify_check_method.py diff --git a/CHANGELOG.md b/CHANGELOG.md index e36e838..02528f2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ - 💄 — Show only not-OK domains by default in domains list, to reduce the load on browser - ♿️ — Fix not-OK domains display if javascript is disabled +- ✨ — Retry check right after a httpx.ReadError +- ✨ — The HTTP method used by checks is now configurable ## 0.5.0 diff --git a/argos/checks/checks.py b/argos/checks/checks.py index 3677d36..ad1ff5c 100644 --- a/argos/checks/checks.py +++ b/argos/checks/checks.py @@ -24,16 +24,15 @@ class HTTPStatus(BaseCheck): expected_cls = ExpectedIntValue async def run(self) -> dict: - # XXX Get the method from the task task = self.task try: response = await self.http_client.request( - method="get", url=task.url, timeout=60 + method=task.method, url=task.url, timeout=60 ) except ReadError: sleep(1) response = await self.http_client.request( - method="get", url=task.url, timeout=60 + method=task.method, url=task.url, timeout=60 ) return self.response( @@ -50,16 +49,15 @@ class HTTPStatusIn(BaseCheck): expected_cls = ExpectedStringValue async def run(self) -> dict: - # XXX Get the method from the task task = self.task try: response = await self.http_client.request( - method="get", url=task.url, timeout=60 + method=task.method, url=task.url, timeout=60 ) except ReadError: sleep(1) response = await self.http_client.request( - method="get", url=task.url, timeout=60 + method=task.method, url=task.url, timeout=60 ) return self.response( @@ -79,10 +77,14 @@ class HTTPToHTTPS(BaseCheck): task = self.task url = URL(task.url).copy_with(scheme="http") try: - response = await self.http_client.request(method="get", url=url, timeout=60) + response = await self.http_client.request( + method=task.method, url=url, timeout=60 + ) except ReadError: sleep(1) - response = await self.http_client.request(method="get", url=url, timeout=60) + response = await self.http_client.request( + method=task.method, url=url, timeout=60 + ) expected_dict = json.loads(self.expected) expected = range(300, 400) @@ -108,16 +110,15 @@ class HTTPHeadersContain(BaseCheck): expected_cls = ExpectedStringValue async def run(self) -> dict: - # XXX Get the method from the task task = self.task try: response = await self.http_client.request( - method="get", url=task.url, timeout=60 + method=task.method, url=task.url, timeout=60 ) except ReadError: sleep(1) response = await self.http_client.request( - method="get", url=task.url, timeout=60 + method=task.method, url=task.url, timeout=60 ) status = True @@ -140,16 +141,15 @@ class HTTPHeadersHave(BaseCheck): expected_cls = ExpectedStringValue async def run(self) -> dict: - # XXX Get the method from the task task = self.task try: response = await self.http_client.request( - method="get", url=task.url, timeout=60 + method=task.method, url=task.url, timeout=60 ) except ReadError: sleep(1) response = await self.http_client.request( - method="get", url=task.url, timeout=60 + method=task.method, url=task.url, timeout=60 ) status = True @@ -176,16 +176,15 @@ class HTTPHeadersLike(BaseCheck): expected_cls = ExpectedStringValue async def run(self) -> dict: - # XXX Get the method from the task task = self.task try: response = await self.http_client.request( - method="get", url=task.url, timeout=60 + method=task.method, url=task.url, timeout=60 ) except ReadError: sleep(1) response = await self.http_client.request( - method="get", url=task.url, timeout=60 + method=task.method, url=task.url, timeout=60 ) status = True @@ -213,12 +212,12 @@ class HTTPBodyContains(BaseCheck): async def run(self) -> dict: try: response = await self.http_client.request( - method="get", url=self.task.url, timeout=60 + method=self.task.method, url=self.task.url, timeout=60 ) except ReadError: sleep(1) response = await self.http_client.request( - method="get", url=self.task.url, timeout=60 + method=self.task.method, url=self.task.url, timeout=60 ) return self.response(status=self.expected in response.text) @@ -232,12 +231,12 @@ class HTTPBodyLike(BaseCheck): async def run(self) -> dict: try: response = await self.http_client.request( - method="get", url=self.task.url, timeout=60 + method=self.task.method, url=self.task.url, timeout=60 ) except ReadError: sleep(1) response = await self.http_client.request( - method="get", url=self.task.url, timeout=60 + method=self.task.method, url=self.task.url, timeout=60 ) if re.search(rf"{self.expected}", response.text): return self.response(status=True) @@ -253,16 +252,15 @@ class HTTPJsonContains(BaseCheck): expected_cls = ExpectedStringValue async def run(self) -> dict: - # XXX Get the method from the task task = self.task try: response = await self.http_client.request( - method="get", url=task.url, timeout=60 + method=task.method, url=task.url, timeout=60 ) except ReadError: sleep(1) response = await self.http_client.request( - method="get", url=task.url, timeout=60 + method=task.method, url=task.url, timeout=60 ) obj = response.json() @@ -289,16 +287,15 @@ class HTTPJsonHas(BaseCheck): expected_cls = ExpectedStringValue async def run(self) -> dict: - # XXX Get the method from the task task = self.task try: response = await self.http_client.request( - method="get", url=task.url, timeout=60 + method=task.method, url=task.url, timeout=60 ) except ReadError: sleep(1) response = await self.http_client.request( - method="get", url=task.url, timeout=60 + method=task.method, url=task.url, timeout=60 ) obj = response.json() @@ -329,16 +326,15 @@ class HTTPJsonLike(BaseCheck): expected_cls = ExpectedStringValue async def run(self) -> dict: - # XXX Get the method from the task task = self.task try: response = await self.http_client.request( - method="get", url=task.url, timeout=60 + method=task.method, url=task.url, timeout=60 ) except ReadError: sleep(1) response = await self.http_client.request( - method="get", url=task.url, timeout=60 + method=task.method, url=task.url, timeout=60 ) obj = response.json() @@ -368,16 +364,15 @@ class HTTPJsonIs(BaseCheck): expected_cls = ExpectedStringValue async def run(self) -> dict: - # XXX Get the method from the task task = self.task try: response = await self.http_client.request( - method="get", url=task.url, timeout=60 + method=task.method, url=task.url, timeout=60 ) except ReadError: sleep(1) response = await self.http_client.request( - method="get", url=task.url, timeout=60 + method=task.method, url=task.url, timeout=60 ) obj = response.json() @@ -400,10 +395,14 @@ class SSLCertificateExpiration(BaseCheck): async def run(self): """Returns the number of days in which the certificate will expire.""" try: - response = await self.http_client.get(self.task.url, timeout=60) + response = await self.http_client.request( + method=self.task.method, url=self.task.url, timeout=60 + ) except ReadError: sleep(1) - response = await self.http_client.get(self.task.url, timeout=60) + response = await self.http_client.request( + method=self.task.method, url=self.task.url, timeout=60 + ) network_stream = response.extensions["network_stream"] ssl_obj = network_stream.get_extra_info("ssl_object") diff --git a/argos/config-example.yaml b/argos/config-example.yaml index 2999e12..f0d081e 100644 --- a/argos/config-example.yaml +++ b/argos/config-example.yaml @@ -93,6 +93,11 @@ websites: - domain: "https://mypads.example.org" paths: - path: "/mypads/" + # Specify the method of the HTTP request + # Valid values are "GET", "HEAD", "POST", "OPTIONS", + # "CONNECT", "TRACE", "PUT", "PATCH" and "DELETE" + # default is "GET" if omitted + method: "GET" checks: # Check that the returned HTTP status is 200 - status-is: 200 diff --git a/argos/schemas/config.py b/argos/schemas/config.py index 7cc29d1..07b1c97 100644 --- a/argos/schemas/config.py +++ b/argos/schemas/config.py @@ -22,7 +22,7 @@ from pydantic.networks import UrlConstraints from pydantic_core import Url from typing_extensions import Annotated -from argos.schemas.utils import string_to_duration +from argos.schemas.utils import string_to_duration, Method Severity = Literal["warning", "error", "critical", "unknown"] Environment = Literal["dev", "test", "production"] @@ -104,6 +104,7 @@ def parse_checks(value): class WebsitePath(BaseModel): path: str + method: Method = "GET" checks: List[ Annotated[ Tuple[str, str], diff --git a/argos/schemas/models.py b/argos/schemas/models.py index a297acf..a4a37c2 100644 --- a/argos/schemas/models.py +++ b/argos/schemas/models.py @@ -8,6 +8,8 @@ from typing import Literal from pydantic import BaseModel, ConfigDict +from argos.schemas.utils import Method + # XXX Refactor using SQLModel to avoid duplication of model data @@ -18,6 +20,7 @@ class Task(BaseModel): url: str domain: str check: str + method: Method expected: str selected_at: datetime | None selected_by: str | None diff --git a/argos/schemas/utils.py b/argos/schemas/utils.py index f225241..fe12087 100644 --- a/argos/schemas/utils.py +++ b/argos/schemas/utils.py @@ -1,6 +1,11 @@ from typing import Literal +Method = Literal[ + "GET", "HEAD", "POST", "OPTIONS", "CONNECT", "TRACE", "PUT", "PATCH", "DELETE" +] + + def string_to_duration( value: str, target: Literal["days", "hours", "minutes"] ) -> int | float: diff --git a/argos/server/migrations/versions/dcf73fa19fce_specify_check_method.py b/argos/server/migrations/versions/dcf73fa19fce_specify_check_method.py new file mode 100644 index 0000000..f218108 --- /dev/null +++ b/argos/server/migrations/versions/dcf73fa19fce_specify_check_method.py @@ -0,0 +1,45 @@ +"""Specify check method + +Revision ID: dcf73fa19fce +Revises: c780864dc407 +Create Date: 2024-11-26 14:40:27.510587 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "dcf73fa19fce" +down_revision: Union[str, None] = "c780864dc407" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + with op.batch_alter_table("tasks", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "method", + sa.Enum( + "GET", + "HEAD", + "POST", + "OPTIONS", + "CONNECT", + "TRACE", + "PUT", + "PATCH", + "DELETE", + name="method", + ), + nullable=False, + ) + ) + + +def downgrade() -> None: + with op.batch_alter_table("tasks", schema=None) as batch_op: + batch_op.drop_column("method") diff --git a/argos/server/models.py b/argos/server/models.py index 5a03399..d88fa39 100644 --- a/argos/server/models.py +++ b/argos/server/models.py @@ -12,6 +12,7 @@ from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from argos.checks import BaseCheck, get_registered_check from argos.schemas import WebsiteCheck +from argos.schemas.utils import Method class Base(DeclarativeBase): @@ -35,6 +36,21 @@ class Task(Base): check: Mapped[str] = mapped_column() expected: Mapped[str] = mapped_column() frequency: Mapped[int] = mapped_column() + method: Mapped[Method] = mapped_column( + Enum( + "GET", + "HEAD", + "POST", + "OPTIONS", + "CONNECT", + "TRACE", + "PUT", + "PATCH", + "DELETE", + name="method", + ), + insert_default="GET", + ) # Orchestration-related selected_by: Mapped[str] = mapped_column(nullable=True) diff --git a/argos/server/queries.py b/argos/server/queries.py index e887ebe..1369e61 100644 --- a/argos/server/queries.py +++ b/argos/server/queries.py @@ -146,6 +146,7 @@ async def update_from_config(db: Session, config: schemas.Config): db.query(Task) .filter( Task.url == url, + Task.method == p.method, Task.check == check_key, Task.expected == expected, ) @@ -159,8 +160,10 @@ async def update_from_config(db: Session, config: schemas.Config): existing_task.frequency = frequency logger.debug( "Skipping db task creation for url=%s, " - "check_key=%s, expected=%s, frequency=%s.", + "method=%s, check_key=%s, expected=%s, " + "frequency=%s.", url, + p.method, check_key, expected, frequency, @@ -173,6 +176,7 @@ async def update_from_config(db: Session, config: schemas.Config): task = Task( domain=domain, url=url, + method=p.method, check=check_key, expected=expected, frequency=frequency, diff --git a/tests/test_checks.py b/tests/test_checks.py index fc660b5..df58b04 100644 --- a/tests/test_checks.py +++ b/tests/test_checks.py @@ -35,6 +35,7 @@ def ssl_task(now): id=1, url="https://example.org", domain="https://example.org", + method="GET", check="ssl-certificate-expiration", expected="on-check", selected_at=now,