depot/third_party/nixpkgs/pkgs/development/python-modules/jax/test-cuda.nix
Default email fa5436e0a7 Project import generated by Copybara.
GitOrigin-RevId: e8057b67ebf307f01bdcc8fba94d94f75039d1f6
2024-06-05 17:53:02 +02:00

25 lines
332 B
Nix

{
jax,
jaxlib,
pkgs,
}:
pkgs.writers.writePython3Bin "jax-test-cuda"
{
libraries = [
jax
jaxlib
];
}
''
import jax
from jax import random
assert jax.devices()[0].platform == "gpu"
rng = random.PRNGKey(0)
x = random.normal(rng, (100, 100))
x @ x
print("success!")
''