Skip to content

Commit

Permalink
Match factor levels in prediction and training
Browse files Browse the repository at this point in the history
  • Loading branch information
koenderks committed Nov 28, 2024
1 parent 07d2513 commit f4d175b
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions R/mlPrediction.R
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ is.jaspMachineLearning <- function(x) {

# also define methods for other objects
.mlPredictionReady <- function(model, dataset, options) {
if (!is.null(model)) {
if (!is.null(model) && !is.null(dataset)) {
modelVars <- model[["jaspVars"]][["encoded"]]$predictors
presentVars <- colnames(dataset)
ready <- all(modelVars %in% presentVars)
Expand All @@ -241,12 +241,21 @@ is.jaspMachineLearning <- function(x) {
}

.mlPredictionReadData <- function(dataset, options, model) {
dataset <- jaspBase::excludeNaListwise(dataset, options[["predictors"]])
if (options[["scaleVariables"]] && length(unlist(options[["predictors"]])) > 0) {
dataset <- .scaleNumericData(dataset)
if (length(options[["predictors"]]) == 0) {
dataset <- NULL
} else {
dataset <- jaspBase::excludeNaListwise(dataset, options[["predictors"]])
if (options[["scaleVariables"]] && length(unlist(options[["predictors"]])) > 0) {
dataset <- .scaleNumericData(dataset)
}
# Select only the predictors in the model to prevent accidental double column names
dataset <- dataset[, which(decodeColNames(colnames(dataset)) %in% model[["jaspVars"]][["decoded"]]$predictors)]
# Ensure the column names in the dataset match those in the training data
colnames(dataset) <- .matchDecodedNames(colnames(dataset), model)
# Ensure factor variables in dataset have same levels as those in the training data
factorColumns <- colnames(dataset)[sapply(dataset, is.factor)]
dataset[factorColumns] <- lapply(factorColumns, function(i) factor(dataset[[i]], levels = levels(model[["explainer"]]$data[[i]])))
}
dataset <- dataset[, which(decodeColNames(colnames(dataset)) %in% model[["jaspVars"]][["decoded"]]$predictors)] # Filter only predictors to prevent accidental double column names
colnames(dataset) <- .matchDecodedNames(colnames(dataset), model)
return(dataset)
}

Expand Down

0 comments on commit f4d175b

Please sign in to comment.