diff --git a/argos/server/queries.py b/argos/server/queries.py index 266186d..b6cc453 100644 --- a/argos/server/queries.py +++ b/argos/server/queries.py @@ -60,6 +60,8 @@ async def count_results(db: Session): async def update_from_config(db: Session, config: schemas.Config): + tasks = [] + unique_properties = [] for website in config.websites: domain = str(website.domain) frequency = website.frequency or config.general.frequency @@ -68,31 +70,38 @@ async def update_from_config(db: Session, config: schemas.Config): url = urljoin(domain, str(p.path)) for check_key, expected in p.checks: # Check the db for already existing tasks. - existing_task = ( + existing_tasks = ( db.query(Task) .filter( Task.url == url, Task.check == check_key, Task.expected == expected, ) - .scalar() + .all() ) - if existing_task and frequency != existing_task.frequency: - existing_task.frequency = frequency + if existing_tasks: + existing_task = existing_tasks[0] - if not existing_task: - task = Task( - domain=domain, - url=url, - check=check_key, - expected=expected, - frequency=frequency, - ) - logger.debug(f"Adding a new task in the db: {task}") - db.add(task) - else: + if frequency != existing_task.frequency: + existing_task.frequency = frequency msg = f"Skipping db task creation for {url=}, {check_key=}, {expected=}, {frequency=}." logger.debug(msg) + + else: + properties = (url, check_key, expected) + if properties not in unique_properties: + unique_properties.append(properties) + task = Task( + domain=domain, + url=url, + check=check_key, + expected=expected, + frequency=frequency, + ) + logger.debug(f"Adding a new task in the db: {task}") + tasks.append(task) + + db.add_all(tasks) db.commit() diff --git a/tests/test_queries.py b/tests/test_queries.py index 1d9ccab..32f9c72 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -2,6 +2,7 @@ from datetime import datetime, timedelta import pytest +from argos import schemas from argos.server import queries from argos.server.models import Result, Task @@ -52,11 +53,81 @@ async def test_release_old_locks_with_empty_db(db): assert released == 0 +@pytest.mark.asyncio +async def test_update_from_config_with_duplicate_tasks(db, empty_config): + # We pass the same path twice + fake_path = dict(path="/", checks=[{"body-contains": "foo"}]) + website = schemas.config.Website( + domain="https://example.org", + paths=[ + fake_path, + fake_path, + ], + ) + empty_config.websites = [website] + + assert db.query(Task).count() == 0 + await queries.update_from_config(db, empty_config) + + # Only one path has been saved in the database + assert db.query(Task).count() == 1 + + # Calling again with the same data works, and will not result in more tasks being + # created. + await queries.update_from_config(db, empty_config) + + +@pytest.mark.asyncio +async def test_update_from_config_db_can_handle_already_present_duplicates( + db, empty_config, task +): + # Add a duplicate in the db + same_task = Task( + url=task.url, + domain=task.domain, + check=task.check, + expected=task.expected, + frequency=task.frequency, + ) + db.add(same_task) + db.commit() + assert db.query(Task).count() == 2 + + website = schemas.config.Website( + domain=task.domain, + paths=[ + dict( + path="https://another-example.com", checks=[{task.check: task.expected}] + ), + ], + ) + empty_config.websites = [website] + + await queries.update_from_config(db, empty_config) + assert db.query(Task).count() == 3 + + +@pytest.mark.asyncio +async def test_update_from_config_db_updates_existing_tasks(db, empty_config, task): + assert db.query(Task).count() == 1 + + website = schemas.config.Website( + domain=task.domain, + paths=[ + dict(path=task.url, checks=[{task.check: task.expected}]), + ], + ) + empty_config.websites = [website] + + await queries.update_from_config(db, empty_config) + assert db.query(Task).count() == 1 + + @pytest.fixture def task(db): task = Task( url="https://www.example.com", - domain="example.com", + domain="https://www.example.com", check="body-contains", expected="foo", frequency=1, @@ -66,6 +137,27 @@ def task(db): return task +@pytest.fixture +def empty_config(): + return schemas.config.Config( + general=schemas.config.General( + frequency="1m", + alerts=schemas.config.Alert( + error=["", ""], + warning=["", ""], + alert=["", ""], + ), + ), + service=schemas.config.Service( + secrets=[ + "1234", + ] + ), + ssl=schemas.config.SSL(thresholds=[]), + websites=[], + ) + + @pytest.fixture def ten_results(db, task): results = []