Skip to content

Latest commit

 

History

History
192 lines (178 loc) · 38 KB

TRAINING.md

File metadata and controls

192 lines (178 loc) · 38 KB

DeText Training Manual

Training data format and preparation

DeText uses TFRecords format for training data. In general, the input data should have:

  • One field for "query" with name query
  • One field for "wide features" with name wide_ftrs
  • Multiple fields for "document fields" with name doc_<field name>
  • One field for "labels" with name label
  • [optional] Mutiple fields for "user fields" with name usr_<field name>
  • [optional] One field for "sparse wide features indices" with name wide_ftrs_sp_idx and one field for "sparse wide features values" with name wide_ftrs_sp_val

We show an example of the prepared training data and explain the data format and shapes.

  • query (string list containing only 1 string)
    • For each training sample, there should be 1 query field.
    • eg. ["how do you del ##ete messages"]
  • wide_ftrs (float list)
    • There could be multiple dense wide features for each document. Therefore the dense wide features are a 2-D array with shape [#documents, #dense wide features per document]. Since TFRecords support 1-D FloatList, we flatten the dense wide features in the preparation and transform to grouped features in train/data_fn.py by reshaping. Therefore the float list of wide_ftrs in the training data has #documents * #dense wide features per document = 4 * 3 = 12 entries. The dense wide features belong to each document sequentially. I.e., the first 3 wide features belong to the first document, the second 3 wide features belong to the second document, etc..
    • [0.305 0.264 0.180 0.192 0.136 0.027 0.273 0.273 0.377 0.233 0.264 0.227]
  • wide_ftrs_sp_idx (int list)
    • There could be multiple sparse wide features for each document. Therefore the sparse wide features indices are a 2-D array with shape [#documents, #max num of sparse wide features among documents]. Since TFRecords support 1-D IntList, we flatten the sparse wide features indices in the preparation and transform to grouped features in train/data_fn.py by reshaping. Therefore the int list of wide_ftrs_sp_idx in the training data has #sum_i(num documents in list i * #max sparse wide features in list i) entries. Within the same list, if the number of sparse feature of document m is smaller than max number of sparse wide features in the list, the sparse feature indices must be padded with 0. An example below shows the wide_ftrs_sp_idx for a list where the maximum number of sparse wide features is 2 and the list has 4 documents. The sparse wide features belong to each document sequentially. I.e., the first 2 wide features belong to the first document, the second 2 wide features belong to the second document, etc.. Note that 0 should NEVER be used for wide_ftrs_sp_idx except for padding.
    • [3 2 5000 20 1 0 8 0]
  • wide_ftrs_sp_val (float list)
    • Sparse wide feature values are in the same shape and must be correspondent to as sparse wide feature indices. I.e., if the sparse feature indices of list i is [1, 5, 2], then the sparse feature values [-5.0, 12.0, 11.0] means that the sparse wide features for this list is [-5.0, 11.0, 0.0, 0.0, 12.0]. If this field is missing, values corresponding to sparse wide feature indices will be set to 1 by default. Values corresponding to padding values of sparse wide feature indices must be set to 0.
    • [3 2 5000 20 1 0 8 0]
  • label (float list)
    • The labels corresponding to each document. In our example, 0 for documents without any click and 1 for documents with clicks.
    • [0 0 1 0 0 0 0 0 0 0]
  • doc_titles (string list)
    • Document text fields. The shape should be the same as label. There could be multiple doc_fields in the data. For example, we could also include a doc_description as a feature. If multiple doc_fields are present, the interaction features will be computed for each query-doc pair.
    • ["creating a linked ##in group", "edit your profile"...]

Customizing and training a DeText model

The following example (from run_detext.sh) shows how you could train a DeText CNN model for a search ranking task.

The train/dev/test datasets are prepared in the format mentioned in the previous section. More specifically, the following fields are used:

  • query
  • wide_ftrs
  • wide_ftrs_sp_idx
  • wide_ftrs_sp_val
  • doc_titles
  • label

The DeText model will extract deep features using the CNN module from both query and doc_titles. After the text representation, cosine similarity interaction feature between the two fields is computed. The interaction score is then concatenated with the wide_ftrs. A dense hidden layer is added before computing the final LTR score.

The following script is used for running the DeText training.

python run_detext.py \
--ftr_ext=cnn \ # deep module is CNN
--feature_names=query,label,wide_ftrs,doc_title \   # list all the feature names in the data
--learning_rate=0.001 \
--ltr=softmax \ # type of ltr loss
--max_len=32 \  # sentence max length
--min_len=3 \   # sentence min length
--num_fields=1 \    # the number of document fields (starting with 'doc_') used
--filter_window_sizes=2,3 \   # CNN filter sizes. Could be a list of different sizes.
--num_filters=50 \  # number of filters in CNN
--num_hidden=100 \  # size of hidden layer after the interaction layer
--num_train_steps=10 \
--num_units=32 \    # word embedding size
--num_wide=10 \     # number of wide features per document
--optimizer=bert_adam \
--pmetric=ndcg@10 \     primary metric. This is used for evaluation during training. Best models are kept according to this metric.
--random_seed=11 \
--steps_per_stats=1 \
--steps_per_eval=2 \
--test_batch_size=2 \
--train_batch_size=2 \
--use_wide=True \   # whether to use wide_ftrs
--use_deep=True \   # whether to use the text features
--dev_file=hc_examples.tfrecord \
--test_file=hc_examples.tfrecord \
--train_file=hc_examples.tfrecord \
--vocab_file=vocab.txt \
--out_dir=detext-output/hc_cnn_f50_u32_h100 \

The primary parameters are included with comments. Please also find the complete list of training parameters in the next section.

List of all DeText parameters

A complete list of training parameters that DeText provides is given below. Users can refer to this table for full customization when designing DeText models.

Parameter Name Type Choices Default Help
Network ftr_ext str cnn, bert, lstm, lstm_lm NLP feature extraction module.
num_units int 128 word embedding size.
num_units_for_id_ftr int 128 id feature embedding size.
num_hidden str 0 hidden size. This could be a number or a list of comma separated numbers for multiple hidden layers.
num_wide int 0 number of wide features per doc.
ltr_loss_fn str pairwise learning-to-rank method.
use_deep str2bool TRUE Whether to use deep features.
elem_rescale str2bool TRUE Whether to perform elementwise rescaling.
emb_sim_func str inner The approach to compute query/doc similarity scores: inner/hadamard/concat or any combination of them separater by comma.
num_classes int 1 Number of classes for multi-class classification tasks. This should be set to the number of classes in the multiclass classification task.
Sparse feature related num_wide_sp int None maximum number of sparse wide features
sp_emb_size int 1 embedding size of sparse wide features
CNN related filter_window_sizes str "1,2,3" CNN filter window sizes.
num_filters int 100 number of CNN filters.
explicit_empty str2bool FALSE Explicitly modeling empty string in cnn
BERT related lr_bert float None Learning rate factor for bert components
bert_config_file str None bert config.
bert_checkpoint str None pretrained bert model checkpoint.
LSTM related unit_type str lstm lstm RNN cell unit type. Currently only supports lstm
num_layers int 1 RNN layers
num_residual_layers int 0 Number of residual layers from top to bottom. For example, if num_layers=4 and num_residual_layers=2, the last 2 RNN cells in the returned list will be wrapped with ResidualWrapper.
forget_bias float 1 Forget bias of RNN cell
rnn_dropout float 0 Dropout of RNN cell
bidirectional str2bool FALSE Whether to use bidirectional RNN
normalized_lm str2bool FALSE Whether to use normalized lm. This option only works for lstm_lm
Optimizer optimizer str sgd, adam, bert_adam, bert_lamb sgd sgd, adam, adam with weight decay (similar to bert's optimizer implementation), lamb with weight decay (see paper Large Batch Optimization for Deep Learning: Training BERT in 76 minutes).
max_gradient_norm float 5 Clip gradients to this norm.
learning_rate float 1 Learning rate. Adam: 0.001
num_train_steps int 1 Num steps to train.
num_warmup_steps int 0 Num steps for warmup.
train_batch_size int 32 Training data batch size.
test_batch_size int 32 Test data batch size.
Data train_file str None Train file.
dev_file str stNone Dev file.
test_file str None Test file.
out_dir str None Store log/model files.
std_file str None feature standardization file
Vocab related vocab_file str None Vocab file
we_file str None Pretrained word embedding file
we_trainable str2bool TRUE Whether to train word embedding
PAD str [PAD] Token for padding
SEP str [SEP] Token for sentence separation
CLS str [CLS] Token for start of sentence
UNK str [UNK] Token for unknown word
MASK str [MASK] Token for masked word
vocab_file_for_id_ftr str None Vocab file for id features
we_file_for_id_ftr str None Pretrained word embedding file for id features
we_trainable_for_id_ftr str2bool TRUE Whether to train word embedding for id features
PAD_FOR_ID_FTR str [PAD] Padding token for id features
UNK_FOR_ID_FTR str [UNK] Unknown word token for id features
MISC random_seed int 1234 Random seed (>0, set a specific seed).
steps_per_stats int 100 training steps to print statistics.
steps_per_eval int 1000 training steps to evaluate datasets.
keep_checkpoint_max int >= 0 5 The maximum number of recent checkpoint files to keep. If 0, all checkpoint files are kept. Defaults to 5
max_len int 32 max sent length.
min_len int 3 min sent length.
feature_names str None the feature names.
lambda_metric str None only support ndcg.
init_weight float 0.1 weight initialization value.
pmetric str None Primary metric.
all_metrics str precision@1,ndcg@10 All metrics.
score_rescale str None The mean and std of previous model.
save_model_scoring_modes str mode_all_online_scoring The decoding modes to generate savedmodel. Default is 'mode_all_online_scoring'. Supported modes are: 'mode_all_online_scoring', 'mode_query_embedding', 'mode_doc_embedding', 'mode_sim_wide_scoring'.
tokenization str None The tokenzation performed for data preprocessing. Currently support: punct/plain(no split). Note that this should be set correctly to ensure consistency for savedmodel.
resume_training str2bool TRUE Whether to resume training from checkpoint in out_dir.
metadata_path str None The metadata_path for converted avro2tf avro data.
TFR related tfr_metrics str None tf-ranking metrics.
use_tfr_loss str2bool FALSE whether to use tf-ranking loss.
tfr_loss_fn tfr.losses.RankingLossKey.
SOFTMAX_LOSS, tfr.losses.RankingLossKey.
PAIRWISE_LOGISTIC_LOSS
tfr.losses.RankingLossKey.
SOFTMAX_LOSS
softmax_loss
tfr_lambda_weights str None tfr_lambda_weights parameter for tfr loss function.
use_horovod str2bool FALSE whether to use horovod for sync distributed training
Multitask related task_ids str None All types of task IDs used for multitask training. E.g. 1,2,3
task_weights str None Weights for each task specified in task_ids. E.g. 0.5,0.3,0.2