2021-12-30 13:39:12 +00:00
{ lib
, pkgs
, stdenv
# Build-time dependencies:
, addOpenGLRunpath
2022-02-20 05:27:41 +00:00
, bazel_5
2021-12-30 13:39:12 +00:00
, binutils
, buildBazelPackage
, buildPythonPackage
2022-08-21 13:32:41 +00:00
, cctools
2022-09-09 14:08:57 +00:00
, curl
2021-12-30 13:39:12 +00:00
, cython
, fetchFromGitHub
, git
2022-08-21 13:32:41 +00:00
, IOKit
2021-12-30 13:39:12 +00:00
, jsoncpp
2022-09-09 14:08:57 +00:00
, nsync
, openssl
2021-12-30 13:39:12 +00:00
, pybind11
, setuptools
, symlinkJoin
, wheel
, which
# Python dependencies:
, absl-py
, flatbuffers
, numpy
, scipy
, six
# Runtime dependencies:
, double-conversion
, giflib
, grpc
, libjpeg_turbo
2022-09-30 11:47:45 +00:00
, protobuf
2021-12-30 13:39:12 +00:00
, python
, snappy
, zlib
# CUDA flags:
, cudaSupport ? false
2022-04-15 01:41:22 +00:00
, cudaPackages ? { }
2021-12-30 13:39:12 +00:00
# MKL:
, mklSupport ? true
} :
2021-09-18 10:52:07 +00:00
let
2022-12-17 10:02:37 +00:00
inherit ( cudaPackages ) cudatoolkit cudaFlags cudnn nccl ;
2022-04-15 01:41:22 +00:00
2021-09-18 10:52:07 +00:00
pname = " j a x l i b " ;
2022-10-30 15:09:59 +00:00
version = " 0 . 3 . 2 2 " ;
2021-12-30 13:39:12 +00:00
meta = with lib ; {
description = " J A X i s A u t o g r a d a n d X L A , b r o u g h t t o g e t h e r f o r h i g h - p e r f o r m a n c e m a c h i n e l e a r n i n g r e s e a r c h . " ;
homepage = " h t t p s : / / g i t h u b . c o m / g o o g l e / j a x " ;
license = licenses . asl20 ;
maintainers = with maintainers ; [ ndl ] ;
2022-08-21 13:32:41 +00:00
platforms = platforms . unix ;
# aarch64-darwin is broken because of https://github.com/bazelbuild/rules_cc/pull/136
# however even with that fix applied, it doesn't work for everyone:
# https://github.com/NixOS/nixpkgs/pull/184395#issuecomment-1207287129
broken = stdenv . isAarch64 ;
2021-12-30 13:39:12 +00:00
} ;
cudatoolkit_joined = symlinkJoin {
name = " ${ cudatoolkit . name } - m e r g e d " ;
paths = [
cudatoolkit . lib
cudatoolkit . out
] ++ lib . optionals ( lib . versionOlder cudatoolkit . version " 1 1 " ) [
# for some reason some of the required libs are in the targets/x86_64-linux
# directory; not sure why but this works around it
" ${ cudatoolkit } / t a r g e t s / ${ stdenv . system } "
] ;
} ;
cudatoolkit_cc_joined = symlinkJoin {
name = " ${ cudatoolkit . cc . name } - m e r g e d " ;
paths = [
cudatoolkit . cc
binutils . bintools # for ar, dwp, nm, objcopy, objdump, strip
] ;
} ;
2021-09-18 10:52:07 +00:00
2021-12-30 13:39:12 +00:00
bazel-build = buildBazelPackage {
name = " b a z e l - b u i l d - ${ pname } - ${ version } " ;
2021-09-18 10:52:07 +00:00
2022-02-20 05:27:41 +00:00
bazel = bazel_5 ;
2021-12-30 13:39:12 +00:00
src = fetchFromGitHub {
owner = " g o o g l e " ;
repo = " j a x " ;
rev = " ${ pname } - v ${ version } " ;
2022-10-30 15:09:59 +00:00
hash = " s h a 2 5 6 - b n c z J 8 m a / U M K h A 5 M U Q 6 H 4 a z + T j + B y 1 4 Z T G 6 l Q Q w p t Q s = " ;
2021-09-18 10:52:07 +00:00
} ;
2021-12-30 13:39:12 +00:00
nativeBuildInputs = [
cython
pkgs . flatbuffers
git
setuptools
wheel
which
2022-09-09 14:08:57 +00:00
] ++ lib . optionals stdenv . isDarwin [
cctools
2021-12-30 13:39:12 +00:00
] ;
buildInputs = [
2022-09-09 14:08:57 +00:00
curl
2021-12-30 13:39:12 +00:00
double-conversion
giflib
grpc
jsoncpp
libjpeg_turbo
numpy
2022-09-09 14:08:57 +00:00
openssl
2021-12-30 13:39:12 +00:00
pkgs . flatbuffers
2022-09-30 11:47:45 +00:00
protobuf
2021-12-30 13:39:12 +00:00
pybind11
scipy
six
snappy
zlib
] ++ lib . optionals cudaSupport [
cudatoolkit
cudnn
2022-08-21 13:32:41 +00:00
] ++ lib . optionals stdenv . isDarwin [
IOKit
2022-09-09 14:08:57 +00:00
] ++ lib . optionals ( ! stdenv . isDarwin ) [
nsync
2021-12-30 13:39:12 +00:00
] ;
postPatch = ''
rm - f . bazelversion
'' ;
bazelTarget = " / / b u i l d : b u i l d _ w h e e l " ;
removeRulesCC = false ;
GCC_HOST_COMPILER_PREFIX = lib . optionalString cudaSupport " ${ cudatoolkit_cc_joined } / b i n " ;
GCC_HOST_COMPILER_PATH = lib . optionalString cudaSupport " ${ cudatoolkit_cc_joined } / b i n / g c c " ;
preConfigure = ''
# dummy ldconfig
mkdir dummy-ldconfig
echo " # ! ${ stdenv . shell } " > dummy-ldconfig/ldconfig
chmod + x dummy-ldconfig/ldconfig
export PATH = " $ P W D / d u m m y - l d c o n f i g : $ P A T H "
cat < < CFG > ./.jax_configure.bazelrc
build - - strategy = Genrule = standalone
build - - repo_env PYTHON_BIN_PATH = " ${ python } / b i n / p y t h o n "
build - - action_env = PYENV_ROOT
build - - python_path = " ${ python } / b i n / p y t h o n "
build - - distinct_host_configuration = false
2022-09-30 11:47:45 +00:00
build - - define PROTOBUF_INCLUDE_PATH = " ${ protobuf } / i n c l u d e "
2021-12-30 13:39:12 +00:00
'' + l i b . o p t i o n a l S t r i n g c u d a S u p p o r t ''
build - - action_env CUDA_TOOLKIT_PATH = " ${ cudatoolkit_joined } "
build - - action_env CUDNN_INSTALL_PATH = " ${ cudnn } "
build - - action_env TF_CUDA_PATHS = " ${ cudatoolkit_joined } , ${ cudnn } , ${ nccl } "
build - - action_env TF_CUDA_VERSION = " ${ lib . versions . majorMinor cudatoolkit . version } "
build - - action_env TF_CUDNN_VERSION = " ${ lib . versions . major cudnn . version } "
2022-12-17 10:02:37 +00:00
build:cuda - - action_env TF_CUDA_COMPUTE_CAPABILITIES = " ${ cudaFlags . cudaRealCapabilitiesCommaString } "
2021-12-30 13:39:12 +00:00
'' + ''
CFG
'' ;
# Copy-paste from TF derivation.
# Most of these are not really used in jaxlib compilation but it's simpler to keep it
# 'as is' so that it's more compatible with TF derivation.
2022-09-09 14:08:57 +00:00
TF_SYSTEM_LIBS = lib . concatStringsSep " , " ( [
2021-12-30 13:39:12 +00:00
" a b s l _ p y "
" a s t o r _ a r c h i v e "
" a s t u n p a r s e _ a r c h i v e "
" b o r i n g s s l "
# Not packaged in nixpkgs
# "com_github_googleapis_googleapis"
# "com_github_googlecloudplatform_google_cloud_cpp"
" c o m _ g i t h u b _ g r p c _ g r p c "
" c o m _ g o o g l e _ p r o t o b u f "
# Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)'
# "com_googlesource_code_re2"
" c u r l "
" c y t h o n "
" d i l l _ a r c h i v e "
" d o u b l e _ c o n v e r s i o n "
" f l a t b u f f e r s "
" f u n c t o o l s 3 2 _ a r c h i v e "
" g a s t _ a r c h i v e "
" g i f "
" h w l o c "
" i c u "
" j s o n c p p _ g i t "
" l i b j p e g _ t u r b o "
" l m d b "
" n a s m "
" o p t _ e i n s u m _ a r c h i v e "
" o r g _ s q l i t e "
" p a s t a "
" p n g "
" p y b i n d 1 1 "
" s i x _ a r c h i v e "
" s n a p p y "
" t b l i b _ a r c h i v e "
" t e r m c o l o r _ a r c h i v e "
" t y p i n g _ e x t e n s i o n s _ a r c h i v e "
" w r a p t "
" z l i b "
2022-09-09 14:08:57 +00:00
] ++ lib . optionals ( ! stdenv . isDarwin ) [
" n s y n c " # fails to build on darwin
] ) ;
2021-12-30 13:39:12 +00:00
# Make sure Bazel knows about our configuration flags during fetching so that the
# relevant dependencies can be downloaded.
2022-08-21 13:32:41 +00:00
bazelFlags = [
2021-12-30 13:39:12 +00:00
" - c o p t "
2022-10-30 15:09:59 +00:00
] ++ lib . optionals ( stdenv . targetPlatform . isx86_64 && stdenv . targetPlatform . isUnix ) [
2021-12-30 13:39:12 +00:00
" - - c o n f i g = a v x _ p o s i x "
2022-10-30 15:09:59 +00:00
] ++ lib . optionals cudaSupport [
2021-12-30 13:39:12 +00:00
" - - c o n f i g = c u d a "
2022-10-30 15:09:59 +00:00
] ++ lib . optionals mklSupport [
2021-12-30 13:39:12 +00:00
" - - c o n f i g = m k l _ o p e n _ s o u r c e _ o n l y "
2022-08-21 13:32:41 +00:00
] ++ lib . optionals stdenv . cc . isClang [
# bazel depends on the compiler frontend automatically selecting these flags based on file
# extension but our clang doesn't.
# https://github.com/NixOS/nixpkgs/issues/150655
" - - c x x o p t = - x " " - - c x x o p t = c + + " " - - h o s t _ c x x o p t = - x " " - - h o s t _ c x x o p t = c + + "
2021-12-30 13:39:12 +00:00
] ;
fetchAttrs = {
sha256 =
if cudaSupport then
2023-01-11 07:51:40 +00:00
" s h a 2 5 6 - n 8 w o + h D 9 Z Y O 1 S s J K g y J z U m j R l s z 4 5 W T 6 t t 5 Z L l e G v G Y = "
else {
x86_64-linux = " s h a 2 5 6 - A 0 A 1 8 k x g G N G H N Q 6 7 Z P U z h 3 Y q 2 L E c R V 7 C q R 9 E f P 8 0 N Q k = " ;
aarch64-linux = " s h a 2 5 6 - m U 2 j z u D u 8 9 j V m a G / M 5 b A 3 j S d 7 n 7 l D i + h 8 s d h s 1 z 8 p 1 A = " ;
x86_64-darwin = " s h a 2 5 6 - 9 n N T p e t v j y i p D / l 8 v K l r e g l 1 j / O n Z K A c O C o Z Q e R B v t s = " ;
aarch64-darwin = " s h a 2 5 6 - d O G U s d F I m e O L c Z 3 V t g r N n d 8 A / H g I s / L Y u H 9 G Q V 7 A + 7 8 = " ;
} . ${ stdenv . system } or ( throw " u n s u p p o r t e d s y s t e m ${ stdenv . system } " ) ;
2021-12-30 13:39:12 +00:00
} ;
buildAttrs = {
outputs = [ " o u t " ] ;
# Note: we cannot do most of this patching at `patch` phase as the deps are not available yet.
# 1) Fix pybind11 include paths.
2022-08-21 13:32:41 +00:00
# 2) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on
# loading multiple extensions in the same python program due to duplicate protobuf DBs.
2021-12-30 13:39:12 +00:00
# 3) Patch python path in the compiler driver.
preBuild = ''
2022-09-09 14:08:57 +00:00
for src in ./jaxlib /* . { c c , h } . / j a x l i b / c u d a / * . { c c , h } ; d o
2021-12-30 13:39:12 +00:00
sed - i ' s @ include/pybind11 @ pybind11 @ g' $ src
done
'' + l i b . o p t i o n a l S t r i n g c u d a S u p p o r t ''
patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
2022-08-21 13:32:41 +00:00
'' + l i b . o p t i o n a l S t r i n g s t d e n v . i s D a r w i n ''
# Framework search paths aren't added by bintools hook
# https://github.com/NixOS/nixpkgs/pull/41914
export NIX_LDFLAGS + = " - F ${ IOKit } / L i b r a r y / F r a m e w o r k s "
substituteInPlace ../output/external/rules_cc/cc/private/toolchain/osx_cc_wrapper.sh.tpl \
- - replace " / u s r / b i n / i n s t a l l _ n a m e _ t o o l " " ${ cctools } / b i n / i n s t a l l _ n a m e _ t o o l "
substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \
- - replace " / u s r / b i n / l i b t o o l " " ${ cctools } / b i n / l i b t o o l "
'' + ( i f s t d e n v . c c . i s G N U t h e n ''
sed - i ' s @ - lprotobuf @ - l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
sed - i ' s @ - lprotoc @ - l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
'' e l s e i f s t d e n v . c c . i s C l a n g t h e n ''
2022-09-30 11:47:45 +00:00
sed - i ' s @ - lprotobuf @ $ { protobuf } /lib/libprotobuf.a @ ' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
sed - i ' s @ - lprotoc @ $ { protobuf } /lib/libprotoc.a @ ' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
2022-08-21 13:32:41 +00:00
'' e l s e t h r o w " U n s u p p o r t e d s t d e n v . c c : ${ stdenv . cc } " ) ;
2021-12-30 13:39:12 +00:00
installPhase = ''
./bazel-bin/build/build_wheel - - output_path = $ out - - cpu = $ { stdenv . targetPlatform . linuxArch }
'' ;
2021-09-18 10:52:07 +00:00
} ;
2021-12-30 13:39:12 +00:00
inherit meta ;
} ;
2022-08-21 13:32:41 +00:00
platformTag =
if stdenv . targetPlatform . isLinux then
2022-09-09 14:08:57 +00:00
" m a n y l i n u x 2 0 1 4 _ ${ stdenv . targetPlatform . linuxArch } "
2022-08-21 13:32:41 +00:00
else if stdenv . system == " x 8 6 _ 6 4 - d a r w i n " then
" m a c o s x _ 1 0 _ 9 _ ${ stdenv . targetPlatform . linuxArch } "
else if stdenv . system == " a a r c h 6 4 - d a r w i n " then
" m a c o s x _ 1 1 _ 0 _ ${ stdenv . targetPlatform . linuxArch } "
else throw " U n s u p p o r t e d t a r g e t p l a t f o r m : ${ stdenv . targetPlatform } " ;
2021-12-30 13:39:12 +00:00
in
buildPythonPackage {
inherit meta pname version ;
format = " w h e e l " ;
2022-10-30 15:09:59 +00:00
src =
let cp = " c p ${ builtins . replaceStrings [ " . " ] [ " " ] python . pythonVersion } " ;
in " ${ bazel-build } / j a x l i b - ${ version } - ${ cp } - ${ cp } - ${ platformTag } . w h l " ;
2021-12-30 13:39:12 +00:00
2022-03-30 09:31:56 +00:00
# Note that cudatoolkit is necessary since jaxlib looks for "ptxas" in $PATH.
# See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for
# more info.
2021-12-30 13:39:12 +00:00
postInstall = lib . optionalString cudaSupport ''
2022-03-30 09:31:56 +00:00
mkdir - p $ out/bin
ln - s $ { cudatoolkit } /bin/ptxas $ out/bin/ptxas
2021-12-30 13:39:12 +00:00
find $ out - type f \ ( - name ' * . so' - or - name ' * . so . * ' \ ) | while read lib ; do
addOpenGLRunpath " $ l i b "
patchelf - - set-rpath " ${ cudatoolkit } / l i b : ${ cudatoolkit . lib } / l i b : ${ cudnn } / l i b : ${ nccl } / l i b : $ ( p a t c h e l f - - p r i n t - r p a t h " $ lib " ) " " $ l i b "
2021-09-18 10:52:07 +00:00
done
'' ;
2021-12-30 13:39:12 +00:00
nativeBuildInputs = lib . optional cudaSupport addOpenGLRunpath ;
propagatedBuildInputs = [
absl-py
2022-09-09 14:08:57 +00:00
curl
2021-12-30 13:39:12 +00:00
double-conversion
flatbuffers
giflib
grpc
jsoncpp
libjpeg_turbo
numpy
scipy
six
snappy
] ;
2021-09-18 10:52:07 +00:00
pythonImportsCheck = [ " j a x l i b " ] ;
2021-12-30 13:39:12 +00:00
# Without it there are complaints about libcudart.so.11.0 not being found
# because RPATH path entries added above are stripped.
dontPatchELF = cudaSupport ;
2021-09-18 10:52:07 +00:00
}