-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgather_align_steps_orig.py
402 lines (389 loc) · 18.3 KB
/
gather_align_steps_orig.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
import os
import string
import json
import torch
import numpy as np
import openai
import random
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel
#from nemo.collections.nlp.models import PunctuationCapitalizationModel
import argparse
from tqdm import tqdm
import spacy
from sentence_transformers import SentenceTransformer
import multiprocessing as mp
import _io
nlp = spacy.load('en_core_web_sm')
sent_tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")
def get_next_character(text_list, index1, index2):
if index1 == len(text_list):
return None, index1, index2
if index2 == len(text_list[index1]):
return get_next_character(text_list, index1+1, 0)
if text_list[index1][index2].isspace():
return get_next_character(text_list, index1, index2+1)
return text_list[index1][index2], index1, index2
def align_after_postprocess(postprocessed, original):
index_map = {}
speech_segment_index = 0
within_segment_index = 0
p_index = 0
postprocessed_l = postprocessed # .lower()
while p_index < len(postprocessed_l):
if postprocessed_l[p_index].isspace():
p_index += 1
continue
char, speech_segment_index, within_segment_index = get_next_character(original["text"], speech_segment_index, within_segment_index)
if char is not None:
_, next_speech_segment_index, next_within_segment_index = get_next_character(original["text"], speech_segment_index, within_segment_index+1)
if postprocessed_l[p_index].upper().lower() == char.upper().lower() or postprocessed_l[p_index:p_index+2].upper().lower() == char.upper().lower():
index_map[p_index] = (speech_segment_index, within_segment_index)
speech_segment_index = next_speech_segment_index
within_segment_index = next_within_segment_index
p_index += 1
return index_map
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def encode_section(sent_model, sents, start, end):
section = ' '.join(sents[start:end])
return {(start, end): sent_model.encode([section])[0]}
def remove_punctuation(text):
new_text = text
for c in string.punctuation:
new_text = new_text.replace(c, '')
return new_text
def align_text(text, original_text, steps, sent_model, num_workers, dtw=True, dtw_window_size=10000000000, dtw_start_offset=False):
doc = nlp(text)
sents = [str(sent) for sent in list(doc.sents)]
steps = steps[:len(sents)]
step_embs = sent_model.encode(steps)
text = text.replace('ı', 'i')
if dtw:
dtw_matrix = np.zeros((len(steps)+1, len(sents)+1, len(sents)+1))
for i in range(len(steps)+1):
for start in range(len(sents)+1):
for end in range(len(sents)+1):
dtw_matrix[i,start,end] = -np.inf
dtw_matrix[0,0,0] = 0
pointers = -1*np.ones((len(steps)+1, len(sents)+1, len(sents)+1), dtype=np.int32)
pointer_scores = -np.inf*np.ones((len(steps)+1, len(sents)+1, len(sents)+1), dtype=np.float32)
start_sent_index = 0
if dtw_start_offset:
single_sent_emb = np.stack([sent_model.encode([sent])[0,:] for sent in sents])
start_scores = (step_embs[:1,:]*single_sent_emb).sum(1)
start_sent_index = min(max(0, start_scores.argmax()-1), len(sents)-len(steps))
dtw_matrix[0,start_sent_index,start_sent_index] = 0
section_emb = {}
if num_workers == 1:
batch = []
for start in range(start_sent_index, len(sents)):
for end in range(start+1, min(start+dtw_window_size+1, len(sents)+1)):
section = ' '.join(sents[start:end])
batch.append((start, end, section))
if len(batch) == 16 or (start == len(sents)-1 and end == len(sents)):
inputs = [item[-1] for item in batch]
outputs = sent_model.encode(inputs)
for item, output in zip(batch, outputs):
section_emb[item[:2]] = output
batch = []
if len(batch) > 0:
inputs = [item[-1] for item in batch]
outputs = sent_model.encode(inputs)
for item, output in zip(batch, outputs):
section_emb[item[:2]] = output
else:
with mp.Pool(num_workers) as pool:
section_emb_list = pool.starmap(encode_section, [(sent_model, sents, start, end) for start in range(0, len(sents)) for end in range(start+1, min(start+dtw_window_size+1, len(sents)+1))])
for emb_dict in section_emb_list:
section_emb.update(emb_dict)
for i in range(1, len(steps)+1):
for start in range(start_sent_index, len(sents)):
for end in range(start+1, min(start+dtw_window_size+1, len(sents)+1)):
section = ' '.join(sents[start:end])
sentence_emb = section_emb[(start,end)] # sent_model.encode([section])[0]
step_emb = step_embs[i-1] # sent_model.encode([steps[i-1]])[0]
similarity = (sentence_emb*step_emb).sum().item()
best_prev_segment = dtw_matrix[i-1,:,start].argmax().item()
prev_segment_score = dtw_matrix[i-1,:,start].max().item()
# if prev_segment_score > dtw_matrix[i-1,start,end].item():
# pointers[i,start,end] = best_prev_segment
# else:
# pointers[i,start,end] = start
pointers[i,start,end] = best_prev_segment
pointer_scores[i,start,end] = prev_segment_score
last_max = np.max([prev_segment_score]) # , dtw_matrix[i-1,start,end]])
dtw_matrix[i,start,end] = similarity+last_max
# print('good', i, [j for j in range(dtw_matrix.shape[1]) if dtw_matrix[i,j,:].max().item() > -np.inf])
end = dtw_matrix.shape[1]-1
index = dtw_matrix.shape[0]-1
start = dtw_matrix[index,:,end].argmax().item()
print(dtw_matrix[index,:,:end].max().item())
segments = {index: (start, end)}
index -= 1
while index > 0:
# print(index+1, start, end)
new_start = int(pointers[index+1,start,end])
print(pointer_scores[index+1,start,end])
if new_start != start:
end = start
start = new_start
# else:
# print('bad', pointers[index+1,start,end], pointer_scores[index+1,start,end])
segments[index] = (start, end)
index -= 1
print(start_sent_index, segments)
else:
sent_emb = sent_model.encode(sents)
scores = torch.matmul(torch.from_numpy(step_embs), torch.from_numpy(sent_emb).t())
matched_sentences = scores.argmax(dim=-1).tolist()
segments = {}
for i in range(1, len(steps)+1):
print(steps[i-1], '|||', sents[matched_sentences[i-1]])
segments[i] = (max(0, matched_sentences[i-1]-1), min(len(sents), matched_sentences[i-1]+2))
# text_sans_punct = remove_punctuation(text)
# assert text_sans_punct.lower() == ' '.join(original_text['text'])
postprocess_alignment = align_after_postprocess(text, original_text)
# print(segments)
# print(postprocess_alignment)
aligned_segments = {}
sents = list(doc.sents)
# print(text)
# print(original_text)
# print(' '.join(original_text['text']))
# print(max(list(postprocess_alignment.keys())), [sents[segments[index][0]].start_char for index in segments], [text[sents[segments[index][0]].start_char:sents[segments[index][1]-1].end_char] for index in segments])
for index in segments:
while str(sents[segments[index][0]]).isspace():
segments[index] = (segments[index][0]-1, segments[index][1])
start = sents[segments[index][0]].start_char
while start not in postprocess_alignment and start < len(text):
start += 1
if start not in postprocess_alignment:
print('A', sents[segments[index][0]])
print('B', text[sents[segments[index][0]].start_char:], sents[segments[index][0]].start_char)
print('C', text)
print('D', ' '.join(original_text['text']))
print(sents[segments[index][0]].start_char, sorted(list(postprocess_alignment.keys()))[-50:])
assert start in postprocess_alignment
end = sents[segments[index][1]-1].end_char-1
while end not in postprocess_alignment and end >= 0:
end -= 1
assert end in postprocess_alignment
aligned_segments[index] = postprocess_alignment[start]+postprocess_alignment[end]
print('aligned', ' '.join(original_text['text'][aligned_segments[index][0]:aligned_segments[index][2]+1]), sents[segments[index][0]:segments[index][1]])
return aligned_segments
def remove_repeat_ngrams(text_list, min_n=3, max_n=8, return_segment_ids=False):
assert isinstance(text_list, list)
tokens = []
segment_ids = []
for segment_id, segment in enumerate(text_list):
segment_tokens = segment.split()
for token in segment_tokens:
if len(token) > 0:
tokens.append(token)
segment_ids.append(segment_id)
inside_segment = False
num_streak_tokens = 0
new_tokens = []
new_segment_ids = []
indices_added = set()
for i in range(len(tokens)):
redundant = False
for j in range(max_n, min_n-1, -1):
if i+1 >= j*2 and tokens[i+1-j:i+1] == tokens[i+1-j*2:i+1-j]:
# print('here', tokens[i+1-j*2:i+1])
inside_segment = True
num_streak_tokens = min_n
for k in range(1, j):
if i-k in indices_added:
new_tokens.pop()
new_segment_ids.pop()
indices_added.remove(i-k)
redundant = True
break
if not redundant:
new_tokens.append(tokens[i])
indices_added.add(i)
new_segment_ids.append(segment_ids[i])
if return_segment_ids:
return ' '.join(new_tokens), new_segment_ids
return ' '.join(new_tokens)
def process_video(video_id, args, input_steps, transcripts, tokenizer, output_queue, punct_cap_model=None):
prompt = "Write the steps of the task that the person is demonstrating, based on the noisy transcript.\nTranscript: |||1\nSteps:\n1."
print('here3')
# RUN IF TRANSCRIPTS HAS BEEN PASSED IN
if transcripts is not None:
original = transcripts[video_id]
else:
# OPEN TRANSCRIPTS PATH, QUERY USING video_id.csv
f = open(os.path.join(args.transcripts_path, video_id+".csv"))
# READ LINES OF TRANSCRIPT
lines = f.readlines()
# TEXT - ALL CAPTIONS OF VIDEO ID, START - ALL START TIMESTAMPS, END - ALL END TIMESTAMPS
original = {"text": [], "start": [], "end": []}
for line in lines[1:]:
parts = line.split(',')
original["start"].append(float(parts[0]))
original["end"].append(float(parts[1]))
original["text"].append(parts[-1].strip())
# TRANSCRIPT - JOINED TEXT OF ALL CAPTIONS
transcript = " ".join(original["text"])
# DEDUPLICATES THE TEXT FOR REPEATED CAPTIONS
deduplicated_text, new_segment_ids = remove_repeat_ngrams(original["text"], min_n=3, max_n=9, return_segment_ids=True)
deduplicated_tokens = deduplicated_text.split()
# RESETS 'TEXT' VARIABLES
original["text"] = [[] for _ in range(len(original["text"]))]
# ADD SEGMENT IDS FOR DEDUPLICATED TEXT, APENDS THEM
for token, new_id in zip(deduplicated_tokens, new_segment_ids):
original["text"][new_id].append(token)
# COMBINES DEDUPLICATED TEXT
original["text"] = [" ".join(lst) for lst in original["text"]]
transcript = " ".join(original["text"])
# DEALS WITH FORMATTING OPTIONS
if not args.no_formatting:
# ADDS CAPITALIZATION AND FORMATS TRANSCRIPT??
if args.formatted_transcripts_path is not None:
fname = os.path.join(args.formatted_transcripts_path, video_id+".txt")
if args.formatted_transcripts_path is not None and os.path.exists(fname):
f = open(fname)
transcript = f.readlines()[0]
else:
transcript = punct_cap_model.add_punctuation_capitalization([transcript])[0]
# TOKENIZES TRANSCRIPT
tokens = tokenizer(transcript)
print(video_id, len(transcript), len(tokens["input_ids"]))
# ENSURES TOKEN LENGTH IS LESS THAN MAX TOKEN LENGTH (1600)
while len(tokens["input_ids"]) > 1600:
transcript = transcript[:-100]
tokens = tokenizer(transcript)
# ARGS INPUT STEPS - TAKES IN INPUT STEPS?? NOT SURE IF NECESSARY
if args.input_steps_path is not None:
if video_id not in input_steps:
return
steps = input_steps[video_id]["steps"]
else:
# MAKES OPENAI API CALL
if video_id in finished:
return
input_text = prompt.replace("|||1", transcript)
steps = []
num_attempts = 0
while len(steps) == 0:
response = openai.Completion.create(
engine="text-babbage-001",
prompt=input_text,
temperature=0.7,
max_tokens=256,
top_p=1,
frequency_penalty=0,
presence_penalty=0
)
output = response["choices"][0]["text"].strip()
num_attempts += 1
steps = output.split("\n")
if all(["." in step for step in steps[1:]]):
steps = steps[:1]+[step[step.index(".")+1:].strip() for step in steps[1:]]
elif num_attempts < args.max_attempts:
steps = []
output_dict = {"video_id": video_id, "steps": steps, "transcript": transcript}
# DEALS WITH THE DROP DTW CODE
if not args.no_align:
segments = align_text(transcript, original, steps, sent_model, args.num_workers, not args.no_dtw, args.dtw_window_size)
print(segments)
output_dict["segments"] = segments
if isinstance(output_queue, _io.TextIOWrapper):
output_queue.write(json.dumps(output_dict)+'\n')
else:
output_queue.put(json.dumps(output_dict)+'\n')
def output_listener(output_queue, output_filename):
mode = 'a+' if os.path.exists(output_filename) else 'w'
with open(output_filename, 'a+') as fout:
while True:
output = output_queue.get()
if output == 'kill':
break
fout.write(output)
fout.flush()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--video_list_path")
parser.add_argument("--transcripts_path")
parser.add_argument("--formatted_transcripts_path")
parser.add_argument("--start_index", type=int, default=0)
parser.add_argument("--end_index", type=int, default=None)
parser.add_argument("--max_attempts", type=int, default=1)
parser.add_argument("--no_formatting", action="store_true")
parser.add_argument("--output_path")
parser.add_argument("--cpu", action="store_true")
parser.add_argument("--no_align", action="store_true")
parser.add_argument("--input_steps_path", type=str, default=None)
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--no_dtw", action="store_true")
parser.add_argument("--dtw_window_size", type=int, default=1000000)
args = parser.parse_args()
if not args.no_align:
if args.cpu:
sent_model = SentenceTransformer('sentence-transformers/paraphrase-mpnet-base-v2').cpu()
else:
sent_model = SentenceTransformer('sentence-transformers/paraphrase-mpnet-base-v2').cuda()
# sent_model = AutoModel.from_pretrained('sentence-transformers/paraphrase-mpnet-base-v2').cuda()
'''
if not args.no_formatting:
punct_cap_model = PunctuationCapitalizationModel.from_pretrained("punctuation_en_bert")
if args.cpu:
punct_cap_model = punct_cap_model.cpu()
'''
tokenizer = AutoTokenizer.from_pretrained("gpt2")
f = open(args.video_list_path)
#video_ids = [line.strip().split()[0].split('.')[0] for line in lines]
content = f.read()
print("CONTENT:", content)
video_ids = content.split(",")
print("VIDEO_IDS:", video_ids)
transcripts = None
if args.transcripts_path[-5:] == ".json":
f = open(args.transcripts_path)
transcripts = json.load(f)
if args.end_index is not None:
video_ids = video_ids[:args.end_index]
video_ids = video_ids[args.start_index:]
finished = set()
if os.path.exists(args.output_path):
fout = open(args.output_path)
written_lines = fout.readlines()
fout.close()
for line in written_lines:
try:
datum = json.loads(line)
finished.add(datum['video_id'])
except:
pass
fout = open(args.output_path, 'a')
else:
fout = open(args.output_path, 'w')
input_steps = None
if args.input_steps_path is not None:
f = open(args.input_steps_path)
lines = f.readlines()
input_steps = [json.loads(line) for line in lines]
input_steps = {datum["video_id"]: datum for datum in input_steps}
"""manager = mp.Manager()
q = manager.Queue()
pool = mp.Pool(args.num_workers+2)
watcher = pool.apply_async(output_listener, (q, args.output_path))
print('here1', pool._processes)
jobs = []"""
for video_id in tqdm(video_ids):
if video_id in finished:
continue
# job = pool.apply_async(process_video, (video_id, args, input_steps, transcripts, tokenizer, punct_cap_model, q))
process_video(video_id, args, input_steps, transcripts, tokenizer, fout)
# print('here', len(jobs))
# jobs.append(job)
"""for job in jobs:
job.get()
q.put('kill')
pool.close()
pool.join()"""
fout.close()