c7cb07f092
GitOrigin-RevId: 1536926ef5621b09bba54035ae2bb6d806d72ac8
155 lines
3.3 KiB
Nix
155 lines
3.3 KiB
Nix
{ lib
|
|
, stdenv
|
|
, buildPythonPackage
|
|
, fetchFromGitHub
|
|
, pythonOlder
|
|
, writeText
|
|
, setuptools
|
|
, wheel
|
|
, filelock
|
|
, huggingface-hub
|
|
, importlib-metadata
|
|
, numpy
|
|
, pillow
|
|
, regex
|
|
, requests
|
|
, safetensors
|
|
# optional dependencies
|
|
, accelerate
|
|
, datasets
|
|
, flax
|
|
, jax
|
|
, jaxlib
|
|
, jinja2
|
|
, peft
|
|
, protobuf
|
|
, tensorboard
|
|
, torch
|
|
# test dependencies
|
|
, parameterized
|
|
, pytest-timeout
|
|
, pytest-xdist
|
|
, pytestCheckHook
|
|
, requests-mock
|
|
, scipy
|
|
, sentencepiece
|
|
, torchsde
|
|
, transformers
|
|
}:
|
|
|
|
buildPythonPackage rec {
|
|
pname = "diffusers";
|
|
version = "0.26.3";
|
|
pyproject = true;
|
|
|
|
disabled = pythonOlder "3.8";
|
|
|
|
src = fetchFromGitHub {
|
|
owner = "huggingface";
|
|
repo = "diffusers";
|
|
rev = "refs/tags/v${version}";
|
|
hash = "sha256-1pIe1OU+vIrHM6KIZtHRMXklBZrugDV+I/OBNQYqvXI=";
|
|
};
|
|
|
|
nativeBuildInputs = [
|
|
setuptools
|
|
wheel
|
|
];
|
|
|
|
propagatedBuildInputs = [
|
|
filelock
|
|
huggingface-hub
|
|
importlib-metadata
|
|
numpy
|
|
pillow
|
|
regex
|
|
requests
|
|
safetensors
|
|
];
|
|
|
|
passthru.optional-dependencies = {
|
|
flax = [
|
|
flax
|
|
jax
|
|
jaxlib
|
|
];
|
|
torch = [
|
|
accelerate
|
|
torch
|
|
];
|
|
training = [
|
|
accelerate
|
|
datasets
|
|
jinja2
|
|
peft
|
|
protobuf
|
|
tensorboard
|
|
];
|
|
};
|
|
|
|
pythonImportsCheck = [
|
|
"diffusers"
|
|
];
|
|
|
|
# tests crash due to torch segmentation fault
|
|
doCheck = !(stdenv.isLinux && stdenv.isAarch64);
|
|
|
|
nativeCheckInputs = [
|
|
parameterized
|
|
pytest-timeout
|
|
pytest-xdist
|
|
pytestCheckHook
|
|
requests-mock
|
|
scipy
|
|
sentencepiece
|
|
torchsde
|
|
transformers
|
|
] ++ passthru.optional-dependencies.torch;
|
|
|
|
preCheck = let
|
|
# This pytest hook mocks and catches attempts at accessing the network
|
|
# tests that try to access the network will raise, get caught, be marked as skipped and tagged as xfailed.
|
|
# cf. python3Packages.shap
|
|
conftestSkipNetworkErrors = writeText "conftest.py" ''
|
|
from _pytest.runner import pytest_runtest_makereport as orig_pytest_runtest_makereport
|
|
import urllib3
|
|
|
|
class NetworkAccessDeniedError(RuntimeError): pass
|
|
def deny_network_access(*a, **kw):
|
|
raise NetworkAccessDeniedError
|
|
|
|
urllib3.connection.HTTPSConnection._new_conn = deny_network_access
|
|
|
|
def pytest_runtest_makereport(item, call):
|
|
tr = orig_pytest_runtest_makereport(item, call)
|
|
if call.excinfo is not None and call.excinfo.type is NetworkAccessDeniedError:
|
|
tr.outcome = 'skipped'
|
|
tr.wasxfail = "reason: Requires network access."
|
|
return tr
|
|
'';
|
|
in ''
|
|
export HOME=$TMPDIR
|
|
cat ${conftestSkipNetworkErrors} >> tests/conftest.py
|
|
'';
|
|
|
|
pytestFlagsArray = [
|
|
"tests/"
|
|
];
|
|
|
|
disabledTests = [
|
|
# depends on current working directory
|
|
"test_deprecate_stacklevel"
|
|
# fails due to precision of floating point numbers
|
|
"test_model_cpu_offload_forward_pass"
|
|
# tries to run ruff which we have intentionally removed from nativeCheckInputs
|
|
"test_is_copy_consistent"
|
|
];
|
|
|
|
meta = with lib; {
|
|
description = "State-of-the-art diffusion models for image and audio generation in PyTorch";
|
|
homepage = "https://github.com/huggingface/diffusers";
|
|
changelog = "https://github.com/huggingface/diffusers/releases/tag/${src.rev}";
|
|
license = licenses.asl20;
|
|
maintainers = with maintainers; [ natsukium ];
|
|
};
|
|
}
|