🎨 — Ruff formating

This commit is contained in:
Luc Didry 2024-03-25 13:19:07 +01:00
parent 6f856afe4a
commit 6f93eeee49
No known key found for this signature in database
GPG key ID: EA868E12D0257E3C
19 changed files with 221 additions and 202 deletions

View file

@ -1,10 +1,12 @@
import re import re
import sys import sys
def fix_output(matchobj):
return f'{matchobj.group(1)}{float(matchobj.group(2)) * 10}/{int(matchobj.group(3)) * 10}'
pattern = re.compile(r'(Your code has been rated at )([0-9.]+)/(10)') def fix_output(matchobj):
return f"{matchobj.group(1)}{float(matchobj.group(2)) * 10}/{int(matchobj.group(3)) * 10}"
pattern = re.compile(r"(Your code has been rated at )([0-9.]+)/(10)")
for line in sys.stdin: for line in sys.stdin:
line.rstrip() line.rstrip()
print(re.sub(pattern, fix_output, line), end='') print(re.sub(pattern, fix_output, line), end="")

View file

@ -51,7 +51,8 @@ def run_migrations_online() -> None:
) )
with connectable.connect() as connection: with connectable.connect() as connection:
context.configure(connection=connection, context.configure(
connection=connection,
target_metadata=target_metadata, target_metadata=target_metadata,
render_as_batch=True, render_as_batch=True,
) )

View file

