depot/third_party/nixpkgs/pkgs/development/python-modules/tinygrad/fix-dlopen-cuda.patch
Default email bcb2f287e1 Project import generated by Copybara.
GitOrigin-RevId: d603719ec6e294f034936c0d0dc06f689d91b6c3
2024-06-20 20:27:18 +05:30

32 lines
1 KiB
Diff

diff --git a/tinygrad/runtime/autogen/cuda.py b/tinygrad/runtime/autogen/cuda.py
index 359083a9..3cd5f7be 100644
--- a/tinygrad/runtime/autogen/cuda.py
+++ b/tinygrad/runtime/autogen/cuda.py
@@ -143,10 +143,25 @@ def char_pointer_cast(string, encoding='utf-8'):
return ctypes.cast(string, ctypes.POINTER(ctypes.c_char))
+NAME_TO_PATHS = {
+ "libcuda.so": ["@driverLink@/lib/libcuda.so"],
+ "libnvrtc.so": ["@libnvrtc@"],
+}
+def _try_dlopen(name):
+ try:
+ return ctypes.CDLL(name)
+ except OSError:
+ pass
+ for candidate in NAME_TO_PATHS.get(name, []):
+ try:
+ return ctypes.CDLL(candidate)
+ except OSError:
+ pass
+ raise RuntimeError(f"{name} not found")
_libraries = {}
-_libraries['libcuda.so'] = ctypes.CDLL(ctypes.util.find_library('cuda'))
-_libraries['libnvrtc.so'] = ctypes.CDLL(ctypes.util.find_library('nvrtc'))
+_libraries['libcuda.so'] = _try_dlopen('libcuda.so')
+_libraries['libnvrtc.so'] = _try_dlopen('libnvrtc.so')
cuuint32_t = ctypes.c_uint32