From a8d86ea5256fdcf50c3188fdde057d7d8c75f063 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexis=20M=C3=A9taireau?= Date: Sat, 16 Dec 2023 22:53:39 +0100 Subject: [PATCH] refactor: add a coroutine decorator for the click commands --- argos/commands.py | 48 +++++++++++++++++++++++++++-------------------- 1 file changed, 28 insertions(+), 20 deletions(-) diff --git a/argos/commands.py b/argos/commands.py index 1e94071..23ae9f1 100644 --- a/argos/commands.py +++ b/argos/commands.py @@ -1,5 +1,6 @@ import asyncio import os +from functools import wraps import click import uvicorn @@ -8,16 +9,22 @@ from argos import logging from argos.agent import ArgosAgent -def validate_max_lock_seconds(ctx, param, value): - if value <= 60: - raise click.BadParameter("Should be strictly higher than 60") - return value +async def get_db(): + from argos.server.main import connect_to_db, get_application, setup_database + + app = get_application() + setup_database(app) + return await connect_to_db(app) -def validate_max_results(ctx, param, value): - if value <= 0: - raise click.BadParameter("Should be a positive integer") - return value +def coroutine(f): + """Decorator to enable async functions in click""" + + @wraps(f) + def wrapper(*args, **kwargs): + return asyncio.run(f(*args, **kwargs)) + + return wrapper @click.group() @@ -79,6 +86,12 @@ def start(host, port, config, reload): uvicorn.run("argos.server:app", host=host, port=port, reload=reload) +def validate_max_lock_seconds(ctx, param, value): + if value <= 60: + raise click.BadParameter("Should be strictly higher than 60") + return value + + @server.command() @click.option("--max-results", default=100, help="Number of results per task to keep") @click.option( @@ -88,7 +101,8 @@ def start(host, port, config, reload): "(the checks have a timeout value of 60 seconds)", callback=validate_max_lock_seconds, ) -def cleandb(max_results, max_lock_seconds): +@coroutine +async def cleandb(max_results, max_lock_seconds): """Clean the database (to run routinely) \b @@ -97,19 +111,13 @@ def cleandb(max_results, max_lock_seconds): """ # The imports are made here otherwise the agent will need server configuration files. from argos.server import queries - from argos.server.main import connect_to_db, get_application, setup_database - async def clean_old_results(): - app = get_application() - setup_database(app) - db = await connect_to_db(app) - removed = await queries.remove_old_results(db, max_results) - updated = await queries.release_old_locks(db, max_lock_seconds) + db = await get_db() + removed = await queries.remove_old_results(db, max_results) + updated = await queries.release_old_locks(db, max_lock_seconds) - click.echo(f"{removed} results removed") - click.echo(f"{updated} locks released") - - asyncio.run(clean_old_results()) + click.echo(f"{removed} results removed") + click.echo(f"{updated} locks released") if __name__ == "__main__":