-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from convince-project/option_learning
Option Learning Functionality (and other mods)
- Loading branch information
Showing
107 changed files
with
515,909 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,226 @@ | ||
#!/usr/bin/env python3 | ||
""" A script to run REFINE-PLAN on the fake museum simulation example | ||
Author: Charlie Street | ||
Owner: Charlie Street | ||
""" | ||
|
||
from refine_plan.models.condition import Label, EqCondition, AndCondition, OrCondition | ||
from refine_plan.learning.option_learning import mongodb_to_yaml, learn_dbns | ||
from refine_plan.algorithms.semi_mdp_solver import synthesise_policy | ||
from refine_plan.models.state_factor import StateFactor | ||
from refine_plan.models.dbn_option import DBNOption | ||
from refine_plan.models.semi_mdp import SemiMDP | ||
from refine_plan.models.state import State | ||
import sys | ||
|
||
# Global map setup | ||
|
||
GRAPH = { | ||
"v1": {"e12": "v2", "e13": "v3", "e14": "v4"}, | ||
"v2": {"e12": "v1", "e23": "v3", "e25": "v5", "e26": "v6"}, | ||
"v3": { | ||
"e13": "v1", | ||
"e23": "v2", | ||
"e34": "v4", | ||
"e35": "v5", | ||
"e36": "v6", | ||
"e37": "v7", | ||
}, | ||
"v4": {"e14": "v1", "e34": "v3", "e46": "v6", "e47": "v7"}, | ||
"v5": {"e25": "v2", "e35": "v3", "e56": "v6", "e58": "v8"}, | ||
"v6": { | ||
"e26": "v2", | ||
"e36": "v3", | ||
"e46": "v4", | ||
"e56": "v5", | ||
"e67": "v7", | ||
"e68": "v8", | ||
}, | ||
"v7": { | ||
"e37": "v3", | ||
"e47": "v4", | ||
"e67": "v6", | ||
"e78": "v8", | ||
}, | ||
"v8": {"e58": "v5", "e68": "v6", "e78": "v7"}, | ||
} | ||
|
||
CORRESPONDING_DOOR = { | ||
"e12": None, | ||
"e14": None, | ||
"e58": "v5", | ||
"e78": "v7", | ||
"e13": None, | ||
"e36": "v3", | ||
"e68": "v6", | ||
"e25": "v2", | ||
"e47": "v4", | ||
"e26": "v2", | ||
"e35": "v3", | ||
"e46": "v4", | ||
"e37": "v3", | ||
"e23": None, | ||
"e34": None, | ||
"e56": None, | ||
"e67": None, | ||
} | ||
|
||
# Problem Setup | ||
INITIAL_LOC = "v1" | ||
GOAL_LOC = "v8" | ||
|
||
|
||
def _get_enabled_cond(sf_list, option): | ||
"""Get the enabled condition for an option. | ||
Args: | ||
sf_list: The list of state factors | ||
option: The option we want the condition for | ||
Returns: | ||
The enabled condition for the option | ||
""" | ||
sf_dict = {sf.get_name(): sf for sf in sf_list} | ||
|
||
door_locs = ["v{}".format(i) for i in range(2, 8)] | ||
|
||
if option == "check_door" or option == "open_door": | ||
enabled_cond = OrCondition() | ||
door_status = "unknown" if option == "check_door" else "closed" | ||
for door in door_locs: | ||
enabled_cond.add_cond( | ||
AndCondition( | ||
EqCondition(sf_dict["location"], door), | ||
EqCondition(sf_dict["{}_door".format(door)], door_status), | ||
) | ||
) | ||
return enabled_cond | ||
else: # edge navigation option | ||
enabled_cond = OrCondition() | ||
for node in GRAPH: | ||
if option in GRAPH[node]: | ||
enabled_cond.add_cond(EqCondition(sf_dict["location"], node)) | ||
door = CORRESPONDING_DOOR[option] | ||
if door != None: | ||
enabled_cond = AndCondition( | ||
enabled_cond, EqCondition(sf_dict["{}_door".format(door)], "open") | ||
) | ||
return enabled_cond | ||
|
||
|
||
def write_mongodb_to_yaml(mongo_connection_str): | ||
"""Learn the DBNOptions from the database. | ||
Args: | ||
mongo_connection_str: The MongoDB conenction string""" | ||
|
||
loc_sf = StateFactor("location", ["v{}".format(i) for i in range(1, 9)]) | ||
door_sfs = [ | ||
StateFactor("v2_door", ["unknown", "closed", "open"]), | ||
StateFactor("v3_door", ["unknown", "closed", "open"]), | ||
StateFactor("v4_door", ["unknown", "closed", "open"]), | ||
StateFactor("v5_door", ["unknown", "closed", "open"]), | ||
StateFactor("v6_door", ["unknown", "closed", "open"]), | ||
StateFactor("v7_door", ["unknown", "closed", "open"]), | ||
] | ||
|
||
print("Writing mongo database to yaml file") | ||
mongodb_to_yaml( | ||
mongo_connection_str, | ||
"refine-plan", | ||
"fake-museum-data", | ||
[loc_sf] + door_sfs, | ||
"../data/fake_museum/dataset.yaml", | ||
) | ||
|
||
|
||
def learn_options(): | ||
"""Learn the options from the YAML file.""" | ||
dataset_path = "../data/fake_museum/dataset.yaml" | ||
output_dir = "../data/fake_museum/" | ||
|
||
loc_sf = StateFactor("location", ["v{}".format(i) for i in range(1, 9)]) | ||
door_sfs = [ | ||
StateFactor("v2_door", ["unknown", "closed", "open"]), | ||
StateFactor("v3_door", ["unknown", "closed", "open"]), | ||
StateFactor("v4_door", ["unknown", "closed", "open"]), | ||
StateFactor("v5_door", ["unknown", "closed", "open"]), | ||
StateFactor("v6_door", ["unknown", "closed", "open"]), | ||
StateFactor("v7_door", ["unknown", "closed", "open"]), | ||
] | ||
|
||
learn_dbns(dataset_path, output_dir, [loc_sf] + door_sfs) | ||
|
||
|
||
def run_planner(): | ||
"""Run refine-plan and synthesise a BT. | ||
Returns: | ||
The refined BT | ||
""" | ||
|
||
loc_sf = StateFactor("location", ["v{}".format(i) for i in range(1, 9)]) | ||
door_sfs = [ | ||
StateFactor("v2_door", ["unknown", "closed", "open"]), | ||
StateFactor("v3_door", ["unknown", "closed", "open"]), | ||
StateFactor("v4_door", ["unknown", "closed", "open"]), | ||
StateFactor("v5_door", ["unknown", "closed", "open"]), | ||
StateFactor("v6_door", ["unknown", "closed", "open"]), | ||
StateFactor("v7_door", ["unknown", "closed", "open"]), | ||
] | ||
sf_list = [loc_sf] + door_sfs | ||
|
||
labels = [Label("goal", EqCondition(loc_sf, "v8"))] | ||
|
||
option_names = [ | ||
"e12", | ||
"e14", | ||
"e58", | ||
"e78", | ||
"e13", | ||
"e36", | ||
"e68", | ||
"e25", | ||
"e47", | ||
"e26", | ||
"e35", | ||
"e46", | ||
"e37", | ||
"e23", | ||
"e34", | ||
"e56", | ||
"e67", | ||
"check_door", | ||
"open_door", | ||
] | ||
|
||
assert len(set(option_names)) == 19 # Quick safety check | ||
|
||
init_state_dict = {sf: "unknown" for sf in door_sfs} | ||
init_state_dict[loc_sf] = "v1" | ||
init_state = State(init_state_dict) | ||
|
||
option_list = [] | ||
for option in option_names: | ||
print("Reading in option: {}".format(option)) | ||
t_path = "../data/fake_museum/{}_transition.bifxml".format(option) | ||
r_path = "../data/fake_museum/{}_reward.bifxml".format(option) | ||
option_list.append( | ||
DBNOption( | ||
option, t_path, r_path, sf_list, _get_enabled_cond(sf_list, option) | ||
) | ||
) | ||
|
||
print("Creating MDP...") | ||
semi_mdp = SemiMDP(sf_list, option_list, labels, initial_state=init_state) | ||
print("Synthesising Policy...") | ||
policy = synthesise_policy(semi_mdp, prism_prop='Rmin=?[F "goal"]') | ||
policy.write_policy("../data/fake_museum/fake_museum_refined_policy.yaml") | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
# write_mongodb_to_yaml(sys.argv[1]) | ||
# learn_options() | ||
run_planner() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
#!/usr/bin/env python | ||
""" Script for plotting the fake museum REFINE-PLAN results. | ||
Author: Charlie Street | ||
""" | ||
|
||
from scipy.stats import mannwhitneyu | ||
from pymongo import MongoClient | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import matplotlib | ||
import sys | ||
|
||
plt.rcParams["pdf.fonttype"] = 42 | ||
matplotlib.rcParams.update({"font.size": 40}) | ||
|
||
|
||
def read_results_for_method(collection, sf_names): | ||
"""Read the mongo results for a single method (i.e. a collection). | ||
Args: | ||
collection: The MongoDB collection | ||
sf_names: A list of state factor names | ||
Returns: | ||
results: The list of run durations | ||
""" | ||
|
||
# Group docs together by run_id | ||
docs_per_run = {} | ||
for doc in collection.find({}): | ||
if doc["run_id"] not in docs_per_run: | ||
docs_per_run[doc["run_id"]] = [] | ||
docs_per_run[doc["run_id"]].append(doc) | ||
|
||
# Sanity check each run | ||
results = [] | ||
for run_id in docs_per_run: | ||
total_duration = 0.0 | ||
in_order = sorted(docs_per_run[run_id], key=lambda d: d["date_started"]) | ||
assert in_order[0]["location0"] == "v1" | ||
assert in_order[-1]["locationt"] == "v8" | ||
|
||
for i in range(len(in_order) - 1): | ||
total_duration += in_order[i]["duration"] | ||
for sf in sf_names: | ||
assert ( | ||
in_order[i]["{}t".format(sf)] == in_order[i + 1]["{}0".format(sf)] | ||
) | ||
total_duration += in_order[-1]["duration"] | ||
results.append(total_duration) | ||
|
||
assert len(results) == 100 | ||
return results | ||
|
||
|
||
def print_stats(init_results, refined_results): | ||
"""Print the statistics for the initial and refined results. | ||
Args: | ||
init_results: The durations for the initial behaviour | ||
refined_results: The durations for the refined behaviour | ||
""" | ||
print( | ||
"INITIAL BEHAVIOUR: AVG COST: {}; VARIANCE: {}".format( | ||
np.mean(init_results), np.var(init_results) | ||
) | ||
) | ||
print( | ||
"REFINED BEHAVIOUR: AVG COST: {}; VARIANCE: {}".format( | ||
np.mean(refined_results), np.var(refined_results) | ||
) | ||
) | ||
p = mannwhitneyu( | ||
refined_results, | ||
init_results, | ||
alternative="less", | ||
)[1] | ||
print( | ||
"REFINED BT BETTER THAN INITIAL BT: p = {}, stat sig better = {}".format( | ||
p, p < 0.05 | ||
) | ||
) | ||
|
||
|
||
def set_box_colors(bp): | ||
plt.setp(bp["boxes"][0], color="tab:blue", linewidth=8.0) | ||
plt.setp(bp["caps"][0], color="tab:blue", linewidth=8.0) | ||
plt.setp(bp["caps"][1], color="tab:blue", linewidth=8.0) | ||
plt.setp(bp["whiskers"][0], color="tab:blue", linewidth=8.0) | ||
plt.setp(bp["whiskers"][1], color="tab:blue", linewidth=8.0) | ||
plt.setp(bp["fliers"][0], color="tab:blue") | ||
plt.setp(bp["medians"][0], color="tab:blue", linewidth=8.0) | ||
|
||
plt.setp(bp["boxes"][1], color="tab:red", linewidth=8.0) | ||
plt.setp(bp["caps"][2], color="tab:red", linewidth=8.0) | ||
plt.setp(bp["caps"][3], color="tab:red", linewidth=8.0) | ||
plt.setp(bp["whiskers"][2], color="tab:red", linewidth=8.0) | ||
plt.setp(bp["whiskers"][3], color="tab:red", linewidth=8.0) | ||
plt.setp(bp["medians"][1], color="tab:red", linewidth=8.0) | ||
|
||
|
||
def plot_box_plot(init_results, refined_results): | ||
"""Plot a box plot showing the initial and refined results. | ||
Args: | ||
init_results: The durations for the initial behaviour | ||
refined_results: The durations for the refined behaviour | ||
""" | ||
|
||
box = plt.boxplot( | ||
[init_results, refined_results], | ||
whis=[0, 100], | ||
positions=[1, 2], | ||
widths=0.6, | ||
) | ||
set_box_colors(box) | ||
|
||
plt.tick_params( | ||
axis="x", # changes apply to the x-axis | ||
which="both", # both major and minor ticks are affected | ||
bottom=True, # ticks along the bottom edge are off | ||
top=False, # ticks along the top edge are off | ||
labelbottom=True, # labels along the bottom edge are offcd | ||
labelsize=40, | ||
) | ||
plt.ylabel("Time to Reach Goal (s)") | ||
|
||
plt.xticks([1, 2], ["Initial BT", "Refined BT"]) | ||
|
||
plt.show() | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
sf_names = ["v{}_door".format(v) for v in range(2, 7)] | ||
sf_names = ["location"] + sf_names | ||
client = MongoClient(sys.argv[1]) | ||
db = client["refine-plan"] | ||
init_results = read_results_for_method(db["fake-museum-initial"], sf_names) | ||
refined_results = read_results_for_method(db["fake-museum-refined"], sf_names) | ||
print_stats(init_results, refined_results) | ||
plot_box_plot(init_results, refined_results) |
Oops, something went wrong.