-
Notifications
You must be signed in to change notification settings - Fork 37
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
example: datasets for image segmentation (#2925)
- Loading branch information
1 parent
e7a1780
commit a89d722
Showing
7 changed files
with
319 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
Image Segmentation | ||
====== | ||
|
||
Image segmentation is the process of partitioning a digital image into multiple image segments, also known as image regions or image objects (sets of pixels). The goal of segmentation is to simplify and/or change the representation of an image into something that is more meaningful and easier to analyze. | ||
|
||
In these examples, we will use Starwhale to evaluate a set of image segmentation models. | ||
|
||
What we learn | ||
------ | ||
|
||
- use the `@Starwhale.predict` and `@Starwhale.evaluate` decorators to define handlers for Starwhale Model to finish the image segmentation evaluation tasks. | ||
- build Starwhale Dataset by Starwhale Python SDK and use Starwhale Dataset Web Viewer. | ||
- use Starwhale Evaluation Summary Page to compare the algorithm quality of the different models. | ||
- build an unified Starwhale Runtime to run all models. | ||
- use Starwhale `replicas` feature to speedup model evaluation. | ||
- use `Starwhale.Image`, `Starwhale.COCOObjectAnnotation` and `Starwhale.BoundingBox` to represent Dataset type. | ||
|
||
Models | ||
------ | ||
|
||
- [Segment Anything Model](https://segment-anything.com/) | ||
- [GrabCut](https://docs.opencv.org/3.4/d8/d83/tutorial_py_grabcut.html) | ||
- [PSPNet](https://github.com/qubvel/segmentation_models) | ||
- [FCN](https://pytorch.org/vision/main/models/fcn.html) | ||
- [SAN](https://github.com/MendelXu/SAN) | ||
|
||
Datasets | ||
------ | ||
|
||
We will use the following datasets to evaluate models. | ||
|
||
- [PASCAL VOC](http://host.robots.ox.ac.uk/pascal/VOC/) | ||
|
||
- Introduction: The dataset contains 20 object categories including vehicles, household, animals and others. Each image in this dataset has pixel-level segmentation annotations, bounding box annotations, and object class annotations. This dataset has been widely used as a benchmark for object detection, semantic segmentation, and classification tasks. In this example, we will use 2012 year dataset. | ||
- Size: 2913 segmentation images. | ||
- Dataset build command: `python3 datasets/pascal_voc.py` | ||
|
||
- [COCO 2017 Stuff](https://cocodataset.org/#stuff-2017) | ||
|
||
- Introduction: The COCO Stuff Segmentation Task is designed to push the state of the art in semantic segmentation of stuff classes. Whereas the object detection task addresses thing classes (person, car, elephant), this task focuses on stuff classes (grass, wall, sky). | ||
- Size: Validation images 5,000. | ||
- Dataset build command: `python3 datasets/coco_stuff.py` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
data/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
from __future__ import annotations | ||
|
||
import json | ||
from pathlib import Path | ||
from collections import defaultdict | ||
|
||
from starwhale import Image, dataset, init_logger | ||
from starwhale.utils import console | ||
from starwhale.base.data_type import COCOObjectAnnotation | ||
|
||
try: | ||
from .utils import extract, download | ||
except ImportError: | ||
from utils import extract, download | ||
|
||
# set Starwhale Python SDK logger level to 3 (DEBUG) | ||
init_logger(3) | ||
|
||
ROOT = Path(__file__).parent | ||
DATA_DIR = ROOT / "data" / "coco2017-stuff" | ||
VAL_DIR_NAME = "val2017" | ||
ANNOTATION_DIR_NAME = "annotations" | ||
|
||
|
||
def download_and_extract(root_dir: Path) -> None: | ||
_zip_fpath = DATA_DIR / "val2017.zip" | ||
# for speedup, fork from http://images.cocodataset.org/zips/val2017.zip | ||
download( | ||
"https://starwhale-examples.oss-cn-beijing.aliyuncs.com/dataset/coco2017-stuff/val2017.zip", | ||
_zip_fpath, | ||
) | ||
extract(_zip_fpath, root_dir, root_dir / VAL_DIR_NAME) | ||
|
||
_zip_fpath = DATA_DIR / "stuff_annotations_trainval2017.zip" | ||
# for speedup, fork from http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip | ||
download( | ||
"https://starwhale-examples.oss-cn-beijing.aliyuncs.com/dataset/coco2017-stuff/stuff_annotations_trainval2017.zip", | ||
_zip_fpath, | ||
) | ||
extract(_zip_fpath, DATA_DIR, DATA_DIR / ANNOTATION_DIR_NAME) | ||
extract( | ||
DATA_DIR / ANNOTATION_DIR_NAME / "stuff_val2017_pixelmaps.zip", | ||
DATA_DIR / ANNOTATION_DIR_NAME, | ||
DATA_DIR / ANNOTATION_DIR_NAME / "stuff_val2017_pixelmaps", | ||
) | ||
|
||
|
||
def build(root_dir: Path): | ||
json_path = root_dir / ANNOTATION_DIR_NAME / "stuff_val2017.json" | ||
with json_path.open() as f: | ||
content = json.load(f) | ||
|
||
annotations = defaultdict(list) | ||
for ann in content["annotations"]: | ||
coco_ann = COCOObjectAnnotation( | ||
id=ann["id"], | ||
image_id=ann["image_id"], | ||
category_id=ann["category_id"], | ||
area=ann["area"], | ||
bbox=ann["bbox"], | ||
iscrowd=ann["iscrowd"], | ||
) | ||
coco_ann.segmentation = ann["segmentation"] | ||
annotations[ann["image_id"]].append(coco_ann) | ||
|
||
console.log("start to build dataset") | ||
with dataset("coco2017-stuff-val") as ds: | ||
for image in content["images"]: | ||
ds[image["id"]] = { | ||
"image": Image( | ||
fp=root_dir / VAL_DIR_NAME / image["file_name"], | ||
shape=[image["width"], image["height"], 3], | ||
), | ||
"coco_url": image["coco_url"], | ||
"date_captured": image["date_captured"], | ||
"annotations": annotations[image["id"]], | ||
"pixelmaps": Image( | ||
fp=root_dir | ||
/ ANNOTATION_DIR_NAME | ||
/ "stuff_val2017_pixelmaps" | ||
/ image["file_name"].replace(".jpg", ".png"), | ||
), | ||
} | ||
|
||
ds.commit() | ||
|
||
console.log(f"{ds} has been built") | ||
|
||
|
||
if __name__ == "__main__": | ||
download_and_extract(DATA_DIR) | ||
build(DATA_DIR) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
from __future__ import annotations | ||
|
||
import xml.etree.ElementTree as ET | ||
from pathlib import Path | ||
|
||
from starwhale import Image, dataset, BoundingBox, init_logger | ||
from starwhale.utils import console | ||
|
||
try: | ||
from .utils import extract, download | ||
except ImportError: | ||
from utils import extract, download | ||
|
||
# set Starwhale Python SDK logger level to 3 (DEBUG) | ||
init_logger(3) | ||
|
||
ROOT = Path(__file__).parent | ||
DATA_DIR = ROOT / "data" / "pascal-voc2012" | ||
|
||
|
||
def download_and_extract(root_dir: Path): | ||
_tar_fpath = DATA_DIR / "VOCtrainval_11-May-2012.tar" | ||
# for speedup, fork from http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar | ||
download( | ||
"https://starwhale-examples.oss-cn-beijing.aliyuncs.com/dataset/pascal2012-voc/VOCtrainval_11-May-2012.tar", | ||
_tar_fpath, | ||
) | ||
extract(_tar_fpath, root_dir, DATA_DIR / "VOCdevkit/VOC2012/SegmentationObject") | ||
|
||
|
||
def build(root_dir: Path): | ||
"""PASCAL VOC 2012 dataset folder structure: | ||
- Annotations | ||
- 2011_000234.xml <-- annotations description | ||
- SegmentationObject | ||
- 2011_000234.png <-- segmentation png | ||
- SegmentationClass | ||
- 2011_000234.png <-- segmentation png | ||
- JPEGImages | ||
- 2011_000234.jpg <-- original jpeg | ||
- ImageSets | ||
- Action | ||
- Layout | ||
- Main | ||
- Segmentation | ||
We only use segmentation related files to build dataset. | ||
""" | ||
|
||
console.log("start to build dataset") | ||
data_dir = root_dir / "VOCdevkit/VOC2012" | ||
with dataset("pascal2012-voc-segmentation") as ds: | ||
for obj_path in (data_dir / "SegmentationObject").iterdir(): | ||
if not obj_path.is_file() or obj_path.suffix != ".png": | ||
continue | ||
|
||
name = obj_path.name.split(".")[0] | ||
annotations = ET.parse(data_dir / "Annotations" / f"{name}.xml").getroot() | ||
size_node = annotations.find("size") | ||
|
||
objects = [] | ||
for obj in annotations.findall("object"): | ||
x_min = int(obj.find("bndbox").find("xmin").text) | ||
x_max = int(obj.find("bndbox").find("xmax").text) | ||
y_min = int(obj.find("bndbox").find("ymin").text) | ||
y_max = int(obj.find("bndbox").find("ymax").text) | ||
|
||
objects.append( | ||
{ | ||
"name": obj.find("name").text, | ||
"bbox": BoundingBox( | ||
x=x_min, y=y_min, width=x_max - x_min, height=y_max - y_min | ||
), | ||
"difficult": obj.find("difficult").text, | ||
"pose": obj.find("pose").text, | ||
} | ||
) | ||
|
||
ds[name] = { | ||
"segmentation_object": Image(obj_path), | ||
"segmentation_class": Image( | ||
data_dir / "SegmentationClass" / f"{name}.png" | ||
), | ||
"original": Image( | ||
data_dir / "JPEGImages" / f"{name}.jpg", | ||
shape=[ | ||
int(size_node.find("width").text), | ||
int(size_node.find("height").text), | ||
int(size_node.find("depth").text), | ||
], | ||
), | ||
"segmented": annotations.find("segmented").text, | ||
"objects": objects, | ||
} | ||
ds.commit() | ||
console.log(f"{ds} has been built") | ||
|
||
|
||
if __name__ == "__main__": | ||
download_and_extract(DATA_DIR) | ||
build(DATA_DIR) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from __future__ import annotations | ||
|
||
import tarfile | ||
import zipfile | ||
from pathlib import Path | ||
|
||
import requests | ||
from tqdm import tqdm | ||
|
||
from starwhale.utils import console | ||
|
||
|
||
def _extract_zip(from_path: Path, to_path: Path) -> None: | ||
with zipfile.ZipFile(from_path, "r", zipfile.ZIP_STORED) as z: | ||
for file in tqdm( | ||
iterable=z.namelist(), | ||
total=len(z.namelist()), | ||
desc=f"extract {from_path.name}", | ||
): | ||
z.extract(member=file, path=to_path) | ||
|
||
|
||
def _extract_tar(from_path: Path, to_path: Path) -> None: | ||
with tarfile.open(from_path, "r") as t: | ||
for file in tqdm( | ||
iterable=t.getmembers(), | ||
total=len(t.getmembers()), | ||
desc=f"extract {from_path.name}", | ||
): | ||
t.extract(member=file, path=to_path) | ||
|
||
|
||
def extract(from_path: Path, to_path: Path, chk_path: Path) -> None: | ||
if not from_path.exists() or from_path.suffix not in (".zip", ".tar"): | ||
raise ValueError(f"invalid zip file: {from_path}") | ||
|
||
if chk_path.exists() and chk_path.is_dir(): | ||
console.log(f"skip extract {from_path}, dir {chk_path} already exists") | ||
return | ||
|
||
console.log(f"extract {from_path} to {to_path} ...") | ||
if from_path.suffix == ".zip": | ||
_extract_zip(from_path, to_path) | ||
elif from_path.suffix == ".tar": | ||
_extract_tar(from_path, to_path) | ||
else: | ||
raise ValueError(f"invalid zip file: {from_path}") | ||
|
||
|
||
def download(url: str, to_path: Path) -> None: | ||
if to_path.exists(): | ||
console.log(f"skip download {url}, file {to_path} already exists") | ||
return | ||
|
||
to_path.parent.mkdir(parents=True, exist_ok=True) | ||
|
||
with requests.get(url, timeout=60, stream=True) as r: | ||
r.raise_for_status() | ||
size = int(r.headers.get("content-length", 0)) | ||
with tqdm( | ||
total=size, | ||
unit="B", | ||
unit_scale=True, | ||
desc=f"download {url}", | ||
initial=0, | ||
unit_divisor=1024, | ||
) as pbar: | ||
with to_path.open("wb") as f: | ||
for chunk in r.iter_content(chunk_size=8192): | ||
f.write(chunk) | ||
pbar.update(len(chunk)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
gradio~=3.15.0 # starwhale online-server only supports gradio ~3.15.0 | ||
pillow | ||
tqdm | ||
requests |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
name: image-segmentation | ||
mode: venv | ||
environment: | ||
cuda: 11.7 | ||
python: 3.9 | ||
starwhale: 0.6.2 | ||
dependencies: | ||
- requirements.txt |