depot/pkgs/by-name/az/azure-cli/extensions-tool.py
Luke Granger-Brown 57725ef3ec Squashed 'third_party/nixpkgs/' content from commit 76612b17c0ce
git-subtree-dir: third_party/nixpkgs
git-subtree-split: 76612b17c0ce71689921ca12d9ffdc9c23ce40b2
2024-11-10 23:59:47 +00:00

324 lines
10 KiB
Python

#!/usr/bin/env python
import argparse
import base64
import datetime
import json
import logging
import os
import sys
from dataclasses import asdict, dataclass, replace
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
from urllib.request import Request, urlopen
import git
from packaging.version import Version, parse
INDEX_URL = "https://azcliextensionsync.blob.core.windows.net/index1/index.json"
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class Ext:
pname: str
version: Version
url: str
hash: str
description: str
def _read_cached_index(path: Path) -> Tuple[datetime.datetime, Any]:
with open(path, "r") as f:
data = f.read()
j = json.loads(data)
cache_date_str = j["cache_date"]
if cache_date_str:
cache_date = datetime.datetime.fromisoformat(cache_date_str)
else:
cache_date = datetime.datetime.min
return cache_date, data
def _write_index_to_cache(data: Any, path: Path):
j = json.loads(data)
j["cache_date"] = datetime.datetime.now().isoformat()
with open(path, "w") as f:
json.dump(j, f, indent=2)
def _fetch_remote_index():
r = Request(INDEX_URL)
with urlopen(r) as resp:
return resp.read()
def get_extension_index(cache_dir: Path) -> Set[Ext]:
index_file = cache_dir / "index.json"
os.makedirs(cache_dir, exist_ok=True)
try:
index_cache_date, index_data = _read_cached_index(index_file)
except FileNotFoundError:
logger.info("index has not been cached, downloading from source")
logger.info("creating index cache in %s", index_file)
_write_index_to_cache(_fetch_remote_index(), index_file)
return get_extension_index(cache_dir)
if (
index_cache_date
and datetime.datetime.now() - index_cache_date > datetime.timedelta(days=1)
):
logger.info(
"cache is outdated (%s), refreshing",
datetime.datetime.now() - index_cache_date,
)
_write_index_to_cache(_fetch_remote_index(), index_file)
return get_extension_index(cache_dir)
logger.info("using index cache from %s", index_file)
return json.loads(index_data)
def _read_extension_set(extensions_generated: Path) -> Set[Ext]:
with open(extensions_generated, "r") as f:
data = f.read()
parsed_exts = {Ext(**json_ext) for _pname, json_ext in json.loads(data).items()}
parsed_exts_with_ver = set()
for ext in parsed_exts:
ext2 = replace(ext, version=parse(ext.version))
parsed_exts_with_ver.add(ext2)
return parsed_exts_with_ver
def _write_extension_set(extensions_generated: Path, extensions: Set[Ext]) -> None:
set_without_ver = {replace(ext, version=str(ext.version)) for ext in extensions}
ls = list(set_without_ver)
ls.sort(key=lambda e: e.pname)
with open(extensions_generated, "w") as f:
json.dump({ext.pname: asdict(ext) for ext in ls}, f, indent=2)
f.write("\n")
def _convert_hash_digest_from_hex_to_b64_sri(s: str) -> str:
try:
b = bytes.fromhex(s)
except ValueError as err:
logger.error("not a hex value: %s", str(err))
raise err
return f"sha256-{base64.b64encode(b).decode('utf-8')}"
def _commit(repo: git.Repo, message: str, files: List[Path], actor: git.Actor) -> None:
repo.index.add([str(f.resolve()) for f in files])
if repo.index.diff("HEAD"):
logger.info(f'committing to nixpkgs "{message}"')
repo.index.commit(message, author=actor, committer=actor)
else:
logger.warning("no changes in working tree to commit")
def _filter_invalid(o: Dict[str, Any]) -> bool:
if "metadata" not in o:
logger.warning("extension without metadata")
return False
metadata = o["metadata"]
if "name" not in metadata:
logger.warning("extension without name")
return False
if "version" not in metadata:
logger.warning(f"{metadata['name']} without version")
return False
if "azext.minCliCoreVersion" not in metadata:
logger.warning(
f"{metadata['name']} {metadata['version']} does not have azext.minCliCoreVersion"
)
return False
if "summary" not in metadata:
logger.info(f"{metadata['name']} {metadata['version']} without summary")
return False
if "downloadUrl" not in o:
logger.warning(f"{metadata['name']} {metadata['version']} without downloadUrl")
return False
if "sha256Digest" not in o:
logger.warning(f"{metadata['name']} {metadata['version']} without sha256Digest")
return False
return True
def _filter_compatible(o: Dict[str, Any], cli_version: Version) -> bool:
minCliVersion = parse(o["metadata"]["azext.minCliCoreVersion"])
return cli_version >= minCliVersion
def _transform_dict_to_obj(o: Dict[str, Any]) -> Ext:
m = o["metadata"]
return Ext(
pname=m["name"],
version=parse(m["version"]),
url=o["downloadUrl"],
hash=_convert_hash_digest_from_hex_to_b64_sri(o["sha256Digest"]),
description=m["summary"].rstrip("."),
)
def _get_latest_version(versions: dict) -> dict:
return max(versions, key=lambda e: parse(e["metadata"]["version"]), default=None)
def processExtension(
extVersions: dict,
cli_version: Version,
ext_name: Optional[str] = None,
requirements: bool = False,
) -> Optional[Ext]:
versions = filter(_filter_invalid, extVersions)
versions = filter(lambda v: _filter_compatible(v, cli_version), versions)
latest = _get_latest_version(versions)
if not latest:
return None
if ext_name and latest["metadata"]["name"] != ext_name:
return None
if not requirements and "run_requires" in latest["metadata"]:
return None
return _transform_dict_to_obj(latest)
def _diff_sets(
set_local: Set[Ext], set_remote: Set[Ext]
) -> Tuple[Set[Ext], Set[Ext], Set[Tuple[Ext, Ext]]]:
local_exts = {ext.pname: ext for ext in set_local}
remote_exts = {ext.pname: ext for ext in set_remote}
only_local = local_exts.keys() - remote_exts.keys()
only_remote = remote_exts.keys() - local_exts.keys()
both = remote_exts.keys() & local_exts.keys()
return (
{local_exts[pname] for pname in only_local},
{remote_exts[pname] for pname in only_remote},
{(local_exts[pname], remote_exts[pname]) for pname in both},
)
def _filter_updated(e: Tuple[Ext, Ext]) -> bool:
prev, new = e
return prev != new
def main() -> None:
sh = logging.StreamHandler(sys.stderr)
sh.setFormatter(
logging.Formatter(
"[%(asctime)s] [%(levelname)8s] --- %(message)s (%(filename)s:%(lineno)s)",
"%Y-%m-%d %H:%M:%S",
)
)
logging.basicConfig(level=logging.INFO, handlers=[sh])
parser = argparse.ArgumentParser(
prog="azure-cli.extensions-tool",
description="Script to handle Azure CLI extension updates",
)
parser.add_argument(
"--cli-version", type=str, help="version of azure-cli (required)"
)
parser.add_argument("--extension", type=str, help="name of extension to query")
parser.add_argument(
"--cache-dir",
type=Path,
help="path where to cache the extension index",
default=Path(os.getenv("XDG_CACHE_HOME", Path.home() / ".cache"))
/ "azure-cli-extensions-tool",
)
parser.add_argument(
"--requirements",
action=argparse.BooleanOptionalAction,
help="whether to list extensions that have requirements",
)
parser.add_argument(
"--commit",
action=argparse.BooleanOptionalAction,
help="whether to commit changes to git",
)
args = parser.parse_args()
repo = git.Repo(Path(".").resolve(), search_parent_directories=True)
# Workaround for https://github.com/gitpython-developers/GitPython/issues/1923
author = repo.config_reader().get_value("user", "name").lstrip('"').rstrip('"')
email = repo.config_reader().get_value("user", "email").lstrip('"').rstrip('"')
actor = git.Actor(author, email)
index = get_extension_index(args.cache_dir)
assert index["formatVersion"] == "1" # only support formatVersion 1
extensions_remote = index["extensions"]
cli_version = parse(args.cli_version)
extensions_remote_filtered = set()
for _ext_name, extension in extensions_remote.items():
extension = processExtension(extension, cli_version, args.extension)
if extension:
extensions_remote_filtered.add(extension)
extension_file = (
Path(repo.working_dir) / "pkgs/by-name/az/azure-cli/extensions-generated.json"
)
extensions_local = _read_extension_set(extension_file)
extensions_local_filtered = set()
if args.extension:
extensions_local_filtered = filter(
lambda ext: args.extension == ext.pname, extensions_local
)
else:
extensions_local_filtered = extensions_local
removed, init, updated = _diff_sets(
extensions_local_filtered, extensions_remote_filtered
)
updated = set(filter(_filter_updated, updated))
logger.info("initialized extensions:")
for ext in init:
logger.info(f" {ext.pname} {ext.version}")
logger.info("removed extensions:")
for ext in removed:
logger.info(f" {ext.pname} {ext.version}")
logger.info("updated extensions:")
for prev, new in updated:
logger.info(f" {prev.pname} {prev.version} -> {new.version}")
for ext in init:
extensions_local.add(ext)
commit_msg = f"azure-cli-extensions.{ext.pname}: init at {ext.version}"
_write_extension_set(extension_file, extensions_local)
if args.commit:
_commit(repo, commit_msg, [extension_file], actor)
for prev, new in updated:
extensions_local.remove(prev)
extensions_local.add(new)
commit_msg = (
f"azure-cli-extensions.{prev.pname}: {prev.version} -> {new.version}"
)
_write_extension_set(extension_file, extensions_local)
if args.commit:
_commit(repo, commit_msg, [extension_file], actor)
for ext in removed:
extensions_local.remove(ext)
# TODO: Add additional check why this is removed
# TODO: Add an alias to extensions manual?
commit_msg = f"azure-cli-extensions.{ext.pname}: remove"
_write_extension_set(extension_file, extensions_local)
if args.commit:
_commit(repo, commit_msg, [extension_file], actor)
if __name__ == "__main__":
main()