🔒 — Logging out now invalidate tokens

This commit is contained in:
Luc Didry 2025-03-19 17:21:53 +01:00
parent dbe05178b8
commit 837cd548ad
No known key found for this signature in database
GPG key ID: EA868E12D0257E3C
8 changed files with 183 additions and 46 deletions

View file

@ -5,6 +5,7 @@
- 🚸 — Use ReconnectLDAPObject
- ⏰ — Set recurring task to every minute
- 🔊 — Improve check agent log
- 🔒️ — Logging out now invalidate tokens
## 0.9.0

View file

@ -10,7 +10,9 @@ def auth_exception_handler(request: Request, exc: NotAuthenticatedException):
"""
Redirect the user to the login page if not logged in
"""
response = RedirectResponse(url=request.url_for("login_view"))
response = RedirectResponse(
url=request.url_for("login_view").include_query_params(msg="not-authenticated")
)
manager = request.app.state.manager
manager.set_cookie(response, "")
return response

View file

@ -151,6 +151,9 @@ async def recurring_tasks() -> None:
updated = await queries.release_old_locks(db, config.max_lock_seconds)
logger.info("%i lock(s) released", updated)
removed_tokens = await queries.remove_old_tokens(db)
logger.info("%i old token(s) removed", removed_tokens)
processed_jobs = await queries.process_jobs(db)
logger.info("%i job(s) processed", processed_jobs)

View file

