-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
28 changed files
with
1,527 additions
and
579 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import tensorflow | ||
if tensorflow.__version__.startswith('1.'): | ||
import tensorflow as tf | ||
else: | ||
import tensorflow.compat.v1 as tf | ||
tf.disable_v2_behavior() | ||
|
||
|
||
class MyModel(object): | ||
def __init__(self, embedding_dim, hidden_dim, vocab_size_char, vocab_size_bio, vocab_size_attr, O_tag_index, use_crf): | ||
self.inputs_seq = tf.placeholder(tf.int32, [None, None], name="inputs_seq") # B * S | ||
self.inputs_seq_len = tf.placeholder(tf.int32, [None], name="inputs_seq_len") # B | ||
self.outputs_seq_bio = tf.placeholder(tf.int32, [None, None], name='outputs_seq_bio') # B * S | ||
self.outputs_seq_attr = tf.placeholder(tf.int32, [None, None], name='outputs_seq_attr') # B * S | ||
|
||
with tf.variable_scope('embedding_layer'): | ||
embedding_matrix = tf.get_variable("embedding_matrix", [vocab_size_char, embedding_dim], dtype=tf.float32) | ||
embedded = tf.nn.embedding_lookup(embedding_matrix, self.inputs_seq) | ||
|
||
with tf.variable_scope('encoder'): | ||
cell_fw = tf.nn.rnn_cell.LSTMCell(hidden_dim) | ||
cell_bw = tf.nn.rnn_cell.LSTMCell(hidden_dim) | ||
(rnn_fw_outputs, rnn_bw_outputs), (rnn_fw_final_state, rnn_bw_final_state) = \ | ||
tf.nn.bidirectional_dynamic_rnn(cell_fw=cell_fw, cell_bw=cell_bw, inputs=embedded, | ||
sequence_length=self.inputs_seq_len, dtype=tf.float32) | ||
rnn_outputs = tf.add(rnn_fw_outputs, rnn_bw_outputs) # B * S * D | ||
|
||
with tf.variable_scope('bio_projection'): | ||
logits_bio = tf.layers.dense(rnn_outputs, vocab_size_bio) # B * S * V | ||
probs_bio = tf.nn.softmax(logits_bio, axis=-1) | ||
if not use_crf: | ||
preds_bio = tf.argmax(probs_bio, axis=-1, name="preds_bio") # B * S | ||
else: | ||
log_likelihood, transition_matrix = tf.contrib.crf.crf_log_likelihood(logits_bio, self.outputs_seq_bio, self.inputs_seq_len) | ||
preds_bio, crf_scores = tf.contrib.crf.crf_decode(logits_bio, transition_matrix, self.inputs_seq_len) | ||
|
||
with tf.variable_scope('attr_projection'): | ||
logits_attr = tf.layers.dense(rnn_outputs, vocab_size_attr) # B * S * V | ||
probs_attr = tf.nn.softmax(logits_attr, axis=-1) | ||
preds_attr = tf.argmax(probs_attr, axis=-1, name="preds_attr") # B * S | ||
|
||
self.outputs = (preds_bio, preds_attr) | ||
with tf.variable_scope('loss'): | ||
if not use_crf: | ||
loss_bio = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits_bio, labels=self.outputs_seq_bio) # B * S | ||
masks_bio = tf.sequence_mask(self.inputs_seq_len, dtype=tf.float32) # B * S | ||
loss_bio = tf.reduce_sum(loss_bio * masks_bio, axis=-1) / tf.cast(self.inputs_seq_len, tf.float32) # B | ||
else: | ||
loss_bio = -log_likelihood / tf.cast(self.inputs_seq_len, tf.float32) | ||
|
||
loss_attr = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits_attr, labels=self.outputs_seq_attr) # B * S | ||
masks_attr = tf.cast(tf.not_equal(preds_bio, O_tag_index), tf.float32) # B * S | ||
loss_attr = tf.reduce_sum(loss_attr * masks_attr, axis=-1) / (tf.reduce_sum(masks_attr, axis=-1) + 1e-5) # B | ||
loss = loss_bio + loss_attr # B | ||
|
||
self.loss = tf.reduce_mean(loss) | ||
with tf.variable_scope('opt'): | ||
self.train_op = tf.train.AdamOptimizer().minimize(loss) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import tensorflow | ||
if tensorflow.__version__.startswith('1.'): | ||
import tensorflow as tf | ||
else: | ||
import tensorflow.compat.v1 as tf | ||
tf.disable_v2_behavior() | ||
|
||
|
||
class MyModel(object): | ||
def __init__(self, embedding_dim, hidden_dim, vocab_size_char, vocab_size_word, vocab_size_bio, vocab_size_attr, O_tag_index, use_crf): | ||
self.inputs_seq_char = tf.placeholder(tf.int32, [None, None], name="inputs_seq_char") | ||
self.inputs_seq_word = tf.placeholder(tf.int32, [None, None], name="inputs_seq_word") | ||
self.inputs_seq_len = tf.placeholder(tf.int32, [None], name="inputs_seq_len") | ||
self.outputs_seq_bio = tf.placeholder(tf.int32, [None, None], name='outputs_seq_bio') | ||
self.outputs_seq_attr = tf.placeholder(tf.int32, [None, None], name='outputs_seq_attr') | ||
|
||
with tf.variable_scope('embedding_layer'): | ||
embedding_matrix_char = tf.get_variable("embedding_matrix_char", [vocab_size_char, embedding_dim], dtype=tf.float32) | ||
embedding_matrix_word = tf.get_variable("embedding_matrix_word", [vocab_size_word, embedding_dim], dtype=tf.float32) | ||
embedded_char = tf.nn.embedding_lookup(embedding_matrix_char, self.inputs_seq_char) # B * S * D | ||
embedded_word = tf.nn.embedding_lookup(embedding_matrix_word, self.inputs_seq_word) # B * S * D | ||
embedded = tf.concat([embedded_char, embedded_word], axis=2) | ||
self.embedding_matrix_word = embedding_matrix_word | ||
|
||
with tf.variable_scope('encoder'): | ||
cell_fw = tf.nn.rnn_cell.LSTMCell(hidden_dim) | ||
cell_bw = tf.nn.rnn_cell.LSTMCell(hidden_dim) | ||
(rnn_fw_outputs, rnn_bw_outputs), (rnn_fw_final_state, rnn_bw_final_state) = \ | ||
tf.nn.bidirectional_dynamic_rnn(cell_fw=cell_fw, cell_bw=cell_bw, inputs=embedded, | ||
sequence_length=self.inputs_seq_len, dtype=tf.float32) | ||
rnn_outputs = tf.add(rnn_fw_outputs, rnn_bw_outputs) # B * S * D | ||
|
||
with tf.variable_scope('bio_projection'): | ||
logits_bio = tf.layers.dense(rnn_outputs, vocab_size_bio) # B * S * V | ||
probs_bio = tf.nn.softmax(logits_bio, axis=-1) | ||
if not use_crf: | ||
preds_bio = tf.argmax(probs_bio, axis=-1, name="preds_bio") # B * S | ||
else: | ||
log_likelihood, transition_matrix = tf.contrib.crf.crf_log_likelihood(logits_bio, self.outputs_seq_bio, self.inputs_seq_len) | ||
preds_bio, crf_scores = tf.contrib.crf.crf_decode(logits_bio, transition_matrix, self.inputs_seq_len) | ||
|
||
with tf.variable_scope('attr_projection'): | ||
logits_attr = tf.layers.dense(rnn_outputs, vocab_size_attr) # B * S * V | ||
probs_attr = tf.nn.softmax(logits_attr, axis=-1) | ||
preds_attr = tf.argmax(probs_attr, axis=-1, name="preds_attr") # B * S | ||
|
||
self.outputs = (preds_bio, preds_attr) | ||
with tf.variable_scope('loss'): | ||
if not use_crf: | ||
loss_bio = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits_bio, labels=self.outputs_seq_bio) # B * S | ||
masks_bio = tf.sequence_mask(self.inputs_seq_len, dtype=tf.float32) # B * S | ||
loss_bio = tf.reduce_sum(loss_bio * masks_bio, axis=-1) / tf.cast(self.inputs_seq_len, tf.float32) # B | ||
else: | ||
loss_bio = -log_likelihood / tf.cast(self.inputs_seq_len, tf.float32) | ||
|
||
loss_attr = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits_attr, labels=self.outputs_seq_attr) # B * S | ||
masks_attr = tf.cast(tf.not_equal(preds_bio, O_tag_index), tf.float32) # B * S | ||
loss_attr = tf.reduce_sum(loss_attr * masks_attr, axis=-1) / (tf.reduce_sum(masks_attr, axis=-1) + 1e-5) # B | ||
loss = loss_bio + loss_attr # B | ||
|
||
self.loss = tf.reduce_mean(loss) | ||
with tf.variable_scope('opt'): | ||
self.train_op = tf.train.AdamOptimizer().minimize(loss) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
import os | ||
import logging | ||
import tensorflow | ||
if tensorflow.__version__.startswith('1.'): | ||
import tensorflow as tf | ||
else: | ||
import tensorflow.compat.v1 as tf | ||
tf.disable_v2_behavior() | ||
|
||
from model_multitask_lstm import MyModel | ||
from utils import DataProcessor_MTL_LSTM as DataProcessor | ||
from utils import load_vocabulary | ||
from utils import extract_kvpairs_in_bioes | ||
from utils import cal_f1_score | ||
|
||
|
||
def valid(data_processor, max_batches=None, batch_size=1024): | ||
preds_kvpair = [] | ||
golds_kvpair = [] | ||
batches_sample = 0 | ||
|
||
while True: | ||
inputs_seq_batch, inputs_seq_len_batch, outputs_seq_bio_batch, \ | ||
outputs_seq_attr_batch = data_processor.get_batch(batch_size) | ||
|
||
feed_dict = {model.inputs_seq: inputs_seq_batch, model.inputs_seq_len: inputs_seq_len_batch, | ||
model.outputs_seq_bio: outputs_seq_bio_batch, model.outputs_seq_attr: outputs_seq_attr_batch} | ||
|
||
preds_seq_bio_batch, preds_seq_attr_batch = sess.run(model.outputs, feed_dict) | ||
for pred_seq_bio, gold_seq_bio, pred_seq_attr, gold_seq_attr, input_seq, l in zip(preds_seq_bio_batch, outputs_seq_bio_batch, | ||
preds_seq_attr_batch, outputs_seq_attr_batch, | ||
inputs_seq_batch, inputs_seq_len_batch): | ||
pred_seq_bio = [i2w_bio[i] for i in pred_seq_bio[:l]] | ||
gold_seq_bio = [i2w_bio[i] for i in gold_seq_bio[:l]] | ||
char_seq = [i2w_char[i] for i in input_seq[:l]] | ||
pred_seq_attr = [i2w_attr[i] for i in pred_seq_attr[:l]] | ||
gold_seq_attr = [i2w_attr[i] for i in gold_seq_attr[:l]] | ||
pred_kvpair = extract_kvpairs_in_bioes(pred_seq_bio, char_seq, pred_seq_attr) | ||
gold_kvpair = extract_kvpairs_in_bioes(gold_seq_bio, char_seq, gold_seq_attr) | ||
preds_kvpair.append(pred_kvpair) | ||
golds_kvpair.append(gold_kvpair) | ||
|
||
if data_processor.end_flag: | ||
data_processor.refresh() | ||
break | ||
batches_sample += 1 | ||
if (max_batches is not None) and (batches_sample >= max_batches): | ||
break | ||
|
||
p, r, f1 = cal_f1_score(preds_kvpair, golds_kvpair) | ||
logger.info("Valid Samples: {}".format(len(preds_kvpair))) | ||
logger.info("Valid P/R/F1: {} / {} / {}".format(round(p * 100, 2), round(r * 100, 2), round(f1 * 100, 2))) | ||
return p, r, f1 | ||
|
||
|
||
if __name__ == '__main__': | ||
ckpt_path = 'ckpts' | ||
if not os.path.exists(ckpt_path): | ||
os.makedirs(ckpt_path) | ||
|
||
# set logging | ||
log_file_path = os.path.join(ckpt_path, "train_multitask_lstm_run_log.txt") | ||
if os.path.exists(log_file_path): | ||
os.remove(log_file_path) | ||
|
||
logger = logging.getLogger() | ||
logger.setLevel(logging.INFO) | ||
formatter = logging.Formatter("%(asctime)s | %(message)s", "%Y-%m-%d %H:%M:%S") | ||
chlr = logging.StreamHandler() | ||
chlr.setFormatter(formatter) | ||
fhlr = logging.FileHandler(log_file_path) | ||
fhlr.setFormatter(formatter) | ||
logger.addHandler(chlr) | ||
logger.addHandler(fhlr) | ||
|
||
logger.info("loading vocab...") | ||
w2i_char, i2w_char = load_vocabulary("data/vocab_char.txt") | ||
w2i_bio, i2w_bio = load_vocabulary("data/vocab_bio.txt") | ||
w2i_attr, i2w_attr = load_vocabulary("data/vocab_attr.txt") | ||
|
||
logger.info("loading data...") | ||
data_processor_train = DataProcessor("data/train/input.seq.char", "data/train/output.seq.bio", "data/train/output.seq.attr", | ||
w2i_char, w2i_bio, w2i_attr, shuffling=True) | ||
data_processor_valid = DataProcessor("data/test/input.seq.char", "data/test/output.seq.bio", "data/test/output.seq.attr", | ||
w2i_char, w2i_bio, w2i_attr, shuffling=True ) | ||
|
||
logger.info("building model...") | ||
model = MyModel(embedding_dim=300, hidden_dim=300, vocab_size_char=len(w2i_char), vocab_size_bio=len(w2i_bio), | ||
vocab_size_attr=len(w2i_attr), O_tag_index=w2i_bio["O"], use_crf=False) | ||
|
||
logger.info("model params:") | ||
params_num_all = 0 | ||
for variable in tf.trainable_variables(): | ||
params_num = 1 | ||
for dim in variable.shape: | ||
params_num *= dim | ||
params_num_all += params_num | ||
logger.info("\t {} {} {}".format(variable.name, variable.shape, params_num)) | ||
logger.info("all params num: " + str(params_num_all)) | ||
|
||
logger.info("start training...") | ||
config = tf.ConfigProto() | ||
config.gpu_options.allow_growth = True | ||
|
||
with tf.Session(config=config) as sess: | ||
sess.run(tf.global_variables_initializer()) | ||
saver = tf.train.Saver(max_to_keep=10000) | ||
|
||
epoches = 0 | ||
losses = [] | ||
batches = 0 | ||
best_f1 = 0 | ||
batch_size = 32 | ||
|
||
while epoches < 20: | ||
inputs_seq_batch, inputs_seq_len_batch, outputs_seq_bio_batch, \ | ||
outputs_seq_attr_batch = data_processor_train.get_batch(batch_size) | ||
|
||
feed_dict = {model.inputs_seq: inputs_seq_batch, model.inputs_seq_len: inputs_seq_len_batch, | ||
model.outputs_seq_bio: outputs_seq_bio_batch, model.outputs_seq_attr: outputs_seq_attr_batch} | ||
|
||
if batches == 0: | ||
logger.info("###### shape of a batch #######") | ||
logger.info("input_seq: " + str(inputs_seq_batch.shape)) | ||
logger.info("input_seq_len: " + str(inputs_seq_len_batch.shape)) | ||
logger.info("output_seq_bio: " + str(outputs_seq_bio_batch.shape)) | ||
logger.info("output_seq_attr: " + str(outputs_seq_attr_batch.shape)) | ||
logger.info("\n###### preview a sample #######") | ||
logger.info("input_seq:" + " ".join([i2w_char[i] for i in inputs_seq_batch[0]])) | ||
logger.info("input_seq_len :" + str(inputs_seq_len_batch[0])) | ||
logger.info("output_seq_bio: " + " ".join([i2w_bio[i] for i in outputs_seq_bio_batch[0]])) | ||
logger.info("output_seq_attr: " + " ".join([i2w_attr[i] for i in outputs_seq_attr_batch[0]])) | ||
|
||
loss, _ = sess.run([model.loss, model.train_op], feed_dict) | ||
losses.append(loss) | ||
batches += 1 | ||
|
||
if data_processor_train.end_flag: | ||
data_processor_train.refresh() | ||
epoches += 1 | ||
|
||
if batches % 100 == 0: | ||
logger.info("") | ||
logger.info("Epoches: {}".format(epoches)) | ||
logger.info("Batches: {}".format(batches)) | ||
ave_loss = sum(losses) / len(losses) if len(losses) != 0 else 0.0 | ||
logger.info("Loss: {}".format(ave_loss)) | ||
losses = [] | ||
|
||
ckpt_save_path = os.path.join(ckpt_path, "train_multitask_lstm-model-{}_{}".format(batches, ave_loss)) | ||
logger.info("Path of ckpt: {}".format(ckpt_save_path)) | ||
saver.save(sess, ckpt_save_path) | ||
|
||
p, r, f1 = valid(data_processor_valid, max_batches=10) | ||
if f1 > best_f1: | ||
best_f1 = f1 | ||
logger.info("############# best performance now here ###############") |
Oops, something went wrong.