#!/usr/bin/env python import argparse import base64 import datetime import json import logging import os import sys from dataclasses import asdict, dataclass, replace from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Set, Tuple from urllib.request import Request, urlopen import git from packaging.version import Version, parse INDEX_URL = "https://azcliextensionsync.blob.core.windows.net/index1/index.json" logger = logging.getLogger(__name__) @dataclass(frozen=True) class Ext: pname: str version: Version url: str hash: str description: str def _read_cached_index(path: Path) -> Tuple[datetime.datetime, Any]: with open(path, "r") as f: data = f.read() j = json.loads(data) cache_date_str = j["cache_date"] if cache_date_str: cache_date = datetime.datetime.fromisoformat(cache_date_str) else: cache_date = datetime.datetime.min return cache_date, data def _write_index_to_cache(data: Any, path: Path): j = json.loads(data) j["cache_date"] = datetime.datetime.now().isoformat() with open(path, "w") as f: json.dump(j, f, indent=2) def _fetch_remote_index(): r = Request(INDEX_URL) with urlopen(r) as resp: return resp.read() def get_extension_index(cache_dir: Path) -> Set[Ext]: index_file = cache_dir / "index.json" os.makedirs(cache_dir, exist_ok=True) try: index_cache_date, index_data = _read_cached_index(index_file) except FileNotFoundError: logger.info("index has not been cached, downloading from source") logger.info("creating index cache in %s", index_file) _write_index_to_cache(_fetch_remote_index(), index_file) return get_extension_index(cache_dir) if ( index_cache_date and datetime.datetime.now() - index_cache_date > datetime.timedelta(days=1) ): logger.info( "cache is outdated (%s), refreshing", datetime.datetime.now() - index_cache_date, ) _write_index_to_cache(_fetch_remote_index(), index_file) return get_extension_index(cache_dir) logger.info("using index cache from %s", index_file) return json.loads(index_data) def _read_extension_set(extensions_generated: Path) -> Set[Ext]: with open(extensions_generated, "r") as f: data = f.read() parsed_exts = {Ext(**json_ext) for _pname, json_ext in json.loads(data).items()} parsed_exts_with_ver = set() for ext in parsed_exts: ext2 = replace(ext, version=parse(ext.version)) parsed_exts_with_ver.add(ext2) return parsed_exts_with_ver def _write_extension_set(extensions_generated: Path, extensions: Set[Ext]) -> None: set_without_ver = {replace(ext, version=str(ext.version)) for ext in extensions} ls = list(set_without_ver) ls.sort(key=lambda e: e.pname) with open(extensions_generated, "w") as f: json.dump({ext.pname: asdict(ext) for ext in ls}, f, indent=2) f.write("\n") def _convert_hash_digest_from_hex_to_b64_sri(s: str) -> str: try: b = bytes.fromhex(s) except ValueError as err: logger.error("not a hex value: %s", str(err)) raise err return f"sha256-{base64.b64encode(b).decode('utf-8')}" def _commit(repo: git.Repo, message: str, files: List[Path], actor: git.Actor) -> None: repo.index.add([str(f.resolve()) for f in files]) if repo.index.diff("HEAD"): logger.info(f'committing to nixpkgs "{message}"') repo.index.commit(message, author=actor, committer=actor) else: logger.warning("no changes in working tree to commit") def _filter_invalid(o: Dict[str, Any]) -> bool: if "metadata" not in o: logger.warning("extension without metadata") return False metadata = o["metadata"] if "name" not in metadata: logger.warning("extension without name") return False if "version" not in metadata: logger.warning(f"{metadata['name']} without version") return False if "azext.minCliCoreVersion" not in metadata: logger.warning( f"{metadata['name']} {metadata['version']} does not have azext.minCliCoreVersion" ) return False if "summary" not in metadata: logger.info(f"{metadata['name']} {metadata['version']} without summary") return False if "downloadUrl" not in o: logger.warning(f"{metadata['name']} {metadata['version']} without downloadUrl") return False if "sha256Digest" not in o: logger.warning(f"{metadata['name']} {metadata['version']} without sha256Digest") return False return True def _filter_compatible(o: Dict[str, Any], cli_version: Version) -> bool: minCliVersion = parse(o["metadata"]["azext.minCliCoreVersion"]) return cli_version >= minCliVersion def _transform_dict_to_obj(o: Dict[str, Any]) -> Ext: m = o["metadata"] return Ext( pname=m["name"], version=parse(m["version"]), url=o["downloadUrl"], hash=_convert_hash_digest_from_hex_to_b64_sri(o["sha256Digest"]), description=m["summary"].rstrip("."), ) def _get_latest_version(versions: dict) -> dict: return max(versions, key=lambda e: parse(e["metadata"]["version"]), default=None) def processExtension( extVersions: dict, cli_version: Version, ext_name: Optional[str] = None, requirements: bool = False, ) -> Optional[Ext]: versions = filter(_filter_invalid, extVersions) versions = filter(lambda v: _filter_compatible(v, cli_version), versions) latest = _get_latest_version(versions) if not latest: return None if ext_name and latest["metadata"]["name"] != ext_name: return None if not requirements and "run_requires" in latest["metadata"]: return None return _transform_dict_to_obj(latest) def _diff_sets( set_local: Set[Ext], set_remote: Set[Ext] ) -> Tuple[Set[Ext], Set[Ext], Set[Tuple[Ext, Ext]]]: local_exts = {ext.pname: ext for ext in set_local} remote_exts = {ext.pname: ext for ext in set_remote} only_local = local_exts.keys() - remote_exts.keys() only_remote = remote_exts.keys() - local_exts.keys() both = remote_exts.keys() & local_exts.keys() return ( {local_exts[pname] for pname in only_local}, {remote_exts[pname] for pname in only_remote}, {(local_exts[pname], remote_exts[pname]) for pname in both}, ) def _filter_updated(e: Tuple[Ext, Ext]) -> bool: prev, new = e return prev != new def main() -> None: sh = logging.StreamHandler(sys.stderr) sh.setFormatter( logging.Formatter( "[%(asctime)s] [%(levelname)8s] --- %(message)s (%(filename)s:%(lineno)s)", "%Y-%m-%d %H:%M:%S", ) ) logging.basicConfig(level=logging.INFO, handlers=[sh]) parser = argparse.ArgumentParser( prog="azure-cli.extensions-tool", description="Script to handle Azure CLI extension updates", ) parser.add_argument( "--cli-version", type=str, help="version of azure-cli (required)" ) parser.add_argument("--extension", type=str, help="name of extension to query") parser.add_argument( "--cache-dir", type=Path, help="path where to cache the extension index", default=Path(os.getenv("XDG_CACHE_HOME", Path.home() / ".cache")) / "azure-cli-extensions-tool", ) parser.add_argument( "--requirements", action=argparse.BooleanOptionalAction, help="whether to list extensions that have requirements", ) parser.add_argument( "--commit", action=argparse.BooleanOptionalAction, help="whether to commit changes to git", ) args = parser.parse_args() repo = git.Repo(Path(".").resolve(), search_parent_directories=True) # Workaround for https://github.com/gitpython-developers/GitPython/issues/1923 author = repo.config_reader().get_value("user", "name").lstrip('"').rstrip('"') email = repo.config_reader().get_value("user", "email").lstrip('"').rstrip('"') actor = git.Actor(author, email) index = get_extension_index(args.cache_dir) assert index["formatVersion"] == "1" # only support formatVersion 1 extensions_remote = index["extensions"] cli_version = parse(args.cli_version) extensions_remote_filtered = set() for _ext_name, extension in extensions_remote.items(): extension = processExtension(extension, cli_version, args.extension) if extension: extensions_remote_filtered.add(extension) extension_file = ( Path(repo.working_dir) / "pkgs/by-name/az/azure-cli/extensions-generated.json" ) extensions_local = _read_extension_set(extension_file) extensions_local_filtered = set() if args.extension: extensions_local_filtered = filter( lambda ext: args.extension == ext.pname, extensions_local ) else: extensions_local_filtered = extensions_local removed, init, updated = _diff_sets( extensions_local_filtered, extensions_remote_filtered ) updated = set(filter(_filter_updated, updated)) logger.info("initialized extensions:") for ext in init: logger.info(f" {ext.pname} {ext.version}") logger.info("removed extensions:") for ext in removed: logger.info(f" {ext.pname} {ext.version}") logger.info("updated extensions:") for prev, new in updated: logger.info(f" {prev.pname} {prev.version} -> {new.version}") for ext in init: extensions_local.add(ext) commit_msg = f"azure-cli-extensions.{ext.pname}: init at {ext.version}" _write_extension_set(extension_file, extensions_local) if args.commit: _commit(repo, commit_msg, [extension_file], actor) for prev, new in updated: extensions_local.remove(prev) extensions_local.add(new) commit_msg = ( f"azure-cli-extensions.{prev.pname}: {prev.version} -> {new.version}" ) _write_extension_set(extension_file, extensions_local) if args.commit: _commit(repo, commit_msg, [extension_file], actor) for ext in removed: extensions_local.remove(ext) # TODO: Add additional check why this is removed # TODO: Add an alias to extensions manual? commit_msg = f"azure-cli-extensions.{ext.pname}: remove" _write_extension_set(extension_file, extensions_local) if args.commit: _commit(repo, commit_msg, [extension_file], actor) if __name__ == "__main__": main()