
Save spectral prediction model and model performance statistics
Source:R/save_model.R
save_model.RdGiven a set of pretreatment methods, saves the best spectral
prediction model and model statistics to model.save.folder as
model.name.Rds and model.name_stats.csv respectively. If only
one pretreatment method is supplied, results from that method are stored.
Usage
save_model(
df,
write.model = TRUE,
pretreatment = 1,
model.save.folder = NULL,
model.name = "PredictionModel",
best.model.metric = "RMSE",
k.folds = 5,
proportion.train = 0.7,
tune.length = 50,
model.method = "pls",
num.iterations = 10,
stratified.sampling = TRUE,
cv.scheme = NULL,
trial1 = NULL,
trial2 = NULL,
trial3 = NULL,
seed = 1,
verbose = TRUE,
save.model = lifecycle::deprecated(),
wavelengths = lifecycle::deprecated(),
autoselect.preprocessing = lifecycle::deprecated(),
preprocessing.method = lifecycle::deprecated()
)Arguments
- df
data.frameobject. First column contains unique identifiers, second contains reference values, followed by spectral columns. Include no other columns to right of spectra! Column names of spectra must start with "X" and reference column must be named "reference"- write.model
If
TRUE, the trained model will be saved in .Rds format to the location specified bymodel.save.folder. IfFALSE, the best model will be output by the function but will not save to a file. Default isTRUE.- pretreatment
Number or list of numbers 1:13 corresponding to desired pretreatment method(s):
Raw data (default)
Standard normal variate (SNV)
SNV and first derivative
SNV and second derivative
First derivative
Second derivative
Savitzky–Golay filter (SG)
SNV and SG
Gap-segment derivative (window size = 11)
SG and first derivative (window size = 5)
SG and first derivative (window size = 11)
SG and second derivative (window size = 5)
SG and second derivative (window size = 11)
- model.save.folder
Path to folder where model will be saved. If not provided, will save to working directory.
- model.name
Name that model will be saved as in
model.save.folder. Default is "PredictionModel".- best.model.metric
Metric used to decide which model is best. Must be either "RMSE" or "Rsquared"
- k.folds
Number indicating the number of folds for k-fold cross-validation during model training. Default is 5.
- proportion.train
Fraction of samples to include in the training set. Default is 0.7.
- tune.length
Number delineating search space for tuning of the PLSR hyperparameter
ncomp. Must be set to 5 when using the random forest algorithm (model.method == rf). Default is 50.- model.method
Model type to use for training. Valid options include:
"pls": Partial least squares regression (Default)
"rf": Random forest
"svmLinear": Support vector machine with linear kernel
"svmRadial": Support vector machine with radial kernel
- num.iterations
Number of training iterations to perform
- stratified.sampling
If
TRUE, training and test sets will be selected using stratified random sampling. This term is only used iftest.data == NULL. Default isTRUE.- cv.scheme
A cross validation (CV) scheme from Jarquín et al., 2017. Options for
cv.schemeinclude:"CV1": untested lines in tested environments
"CV2": tested lines in tested environments
"CV0": tested lines in untested environments
"CV00": untested lines in untested environments
- trial1
data.frameobject that is for use only whencv.schemeis provided. Contains the trial to be tested in subsequent model training functions. The first column contains unique identifiers, second contains genotypes, third contains reference values, followed by spectral columns. Include no other columns to right of spectra! Column names of spectra must start with "X", reference column must be named "reference", and genotype column must be named "genotype".- trial2
data.frameobject that is for use only whencv.schemeis provided. This data.frame contains a trial that has overlapping genotypes withtrial1but that were grown in a different site/year (different environment). Formatting must be consistent withtrial1.- trial3
data.frameobject that is for use only whencv.schemeis provided. This data.frame contains a trial that may or may not contain genotypes that overlap withtrial1. Formatting must be consistent withtrial1.- seed
Integer to be used internally as input for
set.seed(). Only used ifstratified.sampling = TRUE. In all other cases, seed is set to the current iteration number. Default is 1.- verbose
If
TRUE, the number of rows removed through filtering will be printed to the console. Default isTRUE.- save.model
DEPRECATED
save.model = FALSEis no longer supported; this function will always return a saved model.- wavelengths
DEPRECATED
wavelengthsis no longer supported; this information is now inferred fromdfcolumn names- autoselect.preprocessing
DEPRECATED
autoselect.preprocessing = FALSEis no longer supported. If multiple pretreatment methods are supplied, the best will be automatically selected as the model to be saved.- preprocessing.method
DEPRECATED
preprocessing.methodhas been renamed "pretreatment"
Value
List of model stats (in data.frame) and trained model object.
If the parameter write.model is TRUE, both objects are saved to
model.save.folder. To use the optimally trained model for
predictions, use tuned parameters from $bestTune.
Details
Wrapper that uses pretreat_spectra,
format_cv, and train_spectra functions.
Author
Jenna Hershberger jmh579@cornell.edu
Examples
# \donttest{
library(magrittr)
test.model <- ikeogu.2017 %>%
dplyr::filter(study.name == "C16Mcal") %>%
dplyr::rename(reference = DMC.oven,
unique.id = sample.id) %>%
dplyr::select(unique.id, reference, dplyr::starts_with("X")) %>%
na.omit() %>%
save_model(
df = .,
write.model = FALSE,
pretreatment = 1:13,
model.name = "my_prediction_model",
tune.length = 3,
num.iterations = 3
)
#> Warning: The `save.model` argument of `test_spectra()` is deprecated as of waves 0.2.0.
#> ℹ Models are now saved by default.
#> ℹ The deprecated feature was likely used in the waves package.
#> Please report the issue at <https://github.com/GoreLab/waves/issues>.
#> Warning: The `return.model` argument of `test_spectra()` is deprecated as of waves
#> 0.2.0.
#> ℹ Trained models are now returned by default.
#> ℹ The deprecated feature was likely used in the waves package.
#> Please report the issue at <https://github.com/GoreLab/waves/issues>.
#> Pretreatment initiated.
#> Training models...
#> Working on Raw_data
#> Warning: The `wavelengths` argument of `train_spectra()` is deprecated as of waves
#> 0.2.0.
#> ℹ Wavelength specification is now inferred from column names.
#> ℹ The deprecated feature was likely used in the waves package.
#> Please report the issue at <https://github.com/GoreLab/waves/issues>.
#> Warning: The `preprocessing` argument of `train_spectra()` is deprecated as of waves
#> 0.2.0.
#> ℹ Argument `preprocessing` is deprecated. Use `pretreatment` instead:
#> `pretreatment = 1:13` (all), or `pretreatment = 1` (raw only).
#> ℹ The deprecated feature was likely used in the waves package.
#> Please report the issue at <https://github.com/GoreLab/waves/issues>.
#> Warning: The `save.model` argument of `train_spectra()` is deprecated as of waves 0.2.0.
#> ℹ Models are now saved by default.
#> ℹ The deprecated feature was likely used in the waves package.
#> Please report the issue at <https://github.com/GoreLab/waves/issues>.
#> Loading required package: lattice
#>
#> Attaching package: ‘pls’
#> The following object is masked from ‘package:caret’:
#>
#> R2
#> The following object is masked from ‘package:stats’:
#>
#> loadings
#> Returning model...
#> Working on SNV
#> Returning model...
#> Working on SNV1D
#> Returning model...
#> Working on SNV2D
#> Returning model...
#> Working on D1
#> Returning model...
#> Working on D2
#> Returning model...
#> Working on SG
#> Returning model...
#> Working on SNVSG
#> Returning model...
#> Working on SGD1
#> Returning model...
#> Working on SG.D1W5
#> Returning model...
#> Working on SG.D1W11
#> Returning model...
#> Working on SG.D2W5
#> Returning model...
#> Working on SG.D2W11
#> Returning model...
#>
#> Training Summary:
#> # A tibble: 13 × 40
#> Pretreatment RMSEp_mean R2p_mean RPD_mean RPIQ_mean CCC_mean Bias_mean
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 Raw_data 2.30 0.717 1.86 2.41 0.817 0.299
#> 2 SNV 1.83 0.820 2.34 3.03 0.888 0.190
#> 3 SNV1D 2.88 0.559 1.47 1.90 0.717 0.351
#> 4 SNV2D 4.28 0.0717 0.988 1.28 0.205 0.635
#> 5 D1 2.76 0.591 1.53 1.99 0.738 0.463
#> 6 D2 4.28 0.0728 0.988 1.28 0.206 0.658
#> 7 SG 2.30 0.716 1.86 2.41 0.817 0.299
#> 8 SNVSG 1.80 0.828 2.39 3.09 0.893 0.167
#> 9 SGD1 2.37 0.706 1.81 2.34 0.807 0.370
#> 10 SG.D1W5 2.33 0.721 1.83 2.37 0.814 0.460
#> 11 SG.D1W11 2.35 0.713 1.82 2.35 0.811 0.395
#> 12 SG.D2W5 4.07 0.0941 1.04 1.35 0.225 0.370
#> 13 SG.D2W11 3.15 0.497 1.34 1.74 0.635 0.850
#> # ℹ 33 more variables: SEP_mean <dbl>, RMSEcv_mean <dbl>, R2cv_mean <dbl>,
#> # R2sp_mean <dbl>, best.ncomp_mean <dbl>, best.ntree_mean <dbl>,
#> # best.mtry_mean <dbl>, RMSEp_sd <dbl>, R2p_sd <dbl>, RPD_sd <dbl>,
#> # RPIQ_sd <dbl>, CCC_sd <dbl>, Bias_sd <dbl>, SEP_sd <dbl>, RMSEcv_sd <dbl>,
#> # R2cv_sd <dbl>, R2sp_sd <dbl>, best.ncomp_sd <dbl>, best.ntree_sd <dbl>,
#> # best.mtry_sd <dbl>, RMSEp_mode <dbl>, R2p_mode <dbl>, RPD_mode <dbl>,
#> # RPIQ_mode <dbl>, CCC_mode <dbl>, Bias_mode <dbl>, SEP_mode <dbl>, …
#>
#> Best pretreatment technique: SNVSG
summary(test.model$best.model)
#> Data: X dimension: 120 2141
#> Y dimension: 120 1
#> Fit method: kernelpls
#> Number of components considered: 3
#> TRAINING: % variance explained
#> 1 comps 2 comps 3 comps
#> X 64.48 87.94 91.43
#> reference 33.93 64.97 87.18
test.model$best.model.stats
#> # A tibble: 1 × 40
#> Pretreatment RMSEp_mean R2p_mean RPD_mean RPIQ_mean CCC_mean Bias_mean
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 SNVSG 1.80 0.828 2.39 3.09 0.893 0.167
#> # ℹ 33 more variables: SEP_mean <dbl>, RMSEcv_mean <dbl>, R2cv_mean <dbl>,
#> # R2sp_mean <dbl>, best.ncomp_mean <dbl>, best.ntree_mean <dbl>,
#> # best.mtry_mean <dbl>, RMSEp_sd <dbl>, R2p_sd <dbl>, RPD_sd <dbl>,
#> # RPIQ_sd <dbl>, CCC_sd <dbl>, Bias_sd <dbl>, SEP_sd <dbl>, RMSEcv_sd <dbl>,
#> # R2cv_sd <dbl>, R2sp_sd <dbl>, best.ncomp_sd <dbl>, best.ntree_sd <dbl>,
#> # best.mtry_sd <dbl>, RMSEp_mode <dbl>, R2p_mode <dbl>, RPD_mode <dbl>,
#> # RPIQ_mode <dbl>, CCC_mode <dbl>, Bias_mode <dbl>, SEP_mode <dbl>, …
# }