refactor: add a coroutine decorator for the click commands

This commit is contained in:
Alexis Métaireau 2023-12-16 22:53:39 +01:00
parent 08c3d8fc20
commit a8d86ea525

View file

@ -1,5 +1,6 @@
import asyncio
import os
from functools import wraps
import click
import uvicorn
@ -8,16 +9,22 @@ from argos import logging
from argos.agent import ArgosAgent
def validate_max_lock_seconds(ctx, param, value):
if value <= 60:
raise click.BadParameter("Should be strictly higher than 60")
return value
async def get_db():
from argos.server.main import connect_to_db, get_application, setup_database
app = get_application()
setup_database(app)
return await connect_to_db(app)
def validate_max_results(ctx, param, value):
if value <= 0:
raise click.BadParameter("Should be a positive integer")
return value
def coroutine(f):
"""Decorator to enable async functions in click"""
@wraps(f)
def wrapper(*args, **kwargs):
return asyncio.run(f(*args, **kwargs))
return wrapper
@click.group()
@ -79,6 +86,12 @@ def start(host, port, config, reload):
uvicorn.run("argos.server:app", host=host, port=port, reload=reload)
def validate_max_lock_seconds(ctx, param, value):
if value <= 60:
raise click.BadParameter("Should be strictly higher than 60")
return value
@server.command()
@click.option("--max-results", default=100, help="Number of results per task to keep")
@click.option(
@ -88,7 +101,8 @@ def start(host, port, config, reload):
"(the checks have a timeout value of 60 seconds)",
callback=validate_max_lock_seconds,
)
def cleandb(max_results, max_lock_seconds):
@coroutine
async def cleandb(max_results, max_lock_seconds):
"""Clean the database (to run routinely)
\b
@ -97,19 +111,13 @@ def cleandb(max_results, max_lock_seconds):
"""
# The imports are made here otherwise the agent will need server configuration files.
from argos.server import queries
from argos.server.main import connect_to_db, get_application, setup_database
async def clean_old_results():
app = get_application()
setup_database(app)
db = await connect_to_db(app)
removed = await queries.remove_old_results(db, max_results)
updated = await queries.release_old_locks(db, max_lock_seconds)
db = await get_db()
removed = await queries.remove_old_results(db, max_results)
updated = await queries.release_old_locks(db, max_lock_seconds)
click.echo(f"{removed} results removed")
click.echo(f"{updated} locks released")
asyncio.run(clean_old_results())
click.echo(f"{removed} results removed")
click.echo(f"{updated} locks released")
if __name__ == "__main__":