2021-12-30 13:39:12 +00:00
{ lib
, pkgs
, stdenv
# Build-time dependencies:
, addOpenGLRunpath
2022-02-20 05:27:41 +00:00
, bazel_5
2021-12-30 13:39:12 +00:00
, binutils
, buildBazelPackage
, buildPythonPackage
, cython
, fetchFromGitHub
, git
, jsoncpp
, pybind11
, setuptools
, symlinkJoin
, wheel
, which
# Build-time and runtime CUDA dependencies:
, cudatoolkit ? null
, cudnn ? null
, nccl ? null
2021-09-18 10:52:07 +00:00
2021-12-30 13:39:12 +00:00
# Python dependencies:
, absl-py
, flatbuffers
, numpy
, scipy
, six
# Runtime dependencies:
, double-conversion
, giflib
, grpc
, libjpeg_turbo
, python
, snappy
, zlib
# CUDA flags:
, cudaCapabilities ? [ " s m _ 3 5 " " s m _ 5 0 " " s m _ 6 0 " " s m _ 7 0 " " s m _ 7 5 " " c o m p u t e _ 8 0 " ]
, cudaSupport ? false
# MKL:
, mklSupport ? true
} :
2021-09-18 10:52:07 +00:00
let
2021-12-30 13:39:12 +00:00
2021-09-18 10:52:07 +00:00
pname = " j a x l i b " ;
2022-02-20 05:27:41 +00:00
version = " 0 . 3 . 0 " ;
2021-12-30 13:39:12 +00:00
meta = with lib ; {
description = " J A X i s A u t o g r a d a n d X L A , b r o u g h t t o g e t h e r f o r h i g h - p e r f o r m a n c e m a c h i n e l e a r n i n g r e s e a r c h . " ;
homepage = " h t t p s : / / g i t h u b . c o m / g o o g l e / j a x " ;
license = licenses . asl20 ;
maintainers = with maintainers ; [ ndl ] ;
} ;
cudatoolkit_joined = symlinkJoin {
name = " ${ cudatoolkit . name } - m e r g e d " ;
paths = [
cudatoolkit . lib
cudatoolkit . out
] ++ lib . optionals ( lib . versionOlder cudatoolkit . version " 1 1 " ) [
# for some reason some of the required libs are in the targets/x86_64-linux
# directory; not sure why but this works around it
" ${ cudatoolkit } / t a r g e t s / ${ stdenv . system } "
] ;
} ;
cudatoolkit_cc_joined = symlinkJoin {
name = " ${ cudatoolkit . cc . name } - m e r g e d " ;
paths = [
cudatoolkit . cc
binutils . bintools # for ar, dwp, nm, objcopy, objdump, strip
] ;
} ;
2021-09-18 10:52:07 +00:00
2021-12-30 13:39:12 +00:00
bazel-build = buildBazelPackage {
name = " b a z e l - b u i l d - ${ pname } - ${ version } " ;
2021-09-18 10:52:07 +00:00
2022-02-20 05:27:41 +00:00
bazel = bazel_5 ;
2021-12-30 13:39:12 +00:00
src = fetchFromGitHub {
owner = " g o o g l e " ;
repo = " j a x " ;
rev = " ${ pname } - v ${ version } " ;
2022-02-20 05:27:41 +00:00
sha256 = " 0 n d p n g x 5 k 6 l f 6 j q j c k 8 2 b b p 0 g s 9 4 3 z 0 w h 7 v s 9 g w b y k 2 b w 0 d a 7 w 7 2 " ;
2021-09-18 10:52:07 +00:00
} ;
2021-12-30 13:39:12 +00:00
nativeBuildInputs = [
cython
pkgs . flatbuffers
git
setuptools
wheel
which
] ;
buildInputs = [
double-conversion
giflib
grpc
jsoncpp
libjpeg_turbo
numpy
pkgs . flatbuffers
pkgs . protobuf
pybind11
scipy
six
snappy
zlib
] ++ lib . optionals cudaSupport [
cudatoolkit
cudnn
] ;
postPatch = ''
rm - f . bazelversion
'' ;
bazelTarget = " / / b u i l d : b u i l d _ w h e e l " ;
removeRulesCC = false ;
GCC_HOST_COMPILER_PREFIX = lib . optionalString cudaSupport " ${ cudatoolkit_cc_joined } / b i n " ;
GCC_HOST_COMPILER_PATH = lib . optionalString cudaSupport " ${ cudatoolkit_cc_joined } / b i n / g c c " ;
preConfigure = ''
# dummy ldconfig
mkdir dummy-ldconfig
echo " # ! ${ stdenv . shell } " > dummy-ldconfig/ldconfig
chmod + x dummy-ldconfig/ldconfig
export PATH = " $ P W D / d u m m y - l d c o n f i g : $ P A T H "
cat < < CFG > ./.jax_configure.bazelrc
build - - strategy = Genrule = standalone
build - - repo_env PYTHON_BIN_PATH = " ${ python } / b i n / p y t h o n "
build - - action_env = PYENV_ROOT
build - - python_path = " ${ python } / b i n / p y t h o n "
build - - distinct_host_configuration = false
'' + l i b . o p t i o n a l S t r i n g c u d a S u p p o r t ''
build - - action_env CUDA_TOOLKIT_PATH = " ${ cudatoolkit_joined } "
build - - action_env CUDNN_INSTALL_PATH = " ${ cudnn } "
build - - action_env TF_CUDA_PATHS = " ${ cudatoolkit_joined } , ${ cudnn } , ${ nccl } "
build - - action_env TF_CUDA_VERSION = " ${ lib . versions . majorMinor cudatoolkit . version } "
build - - action_env TF_CUDNN_VERSION = " ${ lib . versions . major cudnn . version } "
build:cuda - - action_env TF_CUDA_COMPUTE_CAPABILITIES = " ${ lib . concatStringsSep " , " cudaCapabilities } "
'' + ''
CFG
'' ;
# Copy-paste from TF derivation.
# Most of these are not really used in jaxlib compilation but it's simpler to keep it
# 'as is' so that it's more compatible with TF derivation.
TF_SYSTEM_LIBS = lib . concatStringsSep " , " [
" a b s l _ p y "
" a s t o r _ a r c h i v e "
" a s t u n p a r s e _ a r c h i v e "
" b o r i n g s s l "
# Not packaged in nixpkgs
# "com_github_googleapis_googleapis"
# "com_github_googlecloudplatform_google_cloud_cpp"
" c o m _ g i t h u b _ g r p c _ g r p c "
" c o m _ g o o g l e _ p r o t o b u f "
# Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)'
# "com_googlesource_code_re2"
" c u r l "
" c y t h o n "
" d i l l _ a r c h i v e "
" d o u b l e _ c o n v e r s i o n "
" e n u m 3 4 _ a r c h i v e "
" f l a t b u f f e r s "
" f u n c t o o l s 3 2 _ a r c h i v e "
" g a s t _ a r c h i v e "
" g i f "
" h w l o c "
" i c u "
" j s o n c p p _ g i t "
" l i b j p e g _ t u r b o "
" l m d b "
" n a s m "
# "nsync" # not packaged in nixpkgs
" o p t _ e i n s u m _ a r c h i v e "
" o r g _ s q l i t e "
" p a s t a "
" p c r e "
" p n g "
" p y b i n d 1 1 "
" s i x _ a r c h i v e "
" s n a p p y "
" t b l i b _ a r c h i v e "
" t e r m c o l o r _ a r c h i v e "
" t y p i n g _ e x t e n s i o n s _ a r c h i v e "
" w r a p t "
" z l i b "
] ;
# Make sure Bazel knows about our configuration flags during fetching so that the
# relevant dependencies can be downloaded.
bazelFetchFlags = bazel-build . bazelBuildFlags ;
bazelBuildFlags = [
" - c o p t "
] ++ lib . optional ( stdenv . targetPlatform . isx86_64 && stdenv . targetPlatform . isUnix ) [
" - - c o n f i g = a v x _ p o s i x "
] ++ lib . optional cudaSupport [
" - - c o n f i g = c u d a "
] ++ lib . optional mklSupport [
" - - c o n f i g = m k l _ o p e n _ s o u r c e _ o n l y "
] ;
fetchAttrs = {
sha256 =
if cudaSupport then
2022-02-20 05:27:41 +00:00
" 1 k 0 r j x q j m 7 0 3 g d 9 n a v w z x 5 x 3 8 7 4 b 4 d x a m r 6 2 m 1 f x h m 7 9 d 2 7 1 z x i s "
2021-12-30 13:39:12 +00:00
else
2022-02-20 05:27:41 +00:00
" 0 i v a h 1 w 4 1 j c j 1 3 j m 7 4 0 q z w x 5 h 0 i a 8 v b j 7 1 p j g d 0 z r f k 3 c 9 2 k l l 4 1 " ;
2021-12-30 13:39:12 +00:00
} ;
buildAttrs = {
outputs = [ " o u t " ] ;
# Note: we cannot do most of this patching at `patch` phase as the deps are not available yet.
# 1) Fix pybind11 include paths.
# 2) Force static protobuf linkage to prevent crashes on loading multiple extensions
# in the same python program due to duplicate protobuf DBs.
# 3) Patch python path in the compiler driver.
2022-02-20 05:27:41 +00:00
# 4) Patch tensorflow sources to work with later versions of protobuf. See
# https://github.com/google/jax/issues/9534. Note that this should be
# removed on the next release after 0.3.0.
2021-12-30 13:39:12 +00:00
preBuild = ''
for src in ./jaxlib /* . { c c , h } ; d o
sed - i ' s @ include/pybind11 @ pybind11 @ g' $ src
done
sed - i ' s @ - lprotobuf @ - l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
sed - i ' s @ - lprotoc @ - l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
2022-02-20 05:27:41 +00:00
substituteInPlace ../output/external/org_tensorflow/tensorflow/compiler/xla/python/pprof_profile_builder.cc \
- - replace " s t a t u s . m e s s a g e ( ) " " s t d : : s t r i n g { s t a t u s . m e s s a g e ( ) } "
2021-12-30 13:39:12 +00:00
'' + l i b . o p t i o n a l S t r i n g c u d a S u p p o r t ''
patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
'' ;
installPhase = ''
./bazel-bin/build/build_wheel - - output_path = $ out - - cpu = $ { stdenv . targetPlatform . linuxArch }
'' ;
2021-09-18 10:52:07 +00:00
} ;
2021-12-30 13:39:12 +00:00
inherit meta ;
} ;
in
buildPythonPackage {
inherit meta pname version ;
format = " w h e e l " ;
src = " ${ bazel-build } / j a x l i b - ${ version } - c p ${ builtins . replaceStrings [ " . " ] [ " " ] python . pythonVersion } - n o n e - m a n y l i n u x 2 0 1 0 _ ${ stdenv . targetPlatform . linuxArch } . w h l " ;
2022-03-30 09:31:56 +00:00
# Note that cudatoolkit is necessary since jaxlib looks for "ptxas" in $PATH.
# See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for
# more info.
2021-12-30 13:39:12 +00:00
postInstall = lib . optionalString cudaSupport ''
2022-03-30 09:31:56 +00:00
mkdir - p $ out/bin
ln - s $ { cudatoolkit } /bin/ptxas $ out/bin/ptxas
2021-12-30 13:39:12 +00:00
find $ out - type f \ ( - name ' * . so' - or - name ' * . so . * ' \ ) | while read lib ; do
addOpenGLRunpath " $ l i b "
patchelf - - set-rpath " ${ cudatoolkit } / l i b : ${ cudatoolkit . lib } / l i b : ${ cudnn } / l i b : ${ nccl } / l i b : $ ( p a t c h e l f - - p r i n t - r p a t h " $ lib " ) " " $ l i b "
2021-09-18 10:52:07 +00:00
done
'' ;
2021-12-30 13:39:12 +00:00
nativeBuildInputs = lib . optional cudaSupport addOpenGLRunpath ;
propagatedBuildInputs = [
absl-py
double-conversion
flatbuffers
giflib
grpc
jsoncpp
libjpeg_turbo
numpy
scipy
six
snappy
] ;
2021-09-18 10:52:07 +00:00
pythonImportsCheck = [ " j a x l i b " ] ;
2021-12-30 13:39:12 +00:00
# Without it there are complaints about libcudart.so.11.0 not being found
# because RPATH path entries added above are stripped.
dontPatchELF = cudaSupport ;
2021-09-18 10:52:07 +00:00
}