Skip to content

Commit

Permalink
Merge pull request #41 from KumarLabJax/update-prediction-format
Browse files Browse the repository at this point in the history
Update prediction format + some bugfixes
  • Loading branch information
SkepticRaven authored Sep 4, 2024
2 parents 1fe62a3 + 9eaeffd commit 271f29d
Show file tree
Hide file tree
Showing 37 changed files with 775 additions and 459 deletions.
99 changes: 32 additions & 67 deletions classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
import pandas as pd

from src import APP_NAME
from src.classifier import Classifier, ClassifierType
from src.classifier import Classifier
from src.types import ClassifierType, ProjectDistanceUnit
from src.cli import cli_progress_bar
from src.feature_extraction.features import IdentityFeatures
from src.pose_estimation import open_pose_file
from src.project import Project, load_training_data, ProjectDistanceUnit
from src.project import Project, load_training_data

DEFAULT_FPS = 30

Expand All @@ -38,19 +39,20 @@ def train_and_classify(
training_file_path: Path,
input_pose_file: Path,
out_dir: Path,
override_classifier: typing.Optional[ClassifierType] = None,
fps=DEFAULT_FPS,
feature_dir: typing.Optional[str] = None):
feature_dir: typing.Optional[str] = None,
cache_window: bool = False):
if not training_file_path.exists():
sys.exit(f"Unable to open training data\n")

classifier = train(training_file_path, override_classifier)
classify_pose(classifier, input_pose_file, out_dir, behavior, fps, feature_dir)
classifier = train(training_file_path)
classify_pose(classifier, input_pose_file, out_dir, classifier.behavior_name, fps, feature_dir, cache_window)


def classify_pose(classifier: Classifier, input_pose_file: Path, out_dir: Path,
behavior: str, fps=DEFAULT_FPS,
feature_dir: typing.Optional[str] = None):
feature_dir: typing.Optional[str] = None,
cache_window: bool = False):
pose_est = open_pose_file(input_pose_file)
pose_stem = get_pose_stem(input_pose_file)

Expand All @@ -70,7 +72,7 @@ def classify_pose(classifier: Classifier, input_pose_file: Path, out_dir: Path,
complete_as_percent=False, suffix='identities')

features = IdentityFeatures(
input_pose_file, curr_id, feature_dir, pose_est, fps=fps, op_settings=classifier_settings
input_pose_file, curr_id, feature_dir, pose_est, fps=fps, op_settings=classifier_settings, cache_window=cache_window
).get_features(classifier_settings['window_size'])
per_frame_features = pd.DataFrame(IdentityFeatures.merge_per_frame_features(features['per_frame']))
window_features = pd.DataFrame(IdentityFeatures.merge_window_features(features['window']))
Expand All @@ -95,69 +97,38 @@ def classify_pose(classifier: Classifier, input_pose_file: Path, out_dir: Path,

print(f"Writing predictions to {out_dir}")

behavior_out_dir = out_dir / Project.to_safe_name(behavior)
behavior_out_dir = out_dir
try:
behavior_out_dir.mkdir(parents=True, exist_ok=True)
except OSError as e:
sys.exit(f"Unable to create output directory: {e}")
behavior_out_path = behavior_out_dir / (pose_stem + '.h5')
behavior_out_path = behavior_out_dir / (pose_stem + '_behavior.h5')

Project.write_predictions(
behavior,
behavior_out_path,
prediction_labels,
prediction_prob,
pose_est
pose_est,
classifier
)


def train(
training_file: Path,
override_classifier: typing.Optional[ClassifierType] = None
) -> Classifier:

try:
loaded_training_data, _ = load_training_data(training_file)
except OSError as e:
sys.exit(f"Unable to open training data\n{e}")

behavior = loaded_training_data['behavior']

classifier = Classifier()
classifier.set_dict_settings(loaded_training_data['settings'])

# Override the classifier type
if override_classifier is not None:
classifier_type = override_classifier
else:
classifier_type = ClassifierType(
loaded_training_data['classifier_type'])

if classifier_type in classifier.classifier_choices():
classifier.set_classifier(classifier_type)
else:
print(f"Specified classifier type ({classifier_type.name}) "
"is unavailable, using default "
f"({classifier.classifier_type.name})")
classifier = Classifier.from_training_file(training_file)
classifier_settings = classifier.project_settings

print("Training classifier for:", behavior)
print("Training classifier for:", classifier.behavior_name)
print(" Classifier Type: "
f"{__CLASSIFIER_CHOICES[classifier.classifier_type]}")
print(f" Window Size: {loaded_training_data['settings']['window_size']}")
print(f" Social: {loaded_training_data['settings']['social']}")
print(f" Balanced Labels: {loaded_training_data['settings']['balance_labels']}")
print(f" Symmetric Behavior: {loaded_training_data['settings']['symmetric_behavior']}")
print(f" CM Units: {loaded_training_data['settings']['cm_units']}")

training_features = classifier.combine_data(loaded_training_data['per_frame'],
loaded_training_data['window'])
classifier.train(
{
'training_data': training_features,
'training_labels': loaded_training_data['labels']
},
behavior,
random_seed=loaded_training_data['training_seed']
)
print(f" Window Size: {classifier_settings['window_size']}")
print(f" Social: {classifier_settings['social']}")
print(f" Balanced Labels: {classifier_settings['balance_labels']}")
print(f" Symmetric Behavior: {classifier_settings['symmetric_behavior']}")
print(f" CM Units: {bool(classifier_settings['cm_units'])}")

return classifier

Expand Down Expand Up @@ -235,6 +206,12 @@ def classify_main():
help="Feature cache dir. If present, look here for features before "
"computing. If features need to be computed, they will be saved here."
)
parser.add_argument(
'--skip-window-cache',
help="Default will cache all features when --feature-dir is provided. Providing this flag will only cache per-frame features, reducing cache size at the cost of needing to re-calculate window features.",
default=False,
action='store_true'
)

