Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast-forward dev branch #18

Merged
merged 13 commits into from
Dec 24, 2024
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ To train the fly, try the [distributed RL training script][ray-script], which us
[dmpo]: https://github.com/google-deepmind/acme/tree/master/acme/agents/tf/dmpo
[envs]: https://github.com/TuragaLab/flybody/blob/main/docs/fly-env-examples.ipynb
[ray-script]: https://github.com/TuragaLab/flybody/blob/main/flybody/train_dmpo_ray.py
[paper]: https://www.biorxiv.org/content/10.1101/2024.03.11.584515
[paper]: https://www.biorxiv.org/content/10.1101/2024.03.11.584515v2
[ray]: https://github.com/ray-project/ray
[tf]: https://github.com/tensorflow/tensorflow
[acme]: https://github.com/google-deepmind/acme
Expand Down Expand Up @@ -122,7 +122,8 @@ Follow these steps to install `flybody`:
```
2. Also, for the ML and Ray extensions, `LD_LIBRARY_PATH` may require an update, e.g.:
```bash
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/your/path/to/miniconda3/envs/flybody/lib
CUDNN_PATH=$(dirname $(python -c "import nvidia.cudnn;print(nvidia.cudnn.__file__)"))
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/:$CUDNN_PATH/lib
```

3. You may want to run `pytest` to test the main components of the `flybody` installation.
Expand All @@ -134,9 +135,9 @@ See our accompanying [publication][paper]. Thank you for your interest in our fl
title = {Whole-body simulation of realistic fruit fly locomotion with
deep reinforcement learning},
author = {Roman Vaxenburg and Igor Siwanowicz and Josh Merel and Alice A Robie and
Carmen Morrow and Guido Novati and Zinovia Stefanidi and Gwyneth M Card and
Michael B Reiser and Matthew M Botvinick and Kristin M Branson and
Yuval Tassa and Srinivas C Turaga},
Carmen Morrow and Guido Novati and Zinovia Stefanidi and Gert-Jan Both and
Gwyneth M Card and Michael B Reiser and Matthew M Botvinick and
Kristin M Branson and Yuval Tassa and Srinivas C Turaga},
journal = {bioRxiv},
doi = {https://doi.org/10.1101/2024.03.11.584515},
url = {https://www.biorxiv.org/content/10.1101/2024.03.11.584515},
Expand Down
505 changes: 505 additions & 0 deletions docs/controller-reuse-vision-flight.ipynb

Large diffs are not rendered by default.

59 changes: 22 additions & 37 deletions docs/sensory-input-tracking.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,25 @@
"import tensorflow_probability as tfp\n",
"from acme import wrappers\n",
"\n",
"from flybody.download_data import figshare_download\n",
"from flybody.fly_envs import flight_imitation\n",
"from flybody.agents.utils_tf import TestPolicyWrapper"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"id": "42428d7b-b9e6-48eb-afc3-62bc00b54a69",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU')]\n"
]
}
],
"source": [
"# Prevent tensorflow from stealing all the gpu memory.\n",
"physical_devices = tf.config.list_physical_devices('GPU')\n",
Expand All @@ -61,47 +70,26 @@
},
{
"cell_type": "markdown",
"id": "7a306f09-d56d-4eee-a3c5-86e86ec9bf4a",
"id": "cd1bf86b-7744-41bf-abe9-a23dbfabfa65",
"metadata": {},
"source": [
"# Download example data from `figshare`\n",
"## Download the WPG base pattern and a trained flight policy\n",
"\n",
"__Caution__: this cell will create a `flybody-data` directory and download ~3GB of data into it. You may want to delete the downloaded files when you're finished with the notebook."
"This cell will download the required data to local `flybody-data` directory. The `flybody` supplimentary data can also be accessed at <https://doi.org/10.25378/janelia.25309105>"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "00813f67-f7a3-44da-bf31-05b69820cb4e",
"execution_count": 10,
"id": "7e8e7661-0dce-4058-b958-ce57f341150c",
"metadata": {},
"outputs": [],
"source": [
"url_policy = 'https://janelia.figshare.com/ndownloader/files/44815195'\n",
"url_data = 'https://janelia.figshare.com/ndownloader/files/44869654'\n",
"\n",
"base_path = 'flybody-data'\n",
"os.makedirs(base_path, exist_ok=True)\n",
"figshare_download(['flight-imitation-dataset', 'trained-policies'])\n",
"\n",
"# Download and unzip policies.\n",
"path_policies = os.path.join(base_path, 'policies.zip')\n",
"response = requests.get(url_policy)\n",
"response.raise_for_status()\n",
"with open(path_policies, 'wb') as f:\n",
" f.write(response.content)\n",
"with zipfile.ZipFile(path_policies, 'r') as zip_ref:\n",
" for file in zip_ref.namelist():\n",
" if file.startswith('flight/'):\n",
" zip_ref.extract(member=file, path=base_path)\n",
"\n",
"# Donwload and unzip flight trajectories and WPG.\n",
"path_data = os.path.join(base_path, 'data.zip')\n",
"response = requests.get(url_data)\n",
"response.raise_for_status()\n",
"with open(path_data, 'wb') as f:\n",
" f.write(response.content)\n",
"with zipfile.ZipFile(path_data, 'r') as zip_ref:\n",
" zip_ref.extract(member='flight-dataset_saccade-evasion_augmented.hdf5', path=base_path)\n",
" zip_ref.extract(member='wing_pattern_fmech.npy', path=base_path)"
"wpg_path = 'flybody-data/datasets_flight-imitation/wing_pattern_fmech.npy'\n",
"ref_flight_path = 'flybody-data/datasets_flight-imitation/flight-dataset_saccade-evasion_augmented.hdf5'\n",
"flight_policy_path = 'flybody-data/trained-fly-policies/flight'"
]
},
{
Expand Down Expand Up @@ -132,9 +120,6 @@
}
],
"source": [
"wpg_path = os.path.join(base_path, 'wing_pattern_fmech.npy')\n",
"ref_flight_path = os.path.join(base_path, 'flight-dataset_saccade-evasion_augmented.hdf5')\n",
"\n",
"env = flight_imitation(\n",
" ref_path=ref_flight_path,\n",
" wpg_pattern_path=wpg_path,\n",
Expand Down Expand Up @@ -223,7 +208,7 @@
"outputs": [],
"source": [
"# Load flight policy.\n",
"policy = tf.saved_model.load(os.path.join(base_path, 'flight'))\n",
"policy = tf.saved_model.load(flight_policy_path)\n",
"policy = TestPolicyWrapper(policy)"
]
},
Expand Down Expand Up @@ -601,7 +586,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
"version": "3.10.16"
}
},
"nbformat": 4,
Expand Down
Loading
Loading