From cbb76f3ca322ace0aacf720ba52c5961d73880b9 Mon Sep 17 00:00:00 2001 From: Alex Pyrgiotis Date: Tue, 6 May 2025 17:32:07 +0300 Subject: [PATCH] WIP: More polishing --- dev_scripts/inventory.py | 390 +++++++++++++++++++++++++++------------ 1 file changed, 267 insertions(+), 123 deletions(-) diff --git a/dev_scripts/inventory.py b/dev_scripts/inventory.py index 92cbaa2..13c59f4 100755 --- a/dev_scripts/inventory.py +++ b/dev_scripts/inventory.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 """ -GitHub assets inventory +Dangerzone assets inventory This script keeps an inventory of assets (currently GitHub release assets) in a TOML file, resolves their versions (via GitHub API and semver ranges), calculates file @@ -8,20 +8,25 @@ checksums, and downloads assets based on a JSON “lock” file. """ import argparse +import contextlib +import enum import fnmatch import hashlib import json +import logging import os import platform import shutil import sys import tarfile +import traceback import zipfile from pathlib import Path import requests import semver import toml +from colorama import Back, Fore, Style from platformdirs import user_cache_dir # CONSTANTS @@ -30,34 +35,71 @@ LOCK_FILE = "inventory.lock" GITHUB_API_URL = "https://api.github.com" # Determine the cache directory using platformdirs -CACHE_ROOT = Path(user_cache_dir("gh_assets_manager")) +CACHE_ROOT = Path(user_cache_dir("dz-inventory")) + + +logger = logging.getLogger(__name__) + + +class InvException(Exception): + """Inventory-specific error.""" + + pass # HELPER FUNCTIONS +@contextlib.contextmanager +def report_error(verbose=False, fail=False): + """Report errors in a more uniform way. + + Report errors to the user, based on their type and the log verbosity. In case of + non-recoverable errors (defined by the user), exit with status code 1 + """ + try: + yield + except InvException as e: + if not verbose: + print(f"{Fore.RED}{e}{Style.RESET_ALL}", file=sys.stderr) + elif verbose < 2: + traceback.print_exception(e, limit=1, chain=False) + breakpoint() + else: + traceback.print_exception(e, chain=True) + except Exception as e: + logger.exception("An unknown error occurred:") + else: + return + + if fail: + sys.exit(1) + + def read_config(): try: with open(CONFIG_FILE, "r") as fp: return toml.load(fp) except Exception as e: - print(f"Could not load configuration file: {e}") - sys.exit(1) + raise InvException(f"Could not load configuration file '{CONFIG_FILE}'") from e + + +def write_lock(lock_data): + config = read_config() + config_hash = hashlib.sha256(json.dumps(config).encode()).hexdigest() + lock_data["config_checksum"] = config_hash + with open(LOCK_FILE, "w") as fp: + json.dump(lock_data, fp, indent=2) def check_lock_stale(lock): config = read_config() config_hash = hashlib.sha256(json.dumps(config).encode()).hexdigest() if config_hash != lock["config_checksum"]: - raise Exception( + raise InvException( "You have made changes to the inventory since you last updated the lock" " file. You need to run the 'lock' command again." ) -def write_lock(lock_data): - with open(LOCK_FILE, "w") as fp: - json.dump(lock_data, fp, indent=2) - - def load_lock(check=True): try: with open(LOCK_FILE, "r") as fp: @@ -66,8 +108,7 @@ def load_lock(check=True): check_lock_stale(lock) return lock except Exception as e: - print(f"Could not load lock file: {e}") - sys.exit(1) + raise InvException(f"Could not load lock file '{LOCK_FILE}': {e}") from e def calc_checksum(stream): @@ -82,25 +123,43 @@ def calc_checksum(stream): def cache_file_path(url): """ - Generate a safe cache file path for a given URL, - using sha256(url) as filename. + Generate a safe cache file path for a given URL, using the SHA-256 hash of the path, + plus the asset name. """ + # Calculate a unique hash for this URL name, so that it doesn't class with different + # versions of the same asset. url_hash = hashlib.sha256(url.encode("utf-8")).hexdigest() + + # Get the name of the asset from the URL, which is typically the last part of the + # URL. However, if the asset is the GitHub-generated zipball/tarball, use a + # different naming scheme. + parsed = url.split("/") + asset_name = parsed[-1] + if parsed[-2] in ("zipball", "tarball"): + repo_name = parsed[-3] + asset_type = parsed[-2] + asset_name = f"{repo_name}-{asset_type}" + CACHE_ROOT.mkdir(parents=True, exist_ok=True) - return CACHE_ROOT / url_hash + return CACHE_ROOT / f"{url_hash}-{asset_name}" + + +def checksum_file_path(url): + """Generate checksum filename for a given URL""" + path = cache_file_path(url) + return path.parent / (path.name + ".sha256") def store_checksum_in_cache(url, checksum): """Store the checksum in a file whose name is based on the URL hash.""" - checksum_path = cache_file_path(url).with_suffix(".sha256") - with open(checksum_path, "w") as fp: + with open(checksum_file_path(url), "w") as fp: fp.write(checksum) def read_checksum_from_cache(url): - checksum_path = cache_file_path(url).with_suffix(".sha256") - if checksum_path.exists(): - return checksum_path.read_text().strip() + checksum_file = checksum_file_path(url) + if checksum_file.exists(): + return checksum_file.read_text().strip() return None @@ -126,7 +185,7 @@ def download_to_cache(url): if cached: return cached - print(f"Downloading {url} into cache...") + logger.debug(f"Downloading {url} into cache...") response = requests.get(url, stream=True) response.raise_for_status() @@ -137,11 +196,12 @@ def download_to_cache(url): with open(cached, "rb") as f: checksum = calc_checksum(f) store_checksum_in_cache(url, checksum) - print("Download to cache completed.") + logger.debug("Download to cache completed.") return cached def detect_platform(): + """Detect the platform that the script runs in""" # Return a string like 'windows/amd64' or 'linux/amd64' or 'darwin/amd64' sys_platform = sys.platform if sys_platform.startswith("win"): @@ -156,6 +216,7 @@ def detect_platform(): machine = platform.machine().lower() # Normalize architecture names arch = {"x86_64": "amd64", "amd64": "amd64", "arm64": "arm64"}.get(machine, machine) + return f"{os_name}/{arch}" @@ -166,9 +227,12 @@ def get_latest_release(repo, semver_range): """ url = f"{GITHUB_API_URL}/repos/{repo}/releases" response = requests.get(url) + response.raise_for_status() if response.status_code != 200: - print(f"Failed to fetch releases for repo {repo}. HTTP {response.status_code}") - return None + raise InvException( + f"Unexpected response when fetching releases for repo '{repo}': HTTP" + f" {response.status_code}" + ) releases = response.json() matching = [] @@ -176,23 +240,41 @@ def get_latest_release(repo, semver_range): tag = release.get("tag_name", "") # Strip any prefix 'v' if necessary version_str = tag.lstrip("v") + + # Attempt to parse asset version as semver. If the project has a tag that does + # not conform to SemVer, just skip it. try: version = semver.VersionInfo.parse(version_str) - # Skip prereleases and non-matching versions - if release["prerelease"] or not version.match(semver_range): - continue - matching.append((release, version)) except ValueError: + logger.debug( + f"Skipping non SemVer-compliant version '{version_str}' from repo" + f" '{repo}'" + ) continue + # Skip prereleases and non-matching versions + if release["prerelease"]: + logger.debug( + f"Skipping prerelease version '{version_str}' from repo '{repo}'" + ) + continue + elif not version.match(semver_range): + logger.debug( + f"Skipping version '{version_str}' from repo '{repo}' because it does" + f" not match the '{semver_range}' requirement" + ) + continue + matching.append((release, version)) + if not matching: - print(f"No releases match version requirement {semver_range} for repo {repo}") - return None + raise InvException( + f"No releases match version requirement {semver_range} for repo '{repo}'" + ) return max(matching, key=lambda x: x[1])[0] -def resolve_asset_for_platform(release, name): +def get_download_url(release, name): """ Given the release JSON and an asset name, find the asset download URL by matching filename. If the asset name contains "{version}", it will be formatted using the @@ -211,7 +293,8 @@ def resolve_asset_for_platform(release, name): for asset in assets: if asset.get("name") == expected_name: return asset.get("browser_download_url") - return None + + raise InvException(f"Could not find an asset with '{name}'") def hash_asset(url): @@ -219,6 +302,15 @@ def hash_asset(url): Download the asset using caching and return its SHA256 checksum. The checksum is also stored in the cache as a .sha256 file. """ + # If we have downloaded the file and hashed it before, return the checksum + # immediately. + checksum_file = checksum_file_path(url) + if checksum_file.exists(): + logger.debug(f"Using cached checksum for URL: {url}") + with open(checksum_file, "r") as f: + return f.read() + + # Else, download the file, hash it, and store the checksum in the cache. cached_file = download_to_cache(url) with open(cached_file, "rb") as f: checksum = calc_checksum(f) @@ -234,20 +326,17 @@ def download_to_cache_and_verify(url, destination, expected_checksum): If not, remove the cached file and raise an exception. """ cached_file = download_to_cache(url) + checksum_file = checksum_file_path(url) with open(cached_file, "rb") as f: computed_checksum = calc_checksum(f) if computed_checksum != expected_checksum: # Remove cache file and its checksum file - try: - cached_file.unlink() - checksum_file = cached_file.with_suffix(".sha256") - if checksum_file.exists(): - checksum_file.unlink() - except Exception: - pass - raise Exception( - f"Hash mismatch for URL {url}: computed '{computed_checksum}', expected '{expected_checksum}'" + cached_file.unlink(missing_ok=True) + checksum_file.unlink(missing_ok=True) + raise InvException( + f"Hash mismatch for URL {url}: computed '{computed_checksum}'," + f" expected '{expected_checksum}'" ) return cached_file @@ -266,7 +355,7 @@ def determine_extract_opts(extract): globs = ["*"] flatten = False else: - raise Exception(f"Unexpected format for 'extract' field: {extract}") + raise InvException(f"Unexpected format for 'extract' field: {extract}") return { "globs": globs, @@ -284,12 +373,13 @@ def detect_archive_type(name): return "tar" if name.endswith(".zip") or name == "!zipball": return "zip" - raise Exception(f"Unsupported archive type for extraction: {name}") + raise InvException(f"Unsupported archive type for extraction: {name}") def flatten_extracted_files(destination): """ - After extraction, move all files found in subdirectories of destination into destination root. + After extraction, move all files found in subdirectories of destination into + destination root. """ for root, dirs, files in os.walk(destination): # Skip the root directory itself @@ -323,6 +413,7 @@ def extract_asset(archive_path, destination, options): For tarfiles, use filter="data" when extracting to mitigate malicious tar entries. """ + logger.debug(f"Extracting '{archive_path}' to '{destination}'...") ft = options["filetype"] globs = options["globs"] flatten = options["flatten"] @@ -337,10 +428,10 @@ def extract_asset(archive_path, destination, options): if any(fnmatch.fnmatch(m.name, glob) for glob in globs) ] if not members: - raise Exception("Globs did not match any files in the archive") + raise InvException("Globs did not match any files in the archive") tar.extractall(path=destination, members=members, filter="data") except Exception as e: - raise Exception(f"Error extracting '{archive_path}': {e}") + raise InvException(f"Error extracting '{archive_path}': {e}") from e elif ft == "zip": try: with zipfile.ZipFile(archive_path, "r") as zip_ref: @@ -350,17 +441,17 @@ def extract_asset(archive_path, destination, options): if any(fnmatch.fnmatch(m, glob) for glob in globs) ] if not members: - raise Exception("Globs did not match any files in the archive") + raise InvException("Globs did not match any files in the archive") zip_ref.extractall(path=destination, members=members) except Exception as e: - raise Exception(f"Error extracting zip archive: {e}") + raise InvException(f"Error extracting zip archive: {e}") from e else: - raise Exception(f"Unsupported archive type for file {archive_path}") + raise InvException(f"Unsupported archive type for file {archive_path}") if flatten: flatten_extracted_files(destination) - print(f"Extraction of {archive_path} complete.") + logger.debug(f"Successfully extracted '{archive_path}'") def get_platform_assets(assets, platform): @@ -387,6 +478,58 @@ def chmod_exec(path): path.chmod(path.stat().st_mode | 0o111) +def compute_asset_lock(asset_name, asset): + try: + repo = asset["repo"] + version_range = asset["version"] + asset_map = asset["platform"] # mapping platform -> asset file name + destination_str = asset["destination"] + executable = asset.get("executable", False) + extract = asset.get("extract", False) + except KeyError as e: + raise InvException(f"Required field {e} is missing") + + if extract: + extract = determine_extract_opts(extract) + + logger.debug( + f"Fetching a release that satisfies version range '{version_range}' for repo" + f" '{repo}'" + ) + release = get_latest_release(repo, version_range) + version = release["tag_name"].lstrip("v") + logger.debug(f"Found release '{version}' for repo '{repo}'") + + asset_lock_data = {} + # Process each defined platform key in the asset_map + for plat_key, plat_name in asset_map.items(): + logger.debug(f"Getting download URL for asset '{asset_name}' of repo '{repo}'") + download_url = get_download_url(release, plat_name) + logger.debug(f"Found download URL: {download_url}") + + if extract: + extract = extract.copy() + extract["filetype"] = detect_archive_type(plat_name) + + logger.debug( + f"Hashing asset '{asset_name}' of repo '{repo}' for platform" + f" '{plat_key}'..." + ) + checksum = hash_asset(download_url) + logger.debug(f"Computed the following SHA-256 checksum: {checksum}") + asset_lock_data[plat_key] = { + "repo": repo, + "download_url": download_url, + "version": version, + "checksum": checksum, + "executable": executable, + "destination": destination_str, + "extract": extract, + } + + return asset_lock_data + + # COMMAND FUNCTIONS def cmd_lock(args): """ @@ -416,64 +559,22 @@ def cmd_lock(args): # and flatten = True|False. assets_cfg = config.get("asset", {}) if not assets_cfg: - print("No assets defined under the [asset] section in the config file.") - sys.exit(1) + raise InvException( + "No assets defined under the [asset] section in the config file." + ) + lock_assets = lock["assets"] for asset_name, asset in assets_cfg.items(): - repo = asset.get("repo") - version_range = asset.get("version") - asset_map = asset.get("platform") # mapping platform -> asset file name - executable = asset.get("executable") - destination_str = asset.get("destination") - extract = asset.get("extract", False) + logger.info(f"Processing asset '{asset_name}'...") + try: + lock_assets[asset_name] = compute_asset_lock(asset_name, asset) + except Exception as e: + raise InvException( + f"Error when processing asset '{asset_name}': {e}" + ) from e - if extract: - extract = determine_extract_opts(extract) - - if not repo or not version_range or not asset_map or not destination_str: - print(f"Asset {asset_name} is missing required fields.") - continue - - print(f"Processing asset '{asset_name}' for repo '{repo}' ...") - release = get_latest_release(repo, version_range) - if not release: - print(f"Could not resolve release for asset '{asset_name}'.") - continue - - lock_assets = lock["assets"] - asset_lock_data = {} - # Process each defined platform key in the asset_map - for plat_key, plat_name in asset_map.items(): - download_url = resolve_asset_for_platform(release, plat_name) - if not download_url: - print( - f"Warning: No asset found for platform '{plat_key}' in repo '{repo}' for asset '{asset_name}'." - ) - continue - - if extract: - extract = extract.copy() - extract["filetype"] = detect_archive_type(plat_name) - print(f"Hashing asset '{asset_name}' for platform '{plat_key}'...") - checksum = hash_asset(download_url) - asset_lock_data[plat_key] = { - "repo": repo, - "download_url": download_url, - "version": release.get("tag_name").lstrip("v"), - "checksum": checksum, - "executable": executable, - "destination": destination_str, - "extract": extract, - } - if not asset_lock_data: - print(f"No valid platforms found for asset '{asset_name}'.") - continue - lock_assets[asset_name] = asset_lock_data - - config_hash = hashlib.sha256(json.dumps(config).encode()).hexdigest() - lock["config_checksum"] = config_hash write_lock(lock) - print(f"Lock file '{LOCK_FILE}' updated.") + logger.info(f"Lock file '{LOCK_FILE}' updated.") def cmd_sync(args): @@ -552,10 +653,7 @@ def cmd_sync(args): if executable: chmod_exec(destination) except Exception as e: - print( - f"Error processing asset '{asset_name}' for platform '{target_plat}': {e}" - ) - continue + raise InvException(f"Error processing asset '{asset_name}': {e}") from e print("Downloads completed") @@ -573,20 +671,14 @@ def cmd_list(args): print(f"{asset_name} {asset['version']} {asset['download_url']}") -def main(): - parser = argparse.ArgumentParser(description="GitHub Release Assets Manager") +def parse_args(): + parser = argparse.ArgumentParser(description="Inventory") subparsers = parser.add_subparsers(dest="command", required=True) - lock_parser = subparsers.add_parser( - "lock", help="Update lock file from config (without downloading)" - ) + lock_parser = subparsers.add_parser("lock", help="Update lock file from config") lock_parser.set_defaults(func=cmd_lock) sync_parser = subparsers.add_parser("sync", help="Sync assets as per lock file") - sync_parser.add_argument( - "--platform", - help="Platform name/arch (e.g., windows/amd64) to download assets for", - ) sync_parser.add_argument( "assets", nargs="*", @@ -595,14 +687,66 @@ def main(): sync_parser.set_defaults(func=cmd_sync) list_parser = subparsers.add_parser("list", help="List assets for a platform") - list_parser.add_argument( - "--platform", - help="Platform name/arch (e.g., windows/amd64) to list assets for", - ) list_parser.set_defaults(func=cmd_list) + + # Add common arguments. + for subparser in subparsers.choices.values(): + subparser.add_argument( + "-p", + "--platform", + help=( + "The platform to choose when determining which inventory assets to work" + " on. Examples: windows/amd64, linux/amd64, darwin/amd64, darwin/arm64" + " (default: %(default)s)" + ), + ) + subparser.add_argument( + "-v", + "--verbose", + action="count", + default=0, + help="Enable verbose logging", + ) + subparser.add_argument( + "-C", + "--directory", + help=( + "The working directory for the script (defaults to the current working" + " directory)" + ), + ) + args = parser.parse_args() - args.func(args) + return args + + +def setup_logging(verbose=False): + """Simple way to setup logging. + + Copied from: https://docs.python.org/3/howto/logging.html + """ + # specify level + if not verbose: + lvl = logging.WARN + elif verbose == 1: + lvl = logging.INFO + else: + lvl = logging.DEBUG + + logging.basicConfig( + level=lvl, + format="%(asctime)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + +def main(): + args = parse_args() + setup_logging(args.verbose) + + with report_error(verbose=args.verbose, fail=False): + return args.func(args) if __name__ == "__main__": - main() + sys.exit(main())