Skip to content

Commit

Permalink
chore: normalize code style and update notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
Sieluna committed Feb 7, 2025
1 parent 58ee020 commit 4bb2456
Show file tree
Hide file tree
Showing 11 changed files with 50 additions and 113 deletions.
1 change: 1 addition & 0 deletions notebooks/training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@
"encoder_depth: 4\n",
"eos_token: 2\n",
"epochs: 50\n",
"export_onnx: true\n",
"gamma: 0.9995\n",
"heads: 8\n",
"id: null\n",
Expand Down
8 changes: 4 additions & 4 deletions packages/phoebe.js
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,23 @@ function createVenv() {
}

function activeVenv() {
const activatePath = path.resolve(VENV_PATH, `bin/${process.platform === "win32" ? "Activate.ps1" : "activate"}`);
const activatePath = path.resolve(VENV_PATH, process.platform === "win32" ? "Scripts/Activate.ps1" : "bin/activate");

if (process.platform === "win32") {
execSync(activatePath, { stdio: "inherit" });
execSync(`powershell.exe "${activatePath}"`, { stdio: "inherit" });
} else {
execSync(`bash -c 'source "${activatePath}"'`, { stdio: "inherit" });
}
}

function install() {
const pipPath = path.resolve(VENV_PATH, `bin/${process.platform === "win32" ? "pip.exe" : "pip"}`);
const pipPath = path.resolve(VENV_PATH, process.platform === "win32" ? "Scripts/pip.exe" : "bin/pip");

execSync(`"${pipPath}" install -e "${WORKSPACE_PATH}"`, { stdio: "inherit" });
}

