254 lines
7.3 KiB
Nix
254 lines
7.3 KiB
Nix
|
{ lib
|
||
|
, buildPythonPackage
|
||
|
, python
|
||
|
, fetchpatch
|
||
|
, fetchFromGitHub
|
||
|
, addOpenGLRunpath
|
||
|
, cmake
|
||
|
, cudaPackages
|
||
|
, llvmPackages
|
||
|
, pybind11
|
||
|
, gtest
|
||
|
, zlib
|
||
|
, ncurses
|
||
|
, libxml2
|
||
|
, lit
|
||
|
, filelock
|
||
|
, torchWithRocm
|
||
|
, pytest
|
||
|
, pytestCheckHook
|
||
|
, pythonRelaxDepsHook
|
||
|
, pkgsTargetTarget
|
||
|
}:
|
||
|
|
||
|
let
|
||
|
pname = "triton";
|
||
|
version = "2.0.0";
|
||
|
|
||
|
inherit (cudaPackages) cuda_cudart backendStdenv;
|
||
|
|
||
|
# A time may come we'll want to be cross-friendly
|
||
|
#
|
||
|
# Short explanation: we need pkgsTargetTarget, because we use string
|
||
|
# interpolation instead of buildInputs.
|
||
|
#
|
||
|
# Long explanation: OpenAI/triton downloads and vendors a copy of NVidia's
|
||
|
# ptxas compiler. We're not running this ptxas on the build machine, but on
|
||
|
# the user's machine, i.e. our Target platform. The second "Target" in
|
||
|
# pkgsTargetTarget maybe doesn't matter, because ptxas compiles programs to
|
||
|
# be executed on the GPU.
|
||
|
# Cf. https://nixos.org/manual/nixpkgs/unstable/#sec-cross-infra
|
||
|
ptxas = "${pkgsTargetTarget.cudaPackages.cuda_nvcc}/bin/ptxas";
|
||
|
|
||
|
llvm = (llvmPackages.llvm.override {
|
||
|
llvmTargetsToBuild = [ "NATIVE" "NVPTX" ];
|
||
|
# Upstream CI sets these too:
|
||
|
# targetProjects = [ "mlir" ];
|
||
|
extraCMakeFlags = [
|
||
|
"-DLLVM_INSTALL_UTILS=ON"
|
||
|
];
|
||
|
});
|
||
|
in
|
||
|
buildPythonPackage {
|
||
|
inherit pname version;
|
||
|
|
||
|
format = "setuptools";
|
||
|
|
||
|
src = fetchFromGitHub {
|
||
|
owner = "openai";
|
||
|
repo = pname;
|
||
|
rev = "v${version}";
|
||
|
hash = "sha256-9GZzugab+Pdt74Dj6zjlEzjj4BcJ69rzMJmqcVMxsKU=";
|
||
|
};
|
||
|
|
||
|
patches = [
|
||
|
# Prerequisite for llvm15 patch
|
||
|
(fetchpatch {
|
||
|
url = "https://github.com/openai/triton/commit/2aba985daaa70234823ea8f1161da938477d3e02.patch";
|
||
|
hash = "sha256-LGv0+Ut2WYPC4Ksi4803Hwmhi3FyQOF9zElJc/JCobk=";
|
||
|
})
|
||
|
(fetchpatch {
|
||
|
url = "https://github.com/openai/triton/commit/e3941f9d09cdd31529ba4a41018cfc0096aafea6.patch";
|
||
|
hash = "sha256-A+Gor6qzFlGQhVVhiaaYOzqqx8yO2MdssnQS6TIfUWg=";
|
||
|
})
|
||
|
|
||
|
# Source: https://github.com/openai/triton/commit/fc7a8e35819bda632bdcf1cf75fd9abe4d4e077a.patch
|
||
|
# The original patch adds ptxas binary, so we include our own clean copy
|
||
|
# Drop with the next update
|
||
|
./llvm15.patch
|
||
|
|
||
|
# TODO: there have been commits upstream aimed at removing the "torch"
|
||
|
# circular dependency, but the patches fail to apply on the release
|
||
|
# revision. Keeping the link for future reference
|
||
|
# Also cf. https://github.com/openai/triton/issues/1374
|
||
|
|
||
|
# (fetchpatch {
|
||
|
# url = "https://github.com/openai/triton/commit/fc7c0b0e437a191e421faa61494b2ff4870850f1.patch";
|
||
|
# hash = "sha256-f0shIqHJkVvuil2Yku7vuqWFn7VCRKFSFjYRlwx25ig=";
|
||
|
# })
|
||
|
];
|
||
|
|
||
|
postPatch = ''
|
||
|
substituteInPlace python/setup.py \
|
||
|
--replace \
|
||
|
'= get_thirdparty_packages(triton_cache_path)' \
|
||
|
'= os.environ["cmakeFlags"].split()'
|
||
|
''
|
||
|
# Wiring triton=2.0.0 with llcmPackages_rocm.llvm=5.4.3
|
||
|
# Revisit when updating either triton or llvm
|
||
|
+ ''
|
||
|
substituteInPlace CMakeLists.txt \
|
||
|
--replace "nvptx" "NVPTX" \
|
||
|
--replace "LLVM 11" "LLVM"
|
||
|
sed -i '/AddMLIR/a set(MLIR_TABLEGEN_EXE "${llvmPackages.mlir}/bin/mlir-tblgen")' CMakeLists.txt
|
||
|
sed -i '/AddMLIR/a set(MLIR_INCLUDE_DIR ''${MLIR_INCLUDE_DIRS})' CMakeLists.txt
|
||
|
find -iname '*.td' -exec \
|
||
|
sed -i \
|
||
|
-e '\|include "mlir/IR/OpBase.td"|a include "mlir/IR/AttrTypeBase.td"' \
|
||
|
-e 's|include "mlir/Dialect/StandardOps/IR/Ops.td"|include "mlir/Dialect/Func/IR/FuncOps.td"|' \
|
||
|
'{}' ';'
|
||
|
substituteInPlace unittest/CMakeLists.txt --replace "include(GoogleTest)" "find_package(GTest REQUIRED)"
|
||
|
sed -i 's/^include.*$//' unittest/CMakeLists.txt
|
||
|
sed -i '/LINK_LIBS/i NVPTXInfo' lib/Target/PTX/CMakeLists.txt
|
||
|
sed -i '/LINK_LIBS/i NVPTXCodeGen' lib/Target/PTX/CMakeLists.txt
|
||
|
''
|
||
|
# TritonMLIRIR already links MLIRIR. Not transitive?
|
||
|
# + ''
|
||
|
# echo "target_link_libraries(TritonPTX PUBLIC MLIRIR)" >> lib/Target/PTX/CMakeLists.txt
|
||
|
# ''
|
||
|
# Already defined in llvm, when built with -DLLVM_INSTALL_UTILS
|
||
|
+ ''
|
||
|
substituteInPlace bin/CMakeLists.txt \
|
||
|
--replace "add_subdirectory(FileCheck)" ""
|
||
|
|
||
|
rm cmake/FindLLVM.cmake
|
||
|
''
|
||
|
+
|
||
|
(
|
||
|
let
|
||
|
# Bash was getting weird without linting,
|
||
|
# but basically upstream contains [cc, ..., "-lcuda", ...]
|
||
|
# and we replace it with [..., "-lcuda", "-L/run/opengl-driver/lib", "-L$stubs", ...]
|
||
|
old = [ "-lcuda" ];
|
||
|
new = [ "-lcuda" "-L${addOpenGLRunpath.driverLink}" "-L${cuda_cudart}/lib/stubs/" ];
|
||
|
|
||
|
quote = x: ''"${x}"'';
|
||
|
oldStr = lib.concatMapStringsSep ", " quote old;
|
||
|
newStr = lib.concatMapStringsSep ", " quote new;
|
||
|
in
|
||
|
''
|
||
|
substituteInPlace python/triton/compiler.py \
|
||
|
--replace '${oldStr}' '${newStr}'
|
||
|
''
|
||
|
)
|
||
|
# Triton seems to be looking up cuda.h
|
||
|
+ ''
|
||
|
sed -i 's|cu_include_dir = os.path.join.*$|cu_include_dir = "${cuda_cudart}/include"|' python/triton/compiler.py
|
||
|
'';
|
||
|
|
||
|
nativeBuildInputs = [
|
||
|
cmake
|
||
|
pythonRelaxDepsHook
|
||
|
|
||
|
# Requires torch (circular dependency) and probably needs GPUs:
|
||
|
# pytestCheckHook
|
||
|
|
||
|
# Note for future:
|
||
|
# These *probably* should go in depsTargetTarget
|
||
|
# ...but we cannot test cross right now anyway
|
||
|
# because we only support cudaPackages on x86_64-linux atm
|
||
|
lit
|
||
|
llvm
|
||
|
llvmPackages.mlir
|
||
|
];
|
||
|
|
||
|
buildInputs = [
|
||
|
gtest
|
||
|
libxml2.dev
|
||
|
ncurses
|
||
|
pybind11
|
||
|
zlib
|
||
|
];
|
||
|
|
||
|
propagatedBuildInputs = [
|
||
|
filelock
|
||
|
];
|
||
|
|
||
|
# Avoid GLIBCXX mismatch with other cuda-enabled python packages
|
||
|
preConfigure = ''
|
||
|
export CC="${backendStdenv.cc}/bin/cc";
|
||
|
export CXX="${backendStdenv.cc}/bin/c++";
|
||
|
|
||
|
# Upstream's setup.py tries to write cache somewhere in ~/
|
||
|
export HOME=$TMPDIR
|
||
|
|
||
|
# Upstream's github actions patch setup.cfg to write base-dir. May be redundant
|
||
|
echo "
|
||
|
[build_ext]
|
||
|
base-dir=$PWD" >> python/setup.cfg
|
||
|
|
||
|
# The rest (including buildPhase) is relative to ./python/
|
||
|
cd python/
|
||
|
|
||
|
# Work around download_and_copy_ptxas()
|
||
|
dst_cuda="$PWD/triton/third_party/cuda/bin"
|
||
|
mkdir -p "$dst_cuda"
|
||
|
ln -s "${ptxas}" "$dst_cuda/"
|
||
|
'';
|
||
|
|
||
|
# CMake is run by setup.py instead
|
||
|
dontUseCmakeConfigure = true;
|
||
|
cmakeFlags = [
|
||
|
"-DMLIR_DIR=${llvmPackages.mlir}/lib/cmake/mlir"
|
||
|
];
|
||
|
|
||
|
postFixup =
|
||
|
let
|
||
|
ptxasDestination = "$out/${python.sitePackages}/triton/third_party/cuda/bin/ptxas";
|
||
|
in
|
||
|
# Setuptools (?) strips runpath and +x flags. Let's just restore the symlink
|
||
|
''
|
||
|
rm -f ${ptxasDestination}
|
||
|
ln -s ${ptxas} ${ptxasDestination}
|
||
|
'';
|
||
|
|
||
|
checkInputs = [
|
||
|
cmake # ctest
|
||
|
];
|
||
|
dontUseSetuptoolsCheck = true;
|
||
|
preCheck =
|
||
|
# build/temp* refers to build_ext.build_temp (looked up in the build logs)
|
||
|
''
|
||
|
(cd /build/source/python/build/temp* ; ctest)
|
||
|
'' # For pytestCheckHook
|
||
|
+ ''
|
||
|
cd test/unit
|
||
|
'';
|
||
|
pythonImportsCheck = [
|
||
|
# Circular dependency on torch
|
||
|
# "triton"
|
||
|
# "triton.language"
|
||
|
];
|
||
|
|
||
|
# Ultimately, torch is our test suite:
|
||
|
passthru.tests = {
|
||
|
inherit torchWithRocm;
|
||
|
};
|
||
|
|
||
|
pythonRemoveDeps = [
|
||
|
# Circular dependency, cf. https://github.com/openai/triton/issues/1374
|
||
|
"torch"
|
||
|
|
||
|
# CLI tools without dist-info
|
||
|
"cmake"
|
||
|
"lit"
|
||
|
];
|
||
|
meta = with lib; {
|
||
|
description = "Development repository for the Triton language and compiler";
|
||
|
homepage = "https://github.com/openai/triton/";
|
||
|
platforms = lib.platforms.unix;
|
||
|
license = licenses.mit;
|
||
|
maintainers = with maintainers; [ SomeoneSerge ];
|
||
|
};
|
||
|
}
|