{
  jax,
  jaxlib,
  pkgs,
}:

pkgs.writers.writePython3Bin "jax-test-cuda"
  {
    libraries = [
      jax
      jaxlib
    ];
  }
  ''
    import jax
    from jax import random

    assert jax.devices()[0].platform == "gpu"

    rng = random.PRNGKey(0)
    x = random.normal(rng, (100, 100))
    x @ x

    print("success!")
  ''