depot/third_party/nixpkgs/pkgs/servers/apache-airflow/update-providers.py
Default email 9c6ee729d6 Project import generated by Copybara.
GitOrigin-RevId: 6cee3b5893090b0f5f0a06b4cf42ca4e60e5d222
2023-07-15 19:15:38 +02:00

227 lines
7.5 KiB
Python
Executable file

#! /usr/bin/env python3
from itertools import chain
import json
import logging
from pathlib import Path
import os
import re
import subprocess
import sys
from typing import Dict, List, Optional, Set, TextIO
from urllib.request import urlopen
from urllib.error import HTTPError
import yaml
PKG_SET = "apache-airflow.pythonPackages"
# If some requirements are matched by multiple or no Python packages, the
# following can be used to choose the correct one
PKG_PREFERENCES = {
"dnspython": "dnspython",
"elasticsearch-dsl": "elasticsearch-dsl",
"google-api-python-client": "google-api-python-client",
"psycopg2-binary": "psycopg2",
"requests_toolbelt": "requests-toolbelt",
}
# Requirements missing from the airflow provider metadata
EXTRA_REQS = {
"sftp": ["pysftp"],
}
def get_version():
with open(os.path.dirname(sys.argv[0]) + "/default.nix") as fh:
# A version consists of digits, dots, and possibly a "b" (for beta)
m = re.search('version = "([\\d\\.b]+)";', fh.read())
return m.group(1)
def get_file_from_github(version: str, path: str):
with urlopen(
f"https://raw.githubusercontent.com/apache/airflow/{version}/{path}"
) as response:
return yaml.safe_load(response)
def repository_root() -> Path:
return Path(os.path.dirname(sys.argv[0])) / "../../../.."
def dump_packages() -> Dict[str, Dict[str, str]]:
# Store a JSON dump of Nixpkgs' python3Packages
output = subprocess.check_output(
[
"nix-env",
"-f",
repository_root(),
"-qa",
"-A",
PKG_SET,
"--arg",
"config",
"{ allowAliases = false; }",
"--json",
]
)
return json.loads(output)
def remove_version_constraint(req: str) -> str:
return re.sub(r"[=><~].*$", "", req)
def name_to_attr_path(req: str, packages: Dict[str, Dict[str, str]]) -> Optional[str]:
if req in PKG_PREFERENCES:
return f"{PKG_SET}.{PKG_PREFERENCES[req]}"
attr_paths = []
names = [req]
# E.g. python-mpd2 is actually called python3.6-mpd2
# instead of python-3.6-python-mpd2 inside Nixpkgs
if req.startswith("python-") or req.startswith("python_"):
names.append(req[len("python-") :])
for name in names:
# treat "-" and "_" equally
name = re.sub("[-_]", "[-_]", name)
# python(minor).(major)-(pname)-(version or unstable-date)
# we need the version qualifier, or we'll have multiple matches
# (e.g. pyserial and pyserial-asyncio when looking for pyserial)
pattern = re.compile(
f"^python\\d+\\.\\d+-{name}-(?:\\d|unstable-.*)", re.I
)
for attr_path, package in packages.items():
# logging.debug("Checking match for %s with %s", name, package["name"])
if pattern.match(package["name"]):
attr_paths.append(attr_path)
# Let's hope there's only one derivation with a matching name
assert len(attr_paths) <= 1, f"{req} matches more than one derivation: {attr_paths}"
if attr_paths:
return attr_paths[0]
return None
def provider_reqs_to_attr_paths(reqs: List, packages: Dict) -> List:
no_version_reqs = map(remove_version_constraint, reqs)
filtered_reqs = [
req for req in no_version_reqs if not re.match(r"^apache-airflow", req)
]
attr_paths = []
for req in filtered_reqs:
attr_path = name_to_attr_path(req, packages)
if attr_path is not None:
# Add attribute path without "python3Packages." prefix
pname = attr_path[len(PKG_SET + ".") :]
attr_paths.append(pname)
else:
# If we can't find it, we just skip and warn the user
logging.warning("Could not find package attr for %s", req)
return attr_paths
def get_cross_provider_reqs(
provider: str, provider_reqs: Dict, cross_provider_deps: Dict, seen: List = None
) -> Set:
# Unfortunately there are circular cross-provider dependencies, so keep a
# list of ones we've seen already
seen = seen or []
reqs = set(provider_reqs[provider])
if len(cross_provider_deps[provider]) > 0:
reqs.update(
chain.from_iterable(
get_cross_provider_reqs(
d, provider_reqs, cross_provider_deps, seen + [provider]
)
if d not in seen
else []
for d in cross_provider_deps[provider]
)
)
return reqs
def get_provider_reqs(version: str, packages: Dict) -> Dict:
provider_dependencies = get_file_from_github(
version, "generated/provider_dependencies.json"
)
provider_reqs = {}
cross_provider_deps = {}
for provider, provider_data in provider_dependencies.items():
provider_reqs[provider] = list(
provider_reqs_to_attr_paths(provider_data["deps"], packages)
) + EXTRA_REQS.get(provider, [])
cross_provider_deps[provider] = [
d for d in provider_data["cross-providers-deps"] if d != "common.sql"
]
transitive_provider_reqs = {}
# Add transitive cross-provider reqs
for provider in provider_reqs:
transitive_provider_reqs[provider] = get_cross_provider_reqs(
provider, provider_reqs, cross_provider_deps
)
return transitive_provider_reqs
def get_provider_yaml(version: str, provider: str) -> Dict:
provider_dir = provider.replace(".", "/")
path = f"airflow/providers/{provider_dir}/provider.yaml"
try:
return get_file_from_github(version, path)
except HTTPError:
logging.warning("Couldn't get provider yaml for %s", provider)
return {}
def get_provider_imports(version: str, providers) -> Dict:
provider_imports = {}
for provider in providers:
provider_yaml = get_provider_yaml(version, provider)
imports: List[str] = []
if "hooks" in provider_yaml:
imports.extend(
chain.from_iterable(
hook["python-modules"] for hook in provider_yaml["hooks"]
)
)
if "operators" in provider_yaml:
imports.extend(
chain.from_iterable(
operator["python-modules"]
for operator in provider_yaml["operators"]
)
)
provider_imports[provider] = imports
return provider_imports
def to_nix_expr(provider_reqs: Dict, provider_imports: Dict, fh: TextIO) -> None:
fh.write("# Warning: generated by update-providers.py, do not update manually\n")
fh.write("{\n")
for provider, reqs in provider_reqs.items():
provider_name = provider.replace(".", "_")
fh.write(f" {provider_name} = {{\n")
fh.write(
" deps = [ " + " ".join(sorted(f'"{req}"' for req in reqs)) + " ];\n"
)
fh.write(
" imports = [ "
+ " ".join(sorted(f'"{imp}"' for imp in provider_imports[provider]))
+ " ];\n"
)
fh.write(" };\n")
fh.write("}\n")
def main() -> None:
logging.basicConfig(level=logging.INFO)
version = get_version()
packages = dump_packages()
logging.info("Generating providers.nix for version %s", version)
provider_reqs = get_provider_reqs(version, packages)
provider_imports = get_provider_imports(version, provider_reqs.keys())
with open("providers.nix", "w") as fh:
to_nix_expr(provider_reqs, provider_imports, fh)
if __name__ == "__main__":
main()