2021-12-30 13:39:12 +00:00
{ lib
, pkgs
, stdenv
# Build-time dependencies:
, addOpenGLRunpath
2023-08-10 07:59:29 +00:00
, bazel_6
2021-12-30 13:39:12 +00:00
, binutils
, buildBazelPackage
, buildPythonPackage
2022-08-21 13:32:41 +00:00
, cctools
2022-09-09 14:08:57 +00:00
, curl
2021-12-30 13:39:12 +00:00
, cython
, fetchFromGitHub
, git
2022-08-21 13:32:41 +00:00
, IOKit
2021-12-30 13:39:12 +00:00
, jsoncpp
2022-09-09 14:08:57 +00:00
, nsync
, openssl
2021-12-30 13:39:12 +00:00
, pybind11
, setuptools
, symlinkJoin
, wheel
2023-08-10 07:59:29 +00:00
, build
2021-12-30 13:39:12 +00:00
, which
# Python dependencies:
, absl-py
, flatbuffers
2023-08-10 07:59:29 +00:00
, ml-dtypes
2021-12-30 13:39:12 +00:00
, numpy
, scipy
, six
# Runtime dependencies:
, double-conversion
, giflib
, libjpeg_turbo
, python
, snappy
, zlib
2023-08-04 22:07:22 +00:00
, config
2021-12-30 13:39:12 +00:00
# CUDA flags:
2023-08-04 22:07:22 +00:00
, cudaSupport ? config . cudaSupport
2024-01-02 11:29:13 +00:00
, cudaPackagesGoogle
2021-12-30 13:39:12 +00:00
# MKL:
, mklSupport ? true
} :
2021-09-18 10:52:07 +00:00
let
2024-01-02 11:29:13 +00:00
inherit ( cudaPackagesGoogle ) backendStdenv cudatoolkit cudaFlags cudnn nccl ;
2022-04-15 01:41:22 +00:00
2021-09-18 10:52:07 +00:00
pname = " j a x l i b " ;
2024-01-02 11:29:13 +00:00
version = " 0 . 4 . 2 3 " ;
2021-12-30 13:39:12 +00:00
meta = with lib ; {
description = " J A X i s A u t o g r a d a n d X L A , b r o u g h t t o g e t h e r f o r h i g h - p e r f o r m a n c e m a c h i n e l e a r n i n g r e s e a r c h . " ;
homepage = " h t t p s : / / g i t h u b . c o m / g o o g l e / j a x " ;
license = licenses . asl20 ;
maintainers = with maintainers ; [ ndl ] ;
2022-08-21 13:32:41 +00:00
platforms = platforms . unix ;
# aarch64-darwin is broken because of https://github.com/bazelbuild/rules_cc/pull/136
# however even with that fix applied, it doesn't work for everyone:
# https://github.com/NixOS/nixpkgs/pull/184395#issuecomment-1207287129
2023-05-24 13:37:59 +00:00
broken = stdenv . isDarwin ;
2021-12-30 13:39:12 +00:00
} ;
cudatoolkit_joined = symlinkJoin {
name = " ${ cudatoolkit . name } - m e r g e d " ;
paths = [
cudatoolkit . lib
cudatoolkit . out
] ++ lib . optionals ( lib . versionOlder cudatoolkit . version " 1 1 " ) [
# for some reason some of the required libs are in the targets/x86_64-linux
# directory; not sure why but this works around it
" ${ cudatoolkit } / t a r g e t s / ${ stdenv . system } "
] ;
} ;
cudatoolkit_cc_joined = symlinkJoin {
name = " ${ cudatoolkit . cc . name } - m e r g e d " ;
paths = [
2023-04-29 16:46:19 +00:00
backendStdenv . cc
2021-12-30 13:39:12 +00:00
binutils . bintools # for ar, dwp, nm, objcopy, objdump, strip
] ;
} ;
2021-09-18 10:52:07 +00:00
2023-04-12 12:48:02 +00:00
# Copy-paste from TF derivation.
# Most of these are not really used in jaxlib compilation but it's simpler to keep it
# 'as is' so that it's more compatible with TF derivation.
tf_system_libs = [
" a b s l _ p y "
" a s t o r _ a r c h i v e "
" a s t u n p a r s e _ a r c h i v e "
# Not packaged in nixpkgs
# "com_github_googleapis_googleapis"
# "com_github_googlecloudplatform_google_cloud_cpp"
2024-01-02 11:29:13 +00:00
# Issue with transitive dependencies after https://github.com/grpc/grpc/commit/f1d14f7f0b661bd200b7f269ef55dec870e7c108
# "com_github_grpc_grpc"
2023-08-10 07:59:29 +00:00
# ERROR: /build/output/external/bazel_tools/tools/proto/BUILD:25:6: no such target '@com_google_protobuf//:cc_toolchain':
# target 'cc_toolchain' not declared in package '' defined by /build/output/external/com_google_protobuf/BUILD.bazel
# "com_google_protobuf"
2023-04-12 12:48:02 +00:00
# Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)'
# "com_googlesource_code_re2"
" c u r l "
" c y t h o n "
" d i l l _ a r c h i v e "
" d o u b l e _ c o n v e r s i o n "
" f l a t b u f f e r s "
" f u n c t o o l s 3 2 _ a r c h i v e "
" g a s t _ a r c h i v e "
" g i f "
" h w l o c "
" i c u "
" j s o n c p p _ g i t "
" l i b j p e g _ t u r b o "
" l m d b "
" n a s m "
" o p t _ e i n s u m _ a r c h i v e "
" o r g _ s q l i t e "
" p a s t a "
" p n g "
2023-08-10 07:59:29 +00:00
# ERROR: /build/output/external/pybind11/BUILD.bazel: no such target '@pybind11//:osx':
# target 'osx' not declared in package '' defined by /build/output/external/pybind11/BUILD.bazel
# "pybind11"
2023-04-12 12:48:02 +00:00
" s i x _ a r c h i v e "
" s n a p p y "
" t b l i b _ a r c h i v e "
" t e r m c o l o r _ a r c h i v e "
" t y p i n g _ e x t e n s i o n s _ a r c h i v e "
" w r a p t "
" z l i b "
] ;
2023-05-24 13:37:59 +00:00
arch =
# KeyError: ('Linux', 'arm64')
2024-01-02 11:29:13 +00:00
if stdenv . hostPlatform . isLinux && stdenv . hostPlatform . linuxArch == " a r m 6 4 " then " a a r c h 6 4 "
else stdenv . hostPlatform . linuxArch ;
2023-05-24 13:37:59 +00:00
2023-04-12 12:48:02 +00:00
bazel-build = buildBazelPackage rec {
2021-12-30 13:39:12 +00:00
name = " b a z e l - b u i l d - ${ pname } - ${ version } " ;
2021-09-18 10:52:07 +00:00
2023-08-10 07:59:29 +00:00
# See https://github.com/google/jax/blob/main/.bazelversion for the latest.
bazel = bazel_6 ;
2021-12-30 13:39:12 +00:00
src = fetchFromGitHub {
owner = " g o o g l e " ;
repo = " j a x " ;
2023-04-29 16:46:19 +00:00
# google/jax contains tags for jax and jaxlib. Only use jaxlib tags!
rev = " r e f s / t a g s / ${ pname } - v ${ version } " ;
2024-01-02 11:29:13 +00:00
hash = " s h a 2 5 6 - P D a 3 y V H / s s z G b W k V k J + 1 9 F d O r 3 o q d Y k + O d b e U T M T D u U = " ;
2021-09-18 10:52:07 +00:00
} ;
2021-12-30 13:39:12 +00:00
nativeBuildInputs = [
cython
pkgs . flatbuffers
git
setuptools
wheel
2023-08-10 07:59:29 +00:00
build
2021-12-30 13:39:12 +00:00
which
2022-09-09 14:08:57 +00:00
] ++ lib . optionals stdenv . isDarwin [
cctools
2021-12-30 13:39:12 +00:00
] ;
buildInputs = [
2022-09-09 14:08:57 +00:00
curl
2021-12-30 13:39:12 +00:00
double-conversion
giflib
jsoncpp
libjpeg_turbo
numpy
2022-09-09 14:08:57 +00:00
openssl
2021-12-30 13:39:12 +00:00
pkgs . flatbuffers
2023-08-10 07:59:29 +00:00
pkgs . protobuf
2021-12-30 13:39:12 +00:00
pybind11
scipy
six
snappy
zlib
] ++ lib . optionals cudaSupport [
cudatoolkit
cudnn
2022-08-21 13:32:41 +00:00
] ++ lib . optionals stdenv . isDarwin [
IOKit
2022-09-09 14:08:57 +00:00
] ++ lib . optionals ( ! stdenv . isDarwin ) [
nsync
2021-12-30 13:39:12 +00:00
] ;
postPatch = ''
rm - f . bazelversion
'' ;
2023-08-10 07:59:29 +00:00
bazelRunTarget = " / / j a x l i b / t o o l s : b u i l d _ w h e e l " ;
runTargetFlags = [ " - - o u t p u t _ p a t h = $ o u t " " - - c p u = ${ arch } " ] ;
2021-12-30 13:39:12 +00:00
removeRulesCC = false ;
GCC_HOST_COMPILER_PREFIX = lib . optionalString cudaSupport " ${ cudatoolkit_cc_joined } / b i n " ;
GCC_HOST_COMPILER_PATH = lib . optionalString cudaSupport " ${ cudatoolkit_cc_joined } / b i n / g c c " ;
2023-10-09 19:29:22 +00:00
# The version is automatically set to ".dev" if this variable is not set.
# https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3
JAXLIB_RELEASE = " 1 " ;
2021-12-30 13:39:12 +00:00
preConfigure = ''
# dummy ldconfig
mkdir dummy-ldconfig
echo " # ! ${ stdenv . shell } " > dummy-ldconfig/ldconfig
chmod + x dummy-ldconfig/ldconfig
export PATH = " $ P W D / d u m m y - l d c o n f i g : $ P A T H "
cat < < CFG > ./.jax_configure.bazelrc
build - - strategy = Genrule = standalone
build - - repo_env PYTHON_BIN_PATH = " ${ python } / b i n / p y t h o n "
build - - action_env = PYENV_ROOT
build - - python_path = " ${ python } / b i n / p y t h o n "
build - - distinct_host_configuration = false
2023-08-10 07:59:29 +00:00
build - - define PROTOBUF_INCLUDE_PATH = " ${ pkgs . protobuf } / i n c l u d e "
2024-01-02 11:29:13 +00:00
'' + l i b . o p t i o n a l S t r i n g ( s t d e n v . h o s t P l a t f o r m . a v x S u p p o r t & & s t d e n v . h o s t P l a t f o r m . i s U n i x ) ''
2023-08-10 07:59:29 +00:00
build - - config = avx_posix
'' + l i b . o p t i o n a l S t r i n g m k l S u p p o r t ''
build - - config = mkl_open_source_only
2021-12-30 13:39:12 +00:00
'' + l i b . o p t i o n a l S t r i n g c u d a S u p p o r t ''
build - - action_env CUDA_TOOLKIT_PATH = " ${ cudatoolkit_joined } "
build - - action_env CUDNN_INSTALL_PATH = " ${ cudnn } "
build - - action_env TF_CUDA_PATHS = " ${ cudatoolkit_joined } , ${ cudnn } , ${ nccl } "
build - - action_env TF_CUDA_VERSION = " ${ lib . versions . majorMinor cudatoolkit . version } "
build - - action_env TF_CUDNN_VERSION = " ${ lib . versions . major cudnn . version } "
2023-03-08 16:32:21 +00:00
build:cuda - - action_env TF_CUDA_COMPUTE_CAPABILITIES = " ${ builtins . concatStringsSep " , " cudaFlags . realArches } "
2021-12-30 13:39:12 +00:00
'' + ''
CFG
'' ;
# Make sure Bazel knows about our configuration flags during fetching so that the
# relevant dependencies can be downloaded.
2022-08-21 13:32:41 +00:00
bazelFlags = [
2021-12-30 13:39:12 +00:00
" - c o p t "
2022-08-21 13:32:41 +00:00
] ++ lib . optionals stdenv . cc . isClang [
# bazel depends on the compiler frontend automatically selecting these flags based on file
# extension but our clang doesn't.
# https://github.com/NixOS/nixpkgs/issues/150655
" - - c x x o p t = - x " " - - c x x o p t = c + + " " - - h o s t _ c x x o p t = - x " " - - h o s t _ c x x o p t = c + + "
2021-12-30 13:39:12 +00:00
] ;
2023-04-12 12:48:02 +00:00
# We intentionally overfetch so we can share the fetch derivation across all the different configurations
2021-12-30 13:39:12 +00:00
fetchAttrs = {
2023-04-12 12:48:02 +00:00
TF_SYSTEM_LIBS = lib . concatStringsSep " , " tf_system_libs ;
# we have to force @mkl_dnn_v1 since it's not needed on darwin
2023-08-10 07:59:29 +00:00
bazelTargets = [ bazelRunTarget " @ m k l _ d n n _ v 1 / / : m k l _ d n n " ] ;
2023-04-12 12:48:02 +00:00
bazelFlags = bazelFlags ++ [
" - - c o n f i g = a v x _ p o s i x "
] ++ lib . optionals cudaSupport [
# ideally we'd add this unconditionally too, but it doesn't work on darwin
# we make this conditional on `cudaSupport` instead of the system, so that the hash for both
# the cuda and the non-cuda deps can be computed on linux, since a lot of contributors don't
# have access to darwin machines
" - - c o n f i g = c u d a "
] ++ [
" - - c o n f i g = m k l _ o p e n _ s o u r c e _ o n l y "
] ;
2023-08-10 07:59:29 +00:00
sha256 = ( if cudaSupport then {
2024-01-02 11:29:13 +00:00
x86_64-linux = " s h a 2 5 6 - q 2 w R a o C G n I S E d t F 6 j D M k 9 W c c y / w T m L u s V B I 7 d D A T w i 4 = " ;
2023-08-10 07:59:29 +00:00
} else {
2024-01-02 11:29:13 +00:00
x86_64-linux = " s h a 2 5 6 - 0 c D J 2 7 H C i 3 J 5 x e T 6 T k T t f U z F / y E S B Y m E V G 1 r 1 4 k P d R s = " ;
aarch64-linux = " s h a 2 5 6 - W b a N 8 V Y j e W 0 m D t h m t o S T t t q d 4 K / Z 8 d P 5 + V k T o 1 0 p L t U = " ;
2023-08-10 07:59:29 +00:00
} ) . ${ stdenv . system } or ( throw " j a x l i b : u n s u p p o r t e d s y s t e m : ${ stdenv . system } " ) ;
2021-12-30 13:39:12 +00:00
} ;
buildAttrs = {
outputs = [ " o u t " ] ;
2023-04-12 12:48:02 +00:00
TF_SYSTEM_LIBS = lib . concatStringsSep " , " ( tf_system_libs ++ lib . optionals ( ! stdenv . isDarwin ) [
" n s y n c " # fails to build on darwin
] ) ;
2021-12-30 13:39:12 +00:00
# Note: we cannot do most of this patching at `patch` phase as the deps are not available yet.
2023-08-10 07:59:29 +00:00
# 1) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on
2022-08-21 13:32:41 +00:00
# loading multiple extensions in the same python program due to duplicate protobuf DBs.
2023-08-10 07:59:29 +00:00
# 2) Patch python path in the compiler driver.
preBuild = lib . optionalString cudaSupport ''
2023-04-29 16:46:19 +00:00
export NIX_LDFLAGS + = " - L ${ backendStdenv . nixpkgsCompatibleLibstdcxx } / l i b "
2023-08-10 07:59:29 +00:00
patchShebangs ../output/external/xla/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
2022-08-21 13:32:41 +00:00
'' + l i b . o p t i o n a l S t r i n g s t d e n v . i s D a r w i n ''
# Framework search paths aren't added by bintools hook
# https://github.com/NixOS/nixpkgs/pull/41914
export NIX_LDFLAGS + = " - F ${ IOKit } / L i b r a r y / F r a m e w o r k s "
substituteInPlace ../output/external/rules_cc/cc/private/toolchain/osx_cc_wrapper.sh.tpl \
- - replace " / u s r / b i n / i n s t a l l _ n a m e _ t o o l " " ${ cctools } / b i n / i n s t a l l _ n a m e _ t o o l "
substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \
- - replace " / u s r / b i n / l i b t o o l " " ${ cctools } / b i n / l i b t o o l "
2023-11-16 04:20:00 +00:00
'' ;
2021-09-18 10:52:07 +00:00
} ;
2021-12-30 13:39:12 +00:00
inherit meta ;
} ;
2022-08-21 13:32:41 +00:00
platformTag =
2024-01-02 11:29:13 +00:00
if stdenv . hostPlatform . isLinux then
2023-05-24 13:37:59 +00:00
" m a n y l i n u x 2 0 1 4 _ ${ arch } "
2022-08-21 13:32:41 +00:00
else if stdenv . system == " x 8 6 _ 6 4 - d a r w i n " then
2023-05-24 13:37:59 +00:00
" m a c o s x _ 1 0 _ 9 _ ${ arch } "
2022-08-21 13:32:41 +00:00
else if stdenv . system == " a a r c h 6 4 - d a r w i n " then
2023-05-24 13:37:59 +00:00
" m a c o s x _ 1 1 _ 0 _ ${ arch } "
2024-01-02 11:29:13 +00:00
else throw " U n s u p p o r t e d t a r g e t p l a t f o r m : ${ stdenv . hostPlatform } " ;
2021-12-30 13:39:12 +00:00
in
buildPythonPackage {
inherit meta pname version ;
format = " w h e e l " ;
2022-10-30 15:09:59 +00:00
src =
let cp = " c p ${ builtins . replaceStrings [ " . " ] [ " " ] python . pythonVersion } " ;
in " ${ bazel-build } / j a x l i b - ${ version } - ${ cp } - ${ cp } - ${ platformTag } . w h l " ;
2021-12-30 13:39:12 +00:00
2022-03-30 09:31:56 +00:00
# Note that cudatoolkit is necessary since jaxlib looks for "ptxas" in $PATH.
# See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for
# more info.
2021-12-30 13:39:12 +00:00
postInstall = lib . optionalString cudaSupport ''
2022-03-30 09:31:56 +00:00
mkdir - p $ out/bin
ln - s $ { cudatoolkit } /bin/ptxas $ out/bin/ptxas
2021-12-30 13:39:12 +00:00
find $ out - type f \ ( - name ' * . so' - or - name ' * . so . * ' \ ) | while read lib ; do
addOpenGLRunpath " $ l i b "
patchelf - - set-rpath " ${ cudatoolkit } / l i b : ${ cudatoolkit . lib } / l i b : ${ cudnn } / l i b : ${ nccl } / l i b : $ ( p a t c h e l f - - p r i n t - r p a t h " $ lib " ) " " $ l i b "
2021-09-18 10:52:07 +00:00
done
'' ;
2021-12-30 13:39:12 +00:00
nativeBuildInputs = lib . optional cudaSupport addOpenGLRunpath ;
propagatedBuildInputs = [
absl-py
2022-09-09 14:08:57 +00:00
curl
2021-12-30 13:39:12 +00:00
double-conversion
flatbuffers
giflib
jsoncpp
libjpeg_turbo
2023-08-10 07:59:29 +00:00
ml-dtypes
2021-12-30 13:39:12 +00:00
numpy
scipy
six
snappy
] ;
2021-09-18 10:52:07 +00:00
2023-08-10 07:59:29 +00:00
pythonImportsCheck = [
" j a x l i b "
# `import jaxlib` loads surprisingly little. These imports are actually bugs that appeared in the 0.4.11 upgrade.
" j a x l i b . c p u _ f e a t u r e _ g u a r d "
" j a x l i b . x l a _ c l i e n t "
] ;
2021-09-18 10:52:07 +00:00
2021-12-30 13:39:12 +00:00
# Without it there are complaints about libcudart.so.11.0 not being found
# because RPATH path entries added above are stripped.
dontPatchELF = cudaSupport ;
2021-09-18 10:52:07 +00:00
}