Skip to content

Commit

Permalink
NCF pipeline refactor (take 2) and initial TPU port. (tensorflow#4935)
Browse files Browse the repository at this point in the history
* intermediate commit

* ncf now working

* reorder pipeline

* allow batched decode for file backed dataset

* fix bug

* more tweaks

* parallize false negative generation

* shared pool hack

* workers ignore sigint

* intermediate commit

* simplify buffer backed dataset creation to fixed length record approach only. (more cleanup needed)

* more tweaks

* simplify pipeline

* fix misplaced cleanup() calls. (validation works\!)

* more tweaks

* sixify memoryview usage

* more sixification

* fix bug

* add future imports

* break up training input pipeline

* more pipeline tuning

* first pass at moving negative generation to async

* refactor async pipeline to use files instead of ipc

* refactor async pipeline

* move expansion and concatenation from reduce worker to generation workers

* abandon complete async due to interactions with the tensorflow threadpool

* cleanup

* remove performance_comparison.py

* experiment with rough generator + interleave pipeline

* yet more pipeline tuning

* update on-the-fly pipeline

* refactor preprocessing, and move train generation behind a GRPC server

* fix leftover call

* intermediate commit

* intermediate commit

* fix index error in data pipeline, and add logging to train data server

* make sharding more robust to imbalance

* correctly sample with replacement

* file buffers are no longer needed for this branch

* tweak sampling methods

* add README for data pipeline

* fix eval sampling, and vectorize eval metrics

* add spillover and static training batch sizes

* clean up cruft from earlier iterations

* rough delint

* delint 2 / n

* add type annotations

* update run script

* make run.sh a bit nicer

* change embedding initializer to match reference

* rough pass at pure estimator model_fn

* impose static shape hack (revisit later)

* refinements

* fix dir error in run.sh

* add documentation

* add more docs and fix an assert

* old data test is no longer valid. Keeping it around as reference for the new one

* rough draft of data pipeline validation script

* don't rely on shuffle default

* tweaks and documentation

* add separate eval batch size for performance

* initial commit

* terrible hacking

* mini hacks

* missed a bug

* messing about trying to get TPU running

* TFRecords based TPU attempt

* bug fixes

* don't log remotely

* more bug fixes

* TPU tweaks and bug fixes

* more tweaks

* more adjustments

* rework model definition

* tweak data pipeline

* refactor async TFRecords generation

* temp commit to run.sh

* update log behavior

* fix logging bug

* add check for subprocess start to avoid cryptic hangs

* unify deserialize and make it TPU compliant

* delint

* remove gRPC pipeline code

* fix logging bug

* delint and remove old test files

* add unit tests for NCF pipeline

* delint

* clean up run.sh, and add run_tpu.sh

* forgot the most important line

* fix run.sh bugs

* yet more bash debugging

* small tweak to add keras summaries to model_fn

* Clean up sixification issues

* address PR comments

* delinting is never over
  • Loading branch information
Taylor Robie authored Jul 30, 2018
1 parent a88b89b commit 6518c1c
Show file tree
Hide file tree
Showing 18 changed files with 1,767 additions and 944 deletions.
8 changes: 6 additions & 2 deletions official/datasets/movielens.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,12 @@ def _progress(count, block_size, total_size):
_regularize_20m_dataset(temp_dir)

for fname in tf.gfile.ListDirectory(temp_dir):
tf.gfile.Copy(os.path.join(temp_dir, fname),
os.path.join(data_subdir, fname))
if not tf.gfile.Exists(os.path.join(data_subdir, fname)):
tf.gfile.Copy(os.path.join(temp_dir, fname),
os.path.join(data_subdir, fname))
else:
tf.logging.info("Skipping copy of {}, as it already exists in the "
"destination folder.".format(fname))

finally:
tf.gfile.DeleteRecursively(temp_dir)
Expand Down
5 changes: 3 additions & 2 deletions official/recommendation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,16 @@ In both datasets, the timestamp is represented in seconds since midnight Coordin
### Download and preprocess dataset
To download the dataset, please install Pandas package first. Then issue the following command:
```
python movielens_dataset.py
python ../datasets/movielens.py
```
Arguments:
* `--data_dir`: Directory where to download and save the preprocessed data. By default, it is `/tmp/movielens-data/`.
* `--dataset`: The dataset name to be downloaded and preprocessed. By default, it is `ml-1m`.

Use the `--help` or `-h` flag to get a full list of possible arguments.

Note the ml-20m dataset is large (the rating file is ~500 MB), and it may take several minutes (~10 mins) for data preprocessing.
Note the ml-20m dataset is large (the rating file is ~500 MB), and it may take several minutes (~2 mins) for data preprocessing.
Both the ml-1m and ml-20m datasets will be coerced into a common format when downloaded.

### Train and evaluate model
To train and evaluate the model, issue the following command:
Expand Down
67 changes: 67 additions & 0 deletions official/recommendation/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Central location for NCF specific values."""

import os
import time


# ==============================================================================
# == Main Thread Data Processing ===============================================
# ==============================================================================
class Paths(object):
"""Container for various path information used while training NCF."""

def __init__(self, data_dir, cache_id=None):
self.cache_id = cache_id or int(time.time())
self.data_dir = data_dir
self.cache_root = os.path.join(
self.data_dir, "{}_ncf_recommendation_cache".format(self.cache_id))
self.train_shard_subdir = os.path.join(self.cache_root,
"raw_training_shards")
self.train_shard_template = os.path.join(self.train_shard_subdir,
"positive_shard_{}.pickle")
self.train_epoch_dir = os.path.join(self.cache_root, "training_epochs")
self.eval_data_subdir = os.path.join(self.cache_root, "eval_data")
self.eval_raw_file = os.path.join(self.eval_data_subdir, "raw.pickle")
self.eval_record_template_temp = os.path.join(self.eval_data_subdir,
"eval_records.temp")
self.eval_record_template = os.path.join(
self.eval_data_subdir, "padded_eval_batch_size_{}.tfrecords")
self.subproc_alive = os.path.join(self.cache_root, "subproc.alive")


APPROX_PTS_PER_TRAIN_SHARD = 128000

# In both datasets, each user has at least 20 ratings.
MIN_NUM_RATINGS = 20

# The number of negative examples attached with a positive example
# when performing evaluation.
NUM_EVAL_NEGATIVES = 999

# ==============================================================================
# == Subprocess Data Generation ================================================
# ==============================================================================
CYCLES_TO_BUFFER = 3 # The number of train cycles worth of data to "run ahead"
# of the main training loop.

READY_FILE = "ready.json"
TRAIN_RECORD_TEMPLATE = "train_{}.tfrecords"

TIMEOUT_SECONDS = 3600 * 2 # If the train loop goes more than two hours without
# consuming an epoch of data, this is a good
# indicator that the main thread is dead and the
# subprocess is orphaned.
Loading

0 comments on commit 6518c1c

Please sign in to comment.