diff --git a/tinygrad/runtime/autogen/cuda.py b/tinygrad/runtime/autogen/cuda.py index a30c8f53..e2078ff6 100644 --- a/tinygrad/runtime/autogen/cuda.py +++ b/tinygrad/runtime/autogen/cuda.py @@ -145,7 +145,19 @@ def char_pointer_cast(string, encoding='utf-8'): _libraries = {} -_libraries['libcuda.so'] = ctypes.CDLL(ctypes.util.find_library('cuda')) +libcuda = None +try: + libcuda = ctypes.CDLL('libcuda.so') +except OSError: + pass +try: + libcuda = ctypes.CDLL('@driverLink@/lib/libcuda.so') +except OSError: + pass +if libcuda is None: + raise RuntimeError(f"`libcuda.so` not found") + +_libraries['libcuda.so'] = libcuda cuuint32_t = ctypes.c_uint32 diff --git a/tinygrad/runtime/autogen/nvrtc.py b/tinygrad/runtime/autogen/nvrtc.py index 6af74187..c5a6c6c4 100644 --- a/tinygrad/runtime/autogen/nvrtc.py +++ b/tinygrad/runtime/autogen/nvrtc.py @@ -10,7 +10,18 @@ import ctypes, ctypes.util _libraries = {} -_libraries['libnvrtc.so'] = ctypes.CDLL(ctypes.util.find_library('nvrtc')) +libnvrtc = None +try: + libnvrtc = ctypes.CDLL('libnvrtc.so') +except OSError: + pass +try: + libnvrtc = ctypes.CDLL('@libnvrtc@') +except OSError: + pass +if libnvrtc is None: + raise RuntimeError(f"`libnvrtc.so` not found") +_libraries['libnvrtc.so'] = ctypes.CDLL(libnvrtc) def string_cast(char_pointer, encoding='utf-8', errors='strict'): value = ctypes.cast(char_pointer, ctypes.c_char_p).value if value is not None and encoding is not None: