Skip to content

Commit

Permalink
nsys-jax: bugfix and expanded testing (#1132)
Browse files Browse the repository at this point in the history
- Fix for profiling traced code that is not attributed to a named file.
- More test coverage.
- Cleanup `nsys-jax` handling of `-o` and `-f` options.
  • Loading branch information
olupton authored Nov 1, 2024
1 parent 1dad010 commit b1103a0
Show file tree
Hide file tree
Showing 10 changed files with 264 additions and 89 deletions.
11 changes: 9 additions & 2 deletions .github/container/Dockerfile.base
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,20 @@ COPY check-shm.sh /opt/nvidia/entrypoint.d/
## Add helper scripts for profiling with Nsight Systems
##
## The scripts saved to /opt/jax_nsys are embedded in the output archives
## written by nsys-jax, while the nsys-jax wrapper is used inside the container.
## written by nsys-jax, while the nsys-jax and nsys-jax-combine scripts are
## only used inside the containers.
###############################################################################

ADD nsys-jax nsys-jax-combine /usr/local/bin/
ADD jax_nsys/ /opt/jax_nsys
# The jax_nsys package should be installed inside the containers, so nsys-jax
# can eagerly execute analysis recipes (--nsys-jax-analysis) in the container
# environment, without an extra layer of virtual environment indirection.
RUN echo "-e /opt/jax_nsys/python/jax_nsys" > /opt/pip-tools.d/requirements-nsys-jax.in
# This should be embedded in output archives and be runnable inside containers
RUN ln -s /opt/jax_nsys/install-protoc /usr/local/bin/
# Should be available for execution inside the containers, should not be
# embedded in the output archives.
ADD jax_nsys_tests/ /opt/jax_nsys_tests

###############################################################################
## Copy manifest file to the container
Expand Down
4 changes: 4 additions & 0 deletions .github/container/jax_nsys/python/jax_nsys/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ dependencies = [
"uncertainties", # communication analysis recipe
]
requires-python = ">= 3.10"
[project.optional-dependencies]
test = [
"pytest"
]
Empty file modified .github/container/jax_nsys/python/jax_nsys_analysis/summary.py
100755 → 100644
Empty file.
11 changes: 11 additions & 0 deletions .github/container/jax_nsys_tests/example_program.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import jax


@jax.jit
def distinctively_named_function(x):
return x @ x.T


square = jax.random.normal(jax.random.key(1), (32, 32))
for _ in range(5):
square = distinctively_named_function(square)
14 changes: 14 additions & 0 deletions .github/container/jax_nsys_tests/jax_nsys_test_helpers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import subprocess
import tempfile


def nsys_jax(command):
"""
Helper to run nsys-jax with a unique output file that will be automatically
cleaned up on destruction.
"""
output = tempfile.NamedTemporaryFile(suffix=".zip")
subprocess.run(
["nsys-jax", "--force-overwrite", "--output", output.name] + command, check=True
)
return output
50 changes: 50 additions & 0 deletions .github/container/jax_nsys_tests/test_basics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import os
import subprocess
import sys
import tempfile
import zipfile

helper_dir = os.path.join(os.path.dirname(__file__), "jax_nsys_test_helpers")
if helper_dir not in sys.path:
sys.path.insert(0, helper_dir)
from jax_nsys_test_helpers import nsys_jax # noqa: E402


def test_program_without_gpu_activity():
"""
Profiling a program that doesn't do anything should succeed.
"""
nsys_jax([sys.executable, "-c", "print('Hello world!')"])


def test_stacktrace_entry_with_file():
"""
Test that if a source file appears in the traceback of a JITed JAX function then
the source file is bundled into the nsys-jax output archive.
"""
with tempfile.TemporaryDirectory() as tmpdir:
archive = f"{tmpdir}/out.zip"
src_file = f"{tmpdir}/test.py"
assert os.path.isabs(src_file), src_file
src_code = "import jax\njax.jit(lambda x: x*2)(4)\n"
with open(src_file, "w") as f:
f.write(src_code)
subprocess.run(
["nsys-jax", "--output", archive, sys.executable, src_file], check=True
)
with zipfile.ZipFile(archive) as ifile:
src_file_in_archive = f"sources{src_file}"
assert src_file_in_archive in ifile.namelist()
with ifile.open(src_file_in_archive, "r") as archived_file:
assert archived_file.read().decode() == src_code


def test_stacktrace_entry_without_file():
"""
Test that tracing code that does not come from a named file works (bug 4931958).
"""
archive = nsys_jax(["python", "-c", "import jax; jax.jit(lambda x: x*2)(4)"])
with zipfile.ZipFile(archive.name) as ifile:
# The combination of -c and JAX suppressing references to its own source code
# should mean that no source code files are gathered.
assert not any(x.startswith("sources/") for x in ifile.namelist())
67 changes: 67 additions & 0 deletions .github/container/jax_nsys_tests/test_example_program.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from jax_nsys import (
ensure_compiled_protos_are_importable,
load_profiler_data,
)
import os
import pathlib
import pytest # type: ignore
import sys
import tempfile
import zipfile

helper_dir = os.path.join(os.path.dirname(__file__), "jax_nsys_test_helpers")
if helper_dir not in sys.path:
sys.path.insert(0, helper_dir)
from jax_nsys_test_helpers import nsys_jax # noqa: E402


@pytest.fixture(scope="module")
def example_program():
"""
Fixture that yields an extracted archive of the result of profiling
example_program.py with nsys-jax.
"""
tmpdir = tempfile.TemporaryDirectory()
archive = nsys_jax(
[sys.executable, os.path.join(os.path.dirname(__file__), "example_program.py")]
)
old_dir = os.getcwd()
os.chdir(tmpdir.name)
try:
with zipfile.ZipFile(archive) as zf:
zf.extractall()
finally:
os.chdir(old_dir)
# Make sure the protobuf bindings can be imported, the generated .py will go into
# a temporary directory that is not currently cleaned up. The bindings cannot be
# un-imported from the test process, so there is a tacit assumption that in a given
# test session there will only be one set of .proto files and it doesn't matter
# which ones are picked up.
ensure_compiled_protos_are_importable(prefix=pathlib.Path(tmpdir.name))
return tmpdir


@pytest.fixture(scope="module")
def profiler_data(example_program):
return load_profiler_data(pathlib.Path(example_program.name))


def test_comms(profiler_data):
# example_program.py should contain no communication
assert len(profiler_data.communication) == 0


def test_modules(profiler_data):
test_func_mask = profiler_data.module["Name"] == "jit_distinctively_named_function"
assert sum(test_func_mask) == 5
test_func_data = profiler_data.module[test_func_mask]
assert test_func_data.index.names == ["ProgramId", "ProgramExecution", "Device"]
# All executions should have the same program id
program_ids = test_func_data.index.get_level_values("ProgramId")
assert all(program_ids == program_ids[0])
# All executions should be on device 0
execution_devices = test_func_data.index.get_level_values("Device")
assert all(execution_devices == 0)
# Execution indices should count from 0..n-1
execution_indices = test_func_data.index.get_level_values("ProgramExecution")
assert all(execution_indices == range(len(test_func_data)))
Loading

0 comments on commit b1103a0

Please sign in to comment.