Skip to content

Commit

Permalink
DecisionTree Regression
Browse files Browse the repository at this point in the history
  • Loading branch information
zafercavdar committed Nov 26, 2017
1 parent 199a219 commit 4f77068
Show file tree
Hide file tree
Showing 2 changed files with 286 additions and 0 deletions.
152 changes: 152 additions & 0 deletions 5-Decision Tree Regression/DTRegression.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Zafer Cavdar - COMP 421 Homework 5 - Decision Tree Regression
# Reference: Textbook Introduction to Machine Learning, Ethem Alpaydin

# QUESTION 1 - Read data, create train and test dataset.
data_set <- read.csv("hw05_data_set.csv")
x_all <- data_set$x
y_all <- data_set$y

set.seed(521)
train_indices <- sample(length(x_all), 100)
x_train <- x_all[train_indices]
y_train <- y_all[train_indices]
x_test <- x_all[-train_indices]
y_test <- y_all[-train_indices]
minimum_value <- floor(min(x_all)) - 2
maximum_value <- ceiling(max(x_all)) + 2
N_train <- length(x_train)
N_test <- length(x_test)

# QUESTION 2 - Implement Decision Tree Regression algorithm with P pruning parameter

DecisionTreeRegression <- function(P) {
# reset variables
node_splits <- c()
node_means <- c()

# put all training instances into the root node
node_indices <- list(1:N_train)
is_terminal <- c(FALSE)
need_split <- c(TRUE)

# learning algorithm
while (1) {
# find nodes that need splitting
split_nodes <- which(need_split)
# check whether we reach all terminal nodes
if (length(split_nodes) == 0) {
break
}
# find best split positions for all nodes
for (split_node in split_nodes) {
data_indices <- node_indices[[split_node]]
need_split[split_node] <- FALSE
node_mean <- mean(y_train[data_indices])
if (length(x_train[data_indices]) <= P) {
is_terminal[split_node] <- TRUE
node_means[split_node] <- node_mean
} else {
is_terminal[split_node] <- FALSE
unique_values <- sort(unique(x_train[data_indices]))
split_positions <- (unique_values[-1] + unique_values[-length(unique_values)]) / 2
split_scores <- rep(0, length(split_positions))
for (s in 1:length(split_positions)) {
left_indices <- data_indices[which(x_train[data_indices] <= split_positions[s])]
right_indices <- data_indices[which(x_train[data_indices] > split_positions[s])]
total_error <- 0
if (length(left_indices) > 0) {
mean <- mean(y_train[left_indices])
total_error <- total_error + sum((y_train[left_indices] - mean) ^ 2)
}
if (length(right_indices) > 0) {
mean <- mean(y_train[right_indices])
total_error <- total_error + sum((y_train[right_indices] - mean) ^ 2)
}
split_scores[s] <- total_error / (length(left_indices) + length(right_indices))
}
if (length(unique_values) == 1) {
is_terminal[split_node] <- TRUE
node_means[split_node] <- node_mean
next
}
best_split <- split_positions[which.min(split_scores)]
node_splits[split_node] <- best_split

# create left node using the selected split
left_indices <- data_indices[which(x_train[data_indices] < best_split)]
node_indices[[2 * split_node]] <- left_indices
is_terminal[2 * split_node] <- FALSE
need_split[2 * split_node] <- TRUE

# create right node using the selected split
right_indices <- data_indices[which(x_train[data_indices] >= best_split)]
node_indices[[2 * split_node + 1]] <- right_indices
is_terminal[2 * split_node + 1] <- FALSE
need_split[2 * split_node + 1] <- TRUE
}
}
}
result <- list("splits"= node_splits, "means"= node_means, "is_terminal"= is_terminal)
return(result)
}

# QUESTION 3 - Learn a DT with P = 10 and plot
P <- 10
result <- DecisionTreeRegression(P)
node_splits <- result$splits
node_means <- result$means
is_terminal <- result$is_terminal

# define regression function
get_prediction <- function(dp, is_terminal, node_splits, node_means){
index <- 1
while (1) {
if (is_terminal[index] == TRUE) {
return(node_means[index])
} else {
if (dp <= node_splits[index]) {
index <- index * 2
} else {
index <- index * 2 + 1
}
}
}
}

