Skip to content

Commit

Permalink
object-specific active sampler learning
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsilver committed Dec 15, 2023
1 parent 33b1d21 commit e039075
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 10 deletions.
98 changes: 88 additions & 10 deletions predicators/approaches/active_sampler_learning_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from collections import defaultdict
from typing import Any, Callable, DefaultDict, Dict, List, Optional, \
Sequence, Set, Tuple
from typing import Type as TypingType

import dill as pkl
import numpy as np
Expand Down Expand Up @@ -280,19 +281,22 @@ def _learn_wrapped_samplers(self,
if CFG.active_sampler_learning_model in [
"myopic_classifier_mlp", "myopic_classifier_knn"
]:
learner: _WrappedSamplerLearner = _ClassifierWrappedSamplerLearner(
self._get_current_nsrts(), self._get_current_predicates(),
online_learning_cycle)
learner_cls: TypingType[_WrappedSamplerLearner] = \
_ClassifierWrappedSamplerLearner
elif CFG.active_sampler_learning_model == "myopic_classifier_ensemble":
learner = \
_ClassifierEnsembleWrappedSamplerLearner(
self._get_current_nsrts(), self._get_current_predicates(),
online_learning_cycle)
learner_cls = _ClassifierEnsembleWrappedSamplerLearner
else:
assert CFG.active_sampler_learning_model == "fitted_q"
learner = _FittedQWrappedSamplerLearner(
self._get_current_nsrts(), self._get_current_predicates(),
online_learning_cycle)
learner_cls = _FittedQWrappedSamplerLearner
if CFG.active_sampler_learning_object_specific_samplers:
learner: _WrappedSamplerLearner = \
_ObjectSpecificSamplerLearningWrapper(
learner_cls, self._get_current_nsrts(),
self._get_current_predicates(), online_learning_cycle)
else:
learner = learner_cls(self._get_current_nsrts(),
self._get_current_predicates(),
online_learning_cycle)
# Fit with the current data.
learner.learn(self._sampler_data)
wrapped_samplers = learner.get_samplers()
Expand Down Expand Up @@ -619,6 +623,51 @@ def _fit_regressor(self, nsrt_data: _OptionSamplerDataset) -> MLPRegressor:
return regressor


class _ObjectSpecificSamplerLearningWrapper(_WrappedSamplerLearner):
"""Wrapper for learning multiple object-specific samplers for each NSRT."""

def __init__(self,
base_sampler_learner_cls: TypingType[_WrappedSamplerLearner],
nsrts: Set[NSRT], predicates: Set[Predicate],
online_learning_cycle: Optional[int]) -> None:
super().__init__(nsrts, predicates, online_learning_cycle)
self._base_sampler_learner_cls = base_sampler_learner_cls

def _learn_nsrt_sampler(
self, nsrt_data: _OptionSamplerDataset,
nsrt: NSRT) -> Tuple[NSRTSampler, NSRTSamplerWithEpsilonIndicator]:
"""Learn the new test-time and exploration samplers for a single NSRT
and return them."""
# Organize data by groundings.
grounding_to_data: Dict[Tuple[Object, ...], _OptionSamplerDataset] = {}
for datum in nsrt_data:
option = datum[1]
objects = tuple(option.objects)
if objects not in grounding_to_data:
grounding_to_data[objects] = []
grounding_to_data[objects].append(datum)
# Call the base class once per grounding.
grounding_to_sampler: Dict[Tuple[Object, ...], NSRTSampler] = {}
grounding_to_sampler_with_eps: Dict[Tuple[
Object, ...], NSRTSamplerWithEpsilonIndicator] = {}
for grounding, data in grounding_to_data.items():
base_sampler = self._base_sampler_learner_cls(
self._nsrts, self._predicates, self._online_learning_cycle)
sampler_dataset = {nsrt.option: data}
logging.info(f"Fitting object specific sampler: {grounding}...")
base_sampler.learn(sampler_dataset)
samplers = base_sampler.get_samplers()
assert len(samplers) == 1
sampler, sampler_with_eps = samplers[nsrt]
grounding_to_sampler[grounding] = sampler
grounding_to_sampler_with_eps[grounding] = sampler_with_eps
# Create wrapped samplers and samplers with epsilon indicator.
nsrt_sampler = _wrap_object_specific_samplers(grounding_to_sampler)
nsrt_sampler_with_eps = _wrap_object_specific_samplers_with_epsilon(
grounding_to_sampler_with_eps)
return nsrt_sampler, nsrt_sampler_with_eps


# Helper functions.
def _wrap_sampler_test(base_sampler: NSRTSampler,
score_fn: _ScoreFn) -> NSRTSampler:
Expand Down Expand Up @@ -666,6 +715,35 @@ def _sample(state: State, goal: Set[GroundAtom], rng: np.random.Generator,
return _sample


def _wrap_object_specific_samplers(
object_specific_samplers: Dict[Tuple[Object, ...], NSRTSampler]
) -> NSRTSampler:

def _wrapped_sampler(state: State, goal: Set[GroundAtom],
rng: np.random.Generator,
objects: Sequence[Object]) -> Array:
objects_tuple = tuple(objects)
sampler = object_specific_samplers[objects_tuple]
return sampler(state, goal, rng, objects)

return _wrapped_sampler


def _wrap_object_specific_samplers_with_epsilon(
object_specific_samplers: Dict[Tuple[Object, ...],
NSRTSamplerWithEpsilonIndicator]
) -> NSRTSamplerWithEpsilonIndicator:

def _wrapped_sampler(state: State, goal: Set[GroundAtom],
rng: np.random.Generator,
objects: Sequence[Object]) -> Tuple[Array, bool]:
objects_tuple = tuple(objects)
sampler = object_specific_samplers[objects_tuple]
return sampler(state, goal, rng, objects)

return _wrapped_sampler


def _vector_score_fn_to_score_fn(vector_fn: Callable[[Array], float],
nsrt: NSRT) -> _ScoreFn:
"""Helper for _classifier_to_score_fn() and _regressor_to_score_fn()."""
Expand Down
1 change: 1 addition & 0 deletions predicators/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,7 @@ class GlobalSettings:
active_sampler_explorer_planning_progress_max_tasks = 10
active_sampler_explorer_planning_progress_max_replan_tasks = 5
active_sampler_learning_init_cycles_to_pursue_goal = 1
active_sampler_learning_object_specific_samplers = True

# grammar search invention parameters
grammar_search_grammar_includes_givens = True
Expand Down

0 comments on commit e039075

Please sign in to comment.