diff --git a/argos/server/models.py b/argos/server/models.py index 451eb6b..0637e7e 100644 --- a/argos/server/models.py +++ b/argos/server/models.py @@ -93,7 +93,8 @@ class Result(Base): __tablename__ = "results" id: Mapped[int] = mapped_column(primary_key=True) task_id: Mapped[int] = mapped_column(ForeignKey("tasks.id")) - task: Mapped["Task"] = relationship(back_populates="results") + task: Mapped["Task"] = relationship(back_populates="results", + cascade="save-update, merge, delete") agent_id: Mapped[str] = mapped_column(nullable=True) submitted_at: Mapped[datetime] = mapped_column() @@ -112,3 +113,13 @@ class Result(Base): def __str__(self): return f"DB Result {self.id} - {self.status} - {self.context}" + +class ConfigCache(Base): + """Contains some informations on the previous config state + + Used to quickly determine if we need to update the tasks + """ + __tablename__ = "config_cache" + name: Mapped[str] = mapped_column(primary_key=True) + val: Mapped[str] = mapped_column() + updated_at: Mapped[datetime] = mapped_column() diff --git a/argos/server/queries.py b/argos/server/queries.py index 271b2c9..345f480 100644 --- a/argos/server/queries.py +++ b/argos/server/queries.py @@ -1,5 +1,7 @@ """Functions to ease SQL queries management""" from datetime import datetime, timedelta +from hashlib import sha256 +from typing import List from urllib.parse import urljoin from sqlalchemy import desc, func @@ -7,7 +9,7 @@ from sqlalchemy.orm import Session from argos import schemas from argos.logging import logger -from argos.server.models import Result, Task +from argos.server.models import Result, Task, ConfigCache async def list_tasks(db: Session, agent_id: str, limit: int = 100): @@ -60,10 +62,64 @@ async def count_results(db: Session): return db.query(Result).count() +async def is_config_unchanged(db: Session, config: schemas.Config) -> bool: + """Check if websites config has changed by using a hashsum and a config cache""" + websites_hash = sha256(str(config.websites).encode()).hexdigest() + conf_caches = ( + db.query(ConfigCache) + .all() + ) + same_config = True + if conf_caches: + for conf in conf_caches: + if not same_config: + break + + if conf.name == 'websites_hash': + same_config = conf.val == websites_hash + elif conf.name == 'general_frequency': + same_config = conf.val == str(config.general.frequency) + + if same_config: + return True + + for conf in conf_caches: + if conf.name == 'websites_hash': + conf.val = websites_hash + elif conf.name == 'general_frequency': + conf.val = config.general.frequency + conf.updated_at = datetime.now() + else: + web_hash = ConfigCache( + name='websites_hash', + val=websites_hash, + updated_at=datetime.now() + ) + gen_freq = ConfigCache( + name='general_frequency', + val=str(config.general.frequency), + updated_at=datetime.now() + ) + db.add(web_hash) + db.add(gen_freq) + db.commit() + + return False + + async def update_from_config(db: Session, config: schemas.Config): """Update tasks from config file""" + config_unchanged = await is_config_unchanged(db, config) + if config_unchanged: + return None + + max_task_id = ( + db.query(func.max(Task.id).label('max_id')) # pylint: disable-msg=not-callable + .all() + )[0].max_id tasks = [] unique_properties = [] + seen_tasks: List[int] = [] for website in config.websites: domain = str(website.domain) frequency = website.frequency or config.general.frequency @@ -83,6 +139,7 @@ async def update_from_config(db: Session, config: schemas.Config): ) if existing_tasks: existing_task = existing_tasks[0] + seen_tasks.append(existing_task.id) if frequency != existing_task.frequency: existing_task.frequency = frequency @@ -107,6 +164,18 @@ async def update_from_config(db: Session, config: schemas.Config): db.add_all(tasks) db.commit() + # Delete vanished tasks + if max_task_id: + vanished_tasks = ( + db.query(Task) + .filter( + Task.id <= max_task_id, + Task.id.not_in(seen_tasks) + ).delete() + ) + db.commit() + logger.info("%i tasks has been removed since not in config file anymore", vanished_tasks) + async def get_severity_counts(db: Session) -> dict: """Get the severities (ok, warning, critical…) and their count""" diff --git a/tests/test_queries.py b/tests/test_queries.py index 8770d98..d92c750 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -78,7 +78,7 @@ async def test_update_from_config_with_duplicate_tasks(db, empty_config): @pytest.mark.asyncio -async def test_update_from_config_db_can_handle_already_present_duplicates( +async def test_update_from_config_db_can_remove_duplicates_and_old_tasks( db, empty_config, task ): # Add a duplicate in the db @@ -99,12 +99,28 @@ async def test_update_from_config_db_can_handle_already_present_duplicates( dict( path="https://another-example.com", checks=[{task.check: task.expected}] ), + dict( + path=task.url, checks=[{task.check: task.expected}] + ), ], ) empty_config.websites = [website] await queries.update_from_config(db, empty_config) - assert db.query(Task).count() == 3 + assert db.query(Task).count() == 2 + + website = schemas.config.Website( + domain=task.domain, + paths=[ + dict( + path="https://another-example.com", checks=[{task.check: task.expected}] + ), + ], + ) + empty_config.websites = [website] + + await queries.update_from_config(db, empty_config) + assert db.query(Task).count() == 1 @pytest.mark.asyncio