mirror of
https://github.com/NixOS/nixpkgs.git
synced 2026-06-09 23:03:47 +00:00
Diff: https://github.com/Dao-AILab/flash-attention/compare/v2.8.2...v2.8.3 Changelog: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.8.3
154 lines
3.8 KiB
Nix
154 lines
3.8 KiB
Nix
{
|
||
lib,
|
||
buildPythonPackage,
|
||
fetchFromGitHub,
|
||
|
||
# build-system
|
||
ninja,
|
||
setuptools,
|
||
torch,
|
||
|
||
# dependencies
|
||
cuda-bindings,
|
||
einops,
|
||
nvidia-cutlass-dsl,
|
||
|
||
# tests
|
||
apex,
|
||
pytestCheckHook,
|
||
sentencepiece,
|
||
timm,
|
||
transformers,
|
||
writableTmpDirAsHomeHook,
|
||
|
||
# passthru
|
||
flash-attn,
|
||
}:
|
||
|
||
let
|
||
inherit (torch) cudaCapabilities cudaPackages cudaSupport;
|
||
inherit (cudaPackages.flags) dropDots;
|
||
|
||
self = buildPythonPackage.override { inherit (torch) stdenv; } (finalAttrs: {
|
||
pname = "flash-attention";
|
||
version = "2.8.3";
|
||
pyproject = true;
|
||
__structuredAttrs = true;
|
||
|
||
src = fetchFromGitHub {
|
||
owner = "Dao-AILab";
|
||
repo = "flash-attention";
|
||
tag = "v${finalAttrs.version}";
|
||
fetchSubmodules = true;
|
||
hash = "sha256-6I1O4E5K5IdbpzrXFHK06QVcOE8zuVkFE338ffk6N8M=";
|
||
};
|
||
|
||
patches = [
|
||
# cutlass.utils.ampere_helpers was removed from nvidia-cutlass-dsl, this patch is a workaround.
|
||
./drop-cutlass-ampere-utils.patch
|
||
];
|
||
|
||
preConfigure = ''
|
||
export MAX_JOBS="$NIX_BUILD_CORES"
|
||
export NVCC_THREADS=2
|
||
'';
|
||
|
||
env = lib.optionalAttrs cudaSupport {
|
||
FORCE_BUILD = "TRUE";
|
||
FLASH_ATTENTION_SKIP_CUDA_BUILD = "FALSE";
|
||
|
||
# 8.0;9.0;12.0
|
||
TORCH_CUDA_ARCH_LIST = lib.concatStringsSep ";" cudaCapabilities;
|
||
# 80;90;120
|
||
FLASH_ATTN_CUDA_ARCHS = lib.strings.concatMapStringsSep ";" dropDots cudaCapabilities;
|
||
};
|
||
|
||
build-system = [
|
||
ninja
|
||
setuptools
|
||
torch
|
||
];
|
||
|
||
nativeBuildInputs = [
|
||
cudaPackages.cuda_nvcc
|
||
];
|
||
|
||
buildInputs = [
|
||
cudaPackages.cuda_cccl # <thrust/*>
|
||
cudaPackages.libcublas # cublas_v2.h
|
||
cudaPackages.libcurand # curand.h
|
||
cudaPackages.libcusolver # cusolverDn.h
|
||
cudaPackages.libcusparse # cusparse.h
|
||
cudaPackages.cuda_cudart # cuda_runtime.h cuda_runtime_api.h
|
||
];
|
||
|
||
dependencies = [
|
||
# Used in flash_attn/cute/interface.py
|
||
cuda-bindings
|
||
|
||
einops
|
||
nvidia-cutlass-dsl
|
||
torch
|
||
];
|
||
|
||
pythonImportsCheck = [ "flash_attn" ];
|
||
|
||
nativeCheckInputs = [
|
||
apex
|
||
pytestCheckHook
|
||
sentencepiece
|
||
timm
|
||
transformers
|
||
writableTmpDirAsHomeHook
|
||
];
|
||
|
||
enabledTestPaths = [
|
||
"tests/"
|
||
];
|
||
|
||
disabledTestPaths = [
|
||
# `fused_dense_lib` and `dropout_layer_norm` live under csrc/ as standalone Python packages
|
||
# with their own setup.py; the top-level setup.py does not build them, and they are not
|
||
# shipped on PyPI either.
|
||
"tests/ops/test_dropout_layer_norm.py"
|
||
"tests/ops/test_fused_dense.py"
|
||
"tests/ops/test_fused_dense_parallel.py"
|
||
|
||
# Imports `RotaryEmbedding` from `transformers.models.gpt_neox.modeling_gpt_neox`, which
|
||
# upstream transformers has since removed.
|
||
"tests/layers/test_rotary.py"
|
||
|
||
# Tests the ROCm composable_kernel backend; we only build the CUDA backend.
|
||
"tests/test_flash_attn_ck.py"
|
||
|
||
# Module-name collision with tests/test_flash_attn.py (both import as
|
||
# `test_flash_attn`). Disable the CUTE-DSL variant and keep the tests that
|
||
# exercise the C++ extension we actually build.
|
||
"tests/cute/test_flash_attn.py"
|
||
];
|
||
|
||
preCheck = ''
|
||
rm -rf flash_attn
|
||
'';
|
||
|
||
# Tests require access to a physical GPU
|
||
doCheck = false;
|
||
|
||
passthru.gpuCheck = self.overridePythonAttrs {
|
||
requiredSystemFeatures = [ "cuda" ];
|
||
doCheck = true;
|
||
};
|
||
|
||
meta = {
|
||
# Upstream requires either CUDA or ROCm. Couldn't get it to work with ROCm for now.
|
||
broken = !cudaSupport;
|
||
description = "Official implementation of FlashAttention and FlashAttention-2";
|
||
homepage = "https://github.com/Dao-AILab/flash-attention/";
|
||
changelog = "https://github.com/Dao-AILab/flash-attention/releases/tag/${finalAttrs.src.tag}";
|
||
license = lib.licenses.bsd3;
|
||
maintainers = with lib.maintainers; [ jherland ];
|
||
};
|
||
});
|
||
in
|
||
self
|