diff --git a/KNN/__main__.py b/KNN/__main__.py index 732c6c5..bbe8335 100644 --- a/KNN/__main__.py +++ b/KNN/__main__.py @@ -15,7 +15,7 @@ # limitations under the License. from KNN.knn import KNN from KNN.digit import Digit -import os +from pathlib import Path import sys import pygame import numpy as np @@ -33,9 +33,9 @@ def run(): digit_pixels = np.zeros((PIXEL_HEIGHT, PIXEL_WIDTH, 3), dtype=np.uint32) # Load the training data - os.chdir(os.path.dirname(os.path.abspath(__file__))) - digits_knn = KNN(Digit, './datasets/digits/digits.csv', - has_header=False) + digits_file = (Path(__file__).resolve().parent + / "datasets" / "digits" / "digits.csv") + digits_knn = KNN(Digit, digits_file, has_header=False) # Startup Pygame, create the window pygame.init() screen = pygame.display.set_mode(size=(PIXEL_WIDTH, PIXEL_HEIGHT), diff --git a/tests/test_knn.py b/tests/test_knn.py index 15be71d..a83e63f 100644 --- a/tests/test_knn.py +++ b/tests/test_knn.py @@ -31,7 +31,7 @@ def setUp(self) -> None: def test_nearest(self): k: int = 3 - fish_knn = KNN(Fish, str(self.data_file)) + fish_knn = KNN(Fish, self.data_file) test_fish: Fish = Fish("", 0.0, 30.0, 32.5, 38.0, 12.0, 5.0) nearest_fish: list[Fish] = fish_knn.nearest(k, test_fish) self.assertEqual(len(nearest_fish), k) @@ -42,7 +42,7 @@ def test_nearest(self): def test_classify(self): k: int = 5 - fish_knn = KNN(Fish, str(self.data_file)) + fish_knn = KNN(Fish, self.data_file) test_fish: Fish = Fish("", 0.0, 20.0, 23.5, 24.0, 10.0, 4.0) classify_fish: str = fish_knn.classify(k, test_fish) self.assertEqual(classify_fish, "Parkki")