depot/third_party/nixpkgs/pkgs/development/python-modules/numpyro/default.nix
Default email 7e47f3658e Project import generated by Copybara.
GitOrigin-RevId: 1925c603f17fc89f4c8f6bf6f631a802ad85d784
2024-09-26 11:04:55 +00:00

130 lines
2.8 KiB
Nix

{
lib,
buildPythonPackage,
fetchFromGitHub,
# build-system
setuptools,
# dependencies
jax,
jaxlib,
multipledispatch,
numpy,
tqdm,
# tests
# Our current version of tensorflow (2.13.0) is too old and doesn't support python>=3.12
# We remove optional test dependencies that require tensorflow and skip the corresponding tests to
# avoid introducing a useless incompatibility with python 3.12:
# dm-haiku,
# flax,
# tensorflow-probability,
funsor,
graphviz,
optax,
pyro-api,
pytestCheckHook,
scikit-learn,
}:
buildPythonPackage rec {
pname = "numpyro";
version = "0.15.3";
pyproject = true;
src = fetchFromGitHub {
owner = "pyro-ppl";
repo = "numpyro";
rev = "refs/tags/${version}";
hash = "sha256-g+ep221hhLbCjQasKpiEAXkygI5A3Hglqo1tV8lv5eg=";
};
build-system = [ setuptools ];
dependencies = [
jax
jaxlib
multipledispatch
numpy
tqdm
];
nativeCheckInputs = [
# dm-haiku
# flax
funsor
graphviz
optax
pyro-api
pytestCheckHook
scikit-learn
# tensorflow-probability
];
pythonImportsCheck = [ "numpyro" ];
disabledTests = [
# AssertionError due to tolerance issues
"test_beta_binomial_log_prob"
"test_collapse_beta"
"test_cpu"
"test_gamma_poisson"
"test_gof"
"test_hpdi"
"test_kl_dirichlet_dirichlet"
"test_kl_univariate"
"test_mean_var"
# Tests want to download data
"data_load"
"test_jsb_chorales"
# RuntimeWarning: overflow encountered in cast
"test_zero_inflated_logits_probs_agree"
# NameError: unbound axis name: _provenance
"test_model_transformation"
# require dm-haiku
"test_flax_state_dropout_smoke"
"test_flax_module"
"test_random_module_mcmc"
# require flax
"test_haiku_state_dropout_smoke"
"test_haiku_module"
"test_random_module_mcmc"
# require tensorflow-probability
"test_modified_bessel_first_kind_vect"
"test_diag_spectral_density_periodic"
"test_kernel_approx_periodic"
"test_modified_bessel_first_kind_one_dim"
"test_modified_bessel_first_kind_vect"
"test_periodic_gp_one_dim_model"
"test_no_tracer_leak_at_lazy_property_sample"
# flaky on darwin
# TODO: uncomment at next release (0.15.4) as it has been fixed:
# https://github.com/pyro-ppl/numpyro/pull/1863
"test_change_point_x64"
];
disabledTestPaths = [
# require jaxns (unpackaged)
"test/contrib/test_nested_sampling.py"
# requires tensorflow-probability
"test/contrib/test_tfp.py"
"test/test_distributions.py"
];
meta = {
description = "Library for probabilistic programming with NumPy";
homepage = "https://num.pyro.ai/";
changelog = "https://github.com/pyro-ppl/numpyro/releases/tag/${version}";
license = lib.licenses.asl20;
maintainers = with lib.maintainers; [ fab ];
};
}