nix : add cuda, use a symlinked toolkit for cmake (#3202)

This commit is contained in:
Erik Scholz 2023-09-25 13:48:30 +02:00 committed by GitHub
parent c091cdfb24
commit a98b1633d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -35,6 +35,20 @@
); );
pkgs = import nixpkgs { inherit system; }; pkgs = import nixpkgs { inherit system; };
nativeBuildInputs = with pkgs; [ cmake ninja pkg-config ]; nativeBuildInputs = with pkgs; [ cmake ninja pkg-config ];
cudatoolkit_joined = with pkgs; symlinkJoin {
# HACK(Green-Sky): nix currently has issues with cmake findcudatoolkit
# see https://github.com/NixOS/nixpkgs/issues/224291
# copied from jaxlib
name = "${cudaPackages.cudatoolkit.name}-merged";
paths = [
cudaPackages.cudatoolkit.lib
cudaPackages.cudatoolkit.out
] ++ lib.optionals (lib.versionOlder cudaPackages.cudatoolkit.version "11") [
# for some reason some of the required libs are in the targets/x86_64-linux
# directory; not sure why but this works around it
"${cudaPackages.cudatoolkit}/targets/${system}"
];
};
llama-python = llama-python =
pkgs.python3.withPackages (ps: with ps; [ numpy sentencepiece ]); pkgs.python3.withPackages (ps: with ps; [ numpy sentencepiece ]);
postPatch = '' postPatch = ''
@ -70,6 +84,13 @@
"-DLLAMA_CLBLAST=ON" "-DLLAMA_CLBLAST=ON"
]; ];
}; };
packages.cuda = pkgs.stdenv.mkDerivation {
inherit name src meta postPatch nativeBuildInputs postInstall;
buildInputs = with pkgs; buildInputs ++ [ cudatoolkit_joined ];
cmakeFlags = cmakeFlags ++ [
"-DLLAMA_CUBLAS=ON"
];
};
packages.rocm = pkgs.stdenv.mkDerivation { packages.rocm = pkgs.stdenv.mkDerivation {
inherit name src meta postPatch nativeBuildInputs postInstall; inherit name src meta postPatch nativeBuildInputs postInstall;
buildInputs = with pkgs; buildInputs ++ [ hip hipblas rocblas ]; buildInputs = with pkgs; buildInputs ++ [ hip hipblas rocblas ];