diff --git a/chainercv/visualizations/__init__.py b/chainercv/visualizations/__init__.py index 1572a7ff8f..9449ed8a8e 100644 --- a/chainercv/visualizations/__init__.py +++ b/chainercv/visualizations/__init__.py @@ -1,4 +1,4 @@ from chainercv.visualizations.vis_bbox import vis_bbox # NOQA from chainercv.visualizations.vis_image import vis_image # NOQA -from chainercv.visualizations.vis_keypoint import vis_keypoint # NOQA +from chainercv.visualizations.vis_point import vis_point # NOQA from chainercv.visualizations.vis_semantic_segmentation import vis_semantic_segmentation # NOQA diff --git a/chainercv/visualizations/vis_keypoint.py b/chainercv/visualizations/vis_keypoint.py deleted file mode 100644 index 6974593652..0000000000 --- a/chainercv/visualizations/vis_keypoint.py +++ /dev/null @@ -1,60 +0,0 @@ -import numpy as np -import six - -from chainercv.visualizations.vis_image import vis_image - - -def vis_keypoint(img, keypoint, kp_mask=None, ax=None): - """Visualize keypoints in an image. - - Example: - - >>> import chainercv - >>> import matplotlib.pyplot as plot - >>> dataset = chainercv.datasets.CUBKeypointDataset() - >>> img, keypoint, kp_mask = dataset[0] - >>> chainercv.visualizations.vis_keypoint(img, keypoint, kp_mask) - >>> plot.show() - - Args: - img (~numpy.ndarray): An image of shape :math:`(3, height, width)`. - This is in RGB format and the range of its value is - :math:`[0, 255]`. This should be visualizable using - :obj:`matplotlib.pyplot.imshow(img)` - keypoint (~numpy.ndarray): An array with keypoint pairs whose shape is - :math:`(K, 2)`, where :math:`K` is - the number of keypoints in the array. - The second axis corresponds to :math:`y` and :math:`x` coordinates - of the keypoint. - kp_mask (~numpy.ndarray, optional): A boolean array whose shape is - :math:`(K,)`. If :math:`i` th index is :obj:`True`, the - :math:`i` th keypoint is not displayed. If not specified, - all keypoints in :obj:`keypoint` will be displayed. - ax (matplotlib.axes.Axes, optional): If provided, plot on this axis. - - Returns: - ~matploblib.axes.Axes: - Returns the Axes object with the plot for further tweaking. - - """ - import matplotlib.pyplot as plot - # Returns newly instantiated matplotlib.axes.Axes object if ax is None - ax = vis_image(img, ax=ax) - - _, H, W = img.shape - n_kp = len(keypoint) - - if kp_mask is None: - kp_mask = np.ones((n_kp,), dtype=np.bool) - - cm = plot.get_cmap('gist_rainbow') - - colors = [cm(1. * i / n_kp) for i in six.moves.range(n_kp)] - - for i in range(n_kp): - if kp_mask[i]: - ax.scatter(keypoint[i][1], keypoint[i][0], c=colors[i], s=100) - - ax.set_xlim(left=0, right=W) - ax.set_ylim(bottom=H - 1, top=0) - return ax diff --git a/chainercv/visualizations/vis_point.py b/chainercv/visualizations/vis_point.py new file mode 100644 index 0000000000..3f8c2a73c3 --- /dev/null +++ b/chainercv/visualizations/vis_point.py @@ -0,0 +1,62 @@ +from __future__ import division + +import numpy as np +import six + +from chainercv.visualizations.vis_image import vis_image + + +def vis_point(img, point, mask=None, ax=None): + """Visualize points in an image. + + Example: + + >>> import chainercv + >>> import matplotlib.pyplot as plot + >>> dataset = chainercv.datasets.CUBKeypointDataset() + >>> img, point, mask = dataset[0] + >>> chainercv.visualizations.vis_point(img, point, mask) + >>> plot.show() + + Args: + img (~numpy.ndarray): An image of shape :math:`(3, height, width)`. + This is in RGB format and the range of its value is + :math:`[0, 255]`. This should be visualizable using + :obj:`matplotlib.pyplot.imshow(img)` + point (~numpy.ndarray): An array of point coordinates whose shape is + :math:`(P, 2)`, where :math:`P` is + the number of points. + The second axis corresponds to :math:`y` and :math:`x` coordinates + of the points. + mask (~numpy.ndarray): A boolean array whose shape is + :math:`(P,)`. If :math:`i` th element is :obj:`True`, the + :math:`i` th point is not displayed. If not specified, + all points in :obj:`point` will be displayed. + ax (matplotlib.axes.Axes): If provided, plot on this axis. + + Returns: + ~matploblib.axes.Axes: + Returns the Axes object with the plot for further tweaking. + + """ + import matplotlib.pyplot as plot + # Returns newly instantiated matplotlib.axes.Axes object if ax is None + ax = vis_image(img, ax=ax) + + _, H, W = img.shape + n_point = len(point) + + if mask is None: + mask = np.ones((n_point,), dtype=np.bool) + + cm = plot.get_cmap('gist_rainbow') + + colors = [cm(i / n_point) for i in six.moves.range(n_point)] + + for i in range(n_point): + if mask[i]: + ax.scatter(point[i][1], point[i][0], c=colors[i], s=100) + + ax.set_xlim(left=0, right=W) + ax.set_ylim(bottom=H - 1, top=0) + return ax diff --git a/docs/source/reference/visualizations.rst b/docs/source/reference/visualizations.rst index a673a8d88f..c00682e6be 100644 --- a/docs/source/reference/visualizations.rst +++ b/docs/source/reference/visualizations.rst @@ -12,9 +12,9 @@ vis_image ~~~~~~~~~ .. autofunction:: vis_image -vis_keypoint -~~~~~~~~~~~~ -.. autofunction:: vis_keypoint +vis_point +~~~~~~~~~ +.. autofunction:: vis_point vis_semantic_segmentation ~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/visualizations_tests/test_vis_keypoint.py b/tests/visualizations_tests/test_vis_point.py similarity index 57% rename from tests/visualizations_tests/test_vis_keypoint.py rename to tests/visualizations_tests/test_vis_point.py index 2cab9580b5..ffdddff089 100644 --- a/tests/visualizations_tests/test_vis_keypoint.py +++ b/tests/visualizations_tests/test_vis_point.py @@ -4,7 +4,7 @@ from chainer import testing -from chainercv.visualizations import vis_keypoint +from chainercv.visualizations import vis_point try: import matplotlib # NOQA @@ -14,17 +14,17 @@ @testing.parameterize( - {'kp_mask': np.array([True, True, False])}, - {'kp_mask': None} + {'mask': np.array([True, True, False])}, + {'mask': None} ) -class TestVisKeypoint(unittest.TestCase): +class TestVisPoint(unittest.TestCase): - def test_vis_keypoint(self): + def test_vis_point(self): if optional_modules: img = np.random.randint( 0, 255, size=(3, 32, 32)).astype(np.float32) - keypoint = np.random.uniform(size=(3, 2)).astype(np.float32) - ax = vis_keypoint(img, keypoint, self.kp_mask) + point = np.random.uniform(size=(3, 2)).astype(np.float32) + ax = vis_point(img, point, self.mask) self.assertTrue(isinstance(ax, matplotlib.axes.Axes))