2024-06-20 14:57:18 +00:00
|
|
|
diff --git a/tinygrad/runtime/autogen/cuda.py b/tinygrad/runtime/autogen/cuda.py
|
2024-09-19 14:19:46 +00:00
|
|
|
index a30c8f53..e2078ff6 100644
|
2024-06-20 14:57:18 +00:00
|
|
|
--- a/tinygrad/runtime/autogen/cuda.py
|
|
|
|
+++ b/tinygrad/runtime/autogen/cuda.py
|
2024-09-19 14:19:46 +00:00
|
|
|
@@ -145,7 +145,19 @@ def char_pointer_cast(string, encoding='utf-8'):
|
2024-04-21 15:54:59 +00:00
|
|
|
|
|
|
|
|
|
|
|
_libraries = {}
|
|
|
|
-_libraries['libcuda.so'] = ctypes.CDLL(ctypes.util.find_library('cuda'))
|
2024-09-19 14:19:46 +00:00
|
|
|
+libcuda = None
|
|
|
|
+try:
|
|
|
|
+ libcuda = ctypes.CDLL('libcuda.so')
|
|
|
|
+except OSError:
|
|
|
|
+ pass
|
|
|
|
+try:
|
|
|
|
+ libcuda = ctypes.CDLL('@driverLink@/lib/libcuda.so')
|
|
|
|
+except OSError:
|
|
|
|
+ pass
|
|
|
|
+if libcuda is None:
|
|
|
|
+ raise RuntimeError(f"`libcuda.so` not found")
|
|
|
|
+
|
|
|
|
+_libraries['libcuda.so'] = libcuda
|
2024-04-21 15:54:59 +00:00
|
|
|
|
|
|
|
|
|
|
|
cuuint32_t = ctypes.c_uint32
|
2024-09-19 14:19:46 +00:00
|
|
|
diff --git a/tinygrad/runtime/autogen/nvrtc.py b/tinygrad/runtime/autogen/nvrtc.py
|
|
|
|
index 6af74187..c5a6c6c4 100644
|
|
|
|
--- a/tinygrad/runtime/autogen/nvrtc.py
|
|
|
|
+++ b/tinygrad/runtime/autogen/nvrtc.py
|
|
|
|
@@ -10,7 +10,18 @@ import ctypes, ctypes.util
|
|
|
|
|
|
|
|
|
|
|
|
_libraries = {}
|
|
|
|
-_libraries['libnvrtc.so'] = ctypes.CDLL(ctypes.util.find_library('nvrtc'))
|
|
|
|
+libnvrtc = None
|
|
|
|
+try:
|
|
|
|
+ libnvrtc = ctypes.CDLL('libnvrtc.so')
|
|
|
|
+except OSError:
|
|
|
|
+ pass
|
|
|
|
+try:
|
|
|
|
+ libnvrtc = ctypes.CDLL('@libnvrtc@')
|
|
|
|
+except OSError:
|
|
|
|
+ pass
|
|
|
|
+if libnvrtc is None:
|
|
|
|
+ raise RuntimeError(f"`libnvrtc.so` not found")
|
2024-10-04 16:56:33 +00:00
|
|
|
+_libraries['libnvrtc.so'] = libnvrtc
|
2024-09-19 14:19:46 +00:00
|
|
|
def string_cast(char_pointer, encoding='utf-8', errors='strict'):
|
|
|
|
value = ctypes.cast(char_pointer, ctypes.c_char_p).value
|
|
|
|
if value is not None and encoding is not None:
|