— Add mypy test

This commit is contained in:
Luc Didry 2024-07-04 13:33:54 +02:00
parent 3b49594bef
commit 5bd4d9909a
No known key found for this signature in database
GPG key ID: EA868E12D0257E3C
13 changed files with 66 additions and 36 deletions

View file

@ -37,6 +37,12 @@ djlint:
script: script:
- make djlint - make djlint
mypy:
<<: *pull_cache
stage: test
script:
- make mypy
pylint: pylint:
<<: *pull_cache <<: *pull_cache
stage: test stage: test

View file

@ -3,6 +3,7 @@
## [Unreleased] ## [Unreleased]
- 🩹 — Fix release documentation - 🩹 — Fix release documentation
- ✅ — Add mypy test
## 0.2.2 ## 0.2.2

View file

@ -28,7 +28,9 @@ djlint: venv ## Format the templates
venv/bin/djlint --ignore=H030,H031,H006 --profile jinja --lint argos/server/templates/*html venv/bin/djlint --ignore=H030,H031,H006 --profile jinja --lint argos/server/templates/*html
pylint: venv ## Runs pylint on the code pylint: venv ## Runs pylint on the code
venv/bin/pylint argos venv/bin/pylint argos
lint: djlint pylint ruff mypy: venv
venv/bin/mypy argos tests
lint: djlint pylint mypy ruff
help: help:
@python3 -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST) @python3 -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST)

View file

@ -57,9 +57,9 @@ class ArgosAgent:
logger.error("Waiting %i seconds before next retry", self.wait_time) logger.error("Waiting %i seconds before next retry", self.wait_time)
await asyncio.sleep(self.wait_time) await asyncio.sleep(self.wait_time)
async def _complete_task(self, task: dict) -> dict: async def _complete_task(self, _task: dict) -> AgentResult:
try: try:
task = Task(**task) task = Task(**_task)
check_class = get_registered_check(task.check) check_class = get_registered_check(task.check)
check = check_class(self._http_client, task) check = check_class(self._http_client, task)
result = await check.run() result = await check.run()
@ -69,7 +69,7 @@ class ArgosAgent:
except Exception as err: # pylint: disable=broad-except except Exception as err: # pylint: disable=broad-except
status = "error" status = "error"
context = SerializableException.from_exception(err) context = SerializableException.from_exception(err)
msg = f"An exception occured when running {task}. {err.__class__.__name__} : {err}" msg = f"An exception occured when running {_task}. {err.__class__.__name__} : {err}"
logger.error(msg) logger.error(msg)
return AgentResult(task_id=task.id, status=status, context=context) return AgentResult(task_id=task.id, status=status, context=context)
@ -102,12 +102,19 @@ class ArgosAgent:
async def _post_results(self, results: List[AgentResult]): async def _post_results(self, results: List[AgentResult]):
data = [r.model_dump() for r in results] data = [r.model_dump() for r in results]
if self._http_client is not None:
response = await self._http_client.post( response = await self._http_client.post(
f"{self.server}/api/results", params={"agent_id": self.agent_id}, json=data f"{self.server}/api/results",
params={"agent_id": self.agent_id},
json=data,
) )
if response.status_code == httpx.codes.CREATED: if response.status_code == httpx.codes.CREATED:
logger.error("Successfully posted results %s", json.dumps(response.json())) logger.error(
"Successfully posted results %s", json.dumps(response.json())
)
else: else:
logger.error("Failed to post results: %s", response.read()) logger.error("Failed to post results: %s", response.read())
return response return response
logger.error("self._http_client is None")

View file

@ -1,7 +1,7 @@
"""Various base classes for checks""" """Various base classes for checks"""
from dataclasses import dataclass from dataclasses import dataclass
from typing import Type, Union from typing import Type
import httpx import httpx
from pydantic import BaseModel from pydantic import BaseModel
@ -71,7 +71,7 @@ class InvalidResponse(Exception):
class BaseCheck: class BaseCheck:
config: str config: str
expected_cls: Union[None, Type[BaseExpectedValue]] = None expected_cls: None | Type[BaseExpectedValue] = None
_registry = [] # type: ignore[var-annotated] _registry = [] # type: ignore[var-annotated]

View file

@ -1,9 +1,9 @@
from typing import Literal, Union from typing import Literal
def string_to_duration( def string_to_duration(
value: str, target: Literal["days", "hours", "minutes"] value: str, target: Literal["days", "hours", "minutes"]
) -> Union[int, float]: ) -> int | float:
"""Convert a string to a number of hours, days or minutes""" """Convert a string to a number of hours, days or minutes"""
num = int("".join(filter(str.isdigit, value))) num = int("".join(filter(str.isdigit, value)))

View file

@ -51,7 +51,7 @@ async def list_users(db: Session):
return db.query(User).order_by(asc(User.username)) return db.query(User).order_by(asc(User.username))
async def get_task(db: Session, task_id: int) -> Task: async def get_task(db: Session, task_id: int) -> None | Task:
return db.get(Task, task_id) return db.get(Task, task_id)
@ -71,9 +71,9 @@ async def count_tasks(db: Session, selected: None | bool = None):
query = db.query(Task) query = db.query(Task)
if selected is not None: if selected is not None:
if selected: if selected:
query = query.filter(Task.selected_by is not None) query = query.filter(Task.selected_by is not None) # type: ignore[arg-type]
else: else:
query = query.filter(Task.selected_by is None) query = query.filter(Task.selected_by is None) # type: ignore[arg-type]
return query.count() return query.count()
@ -98,7 +98,7 @@ async def has_config_changed(db: Session, config: schemas.Config) -> bool:
case "general_frequency": case "general_frequency":
if conf.val != str(config.general.frequency): if conf.val != str(config.general.frequency):
same_config = False same_config = False
conf.val = config.general.frequency conf.val = str(config.general.frequency)
conf.updated_at = datetime.now() conf.updated_at = datetime.now()
db.commit() db.commit()
@ -208,7 +208,7 @@ async def get_severity_counts(db: Session) -> dict:
# Execute the query and fetch the results # Execute the query and fetch the results
task_counts_by_severity = query.all() task_counts_by_severity = query.all()
counts_dict = dict(task_counts_by_severity) counts_dict = dict(task_counts_by_severity) # type: ignore[var-annotated,arg-type]
for key in ("ok", "warning", "critical", "unknown"): for key in ("ok", "warning", "critical", "unknown"):
counts_dict.setdefault(key, 0) counts_dict.setdefault(key, 0)
return counts_dict return counts_dict

View file

@ -1,5 +1,5 @@
"""Web interface for machines""" """Web interface for machines"""
from typing import List, Union from typing import List
from fastapi import APIRouter, BackgroundTasks, Depends, Request from fastapi import APIRouter, BackgroundTasks, Depends, Request
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -18,10 +18,13 @@ async def read_tasks(
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
limit: int = 10, limit: int = 10,
agent_id: Union[None, str] = None, agent_id: None | str = None,
): ):
"""Return a list of tasks to execute""" """Return a list of tasks to execute"""
agent_id = agent_id or request.client.host host = ""
if request.client is not None:
host = request.client.host
agent_id = agent_id or host
tasks = await queries.list_tasks(db, agent_id=agent_id, limit=limit) tasks = await queries.list_tasks(db, agent_id=agent_id, limit=limit)
return tasks return tasks
@ -33,7 +36,7 @@ async def create_results(
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
db: Session = Depends(get_db), db: Session = Depends(get_db),
config: Config = Depends(get_config), config: Config = Depends(get_config),
agent_id: Union[None, str] = None, agent_id: None | str = None,
): ):
"""Get the results from the agents and store them locally. """Get the results from the agents and store them locally.
@ -42,7 +45,10 @@ async def create_results(
- If it's an error, determine its severity ; - If it's an error, determine its severity ;
- Trigger the reporting calls - Trigger the reporting calls
""" """
agent_id = agent_id or request.client.host host = ""
if request.client is not None:
host = request.client.host
agent_id = agent_id or host
db_results = [] db_results = []
for agent_result in results: for agent_result in results:
# XXX Maybe offload this to a queue. # XXX Maybe offload this to a queue.

View file

@ -1,5 +1,6 @@
from fastapi import Depends, HTTPException, Request from fastapi import Depends, HTTPException, Request
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi_login import LoginManager
auth_scheme = HTTPBearer() auth_scheme = HTTPBearer()
@ -16,7 +17,7 @@ def get_config(request: Request):
return request.app.state.config return request.app.state.config
async def get_manager(request: Request): async def get_manager(request: Request) -> LoginManager:
return await request.app.state.manager(request) return await request.app.state.manager(request)

View file

@ -125,7 +125,7 @@ async def get_domains_view(
tasks = db.query(Task).all() tasks = db.query(Task).all()
domains_severities = defaultdict(list) domains_severities = defaultdict(list)
domains_last_checks = defaultdict(list) domains_last_checks = defaultdict(list) # type: ignore[var-annotated]
for task in tasks: for task in tasks:
domain = urlparse(task.url).netloc domain = urlparse(task.url).netloc
@ -210,6 +210,8 @@ async def get_task_results_view(
.all() .all()
) )
task = db.query(Task).get(task_id) task = db.query(Task).get(task_id)
description = ""
if task is not None:
description = task.get_check().get_description(config) description = task.get_check().get_description(config)
return templates.TemplateResponse( return templates.TemplateResponse(
"results.html", "results.html",
@ -251,8 +253,8 @@ async def set_refresh_cookies_view(
request.url_for("get_severity_counts_view"), request.url_for("get_severity_counts_view"),
status_code=status.HTTP_303_SEE_OTHER, status_code=status.HTTP_303_SEE_OTHER,
) )
response.set_cookie(key="auto_refresh_enabled", value=auto_refresh_enabled) response.set_cookie(key="auto_refresh_enabled", value=str(auto_refresh_enabled))
response.set_cookie( response.set_cookie(
key="auto_refresh_seconds", value=max(5, int(auto_refresh_seconds)) key="auto_refresh_seconds", value=str(max(5, int(auto_refresh_seconds)))
) )
return response return response

View file

@ -7,12 +7,12 @@ from yamlinclude import YamlIncludeConstructor
from argos.schemas.config import Config from argos.schemas.config import Config
def read_yaml_config(filename): def read_yaml_config(filename: str) -> Config:
parsed = _load_yaml(filename) parsed = _load_yaml(filename)
return Config(**parsed) return Config(**parsed)
def _load_yaml(filename): def _load_yaml(filename: str):
base_dir = Path(filename).resolve().parent base_dir = Path(filename).resolve().parent
YamlIncludeConstructor.add_to_loader_class( YamlIncludeConstructor.add_to_loader_class(
loader_class=yaml.FullLoader, base_dir=str(base_dir) loader_class=yaml.FullLoader, base_dir=str(base_dir)

View file

@ -45,16 +45,18 @@ dependencies = [
dev = [ dev = [
"black==23.3.0", "black==23.3.0",
"djlint>=1.34.0", "djlint>=1.34.0",
"hatch==1.9.4",
"ipdb>=0.13,<0.14", "ipdb>=0.13,<0.14",
"ipython>=8.16,<9", "ipython>=8.16,<9",
"isort==5.11.5", "isort==5.11.5",
"mypy>=1.10.0,<2",
"pylint>=3.0.2", "pylint>=3.0.2",
"pytest-asyncio>=0.21,<1", "pytest-asyncio>=0.21,<1",
"pytest>=6.2.5", "pytest>=6.2.5",
"respx>=0.20,<1", "respx>=0.20,<1",
"ruff==0.1.5,<1", "ruff==0.1.5,<1",
"sphinx-autobuild", "sphinx-autobuild",
"hatch==1.9.4", "types-PyYAML",
] ]
docs = [ docs = [
"cogapp", "cogapp",
@ -103,3 +105,6 @@ filterwarnings = [
"ignore:'crypt' is deprecated and slated for removal in Python 3.13:DeprecationWarning", "ignore:'crypt' is deprecated and slated for removal in Python 3.13:DeprecationWarning",
"ignore:The 'app' shortcut is now deprecated:DeprecationWarning", "ignore:The 'app' shortcut is now deprecated:DeprecationWarning",
] ]
[tool.mypy]
ignore_missing_imports = "True"

View file

@ -10,7 +10,7 @@ os.environ["ARGOS_YAML_FILE"] = "tests/config.yaml"
@pytest.fixture @pytest.fixture
def db() -> Session: def db() -> Session: # type: ignore[misc]
from argos.server import models from argos.server import models
app = _create_app() app = _create_app()
@ -20,7 +20,7 @@ def db() -> Session:
@pytest.fixture @pytest.fixture
def app() -> FastAPI: def app() -> FastAPI: # type: ignore[misc]
from argos.server import models from argos.server import models
app = _create_app() app = _create_app()