472aeafc57
GitOrigin-RevId: c31898adf5a8ed202ce5bea9f347b1c6871f32d1
49 lines
1.4 KiB
Diff
49 lines
1.4 KiB
Diff
diff --git a/tinygrad/runtime/autogen/cuda.py b/tinygrad/runtime/autogen/cuda.py
|
|
index a30c8f53..e2078ff6 100644
|
|
--- a/tinygrad/runtime/autogen/cuda.py
|
|
+++ b/tinygrad/runtime/autogen/cuda.py
|
|
@@ -145,7 +145,19 @@ def char_pointer_cast(string, encoding='utf-8'):
|
|
|
|
|
|
_libraries = {}
|
|
-_libraries['libcuda.so'] = ctypes.CDLL(ctypes.util.find_library('cuda'))
|
|
+libcuda = None
|
|
+try:
|
|
+ libcuda = ctypes.CDLL('libcuda.so')
|
|
+except OSError:
|
|
+ pass
|
|
+try:
|
|
+ libcuda = ctypes.CDLL('@driverLink@/lib/libcuda.so')
|
|
+except OSError:
|
|
+ pass
|
|
+if libcuda is None:
|
|
+ raise RuntimeError(f"`libcuda.so` not found")
|
|
+
|
|
+_libraries['libcuda.so'] = libcuda
|
|
|
|
|
|
cuuint32_t = ctypes.c_uint32
|
|
diff --git a/tinygrad/runtime/autogen/nvrtc.py b/tinygrad/runtime/autogen/nvrtc.py
|
|
index 6af74187..c5a6c6c4 100644
|
|
--- a/tinygrad/runtime/autogen/nvrtc.py
|
|
+++ b/tinygrad/runtime/autogen/nvrtc.py
|
|
@@ -10,7 +10,18 @@ import ctypes, ctypes.util
|
|
|
|
|
|
_libraries = {}
|
|
-_libraries['libnvrtc.so'] = ctypes.CDLL(ctypes.util.find_library('nvrtc'))
|
|
+libnvrtc = None
|
|
+try:
|
|
+ libnvrtc = ctypes.CDLL('libnvrtc.so')
|
|
+except OSError:
|
|
+ pass
|
|
+try:
|
|
+ libnvrtc = ctypes.CDLL('@libnvrtc@')
|
|
+except OSError:
|
|
+ pass
|
|
+if libnvrtc is None:
|
|
+ raise RuntimeError(f"`libnvrtc.so` not found")
|
|
+_libraries['libnvrtc.so'] = libnvrtc
|
|
def string_cast(char_pointer, encoding='utf-8', errors='strict'):
|
|
value = ctypes.cast(char_pointer, ctypes.c_char_p).value
|
|
if value is not None and encoding is not None:
|