Skip to content

Commit

Permalink
example: datasets for image segmentation (#2925)
Browse files Browse the repository at this point in the history
  • Loading branch information
tianweidut authored Oct 31, 2023
1 parent e7a1780 commit a89d722
Show file tree
Hide file tree
Showing 7 changed files with 319 additions and 0 deletions.
42 changes: 42 additions & 0 deletions example/image-segmentation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
Image Segmentation
======

Image segmentation is the process of partitioning a digital image into multiple image segments, also known as image regions or image objects (sets of pixels). The goal of segmentation is to simplify and/or change the representation of an image into something that is more meaningful and easier to analyze.

In these examples, we will use Starwhale to evaluate a set of image segmentation models.

What we learn
------

- use the `@Starwhale.predict` and `@Starwhale.evaluate` decorators to define handlers for Starwhale Model to finish the image segmentation evaluation tasks.
- build Starwhale Dataset by Starwhale Python SDK and use Starwhale Dataset Web Viewer.
- use Starwhale Evaluation Summary Page to compare the algorithm quality of the different models.
- build an unified Starwhale Runtime to run all models.
- use Starwhale `replicas` feature to speedup model evaluation.
- use `Starwhale.Image`, `Starwhale.COCOObjectAnnotation` and `Starwhale.BoundingBox` to represent Dataset type.

Models
------

- [Segment Anything Model](https://segment-anything.com/)
- [GrabCut](https://docs.opencv.org/3.4/d8/d83/tutorial_py_grabcut.html)
- [PSPNet](https://github.com/qubvel/segmentation_models)
- [FCN](https://pytorch.org/vision/main/models/fcn.html)
- [SAN](https://github.com/MendelXu/SAN)

Datasets
------

We will use the following datasets to evaluate models.

- [PASCAL VOC](http://host.robots.ox.ac.uk/pascal/VOC/)

- Introduction: The dataset contains 20 object categories including vehicles, household, animals and others. Each image in this dataset has pixel-level segmentation annotations, bounding box annotations, and object class annotations. This dataset has been widely used as a benchmark for object detection, semantic segmentation, and classification tasks. In this example, we will use 2012 year dataset.
- Size: 2913 segmentation images.
- Dataset build command: `python3 datasets/pascal_voc.py`

- [COCO 2017 Stuff](https://cocodataset.org/#stuff-2017)

- Introduction: The COCO Stuff Segmentation Task is designed to push the state of the art in semantic segmentation of stuff classes. Whereas the object detection task addresses thing classes (person, car, elephant), this task focuses on stuff classes (grass, wall, sky).
- Size: Validation images 5,000.
- Dataset build command: `python3 datasets/coco_stuff.py`
1 change: 1 addition & 0 deletions example/image-segmentation/datasets/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
data/*
92 changes: 92 additions & 0 deletions example/image-segmentation/datasets/coco_stuff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from __future__ import annotations

import json
from pathlib import Path
from collections import defaultdict

from starwhale import Image, dataset, init_logger
from starwhale.utils import console
from starwhale.base.data_type import COCOObjectAnnotation

try:
from .utils import extract, download
except ImportError:
from utils import extract, download

# set Starwhale Python SDK logger level to 3 (DEBUG)
init_logger(3)

ROOT = Path(__file__).parent
DATA_DIR = ROOT / "data" / "coco2017-stuff"
VAL_DIR_NAME = "val2017"
ANNOTATION_DIR_NAME = "annotations"


def download_and_extract(root_dir: Path) -> None:
_zip_fpath = DATA_DIR / "val2017.zip"
# for speedup, fork from http://images.cocodataset.org/zips/val2017.zip
download(
"https://starwhale-examples.oss-cn-beijing.aliyuncs.com/dataset/coco2017-stuff/val2017.zip",
_zip_fpath,
)
extract(_zip_fpath, root_dir, root_dir / VAL_DIR_NAME)

_zip_fpath = DATA_DIR / "stuff_annotations_trainval2017.zip"
# for speedup, fork from http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip
download(
"https://starwhale-examples.oss-cn-beijing.aliyuncs.com/dataset/coco2017-stuff/stuff_annotations_trainval2017.zip",
_zip_fpath,
)
extract(_zip_fpath, DATA_DIR, DATA_DIR / ANNOTATION_DIR_NAME)
extract(
DATA_DIR / ANNOTATION_DIR_NAME / "stuff_val2017_pixelmaps.zip",
DATA_DIR / ANNOTATION_DIR_NAME,
DATA_DIR / ANNOTATION_DIR_NAME / "stuff_val2017_pixelmaps",
)


def build(root_dir: Path):
json_path = root_dir / ANNOTATION_DIR_NAME / "stuff_val2017.json"
with json_path.open() as f:
content = json.load(f)

annotations = defaultdict(list)
for ann in content["annotations"]:
coco_ann = COCOObjectAnnotation(
id=ann["id"],
image_id=ann["image_id"],
category_id=ann["category_id"],
area=ann["area"],
bbox=ann["bbox"],
iscrowd=ann["iscrowd"],
)
coco_ann.segmentation = ann["segmentation"]
annotations[ann["image_id"]].append(coco_ann)

console.log("start to build dataset")
with dataset("coco2017-stuff-val") as ds:
for image in content["images"]:
ds[image["id"]] = {
"image": Image(
fp=root_dir / VAL_DIR_NAME / image["file_name"],
shape=[image["width"], image["height"], 3],
),
"coco_url": image["coco_url"],
"date_captured": image["date_captured"],
"annotations": annotations[image["id"]],
"pixelmaps": Image(
fp=root_dir
/ ANNOTATION_DIR_NAME
/ "stuff_val2017_pixelmaps"
/ image["file_name"].replace(".jpg", ".png"),
),
}

ds.commit()

console.log(f"{ds} has been built")


if __name__ == "__main__":
download_and_extract(DATA_DIR)
build(DATA_DIR)
101 changes: 101 additions & 0 deletions example/image-segmentation/datasets/pascal_voc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from __future__ import annotations

import xml.etree.ElementTree as ET
from pathlib import Path

from starwhale import Image, dataset, BoundingBox, init_logger
from starwhale.utils import console

try:
from .utils import extract, download
except ImportError:
from utils import extract, download

# set Starwhale Python SDK logger level to 3 (DEBUG)
init_logger(3)

ROOT = Path(__file__).parent
DATA_DIR = ROOT / "data" / "pascal-voc2012"


def download_and_extract(root_dir: Path):
_tar_fpath = DATA_DIR / "VOCtrainval_11-May-2012.tar"
# for speedup, fork from http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
download(
"https://starwhale-examples.oss-cn-beijing.aliyuncs.com/dataset/pascal2012-voc/VOCtrainval_11-May-2012.tar",
_tar_fpath,
)
extract(_tar_fpath, root_dir, DATA_DIR / "VOCdevkit/VOC2012/SegmentationObject")


def build(root_dir: Path):
"""PASCAL VOC 2012 dataset folder structure:
- Annotations
- 2011_000234.xml <-- annotations description
- SegmentationObject
- 2011_000234.png <-- segmentation png
- SegmentationClass
- 2011_000234.png <-- segmentation png
- JPEGImages
- 2011_000234.jpg <-- original jpeg
- ImageSets
- Action
- Layout
- Main
- Segmentation
We only use segmentation related files to build dataset.
"""

console.log("start to build dataset")
data_dir = root_dir / "VOCdevkit/VOC2012"
with dataset("pascal2012-voc-segmentation") as ds:
for obj_path in (data_dir / "SegmentationObject").iterdir():
if not obj_path.is_file() or obj_path.suffix != ".png":
continue

name = obj_path.name.split(".")[0]
annotations = ET.parse(data_dir / "Annotations" / f"{name}.xml").getroot()
size_node = annotations.find("size")

objects = []
for obj in annotations.findall("object"):
x_min = int(obj.find("bndbox").find("xmin").text)
x_max = int(obj.find("bndbox").find("xmax").text)
y_min = int(obj.find("bndbox").find("ymin").text)
y_max = int(obj.find("bndbox").find("ymax").text)

objects.append(
{
"name": obj.find("name").text,
"bbox": BoundingBox(
x=x_min, y=y_min, width=x_max - x_min, height=y_max - y_min
),
"difficult": obj.find("difficult").text,
"pose": obj.find("pose").text,
}
)

ds[name] = {
"segmentation_object": Image(obj_path),
"segmentation_class": Image(
data_dir / "SegmentationClass" / f"{name}.png"
),
"original": Image(
data_dir / "JPEGImages" / f"{name}.jpg",
shape=[
int(size_node.find("width").text),
int(size_node.find("height").text),
int(size_node.find("depth").text),
],
),
"segmented": annotations.find("segmented").text,
"objects": objects,
}
ds.commit()
console.log(f"{ds} has been built")


if __name__ == "__main__":
download_and_extract(DATA_DIR)
build(DATA_DIR)
71 changes: 71 additions & 0 deletions example/image-segmentation/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from __future__ import annotations

import tarfile
import zipfile
from pathlib import Path

import requests
from tqdm import tqdm

from starwhale.utils import console


def _extract_zip(from_path: Path, to_path: Path) -> None:
with zipfile.ZipFile(from_path, "r", zipfile.ZIP_STORED) as z:
for file in tqdm(
iterable=z.namelist(),
total=len(z.namelist()),
desc=f"extract {from_path.name}",
):
z.extract(member=file, path=to_path)


def _extract_tar(from_path: Path, to_path: Path) -> None:
with tarfile.open(from_path, "r") as t:
for file in tqdm(
iterable=t.getmembers(),
total=len(t.getmembers()),
desc=f"extract {from_path.name}",
):
t.extract(member=file, path=to_path)


def extract(from_path: Path, to_path: Path, chk_path: Path) -> None:
if not from_path.exists() or from_path.suffix not in (".zip", ".tar"):
raise ValueError(f"invalid zip file: {from_path}")

if chk_path.exists() and chk_path.is_dir():
console.log(f"skip extract {from_path}, dir {chk_path} already exists")
return

console.log(f"extract {from_path} to {to_path} ...")
if from_path.suffix == ".zip":
_extract_zip(from_path, to_path)
elif from_path.suffix == ".tar":
_extract_tar(from_path, to_path)
else:
raise ValueError(f"invalid zip file: {from_path}")


def download(url: str, to_path: Path) -> None:
if to_path.exists():
console.log(f"skip download {url}, file {to_path} already exists")
return

to_path.parent.mkdir(parents=True, exist_ok=True)

with requests.get(url, timeout=60, stream=True) as r:
r.raise_for_status()
size = int(r.headers.get("content-length", 0))
with tqdm(
total=size,
unit="B",
unit_scale=True,
desc=f"download {url}",
initial=0,
unit_divisor=1024,
) as pbar:
with to_path.open("wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
pbar.update(len(chunk))
4 changes: 4 additions & 0 deletions example/image-segmentation/runtime/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
gradio~=3.15.0 # starwhale online-server only supports gradio ~3.15.0
pillow
tqdm
requests
8 changes: 8 additions & 0 deletions example/image-segmentation/runtime/runtime.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name: image-segmentation
mode: venv
environment:
cuda: 11.7
python: 3.9
starwhale: 0.6.2
dependencies:
- requirements.txt

0 comments on commit a89d722

Please sign in to comment.