Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

{ai}[gfbf/2024a] jax v0.4.34, ml_dtypes v0.5.0 w/ CUDA 12.6.0 WIP #21924

Draft
wants to merge 27 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9164c61
adding easyconfigs: jax-0.4.35-gfbf-2024a-CUDA-12.6.0.eb and patches:…
ThomasHoffmann77 Nov 28, 2024
179493f
add ml_dtypes v0.5.0
ThomasHoffmann77 Nov 28, 2024
d9668f2
fix style
ThomasHoffmann77 Nov 28, 2024
794c15d
checksums; add missing patch
ThomasHoffmann77 Nov 28, 2024
5724407
borrow pybind11/2.13.6 from PR #21864
ThomasHoffmann77 Nov 28, 2024
b910ecb
temporarily add pytest-xdist from #21879
ThomasHoffmann77 Nov 29, 2024
aa1ab42
Update jax-0.4.35-gfbf-2024a-CUDA-12.6.0.eb
ThomasHoffmann77 Nov 29, 2024
44ebc27
Update easyconfigs.py
ThomasHoffmann77 Nov 29, 2024
31eeb75
Update easyconfigs.py
ThomasHoffmann77 Nov 29, 2024
39e03a3
Update jax-0.4.35-gfbf-2024a-CUDA-12.6.0.eb
ThomasHoffmann77 Nov 30, 2024
a9ad131
Merge branch 'easybuilders:develop' into 20241128144208_new_pr_jax0435
ThomasHoffmann77 Dec 1, 2024
e6bb0e0
Update jax-0.4.35-gfbf-2024a-CUDA-12.6.0.eb
ThomasHoffmann77 Dec 2, 2024
086c5ec
temporarily add SciPy-bundle with pybind11 builddependency
ThomasHoffmann77 Dec 2, 2024
54916df
Update jax-0.4.35-gfbf-2024a-CUDA-12.6.0.eb
ThomasHoffmann77 Dec 2, 2024
fc5b969
Update jax-0.4.35-gfbf-2024a-CUDA-12.6.0.eb
ThomasHoffmann77 Dec 2, 2024
849f9fb
test v0.4.34 with pybind11/2.12.0 builddep
ThomasHoffmann77 Dec 20, 2024
b0afceb
Delete easybuild/easyconfigs/s/SciPy-bundle/SciPy-bundle-2024.05-gfbf…
ThomasHoffmann77 Dec 20, 2024
2e75e9f
Delete easybuild/easyconfigs/p/pybind11/pybind11-2.13.6-GCC-13.3.0.eb
ThomasHoffmann77 Dec 20, 2024
b988776
Delete easybuild/easyconfigs/j/jax/jax-0.4.35-gfbf-2024a-CUDA-12.6.0.eb
ThomasHoffmann77 Dec 20, 2024
6db6d36
Update jax-0.4.34-gfbf-2024a-CUDA-12.6.0.eb
ThomasHoffmann77 Dec 20, 2024
f3fc230
revert SciPy-bundle-2024.05-gfbf-2024a.eb
ThomasHoffmann77 Dec 20, 2024
45153cf
Update jax-0.4.34-gfbf-2024a-CUDA-12.6.0.eb
ThomasHoffmann77 Jan 16, 2025
259590c
Update jax-0.4.34-gfbf-2024a-CUDA-12.6.0.eb
ThomasHoffmann77 Jan 16, 2025
ef36815
also build jax_cuda12_plugin and jax_cuda12_pjrt
ThomasHoffmann77 Jan 17, 2025
26adb0e
set XLA_FLAGS xla_gpu_cuda_data_dir to $CUDA_HOME
ThomasHoffmann77 Jan 22, 2025
3695aff
fix style
ThomasHoffmann77 Jan 22, 2025
8fc8f35
add EC for Bazel v6.5.0
ThomasHoffmann77 Jan 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions easybuild/easyconfigs/b/Bazel/Bazel-6.5.0-GCCcore-13.3.0.eb
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name = 'Bazel'
version = '6.5.0'

