diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py index a4633ae..6508865 100644 --- a/deep_gemm/__init__.py +++ b/deep_gemm/__init__.py @@ -62,21 +62,7 @@ from .utils import * # Initialize CPP modules def _find_cuda_home() -> str: - # TODO: reuse PyTorch API later - # For some PyTorch versions, the original `_find_cuda_home` will initialize CUDA, which is incompatible with process forks - cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH') - if cuda_home is None: - # noinspection PyBroadException - try: - with open(os.devnull, 'w') as devnull: - nvcc = subprocess.check_output(['which', 'nvcc'], stderr=devnull).decode().rstrip('\r\n') - cuda_home = os.path.dirname(os.path.dirname(nvcc)) - except Exception: - cuda_home = '/usr/local/cuda' - if not os.path.exists(cuda_home): - cuda_home = None - assert cuda_home is not None - return cuda_home + return "@cuda_home@" deep_gemm_cpp.init(