diff --git a/rope.ipynb b/rope.ipynb index 8213213..593439b 100644 --- a/rope.ipynb +++ b/rope.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 4, + "execution_count": 156, "id": "1bd666a7-0ad1-4ae7-a56e-43429a1228d8", "metadata": { "tags": [] @@ -21,12 +21,13 @@ "from dreem.datasets import SleapDataset\n", "from dreem.models.transformer import *\n", "from dreem.models import VisualEncoder\n", - "from dreem.models import GlobalTrackingTransformer" + "from dreem.models import GlobalTrackingTransformer\n", + "from dreem.models.gtr_runner import GTRRunner" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 130, "id": "a8736593-71f7-4ab6-a594-eb52d2fd94ac", "metadata": { "tags": [] @@ -282,7 +283,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 131, "id": "fc8aa9bf-7e83-4fa6-892e-8a7703777f95", "metadata": { "tags": [] @@ -296,35 +297,12 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 132, "id": "4903f6c3-1cc8-412b-b988-ebeb2757c3b7", "metadata": { "tags": [] }, - "outputs": [ - { - "ename": "FileNotFoundError", - "evalue": "[Errno 2] Unable to open file (unable to open file: name = '/home/jovyan/talmolab-smb/datasets/mot/microscopy/airyscan_proofread/Final/dreem-train/10-1.slp', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0)", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[7], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m train_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/home/jovyan/talmolab-smb/datasets/mot/microscopy/airyscan_proofread/Final/dreem-train\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m# train_path = \"/Users/mustafashaikh/dreem-data/dreem-train\"\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[43mSleapDataset\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_path\u001b[49m\u001b[43m,\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m10-1.slp\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m[\u001b[49m\u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_path\u001b[49m\u001b[43m,\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m10-1.mp4\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcrop_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m64\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtrain\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclip_length\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m32\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43manchors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcentroid\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Documents/GitHub/dreem/dreem/datasets/sleap_dataset.py:108\u001b[0m, in \u001b[0;36mSleapDataset.__init__\u001b[0;34m(self, slp_files, video_files, padding, crop_size, anchors, chunk, clip_length, mode, handle_missing, augmentations, n_chunks, seed, verbose)\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose \u001b[38;5;241m=\u001b[39m verbose\n\u001b[1;32m 106\u001b[0m \u001b[38;5;66;03m# if self.seed is not None:\u001b[39;00m\n\u001b[1;32m 107\u001b[0m \u001b[38;5;66;03m# np.random.seed(self.seed)\u001b[39;00m\n\u001b[0;32m--> 108\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlabels \u001b[38;5;241m=\u001b[39m \u001b[43m[\u001b[49m\u001b[43msio\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload_slp\u001b[49m\u001b[43m(\u001b[49m\u001b[43mslp_file\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mslp_file\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mslp_files\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 109\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mvideos \u001b[38;5;241m=\u001b[39m [imageio\u001b[38;5;241m.\u001b[39mget_reader(vid_file) \u001b[38;5;28;01mfor\u001b[39;00m vid_file \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mvid_files]\n\u001b[1;32m 110\u001b[0m \u001b[38;5;66;03m# do we need this? would need to update with sleap-io\u001b[39;00m\n\u001b[1;32m 111\u001b[0m \n\u001b[1;32m 112\u001b[0m \u001b[38;5;66;03m# for label in self.labels:\u001b[39;00m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;66;03m# label.remove_empty_instances(keep_empty_frames=False)\u001b[39;00m\n", - "File \u001b[0;32m~/Documents/GitHub/dreem/dreem/datasets/sleap_dataset.py:108\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose \u001b[38;5;241m=\u001b[39m verbose\n\u001b[1;32m 106\u001b[0m \u001b[38;5;66;03m# if self.seed is not None:\u001b[39;00m\n\u001b[1;32m 107\u001b[0m \u001b[38;5;66;03m# np.random.seed(self.seed)\u001b[39;00m\n\u001b[0;32m--> 108\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlabels \u001b[38;5;241m=\u001b[39m [\u001b[43msio\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload_slp\u001b[49m\u001b[43m(\u001b[49m\u001b[43mslp_file\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m slp_file \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mslp_files]\n\u001b[1;32m 109\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mvideos \u001b[38;5;241m=\u001b[39m [imageio\u001b[38;5;241m.\u001b[39mget_reader(vid_file) \u001b[38;5;28;01mfor\u001b[39;00m vid_file \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mvid_files]\n\u001b[1;32m 110\u001b[0m \u001b[38;5;66;03m# do we need this? would need to update with sleap-io\u001b[39;00m\n\u001b[1;32m 111\u001b[0m \n\u001b[1;32m 112\u001b[0m \u001b[38;5;66;03m# for label in self.labels:\u001b[39;00m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;66;03m# label.remove_empty_instances(keep_empty_frames=False)\u001b[39;00m\n", - "File \u001b[0;32m~/miniforge3/envs/dreem/lib/python3.11/site-packages/sleap_io/io/main.py:19\u001b[0m, in \u001b[0;36mload_slp\u001b[0;34m(filename)\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mload_slp\u001b[39m(filename: \u001b[38;5;28mstr\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Labels:\n\u001b[1;32m 11\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Load a SLEAP dataset.\u001b[39;00m\n\u001b[1;32m 12\u001b[0m \n\u001b[1;32m 13\u001b[0m \u001b[38;5;124;03m Args:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[38;5;124;03m The dataset as a `Labels` object.\u001b[39;00m\n\u001b[1;32m 18\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 19\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mslp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mread_labels\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfilename\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/miniforge3/envs/dreem/lib/python3.11/site-packages/sleap_io/io/slp.py:1011\u001b[0m, in \u001b[0;36mread_labels\u001b[0;34m(labels_path)\u001b[0m\n\u001b[1;32m 1002\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mread_labels\u001b[39m(labels_path: \u001b[38;5;28mstr\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Labels:\n\u001b[1;32m 1003\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Read a SLEAP labels file.\u001b[39;00m\n\u001b[1;32m 1004\u001b[0m \n\u001b[1;32m 1005\u001b[0m \u001b[38;5;124;03m Args:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1009\u001b[0m \u001b[38;5;124;03m The processed `Labels` object.\u001b[39;00m\n\u001b[1;32m 1010\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1011\u001b[0m tracks \u001b[38;5;241m=\u001b[39m \u001b[43mread_tracks\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlabels_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1012\u001b[0m videos \u001b[38;5;241m=\u001b[39m read_videos(labels_path)\n\u001b[1;32m 1013\u001b[0m skeletons \u001b[38;5;241m=\u001b[39m read_skeletons(labels_path)\n", - "File \u001b[0;32m~/miniforge3/envs/dreem/lib/python3.11/site-packages/sleap_io/io/slp.py:448\u001b[0m, in \u001b[0;36mread_tracks\u001b[0;34m(labels_path)\u001b[0m\n\u001b[1;32m 439\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mread_tracks\u001b[39m(labels_path: \u001b[38;5;28mstr\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mlist\u001b[39m[Track]:\n\u001b[1;32m 440\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Read `Track` dataset in a SLEAP labels file.\u001b[39;00m\n\u001b[1;32m 441\u001b[0m \n\u001b[1;32m 442\u001b[0m \u001b[38;5;124;03m Args:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 446\u001b[0m \u001b[38;5;124;03m A list of `Track` objects.\u001b[39;00m\n\u001b[1;32m 447\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 448\u001b[0m tracks \u001b[38;5;241m=\u001b[39m [json\u001b[38;5;241m.\u001b[39mloads(x) \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m \u001b[43mread_hdf5_dataset\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlabels_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtracks_json\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m]\n\u001b[1;32m 449\u001b[0m track_objects \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 450\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m track \u001b[38;5;129;01min\u001b[39;00m tracks:\n", - "File \u001b[0;32m~/miniforge3/envs/dreem/lib/python3.11/site-packages/sleap_io/io/utils.py:21\u001b[0m, in \u001b[0;36mread_hdf5_dataset\u001b[0;34m(filename, dataset)\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mread_hdf5_dataset\u001b[39m(filename: \u001b[38;5;28mstr\u001b[39m, dataset: \u001b[38;5;28mstr\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m np\u001b[38;5;241m.\u001b[39mndarray:\n\u001b[1;32m 12\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Read data from an HDF5 file.\u001b[39;00m\n\u001b[1;32m 13\u001b[0m \n\u001b[1;32m 14\u001b[0m \u001b[38;5;124;03m Args:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[38;5;124;03m The data as an array.\u001b[39;00m\n\u001b[1;32m 20\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 21\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[43mh5py\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mFile\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfilename\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mr\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[1;32m 22\u001b[0m data \u001b[38;5;241m=\u001b[39m f[dataset][()]\n\u001b[1;32m 23\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m data\n", - "File \u001b[0;32m~/miniforge3/envs/dreem/lib/python3.11/site-packages/h5py/_hl/files.py:562\u001b[0m, in \u001b[0;36mFile.__init__\u001b[0;34m(self, name, mode, driver, libver, userblock_size, swmr, rdcc_nslots, rdcc_nbytes, rdcc_w0, track_order, fs_strategy, fs_persist, fs_threshold, fs_page_size, page_buf_size, min_meta_keep, min_raw_keep, locking, alignment_threshold, alignment_interval, meta_block_size, **kwds)\u001b[0m\n\u001b[1;32m 553\u001b[0m fapl \u001b[38;5;241m=\u001b[39m make_fapl(driver, libver, rdcc_nslots, rdcc_nbytes, rdcc_w0,\n\u001b[1;32m 554\u001b[0m locking, page_buf_size, min_meta_keep, min_raw_keep,\n\u001b[1;32m 555\u001b[0m alignment_threshold\u001b[38;5;241m=\u001b[39malignment_threshold,\n\u001b[1;32m 556\u001b[0m alignment_interval\u001b[38;5;241m=\u001b[39malignment_interval,\n\u001b[1;32m 557\u001b[0m meta_block_size\u001b[38;5;241m=\u001b[39mmeta_block_size,\n\u001b[1;32m 558\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwds)\n\u001b[1;32m 559\u001b[0m fcpl \u001b[38;5;241m=\u001b[39m make_fcpl(track_order\u001b[38;5;241m=\u001b[39mtrack_order, fs_strategy\u001b[38;5;241m=\u001b[39mfs_strategy,\n\u001b[1;32m 560\u001b[0m fs_persist\u001b[38;5;241m=\u001b[39mfs_persist, fs_threshold\u001b[38;5;241m=\u001b[39mfs_threshold,\n\u001b[1;32m 561\u001b[0m fs_page_size\u001b[38;5;241m=\u001b[39mfs_page_size)\n\u001b[0;32m--> 562\u001b[0m fid \u001b[38;5;241m=\u001b[39m \u001b[43mmake_fid\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43muserblock_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfapl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfcpl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mswmr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mswmr\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 564\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(libver, \u001b[38;5;28mtuple\u001b[39m):\n\u001b[1;32m 565\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_libver \u001b[38;5;241m=\u001b[39m libver\n", - "File \u001b[0;32m~/miniforge3/envs/dreem/lib/python3.11/site-packages/h5py/_hl/files.py:235\u001b[0m, in \u001b[0;36mmake_fid\u001b[0;34m(name, mode, userblock_size, fapl, fcpl, swmr)\u001b[0m\n\u001b[1;32m 233\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m swmr \u001b[38;5;129;01mand\u001b[39;00m swmr_support:\n\u001b[1;32m 234\u001b[0m flags \u001b[38;5;241m|\u001b[39m\u001b[38;5;241m=\u001b[39m h5f\u001b[38;5;241m.\u001b[39mACC_SWMR_READ\n\u001b[0;32m--> 235\u001b[0m fid \u001b[38;5;241m=\u001b[39m \u001b[43mh5f\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mopen\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mflags\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfapl\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfapl\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 236\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m mode \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mr+\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m 237\u001b[0m fid \u001b[38;5;241m=\u001b[39m h5f\u001b[38;5;241m.\u001b[39mopen(name, h5f\u001b[38;5;241m.\u001b[39mACC_RDWR, fapl\u001b[38;5;241m=\u001b[39mfapl)\n", - "File \u001b[0;32mh5py/_objects.pyx:54\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[0;34m()\u001b[0m\n", - "File \u001b[0;32mh5py/_objects.pyx:55\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[0;34m()\u001b[0m\n", - "File \u001b[0;32mh5py/h5f.pyx:102\u001b[0m, in \u001b[0;36mh5py.h5f.open\u001b[0;34m()\u001b[0m\n", - "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] Unable to open file (unable to open file: name = '/home/jovyan/talmolab-smb/datasets/mot/microscopy/airyscan_proofread/Final/dreem-train/10-1.slp', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0)" - ] - } - ], + "outputs": [], "source": [ "# get sample crops from training data to pass through the network\n", "train_path = \"/home/jovyan/talmolab-smb/datasets/mot/microscopy/airyscan_proofread/Final/dreem-train\"\n", @@ -335,14 +313,14 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 133, "id": "27bfdb50-2eee-4207-8a16-6481a9905e90", "metadata": { "tags": [] }, "outputs": [], "source": [ - "# get a list of all instances; this is the format that the model pipeline uses as input data\n", + "# get a list of all instances in the first clip; this is the format that the model pipeline uses as input data\n", "ref_instances = []\n", "for frame in data[0]:\n", " for instance in frame.instances:\n", @@ -351,7 +329,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 134, "id": "8ea441b2-b12a-4f10-8821-aef889a063ba", "metadata": { "tags": [] @@ -365,7 +343,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 135, "id": "b863dbfe-d9fc-4ed1-bf97-3f304d3d03a6", "metadata": { "collapsed": true, @@ -378,10 +356,10 @@ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 8, + "execution_count": 135, "metadata": {}, "output_type": "execute_result" }, @@ -403,7 +381,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 136, "id": "8b17fdb7", "metadata": { "tags": [] @@ -417,7 +395,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 137, "id": "7999fcef-953b-42cf-927c-f3b617f68157", "metadata": { "tags": [] @@ -460,7 +438,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 138, "id": "e299e8a0-61eb-4eee-901c-49aa7e678b3b", "metadata": { "tags": [] @@ -517,7 +495,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 139, "id": "75ec8cab-25b9-4e9e-a64a-b5dbe00cc81a", "metadata": { "tags": [] @@ -538,32 +516,61 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 140, "id": "f0823cf1-2a35-4920-a62e-896bd9dbb078", "metadata": { "tags": [] }, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'ref_instances' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[3], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# input data for transformer\u001b[39;00m\n\u001b[1;32m 2\u001b[0m ref_features \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcat(\n\u001b[0;32m----> 3\u001b[0m [instance\u001b[38;5;241m.\u001b[39mfeatures \u001b[38;5;28;01mfor\u001b[39;00m instance \u001b[38;5;129;01min\u001b[39;00m \u001b[43mref_instances\u001b[49m], dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 4\u001b[0m )\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m 6\u001b[0m \u001b[38;5;66;03m# create transformer instance to test embeddings \u001b[39;00m\n\u001b[1;32m 7\u001b[0m tfmr \u001b[38;5;241m=\u001b[39m Transformer()\n", - "\u001b[0;31mNameError\u001b[0m: name 'ref_instances' is not defined" - ] - } - ], + "outputs": [], "source": [ - "# input data for transformer\n", - "ref_features = torch.cat(\n", - " [instance.features for instance in ref_instances], dim=0\n", - " ).unsqueeze(0)\n", - "\n", "# create transformer instance to test embeddings \n", - "tfmr = Transformer()\n" + "tfmr = Transformer()" + ] + }, + { + "cell_type": "code", + "execution_count": 143, + "id": "5e0b9d31-34be-40f8-91dc-b91d59aee170", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "assoc = tfmr(ref_instances)" + ] + }, + { + "cell_type": "code", + "execution_count": 157, + "id": "9f29ca35-9ff2-4e9a-bba0-37a3a14ad522", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "gtr = GTRRunner()" + ] + }, + { + "cell_type": "code", + "execution_count": 160, + "id": "0aa3876a-6246-4d02-80a5-013d382f6d38", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "metrics = gtr._shared_eval_step(data[0],\"train\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aee0d129-83f2-4f76-b452-132391554b4c", + "metadata": {}, + "outputs": [], + "source": [ + "metrics" ] } ], @@ -571,7 +578,7 @@ "kernelspec": { "display_name": "dreem", "language": "python", - "name": "python3" + "name": "dreem" }, "language_info": { "codemirror_mode": {