Skip to content

Commit

Permalink
Correct mkdocs warnings, and a few pyright errors
Browse files Browse the repository at this point in the history
Reduce pyright errors from ~325 to 290.

- Fix docstring indentation to correct Mkdocs (Griffe) warnings.
- Fix broken mkdocstrings crossrefs in example notebooks.
- Rename `spec` to `stage` in `AbstractStagedModel.__call__`.
- Add `ModelStageCallable` and `OtherStageCallable` protocols
  `ModelStage.callable`.
- Get rid of `in_where` properties that override `in_where`
  field of `AbstractIntervenor`. Pyright doesn't like this.
- Correct a few small typing issues.
    - Improve type handling in `_convert_feedback_spec`
    - Add `property` decorator in a couple of places it's missing
    - Fix `Mapping` annotations lacking two arguments
  • Loading branch information
mlprt committed Feb 28, 2024
1 parent b473e1b commit bd252fa
Show file tree
Hide file tree
Showing 30 changed files with 343 additions and 335 deletions.
2 changes: 2 additions & 0 deletions .git-blame-ignore-revs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Migrate code style to Black
b473e1b077050ac7ca4961231c17e9eeedb72999
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,6 @@
}
],
"explorer.autoReveal": false,
"scm.autoReveal": false
"scm.autoReveal": false,
"files.trimTrailingWhitespace": true
}
2 changes: 2 additions & 0 deletions docs/api/intervene.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Interventions

::: feedbax.intervene.CurlFieldParams

::: feedbax.intervene.CurlField

