depot/third_party/nixpkgs/pkgs/development/cuda-modules/generic-builders/multiplex.nix
Default email 504525a148 Project import generated by Copybara.
GitOrigin-RevId: bd645e8668ec6612439a9ee7e71f7eac4099d4f6
2024-01-02 12:29:13 +01:00

134 lines
4.3 KiB
Nix

{
# callPackage-provided arguments
lib,
cudaVersion,
flags,
hostPlatform,
# Expected to be passed by the caller
mkVersionedPackageName,
# pname :: String
pname,
# releasesModule :: Path
# A path to a module which provides a `releases` attribute
releasesModule,
# shims :: Path
# A path to a module which provides a `shims` attribute
# The redistribRelease is only used in ./manifest.nix for the package version
# and the package description (which NVIDIA's manifest calls the "name").
# It's also used for fetching the source, but we override that since we can't
# re-use that portion of the functionality (different URLs, etc.).
# The featureRelease is used to populate meta.platforms (by way of looking at the attribute names)
# and to determine the outputs of the package.
# shimFn :: {package, redistArch} -> AttrSet
shimsFn ? ({package, redistArch}: throw "shimsFn must be provided"),
# fixupFn :: Path
# A path (or nix expression) to be evaluated with callPackage and then
# provided to the package's overrideAttrs function.
# It must accept at least the following arguments:
# - final
# - cudaVersion
# - mkVersionedPackageName
# - package
fixupFn ? (
{
final,
cudaVersion,
mkVersionedPackageName,
package,
...
}:
throw "fixupFn must be provided"
),
}:
let
inherit (lib)
attrsets
lists
modules
strings
;
evaluatedModules = modules.evalModules {
modules = [
../modules
releasesModule
];
};
# NOTE: Important types:
# - Releases: ../modules/${pname}/releases/releases.nix
# - Package: ../modules/${pname}/releases/package.nix
# FIXME: do this at the module system level
propagatePlatforms = lib.mapAttrs (platform: subset: map (r: r // { inherit platform; }) subset);
# All releases across all platforms
# See ../modules/${pname}/releases/releases.nix
releaseSets = propagatePlatforms evaluatedModules.config.${pname}.releases;
# Compute versioned attribute name to be used in this package set
# Patch version changes should not break the build, so we only use major and minor
# computeName :: Package -> String
computeName = {version, ...}: mkVersionedPackageName pname version;
# Check whether a package supports our CUDA version
# isSupported :: Package -> Bool
isSupported =
package:
!(strings.hasPrefix "unsupported" package.platform)
&& strings.versionAtLeast cudaVersion package.minCudaVersion
&& strings.versionAtLeast package.maxCudaVersion cudaVersion;
# Get all of the packages for our given platform.
redistArch = flags.getRedistArch hostPlatform.system;
allReleases = builtins.concatMap (xs: xs) (builtins.attrValues releaseSets);
# All the supported packages we can build for our platform.
# perSystemReleases :: List Package
perSystemReleases = releaseSets.${redistArch} or [ ];
preferable =
p1: p2: (isSupported p2 -> isSupported p1) && (strings.versionAtLeast p1.version p2.version);
newest = builtins.head (builtins.sort preferable allReleases);
# A function which takes the `final` overlay and the `package` being built and returns
# a function to be consumed via `overrideAttrs`.
overrideAttrsFixupFn =
final: package:
final.callPackage fixupFn {
inherit
final
cudaVersion
mkVersionedPackageName
package
;
};
extension =
final: _:
let
# Builds our package into derivation and wraps it in a nameValuePair, where the name is the versioned name
# of the package.
buildPackage =
package:
let
shims = final.callPackage shimsFn {inherit package redistArch;};
name = computeName package;
drv = final.callPackage ./manifest.nix {
inherit pname;
redistName = pname;
inherit (shims) redistribRelease featureRelease;
};
fixedDrv = drv.overrideAttrs (overrideAttrsFixupFn final package);
in
attrsets.nameValuePair name fixedDrv;
# versionedDerivations :: AttrSet Derivation
versionedDerivations = builtins.listToAttrs (lists.map buildPackage perSystemReleases);
defaultDerivation = { ${pname} = (buildPackage newest).value; };
in
versionedDerivations // defaultDerivation;
in
extension