You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This issue is about converting single model steps to iterated models, and how this affects the model's PyTree structure, and references made to its components.
Generally, models (such as SimpleFeedback) are defined as a single iteration, and then wrapped in an Iterator object -- which is more or less a jax.lax loop.
However:
Almost the entire model PyTree is under model.step.*, from the user's perspective.
For example, when they pass a where_train when calling a TaskTrainer, they generally need to specify it like lambda model: model.step.net.
Similarly, whenever performing model surgery or the like, most references will be to model.step.*.
This differs from the structure of the state PyTree. For example, we have model.step.net but states.net. This is because Iterator adds a time dimension to the arrays in states.
In certain cases, we might have a model whose top level is not an Iterator, but which we will try to interact with using code that refers to model.step.
Currently, all AbstractModels provide a step property, which trivially returns self when the model is not an Iterator. AbstractIterator instead returns self._step, which is the field that the rest of the model PyTree is actually assigned to.
Should we be stricter about types, and (say) always assume that TaskTrainer is passed a model wrapped in Iterator?
In TaskTrainer._train_step we have to get initial states for the model, for all trials in a batch.
We start by vmapping model.init to obtain a default state. Since the input state to an iterated model is the same as to the model step, Iterator.init just returns self.step.init.
After _train_step obtains this default state, it modifies parts of it using state initialization data provided for the current batch of training trials (by the AbstractTask object). Then, it is necessary to ask the model to make sure that the state is internally consistent.
For example, the AbstractTask will typically give an initial position for the effector (e.g. arm endpoint). From the effector position we need to infer and update the mechanical configuration (e.g. joint angles). This is only necessary prior to the first time step for the trials in the batch, after which the states will be internally consistent by virtue of the model's operations. Thus we have a method AbstractStagedModel.state_consistency_update which is called once in _train_step.
So, do we add def _state_consistency_update(self): return self.step.state_consistency_update to Iterator similarly to what we've done with init? Currently, _train_step calls model._step.state_consistency_update.
I have considered modifying TaskTrainer to handle the model iteration over time, so the user does not explicitly instantiate an Iterator, and can refer to model.* instead of model.step.*. This would make sense in light of AbstractTask providing model inputs as trajectories over time, which Iterator indexes from using tree_take -- such that it does not make sense to use a non-iterated model with TaskTrainer. Should this change be adopted?
The text was updated successfully, but these errors were encountered:
If the model step is no longer a component of an AbstractIterator but is merely passed to its __call__ method, then the iterated model is described as a Tuple[AbstractIterator, AbstractModel] rather than as simply an AbstractIterator which is composed of an AbstractModel.
In that case, TaskTrainer would not need to implement model iteration over time, but could still compose an instance of AbstractIterator so that the user doesn't need to pass around Tuple[AbstractIterator, AbstractModel] when they train the model.
This issue is about converting single model steps to iterated models, and how this affects the model's PyTree structure, and references made to its components.
Generally, models (such as
SimpleFeedback
) are defined as a single iteration, and then wrapped in anIterator
object -- which is more or less ajax.lax
loop.However:
model.step.*
, from the user's perspective.where_train
when calling aTaskTrainer
, they generally need to specify it likelambda model: model.step.net
.model.step.*
.model.step.net
butstates.net
. This is becauseIterator
adds a time dimension to the arrays instates
.Iterator
, but which we will try to interact with using code that refers tomodel.step
.AbstractModel
s provide astep
property, which trivially returnsself
when the model is not anIterator
.AbstractIterator
instead returnsself._step
, which is the field that the rest of the model PyTree is actually assigned to.TaskTrainer
is passed a model wrapped inIterator
?TaskTrainer._train_step
we have to get initial states for the model, for all trials in a batch.model.init
to obtain a default state. Since the input state to an iterated model is the same as to the model step,Iterator.init
just returnsself.step.init
._train_step
obtains this default state, it modifies parts of it using state initialization data provided for the current batch of training trials (by theAbstractTask
object). Then, it is necessary to ask the model to make sure that the state is internally consistent.AbstractTask
will typically give an initial position for the effector (e.g. arm endpoint). From the effector position we need to infer and update the mechanical configuration (e.g. joint angles). This is only necessary prior to the first time step for the trials in the batch, after which the states will be internally consistent by virtue of the model's operations. Thus we have a methodAbstractStagedModel.state_consistency_update
which is called once in_train_step
.def _state_consistency_update(self): return self.step.state_consistency_update
toIterator
similarly to what we've done withinit
? Currently,_train_step
callsmodel._step.state_consistency_update
.I have considered modifying
TaskTrainer
to handle the model iteration over time, so the user does not explicitly instantiate anIterator
, and can refer tomodel.*
instead ofmodel.step.*
. This would make sense in light ofAbstractTask
providing model inputs as trajectories over time, whichIterator
indexes from usingtree_take
-- such that it does not make sense to use a non-iterated model withTaskTrainer
. Should this change be adopted?The text was updated successfully, but these errors were encountered: