Skip to content

Commit

Permalink
feat(sdk): support Image type to accept numpy and pillow image types (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
tianweidut authored Dec 6, 2023
1 parent 24d55d8 commit f5cce0c
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 17 deletions.
2 changes: 1 addition & 1 deletion client/starwhale/api/_impl/dataset/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1774,7 +1774,7 @@ def _iter_records() -> t.Iterator[t.Tuple[str, t.Dict]]:
record["caption"] = caption_path.read_text().strip()

record["file"] = file_cls(
fp=p,
p,
display_name=p.name,
mime_type=MIMEType.create_by_file_suffix(p),
)
Expand Down
27 changes: 24 additions & 3 deletions client/starwhale/base/data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import numpy

from starwhale.utils import console
from starwhale.consts import SHORT_VERSION_CNT
from starwhale.utils.fs import DIGEST_SIZE, FilePosition
from starwhale.base.mixin import ASDictMixin
Expand Down Expand Up @@ -260,7 +261,7 @@ def to_tensor(self) -> t.Any:
class Image(BaseArtifact, SwObject):
def __init__(
self,
fp: _TArtifactFP = "",
fp: t.Any = "",
display_name: str = "",
shape: t.Optional[_TShape] = None,
mime_type: MIMEType = MIMEType.UNDEFINED,
Expand All @@ -271,8 +272,9 @@ def __init__(
) -> None:
self.as_mask = as_mask
self.mask_uri = mask_uri

super().__init__(
fp,
self._convert_pil_and_numpy(fp),
ArtifactType.Image,
display_name=display_name,
shape=shape or (None, None, 3),
Expand All @@ -281,6 +283,25 @@ def __init__(
link=link,
)

def _convert_pil_and_numpy(self, source: t.Any) -> t.Any:
try:
# pillow is optional for starwhale, so we need to check if it is installed
from PIL import Image as PILImage
except ImportError: # pragma: no cover
console.trace(
"pillow is not installed, skip try to convert PILImage and numpy.ndarray to bytes"
)
return source

if isinstance(source, (PILImage.Image, numpy.ndarray)):
image_bytes = io.BytesIO()
if isinstance(source, numpy.ndarray):
source = PILImage.fromarray(source)
source.save(image_bytes, format="PNG")
return image_bytes.getvalue()
else:
return source

def _do_validate(self) -> None:
if self.mime_type not in (
MIMEType.PNG,
Expand Down Expand Up @@ -338,7 +359,7 @@ def to_tensor(self) -> t.Any:
class GrayscaleImage(Image):
def __init__(
self,
fp: _TArtifactFP = "",
fp: t.Any = "",
display_name: str = "",
shape: t.Optional[_TShape] = None,
as_mask: bool = False,
Expand Down
6 changes: 1 addition & 5 deletions client/starwhale/integrations/huggingface/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,13 @@ def _transform_to_starwhale(data: t.Any, feature: t.Any) -> t.Any:
from PIL import Image as PILImage

if isinstance(data, PILImage.Image):
img_io = io.BytesIO()
data.save(img_io, format=data.format or "PNG")
img_fp = img_io.getvalue()

try:
data_mimetype = data.get_format_mimetype()
mime_type = MIMEType(data_mimetype)
except (ValueError, AttributeError):
mime_type = MIMEType.PNG
return Image(
fp=img_fp,
data,
shape=(data.height, data.width, len(data.getbands())),
mime_type=mime_type,
)
Expand Down
17 changes: 13 additions & 4 deletions client/tests/base/test_data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def test_numpy_binary(self) -> None:
assert torch.equal(torch.from_numpy(np_array), b.to_tensor())

def test_image(self) -> None:
self.fs.create_file("path/to/file", contents="")

fp = io.StringIO("test")
img = Image(fp, display_name="t", shape=[28, 28, 3], mime_type=MIMEType.PNG)
assert img.to_bytes() == b"test"
Expand Down Expand Up @@ -159,7 +161,6 @@ def test_image(self) -> None:
assert _asdict["shape"] == [28, 28, 1]
assert _asdict["_raw_base64_data"] == base64.b64encode(b"test").decode()

self.fs.create_file("path/to/file", contents="")
img = GrayscaleImage(Path("path/to/file"), shape=[28, 28, 1]).carry_raw_data()
typ = data_store._get_type(img)
assert isinstance(typ, data_store.SwObjectType)
Expand All @@ -168,9 +169,9 @@ def test_image(self) -> None:
pixels = numpy.random.randint(
low=0, high=256, size=(100, 100, 3), dtype=numpy.uint8
)
image_bytes = io.BytesIO()
PILImage.fromarray(pixels, mode="RGB").save(image_bytes, format="PNG")
img = Image(image_bytes.getvalue())
pil_obj = PILImage.fromarray(pixels, mode="RGB")

img = Image(pil_obj)
pil_img = img.to_pil()
assert isinstance(pil_img, PILImage.Image)
assert pil_img.mode == "RGB"
Expand All @@ -182,6 +183,14 @@ def test_image(self) -> None:
l_array = img.to_numpy("L")
assert l_array.shape == (100, 100)

img = Image(pixels)
pil_img = img.to_pil()
assert isinstance(pil_img, PILImage.Image)
assert pil_img.mode == "RGB"
array = img.to_numpy()
assert isinstance(array, numpy.ndarray)
assert (array == pixels).all()

def test_swobject_subclass_init(self) -> None:
from starwhale.base import data_type

Expand Down
4 changes: 2 additions & 2 deletions client/tests/core/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def test_head(self, *args: t.Any) -> None:
(
"label-0",
{
"img": GrayscaleImage(fp=b"123"),
"img": GrayscaleImage(b"123"),
"label": 0,
},
)
Expand All @@ -486,7 +486,7 @@ def test_head(self, *args: t.Any) -> None:
(
"label-1",
{
"img": GrayscaleImage(fp=b"456"),
"img": GrayscaleImage(b"456"),
"label": 1,
},
)
Expand Down
2 changes: 1 addition & 1 deletion client/tests/sdk/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ def test_data_row(self) -> None:
assert dr < dr_another
assert dr != dr_another

dr_third = DataRow(index=1, features={"data": Image(fp=b""), "label": 10})
dr_third = DataRow(index=1, features={"data": Image(b""), "label": 10})
assert dr >= dr_third

def test_data_row_exceptions(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion scripts/example/src/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from starwhale import Image


def random_image() -> bytes:
def random_image() -> Image:
try:
return _random_image_from_pillow()
except ImportError:
Expand Down

0 comments on commit f5cce0c

Please sign in to comment.