Skip to content

Commit

Permalink
Check in of working BTW with loss and regularizers
Browse files Browse the repository at this point in the history
  • Loading branch information
Sebastian Hönel committed Jan 19, 2021
1 parent ae52752 commit 74984c2
Show file tree
Hide file tree
Showing 3 changed files with 243 additions and 47 deletions.
2 changes: 1 addition & 1 deletion models/BTW.R
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ L_final_log <- function(
lb <- min(theta_b_org)
ub <- max(theta_b_org)

loss <- loss + weightR3 * log(1 + sum(R(lb - vts) + R(vts - lb) + R(lb - vte) + R(vte - ub)))
loss <- loss + weightR4 * log(1 + sum(R(lb - vts) + R(vts - lb) + R(lb - vte) + R(vte - ub)))
}

loss
Expand Down
247 changes: 208 additions & 39 deletions models/BTW_updated.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

M_updated <- function(theta_b_org, theta_b, r, f, num_samples = 1e3, zNormalize = FALSE, valueForZeroLenIntervals = NA_real_) {
M_updated <- function(theta_b_org, theta_b, r, f, num_samples = 1e3, zNormalize = FALSE, valueForZeroLenIntervals = 0) {
stopifnot(length(theta_b_org) == length(theta_b))
stopifnot(is.function(r) && is.function(f))
stopifnot(num_samples >= length(theta_b))
Expand All @@ -16,88 +16,182 @@ M_updated <- function(theta_b_org, theta_b, r, f, num_samples = 1e3, zNormalize
theta_b_org[b + 1] - theta_b_org[b]
})

Q <- seq_len(length.out = length(theta_l))
X_r <- range(theta_b_org)
x_r <- seq(from = X_r[1], to = X_r[2], length.out = num_samples)
Q <- seq_len(length.out = length(theta_l))
X_q <- range(theta_b)
x_q <- seq(from = X_q[1], to = X_q[2], length.out = num_samples)

x_q_used <- c()

y <- Vectorize(r)(x_r)
y <- Vectorize(r)(x_q)
y_hat <- rep(NA_real_, num_samples)
q_idx <- rep(NA, num_samples)
X_q <- c()


i <- 1
for (q in Q) {
last_q <- q == rev(Q)[1]
b_os <- theta_b_org[q]
b_oe <- theta_b_org[q + 1]
b_qs <- theta_b[q]
b_qe <- theta_b[q + 1]
b_os <- theta_b_org[q]
b_oe <- theta_b_org[q + 1]

x_r_supp <- if (last_q) {
num_samples_in_q <- length(if (last_q) {
x_r[x_r >= b_os & x_r <= b_oe]
} else {
x_r[x_r >= b_os & x_r < b_oe]
})

q_idx[i:(i + num_samples_in_q - 1)] <- q

if ((b_qe - b_qs) == 0) {
y_hat[i:(i + num_samples_in_q - 1)] <- valueForZeroLenIntervals
x_q_used <- c(x_q_used, rep(b_qs, num_samples_in_q))
i <- i + num_samples_in_q
next
}
supp_len <- length(x_r_supp)
stopifnot(supp_len > 0)

bq_range <- range(b_qs, b_qe)
bq_rev <- b_qs > b_qe

x_q_supp <- seq(
from = bq_range[1], to = bq_range[2], length.out = num_samples_in_q)
x_q_used <- c(x_q_used, x_q_supp)

l_q <- theta_l[q]
phi_q <- if (q == 1) 0 else sum(theta_l[1:(q - 1)])
de_o <- theta_l_org[q]

samp_func <- Vectorize(function(x) {
temp <- de_o * (x - s - phi_q) / l_q + b_os
use_l_q <- if (l_q == 0) 1 else l_q

temp <- de_o * (x - s - phi_q) / use_l_q + b_os
f(temp)
})

# Let's check the new query interval.
if (l_q == 0) {
# It was requested to have a length of 0, but we need to
# sample 'supp_len' number of elements, which we cannot.
# We repeat 'valueForZeroLenIntervals' this element instead.
y_hat[i:(i + supp_len - 1)] <- valueForZeroLenIntervals
X_q <- c(X_q, rep(b_qs, supp_len))
} else {
# An interval of negative length means that it goes backward.
# This case should be caught by the regularizers. However, we
# can still return some meaningful values, by sampling from
# the range of the interval, and then reversing the values.

# This also works if b_qs > b_qe ..
x_q <- seq(from = b_qs, to = b_qe, length.out = supp_len)
X_q <- c(X_q, x_q) # append support used (before reversing)

y_hat[i:(i + supp_len - 1)] <- samp_func(x_r_supp)
use_vals <- samp_func(x_q_supp)
if (bq_rev) {
use_vals <- rev(use_vals)
}

q_idx[i:(i + supp_len - 1)] <- q
y_hat[i:(i + num_samples_in_q - 1)] <- use_vals

i <- i + supp_len
i <- i + num_samples_in_q
}

f_hat <- stats::approxfun(x = x_q_used, y = y_hat, ties = mean, rule = 2)
y_hat <- sapply(x_r, f_hat)

if (zNormalize) {
y <- (y - mean(y)) / sd(y)
y_hat <- (y_hat - mean(y_hat)) / sd(y_hat)
}

data.frame(
x_r = x_r,
X_q = X_q,
X = x_r,
X_q = x_q_used,
y = y,
y_hat = y_hat,
q_idx = factor(x = q_idx, levels = paste0(1:length(theta_l)), ordered = TRUE)
)
}


M_final_no_NA <- function(
theta_b_org, theta_b, r, f, num_samples = 1e3, zNormalize = FALSE
) {
stopifnot(length(theta_b_org) == length(theta_b))
stopifnot(is.function(r) && is.function(f))
stopifnot(num_samples >= length(theta_b))

X <- seq(from = min(theta_b), to = max(theta_b), length.out = num_samples)
y <- Vectorize(r)(X)

# Contrary to the pseudo-code, we will create a list of
# all intervals and scaled+translated functions to look
# them up, that's much faster.

f_primes <- list()
for (iIdx in seq_len(length(theta_b) - 1)) {
f_primes[[paste0(iIdx)]] <- (function(q) {

b_os <- theta_b_org[q]
b_oe <- theta_b_org[q + 1]
b_qs <- theta_b[q]
b_qe <- theta_b[q + 1]
de_o <- b_oe - b_os
de_q <- b_qe - b_qs
frac <- if (de_q == 0) de_o / .Machine$double.eps else de_o / de_q

function(x) {
temp <- (x - b_qs) * frac + b_os
if (is.na(temp)) {
stop(c(x, b_qs, frac, b_os))
}
f(temp)
}
})(iIdx)
}


# Now we can sample from all intervals. Note that this
# Model stays in an interval for as long as it is valid.
# If a subsequent interval overlaps, it will not switch
# into it.

temp <- sapply(seq_len(length.out = length(theta_b) - 1), function(t) {
range(theta_b[t], theta_b[t + 1])
})
interval_supports <- matrix(data = temp, ncol = 2, byrow = TRUE)

determine_interval_for_x <- function(x) {
for (r in seq_len(nrow(interval_supports))) {
is_last <- r == nrow(interval_supports)
if (is_last) {
return(paste0(r))
} else {
if (x >= interval_supports[r, 1] && x < interval_supports[r, 2]) {
return(paste0(r))
}
}
}
stop("Should never get here!")
}

y_hat <- rep(NA, num_samples)
int_idx <- rep(NA, num_samples)
for (xIdx in seq_len(num_samples)) {
iIdx <- determine_interval_for_x(X[xIdx])
y_hat[xIdx] <- f_primes[[iIdx]](X[xIdx])
int_idx[xIdx] <- iIdx
}

if (zNormalize) {
y <- (y - mean(y)) / sd(y)
y_hat <- (y_hat - mean(y_hat)) / sd(y_hat)
}

stopifnot(!any(is.na(y_hat)))
data.frame(
X = X, # return the support used
y = y,
y_hat = y_hat,
int_idx = factor(x = int_idx, levels = paste0(1:(length(theta_b) - 1)), ordered = TRUE)
)
}


L_updated_log <- function(
theta_b_org,
theta_b,
r, f,
weightErr = 1
weightErr = 1,
weightR4 = 1,
weightR5 = 1
) {
loss_raw <- 0
loss <- 0
res <- M_updated(theta_b_org = theta_b_org, theta_b = theta_b, r = r, f = f)
res <- M_final_no_NA(theta_b_org = theta_b_org, theta_b = theta_b, r = r, f = f)

# ######## KL
# idx_not_NA <- !(is.na(res$y) | is.na(res$y_hat))
Expand Down Expand Up @@ -133,9 +227,84 @@ L_updated_log <- function(
#loss <- weightErr * log(1 + sum(na.omit(res$y - res$y_hat)^2) / numData)


######## RSS
numData <- sum(complete.cases(res$y_hat))
loss <- loss + weightErr * log(1 + sum(na.omit(res$y - res$y_hat)^2))
####### RSS
loss_raw <- loss_raw + weightErr * sum((res$y - res$y_hat)^2)
loss <- loss + weightErr * log(1 + sum((res$y - res$y_hat)^2))




vts <- utils::head(theta_b, -1)
vte <- utils::tail(theta_b, -1)

s <- theta_b[1]
theta_l <- sapply(seq_len(length(theta_b) - 1), function(b) {
theta_b[b + 1] - theta_b[b]
})
theta_l_org <- sapply(seq_len(length(theta_b_org) - 1), function(b) {
theta_b_org[b + 1] - theta_b_org[b]
})

# ####### R1:
# p_phi <- function(v1, v2) H(R(v1 - v2)) * R(v1 - v2) + H(R(v2 - v1)) * R(v2 - v1)
#
# vts_o <- sort(vts)
# vts_or <- rev(vts_o)
# vte_o <- sort(vte)
# vte_or <- rev(vte_o)
#
# v_s <- p_phi(vts, vts_o)
# v_e <- p_phi(vte, vte_o)
# u_s <- p_phi(vts_o, vts_or)
# u_e <- p_phi(vte_o, vte_or)
#
# eps <- .Machine$double.eps
# temp <- sum(v_s + v_e) / sum(u_s + u_e)
# temp <- temp * (1 - eps)
#
# loss <- loss - log(1 - temp)


# ####### R2:
# X_r <- range(theta_b_org)
# X_q <- range(theta_b)
# eps <- .Machine$double.eps
# temp <- (X_q[2] - X_q[1]) / (X_r[2] - X_r[1])
# temp <- eps + temp * (1 - eps)
# loss <- loss + abs(log(temp))


# ###### R3:
# X_r <- range(theta_b_org)
# mu <- (X_r[2] - X_r[1]) / (length(theta_b_org) - 1)
# loss <- loss + log(1 + sum((vte - vts - mu)^2))



###### R4 (Box bounds):
lb <- min(theta_b_org)
ub <- max(theta_b_org)
temp <- abs(theta_b[theta_b < lb | theta_b > ub])
loss_raw <- loss_raw + weightR4 * (sum(1 + temp)^length(temp) - 1)
loss <- loss + weightR4 * log(sum(1 + temp)^length(temp))


###### R5 (neg Intervals):
neg_l <- abs(theta_l[theta_l < 0])
loss_raw <- loss_raw + weightR5 * (sum(1 + neg_l)^length(neg_l) - 1)
loss <- loss + weightR5 * log(sum(1 + neg_l)^length(neg_l))






loss
}
}







41 changes: 34 additions & 7 deletions notebooks/boundary-time-warping_final-update.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ $$
\\
&\;\text{(concatenate all models' results), and finally}
\\[1ex]
\mathsf{M}(\bm{\theta},\bm{\vartheta},\;r,f)=&\;\Big[\,\mathbf{y}^\top,\hat{\mathbf{y}}^\top\Big]\;\text{, compute the reference- and transformed query-signal.}
\mathsf{M}(\bm{\theta},\bm{\vartheta}_L,s,\;r,f)=&\;\Big[\,\mathbf{y}^\top,\hat{\mathbf{y}}^\top\Big]\;\text{, compute the reference- and transformed query-signal.}
\end{aligned}
$$

Expand Down Expand Up @@ -167,7 +167,7 @@ ggarrange(
ncol = 1,
plotBtw(data.frame(x = temp$X, y = temp$y)),
plotBtw(data.frame(x = temp$X, y = temp$y_hat), bounds = query_bounds),
plotBtw(data.frame(x = temp2$X, y = temp2$y_hat), bounds = temp2_bounds)
plotBtw(data.frame(x = temp$X, y = temp2$y_hat), bounds = temp2_bounds)
)
```

Expand Down Expand Up @@ -245,18 +245,45 @@ A simple model using only this loss for the data and no regularizers should conv
## Testing without gradient

```{r}
optR <- stats::optim(
control = (maxit = 1e3),
method = "L-BFGS-B",
cl <- parallel::makePSOCKcluster(4)
parallel::clusterExport(cl, varlist = c("M_updated", "query_bounds", "r", "f", "Stabilize", "signal_ref", "signal_query", "L_updated_log", "H", "R"))
set.seed(1337)
optRp <- optimParallel::optimParallel(
par = query_bounds,
fn = function(x) {
L_updated_log(theta_b_org = query_bounds, theta_b = x, r = r, f = f)
},
lower = rep(0, length(query_bounds)),
upper = rep(1, length(query_bounds)),
parallel = list(
cl = cl,
forward = FALSE,
loginfo = TRUE
)
)
parallel::stopCluster(cl)
optRp
```


```{r}
optR <- stats::optim(
control = list(maxit = 1e3, abstol = 1e-3),
method = "BFGS",
#lower = rep(0, length(query_bounds)),
#upper = rep(1, length(query_bounds)),
par = query_bounds,
fn = function(x) {
loss <- L_updated_log(
theta_b_org = query_bounds,
theta_b = x,
r = r,
f = f
f = f,
weightErr = .75,
weightR4 = .2,
weightR5 = 1
)
print(loss)
loss
Expand All @@ -265,7 +292,7 @@ optR <- stats::optim(
```

```{r}
temp3 <- M_updated(
temp3 <- M_final_no_NA(
theta_b_org = query_bounds,
theta_b = optR$par,
r = signal_ref,
Expand Down

0 comments on commit 74984c2

Please sign in to comment.