Skip to content

Commit

Permalink
Update dependencies (#1351)
Browse files Browse the repository at this point in the history
* versions

* black
  • Loading branch information
zphang authored Jun 23, 2022
1 parent f9e0e7c commit e6d9062
Show file tree
Hide file tree
Showing 58 changed files with 553 additions and 185 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ orbs:
jobs:
test:
docker:
- image: python:3.7
- image: python:3.8
steps:
- checkout
- restore_cache:
Expand Down
9 changes: 7 additions & 2 deletions jiant/proj/main/components/container_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,11 @@ def create_task_dict(task_config_dict: dict, verbose: bool = True) -> Dict[str,
if not task.name == task_name:
warnings.warn(
"task {} from {} has conflicting names: {}/{}. Using {}".format(
task_name, task_config_path, task_name, task.name, task_name,
task_name,
task_config_path,
task_name,
task.name,
task_name,
)
)
task.name = task_name
Expand Down Expand Up @@ -205,7 +209,8 @@ def create_jiant_task_container(
task_run_config = TaskRunConfig.from_dict(task_run_config)

num_train_examples_dict = get_num_train_examples(
task_cache_dict=task_cache_dict, train_task_list=task_run_config.train_task_list,
task_cache_dict=task_cache_dict,
train_task_list=task_run_config.train_task_list,
)
task_sampler = jiant_task_sampler.create_task_sampler(
sampler_config=sampler_config,
Expand Down
3 changes: 2 additions & 1 deletion jiant/proj/main/components/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
def write_val_results(val_results_dict, metrics_aggregator, output_dir, verbose=True):
full_results_to_write = {
"aggregated": jiant_task_sampler.compute_aggregate_major_metrics_from_results_dict(
metrics_aggregator=metrics_aggregator, results_dict=val_results_dict,
metrics_aggregator=metrics_aggregator,
results_dict=val_results_dict,
),
}
for task_name, task_results in val_results_dict.items():
Expand Down
7 changes: 5 additions & 2 deletions jiant/proj/main/components/task_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ def get_task_p(self, steps=None) -> np.ndarray:

for i, task_name in enumerate(self.task_names):
p_ls[i] = numexpr.evaluate(
self.task_to_unnormalized_prob_funcs_dict[task_name], local_dict={"t": t},
self.task_to_unnormalized_prob_funcs_dict[task_name],
local_dict={"t": t},
)
p_ls /= p_ls.sum()
return p_ls
Expand Down Expand Up @@ -171,7 +172,9 @@ def create_task_sampler(
elif sampler_type == "ProportionalMultiTaskSampler":
assert len(sampler_config) == 1
return ProportionalMultiTaskSampler(
task_dict=task_dict, rng=rng, task_to_num_examples_dict=task_to_num_examples_dict,
task_dict=task_dict,
rng=rng,
task_to_num_examples_dict=task_to_num_examples_dict,
)
elif sampler_type == "SpecifiedProbMultiTaskSampler":
assert len(sampler_config) == 2
Expand Down
3 changes: 2 additions & 1 deletion jiant/proj/main/components/write_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def write_configs(config_dict, base_path):
assert os.path.exists(path)
for config_key in config_keys:
py_io.write_json(
config_dict[config_key], os.path.join(base_path, f"{config_key}.json"),
config_dict[config_key],
os.path.join(base_path, f"{config_key}.json"),
)
py_io.write_json(config_dict, os.path.join(base_path, "full.json"))
py_io.write_json(
Expand Down
3 changes: 2 additions & 1 deletion jiant/proj/main/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ class RunConfiguration(zconf.RunConfig):


def export_model(
hf_pretrained_model_name_or_path: str, output_base_path: str,
hf_pretrained_model_name_or_path: str,
output_base_path: str,
):
"""Retrieve model and tokenizer from Transformers and save all necessary data
Things saved:
Expand Down
10 changes: 7 additions & 3 deletions jiant/proj/main/metarunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ class ValState(ExtendedDataClassMixin):
def new(self):
# noinspection PyArgumentList
return self.__class__(
score=self.score, metrics=self.metrics, train_state=self.train_state.new(),
score=self.score,
metrics=self.metrics,
train_state=self.train_state.new(),
)

def to_dict(self):
Expand Down Expand Up @@ -104,7 +106,8 @@ def yield_train_step(self):
train_iterator = self.runner.run_train_context(verbose=self.verbose)
else:
train_iterator = self.runner.resume_train_context(
train_state=self.train_state, verbose=self.verbose,
train_state=self.train_state,
verbose=self.verbose,
)
for train_state in train_iterator:
self.train_state = train_state
Expand Down Expand Up @@ -242,7 +245,8 @@ def eval_save(self):
self.log_writer.write_entry("train_val_best", self.best_val_state.to_dict())
del self.best_state_dict
self.best_state_dict = copy_state_dict(
state_dict=get_model_for_saving(self.model).state_dict(), target_device=CPU_DEVICE,
state_dict=get_model_for_saving(self.model).state_dict(),
target_device=CPU_DEVICE,
)
if self.save_best_model:
self.save_best_model_with_metadata(val_metrics_dict=val_metrics_dict)
Expand Down
18 changes: 13 additions & 5 deletions jiant/proj/main/modeling/model_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,18 @@ def delegate_load(jiant_model, weights_dict: dict, load_mode: str):
"""
if load_mode == "from_transformers":
return load_encoder_from_transformers_weights(
encoder=jiant_model.encoder, weights_dict=weights_dict,
encoder=jiant_model.encoder,
weights_dict=weights_dict,
)
elif load_mode == "from_transformers_with_mlm":
remainder = load_encoder_from_transformers_weights(
encoder=jiant_model.encoder, weights_dict=weights_dict, return_remainder=True,
encoder=jiant_model.encoder,
weights_dict=weights_dict,
return_remainder=True,
)
load_lm_heads_from_transformers_weights(
jiant_model=jiant_model, weights_dict=remainder,
jiant_model=jiant_model,
weights_dict=remainder,
)
return
elif load_mode == "all":
Expand All @@ -121,11 +125,15 @@ def delegate_load(jiant_model, weights_dict: dict, load_mode: str):
return load_encoder_only(jiant_model=jiant_model, weights_dict=weights_dict)
elif load_mode == "partial_weights":
return load_partial_heads(
jiant_model=jiant_model, weights_dict=weights_dict, allow_missing_head_weights=True,
jiant_model=jiant_model,
weights_dict=weights_dict,
allow_missing_head_weights=True,
)
elif load_mode == "partial_heads":
return load_partial_heads(
jiant_model=jiant_model, weights_dict=weights_dict, allow_missing_head_model=True,
jiant_model=jiant_model,
weights_dict=weights_dict,
allow_missing_head_model=True,
)
elif load_mode == "partial":
return load_partial_heads(
Expand Down
8 changes: 6 additions & 2 deletions jiant/proj/main/modeling/primary.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def forward(self, batch: BatchMixin, task: Task, compute_loss: bool = False):
taskmodel_key = self.task_to_taskmodel_map[task_name]
taskmodel = self.taskmodels_dict[taskmodel_key]
return taskmodel(
batch=batch, tokenizer=self.tokenizer, compute_loss=compute_loss,
batch=batch,
tokenizer=self.tokenizer,
compute_loss=compute_loss,
).to_dict()


Expand Down Expand Up @@ -105,7 +107,9 @@ def wrap_jiant_forward(
is_multi_gpu = isinstance(jiant_model, nn.DataParallel)
model_output = construct_output_from_dict(
jiant_model(
batch=batch.to_dict() if is_multi_gpu else batch, task=task, compute_loss=compute_loss,
batch=batch.to_dict() if is_multi_gpu else batch,
task=task,
compute_loss=compute_loss,
)
)
if is_multi_gpu and compute_loss:
Expand Down
55 changes: 42 additions & 13 deletions jiant/proj/main/modeling/taskmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,17 @@ def __init__(self, task, encoder, head: heads.ClassificationHead, **kwargs):

def forward(self, batch, tokenizer, compute_loss: bool = False):
encoder_output = self.encoder.encode(
input_ids=batch.input_ids, segment_ids=batch.segment_ids, input_mask=batch.input_mask,
input_ids=batch.input_ids,
segment_ids=batch.segment_ids,
input_mask=batch.input_mask,
)
logits = self.head(pooled=encoder_output.pooled)
if compute_loss:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.head.num_labels), batch.label_id.view(-1),)
loss = loss_fct(
logits.view(-1, self.head.num_labels),
batch.label_id.view(-1),
)
return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other)
else:
return LogitsOutput(logits=logits, other=encoder_output.other)
Expand All @@ -97,7 +102,9 @@ def __init__(self, task, encoder, head: heads.RegressionHead, **kwargs):

def forward(self, batch, tokenizer, compute_loss: bool = False):
encoder_output = self.encoder.encode(
input_ids=batch.input_ids, segment_ids=batch.segment_ids, input_mask=batch.input_mask,
input_ids=batch.input_ids,
segment_ids=batch.segment_ids,
input_mask=batch.input_mask,
)
# TODO: Abuse of notation - these aren't really logits (issue #1187)
logits = self.head(pooled=encoder_output.pooled)
Expand Down Expand Up @@ -133,7 +140,10 @@ def forward(self, batch, tokenizer, compute_loss: bool = False):
for j in range(len(encoder_output_other_ls[0])):
reshaped_outputs.append(
[
torch.stack([misc[j][layer_i] for misc in encoder_output_other_ls], dim=1,)
torch.stack(
[misc[j][layer_i] for misc in encoder_output_other_ls],
dim=1,
)
for layer_i in range(len(encoder_output_other_ls[0][0]))
]
)
Expand Down Expand Up @@ -168,12 +178,17 @@ def forward(self, batch, tokenizer, compute_loss: bool = False):
TYPE: Description
"""
encoder_output = self.encoder.encode(
input_ids=batch.input_ids, segment_ids=batch.segment_ids, input_mask=batch.input_mask,
input_ids=batch.input_ids,
segment_ids=batch.segment_ids,
input_mask=batch.input_mask,
)
logits = self.head(unpooled=encoder_output.unpooled, spans=batch.spans)
if compute_loss:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.head.num_labels), batch.label_id.view(-1),)
loss = loss_fct(
logits.view(-1, self.head.num_labels),
batch.label_id.view(-1),
)
return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other)
else:
return LogitsOutput(logits=logits, other=encoder_output.other)
Expand All @@ -190,7 +205,9 @@ def __init__(self, task, encoder, head: heads.TokenClassificationHead, **kwargs)

def forward(self, batch, tokenizer, compute_loss: bool = False):
encoder_output = self.encoder.encode(
input_ids=batch.input_ids, segment_ids=batch.segment_ids, input_mask=batch.input_mask,
input_ids=batch.input_ids,
segment_ids=batch.segment_ids,
input_mask=batch.input_mask,
)
logits = self.head(unpooled=encoder_output.unpooled)
# Ensure logits in valid range is at least self.offset_margin higher than others
Expand All @@ -199,7 +216,8 @@ def forward(self, batch, tokenizer, compute_loss: bool = False):
if compute_loss:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
logits.transpose(dim0=1, dim1=2).flatten(end_dim=1), batch.gt_span_idxs.flatten(),
logits.transpose(dim0=1, dim1=2).flatten(end_dim=1),
batch.gt_span_idxs.flatten(),
)
return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other)
else:
Expand All @@ -213,12 +231,17 @@ def __init__(self, task, encoder, head: heads.SpanComparisonHead, **kwargs):

def forward(self, batch, tokenizer, compute_loss: bool = False):
encoder_output = self.encoder.encode(
input_ids=batch.input_ids, segment_ids=batch.segment_ids, input_mask=batch.input_mask,
input_ids=batch.input_ids,
segment_ids=batch.segment_ids,
input_mask=batch.input_mask,
)
logits = self.head(unpooled=encoder_output.unpooled, spans=batch.spans)
if compute_loss:
loss_fct = nn.BCEWithLogitsLoss()
loss = loss_fct(logits.view(-1, self.head.num_labels), batch.label_ids.float(),)
loss = loss_fct(
logits.view(-1, self.head.num_labels),
batch.label_ids.float(),
)
return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other)
else:
return LogitsOutput(logits=logits, other=encoder_output.other)
Expand All @@ -233,7 +256,9 @@ def __init__(self, task, encoder, head: heads.TokenClassificationHead, **kwargs)

def forward(self, batch, tokenizer, compute_loss: bool = False):
encoder_output = self.encoder.encode(
input_ids=batch.input_ids, segment_ids=batch.segment_ids, input_mask=batch.input_mask,
input_ids=batch.input_ids,
segment_ids=batch.segment_ids,
input_mask=batch.input_mask,
)
logits = self.head(unpooled=encoder_output.unpooled)
if compute_loss:
Expand All @@ -254,7 +279,9 @@ def __init__(self, task, encoder, head: heads.QAHead, **kwargs):

def forward(self, batch, tokenizer, compute_loss: bool = False):
encoder_output = self.encoder.encode(
input_ids=batch.input_ids, segment_ids=batch.segment_ids, input_mask=batch.input_mask,
input_ids=batch.input_ids,
segment_ids=batch.segment_ids,
input_mask=batch.input_mask,
)
logits = self.head(unpooled=encoder_output.unpooled)
if compute_loss:
Expand Down Expand Up @@ -300,7 +327,9 @@ def __init__(self, task, encoder, head: heads.AbstractPoolerHead, **kwargs):

def forward(self, batch, tokenizer, compute_loss: bool = False):
encoder_output = self.encoder.encode(
input_ids=batch.input_ids, segment_ids=batch.segment_ids, input_mask=batch.input_mask,
input_ids=batch.input_ids,
segment_ids=batch.segment_ids,
input_mask=batch.input_mask,
)

# A tuple of layers of hidden states
Expand Down
8 changes: 6 additions & 2 deletions jiant/proj/main/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def smart_truncate(dataset: torch_utils.ListDataset, max_seq_length: int, verbos
for datum in maybe_tqdm(dataset.data, desc="Smart truncate data", verbose=verbose):
new_datum_ls.append(
smart_truncate_datum(
datum=datum, max_seq_length=max_seq_length, max_valid_length=max_valid_length,
datum=datum,
max_seq_length=max_seq_length,
max_valid_length=max_valid_length,
)
)
new_dataset = torch_utils.ListDataset(new_datum_ls)
Expand All @@ -70,7 +72,9 @@ def smart_truncate_cache(
for datum in maybe_tqdm(chunk, desc="Smart truncate chunk-datum", verbose=verbose):
new_chunk.append(
smart_truncate_datum(
datum=datum, max_seq_length=max_seq_length, max_valid_length=max_valid_length,
datum=datum,
max_seq_length=max_seq_length,
max_valid_length=max_valid_length,
)
)
torch.save(new_chunk, cache.get_chunk_path(chunk_i))
Expand Down
Loading

0 comments on commit e6d9062

Please sign in to comment.