{ stdenv , lib , buildPythonPackage , fetchFromGitHub , fetchpatch , pythonAtLeast , pythonOlder , pytestCheckHook , setuptools , numpy , packaging , psutil , pyyaml , torch , evaluate , parameterized , transformers }: buildPythonPackage rec { pname = "accelerate"; version = "0.21.0"; format = "pyproject"; disabled = pythonOlder "3.7"; src = fetchFromGitHub { owner = "huggingface"; repo = pname; rev = "refs/tags/v${version}"; hash = "sha256-BwM3gyNhsRkxtxLNrycUGwBmXf8eq/7b56/ykMryt5w="; }; patches = [ # fix import error when torch>=2.0.1 and torch.distributed is disabled # https://github.com/huggingface/accelerate/pull/1800 (fetchpatch { url = "https://github.com/huggingface/accelerate/commit/32701039d302d3875c50c35ab3e76c467755eae9.patch"; hash = "sha256-Hth7qyOfx1sC8UaRdbYTnyRXD/VRKf41GtLc0ee1t2I="; }) ]; nativeBuildInputs = [ setuptools ]; propagatedBuildInputs = [ numpy packaging psutil pyyaml torch ]; nativeCheckInputs = [ evaluate parameterized pytestCheckHook transformers ]; preCheck = '' export HOME=$(mktemp -d) export PATH=$out/bin:$PATH ''; pytestFlagsArray = [ "tests" ]; disabledTests = [ # try to download data: "FeatureExamplesTests" "test_infer_auto_device_map_on_t0pp" ] ++ lib.optionals (stdenv.isLinux && stdenv.isAarch64) [ # usual aarch64-linux RuntimeError: DataLoader worker (pid(s) <...>) exited unexpectedly "CheckpointTest" ] ++ lib.optionals (stdenv.isDarwin && stdenv.isx86_64) [ # RuntimeError: torch_shm_manager: execl failed: Permission denied "CheckpointTest" ] ++ lib.optionals (pythonAtLeast "3.11") [ # python3.11 not yet supported for torch.compile "test_dynamo_extract_model" ]; disabledTestPaths = lib.optionals (!(stdenv.isLinux && stdenv.isx86_64)) [ # numerous instances of torch.multiprocessing.spawn.ProcessRaisedException: "tests/test_cpu.py" "tests/test_grad_sync.py" "tests/test_metrics.py" "tests/test_scheduler.py" ]; pythonImportsCheck = [ "accelerate" ]; meta = with lib; { homepage = "https://huggingface.co/docs/accelerate"; description = "A simple way to train and use PyTorch models with multi-GPU, TPU, mixed-precision"; changelog = "https://github.com/huggingface/accelerate/releases/tag/v${version}"; license = licenses.asl20; maintainers = with maintainers; [ bcdarwin ]; mainProgram = "accelerate"; }; }