Skip to content

Commit

Permalink
feat: adding PyTorch devshell
Browse files Browse the repository at this point in the history
feat(pytorch): adding the template to the flake template list

feat(python-pytorch): adding nix-community cache
  • Loading branch information
pierrot-lc committed Jan 22, 2025
1 parent 959eff5 commit 9bd0770
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 0 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ Once your preferred template has been initialized, you can use the provided shel
| [Pulumi] | [`pulumi`](./pulumi/) |
| [Purescript] | [`purescript`](./purescript/) |
| [Python] | [`python`](./python/) |
| [PyTorch] | [`pytorch`](./python-pytorch/)
| [R] | [`r`](./r/) |
| [Ruby] | [`ruby`](./ruby/) |
| [Rust] | [`rust`](./rust/) |
Expand Down Expand Up @@ -269,6 +270,12 @@ A dev template that's fully customizable.
- [Python] 3.11.4
- [pip] 23.0.1

### [`pytorch`](./python-pytorch/)

- [Python] 3.12.8
- [PyTorch] 2.5.1
- [NumPy] 2.2.2

### [`r`](./r/)

- [R] 4.3.1
Expand Down Expand Up @@ -408,6 +415,7 @@ All of the templates have only the root [flake](./flake.nix) as a flake input. T
[nomad-autoscaler]: https://github.com/hashicorp/nomad-autoscaler
[nomad-pack]: https://github.com/hashicorp/nomad-pack
[npm]: https://npmjs.org
[numpy]: https://numpy.org/
[ocaml]: https://ocaml.org
[ocamlformat]: https://github.com/ocaml-ppx/ocamlformat
[odoc]: https://github.com/ocaml/odoc
Expand All @@ -427,6 +435,7 @@ All of the templates have only the root [flake](./flake.nix) as a flake input. T
[purescript-language-server]: https://github.com/nwolverson/purescript-language-server
[purs-tidy]: https://github.com/natefaubion/purescript-tidy
[python]: https://python.org
[pytorch]: https://pytorch.org/
[r]: https://r-project.org
[release]: https://github.com/NixOS/nixpkgs/releases/tag/22.11
[rmarkdown]: https://rmarkdown.rstudio.com
Expand Down
5 changes: 5 additions & 0 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,11 @@
description = "Python development environment";
};

pytorch = {
path = ./python-pytorch;
description = "PyTorch (with Python) development environment";
};

r = {
path = ./r;
description = "R development environment";
Expand Down
1 change: 1 addition & 0 deletions python-pytorch/.envrc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
use flake
25 changes: 25 additions & 0 deletions python-pytorch/flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

67 changes: 67 additions & 0 deletions python-pytorch/flake.nix
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
{
description = "A Nix-flake-based PyTorch development environment";

# CUDA binaries are cached by the community.
nixConfig = {
extra-substituters = [
"https://nix-community.cachix.org"
];
extra-trusted-public-keys = [
"nix-community.cachix.org-1:mB9FSh9qf2dCimDSUo8Zy7bkq5CX+/rkCWyvRCYg3Fs="
];
};

inputs.nixpkgs.url = "https://flakehub.com/f/NixOS/nixpkgs/0.1.*.tar.gz";

outputs = {
self,
nixpkgs,
}: let
supportedSystems = ["x86_64-linux" "aarch64-linux" "x86_64-darwin" "aarch64-darwin"];
forEachSupportedSystem = f:
nixpkgs.lib.genAttrs supportedSystems (system:
f {
pkgs = import nixpkgs {
inherit system;
config.allowUnfree = true;
};
});
in {
devShells = forEachSupportedSystem ({pkgs}: let
libs = [
# PyTorch and Numpy depends on the following libraries.
pkgs.cudaPackages.cudatoolkit
pkgs.cudaPackages.cudnn
pkgs.stdenv.cc.cc.lib
pkgs.zlib

# PyTorch also needs to know where your local "lib/libcuda.so" lives.
# If you're not on NixOS, you should provide the right path (likely
# another one).
"/run/opengl-driver"
];
in {
default = pkgs.mkShell {
packages = [
pkgs.python312
pkgs.python312Packages.venvShellHook
];

env = {
CC = "${pkgs.gcc}/bin/gcc"; # For `torch.compile`.
LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath libs;
};

venvDir = ".venv";
postVenvCreation = ''
# This is run only when creating the virtual environment.
pip install torch==2.5.1 numpy==2.2.2
'';
postShellHook = ''
# This is run every time you enter the devShell.
python3 -c "import torch; print('CUDA available' if torch.cuda.is_available() else 'CPU only')"
'';
};
});
};
}

0 comments on commit 9bd0770

Please sign in to comment.