Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cliff investigation #5

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion fairseq/criterions/wav2vec_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import math
import logging

import torch
import torch.nn.functional as F
Expand All @@ -12,6 +13,7 @@
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.logging.meters import safe_round

logger = logging.getLogger(__name__)

@register_criterion('wav2vec')
class Wav2vecCriterion(FairseqCriterion):
Expand Down Expand Up @@ -41,9 +43,19 @@ def forward(self, model, sample, reduce=True, log_pred=False):
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
#torch.set_printoptions(profile="full")
#logger.info("{}".format(sample['net_input']['source']))
#torch.set_printoptions(profile="default")

net_output = model(**sample['net_input'])
logits = model.get_logits(net_output).float()
target = model.get_targets(sample, net_output)

torch.set_printoptions(profile="full")
logger.info("logits\n{}".format(logits))
minusinf = (logits == float("-inf")).sum(-1)
logger.info("minus infs\n{} / {}".format(minusinf, logits.size(-1)))
torch.set_printoptions(profile="default")

weights = None
if hasattr(model, 'get_target_weights') and not self.infonce:
Expand All @@ -55,6 +67,7 @@ def forward(self, model, sample, reduce=True, log_pred=False):

if self.infonce:
loss = F.cross_entropy(logits, target, reduction="sum" if reduce else "none",)
logger.info("cross entropy loss {}".format(loss))
else:
loss = F.binary_cross_entropy_with_logits(logits, target.float(), weights, reduction="sum" if reduce else "none",)

Expand All @@ -75,6 +88,14 @@ def forward(self, model, sample, reduce=True, log_pred=False):
loss += p
losses.append(p)

#llll = loss.item() if reduce else loss
#if llll / sample_size >= 3.0:
# import ptvsd
# ptvsd.enable_attach(('0.0.0.0', 7310))
# print("Attach debugger now")
# ptvsd.wait_for_attach()
# logger.info("Loss per sample {} >= 3.0!\n".format(llll))

logging_output = {
'loss': loss.item() if reduce else loss,
'ntokens': sample_size,
Expand Down Expand Up @@ -105,6 +126,8 @@ def forward(self, model, sample, reduce=True, log_pred=False):

logging_output["correct"] = corr
logging_output["count"] = count
logging_output["num_correct"] = net_output["num_correct"]
logging_output["num_all"] = net_output["num_all"]

if log_pred:
logging_output['logits'] = logits.cpu().numpy()
Expand All @@ -129,6 +152,12 @@ def reduce_metrics(logging_outputs) -> None:
total = sum(log.get("count", 0) for log in logging_outputs)
metrics.log_scalar("_total", total)

num_correct = sum(log.get("num_correct", 0) for log in logging_outputs)
metrics.log_scalar("num_correct", num_correct)

num_all = sum(log.get("num_all", 0) for log in logging_outputs)
metrics.log_scalar("num_all", num_all)


if total > 0:
metrics.log_derived(
Expand All @@ -137,8 +166,14 @@ def reduce_metrics(logging_outputs) -> None:
if meters["_total"].sum > 0
else float("nan"),
)
metrics.log_derived(
"accuracy_2",
lambda meters: safe_round(meters["num_correct"].sum / meters["num_all"].sum, 5)
if meters["num_all"].sum > 0
else float("nan"),
)

builtin_keys = {'loss', 'ntokens', 'nsentences', 'sample_size', 'correct', 'count'}
builtin_keys = {'loss', 'ntokens', 'nsentences', 'sample_size', 'correct', 'count', 'num_correct', 'num_all'}

for k in logging_outputs[0]:
if k not in builtin_keys:
Expand Down
52 changes: 50 additions & 2 deletions fairseq/models/wav2vec/wav2vec2_scribblelens.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from fairseq.modules.transformer_sentence_encoder import init_bert_params
from fairseq.utils import buffered_arange

logger = logging.getLogger(__name__)

@register_model("wav2vec2_scribblelens")
class Wav2Vec2ModelSL(BaseFairseqModel):
@staticmethod
Expand Down Expand Up @@ -502,6 +504,11 @@ def sample_negatives(self, y, num):
if self.cross_sample_negatives > 0 and self.n_negatives > 0:
neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1)

torch.set_printoptions(profile="full")
logger.info("neg_idxs:\n{}".format(neg_idxs))
#logger.info("neg_idxs unique:\n{}".format(torch.unique(neg_idxs, sorted=False).size()))
torch.set_printoptions(profile="default")

negs = y[neg_idxs.view(-1)]
negs = negs.view(
bsz, num, self.n_negatives + self.cross_sample_negatives, fsz
Expand All @@ -511,8 +518,15 @@ def sample_negatives(self, y, num):
return negs, neg_idxs

def compute_preds(self, x, y, negatives):

neg_is_pos = (y == negatives).all(-1)
#torch.set_printoptions(profile="full")
#logger.info("y:\n{}".format(y))
#logger.info("negatives:\n{}".format(negatives))
#logger.info("neg_is_pos:\n{}".format(neg_is_pos))
#torch.set_printoptions(profile="default")


y = y.unsqueeze(0)
targets = torch.cat([y, negatives], dim=0)

Expand Down Expand Up @@ -557,6 +571,10 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False):