function execute(pkg) {
const pythonPath = path.resolve(VENV_PATH, `bin/${process.platform === "win32" ? "python.exe" : "python"}`);
const pythonPath = path.resolve(VENV_PATH, process.platform === "win32" ? "Scripts/python.exe" : "bin/python");
const pkgPath = path.resolve(WORKSPACE_PATH, pkg);

execSync(`"${pythonPath}" "${pkgPath}"`, { stdio: "inherit" });
Expand Down
2 changes: 1 addition & 1 deletion packages/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ dependencies = [
]

[tool.setuptools]
packages = ["train", "collector"]
packages = ["train", "collector"]
16 changes: 8 additions & 8 deletions packages/train/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def train(args):
model.load_state_dict(torch.load(args.load_chkpt, map_location=device))

def save_models(e, step=0):
torch.save(model.state_dict(), os.path.join(out_path, f'{args.name}_e{e+1:02d}_step{step:02d}.pth'))
torch.save(model.state_dict(), os.path.join(out_path, f'{args.name}_e{e + 1:02d}_step{step:02d}.pth'))
yaml.dump(dict(args), open(os.path.join(out_path, 'config.yaml'), 'w+'))
if args.export_onnx:
onnx_path = os.path.join(out_path, f'{args.name}_e{e + 1:02d}_step{step:02d}.onnx')
Expand All @@ -49,7 +49,7 @@ def save_models(e, step=0):
(dummy_img, dummy_tgt_seq),
onnx_path,
export_params=True,
opset_version=12,
opset_version=14,
do_constant_folding=True,
input_names=["input", "tgt_seq"],
output_names=["output"],
Expand All @@ -71,8 +71,8 @@ def save_models(e, step=0):
opt.zero_grad()
total_loss = 0
for j in range(0, len(im), microbatch):
tgt_seq, tgt_mask = seq['input_ids'][j:j+microbatch].to(device), seq['attention_mask'][j:j+microbatch].bool().to(device)
loss = model.data_parallel(im[j:j+microbatch].to(device), device_ids=args.gpu_devices, tgt_seq=tgt_seq, mask=tgt_mask)*microbatch/args.batchsize
tgt_seq, tgt_mask = (seq['input_ids'][j:j + microbatch].to(device), seq['attention_mask'][j:j + microbatch].bool().to(device))
loss = model.data_parallel(im[j:j + microbatch].to(device), device_ids=args.gpu_devices, tgt_seq=tgt_seq, mask=tgt_mask) * microbatch / args.batchsize
loss.backward() # data parallism loss is a vector
total_loss += loss.item()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
Expand All @@ -81,15 +81,15 @@ def save_models(e, step=0):
dset.set_description('Loss: %.4f' % total_loss)
if args.wandb:
wandb.log({'train/loss': total_loss})
if (i+1+len(dataloader)*e) % args.sample_freq == 0:
bleu_score, edit_distance, token_accuracy = evaluate(model, valdataloader, args, num_batches=int(args.valbatches*e/args.epochs), name='val')
if (i + 1 + len(dataloader) * e) % args.sample_freq == 0:
bleu_score, edit_distance, token_accuracy = evaluate(model, valdataloader, args, num_batches=int(args.valbatches * e / args.epochs), name='val')
if bleu_score > max_bleu and token_accuracy > max_token_acc:
max_bleu, max_token_acc = bleu_score, token_accuracy
save_models(e, step=i)
if (e+1) % args.save_freq == 0:
if (e + 1) % args.save_freq == 0:
save_models(e, step=len(dataloader))
if args.wandb:
wandb.log({'train/epoch': e+1})
wandb.log({'train/epoch': e + 1})
except KeyboardInterrupt:
if e >= 2:
save_models(e, step=i)
Expand Down
10 changes: 6 additions & 4 deletions packages/train/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class Im2LatexDataset:
data = defaultdict(lambda: [])

def __init__(self, equations=None, images=None, tokenizer=None, shuffle=True, batchsize=16, max_seq_len=1024,
max_dimensions=(1024, 512), min_dimensions=(32, 32), pad=False, keep_smaller_batches=False, test=False):
max_dimensions=(1024, 512), min_dimensions=(32, 32), pad=False, keep_smaller_batches=False,
test=False):
"""Generates a torch dataset from pairs of `equations` and `images`.
Args:
Expand Down Expand Up @@ -91,7 +92,7 @@ def __iter__(self):
info = np.array(self.data[k], dtype=object)
p = torch.randperm(len(info)) if self.shuffle else torch.arange(len(info))
for i in range(0, len(info), self.batchsize):
batch = info[p[i:i+self.batchsize]]
batch = info[p[i:i + self.batchsize]]
if len(batch.shape) == 1:
batch = batch[None, :]
if len(batch) < self.batchsize and not self.keep_smaller_batches:
Expand All @@ -108,7 +109,7 @@ def __next__(self):
if self.i >= self.size:
raise StopIteration
self.i += 1
return self.prepare_data(self.pairs[self.i-1])
return self.prepare_data(self.pairs[self.i - 1])

def prepare_data(self, batch):
"""loads images into memory
Expand Down Expand Up @@ -147,7 +148,7 @@ def prepare_data(self, batch):
return None, None
if self.pad:
h, w = images.shape[2:]
images = F.pad(images, (0, self.max_dimensions[0]-w, 0, self.max_dimensions[1]-h), value=1)
images = F.pad(images, (0, self.max_dimensions[0] - w, 0, self.max_dimensions[1] - h), value=1)
return tok, images

def _get_size(self):
Expand Down Expand Up @@ -232,6 +233,7 @@ def generate_tokenizer(equations, output, vocab_size):

if __name__ == '__main__':
import argparse

parser = argparse.ArgumentParser(description='Train model', add_help=False)
parser.add_argument('-i', '--images', type=str, nargs='+', default=None, help='Image folders')
parser.add_argument('-e', '--equations', type=str, nargs='+', default=None, help='equations text files')
Expand Down
2 changes: 1 addition & 1 deletion packages/train/dataset/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@
# alb.Sharpen()
ToTensorV2(),
]
)
)
20 changes: 9 additions & 11 deletions packages/train/models/hybrid.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
import torch.nn as nn

from einops import repeat
from timm.models.layers import StdConv2dSame
Expand All @@ -20,11 +19,11 @@ def forward_features(self, x):

cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
h, w = h//self.patch_size, w//self.patch_size
pos_emb_ind = repeat(torch.arange(h)*(self.width//self.patch_size-w), 'h -> (h w)', w=w)+torch.arange(h*w)
pos_emb_ind = torch.cat((torch.zeros(1), pos_emb_ind+1), dim=0).long()
h, w = h // self.patch_size, w // self.patch_size
pos_emb_ind = repeat(torch.arange(h) * (self.width // self.patch_size - w), 'h -> (h w)', w=w) + torch.arange(h * w)
pos_emb_ind = torch.cat((torch.zeros(1), pos_emb_ind + 1), dim=0).long()
x += self.pos_embed[:, pos_emb_ind]
#x = x + self.pos_embed
# x = x + self.pos_embed
x = self.pos_drop(x)

for blk in self.blocks:
Expand All @@ -35,15 +34,14 @@ def forward_features(self, x):


def get_encoder(args):
backbone = ResNetV2(
layers=args.backbone_layers, num_classes=0, global_pool='', in_chans=args.channels,
preact=False, stem_type='same', conv_layer=StdConv2dSame)
min_patch_size = 2**(len(args.backbone_layers)+1)
backbone = ResNetV2(layers=args.backbone_layers, num_classes=0, global_pool='', in_chans=args.channels,
preact=False, stem_type='same', conv_layer=StdConv2dSame)
min_patch_size = 2 ** (len(args.backbone_layers) + 1)

def embed_layer(**x):
ps = x.pop('patch_size', min_patch_size)
assert ps % min_patch_size == 0 and ps >= min_patch_size, 'patch_size needs to be multiple of %i with current backbone configuration' % min_patch_size
return HybridEmbed(**x, patch_size=ps//min_patch_size, backbone=backbone)
return HybridEmbed(**x, patch_size=ps // min_patch_size, backbone=backbone)

encoder = CustomVisionTransformer(
img_size=(args.max_height, args.max_width),
Expand All @@ -55,4 +53,4 @@ def embed_layer(**x):
num_heads=args.heads,
embed_layer=embed_layer
)
return encoder
return encoder
56 changes: 2 additions & 54 deletions packages/train/models/transformer.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,8 @@
import torch
import torch.nn.functional as F

from x_transformers import TransformerWrapper, Decoder
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper, top_k, top_p


class CustomARWrapper(AutoregressiveWrapper):
def __init__(self, *args, **kwargs):
super(CustomARWrapper, self).__init__(*args, **kwargs)

@torch.no_grad()
def generate(self, start_tokens, seq_len=256, eos_token=None, temperature=1., filter_logits_fn=top_k, filter_thres=0.9, **kwargs):
device = start_tokens.device
was_training = self.net.training
num_dims = len(start_tokens.shape)

if num_dims == 1:
start_tokens = start_tokens[None, :]

b, t = start_tokens.shape

self.net.eval()
out = start_tokens
mask = kwargs.pop('mask', None)
if mask is None:
mask = torch.full_like(out, True, dtype=torch.bool, device=out.device)

for _ in range(seq_len):
x = out[:, -self.max_seq_len:]
mask = mask[:, -self.max_seq_len:]
# print('arw:',out.shape)
logits = self.net(x, mask=mask, **kwargs)[:, -1, :]

if filter_logits_fn in {top_k, top_p}:
filtered_logits = filter_logits_fn(logits, thres=filter_thres)
probs = F.softmax(filtered_logits / temperature, dim=-1)

sample = torch.multinomial(probs, 1)

out = torch.cat((out, sample), dim=-1)
mask = F.pad(mask, (0, 1), value=True)

if eos_token is not None and (torch.cumsum(out == eos_token, 1)[:, -1] >= 1).all():
break

out = out[:, t:]

if num_dims == 1:
out = out.squeeze(0)

self.net.train(was_training)
return out

from x_transformers.autoregressive_wrapper import AutoregressiveWrapper

def get_decoder(args):
return CustomARWrapper(
return AutoregressiveWrapper(
TransformerWrapper(
num_tokens=args.num_tokens,
max_seq_len=args.max_seq_len,
Expand Down
7 changes: 3 additions & 4 deletions packages/train/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,13 @@ def data_parallel(self, x: torch.Tensor, device_ids, output_device=None, **kwarg
outputs = nn.parallel.parallel_apply(replicas, inputs, kwargs)
return nn.parallel.gather(outputs, output_device).mean()

def forward(self, x: torch.Tensor, tgt_seq: torch.Tensor, **kwargs):
def forward(self, x: torch.Tensor, tgt_seq: torch.Tensor, **kwargs):
encoded = self.encoder(x)
out = self.decoder(tgt_seq, context=encoded, **kwargs)
return out
return self.decoder(tgt_seq, context=encoded, **kwargs)

@torch.no_grad()
def generate(self, x: torch.Tensor, temperature: float = 0.25):
return self.decoder.generate((torch.LongTensor([self.args.bos_token]*len(x))[:, None]).to(x.device), self.args.max_seq_len,
return self.decoder.generate((torch.LongTensor([self.args.bos_token] * len(x))[:, None]).to(x.device), self.args.max_seq_len,
eos_token=self.args.eos_token, context=self.encoder(x), temperature=temperature)


Expand Down
25 changes: 8 additions & 17 deletions packages/train/models/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,12 @@


class ViTransformerWrapper(nn.Module):
def __init__(
self,
*,
max_width,
max_height,
patch_size,
attn_layers,
channels=1,
num_classes=None,
dropout=0.,
emb_dropout=0.
):
def __init__(self, *, max_width, max_height, patch_size, attn_layers, channels=1, num_classes=None, dropout=0., emb_dropout=0.):
super().__init__()
assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder'
assert max_width % patch_size == 0 and max_height % patch_size == 0, 'image dimensions must be divisible by the patch size'
dim = attn_layers.dim
num_patches = (max_width // patch_size)*(max_height // patch_size)
num_patches = (max_width // patch_size) * (max_height // patch_size)
patch_dim = channels * patch_size ** 2

self.patch_size = patch_size
Expand All @@ -36,7 +25,7 @@ def __init__(

self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim)
#self.mlp_head = FeedForward(dim, dim_out = num_classes, dropout = dropout) if exists(num_classes) else None
# self.mlp_head = FeedForward(dim, dim_out = num_classes, dropout = dropout) if exists(num_classes) else None

def forward(self, img, **kwargs):
p = self.patch_size
Expand All @@ -47,9 +36,11 @@ def forward(self, img, **kwargs):

cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
x = torch.cat((cls_tokens, x), dim=1)
h, w = torch.tensor(img.shape[2:])//p
pos_emb_ind = repeat(torch.arange(h)*(self.max_width//p-w), 'h -> (h w)', w=w)+torch.arange(h*w)
pos_emb_ind = torch.cat((torch.zeros(1), pos_emb_ind+1), dim=0).long()

h, w = torch.tensor(img.shape[2:]) // p
pos_emb_ind = repeat(torch.arange(h) * (self.max_width // p - w), 'h -> (h w)', w=w) + torch.arange(h * w)
pos_emb_ind = torch.cat((torch.zeros(1), pos_emb_ind + 1), dim=0).long()

x += self.pos_embedding[:, pos_emb_ind]
x = self.dropout(x)

Expand Down
16 changes: 7 additions & 9 deletions packages/train/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ def __init__(self, *args, **kwargs):
def step(self, *args, **kwargs):
pass

# helper functions from lucidrains


def exists(val):
return val is not None
Expand Down Expand Up @@ -115,23 +113,23 @@ def pad(img: Image, divable: int = 32) -> Image:
if data[..., -1].var() == 0:
data = (data[..., 0]).astype(np.uint8)
else:
data = (255-data[..., -1]).astype(np.uint8)
data = (data-data.min())/(data.max()-data.min())*255
data = (255 - data[..., -1]).astype(np.uint8)
data = (data - data.min()) / (data.max() - data.min()) * 255
if data.mean() > threshold:
# To invert the text to white
gray = 255*(data < threshold).astype(np.uint8)
gray = 255 * (data < threshold).astype(np.uint8)
else:
gray = 255*(data > threshold).astype(np.uint8)
data = 255-data
gray = 255 * (data > threshold).astype(np.uint8)
data = 255 - data

coords = cv2.findNonZero(gray) # Find all non-zero points (text)
a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
rect = data[b:b+h, a:a+w]
rect = data[b:b + h, a:a + w]
im = Image.fromarray(rect).convert('L')
dims = []
for x in [w, h]:
div, mod = divmod(x, divable)
dims.append(divable*(div + (1 if mod > 0 else 0)))
dims.append(divable * (div + (1 if mod > 0 else 0)))
padded = Image.new('L', dims, 255)
padded.paste(im, (0, 0, im.size[0], im.size[1]))
return padded
Expand Down

0 comments on commit 4bb2456

Please sign in to comment.