5e7c2d6cef
GitOrigin-RevId: f99e5f03cc0aa231ab5950a15ed02afec45ed51a
228 lines
7.5 KiB
Python
Executable file
228 lines
7.5 KiB
Python
Executable file
#! /usr/bin/env python3
|
|
|
|
from itertools import chain
|
|
import json
|
|
import logging
|
|
from pathlib import Path
|
|
import os
|
|
import re
|
|
import subprocess
|
|
import sys
|
|
from typing import Dict, List, Optional, Set, TextIO
|
|
from urllib.request import urlopen
|
|
from urllib.error import HTTPError
|
|
import yaml
|
|
|
|
PKG_SET = "apache-airflow.pythonPackages"
|
|
|
|
# If some requirements are matched by multiple or no Python packages, the
|
|
# following can be used to choose the correct one
|
|
PKG_PREFERENCES = {
|
|
"dnspython": "dnspython",
|
|
"elasticsearch-dsl": "elasticsearch-dsl",
|
|
"google-api-python-client": "google-api-python-client",
|
|
"protobuf": "protobuf",
|
|
"psycopg2-binary": "psycopg2",
|
|
"requests_toolbelt": "requests-toolbelt",
|
|
}
|
|
|
|
# Requirements missing from the airflow provider metadata
|
|
EXTRA_REQS = {
|
|
"sftp": ["pysftp"],
|
|
}
|
|
|
|
|
|
def get_version():
|
|
with open(os.path.dirname(sys.argv[0]) + "/default.nix") as fh:
|
|
# A version consists of digits, dots, and possibly a "b" (for beta)
|
|
m = re.search('version = "([\\d\\.b]+)";', fh.read())
|
|
return m.group(1)
|
|
|
|
|
|
def get_file_from_github(version: str, path: str):
|
|
with urlopen(
|
|
f"https://raw.githubusercontent.com/apache/airflow/{version}/{path}"
|
|
) as response:
|
|
return yaml.safe_load(response)
|
|
|
|
|
|
def repository_root() -> Path:
|
|
return Path(os.path.dirname(sys.argv[0])) / "../../.."
|
|
|
|
|
|
def dump_packages() -> Dict[str, Dict[str, str]]:
|
|
# Store a JSON dump of Nixpkgs' python3Packages
|
|
output = subprocess.check_output(
|
|
[
|
|
"nix-env",
|
|
"-f",
|
|
repository_root(),
|
|
"-qa",
|
|
"-A",
|
|
PKG_SET,
|
|
"--arg",
|
|
"config",
|
|
"{ allowAliases = false; }",
|
|
"--json",
|
|
]
|
|
)
|
|
return json.loads(output)
|
|
|
|
|
|
def remove_version_constraint(req: str) -> str:
|
|
return re.sub(r"[=><~].*$", "", req)
|
|
|
|
|
|
def name_to_attr_path(req: str, packages: Dict[str, Dict[str, str]]) -> Optional[str]:
|
|
if req in PKG_PREFERENCES:
|
|
return f"{PKG_SET}.{PKG_PREFERENCES[req]}"
|
|
attr_paths = []
|
|
names = [req]
|
|
# E.g. python-mpd2 is actually called python3.6-mpd2
|
|
# instead of python-3.6-python-mpd2 inside Nixpkgs
|
|
if req.startswith("python-") or req.startswith("python_"):
|
|
names.append(req[len("python-") :])
|
|
for name in names:
|
|
# treat "-" and "_" equally
|
|
name = re.sub("[-_]", "[-_]", name)
|
|
# python(minor).(major)-(pname)-(version or unstable-date)
|
|
# we need the version qualifier, or we'll have multiple matches
|
|
# (e.g. pyserial and pyserial-asyncio when looking for pyserial)
|
|
pattern = re.compile(
|
|
f"^python\\d+\\.\\d+-{name}-(?:\\d|unstable-.*)", re.I
|
|
)
|
|
for attr_path, package in packages.items():
|
|
# logging.debug("Checking match for %s with %s", name, package["name"])
|
|
if pattern.match(package["name"]):
|
|
attr_paths.append(attr_path)
|
|
# Let's hope there's only one derivation with a matching name
|
|
assert len(attr_paths) <= 1, f"{req} matches more than one derivation: {attr_paths}"
|
|
if attr_paths:
|
|
return attr_paths[0]
|
|
return None
|
|
|
|
|
|
def provider_reqs_to_attr_paths(reqs: List, packages: Dict) -> List:
|
|
no_version_reqs = map(remove_version_constraint, reqs)
|
|
filtered_reqs = [
|
|
req for req in no_version_reqs if not re.match(r"^apache-airflow", req)
|
|
]
|
|
attr_paths = []
|
|
for req in filtered_reqs:
|
|
attr_path = name_to_attr_path(req, packages)
|
|
if attr_path is not None:
|
|
# Add attribute path without "python3Packages." prefix
|
|
pname = attr_path[len(PKG_SET + ".") :]
|
|
attr_paths.append(pname)
|
|
else:
|
|
# If we can't find it, we just skip and warn the user
|
|
logging.warning("Could not find package attr for %s", req)
|
|
return attr_paths
|
|
|
|
|
|
def get_cross_provider_reqs(
|
|
provider: str, provider_reqs: Dict, cross_provider_deps: Dict, seen: List = None
|
|
) -> Set:
|
|
# Unfortunately there are circular cross-provider dependencies, so keep a
|
|
# list of ones we've seen already
|
|
seen = seen or []
|
|
reqs = set(provider_reqs[provider])
|
|
if len(cross_provider_deps[provider]) > 0:
|
|
reqs.update(
|
|
chain.from_iterable(
|
|
get_cross_provider_reqs(
|
|
d, provider_reqs, cross_provider_deps, seen + [provider]
|
|
)
|
|
if d not in seen
|
|
else []
|
|
for d in cross_provider_deps[provider]
|
|
)
|
|
)
|
|
return reqs
|
|
|
|
|
|
def get_provider_reqs(version: str, packages: Dict) -> Dict:
|
|
provider_dependencies = get_file_from_github(
|
|
version, "generated/provider_dependencies.json"
|
|
)
|
|
provider_reqs = {}
|
|
cross_provider_deps = {}
|
|
for provider, provider_data in provider_dependencies.items():
|
|
provider_reqs[provider] = list(
|
|
provider_reqs_to_attr_paths(provider_data["deps"], packages)
|
|
) + EXTRA_REQS.get(provider, [])
|
|
cross_provider_deps[provider] = [
|
|
d for d in provider_data["cross-providers-deps"] if d != "common.sql"
|
|
]
|
|
transitive_provider_reqs = {}
|
|
# Add transitive cross-provider reqs
|
|
for provider in provider_reqs:
|
|
transitive_provider_reqs[provider] = get_cross_provider_reqs(
|
|
provider, provider_reqs, cross_provider_deps
|
|
)
|
|
return transitive_provider_reqs
|
|
|
|
|
|
def get_provider_yaml(version: str, provider: str) -> Dict:
|
|
provider_dir = provider.replace(".", "/")
|
|
path = f"airflow/providers/{provider_dir}/provider.yaml"
|
|
try:
|
|
return get_file_from_github(version, path)
|
|
except HTTPError:
|
|
logging.warning("Couldn't get provider yaml for %s", provider)
|
|
return {}
|
|
|
|
|
|
def get_provider_imports(version: str, providers) -> Dict:
|
|
provider_imports = {}
|
|
for provider in providers:
|
|
provider_yaml = get_provider_yaml(version, provider)
|
|
imports: List[str] = []
|
|
if "hooks" in provider_yaml:
|
|
imports.extend(
|
|
chain.from_iterable(
|
|
hook["python-modules"] for hook in provider_yaml["hooks"]
|
|
)
|
|
)
|
|
if "operators" in provider_yaml:
|
|
imports.extend(
|
|
chain.from_iterable(
|
|
operator["python-modules"]
|
|
for operator in provider_yaml["operators"]
|
|
)
|
|
)
|
|
provider_imports[provider] = imports
|
|
return provider_imports
|
|
|
|
|
|
def to_nix_expr(provider_reqs: Dict, provider_imports: Dict, fh: TextIO) -> None:
|
|
fh.write("# Warning: generated by update-providers.py, do not update manually\n")
|
|
fh.write("{\n")
|
|
for provider, reqs in provider_reqs.items():
|
|
provider_name = provider.replace(".", "_")
|
|
fh.write(f" {provider_name} = {{\n")
|
|
fh.write(
|
|
" deps = [ " + " ".join(sorted(f'"{req}"' for req in reqs)) + " ];\n"
|
|
)
|
|
fh.write(
|
|
" imports = [ "
|
|
+ " ".join(sorted(f'"{imp}"' for imp in provider_imports[provider]))
|
|
+ " ];\n"
|
|
)
|
|
fh.write(" };\n")
|
|
fh.write("}\n")
|
|
|
|
|
|
def main() -> None:
|
|
logging.basicConfig(level=logging.INFO)
|
|
version = get_version()
|
|
packages = dump_packages()
|
|
logging.info("Generating providers.nix for version %s", version)
|
|
provider_reqs = get_provider_reqs(version, packages)
|
|
provider_imports = get_provider_imports(version, provider_reqs.keys())
|
|
with open("providers.nix", "w") as fh:
|
|
to_nix_expr(provider_reqs, provider_imports, fh)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|