Skip to content

Commit

Permalink
sort prob table by prob, add option to sport by support; closes #58
Browse files Browse the repository at this point in the history
  • Loading branch information
mhtess committed Jun 28, 2018
1 parent ba3727f commit ccc92de
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 11 deletions.
26 changes: 17 additions & 9 deletions R/rwebppl.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ install_webppl <- function(webppl_version) {
rwebppl_meta <- jsonlite::fromJSON(readLines(rwebppl_json))
rwebppl_meta$dependencies$webppl <- webppl_version
webppl_json <- file.path(rwebppl_path(), "js", "package.json")

# Executable bit should be tracked by git but chmod just in case
system2('chmod', args = c('+x', file.path(rwebppl_path(), "bash", "*")))

writeLines(jsonlite::toJSON(rwebppl_meta, auto_unbox = TRUE, pretty = TRUE),
webppl_json)
system2(file.path(rwebppl_path(), "bash", "install-webppl.sh"),
Expand Down Expand Up @@ -179,13 +179,18 @@ countSamples <- function(output, inference_opts) {
}
}

tidy_probTable <- function(output) {
tidy_probTable <- function(output, sort_by) {
if (class(output$support) == "data.frame") {
support <- output$support
} else {
support <- data.frame(support = output$support)
}
return(cbind(support, data.frame(prob = output$probs)))
unsorted_probTable <- cbind(support, data.frame(prob = output$probs))
if (sort_by == "prob") {
return(unsorted_probTable[with(unsorted_probTable, order(-prob)), ])
} else {
return(unsorted_probTable[with(unsorted_probTable, order(support)), ])
}
}

tidy_sampleList <- function(output, chains, chain, inference_opts) {
Expand All @@ -209,9 +214,9 @@ tidy_sampleList <- function(output, chains, chain, inference_opts) {
return(ggmcmc_samples)
}

tidy_output <- function(output, chains = NULL, chain = NULL, inference_opts = NULL) {
tidy_output <- function(output, chains = NULL, chain = NULL, inference_opts = NULL, sort_by = NULL) {
if (is_probTable(output)) {
return(tidy_probTable(output))
return(tidy_probTable(output, sort_by = sort_by))
} else if (is_sampleList(output)) {
# Drop redundant score column, if it exists
if ("score" %in% names(output)) {
Expand All @@ -237,12 +242,13 @@ tidy_output <- function(output, chains = NULL, chain = NULL, inference_opts = NU
#' @param inference_opts Options for inference
#' (see http://webppl.readthedocs.io/en/master/inference.html)
#' @param random_seed Seed for random number generator
#' @param sort_by Sort probability table by probability or support (enumeration only)
#' @param chains Number of chains (this run is one chain).
#' @param chain Chain number of this run.
run_webppl <- function(program_code = NULL, program_file = NULL, data = NULL,
data_var = "data", packages = NULL, model_var = "model",
inference_opts = NULL, chains = NULL, random_seed = NULL,
chain = 1) {
sort_by = "prob", chain = 1) {

# find location of rwebppl JS script, within rwebppl R package
script_path <- file.path(rwebppl_path(), "js/rwebppl")
Expand Down Expand Up @@ -327,7 +333,8 @@ run_webppl <- function(program_code = NULL, program_file = NULL, data = NULL,
output <- jsonlite::fromJSON(output_string, flatten = TRUE)
if (!is.null(names(output))) {
return(tidy_output(output, chains = chains,
chain = chain, inference_opts = inference_opts))
chain = chain, inference_opts = inference_opts,
sort_by = sort_by))
} else {
return(output)
}
Expand Down Expand Up @@ -358,7 +365,7 @@ globalVariables("i")
#' }
webppl <- function(program_code = NULL, program_file = NULL, data = NULL,
data_var = "data", packages = NULL, model_var = "model",
inference_opts = NULL, random_seed = NULL, chains = 1, cores = 1) {
inference_opts = NULL, random_seed = NULL, sort_by = "prob", chains = 1, cores = 1) {

run_fun <- function(k) run_webppl(program_code = program_code,
program_file = program_file,
Expand All @@ -368,6 +375,7 @@ webppl <- function(program_code = NULL, program_file = NULL, data = NULL,
model_var = model_var,
inference_opts = inference_opts,
random_seed = random_seed,
sort_by = sort_by,
chains = chains,
chain = k)
if (chains == 1) {
Expand Down
5 changes: 4 additions & 1 deletion man/run_webppl.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion man/webppl.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit ccc92de

Please sign in to comment.