import os import sys from contextlib import asynccontextmanager from pathlib import Path from fastapi import FastAPI from fastapi.staticfiles import StaticFiles from fastapi_login import LoginManager from pydantic import ValidationError from sqlalchemy import create_engine, event from sqlalchemy.orm import sessionmaker from argos_monitoring.logging import logger from argos_monitoring.server import models, routes, queries from argos_monitoring.server.exceptions import ( NotAuthenticatedException, auth_exception_handler, ) from argos_monitoring.server.settings import read_yaml_config def get_application() -> FastAPI: """Spawn Argos FastAPI server""" appli = FastAPI(lifespan=lifespan) config_file = os.environ["ARGOS_YAML_FILE"] config = read_config(config_file) # Config is the argos config object (built from yaml) appli.state.config = config appli.add_exception_handler(NotAuthenticatedException, auth_exception_handler) appli.state.manager = create_manager(config.general.cookie_secret) @appli.state.manager.user_loader() async def query_user(user: str) -> None | models.User: """ Get a user from the db :param user: name of the user :return: None or the user object """ return await queries.get_user(appli.state.db, user) appli.include_router(routes.api, prefix="/api") appli.include_router(routes.views) static_dir = Path(__file__).resolve().parent / "static" appli.mount("/static", StaticFiles(directory=static_dir), name="static") return appli async def connect_to_db(appli): appli.state.db = appli.state.SessionLocal() return appli.state.db def read_config(yaml_file): try: config = read_yaml_config(yaml_file) return config except ValidationError as err: logger.error("Errors where found while reading configuration:") for error in err.errors(): logger.error("%s is %s", error["loc"], error["type"]) sys.exit(1) def setup_database(appli): config = appli.state.config db_url = str(config.general.db.url) logger.debug("Using database URL %s", db_url) # For sqlite, we need to add connect_args={"check_same_thread": False} if config.general.env == "production" and db_url.startswith("sqlite:////tmp"): logger.warning("Using sqlite in /tmp is not recommended for production") extra_settings = {} if config.general.db.pool_size: extra_settings.setdefault("pool_size", config.general.db.pool_size) if config.general.db.max_overflow: extra_settings.setdefault("max_overflow", config.general.db.max_overflow) engine = create_engine(db_url, **extra_settings) def _fk_pragma_on_connect(dbapi_con, con_record): dbapi_con.execute("pragma foreign_keys=ON") if db_url.startswith("sqlite:///"): event.listen(engine, "connect", _fk_pragma_on_connect) appli.state.SessionLocal = sessionmaker( autocommit=False, autoflush=False, bind=engine ) appli.state.engine = engine models.Base.metadata.create_all(bind=engine) def create_manager(cookie_secret): if cookie_secret == "foo_bar_baz": logger.warning( "You should change the cookie_secret secret in your configuration file." ) return LoginManager( cookie_secret, "/login", use_cookie=True, use_header=False, not_authenticated_exception=NotAuthenticatedException, ) @asynccontextmanager async def lifespan(appli): """Server start and stop actions Setup database connection then close it at shutdown. """ setup_database(appli) db = await connect_to_db(appli) tasks_count = await queries.count_tasks(db) if tasks_count == 0: logger.warning( "There is no tasks in the database. " 'Please launch the command "argos-monitoring server reload-config"' ) yield appli.state.db.close() app = get_application()