2023-04-12 12:48:02 +00:00
|
|
|
{ lib
|
2023-11-16 04:20:00 +00:00
|
|
|
, config
|
2023-04-12 12:48:02 +00:00
|
|
|
, buildPythonPackage
|
|
|
|
, fetchFromGitHub
|
2024-01-02 11:29:13 +00:00
|
|
|
, fetchpatch
|
2023-04-12 12:48:02 +00:00
|
|
|
, addOpenGLRunpath
|
2024-01-02 11:29:13 +00:00
|
|
|
, setuptools
|
2023-10-19 13:55:26 +00:00
|
|
|
, pytestCheckHook
|
|
|
|
, pythonRelaxDepsHook
|
2023-04-12 12:48:02 +00:00
|
|
|
, cmake
|
2023-10-19 13:55:26 +00:00
|
|
|
, ninja
|
2023-04-12 12:48:02 +00:00
|
|
|
, pybind11
|
|
|
|
, gtest
|
|
|
|
, zlib
|
|
|
|
, ncurses
|
|
|
|
, libxml2
|
|
|
|
, lit
|
2023-11-16 04:20:00 +00:00
|
|
|
, llvm
|
2023-04-12 12:48:02 +00:00
|
|
|
, filelock
|
|
|
|
, torchWithRocm
|
2023-10-19 13:55:26 +00:00
|
|
|
, python
|
2024-02-29 20:09:43 +00:00
|
|
|
|
|
|
|
, runCommand
|
|
|
|
|
2023-10-19 13:55:26 +00:00
|
|
|
, cudaPackages
|
2023-11-16 04:20:00 +00:00
|
|
|
, cudaSupport ? config.cudaSupport
|
2023-04-12 12:48:02 +00:00
|
|
|
}:
|
|
|
|
|
|
|
|
let
|
2024-01-02 11:29:13 +00:00
|
|
|
ptxas = "${cudaPackages.cuda_nvcc}/bin/ptxas"; # Make sure cudaPackages is the right version each update (See python/setup.py)
|
2023-04-12 12:48:02 +00:00
|
|
|
in
|
2023-10-19 13:55:26 +00:00
|
|
|
buildPythonPackage rec {
|
|
|
|
pname = "triton";
|
2024-01-02 11:29:13 +00:00
|
|
|
version = "2.1.0";
|
|
|
|
pyproject = true;
|
2023-04-12 12:48:02 +00:00
|
|
|
|
|
|
|
src = fetchFromGitHub {
|
|
|
|
owner = "openai";
|
|
|
|
repo = pname;
|
|
|
|
rev = "v${version}";
|
2024-01-02 11:29:13 +00:00
|
|
|
hash = "sha256-8UTUwLH+SriiJnpejdrzz9qIquP2zBp1/uwLdHmv0XQ=";
|
2023-04-12 12:48:02 +00:00
|
|
|
};
|
|
|
|
|
|
|
|
patches = [
|
2024-01-02 11:29:13 +00:00
|
|
|
# fix overflow error
|
|
|
|
(fetchpatch {
|
|
|
|
url = "https://github.com/openai/triton/commit/52c146f66b79b6079bcd28c55312fc6ea1852519.patch";
|
|
|
|
hash = "sha256-098/TCQrzvrBAbQiaVGCMaF3o5Yc3yWDxzwSkzIuAtY=";
|
|
|
|
})
|
2023-11-16 04:20:00 +00:00
|
|
|
] ++ lib.optionals (!cudaSupport) [
|
|
|
|
./0000-dont-download-ptxas.patch
|
2024-02-29 20:09:43 +00:00
|
|
|
# openai-triton wants to get ptxas version even if ptxas is not
|
|
|
|
# used, resulting in ptxas not found error.
|
|
|
|
./0001-ptxas-disable-version-key-for-non-cuda-targets.patch
|
2023-04-12 12:48:02 +00:00
|
|
|
];
|
|
|
|
|
|
|
|
nativeBuildInputs = [
|
2024-01-02 11:29:13 +00:00
|
|
|
setuptools
|
2023-04-12 12:48:02 +00:00
|
|
|
pythonRelaxDepsHook
|
2023-10-19 13:55:26 +00:00
|
|
|
# pytestCheckHook # Requires torch (circular dependency) and probably needs GPUs:
|
|
|
|
cmake
|
|
|
|
ninja
|
2023-04-12 12:48:02 +00:00
|
|
|
|
|
|
|
# Note for future:
|
|
|
|
# These *probably* should go in depsTargetTarget
|
|
|
|
# ...but we cannot test cross right now anyway
|
|
|
|
# because we only support cudaPackages on x86_64-linux atm
|
|
|
|
lit
|
|
|
|
llvm
|
|
|
|
];
|
|
|
|
|
|
|
|
buildInputs = [
|
|
|
|
gtest
|
|
|
|
libxml2.dev
|
|
|
|
ncurses
|
|
|
|
pybind11
|
|
|
|
zlib
|
|
|
|
];
|
|
|
|
|
2024-02-29 20:09:43 +00:00
|
|
|
propagatedBuildInputs = [
|
|
|
|
filelock
|
|
|
|
# openai-triton uses setuptools at runtime:
|
|
|
|
# https://github.com/NixOS/nixpkgs/pull/286763/#discussion_r1480392652
|
|
|
|
setuptools
|
|
|
|
];
|
2023-10-19 13:55:26 +00:00
|
|
|
|
|
|
|
postPatch = let
|
|
|
|
# Bash was getting weird without linting,
|
|
|
|
# but basically upstream contains [cc, ..., "-lcuda", ...]
|
|
|
|
# and we replace it with [..., "-lcuda", "-L/run/opengl-driver/lib", "-L$stubs", ...]
|
|
|
|
old = [ "-lcuda" ];
|
|
|
|
new = [ "-lcuda" "-L${addOpenGLRunpath.driverLink}" "-L${cudaPackages.cuda_cudart}/lib/stubs/" ];
|
|
|
|
|
|
|
|
quote = x: ''"${x}"'';
|
|
|
|
oldStr = lib.concatMapStringsSep ", " quote old;
|
|
|
|
newStr = lib.concatMapStringsSep ", " quote new;
|
|
|
|
in ''
|
|
|
|
# Use our `cmakeFlags` instead and avoid downloading dependencies
|
|
|
|
substituteInPlace python/setup.py \
|
|
|
|
--replace "= get_thirdparty_packages(triton_cache_path)" "= os.environ[\"cmakeFlags\"].split()"
|
|
|
|
|
|
|
|
# Already defined in llvm, when built with -DLLVM_INSTALL_UTILS
|
|
|
|
substituteInPlace bin/CMakeLists.txt \
|
|
|
|
--replace "add_subdirectory(FileCheck)" ""
|
|
|
|
|
|
|
|
# Don't fetch googletest
|
|
|
|
substituteInPlace unittest/CMakeLists.txt \
|
|
|
|
--replace "include (\''${CMAKE_CURRENT_SOURCE_DIR}/googletest.cmake)" ""\
|
|
|
|
--replace "include(GoogleTest)" "find_package(GTest REQUIRED)"
|
2023-11-16 04:20:00 +00:00
|
|
|
'' + lib.optionalString cudaSupport ''
|
|
|
|
# Use our linker flags
|
2024-01-02 11:29:13 +00:00
|
|
|
substituteInPlace python/triton/common/build.py \
|
2023-11-16 04:20:00 +00:00
|
|
|
--replace '${oldStr}' '${newStr}'
|
2023-10-19 13:55:26 +00:00
|
|
|
'';
|
2023-04-12 12:48:02 +00:00
|
|
|
|
|
|
|
# Avoid GLIBCXX mismatch with other cuda-enabled python packages
|
|
|
|
preConfigure = ''
|
2024-04-21 15:54:59 +00:00
|
|
|
# Ensure that the build process uses the requested number of cores
|
|
|
|
export MAX_JOBS="$NIX_BUILD_CORES"
|
|
|
|
|
2023-04-12 12:48:02 +00:00
|
|
|
# Upstream's setup.py tries to write cache somewhere in ~/
|
2023-10-19 13:55:26 +00:00
|
|
|
export HOME=$(mktemp -d)
|
2023-04-12 12:48:02 +00:00
|
|
|
|
|
|
|
# Upstream's github actions patch setup.cfg to write base-dir. May be redundant
|
|
|
|
echo "
|
|
|
|
[build_ext]
|
|
|
|
base-dir=$PWD" >> python/setup.cfg
|
|
|
|
|
|
|
|
# The rest (including buildPhase) is relative to ./python/
|
2023-10-19 13:55:26 +00:00
|
|
|
cd python
|
2023-11-16 04:20:00 +00:00
|
|
|
'' + lib.optionalString cudaSupport ''
|
|
|
|
export CC=${cudaPackages.backendStdenv.cc}/bin/cc;
|
|
|
|
export CXX=${cudaPackages.backendStdenv.cc}/bin/c++;
|
2023-04-12 12:48:02 +00:00
|
|
|
|
|
|
|
# Work around download_and_copy_ptxas()
|
2023-10-19 13:55:26 +00:00
|
|
|
mkdir -p $PWD/triton/third_party/cuda/bin
|
|
|
|
ln -s ${ptxas} $PWD/triton/third_party/cuda/bin
|
2023-04-12 12:48:02 +00:00
|
|
|
'';
|
|
|
|
|
|
|
|
# CMake is run by setup.py instead
|
|
|
|
dontUseCmakeConfigure = true;
|
|
|
|
|
2023-10-19 13:55:26 +00:00
|
|
|
# Setuptools (?) strips runpath and +x flags. Let's just restore the symlink
|
2023-11-16 04:20:00 +00:00
|
|
|
postFixup = lib.optionalString cudaSupport ''
|
2023-10-19 13:55:26 +00:00
|
|
|
rm -f $out/${python.sitePackages}/triton/third_party/cuda/bin/ptxas
|
|
|
|
ln -s ${ptxas} $out/${python.sitePackages}/triton/third_party/cuda/bin/ptxas
|
|
|
|
'';
|
|
|
|
|
|
|
|
checkInputs = [ cmake ]; # ctest
|
2023-04-12 12:48:02 +00:00
|
|
|
dontUseSetuptoolsCheck = true;
|
2023-10-19 13:55:26 +00:00
|
|
|
|
|
|
|
preCheck = ''
|
2023-04-12 12:48:02 +00:00
|
|
|
# build/temp* refers to build_ext.build_temp (looked up in the build logs)
|
2023-10-19 13:55:26 +00:00
|
|
|
(cd /build/source/python/build/temp* ; ctest)
|
|
|
|
|
|
|
|
# For pytestCheckHook
|
|
|
|
cd test/unit
|
|
|
|
'';
|
|
|
|
|
|
|
|
# Circular dependency on torch
|
|
|
|
# pythonImportsCheck = [
|
|
|
|
# "triton"
|
|
|
|
# "triton.language"
|
|
|
|
# ];
|
2023-04-12 12:48:02 +00:00
|
|
|
|
|
|
|
# Ultimately, torch is our test suite:
|
2024-02-29 20:09:43 +00:00
|
|
|
passthru.tests = {
|
|
|
|
inherit torchWithRocm;
|
|
|
|
# Implemented as alternative to pythonImportsCheck, in case if circular dependency on torch occurs again,
|
|
|
|
# and pythonImportsCheck is commented back.
|
|
|
|
import-triton = runCommand "import-triton" { nativeBuildInputs = [(python.withPackages (ps: [ps.openai-triton]))]; } ''
|
|
|
|
python << \EOF
|
|
|
|
import triton
|
|
|
|
import triton.language
|
|
|
|
EOF
|
|
|
|
touch "$out"
|
|
|
|
'';
|
|
|
|
};
|
2023-04-12 12:48:02 +00:00
|
|
|
|
|
|
|
pythonRemoveDeps = [
|
|
|
|
# Circular dependency, cf. https://github.com/openai/triton/issues/1374
|
|
|
|
"torch"
|
|
|
|
|
|
|
|
# CLI tools without dist-info
|
|
|
|
"cmake"
|
|
|
|
"lit"
|
|
|
|
];
|
2023-10-19 13:55:26 +00:00
|
|
|
|
2023-04-12 12:48:02 +00:00
|
|
|
meta = with lib; {
|
2023-10-19 13:55:26 +00:00
|
|
|
description = "Language and compiler for writing highly efficient custom Deep-Learning primitives";
|
|
|
|
homepage = "https://github.com/openai/triton";
|
2024-05-15 15:35:15 +00:00
|
|
|
platforms = platforms.linux;
|
2023-04-12 12:48:02 +00:00
|
|
|
license = licenses.mit;
|
2023-10-19 13:55:26 +00:00
|
|
|
maintainers = with maintainers; [ SomeoneSerge Madouura ];
|
2023-04-12 12:48:02 +00:00
|
|
|
};
|
|
|
|
}
|