From 54baca13ffb26827d4b9531ed7b12ae7ff2091cf Mon Sep 17 00:00:00 2001
From: Peter Boerboom <peter.boerboom@venmo.com>
Date: Wed, 29 Nov 2017 12:28:56 -0500
Subject: [PATCH] Revise error handling approach.

---
 distributed_nose/plugin.py | 85 ++++++++++++++++----------------------
 tests/test_options.py      | 40 ++++++------------
 2 files changed, 48 insertions(+), 77 deletions(-)

diff --git a/distributed_nose/plugin.py b/distributed_nose/plugin.py
index a63d5aa..a1156aa 100644
--- a/distributed_nose/plugin.py
+++ b/distributed_nose/plugin.py
@@ -128,14 +128,9 @@ def configure(self, options, config):
             self.enabled = True
 
         if self.algorithm == self.ALGORITHM_LEAST_PROCESSING_TIME:
-            if not self.lpt_data_filepath:
-                logger.critical((
-                    "'--algorithm least-processing-time' requires "
-                    "'--lpt-data <lpt-data-filepath>' to be specified as "
-                    "well. Falling back to hash-ring algorithm."
-                ))
-                self.algorithm = self.ALGORITHM_HASH_RING
-            else:
+            assert self.lpt_data_filepath, "'--lpt-data' arg is set."
+
+            try:
                 # Set up the data structure for the nodes. Note that
                 # the 0th node is a dummy node. We do this since the nodes
                 # are 1-indexed and this prevents the need to do
@@ -148,51 +143,43 @@ def configure(self, options, config):
                     }
                     for _ in range(self.node_count + 1)
                 ]
-                try:
-                    with open(self.lpt_data_filepath) as f:
-                        self.lpt_data = json.load(f)
 
-                        # for now, lpt only operates at the class level
-                        self.hash_by_class = True
+                with open(self.lpt_data_filepath) as f:
+                    self.lpt_data = json.load(f)
 
-                        sorted_lpt_data = sorted(
-                            self.lpt_data.items(),
-                            key=lambda t: t[1]['duration'],
-                            reverse=True
-                        )
+                    # for now, lpt only operates at the class level
+                    self.hash_by_class = True
 
-                        for cls, data in sorted_lpt_data:
-                            node = min(
-                                self.lpt_nodes[1:],
-                                key=lambda n: n['processing_time']
-                            )
-                            node['processing_time'] += data['duration']
-                            node['classes'].add(cls)
-
-                except IOError:
-                    logger.critical(
-                        (
-                            "lpt-data file '%s' not found. "
-                            "Falling back to hash-ring algorithm."
-                        ),
-                        self.lpt_data_filepath
-                    )
-                    self.algorithm = self.ALGORITHM_HASH_RING
-                except ValueError as e:
-                    logger.critical(
-                        "%s. Falling back to hash-ring algorithm.",
-                        e
+                    sorted_lpt_data = sorted(
+                        self.lpt_data.items(),
+                        key=lambda t: t[1]['duration'],
+                        reverse=True
                     )
-                    self.algorithm = self.ALGORITHM_HASH_RING
-                except KeyError as e:
-                    logger.critical(
-                        (
-                            "%s. Invalid lpt data file. "
-                            "Falling back to hash-ring algorithm."
-                        ),
-                        e
-                    )
-                    self.algorithm = self.ALGORITHM_HASH_RING
+
+                    for cls, data in sorted_lpt_data:
+                        node = min(
+                            self.lpt_nodes[1:],
+                            key=lambda n: n['processing_time']
+                        )
+                        node['processing_time'] += data['duration']
+                        node['classes'].add(cls)
+
+            except IOError:
+                logger.critical(
+                    "lpt-data file '%s' not found. Aborting.",
+                    self.lpt_data_filepath
+                )
+                raise
+            except ValueError:
+                logger.critical(
+                    "Error decoding lpt-data file. Aborting."
+                )
+                raise
+            except KeyError:
+                logger.critical(
+                    "Invalid lpt data file. Aborting."
+                )
+                raise
 
         self.hash_ring = HashRing(range(1, self.node_count + 1))
 
diff --git a/tests/test_options.py b/tests/test_options.py
index 0d54ef3..665bbb1 100644
--- a/tests/test_options.py
+++ b/tests/test_options.py
@@ -113,7 +113,7 @@ def test_lpt_via_flag(self):
         )
         self.assertTrue(self.plugin.enabled)
 
-    def test_lpt_no_data_arg_falls_back(self):
+    def test_lpt_no_data_arg_aborts(self):
         env = {'NOSE_NODES': 6,
                'NOSE_NODE_NUMBER': 4}
         self.plugin.options(self.parser, env=env)
@@ -121,15 +121,11 @@ def test_lpt_no_data_arg_falls_back(self):
             '--algorithm=least-processing-time'
         ]
         options, _ = self.parser.parse_args(args)
-        self.plugin.configure(options, Config())
 
-        self.assertEqual(
-            self.plugin.algorithm,
-            DistributedNose.ALGORITHM_HASH_RING
-        )
-        self.assertTrue(self.plugin.enabled)
+        with self.assertRaises(AssertionError):
+            self.plugin.configure(options, Config())
 
-    def test_lpt_missing_data_file_falls_back(self):
+    def test_lpt_missing_data_file_aborts(self):
         LPT_DATA_FILEPATH = os.path.join(
             os.path.dirname(__file__),
             'no_such_file.json'
@@ -142,15 +138,11 @@ def test_lpt_missing_data_file_falls_back(self):
             '--lpt-data={}'.format(LPT_DATA_FILEPATH)
         ]
         options, _ = self.parser.parse_args(args)
-        self.plugin.configure(options, Config())
 
-        self.assertEqual(
-            self.plugin.algorithm,
-            DistributedNose.ALGORITHM_HASH_RING
-        )
-        self.assertTrue(self.plugin.enabled)
+        with self.assertRaises(IOError):
+            self.plugin.configure(options, Config())
 
-    def test_lpt_invalid_json_file_falls_back(self):
+    def test_lpt_invalid_json_file_aborts(self):
         LPT_DATA_FILEPATH = os.path.join(
             os.path.dirname(__file__),
             'lpt_data',
@@ -164,15 +156,11 @@ def test_lpt_invalid_json_file_falls_back(self):
             '--lpt-data={}'.format(LPT_DATA_FILEPATH)
         ]
         options, _ = self.parser.parse_args(args)
-        self.plugin.configure(options, Config())
 
-        self.assertEqual(
-            self.plugin.algorithm,
-            DistributedNose.ALGORITHM_HASH_RING
-        )
-        self.assertTrue(self.plugin.enabled)
+        with self.assertRaises(ValueError):
+            self.plugin.configure(options, Config())
 
-    def test_lpt_invalid_data_format_falls_back(self):
+    def test_lpt_invalid_data_format_aborts(self):
         LPT_DATA_FILEPATH = os.path.join(
             os.path.dirname(__file__),
             'lpt_data',
@@ -186,10 +174,6 @@ def test_lpt_invalid_data_format_falls_back(self):
             '--lpt-data={}'.format(LPT_DATA_FILEPATH)
         ]
         options, _ = self.parser.parse_args(args)
-        self.plugin.configure(options, Config())
 
-        self.assertEqual(
-            self.plugin.algorithm,
-            DistributedNose.ALGORITHM_HASH_RING
-        )
-        self.assertTrue(self.plugin.enabled)
+        with self.assertRaises(KeyError):
+            self.plugin.configure(options, Config())