depot/third_party/nixpkgs/pkgs/development/python-modules/tinygrad/fix-dlopen-cuda.patch
Default email 472aeafc57 Project import generated by Copybara.
GitOrigin-RevId: c31898adf5a8ed202ce5bea9f347b1c6871f32d1
2024-10-04 18:56:33 +02:00

49 lines
1.4 KiB
Diff

diff --git a/tinygrad/runtime/autogen/cuda.py b/tinygrad/runtime/autogen/cuda.py
index a30c8f53..e2078ff6 100644
--- a/tinygrad/runtime/autogen/cuda.py
+++ b/tinygrad/runtime/autogen/cuda.py
@@ -145,7 +145,19 @@ def char_pointer_cast(string, encoding='utf-8'):
_libraries = {}
-_libraries['libcuda.so'] = ctypes.CDLL(ctypes.util.find_library('cuda'))
+libcuda = None
+try:
+ libcuda = ctypes.CDLL('libcuda.so')
+except OSError:
+ pass
+try:
+ libcuda = ctypes.CDLL('@driverLink@/lib/libcuda.so')
+except OSError:
+ pass
+if libcuda is None:
+ raise RuntimeError(f"`libcuda.so` not found")
+
+_libraries['libcuda.so'] = libcuda
cuuint32_t = ctypes.c_uint32
diff --git a/tinygrad/runtime/autogen/nvrtc.py b/tinygrad/runtime/autogen/nvrtc.py
index 6af74187..c5a6c6c4 100644
--- a/tinygrad/runtime/autogen/nvrtc.py
+++ b/tinygrad/runtime/autogen/nvrtc.py
@@ -10,7 +10,18 @@ import ctypes, ctypes.util
_libraries = {}
-_libraries['libnvrtc.so'] = ctypes.CDLL(ctypes.util.find_library('nvrtc'))
+libnvrtc = None
+try:
+ libnvrtc = ctypes.CDLL('libnvrtc.so')
+except OSError:
+ pass
+try:
+ libnvrtc = ctypes.CDLL('@libnvrtc@')
+except OSError:
+ pass
+if libnvrtc is None:
+ raise RuntimeError(f"`libnvrtc.so` not found")
+_libraries['libnvrtc.so'] = libnvrtc
def string_cast(char_pointer, encoding='utf-8', errors='strict'):
value = ctypes.cast(char_pointer, ctypes.c_char_p).value
if value is not None and encoding is not None: