mirror of
https://github.com/NixOS/nixpkgs.git
synced 2026-06-17 02:34:05 +00:00
165 lines
6.9 KiB
Diff
165 lines
6.9 KiB
Diff
diff --git a/setup.py b/setup.py
|
|
index fa61b72..c31d1e0 100644
|
|
--- a/setup.py
|
|
+++ b/setup.py
|
|
@@ -405,24 +405,8 @@ if has_flag("--cuda_ext", "APEX_CUDA_EXT"):
|
|
)
|
|
|
|
if bare_metal_version >= Version("11.0"):
|
|
-
|
|
- cc_flag = []
|
|
- cc_flag.append("-gencode")
|
|
- cc_flag.append("arch=compute_70,code=sm_70")
|
|
- cc_flag.append("-gencode")
|
|
- cc_flag.append("arch=compute_80,code=sm_80")
|
|
- if bare_metal_version >= Version("11.1"):
|
|
- cc_flag.append("-gencode")
|
|
- cc_flag.append("arch=compute_86,code=sm_86")
|
|
- if bare_metal_version >= Version("11.8"):
|
|
- cc_flag.append("-gencode")
|
|
- cc_flag.append("arch=compute_90,code=sm_90")
|
|
- if bare_metal_version >= Version("12.8"):
|
|
- cc_flag.append("-gencode")
|
|
- cc_flag.append("arch=compute_100,code=sm_100")
|
|
- cc_flag.append("-gencode")
|
|
- cc_flag.append("arch=compute_120,code=sm_120")
|
|
-
|
|
+ # Architectures are selected via TORCH_CUDA_ARCH_LIST (torch appends the
|
|
+ # matching -gencode flags) instead of a hard-coded list.
|
|
ext_modules.append(
|
|
CUDAExtension(
|
|
name="fused_weight_gradient_mlp_cuda",
|
|
@@ -441,7 +425,7 @@ if has_flag("--cuda_ext", "APEX_CUDA_EXT"):
|
|
"--expt-relaxed-constexpr",
|
|
"--expt-extended-lambda",
|
|
"--use_fast_math",
|
|
- ] + version_dependent_macros + cc_flag,
|
|
+ ] + version_dependent_macros,
|
|
},
|
|
)
|
|
)
|
|
@@ -530,14 +514,7 @@ if has_flag("--group_norm", "APEX_GROUP_NORM"):
|
|
sys.argv.remove("--group_norm")
|
|
raise_if_cuda_home_none("--group_norm")
|
|
|
|
- # CUDA group norm supports from SM70
|
|
- arch_flags = []
|
|
- # FIXME: this needs to be done more cleanly
|
|
- for arch in [70, 75, 80, 86, 90, 100, 120]:
|
|
- arch_flag = f"-gencode=arch=compute_{arch},code=sm_{arch}"
|
|
- arch_flags.append(arch_flag)
|
|
- arch_flags.append(arch_flag)
|
|
-
|
|
+ # Architectures are selected via TORCH_CUDA_ARCH_LIST instead of a hard-coded list.
|
|
ext_modules.append(
|
|
CUDAExtension(
|
|
name="group_norm_cuda",
|
|
@@ -549,7 +526,7 @@ if has_flag("--group_norm", "APEX_GROUP_NORM"):
|
|
"cxx": ["-O3", "-std=c++17"] + version_dependent_macros,
|
|
"nvcc": [
|
|
"-O3", "-std=c++17", "--use_fast_math", "--ftz=false",
|
|
- ] + arch_flags + version_dependent_macros,
|
|
+ ] + version_dependent_macros,
|
|
},
|
|
)
|
|
)
|
|
@@ -651,22 +628,7 @@ if has_flag("--fast_layer_norm", "APEX_FAST_LAYER_NORM"):
|
|
sys.argv.remove("--fast_layer_norm")
|
|
raise_if_cuda_home_none("--fast_layer_norm")
|
|
|
|
- cc_flag = []
|
|
- cc_flag.append("-gencode")
|
|
- cc_flag.append("arch=compute_70,code=sm_70")
|
|
-
|
|
- if bare_metal_version >= Version("11.0"):
|
|
- cc_flag.append("-gencode")
|
|
- cc_flag.append("arch=compute_80,code=sm_80")
|
|
- if bare_metal_version >= Version("11.8"):
|
|
- cc_flag.append("-gencode")
|
|
- cc_flag.append("arch=compute_90,code=sm_90")
|
|
- if bare_metal_version >= Version("12.8"):
|
|
- cc_flag.append("-gencode")
|
|
- cc_flag.append("arch=compute_100,code=sm_100")
|
|
- cc_flag.append("-gencode")
|
|
- cc_flag.append("arch=compute_120,code=sm_120")
|
|
-
|
|
+ # Architectures are selected via TORCH_CUDA_ARCH_LIST instead of a hard-coded list.
|
|
ext_modules.append(
|
|
CUDAExtension(
|
|
name="fast_layer_norm",
|
|
@@ -689,7 +651,7 @@ if has_flag("--fast_layer_norm", "APEX_FAST_LAYER_NORM"):
|
|
"--expt-relaxed-constexpr",
|
|
"--expt-extended-lambda",
|
|
"--use_fast_math",
|
|
- ] + version_dependent_macros + generator_flag + cc_flag,
|
|
+ ] + version_dependent_macros + generator_flag,
|
|
},
|
|
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/layer_norm")],
|
|
)
|
|
@@ -703,17 +665,18 @@ if has_flag("--fmha", "APEX_FMHA"):
|
|
if bare_metal_version < Version("11.0"):
|
|
raise RuntimeError("--fmha only supported on sm_80 and sm_90 GPUs")
|
|
|
|
+ # The fmha kernels use sm_80 MMA instructions, so select architectures from
|
|
+ # TORCH_CUDA_ARCH_LIST but drop anything below 8.0 (torch would otherwise also
|
|
+ # try to build the unsupported sm_75 target).
|
|
cc_flag = []
|
|
- cc_flag.append("-gencode")
|
|
- cc_flag.append("arch=compute_80,code=sm_80")
|
|
- if bare_metal_version >= Version("11.8"):
|
|
- cc_flag.append("-gencode")
|
|
- cc_flag.append("arch=compute_90,code=sm_90")
|
|
- if bare_metal_version >= Version("12.8"):
|
|
- cc_flag.append("-gencode")
|
|
- cc_flag.append("arch=compute_100,code=sm_100")
|
|
- cc_flag.append("-gencode")
|
|
- cc_flag.append("arch=compute_120,code=sm_120")
|
|
+ for arch in os.environ["TORCH_CUDA_ARCH_LIST"].replace(" ", ";").split(";"):
|
|
+ capability = arch.removesuffix("+PTX")
|
|
+ if not capability or Version(capability) < Version("8.0"):
|
|
+ continue
|
|
+ num = capability.replace(".", "")
|
|
+ cc_flag += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
|
|
+ if arch.endswith("+PTX"):
|
|
+ cc_flag += ["-gencode", f"arch=compute_{num},code=compute_{num}"]
|
|
|
|
ext_modules.append(
|
|
CUDAExtension(
|
|
@@ -755,25 +718,7 @@ if has_flag("--fast_multihead_attn", "APEX_FAST_MULTIHEAD_ATTN"):
|
|
sys.argv.remove("--fast_multihead_attn")
|
|
raise_if_cuda_home_none("--fast_multihead_attn")
|
|
|
|
- cc_flag = []
|
|
- cc_flag.append("-gencode")
|
|
- cc_flag.append("arch=compute_70,code=sm_70")
|
|
-
|
|
- if bare_metal_version >= Version("11.0"):
|
|
- cc_flag.append("-gencode")
|
|
- cc_flag.append("arch=compute_80,code=sm_80")
|
|
- if bare_metal_version >= Version("11.1"):
|
|
- cc_flag.append("-gencode")
|
|
- cc_flag.append("arch=compute_86,code=sm_86")
|
|
- if bare_metal_version >= Version("11.8"):
|
|
- cc_flag.append("-gencode")
|
|
- cc_flag.append("arch=compute_90,code=sm_90")
|
|
- if bare_metal_version >= Version("12.8"):
|
|
- cc_flag.append("-gencode")
|
|
- cc_flag.append("arch=compute_100,code=sm_100")
|
|
- cc_flag.append("-gencode")
|
|
- cc_flag.append("arch=compute_120,code=sm_120")
|
|
-
|
|
+ # Architectures are selected via TORCH_CUDA_ARCH_LIST instead of a hard-coded list.
|
|
subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"])
|
|
ext_modules.append(
|
|
CUDAExtension(
|
|
@@ -800,8 +745,7 @@ if has_flag("--fast_multihead_attn", "APEX_FAST_MULTIHEAD_ATTN"):
|
|
"--use_fast_math",
|
|
]
|
|
+ version_dependent_macros
|
|
- + generator_flag
|
|
- + cc_flag,
|
|
+ + generator_flag,
|
|
},
|
|
include_dirs=[
|
|
os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass/include/"),
|