Files
nixpkgs/pkgs/development/python-modules/deep-gemm/patch-runtime-cuda-home-path.patch

28 lines
1.0 KiB
Diff

diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py
index a4633ae..6508865 100644
--- a/deep_gemm/__init__.py
+++ b/deep_gemm/__init__.py
@@ -62,21 +62,7 @@ from .utils import *
# Initialize CPP modules
def _find_cuda_home() -> str:
- # TODO: reuse PyTorch API later
- # For some PyTorch versions, the original `_find_cuda_home` will initialize CUDA, which is incompatible with process forks
- cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
- if cuda_home is None:
- # noinspection PyBroadException
- try:
- with open(os.devnull, 'w') as devnull:
- nvcc = subprocess.check_output(['which', 'nvcc'], stderr=devnull).decode().rstrip('\r\n')
- cuda_home = os.path.dirname(os.path.dirname(nvcc))
- except Exception:
- cuda_home = '/usr/local/cuda'
- if not os.path.exists(cuda_home):
- cuda_home = None
- assert cuda_home is not None
- return cuda_home
+ return "@cuda_home@"
deep_gemm_cpp.init(