forked from jananiravi/2023-mlhd
-
Notifications
You must be signed in to change notification settings - Fork 2
/
index.qmd
641 lines (469 loc) · 22.8 KB
/
index.qmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
---
title: "ML for Microbial Genomics"
author: "Janani Ravi | jravilab.github.io"
date: "`r Sys.Date()`"
output:
html:
toc: true
toc_float: true
number_sections: true
code_fold: show
---
# MLHD @ICTS \| Aug 02, 2023
> This is a companion repo & webpage for the Microbial Genomics and ML workshop, first presented at the MLHD 2023 conference! You can access the material here: <https://jananiravi.github.io/2023-mlhd>
## Overview
> This session will cover ideas, concepts, and insights needed to get started with building machine learning models in R with high-dimensional data, such as microbial genomics. No prior knowledge in ML is required.
### Acknowledgments
- JRaviLab: Jacob Krol, Ethan Wolfe, Evan Brenner, Keenan Manpearl, Joseph Burke, Vignesh Sridhar, Jill Bilodeaux (contributed to the antimicrobial resistance project)
- Arjun Krishnan (contributed to the tidymodels qmd primer)
- R-Ladies, esp. R-Ladies East Lansing, R-Ladies Aurora; R/Bioconductor; rOpenSci (for all things R!)
- `tidymodels` resource by Julia Silge et al., \| <https://tidymodels.org>
## Install and load packages
To use the code in this document, you will need to install the following packages: `glmnet`, `tidyverse`, and `tidymodels`.
```{r echo=TRUE, eval=TRUE, message=FALSE, warning=FALSE}
#| echo: true
#| output: true
library(tidyverse)
library(tidymodels)
library(glmnet) # for LR
library(vip) # to extract important features
library(ranger) # for RF
```
## Explore your data
Here, we will use microbial genomics data (e.g., gene presence/absence across multiple microbial genomes) wrangled and processed from the [BV-BRC](https://bv-brc.org/) to predict the antibiotic resistance phenotype of each sample (genome) based on the presence/absence of genes in that sample.
To make the dataset usable on your local desktop machine, we have pre-processed the data (using custom scripts that use NCBI/BV-BRC data and metadata, NCBI and BV-BRC CLI, Prokka for genome annotation, and Roary/CD-HIT for constructing ht gene presence/absence matrix and gene clusters that serve as ML features). For this workshop, we have selected a subset of \~900 genomes from *Staphylococcus aureus*, and limited the data to `n` genes after filtering out core (present in \>95% of genomes) and unique (present in \<5% of genomes) genes.
The data is contained in the files `abc.csv` with samples (genomes) along the rows and genes along the columns. To get started, let's read this data into R using the `readr::read_delim` function. These files also carry relevant metadata of the genomes and drugs.
### Read in the data file
```{r echo=TRUE, eval=TRUE, message=FALSE, warning=FALSE}
#| echo: true
#| output: true
# Can be set to read csv/tsv: any feature matrix file with metadata
# e.g., gpa-feature-matrix.tsv
gpa_featmat <- read_delim("data/staph_penicillin_pangenome.csv",
delim = ",", col_names = T)
```
### Data exploration
Let's print the tibble to examine it quickly.
```{r echo=TRUE, eval=TRUE, message=FALSE, warning=FALSE}
#| echo: true
#| output: true
gpa_featmat
dim(gpa_featmat)
```
Then, let's examine the `amr_pheno` column of this data frame that tells us which antimicrobial resistance (AMR) phenotype (resistance/susceptible) for each sample (i.e., each row, genome) for different drugs. We can tabulate the number and fraction of genomes per phenotype easily using the `count` and `mutate` functions from `dplyr`.
```{r echo=TRUE, eval=TRUE, message=FALSE, warning=FALSE}
#| echo: true
#| output: true
gpa_featmat %>%
count(amr_pheno) %>%
mutate(prop = n/sum(n))
```
Before we proceed, let's also try and get a sense of the values in this feature matrix. Since there are thousands of genes, we'll randomly pick a few of them and visualize the distribution of their values across all the samples using a histogram.
```{r eval=TRUE, echo=TRUE, message=FALSE, warning=FALSE}
#| echo: true
#| output: true
gpa_sum <- gpa_featmat |>
select(7:last_col()) |>
summarize(across(where(is.numeric), sum))
gpa_sum_long <- gpa_sum |>
pivot_longer(cols = everything(), names_to = "gene")
ggplot(gpa_sum_long, aes(value)) +
# geom_histogram(bins=10) +
geom_bar() +
scale_x_binned() +
theme_minimal() +
xlab("Genes present in X genomes") +
ylab("N Genes with X frequency")
```
## Feature matrices --\> ML
To keep the problem simple, we pick one drug of interest (penicillin) and define the problem as classifying whether a genome is resistant or susceptible to this antibiotic.
```{r echo=TRUE, eval=TRUE, message=FALSE, warning=FALSE}
#| echo: true
#| output: true
pos_pheno <- "Resistant"
```
Then, we need to modify the `amr_pheno` variable into a binary indicator of whether it is resistant or not and finally convert that variable into a factor so that the model knows to consider it as a way to partition the samples.
### Set up the feature matrix and labels for the ML model
```{r echo=TRUE, eval=TRUE, message=FALSE, warning=FALSE}
#| echo: true
#| output: true
gpa_featmat_pheno <-
gpa_featmat %>%
mutate(amr_pheno = ifelse(amr_pheno==pos_pheno,
"Resistant", "Susceptible")) %>%
mutate(across(where(is.character), as.factor))
```
A critical quantity to be fully aware of when setting up an ML problem is class balance, i.e., the relative sizes of the positive (`"Resistant"`) and negative (`"Susceptible"`) classes.
```{r echo=TRUE, eval=TRUE, message=FALSE, warning=FALSE}
#| echo: true
#| output: true
gpa_featmat_pheno %>%
count(amr_pheno) %>%
mutate(prop = n/sum(n))
```
We can see that, in our dataset, only xx% of the samples are "Resistant". Referred to as *class imbalance*, this scenario is extremely common in biomedicine and needs careful attention when analyzing and interpreting results.
### Data splitting
If we take the data from all samples and train an *AMR classification* ML model, we cannot easily tell how good the model is. So, let's reserve 25% of the samples to a *test set*, which we will hold out until the end of the project, at which point there should only be one or two models under serious consideration. The *test set* will be used as an unbiased source for measuring final model performance.
This is also the first step where we need to pay attention to class balance. We use *stratified* random samples so that both the splits contain nearly identical proportions of positive and negative samples.
```{r echo=TRUE, eval=TRUE, message=FALSE, warning=FALSE}
#| echo: true
#| output: true
# The function `initial_split()` takes the original data and saves the information on how to make the partitions.
set.seed(123)
splits <- initial_split(data = gpa_featmat_pheno,
strata = amr_pheno)
# Within initial_split, you can specify proportion using "prop" and
# grouping/datasets to go into the same set using "group"
# The `training()` and `testing()` functions return the actual datasets.
gpa_other <- training(splits)
gpa_test <- testing(splits)
```
Let's check if we indeed did achieve stratified data splits.
```{r echo=TRUE, eval=TRUE, message=FALSE, warning=FALSE}
#| echo: true
#| output: true
# other set proportions by AMR pheno
gpa_other %>%
count(amr_pheno) %>%
mutate(prop = n/sum(n))
```
```{r echo=TRUE, eval=TRUE, message=FALSE, warning=FALSE}
#| echo: true
#| output: true
# test set proportions by R/S ratio
gpa_test %>%
count(amr_pheno) %>%
mutate(prop = n/sum(n))
```
What's up with the `gpa_other` split that's not testing? This split will be used to create two new datasets:
1. The set held out for the purpose of measuring performance, called the *validation set*, and
2. The remaining data used to fit the model, called the *training set*.
We'll use the `validation_split` function to allocate 20% of the `gpa_other` samples to the validation set and the remaining 80% to the training set. Note that this function too has the `strata` argument. Do you see why we need it here?
```{r echo=TRUE, eval=TRUE, message=FALSE, warning=FALSE}
#| echo: true
#| output: true
set.seed(234)
gpa_val <- validation_split(data = gpa_other,
strata = amr_pheno, # maintain original data split
prop = 0.80) # 80% training; 20% validation
gpa_val
```
### Training ML models in R: Penalized logistic regression
Since our outcome variable `AMR_pheno` is categorical, [logistic regression](https://en.wikipedia.org/wiki/Logistic_regression) would be a good first model to start. Let's use a model that can perform feature selection during training. The [glmnet](https://cran.r-project.org/web/packages/glmnet/index.html) R package fits a generalized linear model via penalized maximum likelihood. This method of estimating the logistic regression slope parameters uses a *penalty* on the process so that the coefficients of less relevant predictors are driven towards a value of zero. One of the `glmnet` penalization methods, called the [lasso method](https://en.wikipedia.org/wiki/Lasso_(statistics)), can actually set the predictor slopes to zero if a large enough penalty is used.
### Build the model
To specify a penalized logistic regression model that uses a feature selection penalty, we will use `parsnip` package (part of `tidymodels`) that is great at providing a tidy, unified interface to models that can be used to try a range of models without getting bogged down in the syntactical minutiae of the underlying packages.
Here, let's use it with the [glmnet engine](https://www.tidymodels.org/find/parsnip/):
```{r echo=TRUE, eval=TRUE, message=FALSE, warning=FALSE}
#| echo: true
#| output: true
# Build logistic regression model
lr_model <-
logistic_reg(penalty = tune(), # strength of regularization/penalty
mixture = 1) %>% # specifies a pure lasso model
set_engine("glmnet") # set to generalized linear models
```
We'll set the `penalty` argument to `tune()` as a placeholder for now. This is a model *hyperparameter* that we will [tune](https://www.tidymodels.org/start/tuning/) to find the best value for making predictions with our data. Setting `mixture` to a value of `1` means that the glmnet model will potentially remove irrelevant predictors and choose a simpler model. Sum of absolute values of beta-coefficients is minimized.
*You can try with `mixture=0` for L2 ridge regression (or 0-1 for elasticnet combining L1 and L2).*
### Create the recipe
Next, we're going to use the `recipes` to build [dplyr](https://dplyr.tidyverse.org/)-like pipeable sequences of feature engineering steps to get our data ready for modeling. Recipes are built as a series of pre-processing steps, such as:
- converting qualitative predictors to indicator variables (also known as dummy variables),
- transforming data to be on a different scale (e.g., taking the logarithm of a variable),
- transforming whole groups of predictors together,
- extracting key features from raw variables (e.g., getting the day of the week out of a date variable),
and so on. Here, we're using it to set up the outcome variable as a function of gene presence and then do two things:
- `step_zv()` removes indicator variables that only contain a single unique value (e.g. all zeros). This is important because, for penalized models, the predictors should be centered and scaled.
- `step_normalize()` centers and scales numeric variables.
```{r echo=TRUE, eval=TRUE, message=FALSE, warning=FALSE}
#| echo: true
#| output: true
lr_recipe <-
recipe(amr_pheno ~ ., data = gpa_other) %>% # specify data + labels
update_role(c(s_no, genome_id, assembly_accession, # genome attributes
antibiotic, drug_class), # drug attributes
new_role = "Supplementary") %>% # tag metadata not used for ML
step_zv(all_predictors()) %>% # remove predictors with only one value
# step_nzv(all_predictors()) # for near-zero variance
step_normalize(all_predictors()) # normalize all predictors
```
*Try with `step_nzv` instead of only `step_zv`.*
### Create the workflow
Let's bundle the model and recipe into a single `workflow()` object to make management of the R objects easier:
```{r echo=TRUE, eval=TRUE, message=FALSE, warning=FALSE}
#| echo: true
#| output: true
# Standard model recipe for LR | uses our recipe definition from above
lr_workflow <- workflow() %>%
add_model(lr_model) %>%
add_recipe(lr_recipe)
```
### Create the grid for tuning
Before we fit this model, we need to set up a grid of `penalty` values to tune.
```{r echo=TRUE, eval=TRUE, message=FALSE, warning=FALSE}
#| echo: true
#| output: true
# Try values from 0.0001 to 0.1 to penalize for complex models;
# Minimizing no. of features with non-zero coefficients
lr_reg_grid <- tibble(penalty = 10^seq(-4, -1, length.out = 10))
lr_reg_grid
```
### Train and tune the model
The `tune::tune_grid()` function will help us train these 10 penalized logistic regression models and save the validation set prediction (via the call to `control_grid()`) so that diagnostic information will be available after fitting the model. To quantify how well the model performs (on the *validation set*), let's first consider the [area under the ROC curve](https://en.wikipedia.org/wiki/Receiver_operating_characteristic) across a range of hyperparameters.
```{r echo=TRUE, eval=TRUE, message=FALSE, warning=FALSE}
#| echo: true
#| output: true
lr_res <-
lr_workflow %>%
tune_grid(resamples = gpa_val, # using validation split
grid = lr_reg_grid,
control = control_grid(save_pred = TRUE),
metrics = metric_set(roc_auc))
#metrics = metric_set(pr_auc)) # if you want to optimize for AUPRC instead
```
#### Tune the model with cross-validation instead?
```{r echo=TRUE, eval=FALSE, message=FALSE, warning=FALSE}
#| echo: true
#| output: false
lr_res_cv <-
lr_workflow %>%
tune_grid(resamples = vfold_cv(gpa_other), # new CV line
grid = lr_reg_grid,
control = control_grid(save_pred = TRUE),
metrics = metric_set(roc_auc))
#metrics = metric_set(pr_auc)) # if you want to optimize for AUPRC instead
```
## Evaluation metrics
A plot of the area under the ROC curve against the range of penalty values will help us guess which value is best for the problem/dataset at hand.
```{r echo=TRUE, eval=TRUE, message=FALSE, warning=FALSE}
#| echo: true
#| output: true
lr_plot <- lr_res %>%
collect_metrics() %>%
ggplot(aes(x = penalty, y = mean)) +
geom_point() +
geom_line() +
ylab("Area under the ROC Curve") +
#ylab("Area under the PR Curve") +
scale_x_log10(labels = scales::label_number()) +
theme_bw()
lr_plot
```
What is your interpretation of this plot? Write it here.
We can also tabulate these results to help pick the "best" hyperparameter.
```{r echo=TRUE, eval=TRUE, message=FALSE, warning=FALSE}
#| echo: true
#| output: true
top_models <-
lr_res %>%
show_best("roc_auc", n = 10) %>%
arrange(penalty)
top_models
```
Let's select the best value and visualize the validation set ROC curve. Why are we picking the 6^th^ value instead of the 1^st^ even though they have nearly identical performance metrics?
```{r echo=TRUE, eval=TRUE, message=FALSE, warning=FALSE}
#| echo: true
#| output: true
lr_best <- lr_res %>%
collect_metrics() %>%
arrange(penalty, mean) %>%
slice(6)
# Alternatively, you can just use
lr_best <- lr_res |>
select_best(metric = "roc_auc")
lr_best
```
```{r echo=TRUE, eval=TRUE, message=FALSE, warning=FALSE}
#| echo: true
#| output: true
lr_roc <- lr_res %>%
collect_predictions(parameters = lr_best) %>%
roc_curve(amr_pheno, .pred_Resistant) %>%
mutate(model = "Logistic Regression")
autoplot(lr_roc)
## Alternatively ...
# Select the best LR model
final_lr_model <- finalize_workflow(lr_workflow, lr_best)
# Fit the data
lr_fit <- final_lr_model %>% fit(data = gpa_other)
# Save predictions
lr_aug <- augment(lr_fit, gpa_test)
# Calculate AUROC
auroc <- lr_aug %>% roc_auc(truth = amr_pheno, .pred_Resistant) %>%
select(.estimate) %>% as.numeric()
print(paste("AUROC:", auroc))
```
The area under the ROC curve has a nice property that it can be interpreted as a probability and has a close connection to a statistical test ([the Mann-Whitney U test](https://en.wikipedia.org/wiki/Mann%E2%80%93Whitney_U)).
### Selecting the top features
```{r echo=TRUE, eval=TRUE, message=FALSE, warning=FALSE}
#| echo: true
#| output: true
# Extract top 10 genes
n_top_genes <- 10
top_genes <- lr_fit %>% extract_fit_parsnip() %>%
vip::vi() %>% slice(1:n_top_genes) %>%
select(1) %>% pull()
print(top_genes)
```
### When you have imbalanced classes
However, this measure is not sensitive to class imbalances and can come out to be high even if the model is making many mistakes in the minor positive class --- which is typically of biomedical interest --- and getting most of the major negative class correct.
So, the final analysis we're going to do is to evaluate performance based on another metric called [area under the Precision-Recall curve](https://en.wikipedia.org/wiki/Precision_and_recall) that is more sensitive to the minor positive class by focusing on the fraction of top positive predictions that are correct (precision) and the fraction of positive samples that are correctly predicted (recall).
```{r echo=TRUE, eval=TRUE, message=FALSE, warning=FALSE}
#| echo: true
#| output: true
lr_res_pr <- lr_workflow %>%
tune_grid(resamples = gpa_val,
grid = lr_reg_grid,
control = control_grid(save_pred = TRUE),
metrics = metric_set(pr_auc))
```
```{r echo=TRUE, eval=TRUE, message=FALSE, warning=FALSE}
#| echo: true
#| output: true
lr_plot_pr <- lr_res_pr %>%
collect_metrics() %>%
ggplot(aes(x = penalty, y = mean)) +
geom_point() +
geom_line() +
ylab("Area under the PR Curve") +
scale_x_log10(labels = scales::label_number()) +
theme_bw()
lr_plot
```
```{r echo=TRUE, eval=TRUE, message=FALSE, warning=FALSE}
#| echo: true
#| output: true
lr_best_pr <- lr_res_pr %>%
collect_metrics() %>%
arrange(penalty) %>%
slice(6)
# Alternatively, you can just use
lr_best_pr <- lr_res_pr |>
select_best(metric = "pr_auc")
lr_best
```
```{r echo=TRUE, eval=TRUE, message=FALSE, warning=FALSE}
#| echo: true
#| output: true
lr_pr <- lr_res_pr %>%
collect_predictions(parameters = lr_best) %>%
pr_curve(amr_pheno, .pred_Resistant) %>%
mutate(model = "Logistic Regression")
autoplot(lr_pr)
```
#### Retrieving your top features
```{r echo=TRUE, eval=TRUE, message=FALSE, warning=FALSE, render=FALSE}
#| echo: true
#| output: true
# Select best LR model
best_lr_model_pr <- select_best(lr_res_pr, "pr_auc")
final_lr_model_pr <- finalize_workflow(lr_workflow, best_lr_model_pr)
# Fit the data
lr_fit_pr <- final_lr_model_pr %>%
fit(data = gpa_other)
# Save predictions
lr_aug_pr <- augment(lr_fit_pr, gpa_test)
# Get AUPRC
auprc <- lr_aug_pr %>%
pr_auc(truth = amr_pheno, .pred_Resistant) %>%
select(.estimate) %>% as.numeric()
print(paste("AUPRC:", auprc))
## Extract top 10 genes
n_top_genes <- 10
top_genes_pr <- lr_fit_pr |>
extract_fit_parsnip() |>
vip::vi() |>
slice(1:n_top_genes) |>
select(1) |>
pull()
print(top_genes_pr)
```
## Predicting AR w/ Random Forest
```{r echo=TRUE, eval=TRUE, message=FALSE, warning=FALSE, render=FALSE}
#| echo: true
#| output: true
# Setting a seed enables our analysis to be reproducible when random numbers are used.
set.seed(569)
rf_splits <- initial_split(gpa_featmat_pheno,
#prop = train_test_split,
strata = amr_pheno)
#Create separate data frames for the training and testing sets.
gpa_train <- training(rf_splits)
gpa_test <- testing(rf_splits)
set.seed(234)
gpa_val <- validation_split(data = gpa_train,
strata = amr_pheno, # maintain original data split
prop = 0.80) # 80% training; 20% validation
gpa_val
#Create recipe
rf_recipe <- recipe(amr_pheno ~ ., data = gpa_train) %>%
# To keep these columns but not use them as predictors or outcome
update_role(c(s_no, genome_id, assembly_accession, # genome attributes
antibiotic, drug_class), # drug attributes
new_role = "Supplementary") %>%
step_zv(all_predictors()) %>% # remove predictors with only one value
# step_nzv(all_predictors()) # for near-zero variance
step_normalize(all_predictors()) # normalize all predictors
# Build random forest model
num_trees <- 1000
rf_model <- rand_forest(trees = num_trees) %>%
set_engine("ranger", importance = "impurity") %>%
set_mode("classification")
# Create workflow
rf_workflow <- workflow() %>%
add_model(rf_model) %>%
add_recipe(rf_recipe)
# Specify the hyperparameter tuning grid
rf_grid <- tibble(mtry = c(0.002, 0.02, 0.2), min_n = c(2, 6, 12))
# Tune the model using cross-validation;
# try 9 different hyperparameter sets; use auprc as evaluation metric.
rf_res <- tune_grid(rf_workflow,
resamples = vfold_cv(gpa_train),
grid = rf_grid,
control = control_grid(save_pred = T),
metrics = metric_set(roc_auc))
rf_best <- rf_res |>
select_best(metric = "roc_auc")
# Plot AUROC
rf_roc <- rf_res %>%
collect_predictions(parameters = rf_best) %>%
roc_curve(amr_pheno, .pred_Resistant) %>%
mutate(model = "Logistic Regression")
autoplot(rf_roc)
# Select best RF model
best_rf_model <- select_best(rf_res, "roc_auc")
final_rf_model <- finalize_workflow(rf_workflow, best_rf_model)
# Fit the data
rf_fit <- final_rf_model %>% fit(data = gpa_train)
# Save predictions
rf_aug <- augment(rf_fit, gpa_test)
# Get auroc
auroc <- rf_aug %>%
roc_auc(truth = amr_pheno, .pred_Resistant) %>%
select(.estimate) %>%
as.numeric()
print(paste("AUROC:", auroc))
# Extract top 10 genes
n_top_genes <- 10
top_genes_rf <- rf_fit %>%
extract_fit_parsnip() %>%
vip::vi() %>% slice(1:n_top_genes) %>%
select(1) %>% pull()
print(top_genes_rf)
```
## Too many features?
Try dimensionality reduction with SVD --\> retrieve top PCs --\> find contributing features to the top PCs.
## Recap & Conclusions
- [x] Reproducible docs & code with `qmd`/`rmd`
- [x] basic data cleanup to get it ready for ML models
- [x] tidymodels
- [x] building recipes and workflows
- [x] calculating AUROC and AUPRC
- [x] train-validate-test splits to optimize for best hyperparameters
- [x] picking the best models based on low penalty and high AUROC/AUPRC
- [x] plotting AUROC/AUPRC
- [x] Logistic regression with L1 lasso regression (and L2)
- [x] Random Forest models
## How to contact us
- Website: <https://jravilab.github.io>
- Twitter: \@jravilab \@janani137
- Email: janani DOT ravi AT cuanschutz DOT edu
- Rendered material: <https://jananiravi.github.io/2023-mlhd>