Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Smartpooling #7

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 61 additions & 2 deletions fairseq/models/wav2vec/wav2vec2_scribblelens.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,17 @@ def add_args(parser):
"--conv-bias", action="store_true", help="include bias in conv encoder"
)

parser.add_argument(
"--smartpooling", action="store_true", help="whether to perform smartpooling"
)

parser.add_argument(
"--smartpooling-factor",
type=float,
default=3,
help="factor by which the sequence's length will be reduced in smartpooling"
)

def __init__(self, args):
super().__init__()
self.args = args
Expand All @@ -312,6 +323,9 @@ def __init__(self, args):
conv_bias=args.conv_bias,
)

self.smartpooling = args.smartpooling
self.smartpooling_factor = args.smartpooling_factor
self.smartpooling_filters = torch.tensor([[[[-1,1],[1,-1]]]]).float()
self.post_extract_proj = (
nn.Linear(self.embed, args.encoder_embed_dim)
if self.embed != args.encoder_embed_dim and not args.quantize_input
Expand Down Expand Up @@ -403,6 +417,11 @@ def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
return state_dict

def to(self, *args, **kwargs):
self = super().to(*args, **kwargs)
self.smartpooling_filters = self.smartpooling_filters.to(*args, **kwargs)
janchorowski marked this conversation as resolved.
Show resolved Hide resolved
return self

@classmethod
def build_model(cls, args, task=None):
"""Build a new model instance."""
Expand Down Expand Up @@ -525,6 +544,43 @@ def compute_preds(self, x, y, negatives):

return logits

def smartpool(self, features, padding_mask=None):
features_tmp = F.pad(features,(0,0,1,0))
new_lens = (features_tmp[:,1:,:] - features_tmp[:,:-1,:]).abs().sum(dim=2)
new_lens = new_lens / new_lens.sum(1, keepdim=True) * (features.size(1) / self.smartpooling_factor) # Reducing the original length T by some factor

features, interp_weights = self.warp(features, new_lens)
if padding_mask is not None :
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Padding mask oznacza które kolumny są poza danymi, więc paddingu nie skracamy, to co trzeba zrobić to skrócić te cześci unpadded, a potem sprawdzić który jest najdłuższy i spadować resztę.

padding_mask = padding_mask.unsqueeze(2)
padding_mask = interp_weights @ padding_mask.float()
padding_mask = (padding_mask > 0).squeeze(2)

return features, padding_mask

def warp(self, X, new_lens):
new_lens_cs = new_lens.cumsum(1)
# This really searches for the low boundary of each new pixel
pixel_contributions = new_lens_cs.view(1, -1, 1) - torch.arange(torch.round(new_lens_cs[0, -1]).item(), device=X.device).view(1, 1, -1)
pixel_contributions = pixel_contributions.view(X.size(0), X.size(1), pixel_contributions.size(2))
# Zero out the negative contributions, i.e. pixels which come before each row
pixel_contributions = torch.max(torch.tensor(0.0, device=X.device), pixel_contributions)

# # This contains the cumulated pixel lengths for all pixels in each
# pixel_contributions

pixel_contributions = pixel_contributions.unsqueeze(1)
interp_weights = F.conv2d(pixel_contributions, self.smartpooling_filters, padding=1)
interp_weights = interp_weights[:,:,:-1,1:] # Removing padding
interp_weights = interp_weights.squeeze(1)

# # Each column corresponds to a new element. Its values are the
# # weights associated with the original data.
# interp_weights

interp_weights = interp_weights.transpose(1, 2)
Xnew = interp_weights @ X
return Xnew, interp_weights

def forward(self, source, padding_mask=None, mask=True, features_only=False):
# padding_mask = None # JCh: padding_mask prob need to be True where the data is padded. mask=True => data invalid

Expand All @@ -541,8 +597,7 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False):

features = features.transpose(1, 2)
features = self.layer_norm(features)
unmasked_features = features.clone()


if padding_mask is not None:
assert padding_mask.size(1) == 1
padding_mask = padding_mask.squeeze(1)
Expand All @@ -552,6 +607,10 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False):
padding_mask = padding_mask[:, ::scale]
assert np.all(padding_mask.shape == features.shape[:-1])

if self.smartpooling:
features, padding_mask = self.smartpool(features, padding_mask=padding_mask)
unmasked_features = features.clone()

if self.post_extract_proj is not None:
features = self.post_extract_proj(features)

Expand Down
30 changes: 30 additions & 0 deletions uwr_related/experiments/jdzikowski/scrib.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# [!] needs to be run from fairseq main folder
RUN="smartpooling"
RUNDIR="/pio/scratch/1/i273233/runs"
mkdir -p $RUNDIR/$RUN

python train.py --distributed-world-size 1 --update-freq 2 \
/pio/scratch/1/i283340/MGR/NewSetup/DistSup/data `#path to Scribblelens data folder` \
--vocab-path ./fairseq/data/handwriting/tasman.alphabet.plus.space.mode5.json `#alphabet file` \
--save-dir $RUNDIR/$RUN --num-workers 0 \
`#--restore-file $RUNDIR/$RUN/before.pt` \
--keep-last-epochs 3 \
--tensorboard-logdir $RUNDIR/$RUN --log-format simple \
--task scribblelens --criterion wav2vec --arch wav2vec2_scribblelens \
--valid-subset test --pad-to-multiples-of 4 `#--max-sample-size 256` \
--log-keys '["prob_perplexity","code_perplexity","temp"]' --quantize-targets --extractor-mode default \
--conv-feature-layers '[(64, (3, 3), (1, 2), (1, 1)), (128, (5, 5), (2, 2), (2, 2)), (256, (3,3), (1, 1), (1, 1)), (256, (3,3), (1, 2), (1, 1)), (512, (3,3), (1, 1), (1, 1)), (512, (3,3), (1, 2), (1, 1)), (512, (3,2), (2, 1), (1, 0))]' \
--final-dim 256 \
--latent-vars 320 --latent-groups 2 --latent-temp '(2,0.5,0.999995)' --infonce \
--optimizer adam --adam-betas '(0.9,0.98)' --adam-eps 1e-06 --lr-scheduler polynomial_decay \
--total-num-update 400000 --lr 0.0005 --warmup-updates 32000 \
--mask-length 10 --mask-prob 0.65 --mask-selection static --mask-other 0 \
--encoder-layerdrop 0.05 --dropout-input 0.1 --dropout-features 0.1 --feature-grad-mult 0.1 \
--loss-weights '[0.1, 10]' --conv-pos 128 --conv-pos-groups 16 \
--num-negatives 100 --cross-sample-negatives 0 \
`#--max-sample-size 250000 --min-sample-size 32000` \
--dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --max-tokens 10000 --max-update 400000 \
--skip-invalid-size-inputs-valid-test --ddp-backend no_c10d \
--labels `#can be removed for no labels` \
--enable-padding `# crashes without that, needs to make all lines same-size` \
--smartpooling --smartpooling-factor 3.0