diff --git a/flake.lock b/flake.lock index b293509..3bedff6 100644 --- a/flake.lock +++ b/flake.lock @@ -38,6 +38,24 @@ "type": "github" } }, + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1705309234, + "narHash": "sha256-uNRRNRKmJyCRC/8y1RqBkqWBLM034y4qN7EprSdmgyA=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "1ef2e671c3b0c19053962c07dbda38332dcebf26", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, "nixpkgs": { "locked": { "lastModified": 1709780214, @@ -118,7 +136,29 @@ "nixpkgs": [ "dream2nix", "nixpkgs" + ], + "rust-overlay": "rust-overlay" + } + }, + "rust-overlay": { + "inputs": { + "flake-utils": "flake-utils", + "nixpkgs": [ + "nixpkgs" ] + }, + "locked": { + "lastModified": 1713492869, + "narHash": "sha256-Zv+ZQq3X+EH6oogkXaJ8dGN8t1v26kPZgC5bki04GnM=", + "owner": "oxalica", + "repo": "rust-overlay", + "rev": "1e9264d1214d3db00c795b41f75d55b5e153758e", + "type": "github" + }, + "original": { + "owner": "oxalica", + "repo": "rust-overlay", + "type": "github" } }, "slimlock": { @@ -142,6 +182,21 @@ "repo": "slimlock", "type": "github" } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } } }, "root": "root", diff --git a/flake.nix b/flake.nix index da5c263..f023610 100644 --- a/flake.nix +++ b/flake.nix @@ -4,9 +4,11 @@ dream2nix.url = "github:yorickvp/dream2nix"; nixpkgs.follows = "dream2nix/nixpkgs"; flake-parts.url = "github:hercules-ci/flake-parts"; + rust-overlay.url = "github:oxalica/rust-overlay"; + rust-overlay.inputs.nixpkgs.follows = "nixpkgs"; }; - outputs = { self, dream2nix, nixpkgs, flake-parts }@inputs: + outputs = { self, dream2nix, nixpkgs, flake-parts, rust-overlay }@inputs: flake-parts.lib.mkFlake { inherit inputs; } { systems = [ "x86_64-linux" ]; # debug = true; @@ -51,11 +53,18 @@ config.allowUnfree = true; inherit system; overlays = [ + (import rust-overlay) (final: prev: { pget = prev.callPackage ./pkgs/pget.nix { }; cognix-weights = prev.callPackage ./pkgs/cognix-weights {}; cognix-cli = prev.callPackage ./pkgs/cognix-cli {}; cog = prev.callPackage ./pkgs/cog.nix {}; + uv = prev.callPackage ./pkgs/uv.nix { + rustPlatform = prev.makeRustPlatform { + cargo = prev.rust-bin.stable.latest.minimal; + rustc = prev.rust-bin.stable.latest.minimal; + }; + }; lib = prev.lib.extend (finall: prevl: { trivial = prevl.trivial // { revisionWithDefault = default: nixpkgs.rev or default; @@ -82,7 +91,7 @@ path = if pkgs.lib.isDerivation path then path else "/dev/null"; }) config.legacyPackages); legacyPackages = { - inherit (pkgs) pget cognix-weights cognix-cli cog; + inherit (pkgs) pget cognix-weights cognix-cli cog uv; callCognix = import ./default.nix { inherit pkgs dream2nix; }; diff --git a/modules/cog.nix b/modules/cog.nix index 4969178..f51cd73 100644 --- a/modules/cog.nix +++ b/modules/cog.nix @@ -71,6 +71,7 @@ let pyEnvWithPip = config.python-env.public.pyEnv.override { postBuild = "$out/bin/python -m ensurepip"; }; + patchTorch = builtins.map (y: if builtins.match "torch==[0-9\.]+$" y == [] then "${y}.*" else y); in { imports = [ ./cog-interface.nix @@ -168,6 +169,7 @@ in { dream2nix.modules.dream2nix.pip pipOverridesModule (proxyLockModule config.lock.content) + ./uv-solver.nix ]; paths = { inherit (config.paths) projectRoot package; }; name = "cog-docker-env"; diff --git a/modules/uv-solver.nix b/modules/uv-solver.nix new file mode 100644 index 0000000..8cc903f --- /dev/null +++ b/modules/uv-solver.nix @@ -0,0 +1,48 @@ +{ lib, config, packageSets, ... }: let + cfg = config.pip.uv; + pkgs = packageSets.nixpkgs; + # bug: torch==2.1.0 does not resolve to torch==2.1.0+cpu + patchTorch = builtins.map (y: if builtins.match "torch==[0-9\.]+$" y == [] then "${y}.*" else y); + constraintsArgs = lib.optionals (cfg.constraints != []) [ + "--constraint" + (builtins.toFile "constraints.txt" (lib.concatMapStrings (x: "${x}\n") cfg.constraints)) + ]; + overridesArgs = lib.optionals (cfg.overrides != []) [ + "--override" + (builtins.toFile "overrides.txt" (lib.concatMapStrings (x: "${x}\n") cfg.overrides)) + ]; + extraArgs = constraintsArgs ++ overridesArgs ++ cfg.extraArgs; +in { + # todo: support env + options.pip.uv = with lib; { + enable = mkEnableOption "use uv solver"; + overrides = mkOption { + type = types.listOf types.str; + default = []; + }; + constraints = mkOption { + type = types.listOf types.str; + default = []; + }; + extraArgs = mkOption { + type = types.listOf types.str; + default = []; + }; + }; + config = lib.mkIf cfg.enable { + deps.fetchPipMetadataScript = pkgs.writeShellScript "fetch-pip-metadata-uv" '' + export UV_DUMP_DREAM2NIX="$out" + ${pkgs.uv}/bin/uv pip install \ + --dry-run \ + --reinstall \ + --index-strategy unsafe-highest \ + --break-system-packages \ + ${lib.optionalString (config.pip.pypiSnapshotDate != null) "--exclude-newer ${config.pip.pypiSnapshotDate}"} \ + --python ${config.deps.python}/bin/python \ + ${lib.escapeShellArgs extraArgs} \ + ${lib.escapeShellArgs (patchTorch config.pip.requirementsList)} + ''; + lock.invalidationData.solver = "uv"; + lock.invalidationData.extraArgs = extraArgs; + }; +} diff --git a/pkgs/uv.nix b/pkgs/uv.nix new file mode 100644 index 0000000..76af635 --- /dev/null +++ b/pkgs/uv.nix @@ -0,0 +1,24 @@ +{ rustPlatform, fetchFromGitHub, lib, cmake, openssl, pkg-config, perl }: +rustPlatform.buildRustPackage { + pname = "uv"; + version = "0.1.34.post"; + nativeBuildInputs = [ cmake pkg-config perl ]; + buildInputs = [ openssl ]; + src = fetchFromGitHub { + owner = "yorickvP"; + repo = "uv"; + rev = "6ef1d706379a3403d917b75d41697e98f9a0619d"; + hash = "sha256-U4mowALk3XKSzRkhbZT333v3mVlLBpeatl8mwfeg6uw="; + }; + cargoLock = { + lockFile = builtins.fetchurl { + url = "https://raw.githubusercontent.com/astral-sh/uv/6ef1d706379a3403d917b75d41697e98f9a0619d/Cargo.lock"; + sha256 = "12v38b50gyi5g7dz269vl4briw5jffw16mlsmiswh0f6q8nb64q4"; + }; + outputHashes = { + "async_zip-0.0.17" = "sha256-Q5fMDJrQtob54CTII3+SXHeozy5S5s3iLOzntevdGOs="; + "pubgrub-0.2.1" = "sha256-sqC7R2mtqymYFULDW0wSbM/MKCZc8rP7Yy/gaQpjYEI="; + }; + }; + doCheck = false; +}