From 3cfa49c59fb4288d5fd1688945c1bf364ee4337b Mon Sep 17 00:00:00 2001 From: topepo Date: Wed, 4 Dec 2024 11:54:50 -0500 Subject: [PATCH 1/5] function to get prediction columns --- NAMESPACE | 1 + R/misc.R | 67 ++++++++++++++++++++++++++ man/dot-get_prediction_column_names.Rd | 33 +++++++++++++ tests/testthat/_snaps/misc.md | 16 ++++++ tests/testthat/test-misc.R | 50 +++++++++++++++++++ 5 files changed, 167 insertions(+) create mode 100644 man/dot-get_prediction_column_names.Rd diff --git a/NAMESPACE b/NAMESPACE index 0782892d6..ea8dce360 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -185,6 +185,7 @@ export(.dat) export(.extract_surv_status) export(.extract_surv_time) export(.facts) +export(.get_prediction_column_names) export(.lvls) export(.model_param_name_key) export(.obs) diff --git a/R/misc.R b/R/misc.R index 3582d8ca2..3a4694276 100644 --- a/R/misc.R +++ b/R/misc.R @@ -575,3 +575,70 @@ is_cran_check <- function() { } # nocov end +# ------------------------------------------------------------------------------ + +#' Obtain names of prediction columns for a fitted model or workflow +#' +#' [.get_prediction_column_names()] returns a list that has the names of the +#' columns for the primary prediction types for a model. +#' @param x A fitted model (class `"model_fit"`) or a fitted workflow. +#' @param syms Should the column names be converted to symbols? +#' @return A list with elements `"estimate"` and `"probabilities"`. +#' @examplesIf !parsnip:::is_cran_check() +#' library(dplyr) +#' library(modeldata) +#' data("two_class_dat") +#' +#' levels(two_class_dat$Class) +#' lr_fit <- logistic_reg() %>% fit(Class ~ ., data = two_class_dat) +#' +#' .get_prediction_column_names(lr_fit) +#' .get_prediction_column_names(lr_fit, syms = TRUE) +#' @export +.get_prediction_column_names <- function(x, syms = FALSE) { + if (!inherits(x, c("model_fit", "workflow"))) { + cli::cli_abort("{.arg x} should be an object with class {.cls model_fit} or + {.cls workflow}, not {.obj_type_friendly {x}}.") + } + model_spec <- extract_spec_parsnip(x) + model_engine <- model_spec$engine + model_mode <- model_spec$mode + model_type <- class(model_spec)[1] + + # appropriate populate the model db + inst_res <- purrr::map(required_pkgs(x), rlang::check_installed) + predict_types <- + get_from_env(paste0(model_type, "_predict")) %>% + dplyr::filter(engine == model_engine & mode == model_mode) %>% + purrr::pluck("type") + + if (length(predict_types) == 0) { + cli::cli_abort("Prediction information could not be found for this + {.fn {model_type}} with engine {.val {model_engine}} and mode + {.val {model_mode}}. Does a parsnip extension package need to + be loaded?") + } + + res <- list(estimate = character(0), probabilities = character(0)) + + if (model_mode == "regression") { + res$estimate <- ".pred" + } else if (model_mode == "classification") { + res$estimate <- ".pred_class" + if (any(predict_types == "prob")) { + res$probabilities <- paste0(".pred_", x$lvl) + } + } else if (model_mode == "censored regression") { + res$estimate <- ".pred_time" + if (any(predict_types %in% c("survival"))) { + res$probabilities <- ".pred" + } + } else { + cli::cli_abort("Unsupported model mode {model_mode}.") + } + + if (syms) { + res <- purrr::map(res, rlang::syms) + } + res +} diff --git a/man/dot-get_prediction_column_names.Rd b/man/dot-get_prediction_column_names.Rd new file mode 100644 index 000000000..36fba2c08 --- /dev/null +++ b/man/dot-get_prediction_column_names.Rd @@ -0,0 +1,33 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/misc.R +\name{.get_prediction_column_names} +\alias{.get_prediction_column_names} +\title{Obtain names of prediction columns for a fitted model or workflow} +\usage{ +.get_prediction_column_names(x, syms = FALSE) +} +\arguments{ +\item{x}{A fitted model (class \code{"model_fit"}) or a fitted workflow.} + +\item{syms}{Should the column names be converted to symbols?} +} +\value{ +A list with elements \code{"estimate"} and \code{"probabilities"}. +} +\description{ +\code{\link[=.get_prediction_column_names]{.get_prediction_column_names()}} returns a list that has the names of the +columns for the primary prediction types for a model. +} +\examples{ +\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +library(dplyr) +library(modeldata) +data("two_class_dat") + +levels(two_class_dat$Class) +lr_fit <- logistic_reg() \%>\% fit(Class ~ ., data = two_class_dat) + +.get_prediction_column_names(lr_fit) +.get_prediction_column_names(lr_fit, syms = TRUE) +\dontshow{\}) # examplesIf} +} diff --git a/tests/testthat/_snaps/misc.md b/tests/testthat/_snaps/misc.md index b6b1f918c..b221b1dde 100644 --- a/tests/testthat/_snaps/misc.md +++ b/tests/testthat/_snaps/misc.md @@ -227,3 +227,19 @@ Error in `check_outcome()`: ! For a censored regression model, the outcome should be a object, not an integer vector. +# obtaining prediction columns + + Code + .get_prediction_column_names(1) + Condition + Error in `.get_prediction_column_names()`: + ! `x` should be an object with class or , not a number. + +--- + + Code + .get_prediction_column_names(unk_fit) + Condition + Error in `.get_prediction_column_names()`: + ! Prediction information could not be found for this `linear_reg()` with engine "lm" and mode "Depeche". Does a parsnip extension package need to be loaded? + diff --git a/tests/testthat/test-misc.R b/tests/testthat/test-misc.R index e689bcbcf..901c92748 100644 --- a/tests/testthat/test-misc.R +++ b/tests/testthat/test-misc.R @@ -249,3 +249,53 @@ test_that('check_outcome works as expected', { check_outcome(1:2, cens_spec) ) }) + +# ------------------------------------------------------------------------------ + +test_that('obtaining prediction columns', { + skip_if_not_installed("modeldata") + data(two_class_dat, package = "modeldata") + + ### classification + lr_fit <- logistic_reg() %>% fit(Class ~ ., data = two_class_dat) + expect_equal( + .get_prediction_column_names(lr_fit), + list(estimate = ".pred_class", + probabilities = c(".pred_Class1", ".pred_Class2")) + ) + expect_equal( + .get_prediction_column_names(lr_fit, syms = TRUE), + list(estimate = list(quote(.pred_class)), + probabilities = list(quote(.pred_Class1), quote(.pred_Class2))) + ) + + ### regression + ols_fit <- linear_reg() %>% fit(mpg ~ ., data = mtcars) + expect_equal( + .get_prediction_column_names(ols_fit), + list(estimate = ".pred", + probabilities = character(0)) + ) + expect_equal( + .get_prediction_column_names(ols_fit, syms = TRUE), + list(estimate = list(quote(.pred)), + probabilities = list()) + ) + + ### censored regression + # in extratests + + ### bad input + expect_snapshot( + .get_prediction_column_names(1), + error = TRUE + ) + + unk_fit <- ols_fit + unk_fit$spec$mode <- "Depeche" + expect_snapshot( + .get_prediction_column_names(unk_fit), + error = TRUE + ) + +}) From a834e1d1a1ddf3e468343ce22e0212748e250bfb Mon Sep 17 00:00:00 2001 From: topepo Date: Wed, 4 Dec 2024 12:01:21 -0500 Subject: [PATCH 2/5] forgotten pkgdown entry --- _pkgdown.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/_pkgdown.yml b/_pkgdown.yml index c79ecca06..3596da3f3 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -111,3 +111,4 @@ reference: - .extract_surv_status - .extract_surv_time - .model_param_name_key + - .get_prediction_column_names From df13ba8905532a438a938a6b297f61e586411232 Mon Sep 17 00:00:00 2001 From: topepo Date: Wed, 4 Dec 2024 13:10:28 -0500 Subject: [PATCH 3/5] also, bump version number --- DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index fded77f48..bbc018a48 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: parsnip Title: A Common API to Modeling and Analysis Functions -Version: 1.2.1.9003 +Version: 1.2.1.9004 Authors@R: c( person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre")), person("Davis", "Vaughan", , "davis@posit.co", role = "aut"), From d26a108f5ba56d694b0c3b044f6cfa16264148c5 Mon Sep 17 00:00:00 2001 From: topepo Date: Wed, 4 Dec 2024 13:22:34 -0500 Subject: [PATCH 4/5] fix for workflows --- R/misc.R | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/R/misc.R b/R/misc.R index 3a4694276..dc993cb55 100644 --- a/R/misc.R +++ b/R/misc.R @@ -600,6 +600,10 @@ is_cran_check <- function() { cli::cli_abort("{.arg x} should be an object with class {.cls model_fit} or {.cls workflow}, not {.obj_type_friendly {x}}.") } + + if (inherits(x, "workflow")) { + x <- x %>% extract_fit_parsnip(x) + } model_spec <- extract_spec_parsnip(x) model_engine <- model_spec$engine model_mode <- model_spec$mode From 61a6a1e3688d9a6368d6bfbdd93e2f9337771c0b Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Sun, 8 Dec 2024 07:39:49 -0500 Subject: [PATCH 5/5] Apply suggestions from code review Co-authored-by: Emil Hvitfeldt --- R/misc.R | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/R/misc.R b/R/misc.R index dc993cb55..263a9b0f2 100644 --- a/R/misc.R +++ b/R/misc.R @@ -581,8 +581,8 @@ is_cran_check <- function() { #' #' [.get_prediction_column_names()] returns a list that has the names of the #' columns for the primary prediction types for a model. -#' @param x A fitted model (class `"model_fit"`) or a fitted workflow. -#' @param syms Should the column names be converted to symbols? +#' @param x A fitted parsnip model (class `"model_fit"`) or a fitted workflow. +#' @param syms Should the column names be converted to symbols? Defaults to `FALSE`. #' @return A list with elements `"estimate"` and `"probabilities"`. #' @examplesIf !parsnip:::is_cran_check() #' library(dplyr) @@ -638,6 +638,7 @@ is_cran_check <- function() { res$probabilities <- ".pred" } } else { + # Should be unreachable cli::cli_abort("Unsupported model mode {model_mode}.") }