::: feedbax.intervene.AddNoise
Expand Down
40 changes: 21 additions & 19 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Feedbax makes it easy to:
- alter the activity of a single unit in a neural network;
- perturb the sensory feedback received by a network;
- add any kind of noise to any part of a model's state;
- swap out components of models, and write new components;
- swap out components of models, and write new components;
- [train replicates](/feedbax/examples/4_vmap) of a model in parallel;
- specify which parts of the model are [trainable](/feedbax/examples/1_train/#selecting-part-of-the-model-to-train), or available to a controller as feedback;
<!-- - track the progress of a training run in Tensorboard. -->
Expand All @@ -19,42 +19,44 @@ Feedbax was designed for feedback control of biomechanical models by neural netw

## Feedbax is a JAX library

Feedbax uses JAX and [Equinox](https://docs.kidger.site/equinox/) for the structural [advantages](/feedbax/examples/pytrees/) they provide.
Feedbax uses JAX and [Equinox](https://docs.kidger.site/equinox/), because of their [features](/feedbax/examples/pytrees/) which are very convenient for scientific analyses. Still, if you've never used JAX before, you might find it (and Feedbax) a little strange at first.
<!--
One disadvantage of JAX is a lack of GPU support on Windows, though it is possible to use the GPU through the Windows Subsystem for Linux (WSL). -->

One disadvantage of JAX is a lack of GPU support on Windows, though it is possible to use the GPU through the Windows Subsystem for Linux (WSL).

If you prefer to use PyTorch, check out `MotorNet`!
For a library that's similar to Feedbax but written in PyTorch, please check out [`MotorNet`](https://github.com/OlivierCodol/MotorNet)!

## Installation

Pip TODO.

`python -m pip install`

### Installing from source

## Development

I've developed Feedbax over the last few months, as I've learned JAX. My short-term objective has been to serve my own use case—graduate research in the neuroscience of motor control—but I have also tried to make design choices in pursuit of reusability and generality.

By making the library open source now, I hope to receive some feedback about those decisions. To make that easier I've created GitHub [issues](https://github.com/mlprt/feedbax/issues) documenting my choices and uncertainties. The issues largely belong to one of a few categories:
By making the library open source now, I hope to receive some feedback about those decisions. To make that easier I've created GitHub [issues](https://github.com/mlprt/feedbax/issues) documenting my choices and uncertainties. The issues largely fall into a few categories:

1. Structural issues: Perhaps some of the abstractions I've chosen are clumsy. Depending on the effort involved, I'm still willing to initiate major structural changes at this point.
2. Typing issues: I've tried to err on the side of typing things a bit too much. At least it may still serve as documentation? I'm still learning the limits of typing in Python.
3. Feature issues: There are many small improvements and additions that could be made, especially pre-built models and tasks.
4.

If you are a researcher in optimal control or reinforcement learning, I'd be particularly interested to hear what you think about

- whether you foresee any problems in applying RL formalisms given the way Feedbax is modularized
1. Structure: Some of the abstractions I've chosen are probably clumsy. It would be good to know about that, at this point. Maybe we can make some changes for the better! In approximate order of significance: #19, #12, #1, #21.
2. Features: There are many small additions that could be made, especially to the pre-built models and tasks. There are also a few major improvements which I am anticipating in the near future, such as *online learning* (#21).
3. Typing: Typing in Feedbax is a mess, at the moment. I have been learning to use the typing system recently. However, I haven't been constraining myself with type checker errors. I know I've done some things that probably won't work. See issues. (#7, #8, #9, #11)

If you are an experienced Python or JAX user:

- Do
- Any low-hanging fruit re: abstraction
- Anything obviously clumsy I am doing with PyTrees
- Typing
- Performance issues

!!! Success ""
Ask questions or make suggestions about any part of the code or documentation!
If you are a researcher in optimal control or reinforcement learning, I'd be particularly interested to hear what you think about

- whether you foresee any problems in applying RL formalisms given the way Feedbax is modularized

!!! Note:
For comments on the documentation, I have specifically enabled to the Giscus commenting system so that GitHub users can comment on pages directly. You can also participate via the Discussions tab on GitHub.

## Acknowledgments
## Acknowledgments

Special thanks to [Patrick Kidger](https://github.com/patrick-kidger), whose JAX libraries and their documentation often serve as examples to me.

16 changes: 8 additions & 8 deletions examples/5_model_stages.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"First, notice that `SimpleFeedback` is an Equinox `Module` subclass. It's not obvious from this code alone, but [`Mechanics`](feedbax.mechanics.Mechanics) and [`Channel`](feedbax.channel.Channel) are also `Module` subclasses, with their own parameters and submodules. \n",
"First, notice that `SimpleFeedback` is an Equinox `Module` subclass. It's not obvious from this code alone, but [`Mechanics`][feedbax.mechanics.Mechanics] and [`Channel`][feedbax.channel.Channel] are also `Module` subclasses, with their own parameters and submodules. \n",
"\n",
"Observe the following about `__call__`. It:\n",
"\n",
Expand Down Expand Up @@ -375,7 +375,7 @@
"source": [
"Another advantage of staged models is that it's easy to print out a tree of operations, showing the sequence in which the're performed.\n",
"\n",
"Feedbax provides the function [`pprint_model_spec`](feedbax.pprint_model_spec) for this purpose."
"Feedbax provides the function [`pprint_model_spec`][feedbax.pprint_model_spec] for this purpose."
]
},
{
Expand Down Expand Up @@ -464,7 +464,7 @@
" net2: SimpleStagedNetwork\n",
" channel: Channel\n",
" intervenors: dict[str, Sequence[AbstractIntervenor]]\n",
" \n",
"\n",
" @property\n",
" def model_spec(self) -> OrderedDict[str, ModelStage]:\n",
" return OrderedDict({\n",
Expand All @@ -484,7 +484,7 @@
" where_output=lambda state: state.net2,\n",
" ),\n",
" })\n",
" \n",
"\n",
" def init(self, *, key: PRNGKeyArray | None = None) -> NetworkLoopState:\n",
" keys = jax.random.split(key, 3)\n",
" return NetworkLoopState(\n",
Expand Down Expand Up @@ -512,8 +512,8 @@
"\n",
"\n",
"def setup(\n",
" net1_hidden_size, \n",
" net2_hidden_size, \n",
" net1_hidden_size,\n",
" net2_hidden_size,\n",
" channel_delay=5,\n",
" channel_noise_std=0.05,\n",
" *,\n",
Expand All @@ -531,8 +531,8 @@
" key=key2\n",
" )\n",
" channel = Channel(\n",
" channel_delay, \n",
" channel_noise_std, \n",
" channel_delay,\n",
" channel_noise_std,\n",
" input_proto=jnp.zeros(net2_hidden_size)\n",
" )\n",
"\n",
Expand Down
48 changes: 18 additions & 30 deletions examples/6_intervening_2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
}
],
"source": [
"import jax \n",
"import jax\n",
"\n",
"from feedbax.xabdeef import point_mass_nn_simple_reaches\n",
"\n",
Expand Down Expand Up @@ -373,7 +373,7 @@
],
"source": [
"_ = plot_reach_trajectories(\n",
" task.eval(model_clamp_pre, key=key_eval), \n",
" task.eval(model_clamp_pre, key=key_eval),\n",
" trial_specs=task.validation_trials\n",
")"
]
Expand Down Expand Up @@ -403,7 +403,7 @@
],
"source": [
"_ = plot_reach_trajectories(\n",
" task.eval(model_clamp_post, key=key_eval), \n",
" task.eval(model_clamp_post, key=key_eval),\n",
" trial_specs=task.validation_trials\n",
")"
]
Expand Down Expand Up @@ -452,7 +452,7 @@
"\n",
"# The x and y variables are stored in the same array. We only perturb x,\n",
"# so make a mask that's 1 at x and 0 at y.\n",
"# (This will still work if we switch impulse_var to 1 (that is, perturb the velocity) \n",
"# (This will still work if we switch impulse_var to 1 (that is, perturb the velocity)\n",
"# because the velocity array has the same x/y shape as the position array.)\n",
"array_mask = jnp.zeros((2,)).at[impulse_dim].set(1)\n",
"\n",
Expand Down Expand Up @@ -522,7 +522,7 @@
"from feedbax.intervene import ConstantInput, schedule_intervenor, TimeSeriesParam\n",
"\n",
"task_fb_impulse, model_fb_impulse = schedule_intervenor(\n",
" task, model, \n",
" task, model,\n",
" intervenor=ConstantInput.with_params(\n",
" scale=impulse_amp,\n",
" arrays=array_mask,\n",
Expand Down Expand Up @@ -578,7 +578,7 @@
],
"source": [
"plot_reach_trajectories(\n",
" task_fb_impulse.eval(model_fb_impulse, key=key_eval), \n",
" task_fb_impulse.eval(model_fb_impulse, key=key_eval),\n",
" trial_specs=task.validation_trials,\n",
" straight_guides=True, # Show dashed lines for \"ideal\" straight reaches\n",
")"
Expand Down Expand Up @@ -622,27 +622,27 @@
"class CurlFieldParams(AbstractIntervenorInput):\n",
" \"\"\"Parameters for a curl force field.\"\"\"\n",
" amplitude: float = 0.\n",
" active: bool = True \n",
" \n",
" active: bool = True\n",
"\n",
"\n",
"class CurlField(AbstractIntervenor[MechanicsState, CurlFieldParams]):\n",
" \"\"\"Apply a curl force field to a mechanical effector.\"\"\"\n",
" \n",
"\n",
" params: CurlFieldParams = CurlFieldParams()\n",
" in_where: Callable[[MechanicsState], Array] = lambda state: state.effector.vel \n",
" in_where: Callable[[MechanicsState], Array] = lambda state: state.effector.vel\n",
" out_where: Callable[[MechanicsState], Array] = lambda state: state.effector.force\n",
" operation: Callable[[ArrayLike, ArrayLike], ArrayLike] = lambda x, y: x + y\n",
" label: str = \"CurlField\"\n",
" \n",
"\n",
" def transform(\n",
" self, \n",
" params: CurlFieldParams, \n",
" substate_in: Array, \n",
" *, \n",
" key: Optional[PRNGKeyArray] = None \n",
" self,\n",
" params: CurlFieldParams,\n",
" substate_in: Array,\n",
" *,\n",
" key: Optional[PRNGKeyArray] = None\n",
" ) -> Array:\n",
" \"\"\"Transform velocity into curl force.\"\"\"\n",
" scale = params.amplitude * jnp.array([-1, 1]) \n",
" scale = params.amplitude * jnp.array([-1, 1])\n",
" return scale * substate_in[..., ::-1]"
]
},
Expand Down Expand Up @@ -681,19 +681,7 @@
"\n",
" For many kinds of interventions, the input and the output substates are identical. For example, when adding noise to a state, the input is the substate to which to add noise, and the output is the noise to be added to the same substate.\n",
" \n",
" If you want to ensure that users cannot specify a separate `in_where` and `out_where`, you can override the `in_where` field with a property, like so:\n",
" \n",
" ```python\n",
" class SomeIntervenor(AbstractIntervenor):\n",
" out_where: Callable = ...\n",
" # Other field definitions\n",
" \n",
" @property\n",
" def in_where(self):\n",
" return self.out_where\n",
" ```\n",
" \n",
" However, this will lead to an error if the user tries to pass an `in_where` keyword argument when constructing the intervenor."
" For now, this must be explicitly specified."
]
},
{
Expand Down
42 changes: 21 additions & 21 deletions examples/losses.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
" \n",
" Some loss terms may only be calculated for a subset of time steps. For a reaching task, we might include a loss term that penalizes the square of the velocity but only on the final time step, because we want the point mass to stop at the goal position rather than simply pass through it. It would not make sense to apply this loss term at time steps in the middle of the reach, when the point mass ought to be moving at a non-zero velocity toward the goal!\n",
"\n",
"Common loss terms, such as [`EffectorPositionLoss`][feedbax.task.EffectorPositionLoss] and [`EffectorFinalVelocityLoss`][feedbax.task.EffectorFinalVelocityLoss], are defined in `feedbax.loss`. A loss function with multiple weighted terms can be defined in an algebraic manner.\n"
"Common loss terms, such as [`EffectorPositionLoss`][feedbax.loss.EffectorPositionLoss] and [`EffectorFinalVelocityLoss`][feedbax.loss.EffectorFinalVelocityLoss], are defined in `feedbax.loss`. A loss function with multiple weighted terms can be defined in an algebraic manner.\n"
]
},
{
Expand All @@ -29,16 +29,16 @@
"outputs": [],
"source": [
"from feedbax.loss import (\n",
" EffectorFinalVelocityLoss, \n",
" EffectorPositionLoss, \n",
" NetworkActivityLoss, \n",
" NetworkOutputLoss, \n",
" EffectorFinalVelocityLoss,\n",
" EffectorPositionLoss,\n",
" NetworkActivityLoss,\n",
" NetworkOutputLoss,\n",
")\n",
"\n",
"loss_func = (\n",
" 1.0 * EffectorPositionLoss() \n",
" 1.0 * EffectorPositionLoss()\n",
" + 1.0 * EffectorFinalVelocityLoss()\n",
" + 1e-5 * NetworkOutputLoss() \n",
" + 1e-5 * NetworkOutputLoss()\n",
" + 1e-5 * NetworkActivityLoss()\n",
")"
]
Expand Down Expand Up @@ -71,7 +71,7 @@
" dict(\n",
" effector_position=EffectorPositionLoss(),\n",
" effector_final_velocity=EffectorFinalVelocityLoss(),\n",
" nn_output=NetworkOutputLoss(), \n",
" nn_output=NetworkOutputLoss(),\n",
" nn_hidden=NetworkActivityLoss(),\n",
" ),\n",
" weights=dict(\n",
Expand All @@ -86,7 +86,7 @@
" [\n",
" EffectorPositionLoss(),\n",
" EffectorFinalVelocityLoss(),\n",
" NetworkOutputLoss(), \n",
" NetworkOutputLoss(),\n",
" NetworkActivityLoss(),\n",
" ],\n",
" weights=[1.0, 1.0, 1e-5, 1e-5],\n",
Expand All @@ -112,7 +112,7 @@
" [\n",
" EffectorPositionLoss(),\n",
" EffectorFinalVelocityLoss(),\n",
" NetworkOutputLoss(), \n",
" NetworkOutputLoss(),\n",
" ],\n",
" weights=[1.0, 1.0, 1e-5],\n",
")\n",
Expand Down Expand Up @@ -319,8 +319,8 @@
],
"source": [
"model, train_history = context.train(\n",
" n_batches=500, \n",
" batch_size=250, \n",
" n_batches=500,\n",
" batch_size=250,\n",
" log_step=125,\n",
" key=jax.random.PRNGKey(1),\n",
")"
Expand Down Expand Up @@ -356,7 +356,7 @@
}
],
"source": [
"import equinox as eqx \n",
"import equinox as eqx\n",
"\n",
"eqx.tree_pprint(train_history.loss)"
]
Expand Down Expand Up @@ -396,7 +396,7 @@
}
],
"source": [
"from feedbax.plot import plot_losses \n",
"from feedbax.plot import plot_losses\n",
"\n",
"plot_losses(train_history)"
]
Expand Down Expand Up @@ -444,20 +444,20 @@
" label: str = \"effector_position\"\n",
"\n",
" def term(\n",
" self, \n",
" states: AbstractState, \n",
" self,\n",
" states: AbstractState,\n",
" trial_specs: AbstractTaskTrialSpec,\n",
" ) -> Array:\n",
" \n",
" # Sum over length of variable vector \n",
"\n",
" # Sum over length of variable vector\n",
" loss = jnp.sum(\n",
" (states.some_variable - trial_specs.some_target) ** 2, \n",
" (states.some_variable - trial_specs.some_target) ** 2,\n",
" axis=-1\n",
" )\n",
" \n",
"\n",
" # Sum over time (if calculated for multiple time steps)\n",
" loss = jnp.sum(loss, axis=-1)\n",
" \n",
"\n",
" return loss"
]
},
Expand Down
Loading

0 comments on commit bd252fa

Please sign in to comment.