-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathgenerate.py
208 lines (154 loc) · 7.05 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import cohere
import os
from image_captioning import predict_step
from dotenv import load_dotenv
from binary_tree import *
load_dotenv()
api_key = os.getenv('COHERE_API_KEY')
co = cohere.Client(api_key)
def generate_text(captions, temp=0):
prompt=f"""Predict what I experienced based on these captions in order using first person narrative (I):
{captions[0]},
{captions[1]},
{captions[2]}
Ensure that for each caption, you generate exactly 4 sentences. Separate the output for each caption based
on the format below. Don't ask for feedback. Strictly adhere to the format and ensure you don't generate anything that strays from the
format below. Ensure that the storyline between the outputs for the captions is consistent.
Strictly adhere to the following output format and don't deviate from it:
"
[sentences for caption 1]
###
[sentences for caption 2]
###
[sentences for caption 3]
"
The output should be parse-able using [s.replace('\n', '').replace('"', '') for s in coo.split('###')].
"""
# chat_hist = []
response = co.generate(
model='command',
prompt=prompt,
temperature=0.0,
)
# chat_hist.append(response.generations[0].text)
return response.generations[0].text
def generate_text_level_two(captions, prev, temp=0):
prompt=f"""Predict what I experienced based on these captions in order using first person narrative (I):
{captions[0]},
{captions[1]},
{captions[2]}
The output for the first caption is: {prev[0]}
Ensure that for the remaining captions, you generate exactly 4 sentences. Separate the output for each caption based
on the format below. Don't ask for feedback. Strictly adhere to the format and ensure you don't generate anything that strays from the
format below. Ensure that the storyline between the outputs for the remaining captions is consistent with the
story introduced by the first caption.
Strictly adhere to the following output format and don't deviate from it:
"
{prev[0]}
###
[sentences for caption 2]
###
[sentences for caption 3]
"
The output should be parse-able using [s.replace('\n', '').replace('"', '') for s in coo.split('###')].
"""
# chat_hist = []
response = co.generate(
model='command',
prompt=prompt,
temperature=0.0,
)
# chat_hist.append(response.generations[0].text)
return response.generations[0].text
def generate_text_level_three(captions, prev, temp=0):
prompt=f"""Predict what I experienced based on these captions in order using first person narrative (I):
{captions[0]},
{captions[1]},
{captions[2]}
The output for the first caption is: {prev[0]}
The output for the second caption is: {prev[1]}
Ensure that for the remaining captions, you generate exactly 4 sentences. Separate the output for each caption based
on the format below. Generate only the sentences. Don't ask for feedback. Strictly adhere to the format and don't generate any additional
text asking acknowledging the instructions or asking if the output is satisfactory. Ensure that the storyline between
the outputs for the remaining captions is consistent with the story introduced by the first caption and the second caption.
Strictly adhere to the following output format and don't deviate from it:
"
{prev[0]}
###
{prev[1]}
###
[sentences for caption 3]
"
The output should be parse-able using [s.replace('\n', '').replace('"', '') for s in coo.split('###')].
"""
# chat_hist = []
response = co.generate(
model='command',
prompt=prompt,
temperature=0.0,
)
# chat_hist.append(response.generations[0].text)
return response.generations[0].text
def process_co_output(coo):
print(coo)
return [s.replace('\n', '').replace('"', '') for s in coo.split('###')]
def create_game_tree(imgs):
root = Node(imgs[0], '', 0)
# level 1
root.left = Node(imgs[1], '', 1, 'lose')
root.right = Node(imgs[1], '', 1, 'right')
# level 2
root.left.left, root.left.right = Node(imgs[2], '', 2, 'lose'), Node(imgs[2], '', 2, 'right')
root.right.left, root.right.right = Node(imgs[2], '', 2, 'lose'), Node(imgs[2], '', 2, 'right')
print("Creation completed.")
return root
'''
class Node:
def __init__(self, img, val, level, state = None):
self.image = img # img_path
self.caption = ''
self.val = val # string/story for this image
self.level = level # depth [0, 1, or 2]
self.state = state # win/lose for all levels except 0
self.left = None # lose
self.right = None # win
'''
def populate_tree(root, captions):
# assume generate_text generates winning text
print("Population begun.")
gen = process_co_output(generate_text(captions))
root.caption, root.val = captions[0], gen[0]
root.right.caption, root.right.val = captions[1], gen[1]
root.right.right.caption, root.right.right.val = captions[2], gen[2]
gen2 = process_co_output(generate_text_level_three(captions, gen[:2]))
root.right.left.caption, root.right.left.val = captions[2], gen2[2]
gen_left_from_root = process_co_output(generate_text_level_two(captions, gen[:1]))
root.left.caption, root.left.val = captions[1], gen_left_from_root[1]
root.left.right.caption, root.left.right.val = captions[2], gen_left_from_root[2]
gen_remaining = process_co_output(generate_text_level_three(captions, gen_left_from_root[:2]))
root.left.left.caption, root.left.left.val = captions[2], gen_remaining[2]
print('Population completed.')
return root
# def _expand_tree(node, current_level, max_level):
# if current_level == max_level:
# return
# win_story = generate_text(f"Win story for level {current_level - 1}")
# lose_story = generate_text(f"Lose story for level {current_level - 1}")
# node.left = Node(generate_text(prompt + ""), current_level + 1, 'win')
# node.right = Node(generate_text(prompt + ""), current_level + 1, 'lose')
# # _expand_tree(node.left, current_level + 1, max_level)
# # _expand_tree(node.right, current_level + 1, max_level)
def print_tree(root, level=0, prefix="Root: ", state=""):
if root is not None:
if level == 0:
print(f"{prefix}{root.caption} - {root.state} ({root.level})")
else:
print(f"{' ' * (level * 4)}|-- {root.caption} - {root.state} ({root.level})")
print_tree(root.left, level + 1, "Left: ", root.state)
print_tree(root.right, level + 1, "Right: ", root.state)
if __name__ == "__main__":
image_paths_3 = ['./images/biking.jpg', './images/monke.jpg', './images/rohan.jpeg']
captions = predict_step(image_paths_3)
root = create_game_tree(image_paths_3)
populate_tree(root, captions)
print_tree(root)