Files
nixpkgs/pkgs/development/python-modules/apex/fix-cuda-capabilities-selection.patch

165 lines
6.9 KiB
Diff

diff --git a/setup.py b/setup.py
index fa61b72..c31d1e0 100644
--- a/setup.py
+++ b/setup.py
@@ -405,24 +405,8 @@ if has_flag("--cuda_ext", "APEX_CUDA_EXT"):
)
if bare_metal_version >= Version("11.0"):
-
- cc_flag = []
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_70,code=sm_70")
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_80,code=sm_80")
- if bare_metal_version >= Version("11.1"):
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_86,code=sm_86")
- if bare_metal_version >= Version("11.8"):
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_90,code=sm_90")
- if bare_metal_version >= Version("12.8"):
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_100,code=sm_100")
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_120,code=sm_120")
-
+ # Architectures are selected via TORCH_CUDA_ARCH_LIST (torch appends the
+ # matching -gencode flags) instead of a hard-coded list.
ext_modules.append(
CUDAExtension(
name="fused_weight_gradient_mlp_cuda",
@@ -441,7 +425,7 @@ if has_flag("--cuda_ext", "APEX_CUDA_EXT"):
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
- ] + version_dependent_macros + cc_flag,
+ ] + version_dependent_macros,
},
)
)
@@ -530,14 +514,7 @@ if has_flag("--group_norm", "APEX_GROUP_NORM"):
sys.argv.remove("--group_norm")
raise_if_cuda_home_none("--group_norm")
- # CUDA group norm supports from SM70
- arch_flags = []
- # FIXME: this needs to be done more cleanly
- for arch in [70, 75, 80, 86, 90, 100, 120]:
- arch_flag = f"-gencode=arch=compute_{arch},code=sm_{arch}"
- arch_flags.append(arch_flag)
- arch_flags.append(arch_flag)
-
+ # Architectures are selected via TORCH_CUDA_ARCH_LIST instead of a hard-coded list.
ext_modules.append(
CUDAExtension(
name="group_norm_cuda",
@@ -549,7 +526,7 @@ if has_flag("--group_norm", "APEX_GROUP_NORM"):
"cxx": ["-O3", "-std=c++17"] + version_dependent_macros,
"nvcc": [
"-O3", "-std=c++17", "--use_fast_math", "--ftz=false",
- ] + arch_flags + version_dependent_macros,
+ ] + version_dependent_macros,
},
)
)
@@ -651,22 +628,7 @@ if has_flag("--fast_layer_norm", "APEX_FAST_LAYER_NORM"):
sys.argv.remove("--fast_layer_norm")
raise_if_cuda_home_none("--fast_layer_norm")
- cc_flag = []
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_70,code=sm_70")
-
- if bare_metal_version >= Version("11.0"):
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_80,code=sm_80")
- if bare_metal_version >= Version("11.8"):
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_90,code=sm_90")
- if bare_metal_version >= Version("12.8"):
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_100,code=sm_100")
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_120,code=sm_120")
-
+ # Architectures are selected via TORCH_CUDA_ARCH_LIST instead of a hard-coded list.
ext_modules.append(
CUDAExtension(
name="fast_layer_norm",
@@ -689,7 +651,7 @@ if has_flag("--fast_layer_norm", "APEX_FAST_LAYER_NORM"):
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
- ] + version_dependent_macros + generator_flag + cc_flag,
+ ] + version_dependent_macros + generator_flag,
},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/layer_norm")],
)
@@ -703,17 +665,18 @@ if has_flag("--fmha", "APEX_FMHA"):
if bare_metal_version < Version("11.0"):
raise RuntimeError("--fmha only supported on sm_80 and sm_90 GPUs")
+ # The fmha kernels use sm_80 MMA instructions, so select architectures from
+ # TORCH_CUDA_ARCH_LIST but drop anything below 8.0 (torch would otherwise also
+ # try to build the unsupported sm_75 target).
cc_flag = []
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_80,code=sm_80")
- if bare_metal_version >= Version("11.8"):
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_90,code=sm_90")
- if bare_metal_version >= Version("12.8"):
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_100,code=sm_100")
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_120,code=sm_120")
+ for arch in os.environ["TORCH_CUDA_ARCH_LIST"].replace(" ", ";").split(";"):
+ capability = arch.removesuffix("+PTX")
+ if not capability or Version(capability) < Version("8.0"):
+ continue
+ num = capability.replace(".", "")
+ cc_flag += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
+ if arch.endswith("+PTX"):
+ cc_flag += ["-gencode", f"arch=compute_{num},code=compute_{num}"]
ext_modules.append(
CUDAExtension(
@@ -755,25 +718,7 @@ if has_flag("--fast_multihead_attn", "APEX_FAST_MULTIHEAD_ATTN"):
sys.argv.remove("--fast_multihead_attn")
raise_if_cuda_home_none("--fast_multihead_attn")
- cc_flag = []
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_70,code=sm_70")
-
- if bare_metal_version >= Version("11.0"):
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_80,code=sm_80")
- if bare_metal_version >= Version("11.1"):
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_86,code=sm_86")
- if bare_metal_version >= Version("11.8"):
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_90,code=sm_90")
- if bare_metal_version >= Version("12.8"):
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_100,code=sm_100")
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_120,code=sm_120")
-
+ # Architectures are selected via TORCH_CUDA_ARCH_LIST instead of a hard-coded list.
subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"])
ext_modules.append(
CUDAExtension(
@@ -800,8 +745,7 @@ if has_flag("--fast_multihead_attn", "APEX_FAST_MULTIHEAD_ATTN"):
"--use_fast_math",
]
+ version_dependent_macros
- + generator_flag
- + cc_flag,
+ + generator_flag,
},
include_dirs=[
os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass/include/"),