-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathsplit_coco_dataset.py
executable file
·58 lines (42 loc) · 2.2 KB
/
split_coco_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Apr 3 20:59:00 2022
@author: prabhakar
"""
import json
import funcy
import argparse
from sklearn.model_selection import train_test_split
parser = argparse.ArgumentParser(description='Splits COCO annotations file into training and test sets.')
parser.add_argument('annotations', metavar='coco_annotations', type=str,
help='Path to COCO annotations file.')
parser.add_argument('train', type=str, help='Where to store COCO training annotations')
parser.add_argument('test', type=str, help='Where to store COCO test annotations')
parser.add_argument('-s', dest='split', type=float, required=True,
help="A percentage of a split; a number in (0, 1)")
args = parser.parse_args()
def save_coco(file, info, licenses, images, annotations, categories):
with open(file, 'wt', encoding='UTF-8') as coco:
json.dump({ 'info': info, 'licenses': licenses, 'images': images,
'annotations': annotations, 'categories': categories}, coco, indent=2, sort_keys=True)
def filter_annotations(annotations, images):
image_ids = funcy.lmap(lambda i: int(i['id']), images)
return funcy.lfilter(lambda a: int(a['image_id']) in image_ids, annotations)
def main(args):
with open(args.annotations, 'rt', encoding='UTF-8') as annotations:
coco = json.load(annotations)
info = coco['info'] if 'info' in coco else {}
licenses = coco['licenses'] if 'licenses' in coco else {}
images = coco['images']
annotations = coco['annotations']
categories = coco['categories']
#number_of_images = len(images)
images_with_annotations = funcy.lmap(lambda a: int(a['image_id']), annotations)
images = funcy.lremove(lambda i: i['id'] not in images_with_annotations, images)
x, y = train_test_split(images, train_size=args.split)
save_coco(args.train, info, licenses, x, filter_annotations(annotations, x), categories)
save_coco(args.test, info, licenses, y, filter_annotations(annotations, y), categories)
print("Saved {} entries in {} and {} in {}".format(len(x), args.train, len(y), args.test))
if __name__ == "__main__":
main(args)