features = self.dropout_input(features)
unmasked_features = self.dropout_features(unmasked_features)

#torch.set_printoptions(profile="full")
#logger.info("unmasked_features:\n{}".format(unmasked_features))
#torch.set_printoptions(profile="default")

num_vars = None
code_ppl = None
Expand Down Expand Up @@ -594,24 +612,32 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False):
return {"x": x, "padding_mask": padding_mask}

if self.quantizer:
q = self.quantizer(y, produce_targets=False)
q = self.quantizer(y, produce_targets=True)
y = q["x"]
num_vars = q["num_vars"]
code_ppl = q["code_perplexity"]
prob_ppl = q["prob_perplexity"]
curr_temp = q["temp"]
targets = q["targets"]

torch.set_printoptions(profile="full")
logger.info("quantizer targets:\n{}".format(targets))
torch.set_printoptions(profile="default")

y = self.project_q(y)

if self.negatives_from_everywhere:
logger.info("negatives_from_everywhere")
neg_cands, *_ = self.quantizer(unmasked_features, produce_targets=False)
negs, _ = self.sample_negatives(neg_cands, y.size(1))
negs = self.project_q(negs)

else:
logger.info("negatives_from_everywhere else block")
negs, _ = self.sample_negatives(y, y.size(1))

if self.codebook_negatives > 0:
logger.info("codebook_negatives > 0")
cb_negs = self.quantizer.sample_from_codebook(
y.size(0) * y.size(1), self.codebook_negatives
)
Expand All @@ -630,15 +656,22 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False):
negs, _ = self.sample_negatives(y, y.size(1))

x = x[mask_indices].view(x.size(0), -1, x.size(-1))

#torch.set_printoptions(profile="full")
#logger.info("mask_indices:\n{}".format(mask_indices))
#torch.set_printoptions(profile="default")

if self.target_glu:
y = self.target_glu(y)
negs = self.target_glu(negs)

x = self.final_proj(x)
num_correct, num_all = self.miara_acc(x, y)
x = self.compute_preds(x, y, negs)

result = {"x": x, "padding_mask": padding_mask, "features_pen": features_pen}
result["num_correct"] = num_correct
result["num_all"] = num_all

if prob_ppl is not None:
result["prob_perplexity"] = prob_ppl
Expand All @@ -648,6 +681,21 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False):

return result

# x - encoder vectors
# y - quantizer vectors
def miara_acc(self, x, y):
x_size1 = x.size(1)
xx = x.repeat_interleave(x.size(1), 1) # BxTxC -> BxT^2xC
yy = y.repeat(1, y.size(1), 1) # BxTxC -> BxT^2xC
cos = torch.cosine_similarity(xx.float(), yy.float(), dim=-1).view(-1, x.size(1), x.size(1)) # BxT^2 -> BxTxT
maxi, _ = cos.max(dim=2, keepdim=True) # BxTx1
maxi = maxi * torch.eye(x.size(1), device=x.device).expand(cos.size(0), -1, -1) # BxTxT

num_correct = (cos == maxi).sum().item()
num_all = x.size(0) * x.size(1)
logger.info("{} / {}".format(num_correct, num_all))
return num_correct, num_all

def quantize(self, x):
assert self.quantizer is not None
x = self.feature_extractor(x)
Expand Down
12 changes: 12 additions & 0 deletions fairseq/modules/gumbel_vector_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
import torch
import torch.nn as nn
import torch.nn.functional as F

logger = logging.getLogger(__name__)

class GumbelVectorQuantizer(nn.Module):
def __init__(
Expand Down Expand Up @@ -153,6 +155,12 @@ def forward(self, x, produce_targets=False):
.view(bsz * tsz, self.groups, -1)
)
hard_probs = torch.mean(hard_x.float(), dim=0)

torch.set_printoptions(profile="full")
logger.info("hard_probs:\n{}".format(hard_probs))
#logger.info("hard_probs unique:\n{}".format(torch.unique(hard_probs, sorted=False).size()))
torch.set_printoptions(profile="default")

result["code_perplexity"] = torch.exp(
-torch.sum(hard_probs * torch.log(hard_probs + 1e-7), dim=-1)
).sum()
Expand All @@ -164,6 +172,10 @@ def forward(self, x, produce_targets=False):
-torch.sum(avg_probs * torch.log(avg_probs + 1e-7), dim=-1)
).sum()

#torch.set_printoptions(profile="full")
#logger.info("avg_probs:\n{}".format(avg_probs))
#torch.set_printoptions(profile="default")

result["temp"] = self.curr_temp

if self.training:
Expand Down
1 change: 1 addition & 0 deletions fairseq/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,7 @@ def maybe_no_sync():
try:
with maybe_no_sync():
# forward and backward
logger.info("Batch {}/52".format(i))
loss, sample_size_i, logging_output = self.task.train_step(
sample=sample,
model=self.model,
Expand Down
10 changes: 10 additions & 0 deletions fairseq_cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,16 @@ def train(args, trainer, task, epoch_itr):
default_log_format=("tqdm" if not args.no_progress_bar else "simple"),
)

