forked from szakrewsky/quality-feature-extraction
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdeep_features.py
29 lines (24 loc) · 905 Bytes
/
deep_features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import cv2
import decimal
import json
import ijson.backends.yajl2_cffi as ijson
from sklearn_theano.feature_extraction import OverfeatTransformer
tr = OverfeatTransformer(output_layers=[8])
class DecimalEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, decimal.Decimal):
return float(o)
return super(DecimalEncoder, self).default(o)
with open('../workspace/ds.json') as inh:
with open('../workspace/ds_deep.json', 'w') as outh:
ds = ijson.items(inh, 'item')
outh.write('[')
for i, item in enumerate(ds):
print 'running', i+1
if i > 0:
outh.write(',')
img = cv2.imread('set1/' + item['file'])
img = cv2.resize(img, (231, 231))
item['deep'] = tr.transform(img)[0].tolist()
json.dump(item, outh, cls=DecimalEncoder)
outh.write(']')