Skip to content

Commit

Permalink
GOT can run again
Browse files Browse the repository at this point in the history
  • Loading branch information
j-luo93 committed Feb 8, 2021
1 parent 53392ba commit 9f510d1
Show file tree
Hide file tree
Showing 17 changed files with 741 additions and 139 deletions.
53 changes: 28 additions & 25 deletions scripts/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import logging
import pickle
import re
from copy import deepcopy
from dataclasses import dataclass, field
from typing import ClassVar, Dict, List, Optional, Set, Tuple, Union

Expand All @@ -14,12 +16,12 @@
from pypheature.segment import Segment
from sound_law.data.alphabet import EOT, SOT, Alphabet
from sound_law.main import setup
from sound_law.rl.action import SoundChangeAction, SoundChangeActionSpace
from sound_law.rl.mcts_cpp import \
PyNull_abc # pylint: disable=no-name-in-module
from sound_law.rl.action import SoundChangeAction
from sound_law.rl.env import SoundChangeEnv
# from sound_law.rl.mcts_cpp import
# PyNull_abc # pylint: disable=no-name-in-module
from sound_law.rl.trajectory import VocabState
from sound_law.train.manager import OnePairManager
from copy import deepcopy

_fp = FeatureProcessor()
# NOTE(j_luo) We use `a` to represent the back vowel `ɑ`.
Expand Down Expand Up @@ -382,25 +384,26 @@ def get_actions(raw_rules: List[str], orders: List[str], refs: Optional[List[str
class PlainState:
"""This stores the plain vocabulary state (using str), as opposed to `VocabState` that is used by MCTS (using ids)."""

action_space: ClassVar[SoundChangeActionSpace] = None
# action_space: ClassVar[SoundChangeActionSpace] = None
env: ClassVar[SoundChangeEnv] = None
end_state: ClassVar[PlainState] = None
abc: ClassVar[Alphabet] = None

def __init__(self, segments: List[List[str]]):
self.segments = segments

@classmethod
def from_vocab_state(cls, vocab: VocabState) -> PlainState:
return cls(vocab.segment_list)
def __init__(self, node: VocabState):
self.segments = node.segment_list
self._node = node

def apply_action(self, action: SoundChangeAction) -> PlainState:
cls = type(self)
assert cls.action_space is not None
new_segments = list()
for seg in self.segments:
new_segments.append(cls.action_space.apply_action(seg, action))
assert cls.env is not None
try:
new_node = cls.env.apply_action(self._node, action.before_id, action.after_id,
action.pre_id, action.d_pre_id, action.post_id, action.d_post_id)
return cls(new_node)

return cls(new_segments)
except RuntimeError:
logging.warn("No site was targeted.")
return self

def dist_from(self, tgt_segments: List[List[str]]):
'''Returns the distance between the current state and a specified state of segments'''
Expand All @@ -411,10 +414,10 @@ def dist_from(self, tgt_segments: List[List[str]]):
for s1, s2 in zip(self.segments, tgt_segments):
s1 = [cls.abc[u] for u in s1] # pylint: disable=unsubscriptable-object
s2 = [cls.abc[u] for u in s2] # pylint: disable=unsubscriptable-object
dist += cls.action_space.word_space.get_edit_dist(s1, s2)
dist += cls.env.get_edit_dist(s1, s2)
return dist

@ property
@property
def dist(self) -> float:
'''Returns the distance between the current state and the end state'''
cls = type(self)
Expand Down Expand Up @@ -496,30 +499,30 @@ def simulate(raw_inputs: Optional[List[Tuple[List[str], List[str], List[str]]]]
gold.extend(get_actions(rows['w/ SS'], rows['order'], refs=rows['ref no.']))

# Simulate the actions and get the distance.
PlainState.action_space = manager.action_space
PlainState.end_state = PlainState.from_vocab_state(manager.env.end)
PlainState.env = manager.env
PlainState.end_state = PlainState(manager.env.end)
PlainState.abc = manager.tgt_abc
state = PlainState.from_vocab_state(manager.env.start)
state = PlainState(manager.env.start)
states = [state]
actions = list()
refs = list()
expanded_gold = list()

print(state.dist)
logging.info(f"Starting dist: {state.dist:.3f}")
for hr in gold:
if hr.expandable:
action_q = hr.specialize(state)
print(hr)
logging.warn(f"This is an expandable rule: {hr}")
else:
action_q = [hr.to_action()]
for action in action_q:
logging.info(f"Applying {action}")
state = state.apply_action(action)
states.append(state)
actions.append(action)
refs.append(hr.ref)
expanded_gold.append(action)
print(action)
print(state.dist)
logging.info(f"New dist: {state.dist:.3f}")

# NOTE(j_luo) We can only score based on expanded rules.
gold = expanded_gold
Expand Down
7 changes: 1 addition & 6 deletions scripts/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,7 @@ def get_record(ipas):
tgt_seqs = dl.entire_batch.tgt_seqs
t_arr = np.ascontiguousarray(tgt_seqs.ids.t().cpu().numpy()).astype("uint16")
t_lengths = np.ascontiguousarray(tgt_seqs.lengths.t().cpu().numpy())
py_ss = PySiteSpace(SOT_ID, EOT_ID, ANY_ID, EMP_ID, SYL_EOT_ID, ANY_S_ID, ANY_UNS_ID)
py_ws = PyWordSpace(py_ss, manager.tgt_abc.dist_mat, 2.0)
action_space = SoundChangeActionSpace(py_ss, py_ws, g.dist_threshold,
g.site_threshold, manager.tgt_abc)
env = SoundChangeEnv(action_space, py_ws, s_arr, s_lengths, t_arr, t_lengths, g.final_reward, g.step_penalty)
env = manager.env

init_n_chars = len(get_all_chars(env.start, manager.tgt_abc))
print(init_n_chars)
Expand All @@ -137,7 +133,6 @@ def get_record(ipas):
np.random.seed(args.random_seed)
for i in range(args.length):
while True:
env.action_space.set_action_allowed(state)
best_i = np.random.choice(state.num_actions)
print(state.num_actions, 'allowed.')
# for i, a in enumerate(state.action_allowed):
Expand Down
36 changes: 19 additions & 17 deletions sound_law/data/alphabet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
SYL_EOT = '<syl_EOT>'
ANY_S = '<any_s>'
ANY_UNS = '<any_uns>'
NULL = "<null>"
SOT_ID = 0
EOT_ID = 1
PAD_ID = 2
Expand All @@ -28,6 +29,7 @@
SYL_EOT_ID = 5
ANY_S_ID = 6
ANY_UNS_ID = 7
NULL_ID = 8

_ft = FeatureTable()

Expand Down Expand Up @@ -96,29 +98,29 @@ def __init__(self, lang: str, contents: List[List[str]], sources: Union[str, Lis

# Get vowel info.
n = len(self._id2unit)
self.vowel_mask = np.zeros(n, dtype=bool)
self.vowel_base = np.arange(n, dtype='uint16')
self.vowel_stress = np.zeros(n, dtype='int32')
self.stressed_vowel = np.arange(n, dtype='uint16')
self.unstressed_vowel = np.arange(n, dtype='uint16')
self.vowel_stress.fill(mcts_cpp.PyNoStress)
self.vowel_stress[ANY_S_ID] = mcts_cpp.PyStressed
self.vowel_stress[ANY_UNS_ID] = mcts_cpp.PyUnstressed
self.stressed_vowel[ANY_ID] = ANY_S_ID
self.unstressed_vowel[ANY_ID] = ANY_UNS_ID
self.is_vowel = np.zeros(n, dtype=bool)
self.unit_stress = np.zeros(n, dtype='int32')
self.unit2base = np.arange(n, dtype='uint16')
self.unit2stressed = np.arange(n, dtype='uint16')
self.unit2unstressed = np.arange(n, dtype='uint16')
self.unit_stress.fill(mcts_cpp.PyNoStress)
self.unit_stress[ANY_S_ID] = mcts_cpp.PyStressed
self.unit_stress[ANY_UNS_ID] = mcts_cpp.PyUnstressed
self.unit2stressed[ANY_ID] = ANY_S_ID
self.unit2unstressed[ANY_ID] = ANY_UNS_ID
for u in self._id2unit:
if u.endswith('{+}') or u.endswith('{-}'):
base = u[:-3]
base_id = self._unit2id[base]
i = self._unit2id[u]
self.vowel_mask[base_id] = True
self.vowel_mask[i] = True
self.vowel_base[i] = base_id
self.vowel_stress[i] = mcts_cpp.PyStressed if u[-2] == '+' else mcts_cpp.PyUnstressed
self.is_vowel[base_id] = True
self.is_vowel[i] = True
self.unit2base[i] = base_id
self.unit_stress[i] = mcts_cpp.PyStressed if u[-2] == '+' else mcts_cpp.PyUnstressed
if u.endswith('{+}'):
self.stressed_vowel[base_id] = i
self.unit2stressed[base_id] = i
else:
self.unstressed_vowel[base_id] = i
self.unit2unstressed[base_id] = i

self.stats: pd.DataFrame = pd.DataFrame.from_dict(cnt)
self.dist_mat = self.edges = self.cl_map = self.gb_map = None
Expand All @@ -133,7 +135,7 @@ def __init__(self, lang: str, contents: List[List[str]], sources: Union[str, Lis
orig_u2i = {u: i for i, u in enumerate(orig_units)}
new_ids = np.asarray([self[u] for u in orig_units] + [self[u] for u in units[base_n:]])
orig_ids = np.asarray(list(range(len(orig_units))) +
[orig_u2i[self[self.vowel_base[self[u]]]] for u in units[base_n:]])
[orig_u2i[self[self.unit2base[self[u]]]] for u in units[base_n:]])
self.dist_mat[new_ids.reshape(-1, 1), new_ids] = dist_mat[orig_ids.reshape(-1, 1), orig_ids]
self.edges = edges
self.cl_map = cl_map
Expand Down
35 changes: 23 additions & 12 deletions sound_law/rl/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from dev_misc import BT, add_argument, g, get_tensor, get_zeros
from dev_misc.utils import Singleton, pbar
from sound_law.data.alphabet import (ANY_ID, ANY_S_ID, ANY_UNS_ID, EMP, EMP_ID,
EOT_ID, SOT_ID, SYL_EOT_ID, Alphabet)
EOT_ID, NULL_ID, SOT_ID, SYL_EOT_ID,
Alphabet)

# pylint: disable=no-name-in-module
# from .mcts_cpp import PyAction
Expand All @@ -32,16 +33,25 @@
add_argument('ngram_path', dtype='path', msg='Path to the ngram list.')


@dataclass(eq=True, frozen=True)
class SoundChangeAction:
# class SoundChangeAction(PyAction):
"""One sound change rule."""

before_id: int
after_id: int
pre_id: int
d_pre_id: int
post_id: int
d_post_id: int
special_type: Optional[str] = None

abc: ClassVar[Alphabet] = None

def __hash__(self):
return self.action_id
# def __hash__(self):
# return hash(f"{self.special_type}: {self.before_id} {self.after_id} {self.pre_id} {self.d_pre_id} {self.post_id} {self.d_post_id}")

def __eq__(self, other: SoundChangeAction):
return self.action_id == other.action_id
# def __eq__(self, other: SoundChangeAction):
# return self.before_id == other.before_id and self.after_id == other.after_id and self.pre_id== self.other_id

@classmethod
def from_str(cls, before: str, after: str,
Expand Down Expand Up @@ -71,16 +81,17 @@ def to_int(unit: Union[None, str], before_or_after: str) -> int:
if unit == '##':
return SYL_EOT_ID
if unit is None:
return PyNull_abc
return cls.abc[unit]
return NULL_ID
return cls.abc[unit] # pylint: disable=unsubscriptable-object

return cls(cls.abc[before], to_int(after, 'a'),
return cls(cls.abc[before], to_int(after, 'a'), # pylint: disable=unsubscriptable-object
to_int(pre, 'b'), to_int(d_pre, 'b'),
to_int(post, 'a'), to_int(d_post, 'a'),
special_type=special_type)

def __repr__(self):
if self.action_id == PyStop:
# if self.action_id == PyStop:
if self.before_id == NULL_ID:
return 'STOP'

def get_str(idx: int):
Expand All @@ -105,10 +116,10 @@ def get_cond(cond):
ret = f'({ret})'
return ret

pre = get_cond(self.pre_cond)
pre = get_cond([idx for idx in [self.d_pre_id, self.pre_id] if idx != NULL_ID])
if pre:
pre = f'{pre} + '
post = get_cond(self.post_cond)
post = get_cond([idx for idx in [self.post_id, self.d_post_id] if idx != NULL_ID])
if post:
post = f' + {post}'

Expand Down
50 changes: 37 additions & 13 deletions sound_law/rl/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,60 @@

import numpy as np
import torch
import torch.nn as nn

from dev_misc import FT, LT, add_argument, g, get_tensor, get_zeros
from dev_misc.devlib import pad_to_dense
from dev_misc.devlib.named_tensor import NoName
from dev_misc.utils import handle_sequence_inputs
from sound_law.data.alphabet import PAD_ID
from sound_law.data.alphabet import PAD_ID, Alphabet, EMP

from .action import SoundChangeAction
from .agent import AgentInputs, VanillaPolicyGradient
from .mcts_cpp import PyEnv # pylint: disable=no-name-in-module
# pylint: disable=no-name-in-module
from .mcts_cpp import PyActionSpaceOpt, PyEnv, PyEnvOpt, PyWordSpaceOpt
# pylint: enable=no-name-in-module
from .trajectory import Trajectory, VocabState


class SoundChangeEnv(nn.Module, PyEnv):
class SoundChangeEnv(PyEnv):

tnode_cls = VocabState

add_argument(f'final_reward', default=1.0, dtype=float, msg='Final reward for reaching the end.')
add_argument(f'step_penalty', default=0.02, dtype=float, msg='Penalty for each step if not the end state.')

# pylint: disable=unused-argument
def __init__(self,
action_space: SoundChangeActionSpace,
word_space: PyWordSpace,
s_arr, s_lengths,
e_arr, e_lengths,
final_reward: float,
step_penalty: float):
nn.Module.__init__(self)
def register_changes(self, abc: Alphabet):
# # Set class variable for `SoundChangeAction` here.
SoundChangeAction.abc = abc

# Register unconditional actions first.
units = [u for u in abc if u not in abc.special_units]

def register_uncondional_action(u1: str, u2: str, cl: bool = False, gb: bool = False):
id1 = abc[u1]
id2 = abc[u2]
if cl:
self.register_cl_map(id1, id2)
elif gb:
if u1.startswith('i'):
self.register_gbj(id1, id2)
else:
assert u1.startswith('u')
self.register_gbw(id1, id2)
else:
self.register_permissible_change(id1, id2)

for u1, u2 in abc.edges:
register_uncondional_action(u1, u2)
for u in units:
register_uncondional_action(u, EMP)
# for u1, u2 in abc.cl_map.items():
# register_uncondional_action(u1, u2, cl=True)
# for u1, u2 in abc.gb_map.items():
# register_uncondional_action(u1, u2, gb=True)

# self.set_vowel_info(abc.vowel_mask, abc.vowel_base, abc.vowel_stress, abc.stressed_vowel, abc.unstressed_vowel)
# self.set_glide_info(abc['j'], abc['w'])

def forward(self, state: VocabState, best_i: int, action: SoundChangeAction) -> Tuple[VocabState, bool, float]:
return self.step(state, best_i, action)
Expand Down
Loading

0 comments on commit 9f510d1

Please sign in to comment.