c7cb07f092
GitOrigin-RevId: 1536926ef5621b09bba54035ae2bb6d806d72ac8
17 lines
285 B
Nix
17 lines
285 B
Nix
{ jax
|
|
, jaxlib
|
|
, pkgs
|
|
}:
|
|
|
|
pkgs.writers.writePython3Bin "jax-test-cuda" { libraries = [ jax jaxlib ]; } ''
|
|
import jax
|
|
from jax import random
|
|
|
|
assert jax.devices()[0].platform == "gpu"
|
|
|
|
rng = random.PRNGKey(0)
|
|
x = random.normal(rng, (100, 100))
|
|
x @ x
|
|
|
|
print("success!")
|
|
''
|