From 837cd548ade7b787b2867bc8f6a29eac1350b5b2 Mon Sep 17 00:00:00 2001 From: Luc Didry Date: Wed, 19 Mar 2025 17:21:53 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=92=20=E2=80=94=20Logging=20out=20now?= =?UTF-8?q?=20invalidate=20tokens?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 1 + argos/server/exceptions.py | 4 +- argos/server/main.py | 3 + .../1d0aaa07743c_add_blocked_tokens_table.py | 32 ++++++++ argos/server/models.py | 15 ++++ argos/server/queries.py | 29 ++++++- argos/server/routes/dependencies.py | 70 +++++++++++++++++ argos/server/routes/views.py | 75 ++++++++----------- 8 files changed, 183 insertions(+), 46 deletions(-) create mode 100644 argos/server/migrations/versions/1d0aaa07743c_add_blocked_tokens_table.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ff4c87..7c5ef2c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ - 🚸 — Use ReconnectLDAPObject - ⏰ — Set recurring task to every minute - 🔊 — Improve check agent log +- 🔒️ — Logging out now invalidate tokens ## 0.9.0 diff --git a/argos/server/exceptions.py b/argos/server/exceptions.py index 8fdd39d..042d80d 100644 --- a/argos/server/exceptions.py +++ b/argos/server/exceptions.py @@ -10,7 +10,9 @@ def auth_exception_handler(request: Request, exc: NotAuthenticatedException): """ Redirect the user to the login page if not logged in """ - response = RedirectResponse(url=request.url_for("login_view")) + response = RedirectResponse( + url=request.url_for("login_view").include_query_params(msg="not-authenticated") + ) manager = request.app.state.manager manager.set_cookie(response, "") return response diff --git a/argos/server/main.py b/argos/server/main.py index 00043d4..1ca82e2 100644 --- a/argos/server/main.py +++ b/argos/server/main.py @@ -151,6 +151,9 @@ async def recurring_tasks() -> None: updated = await queries.release_old_locks(db, config.max_lock_seconds) logger.info("%i lock(s) released", updated) + removed_tokens = await queries.remove_old_tokens(db) + logger.info("%i old token(s) removed", removed_tokens) + processed_jobs = await queries.process_jobs(db) logger.info("%i job(s) processed", processed_jobs) diff --git a/argos/server/migrations/versions/1d0aaa07743c_add_blocked_tokens_table.py b/argos/server/migrations/versions/1d0aaa07743c_add_blocked_tokens_table.py new file mode 100644 index 0000000..99367d9 --- /dev/null +++ b/argos/server/migrations/versions/1d0aaa07743c_add_blocked_tokens_table.py @@ -0,0 +1,32 @@ +"""Add blocked_tokens table + +Revision ID: 1d0aaa07743c +Revises: 5f6cb30db996 +Create Date: 2025-03-19 15:23:20.233843 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "1d0aaa07743c" +down_revision: Union[str, None] = "5f6cb30db996" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "blocked_tokens", + sa.Column("token", sa.String(), nullable=False), + sa.Column("expires_at", sa.DateTime(), nullable=False), + sa.Column("excluded_at", sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint("token"), + ) + + +def downgrade() -> None: + op.drop_table("blocked_tokens") diff --git a/argos/server/models.py b/argos/server/models.py index eab35ac..7d6997d 100644 --- a/argos/server/models.py +++ b/argos/server/models.py @@ -217,3 +217,18 @@ class User(Base): def update_last_login_at(self): self.last_login_at = datetime.now() + + +class BlockedToken(Base): + """ + List of tokens discarded by their users + (when they logout) + """ + + __tablename__ = "blocked_tokens" + token: Mapped[str] = mapped_column(primary_key=True) + expires_at: Mapped[datetime] = mapped_column() + excluded_at: Mapped[datetime] = mapped_column(default=datetime.now()) + + def __str__(self) -> str: + return f"DB BlockedToken {self.token} - {self.expires_at} - {self.excluded_at}" diff --git a/argos/server/queries.py b/argos/server/queries.py index 6e93c02..581d5ec 100644 --- a/argos/server/queries.py +++ b/argos/server/queries.py @@ -4,12 +4,15 @@ from hashlib import sha256 from typing import List from urllib.parse import urljoin +import jwt + +from fastapi import Request from sqlalchemy import asc, func, Select from sqlalchemy.orm import Session from argos import schemas from argos.logging import logger -from argos.server.models import ConfigCache, Job, Result, Task, User +from argos.server.models import BlockedToken, ConfigCache, Job, Result, Task, User from argos.server.settings import read_config @@ -420,6 +423,30 @@ async def remove_old_results(db: Session, max_results_age: float): return deleted +async def block_token(db: Session, request: Request): + """Discard user token""" + manager = request.app.state.manager + token = await manager._get_token(request) # pylint: disable-msg=protected-access + payload = jwt.decode( + token, manager.secret.secret_for_decode, algorithms=[manager.algorithm] + ) + blocked_token = BlockedToken( + token=token, expires_at=datetime.utcfromtimestamp(payload["exp"]) + ) + db.add(blocked_token) + db.commit() + + +async def remove_old_tokens(db: Session): + """Remove expired discarded tokens""" + deleted = ( + db.query(BlockedToken).filter(BlockedToken.expires_at < datetime.now()).delete() + ) + db.commit() + + return deleted + + async def release_old_locks(db: Session, max_lock_seconds: int): """Remove outdated locks on tasks""" max_acceptable_time = datetime.now() - timedelta(seconds=max_lock_seconds) diff --git a/argos/server/routes/dependencies.py b/argos/server/routes/dependencies.py index ed0399a..829c0b8 100644 --- a/argos/server/routes/dependencies.py +++ b/argos/server/routes/dependencies.py @@ -1,8 +1,17 @@ +from datetime import datetime, timedelta + from fastapi import Depends, HTTPException, Request from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from fastapi_login import LoginManager +from ldap import INVALID_CREDENTIALS # pylint: disable-msg=no-name-in-module +from ldap.ldapobject import ReconnectLDAPObject +from passlib.context import CryptContext from argos.logging import logger +from argos.schemas import Config, General +from argos.server.exceptions import NotAuthenticatedException +from argos.server.models import BlockedToken +from argos.server.queries import get_user auth_scheme = HTTPBearer() @@ -23,6 +32,11 @@ async def get_manager(request: Request) -> LoginManager: if request.app.state.config.general.unauthenticated_access is not None: return await request.app.state.manager.optional(request) + token = await request.app.state.manager._get_token(request) # pylint: disable-msg=protected-access + db = request.app.state.SessionLocal() + if db.query(BlockedToken).filter(BlockedToken.token == token).count() > 0: + raise NotAuthenticatedException + return await request.app.state.manager(request) @@ -35,6 +49,33 @@ async def verify_token( return token +async def good_user_credentials( + config: Config, request: Request, username: str, password: str +): + if config.general.ldap is not None: + return await good_ldap_user_credentials( + config, request.app.state.ldap, username, password + ) + + return await good_internal_user_credentials( + request.app.state.SessionLocal(), username, password + ) + + +async def good_ldap_user_credentials( + config: Config, ldapobj: ReconnectLDAPObject, username: str, password: str +) -> bool: + ldap_dn = await find_ldap_user(config, ldapobj, username) + if ldap_dn is None: + return False + try: + ldapobj.simple_bind_s(ldap_dn, password) + except INVALID_CREDENTIALS: + return False + + return True + + async def find_ldap_user(config, ldapobj, user: str) -> str | None: """Do a LDAP search for user and return its dn""" import ldap @@ -65,3 +106,32 @@ async def find_ldap_user(config, ldapobj, user: str) -> str | None: return result[0][0] return None + + +async def good_internal_user_credentials(db, username, password) -> bool: + user = await get_user(db, username) + if user is None: + return False + + pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + if not pwd_context.verify(password, user.password): + return False + + user.last_login_at = datetime.now() + db.commit() + + return True + + +async def create_user_token( + manager, config: General, username: str, rememberme: str | None = None +): + session_duration = config.session_duration + if config.remember_me_duration is not None and rememberme == "on": + session_duration = config.remember_me_duration + + delta = timedelta(minutes=session_duration) + return { + "token": manager.create_access_token(data={"sub": username}, expires=delta), + "delta": delta, + } diff --git a/argos/server/routes/views.py b/argos/server/routes/views.py index 3da2a6b..49da12c 100644 --- a/argos/server/routes/views.py +++ b/argos/server/routes/views.py @@ -1,6 +1,6 @@ """Web interface for humans""" from collections import defaultdict -from datetime import datetime, timedelta +from datetime import timedelta from functools import cmp_to_key from pathlib import Path from typing import Annotated @@ -10,7 +10,6 @@ from fastapi import APIRouter, Cookie, Depends, Form, Request, status from fastapi.responses import RedirectResponse from fastapi.security import OAuth2PasswordRequestForm from fastapi.templating import Jinja2Templates -from passlib.context import CryptContext from sqlalchemy import func from sqlalchemy.orm import Session @@ -19,7 +18,13 @@ from argos.schemas import Config from argos.server import queries from argos.server.exceptions import NotAuthenticatedException from argos.server.models import Result, Task, User -from argos.server.routes.dependencies import get_config, get_db, get_manager +from argos.server.routes.dependencies import ( + create_user_token, + get_config, + get_db, + get_manager, + good_user_credentials, +) route = APIRouter() @@ -52,6 +57,8 @@ async def login_view( if msg == "logout": msg = "You have been successfully disconnected." + elif msg == "not-authenticated": + msg = "You are not authenticated or your token has expired" else: msg = None @@ -79,61 +86,38 @@ async def post_login( status_code=status.HTTP_303_SEE_OTHER, ) - username = data.username - - invalid_credentials = templates.TemplateResponse( - "login.html", - {"request": request, "msg": "Sorry, invalid username or bad password."}, + good_credentials = await good_user_credentials( + config, request, data.username, data.password ) - if config.general.ldap is not None: - from ldap import INVALID_CREDENTIALS # pylint: disable-msg=no-name-in-module - from argos.server.routes.dependencies import find_ldap_user - - invalid_credentials = templates.TemplateResponse( + if not good_credentials: + return templates.TemplateResponse( + "login.html", + { + "request": request, + "msg": "Sorry, invalid username or bad password. " + "Or the LDAP server is unreachable (see logs to verify).", + }, + ) + elif not good_credentials: + return templates.TemplateResponse( "login.html", - { - "request": request, - "msg": "Sorry, invalid username or bad password. " - "Or the LDAP server is unreachable (see logs to verify).", - }, + {"request": request, "msg": "Sorry, invalid username or bad password."}, ) - ldap_dn = await find_ldap_user(config, request.app.state.ldap, username) - if ldap_dn is None: - return invalid_credentials - try: - request.app.state.ldap.simple_bind_s(ldap_dn, data.password) - except INVALID_CREDENTIALS: - return invalid_credentials - else: - user = await queries.get_user(db, username) - if user is None: - return invalid_credentials - - pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") - if not pwd_context.verify(data.password, user.password): - return invalid_credentials - - user.last_login_at = datetime.now() - db.commit() - manager = request.app.state.manager - session_duration = config.general.session_duration - if config.general.remember_me_duration is not None and rememberme == "on": - session_duration = config.general.remember_me_duration - delta = timedelta(minutes=session_duration) - token = manager.create_access_token(data={"sub": username}, expires=delta) + token = await create_user_token(manager, config.general, data.username, rememberme) + response = RedirectResponse( request.url_for("get_severity_counts_view"), status_code=status.HTTP_303_SEE_OTHER, ) response.set_cookie( key=manager.cookie_name, - value=token, + value=token["token"], httponly=True, samesite="strict", - expires=int(delta.total_seconds()), + expires=int(token["delta"].total_seconds()), ) return response @@ -143,6 +127,7 @@ async def logout_view( request: Request, config: Config = Depends(get_config), user: User | None = Depends(get_manager), + db: Session = Depends(get_db), ): if config.general.unauthenticated_access == "all": return RedirectResponse( @@ -150,6 +135,8 @@ async def logout_view( status_code=status.HTTP_303_SEE_OTHER, ) + await queries.block_token(db, request) + response = RedirectResponse( request.url_for("login_view").include_query_params(msg="logout"), status_code=status.HTTP_303_SEE_OTHER,