#plot train data, test data and fit in the figure
plot(x_train, y_train, type = "p", pch = 19, col = "blue",
ylim = c(min(y_train), max(y_train)), xlim = c(minimum_value, maximum_value),
ylab = "y", xlab = "x", las = 1)
points(x_test, y_test, type = "p", pch = 19, col= "red")
legend(55,85, legend=c("training", "test"),
col=c("blue", "red"), pch = 19, cex = 0.5, bty = "y")
grid_interval <- 0.01
data_interval <- seq(from = minimum_value, to = maximum_value, by = grid_interval)
for (b in 1:length(data_interval)) {
x_left <- data_interval[b]
x_right <- data_interval[b+1]
lines(c(x_left, x_right), c(get_prediction(x_left, is_terminal, node_splits, node_means), get_prediction(x_left, is_terminal, node_splits, node_means)), lwd = 2, col = "black")
if (b < length(data_interval)) {
lines(c(x_right, x_right), c(get_prediction(x_left, is_terminal, node_splits, node_means), get_prediction(x_right, is_terminal, node_splits, node_means)), lwd = 2, col = "black")
}
}

# QUESTION 4- Calculate RMSE for test data points
y_test_predicted <- sapply(X=1:N_test, FUN = function(i) get_prediction(x_test[i], is_terminal, node_splits, node_means))
RMSE <- sqrt(sum((y_test - y_test_predicted) ^ 2) / length(y_test))
sprintf("RMSE is %s when P is %s", RMSE, P)

# QUESTION 5 - P vs RMSE
RMSEs <- sapply(X=1:20, FUN = function(p) {
sprintf("Calculating RMSE for %d", p)
result <- DecisionTreeRegression(p)
node_splits <- result$splits
node_means <- result$means
is_terminal <- result$is_terminal
y_test_predicted <- sapply(X=1:N_test, FUN = function(i) get_prediction(x_test[i], is_terminal, node_splits, node_means))
RMSE <- sqrt(sum((y_test - y_test_predicted) ^ 2) / length(y_test))
})

plot(1:20, RMSEs,
type = "o", lwd = 1, las = 1, pch = 1, lty = 2,
xlab = "P", ylab = "RMSE")
134 changes: 134 additions & 0 deletions 5-Decision Tree Regression/hw05_data_set.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"x","y"
26.4,-65.6
6.2,-2.7
20.2,-123.1
46.6,10.7
33.4,16
16.8,-77.7
10,-2.7
15.8,-21.5
13.6,-2.7
55,-2.7
16.2,-50.8
8.2,-2.7
26.2,-107.1
47.8,-14.7
11,-5.4
15.4,-32.1
16.8,-91.1
32.8,46.9
35.2,-54.9
43,14.7
16.4,-5.4
16,-26.8
19.2,-123.1
45,10.7
16.2,-21.5
38,10.7
25.4,-72.3
7.8,-2.7
28.4,-21.5
9.6,-2.7
14.6,-13.3
23.2,-123.1
8.8,-2.7
29.4,-17.4
39.4,-1.3
3.2,-2.7
15.6,-40.2
57.6,10.7
38,46.9
36.2,-37.5
23.4,-128.5
35.6,32.1
25.4,-44.3
24.2,-81.8
15.6,-21.5
33.8,45.6
21.2,-134
19.4,-72.3
13.8,0
17.8,-99.1
25,-64.4
17.6,-37.5
22,-123.1
4,-2.7
27.2,-24.2
31.2,8.1
17.6,-123.1
42.8,-10.7
17.6,-85.6
2.4,0
55.4,-2.7
17.6,-101.9
14.8,-2.7
13.2,-2.7
19.6,-127.2
52,10.7
40,-21.5
35.2,-16
47.8,-26.8
24.2,-95.1
35.4,69.6
32,54.9
55,10.7
30.2,36.2
16.6,-59
26,-5.4
15.4,-53.5
53.2,-14.7
15.8,-50.8
15.4,-54.9
10.6,-2.7
6.8,-1.3
2.6,-1.3
20.4,-117.9
6.6,-2.7
28.4,37.5
27,-16
10.2,-5.4
14.6,-5.4
31,75
11.4,0
14.6,-16
48.8,-13.3
18.6,-112.5
27.6,4
39.2,5.4
34.4,1.3
41.6,-10.7
44,-1.3
27.2,9.5
26.2,-21.5
19.4,-85.6
32,48.2
42.4,29.4
24.6,-53.5
36.2,22.8
3.6,0
21.4,-101.9
21.8,-108.4
14.6,-22.8
16.2,-61.7
16,-42.9
40.4,-13.3
14.6,-9.3
41.6,30.8
16.4,-80.4
44.4,0
25,-57.6
17.8,-104.4
25.6,-26.8
14.6,-5.4
8.8,-1.3
34.8,75
35.6,34.8
28.6,46.9
50.6,0
27.2,-45.6
18.6,-50.8
16.8,-71
24,-112.5
42.8,0
28.2,12
15.4,-22.8

0 comments on commit 4f77068

Please sign in to comment.