diff --git a/tinygrad/runtime/autogen/cuda.py b/tinygrad/runtime/autogen/cuda.py index 359083a9..3cd5f7be 100644 --- a/tinygrad/runtime/autogen/cuda.py +++ b/tinygrad/runtime/autogen/cuda.py @@ -143,10 +143,25 @@ def char_pointer_cast(string, encoding='utf-8'): return ctypes.cast(string, ctypes.POINTER(ctypes.c_char)) +NAME_TO_PATHS = { + "libcuda.so": ["@driverLink@/lib/libcuda.so"], + "libnvrtc.so": ["@libnvrtc@"], +} +def _try_dlopen(name): + try: + return ctypes.CDLL(name) + except OSError: + pass + for candidate in NAME_TO_PATHS.get(name, []): + try: + return ctypes.CDLL(candidate) + except OSError: + pass + raise RuntimeError(f"{name} not found") _libraries = {} -_libraries['libcuda.so'] = ctypes.CDLL(ctypes.util.find_library('cuda')) -_libraries['libnvrtc.so'] = ctypes.CDLL(ctypes.util.find_library('nvrtc')) +_libraries['libcuda.so'] = _try_dlopen('libcuda.so') +_libraries['libnvrtc.so'] = _try_dlopen('libnvrtc.so') cuuint32_t = ctypes.c_uint32