diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 30c8c7a..6f3ef20 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,3 +1,4 @@ +--- image: python:3.11 stages: @@ -18,6 +19,9 @@ default: install: stage: install + before_script: + - apt-get update + - apt-get install -y build-essential libldap-dev libsasl2-dev script: - make venv - make develop @@ -64,7 +68,7 @@ release_job: - if: $CI_COMMIT_TAG script: - sed -n '/^## '$CI_COMMIT_TAG'/,/^#/p' CHANGELOG.md | sed -e '/^\(#\|$\|Date\)/d' > release.md - release: # See https://docs.gitlab.com/ee/ci/yaml/#release for available properties + release: # See https://docs.gitlab.com/ee/ci/yaml/#release for available properties tag_name: '$CI_COMMIT_TAG' description: './release.md' assets: diff --git a/CHANGELOG.md b/CHANGELOG.md index a73ebca..42f7dc2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,20 @@ ## [Unreleased] +- 💄 — Show only not-OK domains by default in domains list, to reduce the load on browser +- ♿️ — Fix not-OK domains display if javascript is disabled +- ✨ — Retry check right after a httpx.ReadError +- ✨ — The HTTP method used by checks is now configurable +- ♻️ — Refactor some agent code +- 💄 — Filter form on domains list (#66) +- ✨ — Add "Remember me" checkbox on login (#65) +- ✨ — Add a setting to set a reschedule delay if check failed (#67) + BREAKING CHANGE: `mo` is no longer accepted for declaring a duration in month in the configuration + You need to use `M`, `month` or `months` +- ✨ - Allow to choose a frequency smaller than a minute +- ✨🛂 — Allow partial or total anonymous access to web interface (#63) +- ✨🛂 — Allow to use a LDAP server for authentication (#64) + ## 0.5.0 Date: 2024-09-26 @@ -68,7 +82,7 @@ Date: 2024-06-24 - 💄📯 — Improve notifications and result(s) pages - 🔊 — Add level of log before the log message -— 🔊 — Add a warning messages in the logs if there is no tasks in database. (fix #41) +- 🔊 — Add a warning message in the logs if there is no tasks in database. (fix #41) - ✨ — Add command to generate example configuration (fix #38) - 📝 — Improve documentation - ✨ — Add command to warn if it’s been long since last viewing an agent (fix #49) diff --git a/Makefile b/Makefile index 9d6bec1..86e1737 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ NC=\033[0m # No Color venv: ## Create the venv python3 -m venv venv develop: venv ## Install the dev dependencies - venv/bin/pip install -e ".[dev,docs]" + venv/bin/pip install -e ".[dev,docs,ldap]" docs: cog ## Build the docs venv/bin/sphinx-build docs public if [ ! -e "public/mermaid.min.js" ]; then curl -sL $$(grep mermaid.min.js public/search.html | cut -f 2 -d '"') --output public/mermaid.min.js; fi diff --git a/argos/agent.py b/argos/agent.py index 580553d..d791641 100644 --- a/argos/agent.py +++ b/argos/agent.py @@ -6,6 +6,7 @@ import asyncio import json import logging import socket +from time import sleep from typing import List import httpx @@ -63,9 +64,24 @@ class ArgosAgent: async def _complete_task(self, _task: dict) -> AgentResult: try: task = Task(**_task) + + url = task.url + if task.check == "http-to-https": + url = str(httpx.URL(task.url).copy_with(scheme="http")) + + try: + response = await self._http_client.request( # type: ignore[attr-defined] + method=task.method, url=url, timeout=60 + ) + except httpx.ReadError: + sleep(1) + response = await self._http_client.request( # type: ignore[attr-defined] + method=task.method, url=url, timeout=60 + ) + check_class = get_registered_check(task.check) - check = check_class(self._http_client, task) - result = await check.run() + check = check_class(task) + result = await check.run(response) status = result.status context = result.context diff --git a/argos/checks/base.py b/argos/checks/base.py index 3221809..8206714 100644 --- a/argos/checks/base.py +++ b/argos/checks/base.py @@ -3,7 +3,6 @@ from dataclasses import dataclass from typing import Type -import httpx from pydantic import BaseModel from argos.schemas.models import Task @@ -92,8 +91,7 @@ class BaseCheck: raise CheckNotFound(name) return check - def __init__(self, http_client: httpx.AsyncClient, task: Task): - self.http_client = http_client + def __init__(self, task: Task): self.task = task @property diff --git a/argos/checks/checks.py b/argos/checks/checks.py index e729f72..b30977a 100644 --- a/argos/checks/checks.py +++ b/argos/checks/checks.py @@ -4,7 +4,7 @@ import json import re from datetime import datetime -from httpx import URL +from httpx import Response from jsonpointer import resolve_pointer, JsonPointerException from argos.checks.base import ( @@ -22,13 +22,7 @@ class HTTPStatus(BaseCheck): config = "status-is" expected_cls = ExpectedIntValue - async def run(self) -> dict: - # XXX Get the method from the task - task = self.task - response = await self.http_client.request( - method="get", url=task.url, timeout=60 - ) - + async def run(self, response: Response) -> dict: return self.response( status=response.status_code == self.expected, expected=self.expected, @@ -42,13 +36,7 @@ class HTTPStatusIn(BaseCheck): config = "status-in" expected_cls = ExpectedStringValue - async def run(self) -> dict: - # XXX Get the method from the task - task = self.task - response = await self.http_client.request( - method="get", url=task.url, timeout=60 - ) - + async def run(self, response: Response) -> dict: return self.response( status=response.status_code in json.loads(self.expected), expected=self.expected, @@ -62,11 +50,7 @@ class HTTPToHTTPS(BaseCheck): config = "http-to-https" expected_cls = ExpectedStringValue - async def run(self) -> dict: - task = self.task - url = URL(task.url).copy_with(scheme="http") - response = await self.http_client.request(method="get", url=url, timeout=60) - + async def run(self, response: Response) -> dict: expected_dict = json.loads(self.expected) expected = range(300, 400) if "range" in expected_dict: @@ -90,13 +74,7 @@ class HTTPHeadersContain(BaseCheck): config = "headers-contain" expected_cls = ExpectedStringValue - async def run(self) -> dict: - # XXX Get the method from the task - task = self.task - response = await self.http_client.request( - method="get", url=task.url, timeout=60 - ) - + async def run(self, response: Response) -> dict: status = True for header in json.loads(self.expected): if header not in response.headers: @@ -116,13 +94,7 @@ class HTTPHeadersHave(BaseCheck): config = "headers-have" expected_cls = ExpectedStringValue - async def run(self) -> dict: - # XXX Get the method from the task - task = self.task - response = await self.http_client.request( - method="get", url=task.url, timeout=60 - ) - + async def run(self, response: Response) -> dict: status = True for header, value in json.loads(self.expected).items(): if header not in response.headers: @@ -146,13 +118,7 @@ class HTTPHeadersLike(BaseCheck): config = "headers-like" expected_cls = ExpectedStringValue - async def run(self) -> dict: - # XXX Get the method from the task - task = self.task - response = await self.http_client.request( - method="get", url=task.url, timeout=60 - ) - + async def run(self, response: Response) -> dict: status = True for header, value in json.loads(self.expected).items(): if header not in response.headers: @@ -175,10 +141,7 @@ class HTTPBodyContains(BaseCheck): config = "body-contains" expected_cls = ExpectedStringValue - async def run(self) -> dict: - response = await self.http_client.request( - method="get", url=self.task.url, timeout=60 - ) + async def run(self, response: Response) -> dict: return self.response(status=self.expected in response.text) @@ -188,10 +151,7 @@ class HTTPBodyLike(BaseCheck): config = "body-like" expected_cls = ExpectedStringValue - async def run(self) -> dict: - response = await self.http_client.request( - method="get", url=self.task.url, timeout=60 - ) + async def run(self, response: Response) -> dict: if re.search(rf"{self.expected}", response.text): return self.response(status=True) @@ -205,13 +165,7 @@ class HTTPJsonContains(BaseCheck): config = "json-contains" expected_cls = ExpectedStringValue - async def run(self) -> dict: - # XXX Get the method from the task - task = self.task - response = await self.http_client.request( - method="get", url=task.url, timeout=60 - ) - + async def run(self, response: Response) -> dict: obj = response.json() status = True @@ -235,13 +189,7 @@ class HTTPJsonHas(BaseCheck): config = "json-has" expected_cls = ExpectedStringValue - async def run(self) -> dict: - # XXX Get the method from the task - task = self.task - response = await self.http_client.request( - method="get", url=task.url, timeout=60 - ) - + async def run(self, response: Response) -> dict: obj = response.json() status = True @@ -269,13 +217,7 @@ class HTTPJsonLike(BaseCheck): config = "json-like" expected_cls = ExpectedStringValue - async def run(self) -> dict: - # XXX Get the method from the task - task = self.task - response = await self.http_client.request( - method="get", url=task.url, timeout=60 - ) - + async def run(self, response: Response) -> dict: obj = response.json() status = True @@ -302,13 +244,7 @@ class HTTPJsonIs(BaseCheck): config = "json-is" expected_cls = ExpectedStringValue - async def run(self) -> dict: - # XXX Get the method from the task - task = self.task - response = await self.http_client.request( - method="get", url=task.url, timeout=60 - ) - + async def run(self, response: Response) -> dict: obj = response.json() status = response.json() == json.loads(self.expected) @@ -326,10 +262,8 @@ class SSLCertificateExpiration(BaseCheck): config = "ssl-certificate-expiration" expected_cls = ExpectedStringValue - async def run(self): + async def run(self, response: Response) -> dict: """Returns the number of days in which the certificate will expire.""" - response = await self.http_client.get(self.task.url, timeout=60) - network_stream = response.extensions["network_stream"] ssl_obj = network_stream.get_extra_info("ssl_object") cert = ssl_obj.getpeercert() diff --git a/argos/config-example.yaml b/argos/config-example.yaml index 2999e12..b20c48f 100644 --- a/argos/config-example.yaml +++ b/argos/config-example.yaml @@ -1,5 +1,7 @@ --- general: + # Except for frequency and recheck_delay settings, changes in general + # section of the configuration will need a restart of argos server. db: # The database URL, as defined in SQLAlchemy docs : # https://docs.sqlalchemy.org/en/20/core/engines.html#database-urls @@ -14,13 +16,54 @@ general: # Can be "production", "dev", "test". # If not present, default value is "production" env: "production" - # to get a good string for cookie_secret, run: + # To get a good string for cookie_secret, run: # openssl rand -hex 32 cookie_secret: "foo_bar_baz" + + # Session duration + # Use m for minutes, h for hours, d for days + # w for weeks, M for months, y for years + # See https://github.com/timwedde/durations_nlp#scales-reference for details + # If not present, default value is "7d" + session_duration: "7d" + # Session opened with "Remember me" checked + # If not present, the "Remember me" feature is not available + # remember_me_duration: "1M" + + # Unauthenticated access + # If can grant an unauthenticated access to the dashboard or to all pages + # To do so, choose either "dashboard", or "all" + # If not present, all pages needs authentication + # unauthenticated_access: "all" + + # LDAP authentication + # Instead of relying on Argos’ users, use a LDAP server to authenticate users. + # If not present, Argos’ native user system is used. + # ldap: + # # Server URI + # uri: "ldaps://ldap.example.org" + # # Search base DN + # user_tree: "ou=users,dc=example,dc=org" + # # Search bind DN + # bind_dn: "uid=ldap_user,ou=users,dc=example,dc=org" + # # Search bind password + # bind_pwd: "secr3t" + # # User attribute (uid, mail, sAMAccountName, etc.) + # user_attr: "uid" + # # User filter (to exclude some users, etc.) + # user_filter: "(!(uid=ldap_user))" + # Default delay for checks. # Can be superseeded in domain configuration. - # For ex., to run checks every minute: - frequency: "1m" + # For ex., to run checks every 5 minutes: + frequency: "5m" + # Default re-check delay if a check has failed. + # Can be superseeded in domain configuration. + # If not present, failed checked won’t be re-checked (they will be + # run again like if they succeded + # For ex., to re-try a check one minute after a failure: + # recheck_delay: "1m" + # Which way do you want to be warned when a check goes to that severity? # "local" emits a message in the server log # You’ll need to configure mail, gotify or apprise below to be able to use @@ -93,6 +136,11 @@ websites: - domain: "https://mypads.example.org" paths: - path: "/mypads/" + # Specify the method of the HTTP request + # Valid values are "GET", "HEAD", "POST", "OPTIONS", + # "CONNECT", "TRACE", "PUT", "PATCH" and "DELETE" + # default is "GET" if omitted + method: "GET" checks: # Check that the returned HTTP status is 200 - status-is: 200 @@ -164,6 +212,7 @@ websites: - json-is: '{"foo": "bar", "baz": 42}' - domain: "https://munin.example.org" frequency: "20m" + recheck_delay: "5m" paths: - path: "/" checks: diff --git a/argos/schemas/config.py b/argos/schemas/config.py index 7cc29d1..13119b0 100644 --- a/argos/schemas/config.py +++ b/argos/schemas/config.py @@ -5,8 +5,9 @@ For database models, see argos.server.models. import json -from typing import Dict, List, Literal, Optional, Tuple +from typing import Dict, List, Literal, Tuple +from durations_nlp import Duration from pydantic import ( BaseModel, ConfigDict, @@ -22,10 +23,11 @@ from pydantic.networks import UrlConstraints from pydantic_core import Url from typing_extensions import Annotated -from argos.schemas.utils import string_to_duration +from argos.schemas.utils import Method Severity = Literal["warning", "error", "critical", "unknown"] Environment = Literal["dev", "test", "production"] +Unauthenticated = Literal["dashboard", "all"] SQLiteDsn = Annotated[ Url, UrlConstraints( @@ -37,7 +39,7 @@ SQLiteDsn = Annotated[ def parse_threshold(value): """Parse duration threshold for SSL certificate validity""" for duration_str, severity in value.items(): - days = string_to_duration(duration_str, "days") + days = Duration(duration_str).to_days() # Return here because it's one-item dicts. return (days, severity) @@ -104,6 +106,7 @@ def parse_checks(value): class WebsitePath(BaseModel): path: str + method: Method = "GET" checks: List[ Annotated[ Tuple[str, str], @@ -114,14 +117,23 @@ class WebsitePath(BaseModel): class Website(BaseModel): domain: HttpUrl - frequency: Optional[int] = None + frequency: float | None = None + recheck_delay: float | None = None paths: List[WebsitePath] @field_validator("frequency", mode="before") def parse_frequency(cls, value): """Convert the configured frequency to minutes""" if value: - return string_to_duration(value, "minutes") + return Duration(value).to_minutes() + + return None + + @field_validator("recheck_delay", mode="before") + def parse_recheck_delay(cls, value): + """Convert the configured recheck delay to minutes""" + if value: + return Duration(value).to_minutes() return None @@ -147,7 +159,7 @@ class Mail(BaseModel): port: PositiveInt = 25 ssl: StrictBool = False starttls: StrictBool = False - auth: Optional[MailAuth] = None + auth: MailAuth | None = None addresses: List[EmailStr] @@ -171,23 +183,58 @@ class DbSettings(BaseModel): max_overflow: int = 20 +class LdapSettings(BaseModel): + uri: str + user_tree: str + bind_dn: str | None = None + bind_pwd: str | None = None + user_attr: str + user_filter: str | None = None + + class General(BaseModel): """Frequency for the checks and alerts""" - cookie_secret: str - frequency: int db: DbSettings env: Environment = "production" + cookie_secret: str + session_duration: int = 10080 # 7 days + remember_me_duration: int | None = None + unauthenticated_access: Unauthenticated | None = None + ldap: LdapSettings | None = None + frequency: float + recheck_delay: float | None = None root_path: str = "" alerts: Alert - mail: Optional[Mail] = None - gotify: Optional[List[GotifyUrl]] = None - apprise: Optional[Dict[str, List[str]]] = None + mail: Mail | None = None + gotify: List[GotifyUrl] | None = None + apprise: Dict[str, List[str]] | None = None + + @field_validator("session_duration", mode="before") + def parse_session_duration(cls, value): + """Convert the configured session duration to minutes""" + return Duration(value).to_minutes() + + @field_validator("remember_me_duration", mode="before") + def parse_remember_me_duration(cls, value): + """Convert the configured session duration with remember me feature to minutes""" + if value: + return int(Duration(value).to_minutes()) + + return None @field_validator("frequency", mode="before") def parse_frequency(cls, value): """Convert the configured frequency to minutes""" - return string_to_duration(value, "minutes") + return Duration(value).to_minutes() + + @field_validator("recheck_delay", mode="before") + def parse_recheck_delay(cls, value): + """Convert the configured recheck delay to minutes""" + if value: + return Duration(value).to_minutes() + + return None class Config(BaseModel): diff --git a/argos/schemas/models.py b/argos/schemas/models.py index a297acf..a4a37c2 100644 --- a/argos/schemas/models.py +++ b/argos/schemas/models.py @@ -8,6 +8,8 @@ from typing import Literal from pydantic import BaseModel, ConfigDict +from argos.schemas.utils import Method + # XXX Refactor using SQLModel to avoid duplication of model data @@ -18,6 +20,7 @@ class Task(BaseModel): url: str domain: str check: str + method: Method expected: str selected_at: datetime | None selected_by: str | None diff --git a/argos/schemas/utils.py b/argos/schemas/utils.py index f225241..05d716a 100644 --- a/argos/schemas/utils.py +++ b/argos/schemas/utils.py @@ -1,42 +1,6 @@ from typing import Literal -def string_to_duration( - value: str, target: Literal["days", "hours", "minutes"] -) -> int | float: - """Convert a string to a number of hours, days or minutes""" - num = int("".join(filter(str.isdigit, value))) - - # It's not possible to convert from a smaller unit to a greater one: - # - hours and minutes cannot be converted to days - # - minutes cannot be converted to hours - if (target == "days" and ("h" in value or "m" in value.replace("mo", ""))) or ( - target == "hours" and "m" in value.replace("mo", "") - ): - msg = ( - "Durations cannot be converted from a smaller to a greater unit. " - f"(trying to convert '{value}' to {target})" - ) - raise ValueError(msg, value) - - # Consider we're converting to minutes, do the eventual multiplication at the end. - if "h" in value: - num = num * 60 - elif "d" in value: - num = num * 60 * 24 - elif "w" in value: - num = num * 60 * 24 * 7 - elif "mo" in value: - num = num * 60 * 24 * 30 # considers 30d in a month - elif "y" in value: - num = num * 60 * 24 * 365 # considers 365d in a year - elif "m" not in value: - raise ValueError("Invalid duration value", value) - - if target == "hours": - return num / 60 - if target == "days": - return num / 60 / 24 - - # target == "minutes" - return num +Method = Literal[ + "GET", "HEAD", "POST", "OPTIONS", "CONNECT", "TRACE", "PUT", "PATCH", "DELETE" +] diff --git a/argos/server/alerting.py b/argos/server/alerting.py index 60384aa..4c82a9b 100644 --- a/argos/server/alerting.py +++ b/argos/server/alerting.py @@ -25,7 +25,7 @@ def get_icon_from_severity(severity: str) -> str: return icon -def handle_alert(config: Config, result, task, severity, old_severity, request): +def handle_alert(config: Config, result, task, severity, old_severity, request): # pylint: disable-msg=too-many-positional-arguments """Dispatch alert through configured alert channels""" if "local" in getattr(config.general.alerts, severity): @@ -64,7 +64,7 @@ def handle_alert(config: Config, result, task, severity, old_severity, request): ) -def notify_with_apprise( +def notify_with_apprise( # pylint: disable-msg=too-many-positional-arguments result, task, severity: str, old_severity: str, group: List[str], request ) -> None: logger.debug("Will send apprise notification") @@ -90,7 +90,7 @@ See results of task on {request.url_for('get_task_results_view', task_id=task.id apobj.notify(title=title, body=msg) -def notify_by_mail( +def notify_by_mail( # pylint: disable-msg=too-many-positional-arguments result, task, severity: str, old_severity: str, config: Mail, request ) -> None: logger.debug("Will send mail notification") @@ -137,7 +137,7 @@ See results of task on {request.url_for('get_task_results_view', task_id=task.id smtp.send_message(mail, to_addrs=address) -def notify_with_gotify( +def notify_with_gotify( # pylint: disable-msg=too-many-positional-arguments result, task, severity: str, old_severity: str, config: List[GotifyUrl], request ) -> None: logger.debug("Will send gotify notification") diff --git a/argos/server/main.py b/argos/server/main.py index b6ee412..0543182 100644 --- a/argos/server/main.py +++ b/argos/server/main.py @@ -36,13 +36,25 @@ def get_application() -> FastAPI: appli.add_exception_handler(NotAuthenticatedException, auth_exception_handler) appli.state.manager = create_manager(config.general.cookie_secret) + if config.general.ldap is not None: + import ldap + + l = ldap.initialize(config.general.ldap.uri) + l.simple_bind_s(config.general.ldap.bind_dn, config.general.ldap.bind_pwd) + appli.state.ldap = l + @appli.state.manager.user_loader() - async def query_user(user: str) -> None | models.User: + async def query_user(user: str) -> None | str | models.User: """ - Get a user from the db + Get a user from the db or LDAP :param user: name of the user :return: None or the user object """ + if appli.state.config.general.ldap is not None: + from argos.server.routes.dependencies import find_ldap_user + + return await find_ldap_user(appli.state.config, appli.state.ldap, user) + return await queries.get_user(appli.state.db, user) appli.include_router(routes.api, prefix="/api") @@ -100,7 +112,7 @@ def setup_database(appli): models.Base.metadata.create_all(bind=engine) -def create_manager(cookie_secret): +def create_manager(cookie_secret: str) -> LoginManager: if cookie_secret == "foo_bar_baz": logger.warning( "You should change the cookie_secret secret in your configuration file." diff --git a/argos/server/migrations/versions/127d74c770bb_add_recheck_delay.py b/argos/server/migrations/versions/127d74c770bb_add_recheck_delay.py new file mode 100644 index 0000000..3605e8b --- /dev/null +++ b/argos/server/migrations/versions/127d74c770bb_add_recheck_delay.py @@ -0,0 +1,30 @@ +"""Add recheck delay + +Revision ID: 127d74c770bb +Revises: dcf73fa19fce +Create Date: 2024-11-27 16:04:58.138768 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "127d74c770bb" +down_revision: Union[str, None] = "dcf73fa19fce" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + with op.batch_alter_table("tasks", schema=None) as batch_op: + batch_op.add_column(sa.Column("recheck_delay", sa.Float(), nullable=True)) + batch_op.add_column(sa.Column("already_retried", sa.Boolean(), nullable=False)) + + +def downgrade() -> None: + with op.batch_alter_table("tasks", schema=None) as batch_op: + batch_op.drop_column("already_retried") + batch_op.drop_column("recheck_delay") diff --git a/argos/server/migrations/versions/a1e98cf72a5c_make_frequency_a_float.py b/argos/server/migrations/versions/a1e98cf72a5c_make_frequency_a_float.py new file mode 100644 index 0000000..d0facb7 --- /dev/null +++ b/argos/server/migrations/versions/a1e98cf72a5c_make_frequency_a_float.py @@ -0,0 +1,38 @@ +"""Make frequency a float + +Revision ID: a1e98cf72a5c +Revises: 127d74c770bb +Create Date: 2024-11-27 16:10:13.000705 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "a1e98cf72a5c" +down_revision: Union[str, None] = "127d74c770bb" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + with op.batch_alter_table("tasks", schema=None) as batch_op: + batch_op.alter_column( + "frequency", + existing_type=sa.INTEGER(), + type_=sa.Float(), + existing_nullable=False, + ) + + +def downgrade() -> None: + with op.batch_alter_table("tasks", schema=None) as batch_op: + batch_op.alter_column( + "frequency", + existing_type=sa.Float(), + type_=sa.INTEGER(), + existing_nullable=False, + ) diff --git a/argos/server/migrations/versions/dcf73fa19fce_specify_check_method.py b/argos/server/migrations/versions/dcf73fa19fce_specify_check_method.py new file mode 100644 index 0000000..f218108 --- /dev/null +++ b/argos/server/migrations/versions/dcf73fa19fce_specify_check_method.py @@ -0,0 +1,45 @@ +"""Specify check method + +Revision ID: dcf73fa19fce +Revises: c780864dc407 +Create Date: 2024-11-26 14:40:27.510587 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "dcf73fa19fce" +down_revision: Union[str, None] = "c780864dc407" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + with op.batch_alter_table("tasks", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "method", + sa.Enum( + "GET", + "HEAD", + "POST", + "OPTIONS", + "CONNECT", + "TRACE", + "PUT", + "PATCH", + "DELETE", + name="method", + ), + nullable=False, + ) + ) + + +def downgrade() -> None: + with op.batch_alter_table("tasks", schema=None) as batch_op: + batch_op.drop_column("method") diff --git a/argos/server/models.py b/argos/server/models.py index 5a03399..33b05b9 100644 --- a/argos/server/models.py +++ b/argos/server/models.py @@ -12,6 +12,7 @@ from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from argos.checks import BaseCheck, get_registered_check from argos.schemas import WebsiteCheck +from argos.schemas.utils import Method class Base(DeclarativeBase): @@ -34,7 +35,24 @@ class Task(Base): domain: Mapped[str] = mapped_column() check: Mapped[str] = mapped_column() expected: Mapped[str] = mapped_column() - frequency: Mapped[int] = mapped_column() + frequency: Mapped[float] = mapped_column() + recheck_delay: Mapped[float] = mapped_column(nullable=True) + already_retried: Mapped[bool] = mapped_column(insert_default=False) + method: Mapped[Method] = mapped_column( + Enum( + "GET", + "HEAD", + "POST", + "OPTIONS", + "CONNECT", + "TRACE", + "PUT", + "PATCH", + "DELETE", + name="method", + ), + insert_default="GET", + ) # Orchestration-related selected_by: Mapped[str] = mapped_column(nullable=True) @@ -70,7 +88,16 @@ class Task(Base): now = datetime.now() self.completed_at = now - self.next_run = now + timedelta(minutes=self.frequency) + if ( + self.recheck_delay is not None + and severity != "ok" + and not self.already_retried + ): + self.next_run = now + timedelta(minutes=self.recheck_delay) + self.already_retried = True + else: + self.next_run = now + timedelta(minutes=self.frequency) + self.already_retried = False @property def last_result(self): diff --git a/argos/server/queries.py b/argos/server/queries.py index e887ebe..94fc0f4 100644 --- a/argos/server/queries.py +++ b/argos/server/queries.py @@ -100,6 +100,11 @@ async def has_config_changed(db: Session, config: schemas.Config) -> bool: same_config = False conf.val = str(config.general.frequency) conf.updated_at = datetime.now() + case "general_recheck_delay": + if conf.val != str(config.general.recheck_delay): + same_config = False + conf.val = str(config.general.recheck_delay) + conf.updated_at = datetime.now() db.commit() @@ -115,8 +120,14 @@ async def has_config_changed(db: Session, config: schemas.Config) -> bool: val=str(config.general.frequency), updated_at=datetime.now(), ) + gen_recheck = ConfigCache( + name="general_recheck_delay", + val=str(config.general.recheck_delay), + updated_at=datetime.now(), + ) db.add(web_hash) db.add(gen_freq) + db.add(gen_recheck) db.commit() return True @@ -137,6 +148,7 @@ async def update_from_config(db: Session, config: schemas.Config): for website in config.websites: domain = str(website.domain) frequency = website.frequency or config.general.frequency + recheck_delay = website.recheck_delay or config.general.recheck_delay for p in website.paths: url = urljoin(domain, str(p.path)) @@ -146,6 +158,7 @@ async def update_from_config(db: Session, config: schemas.Config): db.query(Task) .filter( Task.url == url, + Task.method == p.method, Task.check == check_key, Task.expected == expected, ) @@ -157,13 +170,18 @@ async def update_from_config(db: Session, config: schemas.Config): if frequency != existing_task.frequency: existing_task.frequency = frequency + if recheck_delay != existing_task.recheck_delay: + existing_task.recheck_delay = recheck_delay # type: ignore[assignment] logger.debug( "Skipping db task creation for url=%s, " - "check_key=%s, expected=%s, frequency=%s.", + "method=%s, check_key=%s, expected=%s, " + "frequency=%s, recheck_delay=%s.", url, + p.method, check_key, expected, frequency, + recheck_delay, ) else: @@ -173,9 +191,12 @@ async def update_from_config(db: Session, config: schemas.Config): task = Task( domain=domain, url=url, + method=p.method, check=check_key, expected=expected, frequency=frequency, + recheck_delay=recheck_delay, + already_retried=False, ) logger.debug("Adding a new task in the db: %s", task) tasks.append(task) diff --git a/argos/server/routes/api.py b/argos/server/routes/api.py index ec132ca..cc96132 100644 --- a/argos/server/routes/api.py +++ b/argos/server/routes/api.py @@ -30,7 +30,7 @@ async def read_tasks( @route.post("/results", status_code=201, dependencies=[Depends(verify_token)]) -async def create_results( +async def create_results( # pylint: disable-msg=too-many-positional-arguments request: Request, results: List[AgentResult], background_tasks: BackgroundTasks, diff --git a/argos/server/routes/dependencies.py b/argos/server/routes/dependencies.py index f26d5ee..e61b77a 100644 --- a/argos/server/routes/dependencies.py +++ b/argos/server/routes/dependencies.py @@ -18,6 +18,9 @@ def get_config(request: Request): async def get_manager(request: Request) -> LoginManager: + if request.app.state.config.general.unauthenticated_access is not None: + return await request.app.state.manager.optional(request) + return await request.app.state.manager(request) @@ -28,3 +31,28 @@ async def verify_token( if token.credentials not in request.app.state.config.service.secrets: raise HTTPException(status_code=401, detail="Unauthorized") return token + + +async def find_ldap_user(config, ldap, user: str) -> str | None: + """Do a LDAP search for user and return its dn""" + import ldap.filter as ldap_filter + from ldapurl import LDAP_SCOPE_SUBTREE + + result = ldap.search_s( + config.general.ldap.user_tree, + LDAP_SCOPE_SUBTREE, + filterstr=ldap_filter.filter_format( + f"(&(%s=%s){config.general.ldap.user_filter})", + [ + config.general.ldap.user_attr, + user, + ], + ), + attrlist=[config.general.ldap.user_attr], + ) + + # If there is a result, there should, logically, be only one entry + if len(result) > 0: + return result[0][0] + + return None diff --git a/argos/server/routes/views.py b/argos/server/routes/views.py index 6dc5ace..ae2f51c 100644 --- a/argos/server/routes/views.py +++ b/argos/server/routes/views.py @@ -17,6 +17,7 @@ from sqlalchemy.orm import Session from argos.checks.base import Status from argos.schemas import Config from argos.server import queries +from argos.server.exceptions import NotAuthenticatedException from argos.server.models import Result, Task, User from argos.server.routes.dependencies import get_config, get_db, get_manager @@ -28,7 +29,17 @@ SEVERITY_LEVELS = {"ok": 1, "warning": 2, "critical": 3, "unknown": 4} @route.get("/login") -async def login_view(request: Request, msg: str | None = None): +async def login_view( + request: Request, + msg: str | None = None, + config: Config = Depends(get_config), +): + if config.general.unauthenticated_access == "all": + return RedirectResponse( + request.url_for("get_severity_counts_view"), + status_code=status.HTTP_303_SEE_OTHER, + ) + token = request.cookies.get("access-token") if token is not None and token != "": manager = request.app.state.manager @@ -44,7 +55,14 @@ async def login_view(request: Request, msg: str | None = None): else: msg = None - return templates.TemplateResponse("login.html", {"request": request, "msg": msg}) + return templates.TemplateResponse( + "login.html", + { + "request": request, + "msg": msg, + "remember": config.general.remember_me_duration, + }, + ) @route.post("/login") @@ -52,37 +70,77 @@ async def post_login( request: Request, db: Session = Depends(get_db), data: OAuth2PasswordRequestForm = Depends(), + rememberme: Annotated[str | None, Form()] = None, + config: Config = Depends(get_config), ): + if config.general.unauthenticated_access == "all": + return RedirectResponse( + request.url_for("get_severity_counts_view"), + status_code=status.HTTP_303_SEE_OTHER, + ) + username = data.username - user = await queries.get_user(db, username) + invalid_credentials = templates.TemplateResponse( "login.html", {"request": request, "msg": "Sorry, invalid username or bad password."}, ) - if user is None: - return invalid_credentials - pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") - if not pwd_context.verify(data.password, user.password): - return invalid_credentials + if config.general.ldap is not None: + from ldap import INVALID_CREDENTIALS # pylint: disable-msg=no-name-in-module + from argos.server.routes.dependencies import find_ldap_user - user.last_login_at = datetime.now() - db.commit() + ldap_dn = await find_ldap_user(config, request.app.state.ldap, username) + if ldap_dn is None: + return invalid_credentials + try: + request.app.state.ldap.simple_bind_s(ldap_dn, data.password) + except INVALID_CREDENTIALS: + return invalid_credentials + else: + user = await queries.get_user(db, username) + if user is None: + return invalid_credentials + + pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + if not pwd_context.verify(data.password, user.password): + return invalid_credentials + + user.last_login_at = datetime.now() + db.commit() manager = request.app.state.manager - token = manager.create_access_token( - data={"sub": username}, expires=timedelta(days=7) - ) + session_duration = config.general.session_duration + if config.general.remember_me_duration is not None and rememberme == "on": + session_duration = config.general.remember_me_duration + delta = timedelta(minutes=session_duration) + token = manager.create_access_token(data={"sub": username}, expires=delta) response = RedirectResponse( request.url_for("get_severity_counts_view"), status_code=status.HTTP_303_SEE_OTHER, ) - manager.set_cookie(response, token) + response.set_cookie( + key=manager.cookie_name, + value=token, + httponly=True, + samesite="strict", + expires=int(delta.total_seconds()), + ) return response @route.get("/logout") -async def logout_view(request: Request, user: User | None = Depends(get_manager)): +async def logout_view( + request: Request, + config: Config = Depends(get_config), + user: User | None = Depends(get_manager), +): + if config.general.unauthenticated_access == "all": + return RedirectResponse( + request.url_for("get_severity_counts_view"), + status_code=status.HTTP_303_SEE_OTHER, + ) + response = RedirectResponse( request.url_for("login_view").include_query_params(msg="logout"), status_code=status.HTTP_303_SEE_OTHER, @@ -112,6 +170,7 @@ async def get_severity_counts_view( "agents": agents, "auto_refresh_enabled": auto_refresh_enabled, "auto_refresh_seconds": auto_refresh_seconds, + "user": user, }, ) @@ -120,9 +179,14 @@ async def get_severity_counts_view( async def get_domains_view( request: Request, user: User | None = Depends(get_manager), + config: Config = Depends(get_config), db: Session = Depends(get_db), ): """Show all tasks and their current state""" + if config.general.unauthenticated_access == "dashboard": + if user is None: + raise NotAuthenticatedException + tasks = db.query(Task).all() domains_severities = defaultdict(list) @@ -163,6 +227,7 @@ async def get_domains_view( "last_checks": domains_last_checks, "total_task_count": len(tasks), "agents": agents, + "user": user, }, ) @@ -172,12 +237,23 @@ async def get_domain_tasks_view( request: Request, domain: str, user: User | None = Depends(get_manager), + config: Config = Depends(get_config), db: Session = Depends(get_db), ): """Show all tasks attached to a domain""" + if config.general.unauthenticated_access == "dashboard": + if user is None: + raise NotAuthenticatedException + tasks = db.query(Task).filter(Task.domain.contains(f"//{domain}")).all() return templates.TemplateResponse( - "domain.html", {"request": request, "domain": domain, "tasks": tasks} + "domain.html", + { + "request": request, + "domain": domain, + "tasks": tasks, + "user": user, + }, ) @@ -186,12 +262,23 @@ async def get_result_view( request: Request, result_id: int, user: User | None = Depends(get_manager), + config: Config = Depends(get_config), db: Session = Depends(get_db), ): """Show the details of a result""" + if config.general.unauthenticated_access == "dashboard": + if user is None: + raise NotAuthenticatedException + result = db.query(Result).get(result_id) return templates.TemplateResponse( - "result.html", {"request": request, "result": result, "error": Status.ERROR} + "result.html", + { + "request": request, + "result": result, + "error": Status.ERROR, + "user": user, + }, ) @@ -204,6 +291,10 @@ async def get_task_results_view( config: Config = Depends(get_config), ): """Show history of a task’s results""" + if config.general.unauthenticated_access == "dashboard": + if user is None: + raise NotAuthenticatedException + results = ( db.query(Result) .filter(Result.task_id == task_id) @@ -222,6 +313,7 @@ async def get_task_results_view( "task": task, "description": description, "error": Status.ERROR, + "user": user, }, ) @@ -230,9 +322,14 @@ async def get_task_results_view( async def get_agents_view( request: Request, user: User | None = Depends(get_manager), + config: Config = Depends(get_config), db: Session = Depends(get_db), ): """Show argos agents and the last time the server saw them""" + if config.general.unauthenticated_access == "dashboard": + if user is None: + raise NotAuthenticatedException + last_seen = ( db.query(Result.agent_id, func.max(Result.submitted_at).label("submitted_at")) .group_by(Result.agent_id) @@ -240,7 +337,12 @@ async def get_agents_view( ) return templates.TemplateResponse( - "agents.html", {"request": request, "last_seen": last_seen} + "agents.html", + { + "request": request, + "last_seen": last_seen, + "user": user, + }, ) diff --git a/argos/server/templates/base.html b/argos/server/templates/base.html index 4964031..c88065f 100644 --- a/argos/server/templates/base.html +++ b/argos/server/templates/base.html @@ -63,6 +63,8 @@ Agents + {% set unauthenticated_access = request.app.state.config.general.unauthenticated_access %} + {% if (user is defined and user is not none) or unauthenticated_access == "all" %}