-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathgenerate_poem.py
33 lines (28 loc) · 1.28 KB
/
generate_poem.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import argparse
from rupo.api import Engine
from rupo.settings import RU_STRESS_DEFAULT_MODEL, ZALYZNYAK_DICT, GENERATOR_MODEL_DIR
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--model-path', type=str, default=GENERATOR_MODEL_DIR)
parser.add_argument('--token-vocab-path', type=str, default=None)
parser.add_argument('--stress-vocab-path', type=str, default=None)
parser.add_argument('--metre-schema', type=str, default='+-')
parser.add_argument('--rhyme-pattern', type=str, default='abab')
parser.add_argument('--n-syllables', type=int, default=8)
parser.add_argument('--sampling-k', type=int, default=50000)
parser.add_argument('--beam-width', type=int, default=None)
parser.add_argument('--temperature', type=float, default=1.0)
parser.add_argument('--last-text', type=str, default="")
parser.add_argument('--count', type=int, default=100)
args = parser.parse_args()
kwargs = vars(args)
count = kwargs.pop('count')
engine = Engine()
engine.load(RU_STRESS_DEFAULT_MODEL, ZALYZNYAK_DICT)
for seed in range(count):
print(seed)
try:
poem = engine.generate_poem(seed=seed, **kwargs)
print(poem)
except AssertionError as e:
print("Error: ", e)