depot/third_party/nixpkgs/pkgs/development/python-modules/jax/test-cuda.nix
Default email c7cb07f092 Project import generated by Copybara.
GitOrigin-RevId: 1536926ef5621b09bba54035ae2bb6d806d72ac8
2024-02-29 21:09:43 +01:00

17 lines
285 B
Nix

{ jax
, jaxlib
, pkgs
}:
pkgs.writers.writePython3Bin "jax-test-cuda" { libraries = [ jax jaxlib ]; } ''
import jax
from jax import random
assert jax.devices()[0].platform == "gpu"
rng = random.PRNGKey(0)
x = random.normal(rng, (100, 100))
x @ x
print("success!")
''