args = parser.parse_args(classify_args)

Expand All @@ -243,8 +220,7 @@ def classify_main():

if args.training is not None:
train_and_classify(Path(args.training), in_pose_path, out_dir,
override_classifier=args.classifier,
fps=args.fps, feature_dir=args.feature_dir)
fps=args.fps, feature_dir=args.feature_dir, cache_window=not args.skip_window_cache)
elif args.classifier is not None:

try:
Expand All @@ -268,32 +244,21 @@ def classify_main():
print(f" Social: {classifier_settings['social']}")
print(f" CM Units: {classifier_settings['cm_units']}")

classify_pose(classifier, in_pose_path, out_dir, behavior, fps=args.fps, feature_dir=args.feature_dir)
classify_pose(classifier, in_pose_path, out_dir, behavior, fps=args.fps, feature_dir=args.feature_dir, cache_window=not args.skip_window_cache)


def train_main():
# strip out the 'command' component from sys.argv
train_args = sys.argv[2:]

parser = argparse.ArgumentParser(prog=f"{script_name()} train")
classifier_group = parser.add_argument_group(
"optionally override the classifier specified in the training file:\n"
" (the following options are mutually exclusive)")
exclusive_group = classifier_group.add_mutually_exclusive_group(
required=False)
for classifer_type, classifier_str in __CLASSIFIER_CHOICES.items():
exclusive_group.add_argument(
f"--{classifer_type.name.lower().replace('_', '-')}",
action='store_const', const=classifer_type,
dest='classifier', help=f"Use {classifier_str}"
)
parser.add_argument('training_file',
help=f"Training h5 file exported by {APP_NAME}")
parser.add_argument('out_file',
help="output filename")

args = parser.parse_args(train_args)
classifier = train(args.training_file, args.classifier)
classifier = train(args.training_file)

print(f"Saving trained classifier to '{args.out_file}'")
classifier.save(Path(args.out_file))
Expand Down
24 changes: 20 additions & 4 deletions docs/features/features.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,19 @@

### Arena Corners

2 features from arena corners
5 features from arena corners

* distance to corner using the convex hull center
* distance to nearest wall using the convex hull center
* distance to arena center using the convex hull center
* bearing to corner using angle of the base neck - nose vector
* bearing to arena center using angle of the base neck - nose vector

### Water Spout (Lixit)

1 feature from lixit
12 feature from lixit

* distance from nose to nearest lixit
* distance from each keypoint to nearest lixit

### Food Hopper

