From 52d409afd9422a11b12b6bc30e05b7e39e74b774 Mon Sep 17 00:00:00 2001 From: Gabriella Gonzalez Date: Fri, 20 Dec 2024 21:35:05 -0800 Subject: [PATCH] pythonPackages.mlx: upgrade and fix build --- pkgs/by-name/op/openmpi/package.nix | 5 +-- pkgs/by-name/pr/prrte/package.nix | 2 +- .../python-modules/mlx/default.nix | 31 ++++++++++++------- .../mlx/disable-accelerate.patch | 13 -------- 4 files changed, 23 insertions(+), 28 deletions(-) delete mode 100644 pkgs/development/python-modules/mlx/disable-accelerate.patch diff --git a/pkgs/by-name/op/openmpi/package.nix b/pkgs/by-name/op/openmpi/package.nix index 7fcca51c739e35..f8d2c824145429 100644 --- a/pkgs/by-name/op/openmpi/package.nix +++ b/pkgs/by-name/op/openmpi/package.nix @@ -26,6 +26,7 @@ cudaPackages, # Enable the Sun Grid Engine bindings enableSGE ? false, + enablePRRTE ? true, # Pass PATH/LD_LIBRARY_PATH to point to current mpirun by default enablePrefix ? false, # Enable libfabric support (necessary for Omnipath networks) on x86_64 linux @@ -85,6 +86,7 @@ stdenv.mkDerivation (finalAttrs: { zlib libevent hwloc + prrte ] ++ lib.optionals stdenv.hostPlatform.isLinux [ libnl @@ -92,7 +94,6 @@ stdenv.mkDerivation (finalAttrs: { pmix ucx ucc - prrte ] ++ lib.optionals cudaSupport [ cudaPackages.cuda_cudart ] ++ lib.optionals (stdenv.hostPlatform.isLinux || stdenv.hostPlatform.isFreeBSD) [ rdma-core ] @@ -119,7 +120,7 @@ stdenv.mkDerivation (finalAttrs: { "--with-pmix=${lib.getDev pmix}" "--with-pmix-libdir=${lib.getLib pmix}/lib" # Puts a "default OMPI_PRTERUN" value to mpirun / mpiexec executables - (lib.withFeatureAs stdenv.hostPlatform.isLinux "prrte" (lib.getBin prrte)) + (lib.withFeatureAs enablePRRTE "prrte" (lib.getBin prrte)) (lib.withFeature enableSGE "sge") (lib.enableFeature enablePrefix "mpirun-prefix-by-default") # TODO: add UCX support, which is recommended to use with cuda for the most robust OpenMPI build diff --git a/pkgs/by-name/pr/prrte/package.nix b/pkgs/by-name/pr/prrte/package.nix index 44f8fba390aefb..269c16f091cf85 100644 --- a/pkgs/by-name/pr/prrte/package.nix +++ b/pkgs/by-name/pr/prrte/package.nix @@ -75,6 +75,6 @@ stdenv.mkDerivation rec { homepage = "https://docs.prrte.org/"; license = lib.licenses.bsd3; maintainers = with lib.maintainers; [ markuskowa ]; - platforms = lib.platforms.linux; + platforms = lib.platforms.unix; }; } diff --git a/pkgs/development/python-modules/mlx/default.nix b/pkgs/development/python-modules/mlx/default.nix index ac90ebb898082f..904894ebb7a384 100644 --- a/pkgs/development/python-modules/mlx/default.nix +++ b/pkgs/development/python-modules/mlx/default.nix @@ -9,6 +9,9 @@ blas, lapack, setuptools, + nanobind, + openmpi, + apple-sdk_14, }: let @@ -25,29 +28,31 @@ let rev = "v3.11.3"; hash = "sha256-7F0Jon+1oWL7uqet5i1IgHX0fUw/+z0QwEcA3zs5xHg="; }; + fmt = fetchFromGitHub { + owner = "fmtlib"; + repo = "fmt"; + rev = "10.2.1"; + hash = "sha256-pEltGLAHLZ3xypD/Ur4dWPWJ9BGVXwqQyKcDWVmC3co="; + }; in buildPythonPackage rec { pname = "mlx"; - version = "0.18.0"; + version = "0.21.1"; src = fetchFromGitHub { owner = "ml-explore"; repo = "mlx"; rev = "refs/tags/v${version}"; - hash = "sha256-eFKjCrutqrmhZKzRrLq5nYl0ieqLvoXpbnTxA1NEhWo="; + hash = "sha256-wxv9bA9e8VyFv/FMh63sUTTNgkXHGQJNQhLuVynczZA="; }; pyproject = true; - patches = [ - # With Darwin SDK 11 we cannot include vecLib/cblas_new.h, this needs to wait for PR #229210 - # In the meantime, pretend Accelerate is not available and use blas/lapack instead. - ./disable-accelerate.patch - ]; - postPatch = '' substituteInPlace CMakeLists.txt \ - --replace "/usr/bin/xcrun" "${xcbuild}/bin/xcrun" \ + --replace-fail "/usr/bin/xcrun" "${xcbuild}/bin/xcrun" + substituteInPlace pyproject.toml \ + --replace-fail "nanobind==2.2.0" "nanobind" ''; dontUseCmakeConfigure = true; @@ -59,6 +64,7 @@ buildPythonPackage rec { (lib.cmakeBool "MLX_BUILD_METAL" false) (lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_GGUFLIB" "${gguf-tools}") (lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_JSON" "${nlohmann_json}") + (lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_FMT" "${fmt}") ]; }; @@ -67,14 +73,15 @@ buildPythonPackage rec { pybind11 xcbuild zsh - gguf-tools - nlohmann_json setuptools + nanobind + openmpi ]; buildInputs = [ blas lapack + apple-sdk_14 ]; meta = with lib; { @@ -83,6 +90,6 @@ buildPythonPackage rec { changelog = "https://github.com/ml-explore/mlx/releases/tag/v${version}"; license = licenses.mit; platforms = [ "aarch64-darwin" ]; - maintainers = with maintainers; [ viraptor ]; + maintainers = with maintainers; [ viraptor Gabriella439 ]; }; } diff --git a/pkgs/development/python-modules/mlx/disable-accelerate.patch b/pkgs/development/python-modules/mlx/disable-accelerate.patch deleted file mode 100644 index 693e7f41104d03..00000000000000 --- a/pkgs/development/python-modules/mlx/disable-accelerate.patch +++ /dev/null @@ -1,13 +0,0 @@ -diff --git a/CMakeLists.txt b/CMakeLists.txt -index 2d6bef9..d099673 100644 ---- a/CMakeLists.txt -+++ b/CMakeLists.txt -@@ -104,7 +104,7 @@ elseif (MLX_BUILD_METAL) - ${QUARTZ_LIB}) - endif() - --find_library(ACCELERATE_LIBRARY Accelerate) -+#find_library(ACCELERATE_LIBRARY Accelerate) - if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY) - message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}") - set(MLX_BUILD_ACCELERATE ON)