Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add node outputs cache + fixes assert crashes #8

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions hivemind_exp/dht_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import lru_cache
from typing import Any
from hivemind.dht import DHT
from hivemind.utils import ValueWithExpiration
Expand All @@ -7,24 +8,49 @@
ROUND_STAGE_NUMBER_KEY = "rl_swarm_rs" # No subkeys. Coordinator publishes.

# Round and stage (e.g. 0_0) appended.
LEADERBOARD_KEY_PREFIX = "rl_swarm_leaderboard" # Subkey = Metric. Coordinator publishes.
REWARDS_KEY = "rl_swarm_rewards" # Subkey = Metric. Everyone publishes.
LEADERBOARD_KEY_PREFIX = (
"rl_swarm_leaderboard" # Subkey = Metric. Coordinator publishes.
)
REWARDS_KEY = "rl_swarm_rewards" # Subkey = Metric. Everyone publishes.

# Node UUID, round, and stage (e.g. abcde_0_0) appended.
OUTPUTS_KEY_PREFIX = "rl_swarm_outputs" # Subkey = Example Hash. Everyone publishes.


def leaderboard_key(round_num, stage) -> str:
return f"{LEADERBOARD_KEY_PREFIX}_{round_num}_{stage}"


def rewards_key(round_num, stage) -> str:
return f"{REWARDS_KEY}_{round_num}_{stage}"

def node_outputs_key(node: HivemindNode) -> str:
return outputs_key(node.uuid, node.round_num, node.stage_num)

def outputs_key(node_uuid: str, round_num, stage) -> str:
return f"{OUTPUTS_KEY_PREFIX}_{node_uuid}_{round_num}_{stage}"


def node_outputs_key(node: HivemindNode) -> str:
return outputs_key(node.uuid, node.round_num, node.stage_num)


@lru_cache
def get_outputs(
dht: DHT, node_uuid: str, r, s, get_cached_fn=None
) -> dict[str, tuple[float, dict]]: # Q: (timestamp, outputs)
# Try provided cache function first.
if get_cached_fn:
if outputs := get_cached_fn(r, s):
return outputs

# Try from DHT next to include peered outputs.
if outputs := get_dht_value(dht, key=outputs_key(node_uuid, r, s), latest=False):
return outputs

raise ValueError(
f"could not retrieve stage outputs for {node_uuid} at round {r} stage {s}"
)


