Skip to content

Commit

Permalink
digit recognizer working
Browse files Browse the repository at this point in the history
  • Loading branch information
davecom committed May 2, 2024
1 parent 10dcfab commit 7d55e39
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 11 deletions.
58 changes: 53 additions & 5 deletions KNN/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,58 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from KNN.knn import KNN
from KNN.digit import Digit
import os
import sys
import pygame
import numpy as np

PIXEL_WIDTH = 8
PIXEL_HEIGHT = 8
P_TO_D = 16 / 255 # pixel to digit scale factor
D_TO_P = 255 / 16 # digit to pixel scale factor
K = 9
WHITE = (255, 255, 255)


def run():
# Create a 2D array of pixels to represent the digit
digit_pixels = np.zeros((PIXEL_HEIGHT, PIXEL_WIDTH, 3), dtype=np.uint32)
# Load the training data
os.chdir(os.path.dirname(os.path.abspath(__file__)))
digits_knn = KNN(Digit, './datasets/digits/digits.csv', has_header=False)
# Startup Pygame, create the window
pygame.init()
screen = pygame.display.set_mode(size=(PIXEL_WIDTH, PIXEL_HEIGHT), flags=pygame.SCALED | pygame.RESIZABLE)
pygame.display.set_caption("Digit Recognizer")
while True:
pygame.surfarray.blit_array(screen, digit_pixels)
pygame.display.flip()

# Handle keyboard events
for event in pygame.event.get():
if event.type == pygame.KEYDOWN:
key_name = pygame.key.name(event.key)
if key_name == "c": # classify the digit
pixels = digit_pixels.transpose((1, 0, 2))[:, :, 0].flatten() * P_TO_D
classified_digit = digits_knn.classify(K, Digit("", pixels))
print(f"Classified as {classified_digit}")
elif key_name == "e": # erase the digit
digit_pixels.fill(0)
elif key_name == "p": # predict what the digit should look like
pixels = digit_pixels.transpose((1, 0, 2))[:, :, 0].flatten() * P_TO_D
predicted_pixels = digits_knn.predict(K, Digit("", pixels), "pixels")
predicted_pixels = predicted_pixels.reshape((PIXEL_HEIGHT, PIXEL_WIDTH)).transpose((1, 0)) * D_TO_P
digit_pixels = np.stack((predicted_pixels, predicted_pixels, predicted_pixels), axis=2)
# Handle mouse events
elif ((event.type == pygame.MOUSEBUTTONDOWN) or
(event.type == pygame.MOUSEMOTION and pygame.mouse.get_pressed()[0])):
x, y = event.pos
if x < PIXEL_WIDTH and y < PIXEL_HEIGHT:
digit_pixels[x][y] = WHITE
elif event.type == pygame.QUIT:
sys.exit()


if __name__ == "__main__":
# Parse the file argument
file_parser = ArgumentParser("NanoBASIC")
file_parser.add_argument("basic_file", help="A text file containing NanoBASIC source code.")
arguments = file_parser.parse_args()
execute(arguments.basic_file)
run()
10 changes: 6 additions & 4 deletions KNN/digit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from math import dist
from typing import Self
import numpy as np


@dataclass
class Digit:
kind: str
pixels: list[int]
pixels: np.ndarray

@classmethod
def from_string_data(cls, data: list[str]) -> Self:
return cls(kind=data[64], pixels=[int(x) for x in data[:64]])
return cls(kind=data[64],
pixels=np.array(data[:64], dtype=np.uint32))

def distance(self, other: Self) -> float:
return dist(self.pixels, other.pixels)
tmp = self.pixels - other.pixels
return np.sqrt(np.dot(tmp.T, tmp))
5 changes: 3 additions & 2 deletions KNN/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import csv
from typing import Protocol, Self
from pathlib import Path
from collections.abc import Iterable


class DataPoint(Protocol):
Expand Down Expand Up @@ -56,10 +57,10 @@ def classify(self, k: int, data_point: DP) -> str:
kinds[neighbor.kind] += 1
else:
kinds[neighbor.kind] = 1
return max(kinds, key=kinds.get)
return max(kinds, key=kinds.get) # type: ignore

# Predict a property of a data point based on the k nearest neighbors
# Find the average of that property from the neighbors and return it
def predict(self, k: int, data_point: DP, property_name: str) -> float:
def predict(self, k: int, data_point: DP, property_name: str) -> float | Iterable:
neighbors = self.nearest(k, data_point)
return sum([getattr(neighbor, property_name) for neighbor in neighbors]) / len(neighbors)

0 comments on commit 7d55e39

Please sign in to comment.