diff --git a/setup.py b/setup.py index fa61b72..c31d1e0 100644 --- a/setup.py +++ b/setup.py @@ -405,24 +405,8 @@ if has_flag("--cuda_ext", "APEX_CUDA_EXT"): ) if bare_metal_version >= Version("11.0"): - - cc_flag = [] - cc_flag.append("-gencode") - cc_flag.append("arch=compute_70,code=sm_70") - cc_flag.append("-gencode") - cc_flag.append("arch=compute_80,code=sm_80") - if bare_metal_version >= Version("11.1"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_86,code=sm_86") - if bare_metal_version >= Version("11.8"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_90,code=sm_90") - if bare_metal_version >= Version("12.8"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_100,code=sm_100") - cc_flag.append("-gencode") - cc_flag.append("arch=compute_120,code=sm_120") - + # Architectures are selected via TORCH_CUDA_ARCH_LIST (torch appends the + # matching -gencode flags) instead of a hard-coded list. ext_modules.append( CUDAExtension( name="fused_weight_gradient_mlp_cuda", @@ -441,7 +425,7 @@ if has_flag("--cuda_ext", "APEX_CUDA_EXT"): "--expt-relaxed-constexpr", "--expt-extended-lambda", "--use_fast_math", - ] + version_dependent_macros + cc_flag, + ] + version_dependent_macros, }, ) ) @@ -530,14 +514,7 @@ if has_flag("--group_norm", "APEX_GROUP_NORM"): sys.argv.remove("--group_norm") raise_if_cuda_home_none("--group_norm") - # CUDA group norm supports from SM70 - arch_flags = [] - # FIXME: this needs to be done more cleanly - for arch in [70, 75, 80, 86, 90, 100, 120]: - arch_flag = f"-gencode=arch=compute_{arch},code=sm_{arch}" - arch_flags.append(arch_flag) - arch_flags.append(arch_flag) - + # Architectures are selected via TORCH_CUDA_ARCH_LIST instead of a hard-coded list. ext_modules.append( CUDAExtension( name="group_norm_cuda", @@ -549,7 +526,7 @@ if has_flag("--group_norm", "APEX_GROUP_NORM"): "cxx": ["-O3", "-std=c++17"] + version_dependent_macros, "nvcc": [ "-O3", "-std=c++17", "--use_fast_math", "--ftz=false", - ] + arch_flags + version_dependent_macros, + ] + version_dependent_macros, }, ) ) @@ -651,22 +628,7 @@ if has_flag("--fast_layer_norm", "APEX_FAST_LAYER_NORM"): sys.argv.remove("--fast_layer_norm") raise_if_cuda_home_none("--fast_layer_norm") - cc_flag = [] - cc_flag.append("-gencode") - cc_flag.append("arch=compute_70,code=sm_70") - - if bare_metal_version >= Version("11.0"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_80,code=sm_80") - if bare_metal_version >= Version("11.8"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_90,code=sm_90") - if bare_metal_version >= Version("12.8"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_100,code=sm_100") - cc_flag.append("-gencode") - cc_flag.append("arch=compute_120,code=sm_120") - + # Architectures are selected via TORCH_CUDA_ARCH_LIST instead of a hard-coded list. ext_modules.append( CUDAExtension( name="fast_layer_norm", @@ -689,7 +651,7 @@ if has_flag("--fast_layer_norm", "APEX_FAST_LAYER_NORM"): "--expt-relaxed-constexpr", "--expt-extended-lambda", "--use_fast_math", - ] + version_dependent_macros + generator_flag + cc_flag, + ] + version_dependent_macros + generator_flag, }, include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/layer_norm")], ) @@ -703,17 +665,18 @@ if has_flag("--fmha", "APEX_FMHA"): if bare_metal_version < Version("11.0"): raise RuntimeError("--fmha only supported on sm_80 and sm_90 GPUs") + # The fmha kernels use sm_80 MMA instructions, so select architectures from + # TORCH_CUDA_ARCH_LIST but drop anything below 8.0 (torch would otherwise also + # try to build the unsupported sm_75 target). cc_flag = [] - cc_flag.append("-gencode") - cc_flag.append("arch=compute_80,code=sm_80") - if bare_metal_version >= Version("11.8"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_90,code=sm_90") - if bare_metal_version >= Version("12.8"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_100,code=sm_100") - cc_flag.append("-gencode") - cc_flag.append("arch=compute_120,code=sm_120") + for arch in os.environ["TORCH_CUDA_ARCH_LIST"].replace(" ", ";").split(";"): + capability = arch.removesuffix("+PTX") + if not capability or Version(capability) < Version("8.0"): + continue + num = capability.replace(".", "") + cc_flag += ["-gencode", f"arch=compute_{num},code=sm_{num}"] + if arch.endswith("+PTX"): + cc_flag += ["-gencode", f"arch=compute_{num},code=compute_{num}"] ext_modules.append( CUDAExtension( @@ -755,25 +718,7 @@ if has_flag("--fast_multihead_attn", "APEX_FAST_MULTIHEAD_ATTN"): sys.argv.remove("--fast_multihead_attn") raise_if_cuda_home_none("--fast_multihead_attn") - cc_flag = [] - cc_flag.append("-gencode") - cc_flag.append("arch=compute_70,code=sm_70") - - if bare_metal_version >= Version("11.0"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_80,code=sm_80") - if bare_metal_version >= Version("11.1"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_86,code=sm_86") - if bare_metal_version >= Version("11.8"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_90,code=sm_90") - if bare_metal_version >= Version("12.8"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_100,code=sm_100") - cc_flag.append("-gencode") - cc_flag.append("arch=compute_120,code=sm_120") - + # Architectures are selected via TORCH_CUDA_ARCH_LIST instead of a hard-coded list. subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"]) ext_modules.append( CUDAExtension( @@ -800,8 +745,7 @@ if has_flag("--fast_multihead_attn", "APEX_FAST_MULTIHEAD_ATTN"): "--use_fast_math", ] + version_dependent_macros - + generator_flag - + cc_flag, + + generator_flag, }, include_dirs=[ os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass/include/"),