-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathresolve_conflicts.py
executable file
·133 lines (121 loc) · 5.66 KB
/
resolve_conflicts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#!/usr/bin/env python3
import sys
import argparse
parser = argparse.ArgumentParser(description='Merge the segmented, classified plants from each overlapping image together.')
parser.add_argument(
"ortho", help="the path to the orthomosaic"
)
parser.add_argument(
"segments", help="the path to a directory containing (for each image in the orthomosaic) the coordinates of each segmented region"
)
parser.add_argument(
"predicts", help="the path to a directory containing tsv files (for each image in the orthomosaic) with the species class of each segmented region"
)
parser.add_argument(
"--no-labels", action='store_true', help="whether to include the labels of each segment in the output"
)
parser.add_argument(
"out", help="the classes of each segmented region within the orthomosaic"
)
args = parser.parse_args()
args.segments += '/' if not args.segments.endswith('/') else ''
args.predicts += '/' if not args.predicts.endswith('/') else ''
import os
import numpy as np
import pandas as pd
import import_labelme
from PIL import Image
# from matplotlib import pyplot as plt
THRESHOLD = 0.5
Image.MAX_IMAGE_PIXELS = None # so that PIL doesn't complain when we open large files
def shoelace(coords):
""" get the area of a polygon, represented as a list of x-y coordinates """
coords = np.array(coords)
x, y = coords[:,0], coords[:,1]
return 0.5*np.abs(np.dot(x,np.roll(y,1))-np.dot(y,np.roll(x,1)))
def resolve(segments):
""" given the predicts of a segment from multiple files (as a pandas Series), return a new class (as a pandas Series) """
# strategy: weight each probability by the relative size of its area
# first calculate the relative size of each segment
segments['area'] = segments['area']/sum(segments['area'])
segments['prob.1'] = segments['prob.1']*segments['area']
# first, check: is this testing data? if so, we want to preserve the truth
if 'truth' in segments:
return pd.Series([segments['truth'][0], sum(segments['prob.1'])], index=['truth', 'prob.1'])
else:
return pd.Series([sum(segments['prob.1'])], index=['prob.1'])
def img_size(ortho=args.ortho):
# first, load the image as an array, then get its shape
print('loading orthomosaic')
# note that this step is extremely memory inefficient!
# it loads the entire image into memory just so that we can get the image size
# TODO: improve memory usage here, perhaps by getting the image size from the segments.json file (in the imageData tag) or by using a different library that can determine image size without loading the image into memory
return np.asarray(Image.open(args.ortho)).shape[-2::-1]
img_shape = img_size()
# next, load the segments coords
print('loading segments')
# first, get a list of the segment files, sorted by their names
# TODO: also support .npy masks, instead of just JSON segments
segments_fnames = sorted([f for f in os.listdir(args.segments) if f.endswith('.json')])
# and then import them using labelme and convert each set of coords to an area
segments = {
segment[:-len('.json')]: {
label: shoelace(coords)
for label, coords in import_labelme.main(args.segments+segment, True, img_shape).items()
}
for segment in segments_fnames
}
segments_complete = {
segment[:-len('.json')]: {
label: shoelace(coords)
for label, coords in import_labelme.main(args.segments+segment, True).items()
}
for segment in segments_fnames
}
# lastly, flatten the segments to a pandas df multi-indexed by cam and label
areas = pd.DataFrame.from_dict({
(cam, seg): [segments[cam][seg]]
for cam in segments for seg in segments[cam]
}).T
areas_complete = pd.DataFrame.from_dict({
(cam, seg): [segments_complete[cam][seg]]
for cam in segments_complete
for seg in segments_complete[cam]
}).T
# we created two different dataframes
# the areas_complete dataframe contains the sizes of each segment in the orthomosaic
# while the areas dataframe contains the sizes within each image
# so now we divide the two to get the fractional area of each segment in each image
areas = areas/areas_complete
areas.columns = ['area']
# also load the predicts
print('loading classification predictions')
# first, get a list of the classification files, sorted by their names
predicts = sorted([f for f in os.listdir(args.predicts) if f.endswith('.tsv')])
# check that there are an equal number of segments and predicts
assert len(segments) >= len(predicts), "There are less camera files in the segments dir than in the predicts dir."
# import them as a single large, multi-indexed pandas dataframe
predicts = pd.concat(
{
predict[:-len('.tsv')] : pd.read_csv(args.predicts+predict, sep="\t")
for predict in predicts
}
)
# check that the number of segments is kosher before adding the areas
assert len(predicts) <= len(areas), "There are less segments among all of the files than there are classification predictions."
# add the areas of each segment as a column in the predicts df
predicts = predicts.join(areas)
# the dataframe is multi-indexed by camera and label
predicts.index.names = ['camera', 'label']
# now, we can finally group the segments by their label and assign them a new class
print('resolving conflicts')
# get the truth and probs.1 columns
results = predicts.groupby('label').apply(resolve)
# get the prob.0 column and add it before the prob.1 column
results.insert(list(results.columns).index('prob.1'), 'prob.0', (1 - results['prob.1']))
# add the response column back too
results['response'] = (results['prob.1'] >= THRESHOLD).apply(int)
# last step: write the results to the outfile
print('saving results')
# but first, reorder the columns
results.to_csv(args.out, sep="\t", index=(not args.no_labels))