WIP: More polishing

This commit is contained in:
Alex Pyrgiotis 2025-05-06 17:32:07 +03:00
parent 06697248fa
commit cbb76f3ca3
No known key found for this signature in database
GPG key ID: B6C15EBA0357C9AA

View file

@ -1,6 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
GitHub assets inventory Dangerzone assets inventory
This script keeps an inventory of assets (currently GitHub release assets) in a TOML This script keeps an inventory of assets (currently GitHub release assets) in a TOML
file, resolves their versions (via GitHub API and semver ranges), calculates file file, resolves their versions (via GitHub API and semver ranges), calculates file
@ -8,20 +8,25 @@ checksums, and downloads assets based on a JSON “lock” file.
""" """
import argparse import argparse
import contextlib
import enum
import fnmatch import fnmatch
import hashlib import hashlib
import json import json
import logging
import os import os
import platform import platform
import shutil import shutil
import sys import sys
import tarfile import tarfile
import traceback
import zipfile import zipfile
from pathlib import Path from pathlib import Path
import requests import requests
import semver import semver
import toml import toml
from colorama import Back, Fore, Style
from platformdirs import user_cache_dir from platformdirs import user_cache_dir
# CONSTANTS # CONSTANTS
@ -30,34 +35,71 @@ LOCK_FILE = "inventory.lock"
GITHUB_API_URL = "https://api.github.com" GITHUB_API_URL = "https://api.github.com"
# Determine the cache directory using platformdirs # Determine the cache directory using platformdirs
CACHE_ROOT = Path(user_cache_dir("gh_assets_manager")) CACHE_ROOT = Path(user_cache_dir("dz-inventory"))
logger = logging.getLogger(__name__)
class InvException(Exception):
"""Inventory-specific error."""
pass
# HELPER FUNCTIONS # HELPER FUNCTIONS
@contextlib.contextmanager
def report_error(verbose=False, fail=False):
"""Report errors in a more uniform way.
Report errors to the user, based on their type and the log verbosity. In case of
non-recoverable errors (defined by the user), exit with status code 1
"""
try:
yield
except InvException as e:
if not verbose:
print(f"{Fore.RED}{e}{Style.RESET_ALL}", file=sys.stderr)
elif verbose < 2:
traceback.print_exception(e, limit=1, chain=False)
breakpoint()
else:
traceback.print_exception(e, chain=True)
except Exception as e:
logger.exception("An unknown error occurred:")
else:
return
if fail:
sys.exit(1)
def read_config(): def read_config():
try: try:
with open(CONFIG_FILE, "r") as fp: with open(CONFIG_FILE, "r") as fp:
return toml.load(fp) return toml.load(fp)
except Exception as e: except Exception as e:
print(f"Could not load configuration file: {e}") raise InvException(f"Could not load configuration file '{CONFIG_FILE}'") from e
sys.exit(1)
def write_lock(lock_data):
config = read_config()
config_hash = hashlib.sha256(json.dumps(config).encode()).hexdigest()
lock_data["config_checksum"] = config_hash
with open(LOCK_FILE, "w") as fp:
json.dump(lock_data, fp, indent=2)
def check_lock_stale(lock): def check_lock_stale(lock):
config = read_config() config = read_config()
config_hash = hashlib.sha256(json.dumps(config).encode()).hexdigest() config_hash = hashlib.sha256(json.dumps(config).encode()).hexdigest()
if config_hash != lock["config_checksum"]: if config_hash != lock["config_checksum"]:
raise Exception( raise InvException(
"You have made changes to the inventory since you last updated the lock" "You have made changes to the inventory since you last updated the lock"
" file. You need to run the 'lock' command again." " file. You need to run the 'lock' command again."
) )
def write_lock(lock_data):
with open(LOCK_FILE, "w") as fp:
json.dump(lock_data, fp, indent=2)
def load_lock(check=True): def load_lock(check=True):
try: try:
with open(LOCK_FILE, "r") as fp: with open(LOCK_FILE, "r") as fp:
@ -66,8 +108,7 @@ def load_lock(check=True):
check_lock_stale(lock) check_lock_stale(lock)
return lock return lock
except Exception as e: except Exception as e:
print(f"Could not load lock file: {e}") raise InvException(f"Could not load lock file '{LOCK_FILE}': {e}") from e
sys.exit(1)
def calc_checksum(stream): def calc_checksum(stream):
@ -82,25 +123,43 @@ def calc_checksum(stream):
def cache_file_path(url): def cache_file_path(url):
""" """
Generate a safe cache file path for a given URL, Generate a safe cache file path for a given URL, using the SHA-256 hash of the path,
using sha256(url) as filename. plus the asset name.
""" """
# Calculate a unique hash for this URL name, so that it doesn't class with different
# versions of the same asset.
url_hash = hashlib.sha256(url.encode("utf-8")).hexdigest() url_hash = hashlib.sha256(url.encode("utf-8")).hexdigest()
# Get the name of the asset from the URL, which is typically the last part of the
# URL. However, if the asset is the GitHub-generated zipball/tarball, use a
# different naming scheme.
parsed = url.split("/")
asset_name = parsed[-1]
if parsed[-2] in ("zipball", "tarball"):
repo_name = parsed[-3]
asset_type = parsed[-2]
asset_name = f"{repo_name}-{asset_type}"
CACHE_ROOT.mkdir(parents=True, exist_ok=True) CACHE_ROOT.mkdir(parents=True, exist_ok=True)
return CACHE_ROOT / url_hash return CACHE_ROOT / f"{url_hash}-{asset_name}"
def checksum_file_path(url):
"""Generate checksum filename for a given URL"""
path = cache_file_path(url)
return path.parent / (path.name + ".sha256")
def store_checksum_in_cache(url, checksum): def store_checksum_in_cache(url, checksum):
"""Store the checksum in a file whose name is based on the URL hash.""" """Store the checksum in a file whose name is based on the URL hash."""
checksum_path = cache_file_path(url).with_suffix(".sha256") with open(checksum_file_path(url), "w") as fp:
with open(checksum_path, "w") as fp:
fp.write(checksum) fp.write(checksum)
def read_checksum_from_cache(url): def read_checksum_from_cache(url):
checksum_path = cache_file_path(url).with_suffix(".sha256") checksum_file = checksum_file_path(url)
if checksum_path.exists(): if checksum_file.exists():
return checksum_path.read_text().strip() return checksum_file.read_text().strip()
return None return None
@ -126,7 +185,7 @@ def download_to_cache(url):
if cached: if cached:
return cached return cached
print(f"Downloading {url} into cache...") logger.debug(f"Downloading {url} into cache...")
response = requests.get(url, stream=True) response = requests.get(url, stream=True)
response.raise_for_status() response.raise_for_status()
@ -137,11 +196,12 @@ def download_to_cache(url):
with open(cached, "rb") as f: with open(cached, "rb") as f:
checksum = calc_checksum(f) checksum = calc_checksum(f)
store_checksum_in_cache(url, checksum) store_checksum_in_cache(url, checksum)
print("Download to cache completed.") logger.debug("Download to cache completed.")
return cached return cached
def detect_platform(): def detect_platform():
"""Detect the platform that the script runs in"""
# Return a string like 'windows/amd64' or 'linux/amd64' or 'darwin/amd64' # Return a string like 'windows/amd64' or 'linux/amd64' or 'darwin/amd64'
sys_platform = sys.platform sys_platform = sys.platform
if sys_platform.startswith("win"): if sys_platform.startswith("win"):
@ -156,6 +216,7 @@ def detect_platform():
machine = platform.machine().lower() machine = platform.machine().lower()
# Normalize architecture names # Normalize architecture names
arch = {"x86_64": "amd64", "amd64": "amd64", "arm64": "arm64"}.get(machine, machine) arch = {"x86_64": "amd64", "amd64": "amd64", "arm64": "arm64"}.get(machine, machine)
return f"{os_name}/{arch}" return f"{os_name}/{arch}"
@ -166,9 +227,12 @@ def get_latest_release(repo, semver_range):
""" """
url = f"{GITHUB_API_URL}/repos/{repo}/releases" url = f"{GITHUB_API_URL}/repos/{repo}/releases"
response = requests.get(url) response = requests.get(url)
response.raise_for_status()
if response.status_code != 200: if response.status_code != 200:
print(f"Failed to fetch releases for repo {repo}. HTTP {response.status_code}") raise InvException(
return None f"Unexpected response when fetching releases for repo '{repo}': HTTP"
f" {response.status_code}"
)
releases = response.json() releases = response.json()
matching = [] matching = []
@ -176,23 +240,41 @@ def get_latest_release(repo, semver_range):
tag = release.get("tag_name", "") tag = release.get("tag_name", "")
# Strip any prefix 'v' if necessary # Strip any prefix 'v' if necessary
version_str = tag.lstrip("v") version_str = tag.lstrip("v")
# Attempt to parse asset version as semver. If the project has a tag that does
# not conform to SemVer, just skip it.
try: try:
version = semver.VersionInfo.parse(version_str) version = semver.VersionInfo.parse(version_str)
# Skip prereleases and non-matching versions
if release["prerelease"] or not version.match(semver_range):
continue
matching.append((release, version))
except ValueError: except ValueError:
logger.debug(
f"Skipping non SemVer-compliant version '{version_str}' from repo"
f" '{repo}'"
)
continue continue
# Skip prereleases and non-matching versions
if release["prerelease"]:
logger.debug(
f"Skipping prerelease version '{version_str}' from repo '{repo}'"
)
continue
elif not version.match(semver_range):
logger.debug(
f"Skipping version '{version_str}' from repo '{repo}' because it does"
f" not match the '{semver_range}' requirement"
)
continue
matching.append((release, version))
if not matching: if not matching:
print(f"No releases match version requirement {semver_range} for repo {repo}") raise InvException(
return None f"No releases match version requirement {semver_range} for repo '{repo}'"
)
return max(matching, key=lambda x: x[1])[0] return max(matching, key=lambda x: x[1])[0]
def resolve_asset_for_platform(release, name): def get_download_url(release, name):
""" """
Given the release JSON and an asset name, find the asset download URL by matching Given the release JSON and an asset name, find the asset download URL by matching
filename. If the asset name contains "{version}", it will be formatted using the filename. If the asset name contains "{version}", it will be formatted using the
@ -211,7 +293,8 @@ def resolve_asset_for_platform(release, name):
for asset in assets: for asset in assets:
if asset.get("name") == expected_name: if asset.get("name") == expected_name:
return asset.get("browser_download_url") return asset.get("browser_download_url")
return None
raise InvException(f"Could not find an asset with '{name}'")
def hash_asset(url): def hash_asset(url):
@ -219,6 +302,15 @@ def hash_asset(url):
Download the asset using caching and return its SHA256 checksum. Download the asset using caching and return its SHA256 checksum.
The checksum is also stored in the cache as a .sha256 file. The checksum is also stored in the cache as a .sha256 file.
""" """
# If we have downloaded the file and hashed it before, return the checksum
# immediately.
checksum_file = checksum_file_path(url)
if checksum_file.exists():
logger.debug(f"Using cached checksum for URL: {url}")
with open(checksum_file, "r") as f:
return f.read()
# Else, download the file, hash it, and store the checksum in the cache.
cached_file = download_to_cache(url) cached_file = download_to_cache(url)
with open(cached_file, "rb") as f: with open(cached_file, "rb") as f:
checksum = calc_checksum(f) checksum = calc_checksum(f)
@ -234,20 +326,17 @@ def download_to_cache_and_verify(url, destination, expected_checksum):
If not, remove the cached file and raise an exception. If not, remove the cached file and raise an exception.
""" """
cached_file = download_to_cache(url) cached_file = download_to_cache(url)
checksum_file = checksum_file_path(url)
with open(cached_file, "rb") as f: with open(cached_file, "rb") as f:
computed_checksum = calc_checksum(f) computed_checksum = calc_checksum(f)
if computed_checksum != expected_checksum: if computed_checksum != expected_checksum:
# Remove cache file and its checksum file # Remove cache file and its checksum file
try: cached_file.unlink(missing_ok=True)
cached_file.unlink() checksum_file.unlink(missing_ok=True)
checksum_file = cached_file.with_suffix(".sha256") raise InvException(
if checksum_file.exists(): f"Hash mismatch for URL {url}: computed '{computed_checksum}',"
checksum_file.unlink() f" expected '{expected_checksum}'"
except Exception:
pass
raise Exception(
f"Hash mismatch for URL {url}: computed '{computed_checksum}', expected '{expected_checksum}'"
) )
return cached_file return cached_file
@ -266,7 +355,7 @@ def determine_extract_opts(extract):
globs = ["*"] globs = ["*"]
flatten = False flatten = False
else: else:
raise Exception(f"Unexpected format for 'extract' field: {extract}") raise InvException(f"Unexpected format for 'extract' field: {extract}")
return { return {
"globs": globs, "globs": globs,
@ -284,12 +373,13 @@ def detect_archive_type(name):
return "tar" return "tar"
if name.endswith(".zip") or name == "!zipball": if name.endswith(".zip") or name == "!zipball":
return "zip" return "zip"
raise Exception(f"Unsupported archive type for extraction: {name}") raise InvException(f"Unsupported archive type for extraction: {name}")
def flatten_extracted_files(destination): def flatten_extracted_files(destination):
""" """
After extraction, move all files found in subdirectories of destination into destination root. After extraction, move all files found in subdirectories of destination into
destination root.
""" """
for root, dirs, files in os.walk(destination): for root, dirs, files in os.walk(destination):
# Skip the root directory itself # Skip the root directory itself
@ -323,6 +413,7 @@ def extract_asset(archive_path, destination, options):
For tarfiles, use filter="data" when extracting to mitigate malicious tar entries. For tarfiles, use filter="data" when extracting to mitigate malicious tar entries.
""" """
logger.debug(f"Extracting '{archive_path}' to '{destination}'...")
ft = options["filetype"] ft = options["filetype"]
globs = options["globs"] globs = options["globs"]
flatten = options["flatten"] flatten = options["flatten"]
@ -337,10 +428,10 @@ def extract_asset(archive_path, destination, options):
if any(fnmatch.fnmatch(m.name, glob) for glob in globs) if any(fnmatch.fnmatch(m.name, glob) for glob in globs)
] ]
if not members: if not members:
raise Exception("Globs did not match any files in the archive") raise InvException("Globs did not match any files in the archive")
tar.extractall(path=destination, members=members, filter="data") tar.extractall(path=destination, members=members, filter="data")
except Exception as e: except Exception as e:
raise Exception(f"Error extracting '{archive_path}': {e}") raise InvException(f"Error extracting '{archive_path}': {e}") from e
elif ft == "zip": elif ft == "zip":
try: try:
with zipfile.ZipFile(archive_path, "r") as zip_ref: with zipfile.ZipFile(archive_path, "r") as zip_ref:
@ -350,17 +441,17 @@ def extract_asset(archive_path, destination, options):
if any(fnmatch.fnmatch(m, glob) for glob in globs) if any(fnmatch.fnmatch(m, glob) for glob in globs)
] ]
if not members: if not members:
raise Exception("Globs did not match any files in the archive") raise InvException("Globs did not match any files in the archive")
zip_ref.extractall(path=destination, members=members) zip_ref.extractall(path=destination, members=members)
except Exception as e: except Exception as e:
raise Exception(f"Error extracting zip archive: {e}") raise InvException(f"Error extracting zip archive: {e}") from e
else: else:
raise Exception(f"Unsupported archive type for file {archive_path}") raise InvException(f"Unsupported archive type for file {archive_path}")
if flatten: if flatten:
flatten_extracted_files(destination) flatten_extracted_files(destination)
print(f"Extraction of {archive_path} complete.") logger.debug(f"Successfully extracted '{archive_path}'")
def get_platform_assets(assets, platform): def get_platform_assets(assets, platform):
@ -387,6 +478,58 @@ def chmod_exec(path):
path.chmod(path.stat().st_mode | 0o111) path.chmod(path.stat().st_mode | 0o111)
def compute_asset_lock(asset_name, asset):
try:
repo = asset["repo"]
version_range = asset["version"]
asset_map = asset["platform"] # mapping platform -> asset file name
destination_str = asset["destination"]
executable = asset.get("executable", False)
extract = asset.get("extract", False)
except KeyError as e:
raise InvException(f"Required field {e} is missing")
if extract:
extract = determine_extract_opts(extract)
logger.debug(
f"Fetching a release that satisfies version range '{version_range}' for repo"
f" '{repo}'"
)
release = get_latest_release(repo, version_range)
version = release["tag_name"].lstrip("v")
logger.debug(f"Found release '{version}' for repo '{repo}'")
asset_lock_data = {}
# Process each defined platform key in the asset_map
for plat_key, plat_name in asset_map.items():
logger.debug(f"Getting download URL for asset '{asset_name}' of repo '{repo}'")
download_url = get_download_url(release, plat_name)
logger.debug(f"Found download URL: {download_url}")
if extract:
extract = extract.copy()
extract["filetype"] = detect_archive_type(plat_name)
logger.debug(
f"Hashing asset '{asset_name}' of repo '{repo}' for platform"
f" '{plat_key}'..."
)
checksum = hash_asset(download_url)
logger.debug(f"Computed the following SHA-256 checksum: {checksum}")
asset_lock_data[plat_key] = {
"repo": repo,
"download_url": download_url,
"version": version,
"checksum": checksum,
"executable": executable,
"destination": destination_str,
"extract": extract,
}
return asset_lock_data
# COMMAND FUNCTIONS # COMMAND FUNCTIONS
def cmd_lock(args): def cmd_lock(args):
""" """
@ -416,64 +559,22 @@ def cmd_lock(args):
# and flatten = True|False. # and flatten = True|False.
assets_cfg = config.get("asset", {}) assets_cfg = config.get("asset", {})
if not assets_cfg: if not assets_cfg:
print("No assets defined under the [asset] section in the config file.") raise InvException(
sys.exit(1) "No assets defined under the [asset] section in the config file."
)
for asset_name, asset in assets_cfg.items():
repo = asset.get("repo")
version_range = asset.get("version")
asset_map = asset.get("platform") # mapping platform -> asset file name
executable = asset.get("executable")
destination_str = asset.get("destination")
extract = asset.get("extract", False)
if extract:
extract = determine_extract_opts(extract)
if not repo or not version_range or not asset_map or not destination_str:
print(f"Asset {asset_name} is missing required fields.")
continue
print(f"Processing asset '{asset_name}' for repo '{repo}' ...")
release = get_latest_release(repo, version_range)
if not release:
print(f"Could not resolve release for asset '{asset_name}'.")
continue
lock_assets = lock["assets"] lock_assets = lock["assets"]
asset_lock_data = {} for asset_name, asset in assets_cfg.items():
# Process each defined platform key in the asset_map logger.info(f"Processing asset '{asset_name}'...")
for plat_key, plat_name in asset_map.items(): try:
download_url = resolve_asset_for_platform(release, plat_name) lock_assets[asset_name] = compute_asset_lock(asset_name, asset)
if not download_url: except Exception as e:
print( raise InvException(
f"Warning: No asset found for platform '{plat_key}' in repo '{repo}' for asset '{asset_name}'." f"Error when processing asset '{asset_name}': {e}"
) ) from e
continue
if extract:
extract = extract.copy()
extract["filetype"] = detect_archive_type(plat_name)
print(f"Hashing asset '{asset_name}' for platform '{plat_key}'...")
checksum = hash_asset(download_url)
asset_lock_data[plat_key] = {
"repo": repo,
"download_url": download_url,
"version": release.get("tag_name").lstrip("v"),
"checksum": checksum,
"executable": executable,
"destination": destination_str,
"extract": extract,
}
if not asset_lock_data:
print(f"No valid platforms found for asset '{asset_name}'.")
continue
lock_assets[asset_name] = asset_lock_data
config_hash = hashlib.sha256(json.dumps(config).encode()).hexdigest()
lock["config_checksum"] = config_hash
write_lock(lock) write_lock(lock)
print(f"Lock file '{LOCK_FILE}' updated.") logger.info(f"Lock file '{LOCK_FILE}' updated.")
def cmd_sync(args): def cmd_sync(args):
@ -552,10 +653,7 @@ def cmd_sync(args):
if executable: if executable:
chmod_exec(destination) chmod_exec(destination)
except Exception as e: except Exception as e:
print( raise InvException(f"Error processing asset '{asset_name}': {e}") from e
f"Error processing asset '{asset_name}' for platform '{target_plat}': {e}"
)
continue
print("Downloads completed") print("Downloads completed")
@ -573,20 +671,14 @@ def cmd_list(args):
print(f"{asset_name} {asset['version']} {asset['download_url']}") print(f"{asset_name} {asset['version']} {asset['download_url']}")
def main(): def parse_args():
parser = argparse.ArgumentParser(description="GitHub Release Assets Manager") parser = argparse.ArgumentParser(description="Inventory")
subparsers = parser.add_subparsers(dest="command", required=True) subparsers = parser.add_subparsers(dest="command", required=True)
lock_parser = subparsers.add_parser( lock_parser = subparsers.add_parser("lock", help="Update lock file from config")
"lock", help="Update lock file from config (without downloading)"
)
lock_parser.set_defaults(func=cmd_lock) lock_parser.set_defaults(func=cmd_lock)
sync_parser = subparsers.add_parser("sync", help="Sync assets as per lock file") sync_parser = subparsers.add_parser("sync", help="Sync assets as per lock file")
sync_parser.add_argument(
"--platform",
help="Platform name/arch (e.g., windows/amd64) to download assets for",
)
sync_parser.add_argument( sync_parser.add_argument(
"assets", "assets",
nargs="*", nargs="*",
@ -595,14 +687,66 @@ def main():
sync_parser.set_defaults(func=cmd_sync) sync_parser.set_defaults(func=cmd_sync)
list_parser = subparsers.add_parser("list", help="List assets for a platform") list_parser = subparsers.add_parser("list", help="List assets for a platform")
list_parser.add_argument(
"--platform",
help="Platform name/arch (e.g., windows/amd64) to list assets for",
)
list_parser.set_defaults(func=cmd_list) list_parser.set_defaults(func=cmd_list)
# Add common arguments.
for subparser in subparsers.choices.values():
subparser.add_argument(
"-p",
"--platform",
help=(
"The platform to choose when determining which inventory assets to work"
" on. Examples: windows/amd64, linux/amd64, darwin/amd64, darwin/arm64"
" (default: %(default)s)"
),
)
subparser.add_argument(
"-v",
"--verbose",
action="count",
default=0,
help="Enable verbose logging",
)
subparser.add_argument(
"-C",
"--directory",
help=(
"The working directory for the script (defaults to the current working"
" directory)"
),
)
args = parser.parse_args() args = parser.parse_args()
args.func(args) return args
def setup_logging(verbose=False):
"""Simple way to setup logging.
Copied from: https://docs.python.org/3/howto/logging.html
"""
# specify level
if not verbose:
lvl = logging.WARN
elif verbose == 1:
lvl = logging.INFO
else:
lvl = logging.DEBUG
logging.basicConfig(
level=lvl,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
def main():
args = parse_args()
setup_logging(args.verbose)
with report_error(verbose=args.verbose, fail=False):
return args.func(args)
if __name__ == "__main__": if __name__ == "__main__":
main() sys.exit(main())