@ -12,20 +12,21 @@ from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = '1a3497f9f71b' revision: str = "1a3497f9f71b"
down_revision: Union[str, None] = 'e99bc35702c9' down_revision: Union[str, None] = "e99bc35702c9"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
op.create_table('config_cache', op.create_table(
sa.Column('name', sa.String(), nullable=False), "config_cache",
sa.Column('val', sa.String(), nullable=False), sa.Column("name", sa.String(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False), sa.Column("val", sa.String(), nullable=False),
sa.PrimaryKeyConstraint('name') sa.Column("updated_at", sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint("name"),
) )
def downgrade() -> None: def downgrade() -> None:
op.drop_table('config_cache') op.drop_table("config_cache")

View file

@ -11,23 +11,23 @@ from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = 'defda3f2952d' revision: str = "defda3f2952d"
down_revision: Union[str, None] = '1a3497f9f71b' down_revision: Union[str, None] = "1a3497f9f71b"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
with op.batch_alter_table('results', schema=None) as batch_op: with op.batch_alter_table("results", schema=None) as batch_op:
batch_op.drop_constraint('results_task_id_fkey', type_='foreignkey') batch_op.drop_constraint("results_task_id_fkey", type_="foreignkey")
batch_op.create_foreign_key('results_task_id_fkey', batch_op.create_foreign_key(
'tasks', "results_task_id_fkey", "tasks", ["task_id"], ["id"], ondelete="CASCADE"
['task_id'], )
['id'],
ondelete='CASCADE')
def downgrade() -> None: def downgrade() -> None:
with op.batch_alter_table('results', schema=None) as batch_op: with op.batch_alter_table("results", schema=None) as batch_op:
batch_op.drop_constraint('results_task_id_fkey', type_='foreignkey') batch_op.drop_constraint("results_task_id_fkey", type_="foreignkey")
batch_op.create_foreign_key('results_task_id_fkey', 'tasks', ['task_id'], ['id']) batch_op.create_foreign_key(
"results_task_id_fkey", "tasks", ["task_id"], ["id"]
)

View file

@ -12,26 +12,27 @@ from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = 'e99bc35702c9' revision: str = "e99bc35702c9"
down_revision: Union[str, None] = '7d480e6f1112' down_revision: Union[str, None] = "7d480e6f1112"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
op.execute("ALTER TYPE severity ADD VALUE 'unknown'") op.execute("ALTER TYPE severity ADD VALUE 'unknown'")
op.add_column('tasks', op.add_column(
sa.Column('severity', "tasks",
sa.Enum( sa.Column(
'ok', "severity",
'warning', sa.Enum("ok", "warning", "critical", "unknown", name="severity"),
'critical', nullable=False,
'unknown', ),
name='severity'), )
nullable=False)) op.add_column(
op.add_column('tasks', sa.Column('last_severity_update', sa.DateTime(), nullable=True)) "tasks", sa.Column("last_severity_update", sa.DateTime(), nullable=True)
)
def downgrade() -> None: def downgrade() -> None:
op.drop_column('tasks', 'last_severity_update') op.drop_column("tasks", "last_severity_update")
op.drop_column('tasks', 'severity') op.drop_column("tasks", "severity")

View file

@ -20,7 +20,9 @@ class HTTPStatus(BaseCheck):
async def run(self) -> dict: async def run(self) -> dict:
# XXX Get the method from the task # XXX Get the method from the task
task = self.task task = self.task
response = await self.http_client.request(method="get", url=task.url, timeout=60) response = await self.http_client.request(
method="get", url=task.url, timeout=60
)
return self.response( return self.response(
status=response.status_code == self.expected, status=response.status_code == self.expected,
@ -36,7 +38,9 @@ class HTTPBodyContains(BaseCheck):
expected_cls = ExpectedStringValue expected_cls = ExpectedStringValue
async def run(self) -> dict: async def run(self) -> dict:
response = await self.http_client.request(method="get", url=self.task.url, timeout=60) response = await self.http_client.request(
method="get", url=self.task.url, timeout=60
)
return self.response(status=self.expected in response.text) return self.response(status=self.expected in response.text)

View file

@ -136,8 +136,7 @@ async def cleandb(max_results, max_lock_seconds):
@server.command() @server.command()
@coroutine @coroutine
async def reload_config(): async def reload_config():
"""Read tasks config and add/delete tasks in database if needed """Read tasks config and add/delete tasks in database if needed"""
"""
# The imports are made here otherwise the agent will need server configuration files. # The imports are made here otherwise the agent will need server configuration files.
from argos.server import queries from argos.server import queries
from argos.server.main import get_application, read_config from argos.server.main import get_application, read_config
@ -157,8 +156,7 @@ async def reload_config():
@server.command() @server.command()
@coroutine @coroutine
async def migrate(): async def migrate():
"""Run database migrations """Run database migrations"""
"""
# The imports are made here otherwise the agent will need server configuration files. # The imports are made here otherwise the agent will need server configuration files.
from argos.server.settings import get_app_settings from argos.server.settings import get_app_settings

View file

@ -103,14 +103,16 @@ class Service(BaseModel):
class MailAuth(BaseModel): class MailAuth(BaseModel):
"""Mail authentication configuration""" """Mail authentication configuration"""
login: str login: str
password: str password: str
class Mail(BaseModel): class Mail(BaseModel):
"""Mail configuration""" """Mail configuration"""
mailfrom: EmailStr mailfrom: EmailStr
host: str = '127.0.0.1' host: str = "127.0.0.1"
port: PositiveInt = 25 port: PositiveInt = 25
ssl: StrictBool = False ssl: StrictBool = False
starttls: StrictBool = False starttls: StrictBool = False

View file

@ -13,6 +13,7 @@ from pydantic import BaseModel, ConfigDict
class Task(BaseModel): class Task(BaseModel):
"""A task corresponds to a check to execute""" """A task corresponds to a check to execute"""
id: int id: int
url: str url: str
domain: str domain: str
@ -32,6 +33,7 @@ class Task(BaseModel):
class SerializableException(BaseModel): class SerializableException(BaseModel):
"""Task exception""" """Task exception"""
error_message: str error_message: str
error_type: str error_type: str
error_details: str error_details: str
@ -47,6 +49,7 @@ class SerializableException(BaseModel):
class AgentResult(BaseModel): class AgentResult(BaseModel):
"""Tasks result sent by agent""" """Tasks result sent by agent"""
task_id: int task_id: int
# The on-check status means that the service needs to finish the check # The on-check status means that the service needs to finish the check
# and will then determine the severity. # and will then determine the severity.

View file

@ -1,7 +1,9 @@
from typing import Literal, Union from typing import Literal, Union
def string_to_duration(value: str, target: Literal["days", "hours", "minutes"]) -> Union[int,float]: def string_to_duration(
value: str, target: Literal["days", "hours", "minutes"]
) -> Union[int, float]:
"""Convert a string to a number of hours, days or minutes""" """Convert a string to a number of hours, days or minutes"""
num = int("".join(filter(str.isdigit, value))) num = int("".join(filter(str.isdigit, value)))

View file

@ -13,26 +13,37 @@ from argos.schemas.config import Config, Mail, GotifyUrl
# XXX Implement mail alerts https://framagit.org/framasoft/framaspace/argos/-/issues/15 # XXX Implement mail alerts https://framagit.org/framasoft/framaspace/argos/-/issues/15
# XXX Implement gotify alerts https://framagit.org/framasoft/framaspace/argos/-/issues/16 # XXX Implement gotify alerts https://framagit.org/framasoft/framaspace/argos/-/issues/16
def handle_alert(config: Config, result, task, severity, old_severity, request): def handle_alert(config: Config, result, task, severity, old_severity, request):
"""Dispatch alert through configured alert channels""" """Dispatch alert through configured alert channels"""
if 'local' in getattr(config.general.alerts, severity): if "local" in getattr(config.general.alerts, severity):
logger.error("Alerting stub: task=%i, status=%s, severity=%s", logger.error(
"Alerting stub: task=%i, status=%s, severity=%s",
task.id, task.id,
result.status, result.status,
severity) severity,
)
if config.general.mail is not None and \ if config.general.mail is not None and "mail" in getattr(
'mail' in getattr(config.general.alerts, severity): config.general.alerts, severity
notify_by_mail(result, task, severity, old_severity, config.general.mail, request) ):
notify_by_mail(
result, task, severity, old_severity, config.general.mail, request
)
if config.general.gotify is not None and \ if config.general.gotify is not None and "gotify" in getattr(
'gotify' in getattr(config.general.alerts, severity): config.general.alerts, severity
notify_with_gotify(result, task, severity, old_severity, config.general.gotify, request) ):
notify_with_gotify(
result, task, severity, old_severity, config.general.gotify, request
)
def notify_by_mail(result, task, severity: str, old_severity: str, config: Mail, request) -> None: def notify_by_mail(
logger.debug('Will send mail notification') result, task, severity: str, old_severity: str, config: Mail, request
) -> None:
logger.debug("Will send mail notification")
msg = f"""\ msg = f"""\
URL: {task.url} URL: {task.url}
@ -50,50 +61,43 @@ Subject: [Argos] {urlparse(task.url).netloc}: status {severity}
{msg}""" {msg}"""
if config.ssl: if config.ssl:
logger.debug('Mail notification: SSL') logger.debug("Mail notification: SSL")
context = ssl.create_default_context() context = ssl.create_default_context()
smtp = smtplib.SMTP_SSL(host=config.host, smtp = smtplib.SMTP_SSL(host=config.host, port=config.port, context=context)
port=config.port,
context=context)
else: else:
smtp = smtplib.SMTP(host=config.host, # type: ignore smtp = smtplib.SMTP(
port=config.port) host=config.host, # type: ignore
port=config.port,
)
if config.starttls: if config.starttls:
logger.debug('Mail notification: STARTTLS') logger.debug("Mail notification: STARTTLS")
context = ssl.create_default_context() context = ssl.create_default_context()
smtp.starttls(context=context) smtp.starttls(context=context)
if config.auth is not None: if config.auth is not None:
logger.debug('Mail notification: authentification') logger.debug("Mail notification: authentification")
smtp.login(config.auth.login, smtp.login(config.auth.login, config.auth.password)
config.auth.password)
for address in config.addresses: for address in config.addresses:
logger.debug('Sending mail to %s', address) logger.debug("Sending mail to %s", address)
logger.debug(msg) logger.debug(msg)
smtp.sendmail(config.mailfrom, address, mail) smtp.sendmail(config.mailfrom, address, mail)
def notify_with_gotify( def notify_with_gotify(
result, result, task, severity: str, old_severity: str, config: List[GotifyUrl], request
task,
severity: str,
old_severity: str,
config: List[GotifyUrl],
request
) -> None: ) -> None:
logger.debug('Will send gotify notification') logger.debug("Will send gotify notification")
headers = {'accept': 'application/json', headers = {"accept": "application/json", "content-type": "application/json"}
'content-type': 'application/json'}
priority = 9 priority = 9
icon = '' icon = ""
if severity == Severity.OK: if severity == Severity.OK:
priority = 1 priority = 1
icon = '' icon = ""
elif severity == Severity.WARNING: elif severity == Severity.WARNING:
priority = 5 priority = 5
icon = '⚠️' icon = "⚠️"
subject = f"{icon} {urlparse(task.url).netloc}: status {severity}" subject = f"{icon} {urlparse(task.url).netloc}: status {severity}"
msg = f"""\ msg = f"""\
@ -106,20 +110,22 @@ Previous status: {old_severity}
See results of task on {request.url_for('get_task_results_view', task_id=task.id)} See results of task on {request.url_for('get_task_results_view', task_id=task.id)}
""" """
payload = {'title': subject, payload = {"title": subject, "message": msg, "priority": priority}
'message': msg,
'priority': priority}
for url in config: for url in config:
logger.debug('Sending gotify message(s) to %s', url) logger.debug("Sending gotify message(s) to %s", url)
for token in url.tokens: for token in url.tokens:
try: try:
res = httpx.post(f"{url.url}message", res = httpx.post(
params={'token': token}, f"{url.url}message",
params={"token": token},
headers=headers, headers=headers,
json=payload) json=payload,
)
res.raise_for_status() res.raise_for_status()
except httpx.RequestError as err: except httpx.RequestError as err:
logger.error('An error occurred while sending a message to %s with token %s', logger.error(
"An error occurred while sending a message to %s with token %s",
err.request.url, err.request.url,
token) token,
)

View file

@ -41,6 +41,7 @@ def create_start_app_handler(appli):
"""Warmup the server: """Warmup the server:
setup database connection setup database connection
""" """
async def _get_db(): async def _get_db():
setup_database(appli) setup_database(appli)
@ -58,6 +59,7 @@ def create_stop_app_handler(appli):
"""Gracefully shutdown the server: """Gracefully shutdown the server:
close database connection. close database connection.
""" """
async def stop_app(): async def stop_app():
appli.state.db.close() appli.state.db.close()
@ -72,7 +74,7 @@ def read_config(appli, settings):
except ValidationError as err: except ValidationError as err:
logger.error("Errors where found while reading configuration:") logger.error("Errors where found while reading configuration:")
for error in err.errors(): for error in err.errors():
logger.error("%s is %s", error['loc'], error['type']) logger.error("%s is %s", error["loc"], error["type"])
sys.exit(1) sys.exit(1)
@ -90,16 +92,13 @@ def setup_database(appli):
if settings.db_max_overflow: if settings.db_max_overflow:
extra_settings.setdefault("max_overflow", settings.db_max_overflow) extra_settings.setdefault("max_overflow", settings.db_max_overflow)
engine = create_engine( engine = create_engine(settings.database_url, **extra_settings)
settings.database_url,
**extra_settings
)
def _fk_pragma_on_connect(dbapi_con, con_record): def _fk_pragma_on_connect(dbapi_con, con_record):
dbapi_con.execute('pragma foreign_keys=ON') dbapi_con.execute("pragma foreign_keys=ON")
if settings.database_url.startswith("sqlite:////"): if settings.database_url.startswith("sqlite:////"):
event.listen(engine, 'connect', _fk_pragma_on_connect) event.listen(engine, "connect", _fk_pragma_on_connect)
appli.state.SessionLocal = sessionmaker( appli.state.SessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine autocommit=False, autoflush=False, bind=engine

View file

@ -43,13 +43,15 @@ class Task(Base):
severity: Mapped[Literal["ok", "warning", "critical", "unknown"]] = mapped_column( severity: Mapped[Literal["ok", "warning", "critical", "unknown"]] = mapped_column(
Enum("ok", "warning", "critical", "unknown", name="severity"), Enum("ok", "warning", "critical", "unknown", name="severity"),
insert_default="unknown" insert_default="unknown",
) )
last_severity_update: Mapped[datetime] = mapped_column(nullable=True) last_severity_update: Mapped[datetime] = mapped_column(nullable=True)
results: Mapped[List["Result"]] = relationship(back_populates="task", results: Mapped[List["Result"]] = relationship(
back_populates="task",
cascade="all, delete", cascade="all, delete",
passive_deletes=True,) passive_deletes=True,
)
def __str__(self): def __str__(self):
return f"DB Task {self.url} - {self.check} - {self.expected}" return f"DB Task {self.url} - {self.check} - {self.expected}"
@ -92,6 +94,7 @@ class Result(Base):
The status is "Was the agent able to do the check?" while the severity The status is "Was the agent able to do the check?" while the severity
depends on the return value of the check. depends on the return value of the check.
""" """
__tablename__ = "results" __tablename__ = "results"
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[int] = mapped_column(primary_key=True)
task_id: Mapped[int] = mapped_column(ForeignKey("tasks.id", ondelete="CASCADE")) task_id: Mapped[int] = mapped_column(ForeignKey("tasks.id", ondelete="CASCADE"))
@ -115,6 +118,7 @@ class Result(Base):
def __str__(self): def __str__(self):
return f"DB Result {self.id} - {self.status} - {self.context}" return f"DB Result {self.id} - {self.status} - {self.context}"
class ConfigCache(Base): class ConfigCache(Base):
"""Contains some informations on the previous config state """Contains some informations on the previous config state
@ -126,6 +130,7 @@ class ConfigCache(Base):
comparison without looping through all websites comparison without looping through all websites
ex: 8b886e7db7b553fe99f6d5437f31745987e243c77b2109b84cf9a7f8bf7d75b1 ex: 8b886e7db7b553fe99f6d5437f31745987e243c77b2109b84cf9a7f8bf7d75b1
""" """
__tablename__ = "config_cache" __tablename__ = "config_cache"
name: Mapped[str] = mapped_column(primary_key=True) name: Mapped[str] = mapped_column(primary_key=True)
val: Mapped[str] = mapped_column() val: Mapped[str] = mapped_column()

View file

@ -65,20 +65,17 @@ async def count_results(db: Session):
async def has_config_changed(db: Session, config: schemas.Config) -> bool: async def has_config_changed(db: Session, config: schemas.Config) -> bool:
"""Check if websites config has changed by using a hashsum and a config cache""" """Check if websites config has changed by using a hashsum and a config cache"""
websites_hash = sha256(str(config.websites).encode()).hexdigest() websites_hash = sha256(str(config.websites).encode()).hexdigest()
conf_caches = ( conf_caches = db.query(ConfigCache).all()
db.query(ConfigCache)
.all()
)
same_config = True same_config = True
if conf_caches: if conf_caches:
for conf in conf_caches: for conf in conf_caches:
match (conf.name): match conf.name:
case 'websites_hash': case "websites_hash":
if conf.val != websites_hash: if conf.val != websites_hash:
same_config = False same_config = False
conf.val = websites_hash conf.val = websites_hash
conf.updated_at = datetime.now() conf.updated_at = datetime.now()
case 'general_frequency': case "general_frequency":
if conf.val != str(config.general.frequency): if conf.val != str(config.general.frequency):
same_config = False same_config = False
conf.val = config.general.frequency conf.val = config.general.frequency
@ -91,14 +88,12 @@ async def has_config_changed(db: Session, config: schemas.Config) -> bool:
else: # no config cache found else: # no config cache found
web_hash = ConfigCache( web_hash = ConfigCache(
name='websites_hash', name="websites_hash", val=websites_hash, updated_at=datetime.now()
val=websites_hash,
updated_at=datetime.now()
) )
gen_freq = ConfigCache( gen_freq = ConfigCache(
name='general_frequency', name="general_frequency",
val=str(config.general.frequency), val=str(config.general.frequency),
updated_at=datetime.now() updated_at=datetime.now(),
) )
db.add(web_hash) db.add(web_hash)
db.add(gen_freq) db.add(gen_freq)
@ -111,11 +106,10 @@ async def update_from_config(db: Session, config: schemas.Config):
"""Update tasks from config file""" """Update tasks from config file"""
config_changed = await has_config_changed(db, config) config_changed = await has_config_changed(db, config)
if not config_changed: if not config_changed:
return {'added': 0, 'vanished': 0} return {"added": 0, "vanished": 0}
max_task_id = ( max_task_id = (
db.query(func.max(Task.id).label('max_id')) # pylint: disable-msg=not-callable db.query(func.max(Task.id).label("max_id")).all() # pylint: disable-msg=not-callable
.all()
)[0].max_id )[0].max_id
tasks = [] tasks = []
unique_properties = [] unique_properties = []
@ -143,9 +137,14 @@ async def update_from_config(db: Session, config: schemas.Config):
if frequency != existing_task.frequency: if frequency != existing_task.frequency:
existing_task.frequency = frequency existing_task.frequency = frequency
logger.debug("Skipping db task creation for url=%s, " \ logger.debug(
"Skipping db task creation for url=%s, "
"check_key=%s, expected=%s, frequency=%s.", "check_key=%s, expected=%s, frequency=%s.",
url, check_key, expected, frequency) url,
check_key,
expected,
frequency,
)
else: else:
properties = (url, check_key, expected) properties = (url, check_key, expected)
@ -168,23 +167,22 @@ async def update_from_config(db: Session, config: schemas.Config):
if max_task_id: if max_task_id:
vanished_tasks = ( vanished_tasks = (
db.query(Task) db.query(Task)
.filter( .filter(Task.id <= max_task_id, Task.id.not_in(seen_tasks))
Task.id <= max_task_id, .delete()
Task.id.not_in(seen_tasks)
).delete()
) )
db.commit() db.commit()
logger.info("%i tasks has been removed since not in config file anymore", vanished_tasks) logger.info(
return {'added': len(tasks), 'vanished': vanished_tasks} "%i tasks has been removed since not in config file anymore", vanished_tasks
)
return {"added": len(tasks), "vanished": vanished_tasks}
return {'added': len(tasks), 'vanished': 0} return {"added": len(tasks), "vanished": 0}
async def get_severity_counts(db: Session) -> dict: async def get_severity_counts(db: Session) -> dict:
"""Get the severities (ok, warning, critical…) and their count""" """Get the severities (ok, warning, critical…) and their count"""
query = ( query = db.query(Task.severity, func.count(Task.id).label("count")).group_by( # pylint: disable-msg=not-callable
db.query(Task.severity, func.count(Task.id).label("count")) # pylint: disable-msg=not-callable Task.severity
.group_by(Task.severity)
) )
# Execute the query and fetch the results # Execute the query and fetch the results
@ -198,9 +196,9 @@ async def get_severity_counts(db: Session) -> dict:
async def reschedule_all(db: Session): async def reschedule_all(db: Session):
"""Reschedule checks of all non OK tasks ASAP""" """Reschedule checks of all non OK tasks ASAP"""
db.query(Task) \ db.query(Task).filter(Task.severity.in_(["warning", "critical", "unknown"])).update(
.filter(Task.severity.in_(['warning', 'critical', 'unknown'])) \ {Task.next_run: datetime.now() - timedelta(days=1)}
.update({Task.next_run: datetime.now() - timedelta(days=1)}) )
db.commit() db.commit()

View file

@ -18,7 +18,7 @@ async def read_tasks(
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
limit: int = 10, limit: int = 10,
agent_id: Union[None,str] = None, agent_id: Union[None, str] = None,
): ):
"""Return a list of tasks to execute""" """Return a list of tasks to execute"""
agent_id = agent_id or request.client.host agent_id = agent_id or request.client.host
@ -32,7 +32,7 @@ async def create_results(
results: List[AgentResult], results: List[AgentResult],
db: Session = Depends(get_db), db: Session = Depends(get_db),
config: Config = Depends(get_config), config: Config = Depends(get_config),
agent_id: Union[None,str] = None, agent_id: Union[None, str] = None,
): ):
"""Get the results from the agents and store them locally. """Get the results from the agents and store them locally.
@ -67,18 +67,15 @@ async def create_results(
return {"result_ids": [r.id for r in db_results]} return {"result_ids": [r.id for r in db_results]}
@route.post("/reschedule/all", @route.post(
"/reschedule/all",
responses={ responses={
200: { 200: {
"content": { "content": {
"application/json": { "application/json": {"example": {"msg": "Non OK tasks reschuled"}}
"example": {
"msg": "Non OK tasks reschuled"
}
}
}
} }
} }
},
) )
async def reschedule_all(request: Request, db: Session = Depends(get_db)): async def reschedule_all(request: Request, db: Session = Depends(get_db)):
"""Reschedule checks of all non OK tasks ASAP""" """Reschedule checks of all non OK tasks ASAP"""
@ -86,20 +83,21 @@ async def reschedule_all(request: Request, db: Session = Depends(get_db)):
return {"msg": "Non OK tasks reschuled"} return {"msg": "Non OK tasks reschuled"}
@route.get("/stats", @route.get(
"/stats",
responses={ responses={
200: { 200: {
"content": { "content": {
"application/json": { "application/json": {
"example": { "example": {
"upcoming_tasks_count":0, "upcoming_tasks_count": 0,
"results_count":1993085, "results_count": 1993085,
"selected_tasks_count":1845 "selected_tasks_count": 1845,
}
} }
} }
} }
} }
},
) )
async def get_stats(db: Session = Depends(get_db)): async def get_stats(db: Session = Depends(get_db)):
"""Get tasks statistics""" """Get tasks statistics"""
@ -110,18 +108,17 @@ async def get_stats(db: Session = Depends(get_db)):
} }
@route.get("/severities", @route.get(
"/severities",
responses={ responses={
200: { 200: {
"content": { "content": {
"application/json": { "application/json": {
"example": { "example": {"ok": 1541, "warning": 0, "critical": 0, "unknown": 0}
"ok":1541,"warning":0,"critical":0,"unknown":0
}
}
} }
} }
} }
},
) )
async def get_severity_counts(db: Session = Depends(get_db)): async def get_severity_counts(db: Session = Depends(get_db)):
"""Returns the number of results per severity""" """Returns the number of results per severity"""

View file

@ -16,12 +16,7 @@ from argos.server.routes.dependencies import get_config, get_db
route = APIRouter() route = APIRouter()
templates = Jinja2Templates(directory="argos/server/templates") templates = Jinja2Templates(directory="argos/server/templates")
SEVERITY_LEVELS = { SEVERITY_LEVELS = {"ok": 1, "warning": 2, "critical": 3, "unknown": 4}
"ok": 1,
"warning": 2,
"critical": 3,
"unknown": 4
}
@route.get("/") @route.get("/")
@ -29,7 +24,7 @@ async def get_severity_counts_view(
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
refresh: bool = False, refresh: bool = False,
delay: int = 15 delay: int = 15,
): ):
"""Shows the number of results per severity""" """Shows the number of results per severity"""
counts_dict = await queries.get_severity_counts(db) counts_dict = await queries.get_severity_counts(db)
@ -62,7 +57,7 @@ async def get_domains_view(request: Request, db: Session = Depends(get_db)):
if task.last_severity_update is not None: if task.last_severity_update is not None:
domains_last_checks[domain] = task.last_severity_update domains_last_checks[domain] = task.last_severity_update
else: else:
domains_last_checks[domain] = 'Waiting to be checked' domains_last_checks[domain] = "Waiting to be checked"
def _max_severity(severities): def _max_severity(severities):
return max(severities, key=SEVERITY_LEVELS.get) return max(severities, key=SEVERITY_LEVELS.get)
@ -100,14 +95,16 @@ async def get_domain_tasks_view(
request: Request, domain: str, db: Session = Depends(get_db) request: Request, domain: str, db: Session = Depends(get_db)
): ):
"""Show all tasks attached to a domain""" """Show all tasks attached to a domain"""
tasks = db.query(Task).filter(Task.domain.contains(f'//{domain}')).all() tasks = db.query(Task).filter(Task.domain.contains(f"//{domain}")).all()
return templates.TemplateResponse( return templates.TemplateResponse(
"domain.html", {"request": request, "domain": domain, "tasks": tasks} "domain.html", {"request": request, "domain": domain, "tasks": tasks}
) )
@route.get("/result/{result_id}") @route.get("/result/{result_id}")
async def get_result_view(request: Request, result_id: int, db: Session = Depends(get_db)): async def get_result_view(
request: Request, result_id: int, db: Session = Depends(get_db)
):
"""Show the details of a result""" """Show the details of a result"""
result = db.query(Result).get(result_id) result = db.query(Result).get(result_id)
return templates.TemplateResponse( return templates.TemplateResponse(
@ -146,7 +143,7 @@ async def get_task_results_view(
async def get_agents_view(request: Request, db: Session = Depends(get_db)): async def get_agents_view(request: Request, db: Session = Depends(get_db)):
"""Show argos agents and the last time the server saw them""" """Show argos agents and the last time the server saw them"""
last_seen = ( last_seen = (
db.query(Result.agent_id, func.max(Result.submitted_at).label('submitted_at')) db.query(Result.agent_id, func.max(Result.submitted_at).label("submitted_at"))
.group_by(Result.agent_id) .group_by(Result.agent_id)
.all() .all()
) )

View file

@ -2,7 +2,7 @@
import os import os
from functools import lru_cache from functools import lru_cache
from os import environ from os import environ
from typing import Optional,Union from typing import Optional, Union
import yaml import yaml
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
@ -25,6 +25,7 @@ class DevSettings(Settings):
Uses config.yaml as config file. Uses config.yaml as config file.
Uses a SQLite database.""" Uses a SQLite database."""
app_env: str = "dev" app_env: str = "dev"
yaml_file: str = "config.yaml" yaml_file: str = "config.yaml"
db_pool_size: Optional[int] = None db_pool_size: Optional[int] = None
@ -37,6 +38,7 @@ class TestSettings(Settings):
Uses tests/config.yaml as config file. Uses tests/config.yaml as config file.
Uses a SQLite database.""" Uses a SQLite database."""
app_env: str = "test" app_env: str = "test"
yaml_file: str = "tests/config.yaml" yaml_file: str = "tests/config.yaml"
database_url: str = "sqlite:////tmp/test-argos.db" database_url: str = "sqlite:////tmp/test-argos.db"
@ -46,6 +48,7 @@ class TestSettings(Settings):
class ProdSettings(Settings): class ProdSettings(Settings):
"""Settings for prod environment.""" """Settings for prod environment."""
app_env: str = "prod" app_env: str = "prod"
db_pool_size: Optional[int] = 10 db_pool_size: Optional[int] = 10
db_max_overflow: Optional[int] = 20 db_max_overflow: Optional[int] = 20
@ -59,7 +62,7 @@ environments = {
@lru_cache() @lru_cache()
def get_app_settings() -> Union[None,Settings]: def get_app_settings() -> Union[None, Settings]:
"""Load settings depending on the environment""" """Load settings depending on the environment"""
app_env = environ.get("ARGOS_APP_ENV", "dev") app_env = environ.get("ARGOS_APP_ENV", "dev")
settings = environments.get(app_env) settings = environments.get(app_env)
@ -79,5 +82,5 @@ def _load_yaml(filename):
loader_class=yaml.FullLoader, base_dir=base_dir loader_class=yaml.FullLoader, base_dir=base_dir
) )
with open(filename, "r", encoding='utf-8') as stream: with open(filename, "r", encoding="utf-8") as stream:
return yaml.load(stream, Loader=yaml.FullLoader) return yaml.load(stream, Loader=yaml.FullLoader)

View file

@ -33,4 +33,4 @@ html_sidebars = {
html_theme = "shibuya" html_theme = "shibuya"
html_static_path = ["_static"] html_static_path = ["_static"]
html_css_files = ['fonts.css'] html_css_files = ["fonts.css"]

View file

@ -99,9 +99,7 @@ async def test_update_from_config_db_can_remove_duplicates_and_old_tasks(
dict( dict(
path="https://another-example.com", checks=[{task.check: task.expected}] path="https://another-example.com", checks=[{task.check: task.expected}]
), ),
dict( dict(path=task.url, checks=[{task.check: task.expected}]),
path=task.url, checks=[{task.check: task.expected}]
),
], ],
) )
empty_config.websites = [website] empty_config.websites = [website]
@ -140,7 +138,9 @@ async def test_update_from_config_db_updates_existing_tasks(db, empty_config, ta
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reschedule_all(db, ten_tasks, ten_warning_tasks, ten_critical_tasks, ten_ok_tasks): async def test_reschedule_all(
db, ten_tasks, ten_warning_tasks, ten_critical_tasks, ten_ok_tasks
):
assert db.query(Task).count() == 40 assert db.query(Task).count() == 40
assert db.query(Task).filter(Task.severity == "unknown").count() == 10 assert db.query(Task).filter(Task.severity == "unknown").count() == 10
assert db.query(Task).filter(Task.severity == "warning").count() == 10 assert db.query(Task).filter(Task.severity == "warning").count() == 10
@ -260,7 +260,7 @@ def ten_warning_tasks(db):
expected="foo", expected="foo",
frequency=1, frequency=1,
next_run=now, next_run=now,
severity="warning" severity="warning",
) )
db.add(task) db.add(task)
tasks.append(task) tasks.append(task)
@ -280,7 +280,7 @@ def ten_critical_tasks(db):
expected="foo", expected="foo",
frequency=1, frequency=1,
next_run=now, next_run=now,
severity="critical" severity="critical",
) )
db.add(task) db.add(task)
tasks.append(task) tasks.append(task)
@ -300,7 +300,7 @@ def ten_ok_tasks(db):
expected="foo", expected="foo",
frequency=1, frequency=1,
next_run=now, next_run=now,
severity="ok" severity="ok",
) )
db.add(task) db.add(task)
tasks.append(task) tasks.append(task)