diff --git a/halo/halo.py b/halo/halo.py index 640d1af..6e0140c 100644 --- a/halo/halo.py +++ b/halo/halo.py @@ -261,6 +261,8 @@ def _get_spinner(self, spinner): Contains frames and interval defining spinner """ default_spinner = Spinners['dots'].value + windows_spinners = ['balloon', 'balloon2', 'bouncingBar', 'dqpb', 'flip', 'layer', 'line', 'pipe', + 'simpleDots', 'simpleDotsScrolling', 'star2', 'shark', 'toggle13'] if spinner and type(spinner) == dict: return spinner @@ -271,7 +273,10 @@ def _get_spinner(self, spinner): else: return default_spinner else: - return Spinners['line'].value + if all([is_text_type(spinner), spinner in Spinners.__members__, spinner in windows_spinners]): + return Spinners[spinner].value + else: + return Spinners['line'].value def _get_text(self, text): """Creates frames based on the selected animation diff --git a/requirements-dev.txt b/requirements-dev.txt index cac91bf..6bf545f 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,5 @@ coverage==4.4.1 nose==1.3.7 -pylint==1.7.2 +pylint==1.9.4 tox==2.8.2 twine==1.12.1 diff --git a/tests/test_halo.py b/tests/test_halo.py index d357c0d..dde97c1 100644 --- a/tests/test_halo.py +++ b/tests/test_halo.py @@ -551,6 +551,33 @@ def test_redirect_stdout(self): self.assertIn('foo', output[0]) + def test_windows_whitelist(self): + """Test whitelist of Windows-compatible spinners + """ + if not is_supported(): + instance = Halo() + default_spinner_value = "line" + + instance.spinner = default_spinner_value + self.assertEqual(default_spinner, instance.spinner) + + instance.spinner = "balloon" + self.assertEqual(Spinners['balloon'].value, instance.spinner) + + instance.spinner = "monkey" + self.assertEqual(default_spinner, instance.spinner) + + spinner = Halo(text='foo', spinner='balloon2', stream=self._stream) + frames_ = [get_coded_text(frame_) for frame_ in Spinners['balloon2'].value['frames']] + + spinner.start() + time.sleep(1) + spinner.stop() + output = self._get_test_output()['text'] + + for i in range(len(frames_)): + self.assertEqual(output[i], '{0} foo'.format(frames_[i])) + def tearDown(self): pass