Skip to content

Commit

Permalink
add context and word state loss
Browse files Browse the repository at this point in the history
  • Loading branch information
liuzhuang1024 committed Aug 29, 2023
1 parent 472682d commit 187524c
Show file tree
Hide file tree
Showing 15 changed files with 1,155 additions and 342 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,8 @@ logs
datasets

.DS_Store
*.log
.vscode
logs_v2
heatmap
184_checkpoints
.ipynb_checkpoints
674 changes: 674 additions & 0 deletions LICENSE

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# SAM
[Semantic Graph Representation Learning for Handwritten Mathematical Expression Recognition (ICDAR 2023)](https://link.springer.com/chapter/10.1007/978-3-031-41676-7_9)

The code will be released after the ICDAR2023 meeting!

32 changes: 21 additions & 11 deletions config.yaml → config_v2.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,28 @@
# 实验名称
experiment: "v1_l2-loss" # "ori_CAN_with_counting_predicted_class_info_in_full_crohme_with_ori_train_with_gpus1_bs8_lr1" # "ori_CAN_with_counting_predicted_class_info_in_full_crohme_with_gpus1_bs8_lr1" # "ori_CAN_with_counting_predicted_class_info_in_full_crohme_with_gpus1_bs8_lr1_ddp" # "ori_CAN_in_full_crohme"
experiment: "can-l2-context-word" # "ori_CAN_with_counting_predicted_class_info_in_full_crohme_with_ori_train_with_gpus1_bs8_lr1" # "ori_CAN_with_counting_predicted_class_info_in_full_crohme_with_gpus1_bs8_lr1" # "ori_CAN_with_counting_predicted_class_info_in_full_crohme_with_gpus1_bs8_lr1_ddp" # "ori_CAN_in_full_crohme"

sim_loss:
type: l2
use_flag: True
context_loss: False
word_state_loss: False


counting_decoder:
use_flag: True
in_channel: 684
out_channel: 111

# 随机种子
seed: 20211024

# 训练参数
epochs: 200
batch_size: 8 # 8
workers: 0 # 0
workers: 5 # 0
train_parts: 1
valid_parts: 1
valid_start: 20 # 1000000000
valid_start: 100 # 1000000000
save_start: 0 # 220

optimizer: Adadelta
Expand All @@ -22,9 +34,9 @@ eps: 1e-6
weight_decay: 1e-4
beta: 0.9

output_counting_feature: False
output_channel_attn_feature: False
counting_loss_ratio: 1
# output_counting_feature: False
# output_channel_attn_feature: False
# counting_loss_ratio: 1

dropout: True
dropout_ratio: 0.5
Expand Down Expand Up @@ -69,7 +81,7 @@ encoder:
out_channel: 684

decoder:
net: Decoder_v1
net: Decoder_v3
cell: 'GRU'
input_size: 256
hidden_size: 256
Expand All @@ -79,15 +91,13 @@ attention:
attention_dim: 512
word_conv_kernel: 1

sim_loss:
type: l2

whiten_type: None
max_step: 256

optimizer_save: False
finetune: False
finetune: True
checkpoint_dir: 'checkpoints'
checkpoint: ""
log_dir: 'logs'
log_dir: 'logs_v2'

22 changes: 22 additions & 0 deletions counting_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import os

def gen_counting_label(labels, channel, tag):
b, t = labels.size()
counting_labels = torch.zeros((b, channel))
if tag:
ignore = [0, 1, 107, 108, 109, 110]
else:
ignore = []
for i in range(b):
for j in range(t):
k = labels[i][j]
if k in ignore:
continue
else:
counting_labels[i][k] += 1
return counting_labels.detach()
32 changes: 30 additions & 2 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
import pickle as pkl
from torch.utils.data import DataLoader, Dataset, RandomSampler, DistributedSampler
from counting_utils import gen_counting_label


class HMERDataset(Dataset):
Expand Down Expand Up @@ -33,6 +34,10 @@ def __init__(self, params, image_path, label_path, words, is_train=True, use_aug
self.reverse_color = self.params['data_process']['reverse_color'] if 'data_process' in params else False
self.equal_range = self.params['data_process']['equal_range'] if 'data_process' in params else False

with open(self.params['matrix_path'], 'rb') as f:
matrix = pkl.load(f)
self.matrix = torch.Tensor(matrix)

def __len__(self):
# assert len(self.images) == len(self.labels)
return len(self.labels)
Expand All @@ -55,7 +60,28 @@ def __getitem__(self, idx):
words = self.words.encode(labels) + [0]
words = torch.LongTensor(words)
return image, words


def gen_matrix(self, labels):
(B, L), device = labels.shape, labels.device
matrix = []
for i in range(B):
_L = []
label = labels[i]
for x in range(L):
_T = []
for y in range(L):
if x == y:
_T.append(1.)
else:
if label[x] == label[y] or label[x] == 0 or label[y] == 0:
_T.append(0.)
else:
_T.append(self.matrix[label[x], label[y]])
_L.append(_T)
matrix.append(_L)
matrix = torch.tensor(matrix).to(device)
return matrix.detach()

def collate_fn(self, batch_images):
max_width, max_height, max_length = 0, 0, 0
batch, channel = len(batch_images), batch_images[0][0].shape[0]
Expand All @@ -80,7 +106,9 @@ def collate_fn(self, batch_images):
l = proper_items[i][1].shape[0]
labels[i][:l] = proper_items[i][1]
labels_masks[i][:l] = 1
return images, image_masks, labels, labels_masks
matrix = self.gen_matrix(labels)
counting_labels = gen_counting_label(labels, self.params['counting_decoder']['out_channel'], True)
return images, image_masks, labels, labels_masks, matrix, counting_labels


def get_crohme_dataset(params):
Expand Down
122 changes: 0 additions & 122 deletions inference.py

This file was deleted.

5 changes: 4 additions & 1 deletion models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from models.decoder.decoder_v1 import Decoder_v1
from models.decoder.decoder_v1 import Decoder_v1
from models.decoder.decoder_v2 import Decoder_v1 as Decoder_v2
from models.decoder.decoder_v3 import Decoder_v1 as Decoder_v3
from models.counting import CountingDecoder as counting_decoder
62 changes: 51 additions & 11 deletions models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
import models
from models.densenet import DenseNet

from einops.layers.torch import Rearrange
from traceback import print_exc

class Model(nn.Module):
def __init__(self, params=None):
def __init__(self, params={}):
super(Model, self).__init__()
self.params = params

Expand All @@ -21,11 +23,6 @@ def __init__(self, params=None):

self.use_label_mask = params['use_label_mask']
self.encoder = DenseNet(params=self.params)
# self.in_channel = params['counting_decoder']['in_channel']
# self.out_channel = params['counting_decoder']['out_channel']

# self.output_counting_feature = params['output_counting_feature'] if 'output_counting_feature' in params else False
# self.channel_attn_feature = params['output_channel_attn_feature'] if 'output_channel_attn_feature' in params else False

self.decoder = getattr(models, params['decoder']['net'])(params=self.params)
self.cross = nn.CrossEntropyLoss(reduction='none') if self.use_label_mask else nn.CrossEntropyLoss()
Expand All @@ -35,19 +32,62 @@ def __init__(self, params=None):
"""经过cnn后 长宽与原始尺寸比缩小的比例"""
self.ratio = params['densenet']['ratio']

def forward(self, images, images_mask, labels, labels_mask, is_train=True):
if self.params['context_loss'] or self.params['word_state_loss']:
self.cma_context = nn.Sequential(
nn.Linear(params['encoder']['out_channel'], params['decoder']['input_size']),
Rearrange("b l h->b h l"),
nn.BatchNorm1d(params['decoder']['input_size']),
Rearrange("b h l->b l h"),
nn.ReLU()
)
self.cma_word = nn.Sequential(
nn.Linear(params['decoder']['input_size'], params['decoder']['input_size']),
Rearrange("b l h->b h l"),
nn.BatchNorm1d(params['decoder']['input_size']),
Rearrange("b h l->b l h"),
nn.ReLU()
)

def forward(self, images, images_mask, labels, labels_mask, matrix=None, counting_labels=None, is_train=True):
cnn_features = self.encoder(images)

word_probs, word_alphas, embedding = self.decoder(cnn_features, labels, images_mask, labels_mask, is_train=is_train)

word_probs, word_alphas, embedding = self.decoder(cnn_features, labels, images_mask, labels_mask, counting_labels=counting_labels, is_train=is_train)

context_loss, word_state_loss, word_sim_loss, counting_loss = 0, 0, 0, 0
embedding, word_context_vec_list, word_out_state_list, _, counting_loss = embedding
if self.params['context_loss'] or self.params['word_state_loss'] and is_train:
if 'context_loss' in self.params and self.params['context_loss']:
word_context_vec_list = torch.stack(word_context_vec_list, 1)
context_embedding = self.cma_context(word_context_vec_list)
context_loss = self.cal_cam_loss_v2(context_embedding, labels, matrix)
if 'word_state_loss' in self.params and self.params['word_state_loss']:
word_out_state_list = torch.stack(word_out_state_list, 1)
word_state_embedding = self.cma_word(word_out_state_list)
word_state_loss = self.cal_cam_loss_v2(word_state_embedding, labels, matrix)

word_loss = self.cross(word_probs.contiguous().view(-1, word_probs.shape[-1]), labels.view(-1))
word_average_loss = (word_loss * labels_mask.view(-1)).sum() / (labels_mask.sum() + 1e-10) if self.use_label_mask else word_loss

word_sim_loss = self.cal_word_similarity(embedding)
if 'sim_loss' in self.params and self.params['sim_loss']['use_flag']:
word_sim_loss = self.cal_word_similarity(embedding)

return word_probs, (word_average_loss, word_sim_loss)
return word_probs, (word_average_loss, word_sim_loss, context_loss, word_state_loss, counting_loss)


def cal_cam_loss_v2(self, word_embedding, labels, matrix):
(B, L, H), device = word_embedding.shape, word_embedding.device

W = torch.matmul(word_embedding, word_embedding.transpose(-1, -2)) # B L L
denom = torch.matmul(word_embedding.unsqueeze(-2), word_embedding.unsqueeze(-1)).squeeze(-1) ** (0.5)
# B L 1 H @ B L H 1 -> B L 1 1
cosine = W / (denom @ denom.transpose(-1, -2))
sim_mask = matrix != 0
if self.sim_loss_type == 'l1':
loss = abs((cosine - matrix) * sim_mask)
else:
loss = (cosine - matrix) ** 2 * sim_mask
return loss.sum() / B / (labels != 0).sum()

def cal_word_similarity(self, word_embedding):

num = word_embedding @ word_embedding.transpose(1,0)
Expand Down
Loading

0 comments on commit 187524c

Please sign in to comment.