WIP: More polishing

This commit is contained in:
Alex Pyrgiotis 2025-05-06 17:32:07 +03:00
parent 06697248fa
commit cbb76f3ca3
No known key found for this signature in database
GPG key ID: B6C15EBA0357C9AA

View file

@ -1,6 +1,6 @@
#!/usr/bin/env python3
"""
GitHub assets inventory
Dangerzone assets inventory
This script keeps an inventory of assets (currently GitHub release assets) in a TOML
file, resolves their versions (via GitHub API and semver ranges), calculates file
@ -8,20 +8,25 @@ checksums, and downloads assets based on a JSON “lock” file.
"""
import argparse
import contextlib
import enum
import fnmatch
import hashlib
import json
import logging
import os
import platform
import shutil
import sys
import tarfile
import traceback
import zipfile
from pathlib import Path
import requests
import semver
import toml
from colorama import Back, Fore, Style
from platformdirs import user_cache_dir
# CONSTANTS
@ -30,34 +35,71 @@ LOCK_FILE = "inventory.lock"
GITHUB_API_URL = "https://api.github.com"
# Determine the cache directory using platformdirs
CACHE_ROOT = Path(user_cache_dir("gh_assets_manager"))
CACHE_ROOT = Path(user_cache_dir("dz-inventory"))
logger = logging.getLogger(__name__)
class InvException(Exception):
"""Inventory-specific error."""
pass
# HELPER FUNCTIONS
@contextlib.contextmanager
def report_error(verbose=False, fail=False):
"""Report errors in a more uniform way.
Report errors to the user, based on their type and the log verbosity. In case of
non-recoverable errors (defined by the user), exit with status code 1
"""
try:
yield
except InvException as e:
if not verbose:
print(f"{Fore.RED}{e}{Style.RESET_ALL}", file=sys.stderr)
elif verbose < 2:
traceback.print_exception(e, limit=1, chain=False)
breakpoint()
else:
traceback.print_exception(e, chain=True)
except Exception as e:
logger.exception("An unknown error occurred:")
else:
return
if fail:
sys.exit(1)
def read_config():
try:
with open(CONFIG_FILE, "r") as fp:
return toml.load(fp)
except Exception as e:
print(f"Could not load configuration file: {e}")
sys.exit(1)
raise InvException(f"Could not load configuration file '{CONFIG_FILE}'") from e
def write_lock(lock_data):
config = read_config()
config_hash = hashlib.sha256(json.dumps(config).encode()).hexdigest()
lock_data["config_checksum"] = config_hash
with open(LOCK_FILE, "w") as fp:
json.dump(lock_data, fp, indent=2)
def check_lock_stale(lock):
config = read_config()
config_hash = hashlib.sha256(json.dumps(config).encode()).hexdigest()
if config_hash != lock["config_checksum"]:
raise Exception(
raise InvException(
"You have made changes to the inventory since you last updated the lock"
" file. You need to run the 'lock' command again."
)
def write_lock(lock_data):
with open(LOCK_FILE, "w") as fp:
json.dump(lock_data, fp, indent=2)
def load_lock(check=True):
try:
with open(LOCK_FILE, "r") as fp:
@ -66,8 +108,7 @@ def load_lock(check=True):
check_lock_stale(lock)
return lock
except Exception as e:
print(f"Could not load lock file: {e}")
sys.exit(1)
raise InvException(f"Could not load lock file '{LOCK_FILE}': {e}") from e
def calc_checksum(stream):
@ -82,25 +123,43 @@ def calc_checksum(stream):
def cache_file_path(url):
"""
Generate a safe cache file path for a given URL,
using sha256(url) as filename.
Generate a safe cache file path for a given URL, using the SHA-256 hash of the path,
plus the asset name.
"""
# Calculate a unique hash for this URL name, so that it doesn't class with different
# versions of the same asset.
url_hash = hashlib.sha256(url.encode("utf-8")).hexdigest()
# Get the name of the asset from the URL, which is typically the last part of the
# URL. However, if the asset is the GitHub-generated zipball/tarball, use a
# different naming scheme.
parsed = url.split("/")
asset_name = parsed[-1]
if parsed[-2] in ("zipball", "tarball"):
repo_name = parsed[-3]
asset_type = parsed[-2]
asset_name = f"{repo_name}-{asset_type}"
CACHE_ROOT.mkdir(parents=True, exist_ok=True)
return CACHE_ROOT / url_hash
return CACHE_ROOT / f"{url_hash}-{asset_name}"
def checksum_file_path(url):
"""Generate checksum filename for a given URL"""
path = cache_file_path(url)
return path.parent / (path.name + ".sha256")
def store_checksum_in_cache(url, checksum):
"""Store the checksum in a file whose name is based on the URL hash."""
checksum_path = cache_file_path(url).with_suffix(".sha256")
with open(checksum_path, "w") as fp:
with open(checksum_file_path(url), "w") as fp:
fp.write(checksum)
def read_checksum_from_cache(url):
checksum_path = cache_file_path(url).with_suffix(".sha256")
if checksum_path.exists():
return checksum_path.read_text().strip()
checksum_file = checksum_file_path(url)
if checksum_file.exists():
return checksum_file.read_text().strip()
return None
@ -126,7 +185,7 @@ def download_to_cache(url):
if cached:
return cached
print(f"Downloading {url} into cache...")
logger.debug(f"Downloading {url} into cache...")
response = requests.get(url, stream=True)
response.raise_for_status()
@ -137,11 +196,12 @@ def download_to_cache(url):
with open(cached, "rb") as f:
checksum = calc_checksum(f)
store_checksum_in_cache(url, checksum)
print("Download to cache completed.")
logger.debug("Download to cache completed.")
return cached
def detect_platform():
"""Detect the platform that the script runs in"""
# Return a string like 'windows/amd64' or 'linux/amd64' or 'darwin/amd64'
sys_platform = sys.platform
if sys_platform.startswith("win"):
@ -156,6 +216,7 @@ def detect_platform():
machine = platform.machine().lower()
# Normalize architecture names
arch = {"x86_64": "amd64", "amd64": "amd64", "arm64": "arm64"}.get(machine, machine)
return f"{os_name}/{arch}"
@ -166,9 +227,12 @@ def get_latest_release(repo, semver_range):
"""
url = f"{GITHUB_API_URL}/repos/{repo}/releases"
response = requests.get(url)
response.raise_for_status()
if response.status_code != 200:
print(f"Failed to fetch releases for repo {repo}. HTTP {response.status_code}")
return None
raise InvException(
f"Unexpected response when fetching releases for repo '{repo}': HTTP"
f" {response.status_code}"
)
releases = response.json()
matching = []
@ -176,23 +240,41 @@ def get_latest_release(repo, semver_range):
tag = release.get("tag_name", "")
# Strip any prefix 'v' if necessary
version_str = tag.lstrip("v")
# Attempt to parse asset version as semver. If the project has a tag that does
# not conform to SemVer, just skip it.
try:
version = semver.VersionInfo.parse(version_str)
# Skip prereleases and non-matching versions
if release["prerelease"] or not version.match(semver_range):
continue
matching.append((release, version))
except ValueError:
logger.debug(
f"Skipping non SemVer-compliant version '{version_str}' from repo"
f" '{repo}'"
)
continue
# Skip prereleases and non-matching versions
if release["prerelease"]:
logger.debug(
f"Skipping prerelease version '{version_str}' from repo '{repo}'"
)
continue
elif not version.match(semver_range):
logger.debug(
f"Skipping version '{version_str}' from repo '{repo}' because it does"
f" not match the '{semver_range}' requirement"
)
continue
matching.append((release, version))
if not matching:
print(f"No releases match version requirement {semver_range} for repo {repo}")
return None
raise InvException(
f"No releases match version requirement {semver_range} for repo '{repo}'"
)
return max(matching, key=lambda x: x[1])[0]
def resolve_asset_for_platform(release, name):
def get_download_url(release, name):
"""
Given the release JSON and an asset name, find the asset download URL by matching
filename. If the asset name contains "{version}", it will be formatted using the
@ -211,7 +293,8 @@ def resolve_asset_for_platform(release, name):
for asset in assets:
if asset.get("name") == expected_name:
return asset.get("browser_download_url")
return None
raise InvException(f"Could not find an asset with '{name}'")
def hash_asset(url):
@ -219,6 +302,15 @@ def hash_asset(url):
Download the asset using caching and return its SHA256 checksum.
The checksum is also stored in the cache as a .sha256 file.
"""
# If we have downloaded the file and hashed it before, return the checksum
# immediately.
checksum_file = checksum_file_path(url)
if checksum_file.exists():
logger.debug(f"Using cached checksum for URL: {url}")
with open(checksum_file, "r") as f:
return f.read()
# Else, download the file, hash it, and store the checksum in the cache.
cached_file = download_to_cache(url)
with open(cached_file, "rb") as f:
checksum = calc_checksum(f)
@ -234,20 +326,17 @@ def download_to_cache_and_verify(url, destination, expected_checksum):
If not, remove the cached file and raise an exception.
"""
cached_file = download_to_cache(url)
checksum_file = checksum_file_path(url)
with open(cached_file, "rb") as f:
computed_checksum = calc_checksum(f)
if computed_checksum != expected_checksum:
# Remove cache file and its checksum file
try:
cached_file.unlink()
checksum_file = cached_file.with_suffix(".sha256")
if checksum_file.exists():
checksum_file.unlink()
except Exception:
pass
raise Exception(
f"Hash mismatch for URL {url}: computed '{computed_checksum}', expected '{expected_checksum}'"
cached_file.unlink(missing_ok=True)
checksum_file.unlink(missing_ok=True)
raise InvException(
f"Hash mismatch for URL {url}: computed '{computed_checksum}',"
f" expected '{expected_checksum}'"
)
return cached_file
@ -266,7 +355,7 @@ def determine_extract_opts(extract):
globs = ["*"]
flatten = False
else:
raise Exception(f"Unexpected format for 'extract' field: {extract}")
raise InvException(f"Unexpected format for 'extract' field: {extract}")
return {
"globs": globs,
@ -284,12 +373,13 @@ def detect_archive_type(name):
return "tar"
if name.endswith(".zip") or name == "!zipball":
return "zip"
raise Exception(f"Unsupported archive type for extraction: {name}")
raise InvException(f"Unsupported archive type for extraction: {name}")
def flatten_extracted_files(destination):
"""
After extraction, move all files found in subdirectories of destination into destination root.
After extraction, move all files found in subdirectories of destination into
destination root.
"""
for root, dirs, files in os.walk(destination):
# Skip the root directory itself
@ -323,6 +413,7 @@ def extract_asset(archive_path, destination, options):
For tarfiles, use filter="data" when extracting to mitigate malicious tar entries.
"""
logger.debug(f"Extracting '{archive_path}' to '{destination}'...")
ft = options["filetype"]
globs = options["globs"]
flatten = options["flatten"]
@ -337,10 +428,10 @@ def extract_asset(archive_path, destination, options):
if any(fnmatch.fnmatch(m.name, glob) for glob in globs)
]
if not members:
raise Exception("Globs did not match any files in the archive")
raise InvException("Globs did not match any files in the archive")
tar.extractall(path=destination, members=members, filter="data")
except Exception as e:
raise Exception(f"Error extracting '{archive_path}': {e}")
raise InvException(f"Error extracting '{archive_path}': {e}") from e
elif ft == "zip":
try:
with zipfile.ZipFile(archive_path, "r") as zip_ref:
@ -350,17 +441,17 @@ def extract_asset(archive_path, destination, options):
if any(fnmatch.fnmatch(m, glob) for glob in globs)
]
if not members:
raise Exception("Globs did not match any files in the archive")
raise InvException("Globs did not match any files in the archive")
zip_ref.extractall(path=destination, members=members)
except Exception as e:
raise Exception(f"Error extracting zip archive: {e}")
raise InvException(f"Error extracting zip archive: {e}") from e
else:
raise Exception(f"Unsupported archive type for file {archive_path}")
raise InvException(f"Unsupported archive type for file {archive_path}")
if flatten:
flatten_extracted_files(destination)
print(f"Extraction of {archive_path} complete.")
logger.debug(f"Successfully extracted '{archive_path}'")
def get_platform_assets(assets, platform):
@ -387,6 +478,58 @@ def chmod_exec(path):
path.chmod(path.stat().st_mode | 0o111)
def compute_asset_lock(asset_name, asset):
try:
repo = asset["repo"]
version_range = asset["version"]
asset_map = asset["platform"] # mapping platform -> asset file name
destination_str = asset["destination"]
executable = asset.get("executable", False)
extract = asset.get("extract", False)
except KeyError as e:
raise InvException(f"Required field {e} is missing")
if extract:
extract = determine_extract_opts(extract)
logger.debug(
f"Fetching a release that satisfies version range '{version_range}' for repo"
f" '{repo}'"
)
release = get_latest_release(repo, version_range)
version = release["tag_name"].lstrip("v")
logger.debug(f"Found release '{version}' for repo '{repo}'")
asset_lock_data = {}
# Process each defined platform key in the asset_map
for plat_key, plat_name in asset_map.items():
logger.debug(f"Getting download URL for asset '{asset_name}' of repo '{repo}'")
download_url = get_download_url(release, plat_name)
logger.debug(f"Found download URL: {download_url}")
if extract:
extract = extract.copy()
extract["filetype"] = detect_archive_type(plat_name)
logger.debug(
f"Hashing asset '{asset_name}' of repo '{repo}' for platform"
f" '{plat_key}'..."
)
checksum = hash_asset(download_url)
logger.debug(f"Computed the following SHA-256 checksum: {checksum}")
asset_lock_data[plat_key] = {
"repo": repo,
"download_url": download_url,
"version": version,
"checksum": checksum,
"executable": executable,
"destination": destination_str,
"extract": extract,
}
return asset_lock_data
# COMMAND FUNCTIONS
def cmd_lock(args):
"""
@ -416,64 +559,22 @@ def cmd_lock(args):
# and flatten = True|False.
assets_cfg = config.get("asset", {})
if not assets_cfg:
print("No assets defined under the [asset] section in the config file.")
sys.exit(1)
raise InvException(
"No assets defined under the [asset] section in the config file."
)
lock_assets = lock["assets"]
for asset_name, asset in assets_cfg.items():
repo = asset.get("repo")
version_range = asset.get("version")
asset_map = asset.get("platform") # mapping platform -> asset file name
executable = asset.get("executable")
destination_str = asset.get("destination")
extract = asset.get("extract", False)
logger.info(f"Processing asset '{asset_name}'...")
try:
lock_assets[asset_name] = compute_asset_lock(asset_name, asset)
except Exception as e:
raise InvException(
f"Error when processing asset '{asset_name}': {e}"
) from e
if extract:
extract = determine_extract_opts(extract)
if not repo or not version_range or not asset_map or not destination_str:
print(f"Asset {asset_name} is missing required fields.")
continue
print(f"Processing asset '{asset_name}' for repo '{repo}' ...")
release = get_latest_release(repo, version_range)
if not release:
print(f"Could not resolve release for asset '{asset_name}'.")
continue
lock_assets = lock["assets"]
asset_lock_data = {}
# Process each defined platform key in the asset_map
for plat_key, plat_name in asset_map.items():
download_url = resolve_asset_for_platform(release, plat_name)
if not download_url:
print(
f"Warning: No asset found for platform '{plat_key}' in repo '{repo}' for asset '{asset_name}'."
)
continue
if extract:
extract = extract.copy()
extract["filetype"] = detect_archive_type(plat_name)
print(f"Hashing asset '{asset_name}' for platform '{plat_key}'...")
checksum = hash_asset(download_url)
asset_lock_data[plat_key] = {
"repo": repo,
"download_url": download_url,
"version": release.get("tag_name").lstrip("v"),
"checksum": checksum,
"executable": executable,
"destination": destination_str,
"extract": extract,
}
if not asset_lock_data:
print(f"No valid platforms found for asset '{asset_name}'.")
continue
lock_assets[asset_name] = asset_lock_data
config_hash = hashlib.sha256(json.dumps(config).encode()).hexdigest()
lock["config_checksum"] = config_hash
write_lock(lock)
print(f"Lock file '{LOCK_FILE}' updated.")
logger.info(f"Lock file '{LOCK_FILE}' updated.")
def cmd_sync(args):
@ -552,10 +653,7 @@ def cmd_sync(args):
if executable:
chmod_exec(destination)
except Exception as e:
print(
f"Error processing asset '{asset_name}' for platform '{target_plat}': {e}"
)
continue
raise InvException(f"Error processing asset '{asset_name}': {e}") from e
print("Downloads completed")
@ -573,20 +671,14 @@ def cmd_list(args):
print(f"{asset_name} {asset['version']} {asset['download_url']}")
def main():
parser = argparse.ArgumentParser(description="GitHub Release Assets Manager")
def parse_args():
parser = argparse.ArgumentParser(description="Inventory")
subparsers = parser.add_subparsers(dest="command", required=True)
lock_parser = subparsers.add_parser(
"lock", help="Update lock file from config (without downloading)"
)
lock_parser = subparsers.add_parser("lock", help="Update lock file from config")
lock_parser.set_defaults(func=cmd_lock)
sync_parser = subparsers.add_parser("sync", help="Sync assets as per lock file")
sync_parser.add_argument(
"--platform",
help="Platform name/arch (e.g., windows/amd64) to download assets for",
)
sync_parser.add_argument(
"assets",
nargs="*",
@ -595,14 +687,66 @@ def main():
sync_parser.set_defaults(func=cmd_sync)
list_parser = subparsers.add_parser("list", help="List assets for a platform")
list_parser.add_argument(
"--platform",
help="Platform name/arch (e.g., windows/amd64) to list assets for",
)
list_parser.set_defaults(func=cmd_list)
# Add common arguments.
for subparser in subparsers.choices.values():
subparser.add_argument(
"-p",
"--platform",
help=(
"The platform to choose when determining which inventory assets to work"
" on. Examples: windows/amd64, linux/amd64, darwin/amd64, darwin/arm64"
" (default: %(default)s)"
),
)
subparser.add_argument(
"-v",
"--verbose",
action="count",
default=0,
help="Enable verbose logging",
)
subparser.add_argument(
"-C",
"--directory",
help=(
"The working directory for the script (defaults to the current working"
" directory)"
),
)
args = parser.parse_args()
args.func(args)
return args
def setup_logging(verbose=False):
"""Simple way to setup logging.
Copied from: https://docs.python.org/3/howto/logging.html
"""
# specify level
if not verbose:
lvl = logging.WARN
elif verbose == 1:
lvl = logging.INFO
else:
lvl = logging.DEBUG
logging.basicConfig(
level=lvl,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
def main():
args = parse_args()
setup_logging(args.verbose)
with report_error(verbose=args.verbose, fail=False):
return args.func(args)
if __name__ == "__main__":
main()
sys.exit(main())