Skip to content

Commit 97b1e47

Browse files
committed
get torch and python include paths without subprocesses
1 parent ce32b31 commit 97b1e47

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

setup.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
import os
22
import subprocess
3+
import sysconfig
4+
import torch
35
from setuptools import setup
4-
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
6+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, include_paths
57
from config import sources, target, kernels
68
target = target.lower()
79

810
# Set environment variables
911
thunderkittens_root = os.getenv('THUNDERKITTENS_ROOT', os.path.abspath(os.path.join(os.getcwd(), '.')))
10-
python_include = subprocess.check_output(['python', '-c', "import sysconfig; print(sysconfig.get_path('include'))"]).decode().strip()
11-
torch_include = subprocess.check_output(['python', '-c', "import torch; from torch.utils.cpp_extension import include_paths; print(' '.join(['-I' + p for p in include_paths()]))"]).decode().strip()
12+
python_include = sysconfig.get_path("include")
13+
torch_include = [f"-I{p}" + p for p in include_paths()]
1214
print('Thunderkittens root:', thunderkittens_root)
1315
print('Python include:', python_include)
14-
print('Torch include directories:', torch_include)
16+
print('Torch include directories:', " ".join(torch_include))
1517

1618
# CUDA flags
1719
cuda_flags = [
@@ -31,7 +33,7 @@
3133
f'-I{thunderkittens_root}/prototype',
3234
f'-I{python_include}',
3335
'-DTORCH_COMPILE'
34-
] + torch_include.split()
36+
] + torch_include
3537
cpp_flags = [
3638
'-std=c++20',
3739
'-O3'

0 commit comments

Comments
 (0)