-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcheckpoint_to_h5.py
96 lines (75 loc) · 2.8 KB
/
checkpoint_to_h5.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#!/usr/bin/env python
"""Save the latest checkpoint of a generator to a h5 Keras model
"""
from __future__ import print_function
import os
import sys
import argparse
import tensorflow as tf
import utils
def main(args):
# Config
config = utils.read_config(args.config.name)
if 'generator_config' in config:
config_generator_path = os.path.join(os.path.dirname(args.config.name),
config['generator_config'])
config_generator = utils.read_config(config_generator_path)
else:
config_generator = config
# Check if the config relates to a DnGAN or DAE
if "discriminator_fn" in config:
# Models
print("Creating model...")
generator = utils.get_generator(config_generator)
discriminator = utils.get_discriminator(config)
# Optimizers
generator_optimizer, discriminator_optimizer = utils.get_optimiers(config)
# Step
step = tf.Variable(0, dtype=tf.int64)
# Checkpoints
checkpoint = tf.train.Checkpoint(
generator_optimizer=generator_optimizer,
discriminator_optimizer=discriminator_optimizer,
generator=generator,
discriminator=discriminator,
step=step)
else:
# Models
print("Creating model...")
generator = utils.get_generator(config)
# Optimizers
optimizer, _ = utils.get_optimiers(config)
# Step
step = tf.Variable(0, dtype=tf.int64)
# Checkpoints
checkpoint = tf.train.Checkpoint(
model=generator,
optimizer=optimizer,
step=step)
# Restore
checkpoints_dir = os.path.join(args.checkpoints, config['model_name'])
checkpoint_file = tf.train.latest_checkpoint(checkpoints_dir)
print(f"Restoring checkpoint \"{checkpoint_file}\"...")
checkpoint.restore(tf.train.latest_checkpoint(checkpoints_dir)).expect_partial()
# Save as h5
out_file = args.outfile.name
print(f"Saving model to \"{out_file}\"...")
generator.save(out_file)
def parse_args(arguments):
"""Parse the command line arguments."""
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument('config',
help="Model config file",
type=argparse.FileType('r'))
parser.add_argument('outfile',
help="Output file",
type=argparse.FileType('w'))
parser.add_argument('-c',
'--checkpoints',
help="Checkpoints directory",
default="checkpoints")
return parser.parse_args(arguments)
if __name__ == '__main__':
sys.exit(main(parse_args(sys.argv[1:])))