-
Notifications
You must be signed in to change notification settings - Fork 57
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
nsys-jax: bugfix and expanded testing (#1132)
- Fix for profiling traced code that is not attributed to a named file. - More test coverage. - Cleanup `nsys-jax` handling of `-o` and `-f` options.
- Loading branch information
Showing
10 changed files
with
264 additions
and
89 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file modified
0
.github/container/jax_nsys/python/jax_nsys_analysis/summary.py
100755 → 100644
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
import jax | ||
|
||
|
||
@jax.jit | ||
def distinctively_named_function(x): | ||
return x @ x.T | ||
|
||
|
||
square = jax.random.normal(jax.random.key(1), (32, 32)) | ||
for _ in range(5): | ||
square = distinctively_named_function(square) |
14 changes: 14 additions & 0 deletions
14
.github/container/jax_nsys_tests/jax_nsys_test_helpers/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import subprocess | ||
import tempfile | ||
|
||
|
||
def nsys_jax(command): | ||
""" | ||
Helper to run nsys-jax with a unique output file that will be automatically | ||
cleaned up on destruction. | ||
""" | ||
output = tempfile.NamedTemporaryFile(suffix=".zip") | ||
subprocess.run( | ||
["nsys-jax", "--force-overwrite", "--output", output.name] + command, check=True | ||
) | ||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import os | ||
import subprocess | ||
import sys | ||
import tempfile | ||
import zipfile | ||
|
||
helper_dir = os.path.join(os.path.dirname(__file__), "jax_nsys_test_helpers") | ||
if helper_dir not in sys.path: | ||
sys.path.insert(0, helper_dir) | ||
from jax_nsys_test_helpers import nsys_jax # noqa: E402 | ||
|
||
|
||
def test_program_without_gpu_activity(): | ||
""" | ||
Profiling a program that doesn't do anything should succeed. | ||
""" | ||
nsys_jax([sys.executable, "-c", "print('Hello world!')"]) | ||
|
||
|
||
def test_stacktrace_entry_with_file(): | ||
""" | ||
Test that if a source file appears in the traceback of a JITed JAX function then | ||
the source file is bundled into the nsys-jax output archive. | ||
""" | ||
with tempfile.TemporaryDirectory() as tmpdir: | ||
archive = f"{tmpdir}/out.zip" | ||
src_file = f"{tmpdir}/test.py" | ||
assert os.path.isabs(src_file), src_file | ||
src_code = "import jax\njax.jit(lambda x: x*2)(4)\n" | ||
with open(src_file, "w") as f: | ||
f.write(src_code) | ||
subprocess.run( | ||
["nsys-jax", "--output", archive, sys.executable, src_file], check=True | ||
) | ||
with zipfile.ZipFile(archive) as ifile: | ||
src_file_in_archive = f"sources{src_file}" | ||
assert src_file_in_archive in ifile.namelist() | ||
with ifile.open(src_file_in_archive, "r") as archived_file: | ||
assert archived_file.read().decode() == src_code | ||
|
||
|
||
def test_stacktrace_entry_without_file(): | ||
""" | ||
Test that tracing code that does not come from a named file works (bug 4931958). | ||
""" | ||
archive = nsys_jax(["python", "-c", "import jax; jax.jit(lambda x: x*2)(4)"]) | ||
with zipfile.ZipFile(archive.name) as ifile: | ||
# The combination of -c and JAX suppressing references to its own source code | ||
# should mean that no source code files are gathered. | ||
assert not any(x.startswith("sources/") for x in ifile.namelist()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from jax_nsys import ( | ||
ensure_compiled_protos_are_importable, | ||
load_profiler_data, | ||
) | ||
import os | ||
import pathlib | ||
import pytest # type: ignore | ||
import sys | ||
import tempfile | ||
import zipfile | ||
|
||
helper_dir = os.path.join(os.path.dirname(__file__), "jax_nsys_test_helpers") | ||
if helper_dir not in sys.path: | ||
sys.path.insert(0, helper_dir) | ||
from jax_nsys_test_helpers import nsys_jax # noqa: E402 | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def example_program(): | ||
""" | ||
Fixture that yields an extracted archive of the result of profiling | ||
example_program.py with nsys-jax. | ||
""" | ||
tmpdir = tempfile.TemporaryDirectory() | ||
archive = nsys_jax( | ||
[sys.executable, os.path.join(os.path.dirname(__file__), "example_program.py")] | ||
) | ||
old_dir = os.getcwd() | ||
os.chdir(tmpdir.name) | ||
try: | ||
with zipfile.ZipFile(archive) as zf: | ||
zf.extractall() | ||
finally: | ||
os.chdir(old_dir) | ||
# Make sure the protobuf bindings can be imported, the generated .py will go into | ||
# a temporary directory that is not currently cleaned up. The bindings cannot be | ||
# un-imported from the test process, so there is a tacit assumption that in a given | ||
# test session there will only be one set of .proto files and it doesn't matter | ||
# which ones are picked up. | ||
ensure_compiled_protos_are_importable(prefix=pathlib.Path(tmpdir.name)) | ||
return tmpdir | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def profiler_data(example_program): | ||
return load_profiler_data(pathlib.Path(example_program.name)) | ||
|
||
|
||
def test_comms(profiler_data): | ||
# example_program.py should contain no communication | ||
assert len(profiler_data.communication) == 0 | ||
|
||
|
||
def test_modules(profiler_data): | ||
test_func_mask = profiler_data.module["Name"] == "jit_distinctively_named_function" | ||
assert sum(test_func_mask) == 5 | ||
test_func_data = profiler_data.module[test_func_mask] | ||
assert test_func_data.index.names == ["ProgramId", "ProgramExecution", "Device"] | ||
# All executions should have the same program id | ||
program_ids = test_func_data.index.get_level_values("ProgramId") | ||
assert all(program_ids == program_ids[0]) | ||
# All executions should be on device 0 | ||
execution_devices = test_func_data.index.get_level_values("Device") | ||
assert all(execution_devices == 0) | ||
# Execution indices should count from 0..n-1 | ||
execution_indices = test_func_data.index.get_level_values("ProgramExecution") | ||
assert all(execution_indices == range(len(test_func_data))) |
Oops, something went wrong.