refactor: add a coroutine decorator for the click commands

This commit is contained in:
Alexis Métaireau 2023-12-16 22:53:39 +01:00
parent 08c3d8fc20
commit a8d86ea525

View file

@ -1,5 +1,6 @@
import asyncio import asyncio
import os import os
from functools import wraps
import click import click
import uvicorn import uvicorn
@ -8,16 +9,22 @@ from argos import logging
from argos.agent import ArgosAgent from argos.agent import ArgosAgent
def validate_max_lock_seconds(ctx, param, value): async def get_db():
if value <= 60: from argos.server.main import connect_to_db, get_application, setup_database
raise click.BadParameter("Should be strictly higher than 60")
return value app = get_application()
setup_database(app)
return await connect_to_db(app)
def validate_max_results(ctx, param, value): def coroutine(f):
if value <= 0: """Decorator to enable async functions in click"""
raise click.BadParameter("Should be a positive integer")
return value @wraps(f)
def wrapper(*args, **kwargs):
return asyncio.run(f(*args, **kwargs))
return wrapper
@click.group() @click.group()
@ -79,6 +86,12 @@ def start(host, port, config, reload):
uvicorn.run("argos.server:app", host=host, port=port, reload=reload) uvicorn.run("argos.server:app", host=host, port=port, reload=reload)
def validate_max_lock_seconds(ctx, param, value):
if value <= 60:
raise click.BadParameter("Should be strictly higher than 60")
return value
@server.command() @server.command()
@click.option("--max-results", default=100, help="Number of results per task to keep") @click.option("--max-results", default=100, help="Number of results per task to keep")
@click.option( @click.option(
@ -88,7 +101,8 @@ def start(host, port, config, reload):
"(the checks have a timeout value of 60 seconds)", "(the checks have a timeout value of 60 seconds)",
callback=validate_max_lock_seconds, callback=validate_max_lock_seconds,
) )
def cleandb(max_results, max_lock_seconds): @coroutine
async def cleandb(max_results, max_lock_seconds):
"""Clean the database (to run routinely) """Clean the database (to run routinely)
\b \b
@ -97,19 +111,13 @@ def cleandb(max_results, max_lock_seconds):
""" """
# The imports are made here otherwise the agent will need server configuration files. # The imports are made here otherwise the agent will need server configuration files.
from argos.server import queries from argos.server import queries
from argos.server.main import connect_to_db, get_application, setup_database
async def clean_old_results(): db = await get_db()
app = get_application() removed = await queries.remove_old_results(db, max_results)
setup_database(app) updated = await queries.release_old_locks(db, max_lock_seconds)
db = await connect_to_db(app)
removed = await queries.remove_old_results(db, max_results)
updated = await queries.release_old_locks(db, max_lock_seconds)
click.echo(f"{removed} results removed") click.echo(f"{removed} results removed")
click.echo(f"{updated} locks released") click.echo(f"{updated} locks released")
asyncio.run(clean_old_results())
if __name__ == "__main__": if __name__ == "__main__":