homepage = 'https://bazel.io/'
description = """Bazel is a build tool that builds code quickly and reliably.
It is used to build the majority of Google's software."""

toolchain = {'name': 'GCCcore', 'version': '13.3.0'}

source_urls = ['https://github.com/bazelbuild/%(namelower)s/releases/download/%(version)s']
sources = ['%(namelower)s-%(version)s-dist.zip']
patches = ['Bazel-6.5.0_py3.12_pytest_assertEqual.patch']
checksums = [
{'bazel-6.5.0-dist.zip': 'fc89da919415289f29e4ff18a5e01270ece9a6fe83cb60967218bac4a3bb3ed2'},
{'Bazel-6.5.0_py3.12_pytest_assertEqual.patch': '2670dd5c393970ba20db2c98cf0208df7190ff339ccb66fee9a6d48aaaf3ede6'},
]

builddependencies = [
('binutils', '2.42'),
('Python', '3.12.3'),
('Zip', '3.0'),
]

dependencies = [
('Java', '11.0.20', '', SYSTEM),
]

runtest = True
testopts = "-- //examples/cpp:hello-success_test //examples/py/... //examples/py_native:test //examples/shell/..."

moduleclass = 'devel'
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Thomas Hoffmann, EMBL Heidelberg, [email protected], 2025/01
# replace assertEquals by assertEqual
# https://docs.python.org/3/whatsnew/3.12.html#id3
diff -ru bazel-6.5.0/examples/py_native/fail.py bazel-6.5.0_pytest_assertEqual/examples/py_native/fail.py
--- bazel-6.5.0/examples/py_native/fail.py 1980-01-01 00:00:00.000000000 +0100
+++ bazel-6.5.0_pytest_assertEqual/examples/py_native/fail.py 2025-01-24 14:27:22.973336188 +0100
@@ -6,7 +6,7 @@
class TestGetNumber(unittest.TestCase):

def test_fail(self):
- self.assertEquals(GetNumber(), 0)
+ self.assertEqual(GetNumber(), 0)


if __name__ == '__main__':
diff -ru bazel-6.5.0/examples/py_native/test.py bazel-6.5.0_pytest_assertEqual/examples/py_native/test.py
--- bazel-6.5.0/examples/py_native/test.py 1980-01-01 00:00:00.000000000 +0100
+++ bazel-6.5.0_pytest_assertEqual/examples/py_native/test.py 2025-01-24 14:27:22.973336188 +0100
@@ -8,10 +8,10 @@
class TestGetNumber(unittest.TestCase):

def test_ok(self):
- self.assertEquals(GetNumber(), 42)
+ self.assertEqual(GetNumber(), 42)

def test_fib(self):
- self.assertEquals(Fib(5), 8)
+ self.assertEqual(Fib(5), 8)

if __name__ == '__main__':
unittest.main()
229 changes: 229 additions & 0 deletions easybuild/easyconfigs/j/jax/jax-0.4.34-gfbf-2024a-CUDA-12.6.0.eb
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
# This file is an EasyBuild reciPY as per https://github.com/easybuilders/easybuild
# Author: Denis Kristak
# Updated by: Alex Domingo (Vrije Universiteit Brussel)
# Updated by: Pavel Tománek (INUITS)
# Updated by: Thomas Hoffmann (EMBL Heidelberg)
easyblock = 'PythonBundle'

name = 'jax'
version = '0.4.34'
versionsuffix = '-CUDA-%(cudaver)s'

homepage = 'https://jax.readthedocs.io/'
description = """Composable transformations of Python+NumPy programs:
differentiate, vectorize, JIT to GPU/TPU, and more"""

toolchain = {'name': 'gfbf', 'version': '2024a'}
cuda_compute_capabilities = ["5.0", "6.0", "6.1", "7.0", "7.5", "8.0", "8.6", "9.0"]

builddependencies = [
# ('Bazel', '7.4.1'), TODO: problems with @@local_config_python//:py3_runtime:
# Error in fail: interpreter_path must be an absolute path
# Bazel 6.5.0 (download) works.
('pybind11', '2.13.6'),
('pytest-xdist', '3.6.1'),
('git', '2.45.1'), # bazel uses git to fetch repositories
('matplotlib', '3.9.2'), # required for tests/lobpcg_test.py
('poetry', '1.8.3'),
('Clang', '18.1.8', versionsuffix)
]