#torch.set_printoptions(profile="full")
for name, p in trainer.model.named_parameters():
if "quantizer.vars" in name or "quantizer.weight_proj.weight" in name or "project_q.weight" in name:
torch.set_printoptions(profile="full")
logger.info("{}\n{}".format(name, p.data))
torch.set_printoptions(profile="default")
else:
logger.info("{}\n{}".format(name, p.data))
#torch.set_printoptions(profile="default")

trainer.begin_epoch(epoch_itr.epoch)

valid_losses = [None]
Expand Down
8 changes: 4 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@


if __name__ == '__main__':
# import ptvsd
# ptvsd.enable_attach(('0.0.0.0', 7321))
# print("Attach debugger now")
# ptvsd.wait_for_attach()
#import ptvsd
#ptvsd.enable_attach(('0.0.0.0', 7310))
#print("Attach debugger now")
#ptvsd.wait_for_attach()
cli_main()
39 changes: 36 additions & 3 deletions uwr_related/experiments/jch/scrib.sh
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,8 +1,39 @@
#python train.py --distributed-world-size 1 --update-freq 2 \
# /pio/scratch/2/jch/wav2vec/data/scribblelens \
# --save-dir /pio/lscratch/1/jch/fairseq/try_sl2 --num-workers 0 \
# --keep-last-epochs 3 \
# --tensorboard-logdir /pio/scratch/2/jch/wav2vec/runs/try_sl2 --log-format simple \
# --task scribblelens --criterion wav2vec --arch wav2vec2_scribblelens \
# --valid-subset test --pad-to-multiples-of 4 `#--max-sample-size 256` \
# --log-keys '["prob_perplexity","code_perplexity","temp"]' --quantize-targets --extractor-mode default \
# --conv-feature-layers '[(64, (3, 3), (1, 2), (1, 1)), (128, (5, 5), (2, 2), (2, 2)), (256, (3,3), (1, 1), (1, 1)), (256, (3,3), (1, 2), (1, 1)), (512, (3,3), (1, 1), (1, 1)), (512, (3,3), (1, 2), (1, 1)), (512, (3,2), (2, 1), (1, 0))]' \
# --final-dim 256 \
# --latent-vars 320 --latent-groups 2 --latent-temp '(2,0.5,0.999995)' --infonce \
# --optimizer adam --adam-betas '(0.9,0.98)' --adam-eps 1e-06 --lr-scheduler polynomial_decay \
# --total-num-update 400000 --lr 0.0005 --warmup-updates 32000 \
# --mask-length 10 --mask-prob 0.65 --mask-selection static --mask-other 0 \
# --encoder-layerdrop 0.05 --dropout-input 0.1 --dropout-features 0.1 --feature-grad-mult 0.1 \
# --loss-weights '[0.1, 10]' --conv-pos 128 --conv-pos-groups 16 \
# --num-negatives 100 --cross-sample-negatives 0 \
# `#--max-sample-size 250000 --min-sample-size 32000` \
# --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --max-tokens 10000 --max-update 400000 \
# --skip-invalid-size-inputs-valid-test --ddp-backend no_c10d \
# --enable-padding # crashes without that, needs to make all lines same-size

RUN=$1
NUM=${RUN:5:3}
NUMVER=${RUN#"debug"}
echo $RUN $NUM $NUMVER

mkdir -p /pio/scratch/1/i273233/runs/$RUN
ln -s /pio/scratch/1/i273233/runs/try_sl3/checkpoint$NUM.pt /pio/scratch/1/i273233/runs/$RUN/before.pt

python train.py --distributed-world-size 1 --update-freq 2 \
/pio/scratch/2/jch/wav2vec/data/scribblelens \
--save-dir /pio/lscratch/1/jch/fairseq/try_sl2 --num-workers 0 \
--save-dir /pio/scratch/1/i273233/runs/$RUN --num-workers 0 \
--keep-last-epochs 3 \
--tensorboard-logdir /pio/scratch/2/jch/wav2vec/runs/try_sl2 --log-format simple \
--restore-file /pio/scratch/1/i273233/runs/$RUN/before.pt \
--tensorboard-logdir /pio/scratch/1/i273233/runs/$RUN --log-format simple \
--task scribblelens --criterion wav2vec --arch wav2vec2_scribblelens \
--valid-subset test --pad-to-multiples-of 4 `#--max-sample-size 256` \
--log-keys '["prob_perplexity","code_perplexity","temp"]' --quantize-targets --extractor-mode default \
Expand All @@ -18,4 +49,6 @@ python train.py --distributed-world-size 1 --update-freq 2 \
`#--max-sample-size 250000 --min-sample-size 32000` \
--dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --max-tokens 10000 --max-update 400000 \
--skip-invalid-size-inputs-valid-test --ddp-backend no_c10d \
--enable-padding # crashes without that, needs to make all lines same-size
--enable-padding \
`# crashes without that, needs to make all lines same-size` \
> ../logfile$NUMVER.txt
Loading