Skip to content

Commit

Permalink
Fix xlm-r loading (#1329)
Browse files Browse the repository at this point in the history
  • Loading branch information
zphang authored Jul 26, 2021
1 parent 2666eb9 commit 51e9be2
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
2 changes: 1 addition & 1 deletion jiant/proj/main/modeling/model_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def load_encoder_from_transformers_weights(
remainder_weights_dict = {}
load_weights_dict = {}
model_arch = ModelArchitectures.from_model_type(model_type=encoder.config.model_type)
encoder_prefix = model_arch.value + "."
encoder_prefix = model_arch.get_encoder_prefix() + "."
# Encoder
for k, v in weights_dict.items():
if k.startswith(encoder_prefix):
Expand Down
18 changes: 9 additions & 9 deletions jiant/scripts/benchmarks/xtreme/subscripts/d_write_configs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,70 +8,70 @@
mkdir -p ${BASE_PATH}/runconfigs

# XNLI
python jiant/scripts/postproc/xtreme/xtreme_runconfig_writer.py \
python jiant/scripts/benchmarks/xtreme/xtreme_runconfig_writer.py \
--xtreme_task xnli \
--task_config_base_path ${BASE_PATH}/tasks/configs \
--task_cache_base_path ${BASE_PATH}/cache/${MODEL_TYPE} \
--epochs 2 --train_batch_size 4 --gradient_accumulation_steps 8 \
--output_path ${BASE_PATH}/runconfigs/xnli.json

# PAWS-X
python jiant/scripts/postproc/xtreme/xtreme_runconfig_writer.py \
python jiant/scripts/benchmarks/xtreme/xtreme_runconfig_writer.py \
--xtreme_task pawsx \
--task_config_base_path ${BASE_PATH}/tasks/configs \
--task_cache_base_path ${BASE_PATH}/cache/${MODEL_TYPE} \
--epochs 5 --train_batch_size 4 --gradient_accumulation_steps 8 \
--output_path ${BASE_PATH}/runconfigs/pawsx.json

# UDPOS
python jiant/scripts/postproc/xtreme/xtreme_runconfig_writer.py \
python jiant/scripts/benchmarks/xtreme/xtreme_runconfig_writer.py \
--xtreme_task udpos \
--task_config_base_path ${BASE_PATH}/tasks/configs \
--task_cache_base_path ${BASE_PATH}/cache/${MODEL_TYPE} \
--epochs 10 --train_batch_size 4 --gradient_accumulation_steps 8 \
--output_path ${BASE_PATH}/runconfigs/udpos.json

# PANX
python jiant/scripts/postproc/xtreme/xtreme_runconfig_writer.py \
python jiant/scripts/benchmarks/xtreme/xtreme_runconfig_writer.py \
--xtreme_task panx \
--task_config_base_path ${BASE_PATH}/tasks/configs \
--task_cache_base_path ${BASE_PATH}/cache/${MODEL_TYPE} \
--epochs 10 --train_batch_size 4 --gradient_accumulation_steps 8 \
--output_path ${BASE_PATH}/runconfigs/panx.json

# XQuAD
python jiant/scripts/postproc/xtreme/xtreme_runconfig_writer.py \
python jiant/scripts/benchmarks/xtreme/xtreme_runconfig_writer.py \
--xtreme_task xquad \
--task_config_base_path ${BASE_PATH}/tasks/configs \
--task_cache_base_path ${BASE_PATH}/cache/${MODEL_TYPE} \
--epochs 2 --train_batch_size 4 --gradient_accumulation_steps 4 \
--output_path ${BASE_PATH}/runconfigs/xquad.json

# MLQA
python jiant/scripts/postproc/xtreme/xtreme_runconfig_writer.py \
python jiant/scripts/benchmarks/xtreme/xtreme_runconfig_writer.py \
--xtreme_task mlqa \
--task_config_base_path ${BASE_PATH}/tasks/configs \
--task_cache_base_path ${BASE_PATH}/cache/${MODEL_TYPE} \
--epochs 2 --train_batch_size 4 --gradient_accumulation_steps 4 \
--output_path ${BASE_PATH}/runconfigs/mlqa.json

# TyDiQA
python jiant/scripts/postproc/xtreme/xtreme_runconfig_writer.py \
python jiant/scripts/benchmarks/xtreme/xtreme_runconfig_writer.py \
--xtreme_task tydiqa \
--task_config_base_path ${BASE_PATH}/tasks/configs \
--task_cache_base_path ${BASE_PATH}/cache/${MODEL_TYPE} \
--epochs 2 --train_batch_size 4 --gradient_accumulation_steps 4 \
--output_path ${BASE_PATH}/runconfigs/tydiqa.json

# Bucc2018
python jiant/scripts/postproc/xtreme/xtreme_runconfig_writer.py \
python jiant/scripts/benchmarks/xtreme/xtreme_runconfig_writer.py \
--xtreme_task bucc2018 \
--task_config_base_path ${BASE_PATH}/tasks/configs \
--task_cache_base_path ${BASE_PATH}/cache/${MODEL_TYPE} \
--output_path ${BASE_PATH}/runconfigs/bucc2018.json

# Tatoeba
python jiant/scripts/postproc/xtreme/xtreme_runconfig_writer.py \
python jiant/scripts/benchmarks/xtreme/xtreme_runconfig_writer.py \
--xtreme_task tatoeba \
--task_config_base_path ${BASE_PATH}/tasks/configs \
--task_cache_base_path ${BASE_PATH}/cache/${MODEL_TYPE} \
Expand Down
6 changes: 6 additions & 0 deletions jiant/shared/model_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ class ModelArchitectures(Enum):
def from_model_type(cls, model_type: str):
return cls(model_type)

def get_encoder_prefix(self):
if self.value == "xlm-roberta":
return "roberta"
else:
return self.value


TOKENIZER_CLASS_DICT = BiDict(
{
Expand Down

0 comments on commit 51e9be2

Please sign in to comment.