157 lines
5.6 KiB
Nix
157 lines
5.6 KiB
Nix
|
final: prev: let
|
||
|
|
||
|
inherit (final) callPackage;
|
||
|
inherit (prev) cudatoolkit cudaVersion lib pkgs;
|
||
|
|
||
|
### TensorRT
|
||
|
|
||
|
buildTensorRTPackage = args:
|
||
|
callPackage ./generic.nix { } args;
|
||
|
|
||
|
toUnderscore = str: lib.replaceStrings ["."] ["_"] str;
|
||
|
|
||
|
majorMinorPatch = str: lib.concatStringsSep "." (lib.take 3 (lib.splitVersion str));
|
||
|
|
||
|
tensorRTPackages = let
|
||
|
# Check whether a file is supported for our cuda version
|
||
|
isSupported = fileData: lib.elem cudaVersion fileData.supportedCudaVersions;
|
||
|
# Return the first file that is supported. In practice there should only ever be one anyway.
|
||
|
supportedFile = files: lib.findFirst isSupported null files;
|
||
|
|
||
|
# Compute versioned attribute name to be used in this package set
|
||
|
computeName = version: "tensorrt_${toUnderscore version}";
|
||
|
|
||
|
# Supported versions with versions as keys and file as value
|
||
|
supportedVersions = lib.recursiveUpdate
|
||
|
{
|
||
|
tensorrt = {
|
||
|
enable = false;
|
||
|
fileVersionCuda = null;
|
||
|
fileVersionCudnn = null;
|
||
|
fullVersion = "0.0.0";
|
||
|
sha256 = null;
|
||
|
tarball = null;
|
||
|
supportedCudaVersions = [ ];
|
||
|
};
|
||
|
}
|
||
|
(lib.mapAttrs' (version: attrs: lib.nameValuePair (computeName version) attrs)
|
||
|
(lib.filterAttrs (version: file: file != null) (lib.mapAttrs (version: files: supportedFile files) tensorRTVersions)));
|
||
|
|
||
|
# Add all supported builds as attributes
|
||
|
allBuilds = lib.mapAttrs (name: file: buildTensorRTPackage (lib.removeAttrs file ["fileVersionCuda"])) supportedVersions;
|
||
|
|
||
|
# Set the default attributes, e.g. tensorrt = tensorrt_8_4;
|
||
|
defaultName = computeName tensorRTDefaultVersion;
|
||
|
defaultBuild = lib.optionalAttrs (allBuilds ? ${defaultName}) { tensorrt = allBuilds.${computeName tensorRTDefaultVersion}; };
|
||
|
in {
|
||
|
inherit buildTensorRTPackage;
|
||
|
} // allBuilds // defaultBuild;
|
||
|
|
||
|
tarballURL =
|
||
|
{fullVersion, fileVersionCuda, fileVersionCudnn ? null} :
|
||
|
"TensorRT-${fullVersion}.Linux.x86_64-gnu.cuda-${fileVersionCuda}"
|
||
|
+ lib.optionalString (fileVersionCudnn != null) ".cudnn${fileVersionCudnn}"
|
||
|
+ ".tar.gz";
|
||
|
|
||
|
tensorRTVersions = {
|
||
|
"8.6.1" = [
|
||
|
rec {
|
||
|
fileVersionCuda = "12.0";
|
||
|
fullVersion = "8.6.1.6";
|
||
|
sha256 = "sha256-D4FXpfxTKZQ7M4uJNZE3M1CvqQyoEjnNrddYDNHrolQ=";
|
||
|
tarball = tarballURL { inherit fileVersionCuda fullVersion; };
|
||
|
supportedCudaVersions = [ "12.0" "12.1" ];
|
||
|
}
|
||
|
];
|
||
|
"8.5.3" = [
|
||
|
rec {
|
||
|
fileVersionCuda = "11.8";
|
||
|
fileVersionCudnn = "8.6";
|
||
|
fullVersion = "8.5.3.1";
|
||
|
sha256 = "sha256-BNeuOYvPTUAfGxI0DVsNrX6Z/FAB28+SE0ptuGu7YDY=";
|
||
|
tarball = tarballURL { inherit fileVersionCuda fileVersionCudnn fullVersion; };
|
||
|
supportedCudaVersions = [ "11.0" "11.1" "11.2" "11.3" "11.4" "11.5" "11.6" "11.7" "11.8" ];
|
||
|
}
|
||
|
rec {
|
||
|
fileVersionCuda = "10.2";
|
||
|
fileVersionCudnn = "8.6";
|
||
|
fullVersion = "8.5.3.1";
|
||
|
sha256 = "sha256-WCt6yfOmFbrjqdYCj6AE2+s2uFpISwk6urP+2I0BnGQ=";
|
||
|
tarball = tarballURL { inherit fileVersionCuda fileVersionCudnn fullVersion; };
|
||
|
supportedCudaVersions = [ "10.2" ];
|
||
|
}
|
||
|
];
|
||
|
"8.5.2" = [
|
||
|
rec {
|
||
|
fileVersionCuda = "11.8";
|
||
|
fileVersionCudnn = "8.6";
|
||
|
fullVersion = "8.5.2.2";
|
||
|
sha256 = "sha256-Ov5irNS/JETpEz01FIFNMs9YVmjGHL7lSXmDpgCdgao=";
|
||
|
tarball = tarballURL { inherit fileVersionCuda fileVersionCudnn fullVersion; };
|
||
|
supportedCudaVersions = [ "11.0" "11.1" "11.2" "11.3" "11.4" "11.5" "11.6" "11.7" "11.8" ];
|
||
|
}
|
||
|
rec {
|
||
|
fileVersionCuda = "10.2";
|
||
|
fileVersionCudnn = "8.6";
|
||
|
fullVersion = "8.5.2.2";
|
||
|
sha256 = "sha256-UruwQShYcHLY5d81lKNG7XaoUsZr245c+PUpUN6pC5E=";
|
||
|
tarball = tarballURL { inherit fileVersionCuda fileVersionCudnn fullVersion; };
|
||
|
supportedCudaVersions = [ "10.2" ];
|
||
|
}
|
||
|
];
|
||
|
"8.5.1" = [
|
||
|
rec {
|
||
|
fileVersionCuda = "11.8";
|
||
|
fileVersionCudnn = "8.6";
|
||
|
fullVersion = "8.5.1.7";
|
||
|
sha256 = "sha256-Ocx/B3BX0TY3lOj/UcTPIaXb7M8RFrACC6Da4PMGMHY=";
|
||
|
tarball = tarballURL { inherit fileVersionCuda fileVersionCudnn fullVersion; };
|
||
|
supportedCudaVersions = [ "11.0" "11.1" "11.2" "11.3" "11.4" "11.5" "11.6" "11.7" "11.8" ];
|
||
|
}
|
||
|
rec {
|
||
|
fileVersionCuda = "10.2";
|
||
|
fileVersionCudnn = "8.6";
|
||
|
fullVersion = "8.5.1.7";
|
||
|
sha256 = "sha256-CcFGJhw7nFdPnSYYSxcto2MHK3F84nLQlJYjdIw8dPM=";
|
||
|
tarball = tarballURL { inherit fileVersionCuda fileVersionCudnn fullVersion; };
|
||
|
supportedCudaVersions = [ "10.2" ];
|
||
|
}
|
||
|
];
|
||
|
"8.4.0" = [
|
||
|
rec {
|
||
|
fileVersionCuda = "11.6";
|
||
|
fileVersionCudnn = "8.3";
|
||
|
fullVersion = "8.4.0.6";
|
||
|
sha256 = "sha256-DNgHHXF/G4cK2nnOWImrPXAkOcNW6Wy+8j0LRpAH/LQ=";
|
||
|
tarball = tarballURL { inherit fileVersionCuda fileVersionCudnn fullVersion; };
|
||
|
supportedCudaVersions = [ "11.0" "11.1" "11.2" "11.3" "11.4" "11.5" "11.6" ];
|
||
|
}
|
||
|
rec {
|
||
|
fileVersionCuda = "10.2";
|
||
|
fileVersionCudnn = "8.3";
|
||
|
fullVersion = "8.4.0.6";
|
||
|
sha256 = "sha256-aCzH0ZI6BrJ0v+e5Bnm7b8mNltA7NNuIa8qRKzAQv+I=";
|
||
|
tarball = tarballURL { inherit fileVersionCuda fileVersionCudnn fullVersion; };
|
||
|
supportedCudaVersions = [ "10.2" ];
|
||
|
}
|
||
|
];
|
||
|
};
|
||
|
|
||
|
# Default attributes
|
||
|
tensorRTDefaultVersion = {
|
||
|
"10.2" = "8.4.0";
|
||
|
"11.0" = "8.4.0";
|
||
|
"11.1" = "8.4.0";
|
||
|
"11.2" = "8.4.0";
|
||
|
"11.3" = "8.4.0";
|
||
|
"11.4" = "8.4.0";
|
||
|
"11.5" = "8.4.0";
|
||
|
"11.6" = "8.4.0";
|
||
|
"11.7" = "8.5.3";
|
||
|
"11.8" = "8.5.3";
|
||
|
"12.0" = "8.6.1";
|
||
|
"12.1" = "8.6.1";
|
||
|
}.${cudaVersion} or "8.4.0";
|
||
|
|
||
|
in tensorRTPackages
|