def get_round_and_stage(
dht: DHT,
) -> tuple[int, int]:
Expand All @@ -35,6 +61,7 @@ def get_round_and_stage(
round_num, stage = value
return round_num, stage


def get_dht_value(dht: DHT, **kwargs) -> Any | None:
wrapper = dht.get(**kwargs)
if not wrapper:
Expand Down
4 changes: 1 addition & 3 deletions hivemind_exp/gsm8k/generate_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,7 @@ def fill_unknown_answers_opinions(values):
for field in val:
if field in FILLED_FIELDS:
diff_keys = agent_set - val[field].keys()
for (
agent
) in (
for agent in (
diff_keys
): # Fill with default values. TODO: Decide if this is a good choice.
val[field].update({agent: "No answer received..."})
Expand Down
11 changes: 4 additions & 7 deletions hivemind_exp/gsm8k/stage_merger.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
from typing import Any


Expand All @@ -10,9 +9,8 @@ def merge_stage1_question(outputs: dict[str, dict[str, Any]]):
merged["question"] = o["question"]
merged["answer"] = o["answer"]
merged["agent_answers"].update(o["agent_answers"])
for agent in (
outputs.keys()
): # Fill with default values. TODO: Decide if this is a good choice.
# Fill with default values. TODO: Decide if this is a good choice.
for agent in outputs:
if agent not in merged["agent_answers"]:
merged["agent_answers"].update({agent: "No answer received..."})
return merged
Expand All @@ -33,9 +31,8 @@ def merge_stage2_question(outputs: dict[str, dict[str, Any]]):
merged[col] = o[col]
if "agent_opinion" in o:
merged["agent_opinion"].update(o["agent_opinion"])
for agent in (
outputs.keys()
): # Fill with default values. TODO: Decide if this is a good choice.
# Fill with default values. TODO: Decide if this is a good choice.
for agent in outputs:
if agent not in merged["agent_opinion"]:
merged["agent_opinion"].update({agent: "No feedback received..."})
return merged
127 changes: 93 additions & 34 deletions hivemind_exp/gsm8k/stage_utils.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,117 @@
from collections import defaultdict
from hivemind_exp.dht_utils import *
import logging
import time
from hivemind_exp.dht_utils import (
DHT,
HivemindNode,
get_dht_value,
get_outputs,
rewards_key,
)
import hivemind_exp.gsm8k.stage1_rewards as stage1_rewards
import hivemind_exp.gsm8k.stage2_rewards as stage2_rewards
import hivemind_exp.gsm8k.stage3_rewards as stage3_rewards
from hivemind_exp.gsm8k.generate_prompts import *
from hivemind_exp.gsm8k.stage_merger import *
from hivemind_exp.gsm8k.generate_prompts import get_stage2_samples, get_stage3_samples
from hivemind_exp.gsm8k.stage_merger import (
Any,
merge_stage1_question,
merge_stage2_question,
)
from hivemind_exp.utils import SingleStageData, StageData

logger = logging.getLogger(__name__)

def gsm8k_stage_data(dht, node, initial_train_dataset, initial_test_dataset):
def cumulative_reward_0(**kwargs):
return stage1_rewards.hivemind_cumulative_reward(node, **kwargs)
def merged_prev_stage_datasets(
dht: DHT,
node: HivemindNode,
r: int,
s: int,
merge_fn,
samples_fn,
wait_interval=1,
wait_timeout=5,
):
merged_qs = []

def cumulative_reward_1(**kwargs):
return stage2_rewards.hivemind_cumulative_reward(node, **kwargs)
# Retrieves and merges last stage samples locally and from DHT.
def get_prev_rewards():
return get_dht_value(dht, key=rewards_key(r, s - 1), latest=True, beam_size=1000)

def cumulative_reward_2(**kwargs):
return stage3_rewards.hivemind_cumulative_reward(node, **kwargs)
prev_rewards: dict[str, Any] | None = get_prev_rewards()
start_time = time.monotonic()
while not prev_rewards and time.monotonic() - start_time < wait_timeout:
logger.info(
f"[{node.uuid}] Can't retrieve round {r} stage {s - 1} rewards; trying again in {wait_interval}s "
)
time.sleep(wait_interval)
prev_rewards = get_prev_rewards()

def stage_datasets_fn(r, s, merge_fn, samples_fn):
prev_rewards: dict[str, Any] | None = get_dht_value(
dht, key=rewards_key(r, s - 1), latest=True
# Add the current node's local samples first.
prev_outputs: dict[str, list] = defaultdict(list)
try:
prev_node_outputs = get_outputs(
dht, node.uuid, r, s - 1, node.get_stage_outputs
)
for _, outputs in prev_node_outputs.values():
prev_outputs[node.uuid].append(outputs)
except ValueError:
# Joined after the round has started.
logger.info(
f"[{node.uuid}] Could not retrieve local outputs for round {r} stage {s - 1}"
)
assert prev_rewards

prev_outputs: dict[str, list] = defaultdict(list)
for node_uuid in prev_rewards:
prev_node_outputs: dict[str, tuple[float, dict]] | None = get_dht_value(
dht, key=outputs_key(node_uuid, r, s - 1), latest=True
)
assert prev_node_outputs
for _, outputs in prev_node_outputs.values():
prev_outputs[node_uuid].append(outputs)
# Add other nodes' samples iff rewards are available.
if prev_rewards:
node_uuids = prev_rewards.keys()
for node_uuid in node_uuids:
if node_uuid == node.uuid:
continue
try:
prev_node_outputs = get_outputs(dht, node_uuid, r, s - 1)
for _, outputs in prev_node_outputs.values():
prev_outputs[node_uuid].append(outputs)
except ValueError:
# Skip this node's answers for the current round and stage.
logger.info(
f"[{node.uuid}] Found rewards published for node: {node_uuid} but no outputs!"
)

# Merge all samples.
q_to_keyed_outputs: dict[str, dict[str, Any]] = defaultdict(dict)
for node_uuid, all_outputs in prev_outputs.items():
for outputs in all_outputs:
q_to_keyed_outputs[outputs["question"]][node_uuid] = outputs

for outputs in q_to_keyed_outputs.values():
merged = merge_fn(outputs)
merged_qs.append(merged)

q_to_keyed_outputs: dict[str, dict[str, Any]] = defaultdict(dict)
for node_uuid, all_outputs in prev_outputs.items():
for outputs in all_outputs:
q_to_keyed_outputs[outputs["question"]][node_uuid] = outputs
return samples_fn(merged_qs)

merged_qs = []
for outputs in q_to_keyed_outputs.values():
merged = merge_fn(outputs)
merged_qs.append(merged)

return samples_fn(merged_qs)
def gsm8k_stage_data(
dht: DHT, node: HivemindNode, initial_train_dataset, initial_test_dataset
):
def cumulative_reward_0(**kwargs):
return stage1_rewards.hivemind_cumulative_reward(node, **kwargs)

def cumulative_reward_1(**kwargs):
return stage2_rewards.hivemind_cumulative_reward(node, **kwargs)

def cumulative_reward_2(**kwargs):
return stage3_rewards.hivemind_cumulative_reward(node, **kwargs)

def stage2_datasets_fn(r, s):
return stage_datasets_fn(r, s, merge_stage1_question, get_stage2_samples)
return merged_prev_stage_datasets(
dht, node, r, s, merge_stage1_question, get_stage2_samples
)

def stage3_datasets_fn(r, s):
return stage_datasets_fn(r, s, merge_stage2_question, get_stage3_samples)
return merged_prev_stage_datasets(
dht, node, r, s, merge_stage2_question, get_stage3_samples
)

return StageData(
max_rounds=100, # note, this gets overridden from the config file
stages=[
SingleStageData(
name="0",
Expand Down
28 changes: 27 additions & 1 deletion hivemind_exp/tests/fake_data.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
from hivemind_exp.utils import COORDINATOR_KEY
from hivemind_exp.dht_utils import ROUND_STAGE_NUMBER_KEY

Expand All @@ -14,6 +15,9 @@
{"content": "You are a pirate.", "role": "system"},
{"content": QUESTION, "role": "user"},
],
"agent_answers": {
CK: "The meaning of life is 42.",
},
},
{
"question": QUESTION,
Expand All @@ -22,9 +26,31 @@
{"content": "You are a cat.", "role": "system"},
{"content": QUESTION, "role": "user"},
],
}
"agent_answers": {
"0": "The meaning of life is to sleep.",
},
},
]


def samples_with_uuid(new_uuid, orig_samples=SAMPLES, field="agent_answers"):
orig_uuids = (CK, "0", "1", "2")

def replace(orig, value):
if field in value:
answers = value[field]
if orig != new_uuid and orig in answers:
answers[new_uuid] = answers[orig]
del answers[orig]

samples = deepcopy(orig_samples)
for sample in samples:
for orig in orig_uuids:
replace(orig, sample)

return samples


STAGE_1_OUTPUTS = {
CK: {
"question": "Carl is taking a class where the whole grade is based on four tests that are graded out of 100. He got an 80, a 75 and a 90 on his first three tests. If he wants an 85 average for the class, what is the minimum grade he needs to get on his last test?",
Expand Down
Loading