depot/pkgs/development/libraries/science/math/tensorrt/extension.nix

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

157 lines
5.6 KiB
Nix
Raw Normal View History

final: prev: let
inherit (final) callPackage;
inherit (prev) cudatoolkit cudaVersion lib pkgs;
### TensorRT
buildTensorRTPackage = args:
callPackage ./generic.nix { } args;
toUnderscore = str: lib.replaceStrings ["."] ["_"] str;
majorMinorPatch = str: lib.concatStringsSep "." (lib.take 3 (lib.splitVersion str));
tensorRTPackages = let
# Check whether a file is supported for our cuda version
isSupported = fileData: lib.elem cudaVersion fileData.supportedCudaVersions;
# Return the first file that is supported. In practice there should only ever be one anyway.
supportedFile = files: lib.findFirst isSupported null files;
# Compute versioned attribute name to be used in this package set
computeName = version: "tensorrt_${toUnderscore version}";
# Supported versions with versions as keys and file as value
supportedVersions = lib.recursiveUpdate
{
tensorrt = {
enable = false;
fileVersionCuda = null;
fileVersionCudnn = null;
fullVersion = "0.0.0";
sha256 = null;
tarball = null;
supportedCudaVersions = [ ];
};
}
(lib.mapAttrs' (version: attrs: lib.nameValuePair (computeName version) attrs)
(lib.filterAttrs (version: file: file != null) (lib.mapAttrs (version: files: supportedFile files) tensorRTVersions)));
# Add all supported builds as attributes
allBuilds = lib.mapAttrs (name: file: buildTensorRTPackage (lib.removeAttrs file ["fileVersionCuda"])) supportedVersions;
# Set the default attributes, e.g. tensorrt = tensorrt_8_4;
defaultName = computeName tensorRTDefaultVersion;
defaultBuild = lib.optionalAttrs (allBuilds ? ${defaultName}) { tensorrt = allBuilds.${computeName tensorRTDefaultVersion}; };
in {
inherit buildTensorRTPackage;
} // allBuilds // defaultBuild;
tarballURL =
{fullVersion, fileVersionCuda, fileVersionCudnn ? null} :
"TensorRT-${fullVersion}.Linux.x86_64-gnu.cuda-${fileVersionCuda}"
+ lib.optionalString (fileVersionCudnn != null) ".cudnn${fileVersionCudnn}"
+ ".tar.gz";
tensorRTVersions = {
"8.6.1" = [
rec {
fileVersionCuda = "12.0";
fullVersion = "8.6.1.6";
sha256 = "sha256-D4FXpfxTKZQ7M4uJNZE3M1CvqQyoEjnNrddYDNHrolQ=";
tarball = tarballURL { inherit fileVersionCuda fullVersion; };
supportedCudaVersions = [ "12.0" "12.1" ];
}
];
"8.5.3" = [
rec {
fileVersionCuda = "11.8";
fileVersionCudnn = "8.6";
fullVersion = "8.5.3.1";
sha256 = "sha256-BNeuOYvPTUAfGxI0DVsNrX6Z/FAB28+SE0ptuGu7YDY=";
tarball = tarballURL { inherit fileVersionCuda fileVersionCudnn fullVersion; };
supportedCudaVersions = [ "11.0" "11.1" "11.2" "11.3" "11.4" "11.5" "11.6" "11.7" "11.8" ];
}
rec {
fileVersionCuda = "10.2";
fileVersionCudnn = "8.6";
fullVersion = "8.5.3.1";
sha256 = "sha256-WCt6yfOmFbrjqdYCj6AE2+s2uFpISwk6urP+2I0BnGQ=";
tarball = tarballURL { inherit fileVersionCuda fileVersionCudnn fullVersion; };
supportedCudaVersions = [ "10.2" ];
}
];
"8.5.2" = [
rec {
fileVersionCuda = "11.8";
fileVersionCudnn = "8.6";
fullVersion = "8.5.2.2";
sha256 = "sha256-Ov5irNS/JETpEz01FIFNMs9YVmjGHL7lSXmDpgCdgao=";
tarball = tarballURL { inherit fileVersionCuda fileVersionCudnn fullVersion; };
supportedCudaVersions = [ "11.0" "11.1" "11.2" "11.3" "11.4" "11.5" "11.6" "11.7" "11.8" ];
}
rec {
fileVersionCuda = "10.2";
fileVersionCudnn = "8.6";
fullVersion = "8.5.2.2";
sha256 = "sha256-UruwQShYcHLY5d81lKNG7XaoUsZr245c+PUpUN6pC5E=";
tarball = tarballURL { inherit fileVersionCuda fileVersionCudnn fullVersion; };
supportedCudaVersions = [ "10.2" ];
}
];
"8.5.1" = [
rec {
fileVersionCuda = "11.8";
fileVersionCudnn = "8.6";
fullVersion = "8.5.1.7";
sha256 = "sha256-Ocx/B3BX0TY3lOj/UcTPIaXb7M8RFrACC6Da4PMGMHY=";
tarball = tarballURL { inherit fileVersionCuda fileVersionCudnn fullVersion; };
supportedCudaVersions = [ "11.0" "11.1" "11.2" "11.3" "11.4" "11.5" "11.6" "11.7" "11.8" ];
}
rec {
fileVersionCuda = "10.2";
fileVersionCudnn = "8.6";
fullVersion = "8.5.1.7";
sha256 = "sha256-CcFGJhw7nFdPnSYYSxcto2MHK3F84nLQlJYjdIw8dPM=";
tarball = tarballURL { inherit fileVersionCuda fileVersionCudnn fullVersion; };
supportedCudaVersions = [ "10.2" ];
}
];
"8.4.0" = [
rec {
fileVersionCuda = "11.6";
fileVersionCudnn = "8.3";
fullVersion = "8.4.0.6";
sha256 = "sha256-DNgHHXF/G4cK2nnOWImrPXAkOcNW6Wy+8j0LRpAH/LQ=";
tarball = tarballURL { inherit fileVersionCuda fileVersionCudnn fullVersion; };
supportedCudaVersions = [ "11.0" "11.1" "11.2" "11.3" "11.4" "11.5" "11.6" ];
}
rec {
fileVersionCuda = "10.2";
fileVersionCudnn = "8.3";
fullVersion = "8.4.0.6";
sha256 = "sha256-aCzH0ZI6BrJ0v+e5Bnm7b8mNltA7NNuIa8qRKzAQv+I=";
tarball = tarballURL { inherit fileVersionCuda fileVersionCudnn fullVersion; };
supportedCudaVersions = [ "10.2" ];
}
];
};
# Default attributes
tensorRTDefaultVersion = {
"10.2" = "8.4.0";
"11.0" = "8.4.0";
"11.1" = "8.4.0";
"11.2" = "8.4.0";
"11.3" = "8.4.0";
"11.4" = "8.4.0";
"11.5" = "8.4.0";
"11.6" = "8.4.0";
"11.7" = "8.5.3";
"11.8" = "8.5.3";
"12.0" = "8.6.1";
"12.1" = "8.6.1";
}.${cudaVersion} or "8.4.0";
in tensorRTPackages