depot/third_party/nixpkgs/pkgs/development/python-modules/triton/0003-nvidia-cudart-a-systempath.patch

46 lines
1.5 KiB
Diff

From 6f92d54e5a544bc34bb07f2808d554a71cc0e4c3 Mon Sep 17 00:00:00 2001
From: SomeoneSerge <else+aalto@someonex.net>
Date: Sun, 13 Oct 2024 14:30:19 +0000
Subject: [PATCH 3/3] nvidia: cudart a systempath
---
third_party/nvidia/backend/driver.c | 2 +-
third_party/nvidia/backend/driver.py | 5 +++--
2 files changed, 4 insertions(+), 3 deletions(-)
diff --git a/third_party/nvidia/backend/driver.c b/third_party/nvidia/backend/driver.c
index 44524da27..fbdf0d156 100644
--- a/third_party/nvidia/backend/driver.c
+++ b/third_party/nvidia/backend/driver.c
@@ -1,4 +1,4 @@
-#include "cuda.h"
+#include <cuda.h>
#include <dlfcn.h>
#include <stdbool.h>
#define PY_SSIZE_T_CLEAN
diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py
index 30fbadb2a..65c0562ed 100644
--- a/third_party/nvidia/backend/driver.py
+++ b/third_party/nvidia/backend/driver.py
@@ -10,7 +10,8 @@ from triton.backends.compiler import GPUTarget
from triton.backends.driver import GPUDriver
dirname = os.path.dirname(os.path.realpath(__file__))
-include_dir = [os.path.join(dirname, "include")]
+import shlex
+include_dir = [*shlex.split("@cudaToolkitIncludeDirs@"), os.path.join(dirname, "include")]
libdevice_dir = os.path.join(dirname, "lib")
libraries = ['cuda']
@@ -149,7 +150,7 @@ def make_launcher(constants, signature, ids):
# generate glue code
params = [i for i in signature.keys() if i not in constants]
src = f"""
-#include \"cuda.h\"
+#include <cuda.h>
#include <stdbool.h>
#include <Python.h>
#include <dlfcn.h>
--
2.46.0