Skip to content

Commit

Permalink
added donot attack npc to takeru
Browse files Browse the repository at this point in the history
  • Loading branch information
kywch committed Apr 6, 2024
1 parent d029956 commit 6e08c06
Show file tree
Hide file tree
Showing 3 changed files with 687 additions and 0 deletions.
13 changes: 13 additions & 0 deletions agent_zoo/takeru/reward_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import numpy as np

from nmmo.entity.entity import EntityState

from reinforcement_learning.stat_wrapper import BaseStatWrapper

EntityAttr = EntityState.State.attr_name_to_col


class RewardWrapper(BaseStatWrapper):
def __init__(
Expand All @@ -13,12 +19,14 @@ def __init__(
explore_bonus_weight=0,
clip_unique_event=3,
disable_give=True,
donot_attack_dangerous_npc=True,
):
super().__init__(env, eval_mode, early_stop_agent_num, stat_prefix)
self.stat_prefix = stat_prefix
self.explore_bonus_weight = explore_bonus_weight
self.clip_unique_event = clip_unique_event
self.disable_give = disable_give
self.donot_attack_dangerous_npc = donot_attack_dangerous_npc

def observation(self, agent_id, agent_obs):
"""Called before observations are returned from the environment
Expand All @@ -32,6 +40,11 @@ def observation(self, agent_id, agent_obs):
agent_obs["ActionTargets"]["GiveGold"]["Target"][:-1] = 0
agent_obs["ActionTargets"]["GiveGold"]["Price"][1:] = 0

if self.donot_attack_dangerous_npc is True:
# npc type: 1: passive, 2: neutral, 3: hostile
dangerours_npc_idxs = np.where(agent_obs["Entity"][:, EntityAttr["npc_type"]] > 1)
agent_obs["ActionTargets"]["Attack"]["Target"][dangerours_npc_idxs] = 0

return agent_obs

def reward_terminated_truncated_info(self, agent_id, reward, terminated, truncated, info):
Expand Down
Loading

0 comments on commit 6e08c06

Please sign in to comment.