depot/third_party/nixpkgs/pkgs/development/python-modules/torch/gpu-checks.nix
Default email 98eb3e9ef5 Project import generated by Copybara.
GitOrigin-RevId: 00d80d13810dbfea8ab4ed1009b09100cca86ba8
2024-07-01 15:47:52 +00:00

40 lines
817 B
Nix

{
lib,
torchWithCuda,
torchWithRocm,
callPackage,
}:
let
accelAvailable =
{
feature,
versionAttr,
torch,
cudaPackages,
}:
cudaPackages.writeGpuTestPython
{
inherit feature;
libraries = [ torch ];
name = "${feature}Available";
}
''
import torch
message = f"{torch.cuda.is_available()=} and {torch.version.${versionAttr}=}"
assert torch.cuda.is_available() and torch.version.${versionAttr}, message
print(message)
'';
in
{
tester-cudaAvailable = callPackage accelAvailable {
feature = "cuda";
versionAttr = "cuda";
torch = torchWithCuda;
};
tester-rocmAvailable = callPackage accelAvailable {
feature = "rocm";
versionAttr = "hip";
torch = torchWithRocm;
};
}