2024-06-05 15:53:02 +00:00
|
|
|
{
|
|
|
|
jax,
|
|
|
|
jaxlib,
|
|
|
|
pkgs,
|
2024-02-29 20:09:43 +00:00
|
|
|
}:
|
|
|
|
|
2024-06-05 15:53:02 +00:00
|
|
|
pkgs.writers.writePython3Bin "jax-test-cuda"
|
|
|
|
{
|
|
|
|
libraries = [
|
|
|
|
jax
|
|
|
|
jaxlib
|
|
|
|
];
|
|
|
|
}
|
|
|
|
''
|
|
|
|
import jax
|
|
|
|
from jax import random
|
2024-02-29 20:09:43 +00:00
|
|
|
|
2024-06-05 15:53:02 +00:00
|
|
|
assert jax.devices()[0].platform == "gpu"
|
2024-02-29 20:09:43 +00:00
|
|
|
|
2024-06-05 15:53:02 +00:00
|
|
|
rng = random.PRNGKey(0)
|
|
|
|
x = random.normal(rng, (100, 100))
|
|
|
|
x @ x
|
2024-02-29 20:09:43 +00:00
|
|
|
|
2024-06-05 15:53:02 +00:00
|
|
|
print("success!")
|
|
|
|
''
|