mirror of
https://github.com/NixOS/nixpkgs.git
synced 2026-06-15 01:33:50 +00:00
181 lines
4.6 KiB
Nix
181 lines
4.6 KiB
Nix
{
|
|
lib,
|
|
amd-aiter,
|
|
buildPythonPackage,
|
|
einops,
|
|
fetchFromGitHub,
|
|
ninja,
|
|
numpy,
|
|
packaging,
|
|
pandas,
|
|
psutil,
|
|
pybind11,
|
|
python,
|
|
rocmPackages,
|
|
runCommand,
|
|
setuptools,
|
|
setuptools-scm,
|
|
torch,
|
|
writableTmpDirAsHomeHook,
|
|
}:
|
|
let
|
|
# Provide a default set of include paths needed by aiter at runtime for JIT modules
|
|
defaultRocmIncl = lib.makeIncludePath (
|
|
(with rocmPackages; [
|
|
clr
|
|
hipblas
|
|
hipblas-common
|
|
hipblaslt
|
|
hipcub
|
|
hipfft
|
|
hipsolver
|
|
hipsparse
|
|
rocblas
|
|
rocprim
|
|
rocsolver
|
|
rocsparse
|
|
rocthrust
|
|
])
|
|
++ [
|
|
pybind11
|
|
]
|
|
);
|
|
in
|
|
buildPythonPackage (finalAttrs: {
|
|
pname = "amd-aiter";
|
|
version = "0.1.11.post1";
|
|
pyproject = true;
|
|
|
|
src = fetchFromGitHub {
|
|
owner = "ROCm";
|
|
repo = "aiter";
|
|
tag = "v${finalAttrs.version}";
|
|
hash = "sha256-9CCSmEw0kIoxERhtkKhBkaAGx42kCssH7IPjTgbg9LA=";
|
|
};
|
|
|
|
postPatch = ''
|
|
rmdir 3rdparty/composable_kernel
|
|
ln -sf ${rocmPackages.composable_kernel.src} 3rdparty/composable_kernel
|
|
|
|
substituteInPlace pyproject.toml \
|
|
--replace-fail '"flydsl==0.0.1.dev95158637"' ""
|
|
|
|
# TODO: upstream fix. get_git_commit_id_short() assumes a git clone;
|
|
# fails in hermetic builds, should fall back gracefully.
|
|
substituteInPlace csrc/cpp_itfs/utils.py \
|
|
--replace-fail \
|
|
'commit_id = get_git_commit_id_short()' \
|
|
'commit_id = "${finalAttrs.version}"'
|
|
|
|
# NIX_AITER_ROCM_INCL is a colon-separated list of include dirs
|
|
# defaults to packages known to be needed by aiter
|
|
substituteInPlace aiter/jit/utils/cpp_extension.py \
|
|
--replace-fail \
|
|
'paths.append(_join_rocm_home("include"))' \
|
|
'paths.append(_join_rocm_home("include")); paths.extend(os.environ.get("NIX_AITER_ROCM_INCL", "${defaultRocmIncl}").split(":"))'
|
|
|
|
# setuptools runs setup.py twice (metadata + wheel). prepare_packaging()
|
|
# copies 3rdparty/ (with nix store read-only files) into aiter_meta/, then
|
|
# the second run can't rmtree or overwrite them.
|
|
substituteInPlace setup.py \
|
|
--replace-fail \
|
|
$'prepare_packaging()\n\n\nclass' \
|
|
$'if not os.path.exists("aiter_meta"): prepare_packaging()\n\n\nclass' \
|
|
--replace-fail \
|
|
'if os.path.exists("aiter_meta") and os.path.isdir("aiter_meta"):' \
|
|
'if False:'
|
|
'';
|
|
|
|
env = {
|
|
BUILD_TARGET = "rocm";
|
|
PREBUILD_KERNELS = "0";
|
|
ROCM_PATH = "${rocmPackages.clr}";
|
|
SETUPTOOLS_SCM_PRETEND_VERSION = finalAttrs.version;
|
|
};
|
|
|
|
build-system = [
|
|
ninja
|
|
packaging
|
|
pandas
|
|
psutil
|
|
pybind11
|
|
setuptools
|
|
setuptools-scm
|
|
];
|
|
|
|
buildInputs = [ rocmPackages.clr ];
|
|
|
|
nativeBuildInputs = [
|
|
rocmPackages.hipcc
|
|
writableTmpDirAsHomeHook
|
|
];
|
|
|
|
dependencies = [
|
|
einops
|
|
ninja
|
|
numpy
|
|
packaging
|
|
pandas
|
|
psutil
|
|
pybind11
|
|
torch
|
|
];
|
|
|
|
# Most tests and imports require a GPU and writable $HOME for JIT cache
|
|
doCheck = false;
|
|
|
|
# Test JIT module builds for CDNA3 iff rocm enabled for torch
|
|
passthru.tests = lib.optionalAttrs torch.rocmSupport (
|
|
let
|
|
mkJitTest =
|
|
name: moduleName:
|
|
runCommand "amd-aiter-jit-${name}"
|
|
{
|
|
nativeBuildInputs = [
|
|
(python.withPackages (_: [ amd-aiter ]))
|
|
rocmPackages.clr
|
|
writableTmpDirAsHomeHook
|
|
];
|
|
env = {
|
|
CXX = "amdclang++";
|
|
GPU_ARCHS = "gfx942";
|
|
PYTORCH_ROCM_ARCH = "gfx942";
|
|
};
|
|
}
|
|
''
|
|
export AITER_JIT_DIR=$(mktemp -d)
|
|
python3 -c "
|
|
from aiter.jit.core import get_args_of_build, build_module
|
|
args = get_args_of_build('${moduleName}')
|
|
build_module(
|
|
'${moduleName}',
|
|
args['srcs'], args['flags_extra_cc'], args['flags_extra_hip'],
|
|
args['blob_gen_cmd'], args['extra_include'], args['extra_ldflags'],
|
|
args['verbose'], args['is_python_module'], args['is_standalone'],
|
|
args['torch_exclude'],
|
|
)
|
|
print('JIT compile of ${moduleName} finished')
|
|
"
|
|
touch $out
|
|
'';
|
|
in
|
|
{
|
|
jit-module-opus-sort = mkJitTest "opus-sort" "module_moe_sorting_opus";
|
|
jit-module-mhc = mkJitTest "mhc" "module_mhc";
|
|
}
|
|
);
|
|
|
|
meta = {
|
|
description = "AI Tensor Engine for ROCm";
|
|
homepage = "https://github.com/ROCm/aiter";
|
|
license = lib.licenses.mit;
|
|
sourceProvenance = with lib.sourceTypes; [
|
|
fromSource
|
|
binaryNativeCode
|
|
];
|
|
maintainers = with lib.maintainers; [ lach ];
|
|
teams = [ lib.teams.rocm ];
|
|
platforms = lib.platforms.linux;
|
|
};
|
|
})
|