dependencies = [
('CUDA', '12.6.0', '', SYSTEM), # 12.6.2 ?
('cuDNN', '9.5.0.50', versionsuffix, SYSTEM),
('NCCL', '2.22.3', versionsuffix),
('Python', '3.12.3'),
('SciPy-bundle', '2024.05'), # 2024.11 ?
('absl-py', '2.1.0'),
('flatbuffers-python', '24.3.25'),
('ml_dtypes', '0.5.0'),
('zlib', '1.3.1'),
]

# downloading xla and other tarballs to avoid that Bazel downloads it during the build
local_extract_cmd = 'mkdir -p %(builddir)s/archives && cp %s %(builddir)s/archives'
# note: following commits *must* be the exact same onces used upstream
# XLA_COMMIT from jax-jaxlib: third_party/xla/workspace.bzl
local_xla_commit = 'cd6e808c59f53b40a99df1f1b860db9a3e598bff'
# TFRT_COMMIT from xla: third_party/tsl/third_party/tf_runtime/workspace.bzl
local_tfrt_commit = '0aeefb1660d7e37964b2bb71b1f518096bda9a25' # TODO: still required?
# TODO: add other downloads

# Use sources downloaded by EasyBuild
_jaxlib_buildopts = '--bazel_options="--distdir=%(builddir)s/archives" '
# Use dependencies from EasyBuild
_jaxlib_buildopts += '--bazel_options="--action_env=TF_SYSTEM_LIBS=pybind11" '
_jaxlib_buildopts += '--bazel_options="--action_env=CPATH=$EBROOTPYBIND11/include:$EBROOTCUDA/extras/CUPTI/include" '
# Avoid warning (treated as error) in upb/table.c
_jaxlib_buildopts += '--bazel_options="--copt=-Wno-maybe-uninitialized" ' # TODO: still required?
# _jaxlib_buildopts += '--nouse_clang ' #TODO: avoid clang (?)
_jaxlib_buildopts += '--cuda_version=%(cudaver)s '
_jaxlib_buildopts += '--python_bin_path=$EBROOTPYTHON/bin/python3 '
# Do not use hermetic CUDA/cuDNN/NCCL: (requires action_env=CPATH=$EBROOTCUDA/extras/CUPTI/include";
# requires patch of external/xla/xla/tsl/cuda/cupti_stub.cc and jaxlib/gpu/vendor.h (#include <cupti.h>):
_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_CUDNN_PATH="$EBROOTCUDNN" """
_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_NCCL_PATH="$EBROOTNCCL" """
_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_CUDA_PATH="$EBROOTCUDA" """
_jaxlib_buildopts += """--bazel_options="--copt=-Ithird_party/gpus/cuda/extras/CUPTI/include" """

_plugins_buildopts = """--enable_cuda """
_plugins_buildopts += """--build_gpu_plugin """
# _plugins_buildopts +="""--gpu_plugin_cuda_version=12 """
_plugins_buildopts += """--build_gpu_pjrt_plugin """
_plugins_buildopts += """--build_gpu_kernel_plugin=cuda """

# get rid of .devDate versionsuffix: TODO: find a better way
# _no_devtag = """ export JAX_RELEASE && export JAXLIB_RELEASE && """ does not work (?)
_no_devtag = """ sed -i "s/version=__version__/version='%(version)s'/g" setup.py && """
_jaxlib_buildopts += """--bazel_options="--action_env=JAXLIB_RELEASE=1" """ # required?

components = [
('jaxlib', version, {
'sources': [
{
'source_urls': ['https://github.com/google/jax/archive/'],
'filename': 'jax-v%(version)s.tar.gz',
},
{
'source_urls': ['https://github.com/openxla/xla/archive'],
'download_filename': '%s.tar.gz' % local_xla_commit,
'filename': 'xla-%s.tar.gz' % local_xla_commit[:8],
'extract_cmd': local_extract_cmd,
},
{
'source_urls': ['https://github.com/tensorflow/runtime/archive'],
'download_filename': '%s.tar.gz' % local_tfrt_commit,
'filename': 'tf_runtime-%s.tar.gz' % local_tfrt_commit[:8],
'extract_cmd': local_extract_cmd,
},
],
'patches': [
'jax-0.4.35_easyblock_compat.patch',
'jax-0.4.35_fix-pybind11-systemlib_cupti.patch',
'jax-0.4.35_version.patch',
],
'checksums': [
{'jax-v0.4.34.tar.gz':
'd3a75ad667772309ade81350fa70c4a78028a920028800282e46d8383c0ee6bb'},
{'xla-cd6e808c.tar.gz':
'65cb6d63ef4083b35775052636cb9c629f86db6947c8b91711923ba31dbdcde8'},
{'tf_runtime-0aeefb16.tar.gz':
'a3df827d7896774cb1d80bf4e1c79ab05c268f29bd4d3db1fb5a4b9c2079d8e3'},
{'jax-0.4.35_easyblock_compat.patch':
'cbf4ad92b8438c4ce2a975efce1c47c57d4c3b117bceee071ab660f964057223'},
{'jax-0.4.35_fix-pybind11-systemlib_cupti.patch':
'51369589193be60dc94ec2de1b35d0a9268288578903fb05d41b6d1a8c9df460'},
{'jax-0.4.35_version.patch':
'cd2139a7802abf14b4b2cecee331aed80fff2ef91e16fa105093aea0795455e8'},
],
'start_dir': 'jax-jax-v%(version)s',
'buildopts': _jaxlib_buildopts,
'prebuildopts': ' mkdir third_party/gpus/cuda/extras/ -p && ' +
'ln -s $EBROOTCUDA/extras/CUPTI third_party/gpus/cuda/extras --relative &&' +
_no_devtag
}),
# build jaxlib first and then plugins in 2nd interation:
('jaxlib', version, {
'sources': [
{
'source_urls': ['https://github.com/google/jax/archive/'],
'filename': 'jax-v%(version)s.tar.gz',
},
{
'source_urls': ['https://github.com/openxla/xla/archive'],
'download_filename': '%s.tar.gz' % local_xla_commit,
'filename': 'xla-%s.tar.gz' % local_xla_commit[:8],
'extract_cmd': local_extract_cmd,
},
{
'source_urls': ['https://github.com/tensorflow/runtime/archive'],
'download_filename': '%s.tar.gz' % local_tfrt_commit,
'filename': 'tf_runtime-%s.tar.gz' % local_tfrt_commit[:8],
'extract_cmd': local_extract_cmd,
},
],
'checksums': [
{'jax-v0.4.34.tar.gz':
'd3a75ad667772309ade81350fa70c4a78028a920028800282e46d8383c0ee6bb'},
{'xla-cd6e808c.tar.gz':
'65cb6d63ef4083b35775052636cb9c629f86db6947c8b91711923ba31dbdcde8'},
{'tf_runtime-0aeefb16.tar.gz':
'a3df827d7896774cb1d80bf4e1c79ab05c268f29bd4d3db1fb5a4b9c2079d8e3'},
],
'start_dir': 'jax-jax-v%(version)s',
'buildopts': _jaxlib_buildopts + _plugins_buildopts,
'prebuildopts': _no_devtag
}),

]
# failing:
# tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_expm1_complex128 FAILED [ 98%]
# tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_expm1_complex64 FAILED [ 98%]
# tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_tan_complex128 FAILED [ 99%]
# tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_tan_complex64 FAILED [ 99%]
# FAILED tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_expm1_complex128 - AssertionError:
# FAILED tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_expm1_complex64 - AssertionError:
# FAILED tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_tan_complex128 - AssertionError:
# FAILED tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_tan_complex64 - AssertionError:
# tests/nn_test.py::NNFunctionsTest::testDotProductAttentionMask7 FAILED [ 10%]
# FAILED tests/nn_test.py::NNFunctionsTest::testDotProductAttentionMask7 - AssertionError:
#

# Some tests require an isolated run: TODO: still required?
local_isolated_tests = [
'tests/host_callback_test.py::HostCallbackTapTest::test_tap_scan_custom_jvp',
'tests/host_callback_test.py::HostCallbackTapTest::test_tap_transforms_doc',
'tests/lax_scipy_special_functions_test.py::LaxScipySpcialFunctionsTest' +
'::testScipySpecialFun_gammainc_s_2x1x4_float32_float32',
]
# deliberately not testing in parallel, as that results in (additional) failing tests;
# use XLA_PYTHON_CLIENT_ALLOCATOR=platform to allocate and deallocate GPU memory during testing,
# see https://github.com/google/jax/issues/7323 and
# https://github.com/google/jax/blob/main/docs/gpu_memory_allocation.rst;
# use CUDA_VISIBLE_DEVICES=0 to avoid failing tests on systems with multiple GPUs;
# use NVIDIA_TF32_OVERRIDE=0 to avoid loosing numerical precision by disabling TF32 Tensor Cores;
local_test_exports = [
"NVIDIA_TF32_OVERRIDE=0",
"CUDA_VISIBLE_DEVICES=0",
"XLA_PYTHON_CLIENT_ALLOCATOR=platform",
"JAX_ENABLE_X64=true",
]
local_test = ''.join(['export %s;' % x for x in local_test_exports])
# run all tests at once except for local_isolated_tests:
local_test += "pytest -vv tests %s && " % ' '.join(['--deselect %s' % x for x in local_isolated_tests])
# run remaining local_isolated_tests separately:
local_test += ' && '.join(['pytest -vv %s' % x for x in local_isolated_tests])

use_pip = True

exts_list = [
(name, version, {
'patches': ['jax-0.4.35_version.patch'],
'preinstallopts': _no_devtag,
'runtest': False,
'source_tmpl': '%(name)s-v%(version)s.tar.gz',
'source_urls': ['https://github.com/google/jax/archive/'],
'checksums': [
{'jax-v0.4.34.tar.gz': 'd3a75ad667772309ade81350fa70c4a78028a920028800282e46d8383c0ee6bb'},
{'jax-0.4.35_version.patch': 'cd2139a7802abf14b4b2cecee331aed80fff2ef91e16fa105093aea0795455e8'},
],
}),
]
sanity_check_commands = [
"""python -c "import jax_cuda"$(echo $EBVERSIONCUDA|awk -F '.' '{print $1}')"_plugin" """
]
sanity_pip_check = True

# TODO: patch to set default XLA_FLAGS
modluafooter = """
setenv("XLA_FLAGS", "--xla_gpu_cuda_data_dir=" .. os.getenv("CUDA_HOME"));
"""

modtclfooter = """
setenv XLA_FLAGS --xla_gpu_cuda_data_dir=$::env(CUDA_HOME)
"""

# TODO: sanity check paths


moduleclass = 'ai'
21 changes: 21 additions & 0 deletions easybuild/easyconfigs/j/jax/jax-0.4.35_easyblock_compat.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Thomas Hoffmann, EMBL Heidelberg, [email protected], 2024/11
# add dummy parameters to build/build.py for cudnn_path and cuda_path, which are set by default by the jaxlib easyblock.
diff -ru jax-jax-v0.4.35/build/build.py jax-jax-v0.4.35_easyblockcompat/build/build.py
--- jax-jax-v0.4.35/build/build.py 2024-10-22 21:00:23.000000000 +0200
+++ jax-jax-v0.4.35_easyblockcompat/build/build.py 2024-11-19 12:35:46.524479324 +0100
@@ -549,6 +549,15 @@
help_str="Same as update_requirements, but will consider dev, nightly "
"and pre-release versions of packages.")

+ parser.add_argument(
+ "--cuda_path",
+ default="dummy",
+ help="compatibility with jaxlib.py easyblock")
+ parser.add_argument(
+ "--cudnn_path",
+ default="dummy",
+ help="compatibility with jaxlib.py easyblock")
+
args = parser.parse_args()

logging.basicConfig()
Loading
Loading