Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

Commit

Permalink
vis_keypoint --> vis_point
Browse files Browse the repository at this point in the history
  • Loading branch information
yuyu2172 committed Feb 27, 2018
1 parent 18f2be3 commit 954eaf5
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 30 deletions.
2 changes: 1 addition & 1 deletion chainercv/visualizations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from chainercv.visualizations.vis_bbox import vis_bbox # NOQA
from chainercv.visualizations.vis_image import vis_image # NOQA
from chainercv.visualizations.vis_keypoint import vis_keypoint # NOQA
from chainercv.visualizations.vis_point import vis_point # NOQA
from chainercv.visualizations.vis_semantic_segmentation import vis_semantic_segmentation # NOQA
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,32 @@
from chainercv.visualizations.vis_image import vis_image


def vis_keypoint(img, keypoint, kp_mask=None, ax=None):
"""Visualize keypoints in an image.
def vis_point(img, point, mask=None, ax=None):
"""Visualize points in an image.
Example:
>>> import chainercv
>>> import matplotlib.pyplot as plot
>>> dataset = chainercv.datasets.CUBKeypointDataset()
>>> img, keypoint, kp_mask = dataset[0]
>>> chainercv.visualizations.vis_keypoint(img, keypoint, kp_mask)
>>> img, point, mask = dataset[0]
>>> chainercv.visualizations.vis_point(img, point, mask)
>>> plot.show()
Args:
img (~numpy.ndarray): An image of shape :math:`(3, height, width)`.
This is in RGB format and the range of its value is
:math:`[0, 255]`. This should be visualizable using
:obj:`matplotlib.pyplot.imshow(img)`
keypoint (~numpy.ndarray): An array with keypoint pairs whose shape is
:math:`(K, 2)`, where :math:`K` is
the number of keypoints in the array.
point (~numpy.ndarray): An array of point coordinates whose shape is
:math:`(P, 2)`, where :math:`P` is
the number of points.
The second axis corresponds to :math:`y` and :math:`x` coordinates
of the keypoint.
kp_mask (~numpy.ndarray, optional): A boolean array whose shape is
:math:`(K,)`. If :math:`i` th index is :obj:`True`, the
:math:`i` th keypoint is not displayed. If not specified,
all keypoints in :obj:`keypoint` will be displayed.
of the points.
mask (~numpy.ndarray, optional): A boolean array whose shape is
:math:`(P,)`. If :math:`i` th index is :obj:`True`, the
:math:`i` th point is not displayed. If not specified,
all points in :obj:`keypoint` will be displayed.
ax (matplotlib.axes.Axes, optional): If provided, plot on this axis.
Returns:
Expand All @@ -42,18 +42,18 @@ def vis_keypoint(img, keypoint, kp_mask=None, ax=None):
ax = vis_image(img, ax=ax)

_, H, W = img.shape
n_kp = len(keypoint)
n_point = len(point)

if kp_mask is None:
kp_mask = np.ones((n_kp,), dtype=np.bool)
if mask is None:
mask = np.ones((n_point,), dtype=np.bool)

cm = plot.get_cmap('gist_rainbow')

colors = [cm(1. * i / n_kp) for i in six.moves.range(n_kp)]
colors = [cm(1. * i / n_point) for i in six.moves.range(n_point)]

for i in range(n_kp):
if kp_mask[i]:
ax.scatter(keypoint[i][1], keypoint[i][0], c=colors[i], s=100)
for i in range(n_point):
if mask[i]:
ax.scatter(point[i][1], point[i][0], c=colors[i], s=100)

ax.set_xlim(left=0, right=W)
ax.set_ylim(bottom=H - 1, top=0)
Expand Down
6 changes: 3 additions & 3 deletions docs/source/reference/visualizations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ vis_image
~~~~~~~~~
.. autofunction:: vis_image

vis_keypoint
~~~~~~~~~~~~
.. autofunction:: vis_keypoint
vis_point
~~~~~~~~~
.. autofunction:: vis_point

vis_semantic_segmentation
~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from chainer import testing

from chainercv.visualizations import vis_keypoint
from chainercv.visualizations import vis_point

try:
import matplotlib # NOQA
Expand All @@ -14,17 +14,17 @@


@testing.parameterize(
{'kp_mask': np.array([True, True, False])},
{'kp_mask': None}
{'mask': np.array([True, True, False])},
{'mask': None}
)
class TestVisKeypoint(unittest.TestCase):
class TestVisPoint(unittest.TestCase):

def test_vis_keypoint(self):
def test_vis_point(self):
if optional_modules:
img = np.random.randint(
0, 255, size=(3, 32, 32)).astype(np.float32)
keypoint = np.random.uniform(size=(3, 2)).astype(np.float32)
ax = vis_keypoint(img, keypoint, self.kp_mask)
point = np.random.uniform(size=(3, 2)).astype(np.float32)
ax = vis_point(img, point, self.mask)

self.assertTrue(isinstance(ax, matplotlib.axes.Axes))

Expand Down

0 comments on commit 954eaf5

Please sign in to comment.