This function aims to make predictions combining a fitted RGCCA object and a prediction model for classification or regression.
Usage
rgcca_predict(
rgcca_res,
blocks_test = rgcca_res$call$blocks,
prediction_model = "lm",
metric = NULL,
...
)
Arguments
- rgcca_res
A fitted RGCCA object (see
rgcca
).- blocks_test
A list of test blocks from which we aim to predict the associated response block. If the test response block is present among blocks_test, metrics are computed by comparing the predictions and the true values.
- prediction_model
A string giving the model used for prediction. Please see caret::modelLookup() for a list of the available models.
- metric
A string indicating the metric of interest. It should be one of the following scores:
For classification: "Accuracy", "Kappa", "F1", "Sensitivity", "Specificity", "Pos_Pred_Value", "Neg_Pred_Value", "Precision", "Recall", "Detection_Rate", "Balanced_Accuracy".
For regression: "RMSE", "MAE".
- ...
Additional parameters to be passed to prediction_model.
Value
A list containing the following elements:
- score
The score obtained on the testing block. NA if the test block is missing.
- model
A list of the models trained using caret to make the predictions and compute the scores.
- metric
A list of data.frames containing the scores obtained on the training and testing sets.
- confusion
A list containing NA for regression tasks. Otherwise, the confusion summary produced by caret for train and test.
- projection
A list of matrices containing the projections of the test blocks using the canonical components from the fitted RGCCA object. The response block is not projected.
- prediction
A list of data.frames with the predictions of the test and train response blocks.
Examples
data("Russett")
blocks <- list(
agriculture = Russett[, 1:3],
industry = Russett[, 4:5],
politic = Russett[, 6:8]
)
X_train <- lapply(blocks, function(x) x[seq(1, 30), ])
X_test <- lapply(blocks, function(x) x[seq(31, 47), ])
fit <- rgcca(X_train,
tau = 1, ncomp = c(3, 2, 3), response = 3
)
res <- rgcca_predict(fit, X_test)