2024-06-05 15:53:02 +00:00
{
lib ,
pkgs ,
stdenv ,
2021-12-30 13:39:12 +00:00
# Build-time dependencies:
2024-06-05 15:53:02 +00:00
addOpenGLRunpath ,
autoAddDriverRunpath ,
bazel_6 ,
binutils ,
buildBazelPackage ,
buildPythonPackage ,
cctools ,
curl ,
cython ,
fetchFromGitHub ,
git ,
IOKit ,
jsoncpp ,
nsync ,
openssl ,
pybind11 ,
setuptools ,
symlinkJoin ,
wheel ,
build ,
which ,
2021-12-30 13:39:12 +00:00
# Python dependencies:
2024-06-05 15:53:02 +00:00
absl-py ,
flatbuffers ,
ml-dtypes ,
numpy ,
scipy ,
six ,
2021-12-30 13:39:12 +00:00
# Runtime dependencies:
2024-06-05 15:53:02 +00:00
double-conversion ,
giflib ,
libjpeg_turbo ,
python ,
snappy ,
zlib ,
config ,
2021-12-30 13:39:12 +00:00
# CUDA flags:
2024-06-05 15:53:02 +00:00
cudaSupport ? config . cudaSupport ,
cudaPackages ,
2021-12-30 13:39:12 +00:00
# MKL:
2024-06-05 15:53:02 +00:00
mklSupport ? true ,
2024-02-29 20:09:43 +00:00
} @ inputs :
2021-09-18 10:52:07 +00:00
let
2024-06-05 15:53:02 +00:00
inherit ( cudaPackages )
cudaFlags
cudaVersion
cudnn
nccl
;
2022-04-15 01:41:22 +00:00
2021-09-18 10:52:07 +00:00
pname = " j a x l i b " ;
2024-05-15 15:35:15 +00:00
version = " 0 . 4 . 2 8 " ;
2024-02-29 20:09:43 +00:00
# It's necessary to consistently use backendStdenv when building with CUDA
# support, otherwise we get libstdc++ errors downstream
stdenv = throw " U s e e f f e c t i v e S t d e n v i n s t e a d " ;
2024-05-15 15:35:15 +00:00
effectiveStdenv = if cudaSupport then cudaPackages . backendStdenv else inputs . stdenv ;
2021-12-30 13:39:12 +00:00
meta = with lib ; {
2024-06-20 14:57:18 +00:00
description = " J A X i s A u t o g r a d a n d X L A , b r o u g h t t o g e t h e r f o r h i g h - p e r f o r m a n c e m a c h i n e l e a r n i n g r e s e a r c h " ;
2021-12-30 13:39:12 +00:00
homepage = " h t t p s : / / g i t h u b . c o m / g o o g l e / j a x " ;
license = licenses . asl20 ;
maintainers = with maintainers ; [ ndl ] ;
2022-08-21 13:32:41 +00:00
platforms = platforms . unix ;
# aarch64-darwin is broken because of https://github.com/bazelbuild/rules_cc/pull/136
# however even with that fix applied, it doesn't work for everyone:
# https://github.com/NixOS/nixpkgs/pull/184395#issuecomment-1207287129
2024-01-13 08:15:51 +00:00
# NOTE: We always build with NCCL; if it is unsupported, then our build is broken.
2024-02-29 20:09:43 +00:00
broken = effectiveStdenv . isDarwin || nccl . meta . unsupported ;
2021-12-30 13:39:12 +00:00
} ;
2024-02-29 20:09:43 +00:00
# These are necessary at build time and run time.
cuda_libs_joined = symlinkJoin {
name = " c u d a - j o i n e d " ;
2024-05-15 15:35:15 +00:00
paths = with cudaPackages ; [
2024-02-29 20:09:43 +00:00
cuda_cudart . lib # libcudart.so
cuda_cudart . static # libcudart_static.a
cuda_cupti . lib # libcupti.so
libcublas . lib # libcublas.so
libcufft . lib # libcufft.so
libcurand . lib # libcurand.so
libcusolver . lib # libcusolver.so
libcusparse . lib # libcusparse.so
] ;
} ;
# These are only necessary at build time.
cuda_build_deps_joined = symlinkJoin {
name = " c u d a - b u i l d - d e p s - j o i n e d " ;
2024-05-15 15:35:15 +00:00
paths = with cudaPackages ; [
2024-02-29 20:09:43 +00:00
cuda_libs_joined
# Binaries
2024-05-15 15:35:15 +00:00
cudaPackages . cuda_nvcc . bin # nvcc
2024-02-29 20:09:43 +00:00
# Headers
cuda_cccl . dev # block_load.cuh
cuda_cudart . dev # cuda.h
cuda_cupti . dev # cupti.h
cuda_nvcc . dev # See https://github.com/google/jax/issues/19811
cuda_nvml_dev # nvml.h
cuda_nvtx . dev # nvToolsExt.h
libcublas . dev # cublas_api.h
libcufft . dev # cufft.h
libcurand . dev # curand.h
libcusolver . dev # cusolver_common.h
libcusparse . dev # cusparse.h
2021-12-30 13:39:12 +00:00
] ;
} ;
2024-02-29 20:09:43 +00:00
backend_cc_joined = symlinkJoin {
name = " c u d a - c c - j o i n e d " ;
2021-12-30 13:39:12 +00:00
paths = [
2024-02-29 20:09:43 +00:00
effectiveStdenv . cc
2021-12-30 13:39:12 +00:00
binutils . bintools # for ar, dwp, nm, objcopy, objdump, strip
] ;
} ;
2021-09-18 10:52:07 +00:00
2023-04-12 12:48:02 +00:00
# Copy-paste from TF derivation.
# Most of these are not really used in jaxlib compilation but it's simpler to keep it
# 'as is' so that it's more compatible with TF derivation.
tf_system_libs = [
" a b s l _ p y "
" a s t o r _ a r c h i v e "
" a s t u n p a r s e _ a r c h i v e "
# Not packaged in nixpkgs
# "com_github_googleapis_googleapis"
# "com_github_googlecloudplatform_google_cloud_cpp"
2024-01-02 11:29:13 +00:00
# Issue with transitive dependencies after https://github.com/grpc/grpc/commit/f1d14f7f0b661bd200b7f269ef55dec870e7c108
# "com_github_grpc_grpc"
2023-08-10 07:59:29 +00:00
# ERROR: /build/output/external/bazel_tools/tools/proto/BUILD:25:6: no such target '@com_google_protobuf//:cc_toolchain':
# target 'cc_toolchain' not declared in package '' defined by /build/output/external/com_google_protobuf/BUILD.bazel
# "com_google_protobuf"
2023-04-12 12:48:02 +00:00
# Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)'
# "com_googlesource_code_re2"
" c u r l "
" c y t h o n "
" d i l l _ a r c h i v e "
" d o u b l e _ c o n v e r s i o n "
" f l a t b u f f e r s "
" f u n c t o o l s 3 2 _ a r c h i v e "
" g a s t _ a r c h i v e "
" g i f "
" h w l o c "
" i c u "
" j s o n c p p _ g i t "
" l i b j p e g _ t u r b o "
" l m d b "
" n a s m "
" o p t _ e i n s u m _ a r c h i v e "
" o r g _ s q l i t e "
" p a s t a "
" p n g "
2023-08-10 07:59:29 +00:00
# ERROR: /build/output/external/pybind11/BUILD.bazel: no such target '@pybind11//:osx':
# target 'osx' not declared in package '' defined by /build/output/external/pybind11/BUILD.bazel
# "pybind11"
2023-04-12 12:48:02 +00:00
" s i x _ a r c h i v e "
" s n a p p y "
" t b l i b _ a r c h i v e "
" t e r m c o l o r _ a r c h i v e "
" t y p i n g _ e x t e n s i o n s _ a r c h i v e "
" w r a p t "
" z l i b "
] ;
2023-05-24 13:37:59 +00:00
arch =
# KeyError: ('Linux', 'arm64')
2024-06-05 15:53:02 +00:00
if effectiveStdenv . hostPlatform . isLinux && effectiveStdenv . hostPlatform . linuxArch == " a r m 6 4 " then
" a a r c h 6 4 "
else
effectiveStdenv . hostPlatform . linuxArch ;
2024-02-29 20:09:43 +00:00
xla = effectiveStdenv . mkDerivation {
pname = " x l a - s r c " ;
version = " u n s t a b l e " ;
src = fetchFromGitHub {
owner = " o p e n x l a " ;
repo = " x l a " ;
# Update this according to https://github.com/google/jax/blob/jaxlib-v${version}/third_party/xla/workspace.bzl.
2024-05-15 15:35:15 +00:00
rev = " e 8 2 4 7 c 3 e a 1 d 4 d 7 f 3 1 c f 2 7 d e f 4 c 7 a c 6 f 2 c e 6 4 e c d 4 " ;
hash = " s h a 2 5 6 - Z h g M I V s 3 Z 4 d T r k R W D q a P C / i 7 y J z 2 d s Y X r Z b j z q v P X 3 E = " ;
2024-02-29 20:09:43 +00:00
} ;
dontBuild = true ;
# This is necessary for patchShebangs to know the right path to use.
nativeBuildInputs = [ python ] ;
# Main culprits we're targeting are third_party/tsl/third_party/gpus/crosstool/clang/bin/*.tpl
postPatch = ''
patchShebangs .
'' ;
installPhase = ''
cp - r . $ out
'' ;
} ;
2023-05-24 13:37:59 +00:00
2023-04-12 12:48:02 +00:00
bazel-build = buildBazelPackage rec {
2021-12-30 13:39:12 +00:00
name = " b a z e l - b u i l d - ${ pname } - ${ version } " ;
2021-09-18 10:52:07 +00:00
2023-08-10 07:59:29 +00:00
# See https://github.com/google/jax/blob/main/.bazelversion for the latest.
bazel = bazel_6 ;
2021-12-30 13:39:12 +00:00
src = fetchFromGitHub {
owner = " g o o g l e " ;
repo = " j a x " ;
2023-04-29 16:46:19 +00:00
# google/jax contains tags for jax and jaxlib. Only use jaxlib tags!
rev = " r e f s / t a g s / ${ pname } - v ${ version } " ;
2024-05-15 15:35:15 +00:00
hash = " s h a 2 5 6 - q S H P w i 3 i s 6 T s 7 p z 5 s 4 K z Q H B M b c j G p + v A O s e j W 3 o 3 6 E k = " ;
2021-09-18 10:52:07 +00:00
} ;
2021-12-30 13:39:12 +00:00
nativeBuildInputs = [
cython
pkgs . flatbuffers
git
setuptools
wheel
2023-08-10 07:59:29 +00:00
build
2021-12-30 13:39:12 +00:00
which
2024-06-05 15:53:02 +00:00
] ++ lib . optionals effectiveStdenv . isDarwin [ cctools ] ;
buildInputs =
[
curl
double-conversion
giflib
jsoncpp
libjpeg_turbo
numpy
openssl
pkgs . flatbuffers
pkgs . protobuf
pybind11
scipy
six
snappy
zlib
]
++ lib . optionals effectiveStdenv . isDarwin [ IOKit ]
++ lib . optionals ( ! effectiveStdenv . isDarwin ) [ nsync ] ;
2021-12-30 13:39:12 +00:00
2024-02-29 20:09:43 +00:00
# We don't want to be quite so picky regarding bazel version
2021-12-30 13:39:12 +00:00
postPatch = ''
rm - f . bazelversion
'' ;
2023-08-10 07:59:29 +00:00
bazelRunTarget = " / / j a x l i b / t o o l s : b u i l d _ w h e e l " ;
2024-02-29 20:09:43 +00:00
runTargetFlags = [
" - - o u t p u t _ p a t h = $ o u t "
" - - c p u = ${ arch } "
# This has no impact whatsoever...
" - - j a x l i b _ g i t _ h a s h = ' 1 2 3 4 5 6 7 8 ' "
] ;
2021-12-30 13:39:12 +00:00
removeRulesCC = false ;
2024-02-29 20:09:43 +00:00
GCC_HOST_COMPILER_PREFIX = lib . optionalString cudaSupport " ${ backend_cc_joined } / b i n " ;
GCC_HOST_COMPILER_PATH = lib . optionalString cudaSupport " ${ backend_cc_joined } / b i n / g c c " ;
2021-12-30 13:39:12 +00:00
2023-10-09 19:29:22 +00:00
# The version is automatically set to ".dev" if this variable is not set.
# https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3
JAXLIB_RELEASE = " 1 " ;
2024-02-29 20:09:43 +00:00
preConfigure =
# Dummy ldconfig to work around "Can't open cache file /nix/store/<hash>-glibc-2.38-44/etc/ld.so.cache" error
''
mkdir dummy-ldconfig
echo " # ! ${ effectiveStdenv . shell } " > dummy-ldconfig/ldconfig
chmod + x dummy-ldconfig/ldconfig
export PATH = " $ P W D / d u m m y - l d c o n f i g : $ P A T H "
''
2024-06-05 15:53:02 +00:00
+
# Construct .jax_configure.bazelrc. See https://github.com/google/jax/blob/b9824d7de3cb30f1df738cc42e486db3e9d915ff/build/build.py#L259-L345
# for more info. We assume
# * `cpu = None`
# * `enable_nccl = True`
# * `target_cpu_features = "release"`
# * `rocm_amdgpu_targets = None`
# * `enable_rocm = False`
# * `build_gpu_plugin = False`
# * `use_clang = False` (Should we use `effectiveStdenv.cc.isClang` instead?)
#
# Note: We should try just running https://github.com/google/jax/blob/ceb198582b62b9e6f6bdf20ab74839b0cf1db16e/build/build.py#L259-L266
# instead of duplicating the logic here. Perhaps we can leverage the
# `--configure_only` flag (https://github.com/google/jax/blob/ceb198582b62b9e6f6bdf20ab74839b0cf1db16e/build/build.py#L544-L548)?
''
cat < < CFG > ./.jax_configure.bazelrc
build - - strategy = Genrule = standalone
build - - repo_env PYTHON_BIN_PATH = " ${ python } / b i n / p y t h o n "
build - - action_env = PYENV_ROOT
build - - python_path = " ${ python } / b i n / p y t h o n "
build - - distinct_host_configuration = false
build - - define PROTOBUF_INCLUDE_PATH = " ${ pkgs . protobuf } / i n c l u d e "
''
+ lib . optionalString cudaSupport ''
2024-02-29 20:09:43 +00:00
build - - config = cuda
build - - action_env CUDA_TOOLKIT_PATH = " ${ cuda_build_deps_joined } "
build - - action_env CUDNN_INSTALL_PATH = " ${ cudnn } "
build - - action_env TF_CUDA_PATHS = " ${ cuda_build_deps_joined } , ${ cudnn } , ${ nccl } "
build - - action_env TF_CUDA_VERSION = " ${ lib . versions . majorMinor cudaVersion } "
build - - action_env TF_CUDNN_VERSION = " ${ lib . versions . major cudnn . version } "
build:cuda - - action_env TF_CUDA_COMPUTE_CAPABILITIES = " ${ builtins . concatStringsSep " , " cudaFlags . realArches } "
2024-06-05 15:53:02 +00:00
''
+
# Note that upstream conditions this on `wheel_cpu == "x86_64"`. We just
# rely on `effectiveStdenv.hostPlatform.avxSupport` instead. So far so
# good. See https://github.com/google/jax/blob/b9824d7de3cb30f1df738cc42e486db3e9d915ff/build/build.py#L322
# for upstream's version.
lib . optionalString ( effectiveStdenv . hostPlatform . avxSupport && effectiveStdenv . hostPlatform . isUnix )
''
build - - config = avx_posix
''
+ lib . optionalString mklSupport ''
2024-02-29 20:09:43 +00:00
build - - config = mkl_open_source_only
''
2024-06-05 15:53:02 +00:00
+ ''
2024-02-29 20:09:43 +00:00
CFG
'' ;
2021-12-30 13:39:12 +00:00
# Make sure Bazel knows about our configuration flags during fetching so that the
# relevant dependencies can be downloaded.
2024-06-05 15:53:02 +00:00
bazelFlags =
[
" - c o p t "
# See https://bazel.build/external/advanced#overriding-repositories for
# information on --override_repository flag.
" - - o v e r r i d e _ r e p o s i t o r y = x l a = ${ xla } "
]
++ lib . optionals effectiveStdenv . cc . isClang [
# bazel depends on the compiler frontend automatically selecting these flags based on file
# extension but our clang doesn't.
# https://github.com/NixOS/nixpkgs/issues/150655
" - - c x x o p t = - x "
" - - c x x o p t = c + + "
" - - h o s t _ c x x o p t = - x "
" - - h o s t _ c x x o p t = c + + "
] ;
2021-12-30 13:39:12 +00:00
2023-04-12 12:48:02 +00:00
# We intentionally overfetch so we can share the fetch derivation across all the different configurations
2021-12-30 13:39:12 +00:00
fetchAttrs = {
2023-04-12 12:48:02 +00:00
TF_SYSTEM_LIBS = lib . concatStringsSep " , " tf_system_libs ;
# we have to force @mkl_dnn_v1 since it's not needed on darwin
2024-06-05 15:53:02 +00:00
bazelTargets = [
bazelRunTarget
" @ m k l _ d n n _ v 1 / / : m k l _ d n n "
2023-04-12 12:48:02 +00:00
] ;
2024-06-05 15:53:02 +00:00
bazelFlags =
bazelFlags
++ [
" - - c o n f i g = a v x _ p o s i x "
" - - c o n f i g = m k l _ o p e n _ s o u r c e _ o n l y "
]
++ lib . optionals cudaSupport [
# ideally we'd add this unconditionally too, but it doesn't work on darwin
# we make this conditional on `cudaSupport` instead of the system, so that the hash for both
# the cuda and the non-cuda deps can be computed on linux, since a lot of contributors don't
# have access to darwin machines
" - - c o n f i g = c u d a "
] ;
sha256 =
(
if cudaSupport then
{ x86_64-linux = " s h a 2 5 6 - V G N M f 5 / D g X b g s u 1 w 5 J 1 P m r u k w + 7 U O 3 1 B N U + c r K V s X 5 k = " ; }
else
{
x86_64-linux = " s h a 2 5 6 - u O o A y M B L H P X 6 j z d N 4 3 b 5 w Z V 5 e W 0 y I 8 s C D D 7 B S X 2 h 4 o Q = " ;
aarch64-linux = " s h a 2 5 6 - + S n G K Y 9 L I T 1 Q h u / x 6 U h 7 s H R a A E j l c / / q y K j 1 m 4 t 1 6 P A = " ;
}
) . ${ effectiveStdenv . system } or ( throw " j a x l i b : u n s u p p o r t e d s y s t e m : ${ effectiveStdenv . system } " ) ;
2021-12-30 13:39:12 +00:00
} ;
buildAttrs = {
outputs = [ " o u t " ] ;
2024-06-05 15:53:02 +00:00
TF_SYSTEM_LIBS = lib . concatStringsSep " , " (
tf_system_libs
++ lib . optionals ( ! effectiveStdenv . isDarwin ) [
" n s y n c " # fails to build on darwin
]
) ;
2023-04-12 12:48:02 +00:00
2024-02-29 20:09:43 +00:00
# Note: we cannot do most of this patching at `patch` phase as the deps
# are not available yet. Framework search paths aren't added by bintools
# hook. See https://github.com/NixOS/nixpkgs/pull/41914.
preBuild = lib . optionalString effectiveStdenv . isDarwin ''
2022-08-21 13:32:41 +00:00
export NIX_LDFLAGS + = " - F ${ IOKit } / L i b r a r y / F r a m e w o r k s "
substituteInPlace ../output/external/rules_cc/cc/private/toolchain/osx_cc_wrapper.sh.tpl \
- - replace " / u s r / b i n / i n s t a l l _ n a m e _ t o o l " " ${ cctools } / b i n / i n s t a l l _ n a m e _ t o o l "
substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \
- - replace " / u s r / b i n / l i b t o o l " " ${ cctools } / b i n / l i b t o o l "
2023-11-16 04:20:00 +00:00
'' ;
2021-09-18 10:52:07 +00:00
} ;
2021-12-30 13:39:12 +00:00
inherit meta ;
} ;
2022-08-21 13:32:41 +00:00
platformTag =
2024-02-29 20:09:43 +00:00
if effectiveStdenv . hostPlatform . isLinux then
2023-05-24 13:37:59 +00:00
" m a n y l i n u x 2 0 1 4 _ ${ arch } "
2024-02-29 20:09:43 +00:00
else if effectiveStdenv . system == " x 8 6 _ 6 4 - d a r w i n " then
2023-05-24 13:37:59 +00:00
" m a c o s x _ 1 0 _ 9 _ ${ arch } "
2024-02-29 20:09:43 +00:00
else if effectiveStdenv . system == " a a r c h 6 4 - d a r w i n " then
2023-05-24 13:37:59 +00:00
" m a c o s x _ 1 1 _ 0 _ ${ arch } "
2024-06-05 15:53:02 +00:00
else
throw " U n s u p p o r t e d t a r g e t p l a t f o r m : ${ effectiveStdenv . hostPlatform } " ;
2021-12-30 13:39:12 +00:00
in
buildPythonPackage {
inherit meta pname version ;
format = " w h e e l " ;
2022-10-30 15:09:59 +00:00
src =
2024-06-05 15:53:02 +00:00
let
cp = " c p ${ builtins . replaceStrings [ " . " ] [ " " ] python . pythonVersion } " ;
in
" ${ bazel-build } / j a x l i b - ${ version } - ${ cp } - ${ cp } - ${ platformTag } . w h l " ;
2021-12-30 13:39:12 +00:00
2024-02-29 20:09:43 +00:00
# Note that jaxlib looks for "ptxas" in $PATH. See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621
# for more info.
2021-12-30 13:39:12 +00:00
postInstall = lib . optionalString cudaSupport ''
2022-03-30 09:31:56 +00:00
mkdir - p $ out/bin
2024-05-15 15:35:15 +00:00
ln - s $ { cudaPackages . cuda_nvcc . bin } /bin/ptxas $ out/bin/ptxas
2022-03-30 09:31:56 +00:00
2021-12-30 13:39:12 +00:00
find $ out - type f \ ( - name ' * . so' - or - name ' * . so . * ' \ ) | while read lib ; do
2024-06-05 15:53:02 +00:00
patchelf - - add-rpath " ${
lib . makeLibraryPath [
cuda_libs_joined
cudnn
nccl
]
} " " $ lib "
2021-09-18 10:52:07 +00:00
done
'' ;
2024-04-21 15:54:59 +00:00
nativeBuildInputs = lib . optionals cudaSupport [ autoAddDriverRunpath ] ;
2021-12-30 13:39:12 +00:00
2024-05-15 15:35:15 +00:00
dependencies = [
2021-12-30 13:39:12 +00:00
absl-py
2022-09-09 14:08:57 +00:00
curl
2021-12-30 13:39:12 +00:00
double-conversion
flatbuffers
giflib
jsoncpp
libjpeg_turbo
2023-08-10 07:59:29 +00:00
ml-dtypes
2021-12-30 13:39:12 +00:00
numpy
scipy
six
snappy
] ;
2021-09-18 10:52:07 +00:00
2023-08-10 07:59:29 +00:00
pythonImportsCheck = [
" j a x l i b "
# `import jaxlib` loads surprisingly little. These imports are actually bugs that appeared in the 0.4.11 upgrade.
" j a x l i b . c p u _ f e a t u r e _ g u a r d "
" j a x l i b . x l a _ c l i e n t "
] ;
2021-09-18 10:52:07 +00:00
2021-12-30 13:39:12 +00:00
# Without it there are complaints about libcudart.so.11.0 not being found
# because RPATH path entries added above are stripped.
dontPatchELF = cudaSupport ;
2021-09-18 10:52:07 +00:00
}