Skip to content

function to get prediction columns #1224

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: parsnip
Title: A Common API to Modeling and Analysis Functions
Version: 1.2.1.9003
Version: 1.2.1.9004
Authors@R: c(
person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre")),
person("Davis", "Vaughan", , "davis@posit.co", role = "aut"),
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ export(.dat)
export(.extract_surv_status)
export(.extract_surv_time)
export(.facts)
export(.get_prediction_column_names)
export(.lvls)
export(.model_param_name_key)
export(.obs)
Expand Down
72 changes: 72 additions & 0 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -575,3 +575,75 @@
}
# nocov end

# ------------------------------------------------------------------------------

#' Obtain names of prediction columns for a fitted model or workflow
#'
#' [.get_prediction_column_names()] returns a list that has the names of the
#' columns for the primary prediction types for a model.
#' @param x A fitted parsnip model (class `"model_fit"`) or a fitted workflow.
#' @param syms Should the column names be converted to symbols? Defaults to `FALSE`.
#' @return A list with elements `"estimate"` and `"probabilities"`.
#' @examplesIf !parsnip:::is_cran_check()
#' library(dplyr)
#' library(modeldata)
#' data("two_class_dat")
#'
#' levels(two_class_dat$Class)
#' lr_fit <- logistic_reg() %>% fit(Class ~ ., data = two_class_dat)
#'
#' .get_prediction_column_names(lr_fit)
#' .get_prediction_column_names(lr_fit, syms = TRUE)
#' @export
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make this function @keyword internal?

.get_prediction_column_names <- function(x, syms = FALSE) {
if (!inherits(x, c("model_fit", "workflow"))) {
cli::cli_abort("{.arg x} should be an object with class {.cls model_fit} or
{.cls workflow}, not {.obj_type_friendly {x}}.")
}

if (inherits(x, "workflow")) {
x <- x %>% extract_fit_parsnip(x)

Check warning on line 605 in R/misc.R

View check run for this annotation

Codecov / codecov/patch

R/misc.R#L605

Added line #L605 was not covered by tests
}
Comment on lines +604 to +606
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With an eye toward the release cascade and trying to keep things as modular as possible, I think it would be good to make this function only for model_spec objects. This call to extract_fit_parsnip() seems to be all that's necessary for workflow objects. I'd move that to workflows or wherever this is going to be needed.

model_spec <- extract_spec_parsnip(x)
model_engine <- model_spec$engine
model_mode <- model_spec$mode
model_type <- class(model_spec)[1]

# appropriate populate the model db
inst_res <- purrr::map(required_pkgs(x), rlang::check_installed)
predict_types <-
get_from_env(paste0(model_type, "_predict")) %>%
dplyr::filter(engine == model_engine & mode == model_mode) %>%
purrr::pluck("type")

if (length(predict_types) == 0) {
cli::cli_abort("Prediction information could not be found for this
{.fn {model_type}} with engine {.val {model_engine}} and mode
{.val {model_mode}}. Does a parsnip extension package need to
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That last question would be very nice in an i bullet rather than the error message itself.

be loaded?")
}

res <- list(estimate = character(0), probabilities = character(0))

if (model_mode == "regression") {
res$estimate <- ".pred"
} else if (model_mode == "classification") {
res$estimate <- ".pred_class"
if (any(predict_types == "prob")) {
res$probabilities <- paste0(".pred_", x$lvl)
}
} else if (model_mode == "censored regression") {
res$estimate <- ".pred_time"
if (any(predict_types %in% c("survival"))) {
res$probabilities <- ".pred"

Check warning on line 638 in R/misc.R

View check run for this annotation

Codecov / codecov/patch

R/misc.R#L635-L638

Added lines #L635 - L638 were not covered by tests
}
} else {
# Should be unreachable
cli::cli_abort("Unsupported model mode {model_mode}.")

Check warning on line 642 in R/misc.R

View check run for this annotation

Codecov / codecov/patch

R/misc.R#L642

Added line #L642 was not covered by tests
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know that this branch of the if/then will ever be encountered, given the error check above. However, we might hit is when we have a mode for quantile regression data so I'd err on the side of leaving it in, untested.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think that is reasonable

}

if (syms) {
res <- purrr::map(res, rlang::syms)
}
res
}
1 change: 1 addition & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,4 @@ reference:
- .extract_surv_status
- .extract_surv_time
- .model_param_name_key
- .get_prediction_column_names
33 changes: 33 additions & 0 deletions man/dot-get_prediction_column_names.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions tests/testthat/_snaps/misc.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,19 @@
Error in `check_outcome()`:
! For a censored regression model, the outcome should be a <Surv> object, not an integer vector.

# obtaining prediction columns

Code
.get_prediction_column_names(1)
Condition
Error in `.get_prediction_column_names()`:
! `x` should be an object with class <model_fit> or <workflow>, not a number.

---

Code
.get_prediction_column_names(unk_fit)
Condition
Error in `.get_prediction_column_names()`:
! Prediction information could not be found for this `linear_reg()` with engine "lm" and mode "Depeche". Does a parsnip extension package need to be loaded?

50 changes: 50 additions & 0 deletions tests/testthat/test-misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,53 @@ test_that('check_outcome works as expected', {
check_outcome(1:2, cens_spec)
)
})

# ------------------------------------------------------------------------------

test_that('obtaining prediction columns', {
skip_if_not_installed("modeldata")
data(two_class_dat, package = "modeldata")

### classification
lr_fit <- logistic_reg() %>% fit(Class ~ ., data = two_class_dat)
expect_equal(
.get_prediction_column_names(lr_fit),
list(estimate = ".pred_class",
probabilities = c(".pred_Class1", ".pred_Class2"))
)
expect_equal(
.get_prediction_column_names(lr_fit, syms = TRUE),
list(estimate = list(quote(.pred_class)),
probabilities = list(quote(.pred_Class1), quote(.pred_Class2)))
)

### regression
ols_fit <- linear_reg() %>% fit(mpg ~ ., data = mtcars)
expect_equal(
.get_prediction_column_names(ols_fit),
list(estimate = ".pred",
probabilities = character(0))
)
expect_equal(
.get_prediction_column_names(ols_fit, syms = TRUE),
list(estimate = list(quote(.pred)),
probabilities = list())
)

### censored regression
# in extratests

### bad input
expect_snapshot(
.get_prediction_column_names(1),
error = TRUE
)

unk_fit <- ols_fit
unk_fit$spec$mode <- "Depeche"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😆

expect_snapshot(
.get_prediction_column_names(unk_fit),
error = TRUE
)

})
Loading