mirror of
https://github.com/NixOS/nixpkgs.git
synced 2026-06-09 14:53:47 +00:00
28 lines
1.0 KiB
Diff
28 lines
1.0 KiB
Diff
diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py
|
|
index a4633ae..6508865 100644
|
|
--- a/deep_gemm/__init__.py
|
|
+++ b/deep_gemm/__init__.py
|
|
@@ -62,21 +62,7 @@ from .utils import *
|
|
|
|
# Initialize CPP modules
|
|
def _find_cuda_home() -> str:
|
|
- # TODO: reuse PyTorch API later
|
|
- # For some PyTorch versions, the original `_find_cuda_home` will initialize CUDA, which is incompatible with process forks
|
|
- cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
|
|
- if cuda_home is None:
|
|
- # noinspection PyBroadException
|
|
- try:
|
|
- with open(os.devnull, 'w') as devnull:
|
|
- nvcc = subprocess.check_output(['which', 'nvcc'], stderr=devnull).decode().rstrip('\r\n')
|
|
- cuda_home = os.path.dirname(os.path.dirname(nvcc))
|
|
- except Exception:
|
|
- cuda_home = '/usr/local/cuda'
|
|
- if not os.path.exists(cuda_home):
|
|
- cuda_home = None
|
|
- assert cuda_home is not None
|
|
- return cuda_home
|
|
+ return "@cuda_home@"
|
|
|
|
|
|
deep_gemm_cpp.init(
|