Expand Down Expand Up @@ -124,7 +127,9 @@ These features contain ciruclar measurements and need to be treated differently.
* Angles
* Bearings

# Methods of handling Missing Data
# Methods of handling Missing or Infinite Data

Features only propagate NaN (not a number) values forward to indicated missing or invalid data. Infinity and negative infinity values are converted to NaNs. Different parts of the software will handle NaN values differently, described below.

## Classifiers

Expand All @@ -149,3 +154,14 @@ This may have adverse effects for skew and kurtosis estimates, as the window may
## Signal Features

Per-frame features fill missing values with zeros before passing into the FFT.

# Extra Features calculated, but not used in a classifier

## Closest Objects

For calculating distances and bearings to nearby items, sometimes there are multiple items to choose from. For the following objects, we identify which object is closest by using the current mouses convex hull centroid and the other object. These features are not available in trained classifiers.

* Closest mouse
* Closest mouse in field of view
* Closest arena corner
* Closest water spout (lixit)
32 changes: 21 additions & 11 deletions docs/user_guide/user_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -425,27 +425,37 @@ one video file.
#### Location

The prediction files are saved
in `<JABS project dir>/jabs/predictions/<behavior_name>/<video_name>.h5` if
in `<JABS project dir>/jabs/predictions/<video_name>.h5` if
they were generated by the JABS GUI. The `classify.py` script saves inference
files in `<out-dir>/<behavior_name>/<video_name>.h5`
files in `<out-dir>/<video_name>_behavior.h5`

#### Contents

The H5 file contains one group, called "predictions". This group contains three
datasets
The H5 file contains one group, called "predictions". This group contains one or more behavior prediction groups. Each behavior prediction group contains 3 datasets and 1 new group.

predictions

- predicted_class
- probabilities
- identity_to_track
- behavior_1
- predicted_class
- probabilities
- identity_to_track
- behavior_2
- ...

The file also has some attributes:

- version: This attribute contains an integer version number, and will be
incremented if an incompatible change is made to the file format.
- source_pose_major_version: integer containing the major version of the pose
file that was used for the prediction
The root file contains the following attributes:

- pose_file: filename of the pose file used during prediction
- pose_hash: blake2b hash of pose file
- version: prediction output version

Each behavior prediction group contains the following attributes:

- classifier_file: filename of the classifier file used to predict
- classifier_hash: blake2b hash of the classifier file
- app_version: JABS application version used to make predictions
- prediction_date: date when predictions were made

##### predicted_class

Expand Down
12 changes: 3 additions & 9 deletions initialize_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import src.pose_estimation
import src.feature_extraction
import src.project
from src.types import ProjectDistanceUnit
from src.cli import cli_progress_bar
from src.video_stream import VideoStream

Expand All @@ -27,15 +28,9 @@ def generate_files_worker(params: dict):
pose_est = project.load_pose_est(
project.video_path(params['video']))

if params['force_pixel_distance'] or project.distance_unit == src.project.ProjectDistanceUnit.PIXEL:
distance_scale_factor = 1
else:
distance_scale_factor = pose_est.cm_per_pixel

features = src.feature_extraction.IdentityFeatures(
params['video'], params['identity'], project.feature_dir, pose_est,
force=params['force'], distance_scale_factor=distance_scale_factor,
extended_features=project.extended_features
force=params['force'], op_settings=project.get_project_defaults()
)

# unlike per frame features, window features are not automatically
Expand Down Expand Up @@ -219,7 +214,6 @@ def feature_job_producer():
'project': project,
'force': args.force,
'window_sizes': window_sizes,
'force_pixel_distance': args.force_pixel_distances
})

# print the initial progress bar with 0% complete
Expand All @@ -245,7 +239,7 @@ def feature_job_producer():
print('\n' + '-' * 70)
if args.force_pixel_distances:
print("computed features using pixel distances")
elif distance_unit == src.project.ProjectDistanceUnit.PIXEL:
elif distance_unit == ProjectDistanceUnit.PIXEL:
print("One or more pose files did not have the cm_per_pixel attribute")
print(" Falling back to using pixel distances")
else:
Expand Down
1 change: 0 additions & 1 deletion src/classifier/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from .classifier import ClassifierType
from .classifier import Classifier
Loading

0 comments on commit 271f29d

Please sign in to comment.