Files
nixpkgs/pkgs/development/python-modules/amd-aiter/default.nix

181 lines
4.6 KiB
Nix

{
lib,
amd-aiter,
buildPythonPackage,
einops,
fetchFromGitHub,
ninja,
numpy,
packaging,
pandas,
psutil,
pybind11,
python,
rocmPackages,
runCommand,
setuptools,
setuptools-scm,
torch,
writableTmpDirAsHomeHook,
}:
let
# Provide a default set of include paths needed by aiter at runtime for JIT modules
defaultRocmIncl = lib.makeIncludePath (
(with rocmPackages; [
clr
hipblas
hipblas-common
hipblaslt
hipcub
hipfft
hipsolver
hipsparse
rocblas
rocprim
rocsolver
rocsparse
rocthrust
])
++ [
pybind11
]
);
in
buildPythonPackage (finalAttrs: {
pname = "amd-aiter";
version = "0.1.11.post1";
pyproject = true;
src = fetchFromGitHub {
owner = "ROCm";
repo = "aiter";
tag = "v${finalAttrs.version}";
hash = "sha256-9CCSmEw0kIoxERhtkKhBkaAGx42kCssH7IPjTgbg9LA=";
};
postPatch = ''
rmdir 3rdparty/composable_kernel
ln -sf ${rocmPackages.composable_kernel.src} 3rdparty/composable_kernel
substituteInPlace pyproject.toml \
--replace-fail '"flydsl==0.0.1.dev95158637"' ""
# TODO: upstream fix. get_git_commit_id_short() assumes a git clone;
# fails in hermetic builds, should fall back gracefully.
substituteInPlace csrc/cpp_itfs/utils.py \
--replace-fail \
'commit_id = get_git_commit_id_short()' \
'commit_id = "${finalAttrs.version}"'
# NIX_AITER_ROCM_INCL is a colon-separated list of include dirs
# defaults to packages known to be needed by aiter
substituteInPlace aiter/jit/utils/cpp_extension.py \
--replace-fail \
'paths.append(_join_rocm_home("include"))' \
'paths.append(_join_rocm_home("include")); paths.extend(os.environ.get("NIX_AITER_ROCM_INCL", "${defaultRocmIncl}").split(":"))'
# setuptools runs setup.py twice (metadata + wheel). prepare_packaging()
# copies 3rdparty/ (with nix store read-only files) into aiter_meta/, then
# the second run can't rmtree or overwrite them.
substituteInPlace setup.py \
--replace-fail \
$'prepare_packaging()\n\n\nclass' \
$'if not os.path.exists("aiter_meta"): prepare_packaging()\n\n\nclass' \
--replace-fail \
'if os.path.exists("aiter_meta") and os.path.isdir("aiter_meta"):' \
'if False:'
'';
env = {
BUILD_TARGET = "rocm";
PREBUILD_KERNELS = "0";
ROCM_PATH = "${rocmPackages.clr}";
SETUPTOOLS_SCM_PRETEND_VERSION = finalAttrs.version;
};
build-system = [
ninja
packaging
pandas
psutil
pybind11
setuptools
setuptools-scm
];
buildInputs = [ rocmPackages.clr ];
nativeBuildInputs = [
rocmPackages.hipcc
writableTmpDirAsHomeHook
];
dependencies = [
einops
ninja
numpy
packaging
pandas
psutil
pybind11
torch
];
# Most tests and imports require a GPU and writable $HOME for JIT cache
doCheck = false;
# Test JIT module builds for CDNA3 iff rocm enabled for torch
passthru.tests = lib.optionalAttrs torch.rocmSupport (
let
mkJitTest =
name: moduleName:
runCommand "amd-aiter-jit-${name}"
{
nativeBuildInputs = [
(python.withPackages (_: [ amd-aiter ]))
rocmPackages.clr
writableTmpDirAsHomeHook
];
env = {
CXX = "amdclang++";
GPU_ARCHS = "gfx942";
PYTORCH_ROCM_ARCH = "gfx942";
};
}
''
export AITER_JIT_DIR=$(mktemp -d)
python3 -c "
from aiter.jit.core import get_args_of_build, build_module
args = get_args_of_build('${moduleName}')
build_module(
'${moduleName}',
args['srcs'], args['flags_extra_cc'], args['flags_extra_hip'],
args['blob_gen_cmd'], args['extra_include'], args['extra_ldflags'],
args['verbose'], args['is_python_module'], args['is_standalone'],
args['torch_exclude'],
)
print('JIT compile of ${moduleName} finished')
"
touch $out
'';
in
{
jit-module-opus-sort = mkJitTest "opus-sort" "module_moe_sorting_opus";
jit-module-mhc = mkJitTest "mhc" "module_mhc";
}
);
meta = {
description = "AI Tensor Engine for ROCm";
homepage = "https://github.com/ROCm/aiter";
license = lib.licenses.mit;
sourceProvenance = with lib.sourceTypes; [
fromSource
binaryNativeCode
];
maintainers = with lib.maintainers; [ lach ];
teams = [ lib.teams.rocm ];
platforms = lib.platforms.linux;
};
})