From 97f5dc7f2441fb3f4def5b85ac18854d342ae0ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Arroyo=20Men=C3=A9ndez?= Date: Wed, 15 May 2019 12:49:32 +0200 Subject: [PATCH] ai/lda.py: creation --- ai/lda.py | 56 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 ai/lda.py diff --git a/ai/lda.py b/ai/lda.py new file mode 100644 index 0000000..009dc08 --- /dev/null +++ b/ai/lda.py @@ -0,0 +1,56 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright (C) 2018 David Arroyo Menéndez + +# Author: David Arroyo Menéndez +# Maintainer: David Arroyo Menéndez + +# This file is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 3, or (at your option) +# any later version. + +# This file is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. + +# You should have received a copy of the GNU General Public License +# along with GNU Emacs; see the file COPYING. If not, write to +# the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, +# Boston, MA 02110-1301 USA, + +import numpy as np +import matplotlib.pyplot as plt +from matplotlib import style +style.use('fivethirtyeight') +np.random.seed(seed=42) +# Create data +rectangles = np.array([[1,1.5,1.7,1.45,1.1,1.6,1.8],[1.8,1.55,1.45,1.6,1.65,1.7,1.75]]) +triangles = np.array([[0.1,0.5,0.25,0.4,0.3,0.6,0.35,0.15,0.4,0.5,0.48],[1.1,1.5,1.3,1.2,1.15,1.0,1.4,1.2,1.3,1.5,1.0]]) +circles = np.array([[1.5,1.55,1.52,1.4,1.3,1.6,1.35,1.45,1.4,1.5,1.48,1.51,1.52,1.49,1.41,1.39,1.6,1.35,1.55,1.47,1.57,1.48, + 1.55,1.555,1.525,1.45,1.35,1.65,1.355,1.455,1.45,1.55,1.485,1.515,1.525,1.495,1.415,1.395,1.65,1.355,1.555,1.475,1.575,1.485] + ,[1.3,1.35,1.33,1.32,1.315,1.30,1.34,1.32,1.33,1.35,1.30,1.31,1.35,1.33,1.32,1.315,1.38,1.34,1.28,1.23,1.25,1.29, + 1.35,1.355,1.335,1.325,1.3155,1.305,1.345,1.325,1.335,1.355,1.305,1.315,1.355,1.335,1.325,1.3155,1.385,1.345,1.285,1.235,1.255,1.295]]) +#Plot the data +fig = plt.figure(figsize=(10,10)) +ax0 = fig.add_subplot(111) +ax0.scatter(rectangles[0],rectangles[1],marker='s',c='grey',edgecolor='black') +ax0.scatter(triangles[0],triangles[1],marker='^',c='yellow',edgecolor='black') +ax0.scatter(circles[0],circles[1],marker='o',c='blue',edgecolor='black') +# Calculate the mean vectors per class +mean_rectangles = np.mean(rectangles,axis=1).reshape(2,1) # Creates a 2x1 vector consisting of the means of the dimensions +mean_triangles = np.mean(triangles,axis=1).reshape(2,1) +mean_circles = np.mean(circles,axis=1).reshape(2,1) +# Calculate the scatter matrices for the SW (Scatter within) and sum the elements up +scatter_rectangles = np.dot((rectangles-mean_rectangles),(rectangles-mean_rectangles).T) +# Mind that we do not calculate the covariance matrix here because then we have to divide by n or n-1 as shown below +#print((1/7)*np.dot((rectangles-mean_rectangles),(rectangles-mean_rectangles).T)) +#print(np.var(rectangles[0],ddof=0)) +scatter_triangles = np.dot((triangles-mean_triangles),(triangles-mean_triangles).T) +scatter_circles = np.dot((circles-mean_circles),(circles-mean_circles).T) +# Calculate the SW by adding the scatters within classes +SW = scatter_triangles+scatter_circles+scatter_rectangles +print(SW) +plt.show()