@ -0,0 +1,32 @@
"""Add blocked_tokens table
Revision ID: 1d0aaa07743c
Revises: 5f6cb30db996
Create Date: 2025-03-19 15:23:20.233843
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "1d0aaa07743c"
down_revision: Union[str, None] = "5f6cb30db996"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"blocked_tokens",
sa.Column("token", sa.String(), nullable=False),
sa.Column("expires_at", sa.DateTime(), nullable=False),
sa.Column("excluded_at", sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint("token"),
)
def downgrade() -> None:
op.drop_table("blocked_tokens")

View file

@ -217,3 +217,18 @@ class User(Base):
def update_last_login_at(self):
self.last_login_at = datetime.now()
class BlockedToken(Base):
"""
List of tokens discarded by their users
(when they logout)
"""
__tablename__ = "blocked_tokens"
token: Mapped[str] = mapped_column(primary_key=True)
expires_at: Mapped[datetime] = mapped_column()
excluded_at: Mapped[datetime] = mapped_column(default=datetime.now())
def __str__(self) -> str:
return f"DB BlockedToken {self.token} - {self.expires_at} - {self.excluded_at}"

View file

@ -4,12 +4,15 @@ from hashlib import sha256
from typing import List
from urllib.parse import urljoin
import jwt
from fastapi import Request
from sqlalchemy import asc, func, Select
from sqlalchemy.orm import Session
from argos import schemas
from argos.logging import logger
from argos.server.models import ConfigCache, Job, Result, Task, User
from argos.server.models import BlockedToken, ConfigCache, Job, Result, Task, User
from argos.server.settings import read_config
@ -420,6 +423,30 @@ async def remove_old_results(db: Session, max_results_age: float):
return deleted
async def block_token(db: Session, request: Request):
"""Discard user token"""
manager = request.app.state.manager
token = await manager._get_token(request) # pylint: disable-msg=protected-access
payload = jwt.decode(
token, manager.secret.secret_for_decode, algorithms=[manager.algorithm]
)
blocked_token = BlockedToken(
token=token, expires_at=datetime.utcfromtimestamp(payload["exp"])
)
db.add(blocked_token)
db.commit()
async def remove_old_tokens(db: Session):
"""Remove expired discarded tokens"""
deleted = (
db.query(BlockedToken).filter(BlockedToken.expires_at < datetime.now()).delete()
)
db.commit()
return deleted
async def release_old_locks(db: Session, max_lock_seconds: int):
"""Remove outdated locks on tasks"""
max_acceptable_time = datetime.now() - timedelta(seconds=max_lock_seconds)

View file

@ -1,8 +1,17 @@
from datetime import datetime, timedelta
from fastapi import Depends, HTTPException, Request
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi_login import LoginManager
from ldap import INVALID_CREDENTIALS # pylint: disable-msg=no-name-in-module
from ldap.ldapobject import ReconnectLDAPObject
from passlib.context import CryptContext
from argos.logging import logger
from argos.schemas import Config, General
from argos.server.exceptions import NotAuthenticatedException
from argos.server.models import BlockedToken
from argos.server.queries import get_user
auth_scheme = HTTPBearer()
@ -23,6 +32,11 @@ async def get_manager(request: Request) -> LoginManager:
if request.app.state.config.general.unauthenticated_access is not None:
return await request.app.state.manager.optional(request)
token = await request.app.state.manager._get_token(request) # pylint: disable-msg=protected-access
db = request.app.state.SessionLocal()
if db.query(BlockedToken).filter(BlockedToken.token == token).count() > 0:
raise NotAuthenticatedException
return await request.app.state.manager(request)
@ -35,6 +49,33 @@ async def verify_token(
return token
async def good_user_credentials(
config: Config, request: Request, username: str, password: str
):
if config.general.ldap is not None:
return await good_ldap_user_credentials(
config, request.app.state.ldap, username, password
)
return await good_internal_user_credentials(
request.app.state.SessionLocal(), username, password
)
async def good_ldap_user_credentials(
config: Config, ldapobj: ReconnectLDAPObject, username: str, password: str
) -> bool:
ldap_dn = await find_ldap_user(config, ldapobj, username)
if ldap_dn is None:
return False
try:
ldapobj.simple_bind_s(ldap_dn, password)
except INVALID_CREDENTIALS:
return False
return True
async def find_ldap_user(config, ldapobj, user: str) -> str | None:
"""Do a LDAP search for user and return its dn"""
import ldap
@ -65,3 +106,32 @@ async def find_ldap_user(config, ldapobj, user: str) -> str | None:
return result[0][0]
return None
async def good_internal_user_credentials(db, username, password) -> bool:
user = await get_user(db, username)
if user is None:
return False
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
if not pwd_context.verify(password, user.password):
return False
user.last_login_at = datetime.now()
db.commit()
return True
async def create_user_token(
manager, config: General, username: str, rememberme: str | None = None
):
session_duration = config.session_duration
if config.remember_me_duration is not None and rememberme == "on":
session_duration = config.remember_me_duration
delta = timedelta(minutes=session_duration)
return {
"token": manager.create_access_token(data={"sub": username}, expires=delta),
"delta": delta,
}

View file

@ -1,6 +1,6 @@
"""Web interface for humans"""
from collections import defaultdict
from datetime import datetime, timedelta
from datetime import timedelta
from functools import cmp_to_key
from pathlib import Path
from typing import Annotated
@ -10,7 +10,6 @@ from fastapi import APIRouter, Cookie, Depends, Form, Request, status
from fastapi.responses import RedirectResponse
from fastapi.security import OAuth2PasswordRequestForm
from fastapi.templating import Jinja2Templates
from passlib.context import CryptContext
from sqlalchemy import func
from sqlalchemy.orm import Session
@ -19,7 +18,13 @@ from argos.schemas import Config
from argos.server import queries
from argos.server.exceptions import NotAuthenticatedException
from argos.server.models import Result, Task, User
from argos.server.routes.dependencies import get_config, get_db, get_manager
from argos.server.routes.dependencies import (
create_user_token,
get_config,
get_db,
get_manager,
good_user_credentials,
)
route = APIRouter()
@ -52,6 +57,8 @@ async def login_view(
if msg == "logout":
msg = "You have been successfully disconnected."
elif msg == "not-authenticated":
msg = "You are not authenticated or your token has expired"
else:
msg = None
@ -79,61 +86,38 @@ async def post_login(
status_code=status.HTTP_303_SEE_OTHER,
)
username = data.username
invalid_credentials = templates.TemplateResponse(
"login.html",
{"request": request, "msg": "Sorry, invalid username or bad password."},
good_credentials = await good_user_credentials(
config, request, data.username, data.password
)
if config.general.ldap is not None:
from ldap import INVALID_CREDENTIALS # pylint: disable-msg=no-name-in-module
from argos.server.routes.dependencies import find_ldap_user
invalid_credentials = templates.TemplateResponse(
if not good_credentials:
return templates.TemplateResponse(
"login.html",
{
"request": request,
"msg": "Sorry, invalid username or bad password. "
"Or the LDAP server is unreachable (see logs to verify).",
},
)
elif not good_credentials:
return templates.TemplateResponse(
"login.html",
{
"request": request,
"msg": "Sorry, invalid username or bad password. "
"Or the LDAP server is unreachable (see logs to verify).",
},
{"request": request, "msg": "Sorry, invalid username or bad password."},
)
ldap_dn = await find_ldap_user(config, request.app.state.ldap, username)
if ldap_dn is None:
return invalid_credentials
try:
request.app.state.ldap.simple_bind_s(ldap_dn, data.password)
except INVALID_CREDENTIALS:
return invalid_credentials
else:
user = await queries.get_user(db, username)
if user is None:
return invalid_credentials
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
if not pwd_context.verify(data.password, user.password):
return invalid_credentials
user.last_login_at = datetime.now()
db.commit()
manager = request.app.state.manager
session_duration = config.general.session_duration
if config.general.remember_me_duration is not None and rememberme == "on":
session_duration = config.general.remember_me_duration
delta = timedelta(minutes=session_duration)
token = manager.create_access_token(data={"sub": username}, expires=delta)
token = await create_user_token(manager, config.general, data.username, rememberme)
response = RedirectResponse(
request.url_for("get_severity_counts_view"),
status_code=status.HTTP_303_SEE_OTHER,
)
response.set_cookie(
key=manager.cookie_name,
value=token,
value=token["token"],
httponly=True,
samesite="strict",
expires=int(delta.total_seconds()),
expires=int(token["delta"].total_seconds()),
)
return response
@ -143,6 +127,7 @@ async def logout_view(
request: Request,
config: Config = Depends(get_config),
user: User | None = Depends(get_manager),
db: Session = Depends(get_db),
):
if config.general.unauthenticated_access == "all":
return RedirectResponse(
@ -150,6 +135,8 @@ async def logout_view(
status_code=status.HTTP_303_SEE_OTHER,
)
await queries.block_token(db, request)
response = RedirectResponse(
request.url_for("login_view").include_query_params(msg="logout"),
status_code=status.HTTP_303_SEE_OTHER,