Skip to content

Commit

Permalink
Add R2
Browse files Browse the repository at this point in the history
  • Loading branch information
sims1253 committed Nov 15, 2022
1 parent 9b9a8a8 commit b6ad1af
Show file tree
Hide file tree
Showing 4 changed files with 210 additions and 0 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ S3method(loo_compare,default)
S3method(loo_model_weights,default)
S3method(loo_moment_match,default)
S3method(loo_predictive_metric,default)
S3method(loo_r2,default)
S3method(loo_subsample,"function")
S3method(nobs,psis_loo_ss)
S3method(plot,loo)
Expand Down Expand Up @@ -108,6 +109,7 @@ export(loo_model_weights.default)
export(loo_moment_match)
export(loo_moment_match.default)
export(loo_predictive_metric)
export(loo_r2)
export(loo_subsample)
export(loo_subsample.function)
export(mcse_loo)
Expand Down
98 changes: 98 additions & 0 deletions R/loo_r2.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#' R2
#'
#' The `loo_r2()` method can compute the pointwise R2 metric
#'
#' @export
#' @param y A numeric vector of observations. Length should be equal to the
#' number of rows in `ypred`.
#' @param ypred A numeric matrix of predictions.
#' @param log_lik A matrix of pointwise log-likelihoods. Should be of same
#' dimension as `ypred`.
#' @param r_eff A Vector of relative effective sample size estimates containing
#' one element per observation. See [psis()] for more details.
#'
#' @return A list with the following components:
#' \describe{
#' \item{`pointwise`}{
#' Pointwise components of the R2
#' }
#' \item{`estimate`}{
#' R2 estimate
#' }
#' \item{`se`}{
#' Standard error of the R2
#' }
#' }
#' @export
#' @examples
#' \donttest{
#' if (requireNamespace("rstanarm", quietly = TRUE)) {
#' # Use rstanarm package to quickly fit a model and get both a log-likelihood
#' # matrix and draws from the posterior predictive distribution
#' library("rstanarm")
#' data("mtcars")
#'
#' fit <- stan_glm(mpg ~ cyl + hp + wt, data = mtcars, refresh = 0)
#' ll <- log_lik(fit)
#' r_eff <- relative_eff(exp(-ll), chain_id = rep(1:4, each = 1000))
#'
#' ypred <- posterior_predict(fit)
#' # Leave-one-out R2
#' r2 <- loo_r2(
#' y = mtcars$mpg,
#' ypred = ypred,
#' log_lik = ll,
#' r_eff = r_eff
#' )
#' }
#' }
loo_r2 <- function(y, ypred, log_lik, r_eff, ...) {
UseMethod("loo_r2")
}

#' @rdname loo_r2
#' @export
loo_r2.default <-
function(y,
ypred,
log_lik,
r_eff = NULL) {
stopifnot(
is.numeric(y),
is.numeric(ypred),
identical(ncol(ypred), length(y)),
identical(dim(ypred), dim(log_lik))
)
psis_object <- psis(-log_lik, r_eff = r_eff)
pointwise <- .r2(
y = y,
ypred = ypred,
weights = exp(psis_object$log_weights)
)
list(
estimate = sum(pointwise),
pointwise = pointwise,
se = sqrt(length(pointwise) * var(pointwise)),
diagnostics = psis_object$diagnostics
)
}

# internal ----------------------------------------------------------------

.r2 <- function(y, ypred, weights = NULL) {
ss_y <- sum((y - mean(y))^2)
pointwise_loo_r2 <- vector(mode = "numeric", length = length(y))

if (is.null(weights)) {
for (n in seq_along(pointwise_loo_r2)) {
ss_e <- sum((y[n] - ypred[, n])^2)
pointwise_loo_r2[[n]] <- 1 / length(y) - ss_e / ss_y
}
} else {
for (n in seq_along(pointwise_loo_r2)) {
ss_e <- sum((weights[, n] * (y[n] - ypred[, n])^2) / sum(weights[, n]))
pointwise_loo_r2[[n]] <- 1 / length(y) - ss_e / ss_y
}
}
pointwise_loo_r2
}
63 changes: 63 additions & 0 deletions man/loo_r2.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

47 changes: 47 additions & 0 deletions tests/testthat/test_loo_r2.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
options(mc.cores = 1)
set.seed(123)
context("r2")

LL <- example_loglik_matrix()
chain_id <- rep(1:2, each = dim(LL)[1] / 2)
r_eff <- relative_eff(exp(LL), chain_id)
psisd_obj <- psis(-LL, r_eff = r_eff, cores = 2)

yrep <- matrix(rnorm(length(LL)), nrow = nrow(LL), ncol = ncol(LL))
y <- rnorm(ncol(LL))


r2 <- loo_r2(y = y,
ypred = yrep ,
log_lik = LL,
r_eff = r_eff)

test_that("loo_r2 stops with incorrect inputs", {
expect_error(loo_r2(as.character(y), ypred, LL, r_eff = r_eff),
"is.numeric(y) is not TRUE",
fixed = TRUE)

expect_error(loo_r2(y, as.character(ypred), LL, r_eff = r_eff),
"is.numeric(ypred) is not TRUE",
fixed = TRUE)

ypred_invalid <- matrix(rnorm(9), nrow = 3)
expect_error(loo_r2(y, ypred_invalid, LL, r_eff = r_eff),
"identical(ncol(ypred), length(y)) is not TRUE",
fixed = TRUE)

ypred_invalid <- matrix(rnorm(64), nrow = 2)
expect_error(loo_r2(y, ypred_invalid, LL, r_eff = r_eff),
"identical(dim(ypred), dim(log_lik)) is not TRUE",
fixed = TRUE)
})


test_that("loo_r2 return types are correct", {
expect_type(r2, "list")
expect_named(r2, c("estimate", "pointwise", "se", "diagnostics"))
})

test_that("loo_r2 results haven't changed", {
expect_equal_to_reference(r2, "reference-results/loo_r2.rds")
})

0 comments on commit b6ad1af

Please sign in to comment.