diff --git a/README.md b/README.md index 9db9ed9..f8d72f1 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,7 @@ Once your preferred template has been initialized, you can use the provided shel | [Pulumi] | [`pulumi`](./pulumi/) | | [Purescript] | [`purescript`](./purescript/) | | [Python] | [`python`](./python/) | +| [PyTorch] | [`pytorch`](./python-pytorch/) | [R] | [`r`](./r/) | | [Ruby] | [`ruby`](./ruby/) | | [Rust] | [`rust`](./rust/) | @@ -269,6 +270,12 @@ A dev template that's fully customizable. - [Python] 3.11.4 - [pip] 23.0.1 +### [`pytorch`](./python-pytorch/) + +- [Python] 3.12.8 +- [PyTorch] 2.5.1 +- [NumPy] 2.2.2 + ### [`r`](./r/) - [R] 4.3.1 @@ -408,6 +415,7 @@ All of the templates have only the root [flake](./flake.nix) as a flake input. T [nomad-autoscaler]: https://github.com/hashicorp/nomad-autoscaler [nomad-pack]: https://github.com/hashicorp/nomad-pack [npm]: https://npmjs.org +[numpy]: https://numpy.org/ [ocaml]: https://ocaml.org [ocamlformat]: https://github.com/ocaml-ppx/ocamlformat [odoc]: https://github.com/ocaml/odoc @@ -427,6 +435,7 @@ All of the templates have only the root [flake](./flake.nix) as a flake input. T [purescript-language-server]: https://github.com/nwolverson/purescript-language-server [purs-tidy]: https://github.com/natefaubion/purescript-tidy [python]: https://python.org +[pytorch]: https://pytorch.org/ [r]: https://r-project.org [release]: https://github.com/NixOS/nixpkgs/releases/tag/22.11 [rmarkdown]: https://rmarkdown.rstudio.com diff --git a/flake.nix b/flake.nix index e2bc394..c37c4af 100644 --- a/flake.nix +++ b/flake.nix @@ -258,6 +258,11 @@ description = "Python development environment"; }; + pytorch = { + path = ./python-pytorch; + description = "PyTorch (with Python) development environment"; + }; + r = { path = ./r; description = "R development environment"; diff --git a/python-pytorch/.envrc b/python-pytorch/.envrc new file mode 100644 index 0000000..3550a30 --- /dev/null +++ b/python-pytorch/.envrc @@ -0,0 +1 @@ +use flake diff --git a/python-pytorch/flake.lock b/python-pytorch/flake.lock new file mode 100644 index 0000000..54f292f --- /dev/null +++ b/python-pytorch/flake.lock @@ -0,0 +1,25 @@ +{ + "nodes": { + "nixpkgs": { + "locked": { + "lastModified": 1737062831, + "narHash": "sha256-Tbk1MZbtV2s5aG+iM99U8FqwxU/YNArMcWAv6clcsBc=", + "rev": "5df43628fdf08d642be8ba5b3625a6c70731c19c", + "revCount": 738982, + "type": "tarball", + "url": "https://api.flakehub.com/f/pinned/NixOS/nixpkgs/0.1.738982%2Brev-5df43628fdf08d642be8ba5b3625a6c70731c19c/01947627-561b-7a9f-a379-f9ac4c680cb0/source.tar.gz" + }, + "original": { + "type": "tarball", + "url": "https://flakehub.com/f/NixOS/nixpkgs/0.1.%2A.tar.gz" + } + }, + "root": { + "inputs": { + "nixpkgs": "nixpkgs" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/python-pytorch/flake.nix b/python-pytorch/flake.nix new file mode 100644 index 0000000..1d9bf3b --- /dev/null +++ b/python-pytorch/flake.nix @@ -0,0 +1,67 @@ +{ + description = "A Nix-flake-based PyTorch development environment"; + + # CUDA binaries are cached by the community. + nixConfig = { + extra-substituters = [ + "https://nix-community.cachix.org" + ]; + extra-trusted-public-keys = [ + "nix-community.cachix.org-1:mB9FSh9qf2dCimDSUo8Zy7bkq5CX+/rkCWyvRCYg3Fs=" + ]; + }; + + inputs.nixpkgs.url = "https://flakehub.com/f/NixOS/nixpkgs/0.1.*.tar.gz"; + + outputs = { + self, + nixpkgs, + }: let + supportedSystems = ["x86_64-linux" "aarch64-linux" "x86_64-darwin" "aarch64-darwin"]; + forEachSupportedSystem = f: + nixpkgs.lib.genAttrs supportedSystems (system: + f { + pkgs = import nixpkgs { + inherit system; + config.allowUnfree = true; + }; + }); + in { + devShells = forEachSupportedSystem ({pkgs}: let + libs = [ + # PyTorch and Numpy depends on the following libraries. + pkgs.cudaPackages.cudatoolkit + pkgs.cudaPackages.cudnn + pkgs.stdenv.cc.cc.lib + pkgs.zlib + + # PyTorch also needs to know where your local "lib/libcuda.so" lives. + # If you're not on NixOS, you should provide the right path (likely + # another one). + "/run/opengl-driver" + ]; + in { + default = pkgs.mkShell { + packages = [ + pkgs.python312 + pkgs.python312Packages.venvShellHook + ]; + + env = { + CC = "${pkgs.gcc}/bin/gcc"; # For `torch.compile`. + LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath libs; + }; + + venvDir = ".venv"; + postVenvCreation = '' + # This is run only when creating the virtual environment. + pip install torch==2.5.1 numpy==2.2.2 + ''; + postShellHook = '' + # This is run every time you enter the devShell. + python3 -c "import torch; print('CUDA available' if torch.cuda.is_available() else 'CPU only')" + ''; + }; + }); + }; +}