Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use new ResourceEnvs from Haliax #444

Open
wants to merge 236 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
236 commits
Select commit Hold shift + click to select a range
fff4dfb
wip
dlwh Oct 25, 2023
740ad68
wip
dlwh Nov 7, 2023
4ad74a6
almost got new logger working
dlwh Nov 7, 2023
ad708e3
move the metrics stuff to its own file
dlwh Nov 8, 2023
6930fa9
refactor and move stuff around
dlwh Nov 8, 2023
abf7ec3
use generic infrastructure for summary
dlwh Nov 8, 2023
547cea8
wip towards a clean tracker package
dlwh Nov 8, 2023
2f481ed
wip
dlwh Nov 9, 2023
0b080fb
remove more wandb deps
dlwh Nov 9, 2023
a324ae5
tiny cleanup
dlwh Nov 9, 2023
cfdcbb9
add some tests
dlwh Nov 9, 2023
2ddc558
migrate alpaca-lora to new logger
dlwh Nov 9, 2023
9b0df08
sort of get tb to work
dlwh Nov 10, 2023
4fd2526
wip
dlwh Nov 14, 2023
a608a65
wip
dlwh Nov 16, 2023
176e5fa
Merge remote-tracking branch 'origin/main' into generic_logger
dlwh Nov 17, 2023
8d34f6f
update configs, expose a method to find trackers
dlwh Nov 17, 2023
42d7f2c
use `trainer` more to set logging
dlwh Nov 17, 2023
b887761
test the tracker get name stuff
dlwh Nov 17, 2023
3ebd161
minor
dlwh Nov 17, 2023
0d2efbc
making speccing the loss function simpler
dlwh Nov 18, 2023
f085287
stop requiring a loss function for every model definition
dlwh Nov 18, 2023
f21cf4b
wip
dlwh Nov 19, 2023
01c8b87
jkacjkac
dlwh Nov 19, 2023
e374697
tweak
dlwh Nov 22, 2023
921acf8
register default hooks by default...
dlwh Nov 22, 2023
c8a5d6c
wip
dlwh Nov 24, 2023
639d334
make it so we can evaluate if we have a cache but no sources
dlwh Nov 24, 2023
a3fdbaf
Merge branch 'cache_only' into extensible_trainer
dlwh Nov 24, 2023
ec35e9b
about got the checkpoint refactor done
dlwh Nov 25, 2023
ed13502
about got the checkpoint refactor done
dlwh Nov 25, 2023
634407e
minor dead code removal
dlwh Nov 25, 2023
4208e03
fix tests
dlwh Nov 26, 2023
9584884
cleanup
dlwh Nov 26, 2023
5a18678
cleanup
dlwh Nov 26, 2023
c355106
minor
dlwh Nov 26, 2023
7a2ffc3
Merge branch 'extensible_trainer' into doremi
dlwh Nov 26, 2023
d2e0de1
wip
dlwh Nov 26, 2023
be99631
register default hooks by default...
dlwh Nov 22, 2023
5d033eb
wip
dlwh Nov 24, 2023
c4a9160
make it so we can evaluate if we have a cache but no sources
dlwh Nov 24, 2023
b888065
about got the checkpoint refactor done
dlwh Nov 25, 2023
c47ae97
about got the checkpoint refactor done
dlwh Nov 25, 2023
f0613c7
minor dead code removal
dlwh Nov 25, 2023
85c5678
fix tests
dlwh Nov 26, 2023
8f84822
cleanup
dlwh Nov 26, 2023
e54bad0
cleanup
dlwh Nov 26, 2023
85dd89b
minor
dlwh Nov 26, 2023
c61824e
generalize and extract the checkpoint loading logic so it can be used…
dlwh Nov 27, 2023
7391475
Revert "Temporarily Revert "Generic Tracker interface, support for TB…
dlwh Nov 28, 2023
2387f26
wip
dlwh Nov 28, 2023
6446bc0
just about workable logger stuff
dlwh Nov 28, 2023
1b821d1
fix logging of config with a new levanter.initialize
dlwh Nov 28, 2023
afb6459
missed a sopt
dlwh Nov 28, 2023
9d916bd
on second thought, don't use tb in small_fast
dlwh Nov 29, 2023
3d67552
Merge remote-tracking branch 'origin/dev' into extensible_trainer
dlwh Nov 30, 2023
4d8cd68
main->dev (#375)
dlwh Dec 1, 2023
272e1e1
Merge remote-tracking branch 'origin/dev' into extensible_trainer
dlwh Dec 1, 2023
48ccdd3
Merge remote-tracking branch 'origin/main' into dev
dlwh Dec 1, 2023
3b27a08
supporting new trainer in gsm8k example
dlwh Dec 1, 2023
dcbed88
Merge branch 'dev' into extensible_trainer
dlwh Dec 2, 2023
bbac4ef
Merge remote-tracking branch 'origin/main' into extensible_trainer
dlwh Dec 2, 2023
f2842e9
Add Sophia-H, some WIP support for Sophia-G (#372)
dlwh Dec 7, 2023
6d6ae21
Merge remote-tracking branch 'origin/main' into dev
dlwh Dec 10, 2023
83bea6e
fix missing test changes
dlwh Dec 10, 2023
92a615f
should use a tempdir
dlwh Dec 11, 2023
cbee427
update gsm8k lora for sophia refactors
dlwh Dec 11, 2023
e048581
Allow val change wandb dev (#384)
dlwh Dec 13, 2023
2bdf08b
oops
dlwh Dec 13, 2023
8f4aff3
do loss in fp32
dlwh Dec 14, 2023
91eb588
Merge remote-tracking branch 'origin/dev' into extensible_trainer
dlwh Dec 17, 2023
2002832
more dead code removal
dlwh Dec 17, 2023
efa70a1
refix merge issues
dlwh Dec 17, 2023
4cca0d1
refix merge issues
dlwh Dec 17, 2023
15e223d
Merge remote-tracking branch 'origin/main' into dev
dlwh Dec 19, 2023
904497b
Merge branch 'dev' into extensible_trainer
dlwh Dec 19, 2023
2a90f57
allow train_batch_size to be -1 if per_device_parallelism isn't -1
dlwh Dec 19, 2023
f05739a
wip
dlwh Dec 21, 2023
321bb30
Merge remote-tracking branch 'origin/main' into extensible_trainer
dlwh Dec 21, 2023
38db3d5
fix performance regression in trainer.py
dlwh Dec 21, 2023
9b2813b
wth
dlwh Dec 21, 2023
e014c45
mdkladmlkad
dlwh Dec 21, 2023
95a391f
jfakmfa
dlwh Dec 21, 2023
9f40f10
try this other approach to steps in TrainerState
dlwh Dec 21, 2023
6df53f4
fix checkpoint tests
dlwh Dec 21, 2023
94aa8fa
fix gsm8k
dlwh Dec 21, 2023
5af6cb2
update for new Haliax reduction functions
dlwh Dec 24, 2023
e2b086f
Merge branch 'extensible_trainer' into doremi
dlwh Dec 24, 2023
84d3b33
wip
dlwh Dec 24, 2023
85b42b0
refactor grad_accum to have a separate microbatched
dlwh Dec 25, 2023
c47c188
remove accumulate_gradients_sharded and just use microbatched directly
dlwh Dec 27, 2023
70b766f
add dtype for grad accum
dlwh Dec 27, 2023
57725ea
small refactor
dlwh Dec 27, 2023
85f777b
small refactor
dlwh Dec 27, 2023
f8d98fc
fix key handling in grad accum
dlwh Dec 27, 2023
5a8c77a
make sophia work with non-trainables again
dlwh Dec 28, 2023
ff59e51
factor out some methods in train_step
dlwh Dec 28, 2023
c1718dd
Merge branch 'extensible_trainer' into doremi
dlwh Dec 28, 2023
d7a060d
make the initialize_from logic just use load_checkpoint_or_initialize
dlwh Dec 29, 2023
8c44e64
on second thought load_from_checkpoint_or_initialize is the wrong abs…
dlwh Dec 30, 2023
72f1e47
wip
dlwh Dec 30, 2023
add3df4
on second thought load_from_checkpoint_or_initialize is the wrong abs…
dlwh Dec 30, 2023
3ba7bf1
Merge branch 'extensible_trainer' into doremi
dlwh Dec 30, 2023
b6535b5
wip factoring out the initial state stuff, again
dlwh Dec 30, 2023
0d6f357
almost ready to try out doremi
dlwh Dec 30, 2023
7395e3c
almost ready to try out doremi
dlwh Jan 2, 2024
08996e6
cleanup typing.overloads
dlwh Jan 3, 2024
710900c
use auto_sharded internally, undeprecate it b/c it has a point
dlwh Jan 3, 2024
5f9d96d
fix docs
dlwh Jan 4, 2024
04a74a1
use new dot syntax in doremi
dlwh Jan 4, 2024
3249ca1
Merge remote-tracking branch 'origin/main' into doremi
dlwh Jan 8, 2024
6a20c95
fix mixture init with prngkey
dlwh Jan 9, 2024
fd6d343
add a simple InMemoryDataset that takes a list
dlwh Jan 9, 2024
f5b8d00
make keyiterator support just an int seed
dlwh Jan 9, 2024
288e7fb
dumb bug in grad accum
dlwh Jan 9, 2024
c4da125
fix some dumb bugs in new trainer
dlwh Jan 9, 2024
9257597
test for doremi and associated fixes
dlwh Jan 9, 2024
317b10d
depend on haliax dev for levanter dev
dlwh Jan 9, 2024
e4d1385
fix gsm8k_lora
dlwh Jan 9, 2024
ddcdac7
add a small_pile configuration
dlwh Jan 9, 2024
792f769
make it len 2048
dlwh Jan 9, 2024
e16b3af
add doremi main
dlwh Jan 10, 2024
a272ca9
we install haliax from source with the pyprojec.toml
dlwh Jan 10, 2024
e8d4b9d
fix doremi test when doing multidevice
dlwh Jan 10, 2024
5c489c1
add a pile_mixture.yaml
dlwh Jan 10, 2024
1672148
add a config for the small pile mixture
dlwh Jan 10, 2024
f485c5f
reduce default rows per chunk and see if that helps with these big su…
dlwh Jan 10, 2024
b2d8a58
add some more logging to see if we can figure out why it's running ou…
dlwh Jan 10, 2024
f76e466
add some more logging to see if we can figure out why it's running ou…
dlwh Jan 11, 2024
fc78716
dumb
dlwh Jan 11, 2024
4927f67
don't run the slow tests in CI
dlwh Jan 11, 2024
1ceb00a
wip
dlwh Jan 12, 2024
bc7108c
move the script, make it read off fsspec
dlwh Jan 13, 2024
69ca4a4
update for reverted Haliax change
dlwh Jan 13, 2024
ff5cb6d
update for reverted Haliax change
dlwh Jan 13, 2024
d6bf2c0
update paths for pile mixture
dlwh Jan 15, 2024
cc6044c
fix new import
dlwh Jan 15, 2024
415158a
sigh
dlwh Jan 15, 2024
d2a90ae
isjfo
dlwh Jan 15, 2024
058a9e0
mdklmdlm
dlwh Jan 15, 2024
9f16fbe
make logging list names of caches
dlwh Jan 15, 2024
b80ef6a
lower resource requirements to see if this gets us processing faster
dlwh Jan 15, 2024
6983ff0
let's make the chunkcachebuilders free
dlwh Jan 15, 2024
5f42ad8
minimize use of optax internals
dlwh Jan 15, 2024
e6e8d27
fix a crash i don't understand
dlwh Jan 16, 2024
ab29e92
let's reduce requirements some more to see if we can keep everything …
dlwh Jan 16, 2024
83f0616
let's reduce requirements some more to see if we can keep everything …
dlwh Jan 16, 2024
def45cc
silly
dlwh Jan 16, 2024
de821ca
ok so we're ok maybe
dlwh Jan 16, 2024
cbddab8
don't fetch local
dlwh Jan 16, 2024
5d0f987
wtf
dlwh Jan 16, 2024
13cc556
what
dlwh Jan 16, 2024
5afac01
ok, think we figured it out
dlwh Jan 16, 2024
41ac362
less logging
dlwh Jan 16, 2024
c621a08
toward turning the reader process into an actor too
dlwh Jan 17, 2024
4d92af9
did we do it?
dlwh Jan 17, 2024
257dfa7
wandb: only force a step if commit is true
dlwh Jan 17, 2024
23865a1
don't crash if n == 0
dlwh Jan 17, 2024
70c00f1
wandb: maybe this gives the behavior i want?
dlwh Jan 17, 2024
4c54365
mklafmlkafml
dlwh Jan 17, 2024
1edeeef
Merge branch 'main' into dev
dlwh Jan 17, 2024
6148381
minimize use of optax internals
dlwh Jan 15, 2024
8274cad
what
dlwh Jan 18, 2024
b980c9f
actually this is probably better
dlwh Jan 18, 2024
36f25a0
actually this is probably better
dlwh Jan 18, 2024
4da7112
dumb
dlwh Jan 18, 2024
6147520
mkladmlkad
dlwh Jan 18, 2024
e166a78
fix key order for doremi
dlwh Jan 18, 2024
e6b581b
remove excess log
dlwh Jan 18, 2024
8c64be5
remove a redundant log message
dlwh Jan 18, 2024
e89e709
fixed more bugs
dlwh Jan 18, 2024
33600fd
almost there
dlwh Jan 18, 2024
efbdd31
don't log a value for domains with no data on a step
dlwh Jan 18, 2024
a810242
bring over the trainer-abstraction doc
dlwh Jan 30, 2024
e49fb38
remove the wrapped loss_fn thing from trainer
dlwh Jan 30, 2024
13dc392
factor out a take_opt_step. need to decide where to put it
dlwh Jan 30, 2024
514da05
explicitly expose microbatch_size, use it in microbatched
dlwh Jan 30, 2024
f797a85
comment about custom_jvp on microbatched
dlwh Jan 31, 2024
4301930
unneeded cast
dlwh Jan 31, 2024
d3416b1
rename to mixed-precision.md
dlwh Jan 31, 2024
9552909
cleanup ctors for BatchLoaders some
dlwh Jan 31, 2024
888d35e
misc cleanup
dlwh Jan 31, 2024
49a409b
wip
dlwh Jan 31, 2024
78d9342
stable point: migrating to resourceenvs
dlwh Jan 31, 2024
8e4e183
require the jamp branch
dlwh Jan 31, 2024
9a0ea6d
knknajkdnjakd
dlwh Jan 31, 2024
7c19f47
try this?
dlwh Jan 31, 2024
d98a885
cleanup and explain the issue
dlwh Jan 31, 2024
015dfb3
see if we get the just-in-time conversion to bf16 that we want
dlwh Jan 31, 2024
cddaf20
wtf
dlwh Feb 1, 2024
27949f8
bypass microbatching if we don't need it?
dlwh Feb 1, 2024
c3a9ce1
switch to using hnn.Embedding in gpt2, which means we get the mixed p…
dlwh Feb 1, 2024
0e91352
switch to using compute_envs where posisble use .shard instead
dlwh Feb 1, 2024
b57e1c7
please pre-commit
dlwh Feb 1, 2024
7fd46cb
ok maybe we can do it?
dlwh Feb 1, 2024
2ca4d97
sigh
dlwh Feb 1, 2024
b1e99e5
Merge branch 'dev' into use_jamp
dlwh Feb 1, 2024
a237a57
fix test_weight_decay_mask.py
dlwh Feb 1, 2024
5282694
use param_env everywhere
dlwh Feb 1, 2024
a013c4c
makldmlkad
dlwh Feb 1, 2024
3049f89
Merge remote-tracking branch 'origin/main' into use_jamp
dlwh Feb 2, 2024
e4fcd67
Merge remote-tracking branch 'origin/main' into dev
dlwh Feb 2, 2024
4a8d07a
Merge branch 'dev' into use_jamp
dlwh Feb 2, 2024
1983a1f
Merge remote-tracking branch 'origin/main' into dev
dlwh Feb 2, 2024
5312b87
Merge remote-tracking branch 'origin/main' into dev
dlwh Feb 2, 2024
90ed9cd
Merge remote-tracking branch 'origin/main' into use_jamp
dlwh Feb 2, 2024
74096fa
Merge branch 'dev' into use_jamp
dlwh Feb 2, 2024
58ca1d7
wip debugging devices
dlwh Feb 2, 2024
6ee6d8f
let's try this?
dlwh Feb 2, 2024
b485673
so confused
dlwh Feb 2, 2024
de3162b
sigh
dlwh Feb 2, 2024
343367f
ok i think i got it
dlwh Feb 2, 2024
1198bb2
Merge branch 'dev' into simple_first_cleanup
dlwh Feb 3, 2024
a2d5934
Merge branch 'simple_first_cleanup' into use_jamp
dlwh Feb 3, 2024
71b755e
wtf
dlwh Feb 5, 2024
07b5797
this async seems like a bad idea
dlwh Feb 5, 2024
b6e0c1d
log perf numbers?
dlwh Feb 5, 2024
a3f9c7f
more logging
dlwh Feb 5, 2024
ce2db7b
moar
dlwh Feb 5, 2024
ea57bde
oops
dlwh Feb 5, 2024
d352b37
reduce logging some, try to figure out this stupid books problem
dlwh Feb 5, 2024
0effef0
ka dkla dkl
dlwh Feb 5, 2024
1e85d16
admaldl
dlwh Feb 5, 2024
a25a8ce
fix the unnecessarily long time outs
dlwh Feb 6, 2024
7c163a8
break really long docs into shorter docs b/c tokenizers is quadratic
dlwh Feb 6, 2024
4125d3f
kmklamdklad
dlwh Feb 6, 2024
99a87e8
maybe don't do the workaround so often?
dlwh Feb 6, 2024
5245e10
is this the leak?!?
dlwh Feb 6, 2024
8a6f59b
update for latest datasets
dlwh Feb 6, 2024
002989b
add a test to ensure we use the workaround for llama tokenizer
dlwh Feb 6, 2024
3dfebe2
tweak timeouts in test
dlwh Feb 6, 2024
5a5a1f1
less spammy logging
dlwh Feb 6, 2024
4e6df52
cleanup, see if we can avoid crashing when one cache finishes
dlwh Feb 6, 2024
c2dccf2
tweaks to tokenization/shard_cache throughput (#456)
dlwh Feb 6, 2024
f7a3d0a
Merge remote-tracking branch 'origin/main' into use_jamp_harfleur
dlwh Feb 6, 2024
c95ba8b
Merge branch 'use_jamp_harfleur' into use_jamp
dlwh Feb 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
wip
dlwh committed Nov 10, 2023
commit 740ad6898d157815a9fea60c81b5c0677dc146a8
154 changes: 126 additions & 28 deletions src/levanter/logging.py
Original file line number Diff line number Diff line change
@@ -5,14 +5,14 @@
import os
import tempfile
import time
import typing
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Any, List, Optional, Union

import draccus
import jax
import wandb
from draccus import field
from git import InvalidGitRepositoryError, NoSuchPathError, Repo
from optax import MultiStepsState
@@ -21,10 +21,73 @@
from levanter.utils.jax_utils import jnp_to_python


logger = pylogging.getLogger(__name__)
pylogger = pylogging.getLogger(__name__)

class LoggerSink(abc.ABC):
_global_logger: Optional["MetricsLogger"] = None


def log_metrics(metrics: dict[str, Any], *, step):
"""
Log metrics to the global logger.

:param metrics: Metrics to log
:param step: Step to log metrics at
"""
global _global_logger
if _global_logger is None:
raise RuntimeError("No global logger set")

_global_logger.log(metrics, step=step)


def jit_log_metrics(metrics, *, step=None):
"""uses jax effect callback to log to wandb from the host"""
jax.debug.callback(log_metrics, metrics, step=step)


def log_summary(metrics: dict[str, Any]):
"""
Log summary metrics to the global logger.

:param metrics: Metrics to log
"""
global _global_logger
if _global_logger is None:
raise RuntimeError("No global logger set")
_global_logger.log_summary(metrics)

@typing.overload
def global_logger() -> "MetricsLogger":
...


@typing.overload
def global_logger(logger: "MetricsLogger") -> contextlib.AbstractContextManager:
"""Context manager for setting the global logger"""
...


def global_logger(logger: Optional["MetricsLogger"] = None) -> Union["MetricsLogger", contextlib.AbstractContextManager]:
"""
Get or set the global logger.

:param logger: If provided, sets the global logger to this value.
:return: The global logger, or a context manager for setting the global logger.
"""
global _global_logger
if logger is None:
if _global_logger is None:
raise RuntimeError("No global logger set")
return _global_logger
else:
return _GlobalLoggerContextManager(logger)


class MetricsLogger(abc.ABC):
"""
A logger for logging metrics to some backend(s).
Meant to be used with the [global_logger][] context manager, but can also be used directly.
"""
@abc.abstractmethod
def init(self, run_id: Optional[str]):
pass
@@ -33,7 +96,6 @@ def init(self, run_id: Optional[str]):
def log_hyperparameters(self, hparams: dict[str, Any]):
pass


@abc.abstractmethod
def log(self, metrics: dict[str, Any], *, step):
"""
@@ -49,7 +111,47 @@ def log_summary(self, metrics: dict[str, Any]):
def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None):
pass

class WandbLoggerSink(LoggerSink):

class CompositeLogger(MetricsLogger):
def __init__(self, loggers: List[MetricsLogger]):
self.loggers = loggers

def init(self, run_id: Optional[str]):
for logger in self.loggers:
logger.init(run_id)

def log_hyperparameters(self, hparams: dict[str, Any]):
for logger in self.loggers:
logger.log_hyperparameters(hparams)

def log(self, metrics: dict[str, Any], *, step):
for logger in self.loggers:
logger.log(metrics, step=step)

def log_summary(self, metrics: dict[str, Any]):
for logger in self.loggers:
logger.log_summary(metrics)

def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None):
for logger in self.loggers:
logger.log_artifact(artifact, name=name, type=type)


class _GlobalLoggerContextManager(contextlib.AbstractContextManager):
def __init__(self, logger: "MetricsLogger"):
self.logger = logger

def __enter__(self):
global _global_logger
self.old_logger = _global_logger
_global_logger = self.logger

def __exit__(self, exc_type, exc_val, exc_tb):
global _global_logger
_global_logger = self.old_logger


class WandbLogger(MetricsLogger):
def __init__(self, config: 'WandbConfig'):
self.config = config
self._run = None
@@ -78,7 +180,7 @@ def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[s
self._run.log_artifact(artifact, name=name, type=type)


class TensorboardLoggerSink(LoggerSink):
class TensorboardLogger(MetricsLogger):

def __init__(self, logdir: Union[str, Path]):
self.logdir = logdir
@@ -102,14 +204,11 @@ def log_summary(self, metrics: dict[str, Any]):
for k, v in metrics.items():
self.writer.add_scalar(k, v, 0)


def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None):
pylogger.warning("TensorboardLoggerSink does not support logging artifacts yet")
pass





def log_optimizer_hyperparams(opt_state, prefix: Optional[str] = None, *, step=None):
if isinstance(opt_state, MultiStepsState):
opt_state = opt_state.inner_opt_state
@@ -121,10 +220,10 @@ def wrap_key(key):

if hasattr(opt_state, "hyperparams"):
params = {wrap_key(k): jnp_to_python(v) for k, v in opt_state.hyperparams.items()}
wandb.log(params, step=step)
log_metrics(params, step=step)


def init_logger(path: Union[str, Path], level: int = pylogging.INFO) -> None:
def init_logging(path: Union[str, Path], level: int = pylogging.INFO) -> None:
"""
Initialize logging.Logger with the appropriate name, console, and file handlers.

@@ -147,14 +246,19 @@ def init_logger(path: Union[str, Path], level: int = pylogging.INFO) -> None:

def save_xla_dumps_to_wandb(initial_time: float):
import os
if not is_wandb_available():
pylogger.warning("Wandb is not available, so we can't save XLA dumps")
return

import wandb

# attempt to parse xla_flags to see if we're dumping assembly files
flags = os.getenv("XLA_FLAGS", None)
if flags is not None and "xla_dump_to" in flags:
# parse the path
# this isn't robust to quotes
path = flags.split("xla_dump_to=")[1].split(" ")[0]
logger.info(f"Found xla_dump_to={path}, logging to wandb")
pylogger.info(f"Found xla_dump_to={path}, logging to wandb")
if wandb.run:
# only want to save the files that were generated during this run
# XLA_FLAGS has to be set before the first jax call, so we can't just set it in the middle of the run
@@ -166,7 +270,7 @@ def include_file(path: str):

wandb.run.log_code(root=path, name="xla_dumps", include_fn=include_file)
else:
logger.warning("XLA_FLAGS is not set to dump to a path, so we can't save the dumps to wandb")
pylogger.warning("XLA_FLAGS is not set to dump to a path, so we can't save the dumps to wandb")


@contextlib.contextmanager
@@ -184,20 +288,14 @@ def fn():
end = time.time()


@contextlib.contextmanager
def log_time_to_wandb(name: str, *, step=None):
with capture_time() as fn:
yield fn
wandb.log({name: fn()}, step=step)


def jittable_wandb_log(data, *, step=None):
"""uses jax effect callback to log to wandb from the host"""
if is_wandb_available():
jax.debug.callback(wandb.log, data, step=step)


def is_wandb_available():
try:
import wandb
except ImportError:
return False
return wandb is not None and wandb.run is not None


@@ -278,7 +376,7 @@ def init(self, run_id: Optional[str], hparams=None, **extra_hparams):

other_settings = dict()
if code_dir is not None:
logger.info(f"Setting wandb code_dir to {code_dir}")
pylogger.info(f"Setting wandb code_dir to {code_dir}")
other_settings["code_dir"] = code_dir
other_settings["git_root"] = code_dir
# for some reason, wandb isn't populating the git commit, so we do it here
@@ -287,7 +385,7 @@ def init(self, run_id: Optional[str], hparams=None, **extra_hparams):
other_settings["git_commit"] = repo.head.commit.hexsha
hparams_to_save["git_commit"] = repo.head.commit.hexsha
except (NoSuchPathError, InvalidGitRepositoryError):
logger.warning(f"Could not find git repo at {code_dir}")
pylogger.warning(f"Could not find git repo at {code_dir}")
pass

r = wandb.init(
@@ -324,7 +422,7 @@ def init(self, run_id: Optional[str], hparams=None, **extra_hparams):
for k, v in metadata_to_share.items():
setattr(r, k, v)

logger.info(f"Synced wandb run information from process 0: {r.name} {r.id}")
pylogger.info(f"Synced wandb run information from process 0: {r.name} {r.id}")

if dataclasses.is_dataclass(hparams):
with tempfile.TemporaryDirectory() as tmpdir:
@@ -370,7 +468,7 @@ def _infer_experiment_git_root() -> Optional[str | os.PathLike[str]]:
top_git_root = repo.working_dir
break
except (NoSuchPathError, InvalidGitRepositoryError):
logger.debug(f"Skipping {dirname} since it's not a git root")
pylogger.debug(f"Skipping {dirname} since it's not a git root")
pass
return top_git_root

4 changes: 2 additions & 2 deletions src/levanter/main/cache_dataset.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@
from levanter.data.shard_cache import RichMetricsMonitor, WandbMetricsMonitor, build_cache
from levanter.data.text import BatchTokenizer, LMDatasetConfig
from levanter.distributed import RayConfig
from levanter.logging import init_logger
from levanter.logging import init_logging


logger = logging.getLogger(__name__)
@@ -22,7 +22,7 @@ class RayCachedLMDatasetConfig(LMDatasetConfig, RayConfig):
@levanter.config.main()
def main(args: RayCachedLMDatasetConfig):
"""Caches two different kinds of datasets. It can cache a dataset from a list of urls, or a dataset from a hf dataset"""
init_logger("cache_dataset.log")
init_logging("cache_dataset.log")
args.initialize()

tokenizer = args.the_tokenizer
2 changes: 1 addition & 1 deletion src/levanter/trainer.py
Original file line number Diff line number Diff line change
@@ -608,7 +608,7 @@ def _initialize_jax_config(self):

def _initialize_logging(self):
self.log_dir.mkdir(parents=True, exist_ok=True)
levanter.logging.init_logger(self.log_dir / f"{self.id}.log")
levanter.logging.init_logging(self.log_dir / f"{self.id}.log")

def _maybe_set_id(self):
# always do this so we don't get weird hangs if the id isn't set right
5 changes: 5 additions & 0 deletions src/levanter/utils/jax_utils.py
Original file line number Diff line number Diff line change
@@ -41,6 +41,11 @@ def use_cpu_device():
yield


def is_inside_jit():
"""Returns True if we're currently inside a jit"""
return isinstance(jnp.zeros(()), jax.core.Tracer)


def flops_estimate(fn, *args):
"""Estimates the flop count of a function using XLA/HLO fanciness. See
https://github.com/google/flax/discussions/1854"""