-
Notifications
You must be signed in to change notification settings - Fork 93
function to get prediction columns #1224
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -575,3 +575,75 @@ | |
} | ||
# nocov end | ||
|
||
# ------------------------------------------------------------------------------ | ||
|
||
#' Obtain names of prediction columns for a fitted model or workflow | ||
#' | ||
#' [.get_prediction_column_names()] returns a list that has the names of the | ||
#' columns for the primary prediction types for a model. | ||
#' @param x A fitted parsnip model (class `"model_fit"`) or a fitted workflow. | ||
#' @param syms Should the column names be converted to symbols? Defaults to `FALSE`. | ||
#' @return A list with elements `"estimate"` and `"probabilities"`. | ||
#' @examplesIf !parsnip:::is_cran_check() | ||
#' library(dplyr) | ||
#' library(modeldata) | ||
#' data("two_class_dat") | ||
#' | ||
#' levels(two_class_dat$Class) | ||
#' lr_fit <- logistic_reg() %>% fit(Class ~ ., data = two_class_dat) | ||
#' | ||
#' .get_prediction_column_names(lr_fit) | ||
#' .get_prediction_column_names(lr_fit, syms = TRUE) | ||
#' @export | ||
.get_prediction_column_names <- function(x, syms = FALSE) { | ||
if (!inherits(x, c("model_fit", "workflow"))) { | ||
cli::cli_abort("{.arg x} should be an object with class {.cls model_fit} or | ||
{.cls workflow}, not {.obj_type_friendly {x}}.") | ||
} | ||
|
||
if (inherits(x, "workflow")) { | ||
x <- x %>% extract_fit_parsnip(x) | ||
} | ||
Comment on lines
+604
to
+606
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With an eye toward the release cascade and trying to keep things as modular as possible, I think it would be good to make this function only for |
||
model_spec <- extract_spec_parsnip(x) | ||
model_engine <- model_spec$engine | ||
model_mode <- model_spec$mode | ||
model_type <- class(model_spec)[1] | ||
|
||
# appropriate populate the model db | ||
inst_res <- purrr::map(required_pkgs(x), rlang::check_installed) | ||
predict_types <- | ||
get_from_env(paste0(model_type, "_predict")) %>% | ||
dplyr::filter(engine == model_engine & mode == model_mode) %>% | ||
purrr::pluck("type") | ||
|
||
if (length(predict_types) == 0) { | ||
cli::cli_abort("Prediction information could not be found for this | ||
{.fn {model_type}} with engine {.val {model_engine}} and mode | ||
{.val {model_mode}}. Does a parsnip extension package need to | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That last question would be very nice in an |
||
be loaded?") | ||
} | ||
|
||
res <- list(estimate = character(0), probabilities = character(0)) | ||
|
||
if (model_mode == "regression") { | ||
res$estimate <- ".pred" | ||
} else if (model_mode == "classification") { | ||
res$estimate <- ".pred_class" | ||
if (any(predict_types == "prob")) { | ||
res$probabilities <- paste0(".pred_", x$lvl) | ||
} | ||
} else if (model_mode == "censored regression") { | ||
res$estimate <- ".pred_time" | ||
if (any(predict_types %in% c("survival"))) { | ||
res$probabilities <- ".pred" | ||
} | ||
} else { | ||
# Should be unreachable | ||
cli::cli_abort("Unsupported model mode {model_mode}.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know that this branch of the if/then will ever be encountered, given the error check above. However, we might hit is when we have a mode for quantile regression data so I'd err on the side of leaving it in, untested. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think that is reasonable
topepo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
if (syms) { | ||
res <- purrr::map(res, rlang::syms) | ||
} | ||
res | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -111,3 +111,4 @@ reference: | |
- .extract_surv_status | ||
- .extract_surv_time | ||
- .model_param_name_key | ||
- .get_prediction_column_names |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -249,3 +249,53 @@ test_that('check_outcome works as expected', { | |
check_outcome(1:2, cens_spec) | ||
) | ||
}) | ||
|
||
# ------------------------------------------------------------------------------ | ||
|
||
test_that('obtaining prediction columns', { | ||
skip_if_not_installed("modeldata") | ||
data(two_class_dat, package = "modeldata") | ||
|
||
### classification | ||
lr_fit <- logistic_reg() %>% fit(Class ~ ., data = two_class_dat) | ||
expect_equal( | ||
.get_prediction_column_names(lr_fit), | ||
list(estimate = ".pred_class", | ||
probabilities = c(".pred_Class1", ".pred_Class2")) | ||
) | ||
expect_equal( | ||
.get_prediction_column_names(lr_fit, syms = TRUE), | ||
list(estimate = list(quote(.pred_class)), | ||
probabilities = list(quote(.pred_Class1), quote(.pred_Class2))) | ||
) | ||
|
||
### regression | ||
ols_fit <- linear_reg() %>% fit(mpg ~ ., data = mtcars) | ||
expect_equal( | ||
.get_prediction_column_names(ols_fit), | ||
list(estimate = ".pred", | ||
probabilities = character(0)) | ||
) | ||
expect_equal( | ||
.get_prediction_column_names(ols_fit, syms = TRUE), | ||
list(estimate = list(quote(.pred)), | ||
probabilities = list()) | ||
) | ||
|
||
### censored regression | ||
# in extratests | ||
|
||
### bad input | ||
expect_snapshot( | ||
.get_prediction_column_names(1), | ||
error = TRUE | ||
) | ||
|
||
unk_fit <- ols_fit | ||
unk_fit$spec$mode <- "Depeche" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 😆 |
||
expect_snapshot( | ||
.get_prediction_column_names(unk_fit), | ||
error = TRUE | ||
) | ||
|
||
}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we make this function
@keyword internal
?