2022-06-16 17:23:12 +00:00
{ buildPythonPackage
, expecttest
, fetchFromGitHub
, lib
, ninja
, pytestCheckHook
, python
2022-09-09 14:08:57 +00:00
, torch
2022-06-26 10:26:21 +00:00
, pybind11
2022-06-16 17:23:12 +00:00
, which
} :
buildPythonPackage rec {
pname = " f u n c t o r c h " ;
2022-08-12 12:06:08 +00:00
version = " 0 . 2 . 0 " ;
2022-06-16 17:23:12 +00:00
format = " s e t u p t o o l s " ;
src = fetchFromGitHub {
owner = " p y t o r c h " ;
repo = pname ;
2022-08-12 12:06:08 +00:00
rev = " r e f s / t a g s / v ${ version } " ;
hash = " s h a 2 5 6 - 3 3 s k K k 5 a A I H n + 1 1 4 9 i f o l X P A + t p Q + W R O A Z v w P e G B b r A = " ;
2022-06-16 17:23:12 +00:00
} ;
# Somewhat surprisingly pytorch is actually necessary for the build process.
# `setup.py` imports `torch.utils.cpp_extension`.
nativeBuildInputs = [
ninja
2022-09-09 14:08:57 +00:00
torch
2022-06-16 17:23:12 +00:00
which
] ;
2022-06-26 10:26:21 +00:00
buildInputs = [
pybind11
] ;
2022-06-16 17:23:12 +00:00
preCheck = ''
rm - rf functorch /
'' ;
checkInputs = [
expecttest
pytestCheckHook
] ;
# See https://github.com/pytorch/functorch/issues/835.
disabledTests = [
# RuntimeError: ("('...', '') is in PyTorch's OpInfo db ", "but is not in functorch's OpInfo db. Please regenerate ", '... and add the new tests to ', 'denylists if necessary.')
" t e s t _ c o v e r a g e _ b e r n o u l l i _ c p u _ f l o a t 3 2 "
" t e s t _ c o v e r a g e _ c o l u m n _ s t a c k _ c p u _ f l o a t 3 2 "
" t e s t _ c o v e r a g e _ d i a g f l a t _ c p u _ f l o a t 3 2 "
" t e s t _ c o v e r a g e _ f l a t t e n _ c p u _ f l o a t 3 2 "
" t e s t _ c o v e r a g e _ l i n a l g _ l u _ f a c t o r _ c p u _ f l o a t 3 2 "
" t e s t _ c o v e r a g e _ l i n a l g _ l u _ f a c t o r _ e x _ c p u _ f l o a t 3 2 "
" t e s t _ c o v e r a g e _ m u l t i n o m i a l _ c p u _ f l o a t 3 2 "
" t e s t _ c o v e r a g e _ n n _ f u n c t i o n a l _ d r o p o u t 2 d _ c p u _ f l o a t 3 2 "
" t e s t _ c o v e r a g e _ n n _ f u n c t i o n a l _ f e a t u r e _ a l p h a _ d r o p o u t _ w i t h _ t r a i n _ c p u _ f l o a t 3 2 "
" t e s t _ c o v e r a g e _ n n _ f u n c t i o n a l _ f e a t u r e _ a l p h a _ d r o p o u t _ w i t h o u t _ t r a i n _ c p u _ f l o a t 3 2 "
" t e s t _ c o v e r a g e _ n n _ f u n c t i o n a l _ k l _ d i v _ c p u _ f l o a t 3 2 "
" t e s t _ c o v e r a g e _ n o r m a l _ c p u _ f l o a t 3 2 "
" t e s t _ c o v e r a g e _ n o r m a l _ n u m b e r _ m e a n _ c p u _ f l o a t 3 2 "
" t e s t _ c o v e r a g e _ p c a _ l o w r a n k _ c p u _ f l o a t 3 2 "
" t e s t _ c o v e r a g e _ r o u n d _ d e c i m a l s _ 0 _ c p u _ f l o a t 3 2 "
" t e s t _ c o v e r a g e _ r o u n d _ d e c i m a l s _ 3 _ c p u _ f l o a t 3 2 "
" t e s t _ c o v e r a g e _ r o u n d _ d e c i m a l s _ n e g _ 3 _ c p u _ f l o a t 3 2 "
" t e s t _ c o v e r a g e _ s c a t t e r _ r e d u c e _ c p u _ f l o a t 3 2 "
" t e s t _ c o v e r a g e _ s v d _ l o w r a n k _ c p u _ f l o a t 3 2 "
# > self.assertEqual(len(functorch_lagging_op_db), len(op_db))
# E AssertionError: Scalars are not equal!
# E
# E Absolute difference: 19
# E Relative difference: 0.03525046382189239
" t e s t _ f u n c t o r c h _ l a g g i n g _ o p _ d b _ h a s _ o p i n f o s _ c p u "
# RuntimeError: PyTorch not compiled with LLVM support!
" t e s t _ b i a s _ g e l u "
" t e s t _ b i n a r y _ o p s "
" t e s t _ b r o a d c a s t 1 "
" t e s t _ b r o a d c a s t 2 "
" t e s t _ f l o a t _ d o u b l e "
" t e s t _ f l o a t _ i n t "
" t e s t _ f x _ t r a c e "
" t e s t _ i n t _ l o n g "
" t e s t _ i s s u e 5 7 6 1 1 "
" t e s t _ s l i c e 1 "
" t e s t _ s l i c e 2 "
" t e s t _ t r a n s p o s e d 1 "
" t e s t _ t r a n s p o s e d 2 "
" t e s t _ u n a r y _ o p s "
] ;
pythonImportsCheck = [ " f u n c t o r c h " ] ;
meta = with lib ; {
description = " J A X - l i k e c o m p o s a b l e f u n c t i o n t r a n s f o r m s f o r P y T o r c h " ;
homepage = " h t t p s : / / p y t o r c h . o r g / f u n c t o r c h " ;
license = licenses . bsd3 ;
maintainers = with maintainers ; [ samuela ] ;
# See https://github.com/NixOS/nixpkgs/pull/174248#issuecomment-1139895064.
platforms = platforms . x86_64 ;
} ;
}