Compare commits

..

No commits in common. "0563cf185a0122fc3bacc9a6aa244563e88d58d5" and "a31c12e037feb0c96ca6820212b467adbedc707c" have entirely different histories.

16 changed files with 110 additions and 227 deletions

View file

@ -4,11 +4,6 @@
- 💄 — Show only not-OK domains by default in domains list, to reduce the load on browser
- ♿️ — Fix not-OK domains display if javascript is disabled
- ✨ — Retry check right after a httpx.ReadError
- ✨ — The HTTP method used by checks is now configurable
- ♻️ — Refactor some agent code
- 💄 — Filter form on domains list (#66)
- ✨ — Add "Remember me" checkbox on login (#65)
## 0.5.0

View file

@ -6,7 +6,6 @@ import asyncio
import json
import logging
import socket
from time import sleep
from typing import List
import httpx
@ -64,24 +63,9 @@ class ArgosAgent:
async def _complete_task(self, _task: dict) -> AgentResult:
try:
task = Task(**_task)
url = task.url
if task.check == "http-to-https":
url = str(httpx.URL(task.url).copy_with(scheme="http"))
try:
response = await self._http_client.request( # type: ignore[attr-defined]
method=task.method, url=url, timeout=60
)
except httpx.ReadError:
sleep(1)
response = await self._http_client.request( # type: ignore[attr-defined]
method=task.method, url=url, timeout=60
)
check_class = get_registered_check(task.check)
check = check_class(task)
result = await check.run(response)
check = check_class(self._http_client, task)
result = await check.run()
status = result.status
context = result.context

View file

@ -3,6 +3,7 @@
from dataclasses import dataclass
from typing import Type
import httpx
from pydantic import BaseModel
from argos.schemas.models import Task
@ -91,7 +92,8 @@ class BaseCheck:
raise CheckNotFound(name)
return check
def __init__(self, task: Task):
def __init__(self, http_client: httpx.AsyncClient, task: Task):
self.http_client = http_client
self.task = task
@property

View file

@ -4,7 +4,7 @@ import json
import re
from datetime import datetime
from httpx import Response
from httpx import URL
from jsonpointer import resolve_pointer, JsonPointerException
from argos.checks.base import (
@ -22,7 +22,13 @@ class HTTPStatus(BaseCheck):
config = "status-is"
expected_cls = ExpectedIntValue
async def run(self, response: Response) -> dict:
async def run(self) -> dict:
# XXX Get the method from the task
task = self.task
response = await self.http_client.request(
method="get", url=task.url, timeout=60
)
return self.response(
status=response.status_code == self.expected,
expected=self.expected,
@ -36,7 +42,13 @@ class HTTPStatusIn(BaseCheck):
config = "status-in"
expected_cls = ExpectedStringValue
async def run(self, response: Response) -> dict:
async def run(self) -> dict:
# XXX Get the method from the task
task = self.task
response = await self.http_client.request(
method="get", url=task.url, timeout=60
)
return self.response(
status=response.status_code in json.loads(self.expected),
expected=self.expected,
@ -50,7 +62,11 @@ class HTTPToHTTPS(BaseCheck):
config = "http-to-https"
expected_cls = ExpectedStringValue
async def run(self, response: Response) -> dict:
async def run(self) -> dict:
task = self.task
url = URL(task.url).copy_with(scheme="http")
response = await self.http_client.request(method="get", url=url, timeout=60)
expected_dict = json.loads(self.expected)
expected = range(300, 400)
if "range" in expected_dict:
@ -74,7 +90,13 @@ class HTTPHeadersContain(BaseCheck):
config = "headers-contain"
expected_cls = ExpectedStringValue
async def run(self, response: Response) -> dict:
async def run(self) -> dict:
# XXX Get the method from the task
task = self.task
response = await self.http_client.request(
method="get", url=task.url, timeout=60
)
status = True
for header in json.loads(self.expected):
if header not in response.headers:
@ -94,7 +116,13 @@ class HTTPHeadersHave(BaseCheck):
config = "headers-have"
expected_cls = ExpectedStringValue
async def run(self, response: Response) -> dict:
async def run(self) -> dict:
# XXX Get the method from the task
task = self.task
response = await self.http_client.request(
method="get", url=task.url, timeout=60
)
status = True
for header, value in json.loads(self.expected).items():
if header not in response.headers:
@ -118,7 +146,13 @@ class HTTPHeadersLike(BaseCheck):
config = "headers-like"
expected_cls = ExpectedStringValue
async def run(self, response: Response) -> dict:
async def run(self) -> dict:
# XXX Get the method from the task
task = self.task
response = await self.http_client.request(
method="get", url=task.url, timeout=60
)
status = True
for header, value in json.loads(self.expected).items():
if header not in response.headers:
@ -141,7 +175,10 @@ class HTTPBodyContains(BaseCheck):
config = "body-contains"
expected_cls = ExpectedStringValue
async def run(self, response: Response) -> dict:
async def run(self) -> dict:
response = await self.http_client.request(
method="get", url=self.task.url, timeout=60
)
return self.response(status=self.expected in response.text)
@ -151,7 +188,10 @@ class HTTPBodyLike(BaseCheck):
config = "body-like"
expected_cls = ExpectedStringValue
async def run(self, response: Response) -> dict:
async def run(self) -> dict:
response = await self.http_client.request(
method="get", url=self.task.url, timeout=60
)
if re.search(rf"{self.expected}", response.text):
return self.response(status=True)
@ -165,7 +205,13 @@ class HTTPJsonContains(BaseCheck):
config = "json-contains"
expected_cls = ExpectedStringValue
async def run(self, response: Response) -> dict:
async def run(self) -> dict:
# XXX Get the method from the task
task = self.task
response = await self.http_client.request(
method="get", url=task.url, timeout=60
)
obj = response.json()
status = True
@ -189,7 +235,13 @@ class HTTPJsonHas(BaseCheck):
config = "json-has"
expected_cls = ExpectedStringValue
async def run(self, response: Response) -> dict:
async def run(self) -> dict:
# XXX Get the method from the task
task = self.task
response = await self.http_client.request(
method="get", url=task.url, timeout=60
)
obj = response.json()
status = True
@ -217,7 +269,13 @@ class HTTPJsonLike(BaseCheck):
config = "json-like"
expected_cls = ExpectedStringValue
async def run(self, response: Response) -> dict:
async def run(self) -> dict:
# XXX Get the method from the task
task = self.task
response = await self.http_client.request(
method="get", url=task.url, timeout=60
)
obj = response.json()
status = True
@ -244,7 +302,13 @@ class HTTPJsonIs(BaseCheck):
config = "json-is"
expected_cls = ExpectedStringValue
async def run(self, response: Response) -> dict:
async def run(self) -> dict:
# XXX Get the method from the task
task = self.task
response = await self.http_client.request(
method="get", url=task.url, timeout=60
)
obj = response.json()
status = response.json() == json.loads(self.expected)
@ -262,8 +326,10 @@ class SSLCertificateExpiration(BaseCheck):
config = "ssl-certificate-expiration"
expected_cls = ExpectedStringValue
async def run(self, response: Response) -> dict:
async def run(self):
"""Returns the number of days in which the certificate will expire."""
response = await self.http_client.get(self.task.url, timeout=60)
network_stream = response.extensions["network_stream"]
ssl_obj = network_stream.get_extra_info("ssl_object")
cert = ssl_obj.getpeercert()

View file

@ -14,19 +14,9 @@ general:
# Can be "production", "dev", "test".
# If not present, default value is "production"
env: "production"
# To get a good string for cookie_secret, run:
# to get a good string for cookie_secret, run:
# openssl rand -hex 32
cookie_secret: "foo_bar_baz"
# Session duration
# Use m for minutes, h for hours, d for days
# w for weeks, mo for months, y for years
# If not present, default value is "7d"
session_duration: "7d"
# Session opened with "Remember me" checked
# If not present, the "Remember me" feature is not available
# remember_me_duration: "1mo"
# Default delay for checks.
# Can be superseeded in domain configuration.
# For ex., to run checks every minute:
@ -103,11 +93,6 @@ websites:
- domain: "https://mypads.example.org"
paths:
- path: "/mypads/"
# Specify the method of the HTTP request
# Valid values are "GET", "HEAD", "POST", "OPTIONS",
# "CONNECT", "TRACE", "PUT", "PATCH" and "DELETE"
# default is "GET" if omitted
method: "GET"
checks:
# Check that the returned HTTP status is 200
- status-is: 200

View file

@ -22,7 +22,7 @@ from pydantic.networks import UrlConstraints
from pydantic_core import Url
from typing_extensions import Annotated
from argos.schemas.utils import string_to_duration, Method
from argos.schemas.utils import string_to_duration
Severity = Literal["warning", "error", "critical", "unknown"]
Environment = Literal["dev", "test", "production"]
@ -104,7 +104,6 @@ def parse_checks(value):
class WebsitePath(BaseModel):
path: str
method: Method = "GET"
checks: List[
Annotated[
Tuple[str, str],
@ -175,31 +174,16 @@ class DbSettings(BaseModel):
class General(BaseModel):
"""Frequency for the checks and alerts"""
cookie_secret: str
frequency: int
db: DbSettings
env: Environment = "production"
cookie_secret: str
session_duration: int = 10080 # 7 days
remember_me_duration: Optional[int] = None
frequency: int
root_path: str = ""
alerts: Alert
mail: Optional[Mail] = None
gotify: Optional[List[GotifyUrl]] = None
apprise: Optional[Dict[str, List[str]]] = None
@field_validator("session_duration", mode="before")
def parse_session_duration(cls, value):
"""Convert the configured session duration to minutes"""
return string_to_duration(value, "minutes")
@field_validator("remember_me_duration", mode="before")
def parse_remember_me_duration(cls, value):
"""Convert the configured session duration with remember me feature to minutes"""
if value:
return string_to_duration(value, "minutes")
return None
@field_validator("frequency", mode="before")
def parse_frequency(cls, value):
"""Convert the configured frequency to minutes"""

View file

@ -8,8 +8,6 @@ from typing import Literal
from pydantic import BaseModel, ConfigDict
from argos.schemas.utils import Method
# XXX Refactor using SQLModel to avoid duplication of model data
@ -20,7 +18,6 @@ class Task(BaseModel):
url: str
domain: str
check: str
method: Method
expected: str
selected_at: datetime | None
selected_by: str | None

View file

@ -1,11 +1,6 @@
from typing import Literal
Method = Literal[
"GET", "HEAD", "POST", "OPTIONS", "CONNECT", "TRACE", "PUT", "PATCH", "DELETE"
]
def string_to_duration(
value: str, target: Literal["days", "hours", "minutes"]
) -> int | float:

View file

@ -1,45 +0,0 @@
"""Specify check method
Revision ID: dcf73fa19fce
Revises: c780864dc407
Create Date: 2024-11-26 14:40:27.510587
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "dcf73fa19fce"
down_revision: Union[str, None] = "c780864dc407"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
with op.batch_alter_table("tasks", schema=None) as batch_op:
batch_op.add_column(
sa.Column(
"method",
sa.Enum(
"GET",
"HEAD",
"POST",
"OPTIONS",
"CONNECT",
"TRACE",
"PUT",
"PATCH",
"DELETE",
name="method",
),
nullable=False,
)
)
def downgrade() -> None:
with op.batch_alter_table("tasks", schema=None) as batch_op:
batch_op.drop_column("method")

View file

@ -12,7 +12,6 @@ from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from argos.checks import BaseCheck, get_registered_check
from argos.schemas import WebsiteCheck
from argos.schemas.utils import Method
class Base(DeclarativeBase):
@ -36,21 +35,6 @@ class Task(Base):
check: Mapped[str] = mapped_column()
expected: Mapped[str] = mapped_column()
frequency: Mapped[int] = mapped_column()
method: Mapped[Method] = mapped_column(
Enum(
"GET",
"HEAD",
"POST",
"OPTIONS",
"CONNECT",
"TRACE",
"PUT",
"PATCH",
"DELETE",
name="method",
),
insert_default="GET",
)
# Orchestration-related
selected_by: Mapped[str] = mapped_column(nullable=True)

View file

@ -146,7 +146,6 @@ async def update_from_config(db: Session, config: schemas.Config):
db.query(Task)
.filter(
Task.url == url,
Task.method == p.method,
Task.check == check_key,
Task.expected == expected,
)
@ -160,10 +159,8 @@ async def update_from_config(db: Session, config: schemas.Config):
existing_task.frequency = frequency
logger.debug(
"Skipping db task creation for url=%s, "
"method=%s, check_key=%s, expected=%s, "
"frequency=%s.",
"check_key=%s, expected=%s, frequency=%s.",
url,
p.method,
check_key,
expected,
frequency,
@ -176,7 +173,6 @@ async def update_from_config(db: Session, config: schemas.Config):
task = Task(
domain=domain,
url=url,
method=p.method,
check=check_key,
expected=expected,
frequency=frequency,

View file

@ -28,11 +28,7 @@ SEVERITY_LEVELS = {"ok": 1, "warning": 2, "critical": 3, "unknown": 4}
@route.get("/login")
async def login_view(
request: Request,
msg: str | None = None,
config: Config = Depends(get_config),
):
async def login_view(request: Request, msg: str | None = None):
token = request.cookies.get("access-token")
if token is not None and token != "":
manager = request.app.state.manager
@ -48,14 +44,7 @@ async def login_view(
else:
msg = None
return templates.TemplateResponse(
"login.html",
{
"request": request,
"msg": msg,
"remember": config.general.remember_me_duration,
},
)
return templates.TemplateResponse("login.html", {"request": request, "msg": msg})
@route.post("/login")
@ -63,8 +52,6 @@ async def post_login(
request: Request,
db: Session = Depends(get_db),
data: OAuth2PasswordRequestForm = Depends(),
rememberme: Annotated[str | None, Form()] = None,
config: Config = Depends(get_config),
):
username = data.username
user = await queries.get_user(db, username)
@ -83,22 +70,14 @@ async def post_login(
db.commit()
manager = request.app.state.manager
session_duration = config.general.session_duration
if config.general.remember_me_duration is not None and rememberme == "on":
session_duration = config.general.remember_me_duration
delta = timedelta(minutes=session_duration)
token = manager.create_access_token(data={"sub": username}, expires=delta)
token = manager.create_access_token(
data={"sub": username}, expires=timedelta(days=7)
)
response = RedirectResponse(
request.url_for("get_severity_counts_view"),
status_code=status.HTTP_303_SEE_OTHER,
)
response.set_cookie(
key=manager.cookie_name,
value=token,
httponly=True,
samesite="strict",
expires=int(delta.total_seconds()),
)
manager.set_cookie(response, token)
return response

View file

@ -13,15 +13,7 @@
</li>
</ul>
{# djlint:off H021 #}
<ul id="js-only" style="display: none; ">{# djlint:on #}
<li>
<input id="domain-search"
type="search"
spellcheck="false"
placeholder="Filter domains list"
aria-label="Filter domains list"
/>
</li>
<ul id="status-selector" style="display: none;">{# djlint:on #}
<li>
<label for="select-status">Show domains with status:</label>
<select id="select-status">
@ -46,8 +38,7 @@
<tbody id="domains-body">
{% for (domain, status) in domains %}
<tr data-status="{{ status }}"
data-domain="{{ domain }}">
<tr data-status="{{ status }}">
<td>
<a href="{{ url_for('get_domain_tasks_view', domain=domain) }}">
{{ domain }}
@ -71,46 +62,29 @@
</table>
</div>
<script>
function filterDomains(e) {
let status = document.getElementById('select-status');
let filter = document.getElementById('domain-search').value;
console.log(filter)
if (status.value === 'all') {
document.getElementById('select-status').addEventListener('change', (e) => {
if (e.currentTarget.value === 'all') {
document.querySelectorAll('[data-status]').forEach((item) => {
if (filter && item.dataset.domain.indexOf(filter) == -1) {
item.style.display = 'none';
} else {
item.style.display = null;
}
item.style.display = null;
})
} else if (status.value === 'not-ok') {
} else if (e.currentTarget.value === 'not-ok') {
document.querySelectorAll('[data-status]').forEach((item) => {
if (item.dataset.status !== 'ok') {
if (filter && item.dataset.domain.indexOf(filter) == -1) {
item.style.display = 'none';
} else {
item.style.display = null;
}
item.style.display = null;
} else {
item.style.display = 'none';
}
})
} else {
document.querySelectorAll('[data-status]').forEach((item) => {
if (item.dataset.status === status.value) {
if (filter && item.dataset.domain.indexOf(filter) == -1) {
item.style.display = 'none';
} else {
item.style.display = null;
}
if (item.dataset.status === e.currentTarget.value) {
item.style.display = null;
} else {
item.style.display = 'none';
}
})
}
}
document.getElementById('select-status').addEventListener('change', filterDomains);
document.getElementById('domain-search').addEventListener('input', filterDomains);
});
document.querySelectorAll('[data-status]').forEach((item) => {
if (item.dataset.status !== 'ok') {
item.style.display = null;
@ -118,6 +92,6 @@
item.style.display = 'none';
}
})
document.getElementById('js-only').style.display = null;
document.getElementById('status-selector').style.display = null;
</script>
{% endblock content %}

View file

@ -16,14 +16,6 @@
name="password"
type="password"
form="login">
{% if remember is not none %}
<label>
<input type="checkbox"
name="rememberme"
form="login">
Remember me
</label>
{% endif %}
<form id="login"
method="post"
action="{{ url_for('post_login') }}">

View file

@ -10,8 +10,7 @@ First, do your changes in the code, change the model, add new tables, etc. Once
you're done, you can create a new migration.
```bash
venv/bin/alembic -c argos/server/migrations/alembic.ini revision \
--autogenerate -m "migration reason"
venv/bin/alembic -c argos/server/migrations/alembic.ini revision --autogenerate -m "migration reason"
```
Edit the created file to remove comments and adapt it to make sure the migration is complete (Alembic is not powerful enough to cover all the corner cases).

View file

@ -35,7 +35,6 @@ def ssl_task(now):
id=1,
url="https://example.org",
domain="https://example.org",
method="GET",
check="ssl-certificate-expiration",
expected="on-check",
selected_at=now,
@ -52,9 +51,6 @@ async def test_ssl_check_accepts_statuts(
return_value=httpx.Response(http_status, extensions=httpx_extensions_ssl),
)
async with httpx.AsyncClient() as client:
check = SSLCertificateExpiration(ssl_task)
response = await client.request(
method=ssl_task.method, url=ssl_task.url, timeout=60
)
check_response = await check.run(response)
check = SSLCertificateExpiration(client, ssl_task)
check_response = await check.run()
assert check_response.status == "on-check"