library(psptools)
library(pspdata)
library(dplyr)
library(ggplot2)
Class Imbalance
Class imbalance in training data using binary and multiclass bins
<- read_psp_data(fix_species=TRUE) |>
psp mutate(year = format(date, format="%Y")) |>
filter(species == "mytilus")
Binary
Predicting probability of toxicity above/below closure limit
<- list(
cfg configuration="test",
image_list = list(tox_levels = c(0,80),
forecast_steps = 1,
n_steps = 3,
minimum_gap = 4,
maximum_gap = 10,
multisample_weeks="last",
toxins = c("gtx4", "gtx1", "dcgtx3", "gtx5", "dcgtx2", "gtx3",
"gtx2", "neo", "dcstx", "stx", "c1", "c2")),
model = list(balance_val_set=FALSE,
downsample=FALSE,
use_class_weights=FALSE,
dropout1 = 0.3,
dropout2 = 0.3,
batch_size = 32,
units1 = 32,
units2 = 32,
epochs = 128,
validation_split = 0.2,
shuffle = TRUE,
num_classes = 4,
optimizer="adam",
loss_function="categorical_crossentropy",
model_metrics=c("categorical_accuracy")),
train_test = list(split_by="year",
train = c("2015", "2016", "2017", "2018", "2019", "2020", "2021"),
test = c("2014"))
)
<- transform_data(cfg, psp) binary
tibble(location = binary$train$locations, class = binary$train$classifications) |>
ggplot(aes(x=class)) +
geom_bar()
Multiclass
Predicting one of more than two toxicity classifications. Here we use 0, 10, 30, and 80 as cutoffs.
<- list(
cfg configuration="test",
image_list = list(tox_levels = c(0,10,30,80),
forecast_steps = 1,
n_steps = 3,
minimum_gap = 4,
maximum_gap = 10,
multisample_weeks="last",
toxins = c("gtx4", "gtx1", "dcgtx3", "gtx5", "dcgtx2", "gtx3",
"gtx2", "neo", "dcstx", "stx", "c1", "c2")),
model = list(balance_val_set=FALSE,
downsample=FALSE,
use_class_weights=FALSE,
dropout1 = 0.3,
dropout2 = 0.3,
batch_size = 32,
units1 = 32,
units2 = 32,
epochs = 128,
validation_split = 0.2,
shuffle = TRUE,
num_classes = 4,
optimizer="adam",
loss_function="categorical_crossentropy",
model_metrics=c("categorical_accuracy")),
train_test = list(split_by="year",
train = c("2015", "2016", "2017", "2018", "2019", "2020", "2021"),
test = c("2014"))
)
<- transform_data(cfg, psp) multiclass
tibble(location = multiclass$train$locations, class = multiclass$train$classifications) |>
ggplot(aes(x=class)) +
geom_bar()
Techniques for overcoming class imbalance
Downsampling
The distribution of the classes becomes even. Since we only have around 200 samples in class 3 (the most rare), we will sample that many from each of the others.
<- list(
cfg configuration="test",
image_list = list(tox_levels = c(0,10,30,80),
forecast_steps = 1,
n_steps = 3,
minimum_gap = 4,
maximum_gap = 10,
multisample_weeks="last",
toxins = c("gtx4", "gtx1", "dcgtx3", "gtx5", "dcgtx2", "gtx3",
"gtx2", "neo", "dcstx", "stx", "c1", "c2")),
model = list(balance_val_set=FALSE,
downsample=TRUE,
use_class_weights=FALSE,
dropout1 = 0.3,
dropout2 = 0.3,
batch_size = 32,
units1 = 32,
units2 = 32,
epochs = 128,
validation_split = 0.2,
shuffle = TRUE,
num_classes = 4,
optimizer="adam",
loss_function="categorical_crossentropy",
model_metrics=c("categorical_accuracy")),
train_test = list(split_by="year",
train = c("2015", "2016", "2017", "2018", "2019", "2020", "2021"),
test = c("2014"))
)
<- transform_data(cfg, psp) downsampled
tibble(location = downsampled$train$locations, class = downsampled$train$classifications) |>
ggplot(aes(x=class)) +
geom_bar()
Validation set balancing
The keras::fit()
function will let us manually assign the samples in the validation set, rather than choosing a random percentage with the validation_split
argument. We can sample an even distribution of each class. Balancing the validation set can be combined with downsampling in the training set.
<- list(
cfg configuration="test",
image_list = list(tox_levels = c(0,10,30,80),
forecast_steps = 1,
n_steps = 3,
minimum_gap = 4,
maximum_gap = 10,
multisample_weeks="last",
toxins = c("gtx4", "gtx1", "dcgtx3", "gtx5", "dcgtx2", "gtx3",
"gtx2", "neo", "dcstx", "stx", "c1", "c2")),
model = list(balance_val_set=TRUE,
downsample=FALSE,
use_class_weights=FALSE,
dropout1 = 0.3,
dropout2 = 0.3,
batch_size = 32,
units1 = 32,
units2 = 32,
epochs = 128,
validation_split = 0.2,
shuffle = TRUE,
num_classes = 4,
optimizer="adam",
loss_function="categorical_crossentropy",
model_metrics=c("categorical_accuracy")),
train_test = list(split_by="year",
train = c("2015", "2016", "2017", "2018", "2019", "2020", "2021"),
test = c("2014"))
)
<- transform_data(cfg, psp) balanced_val
str(balanced_val)
List of 3
$ train:List of 7
..$ labels : num [1:3565, 1:4] 0 0 1 1 1 1 0 1 1 0 ...
..$ image : num [1:3565, 1:36] 0 0 0 0.465 0 ...
..$ classifications: num [1:3565] 1 1 0 0 0 0 1 0 0 1 ...
..$ toxicity : num [1:3565] 11.879 10.786 0 3.365 0.431 ...
..$ locations : chr [1:3565] "PSP10.2" "PSP27.43" "PSP27.22" "PSP19.13" ...
..$ dates : num [1:3565] 16573 16978 16587 18071 17666 ...
..$ scaling_factors: NULL
$ val :List of 7
..$ labels : num [1:891, 1:4] 1 1 1 1 1 0 0 1 1 1 ...
..$ image : num [1:891, 1:36] 0 0 0 0 0 ...
..$ classifications: num [1:891] 0 0 0 0 0 2 3 0 0 0 ...
..$ toxicity : num [1:891] 0.402 0 6.401 7.203 0.436 ...
..$ locations : chr [1:891] "PSP26.15" "PSP25.13" "PSP14.2" "PSP12.06" ...
..$ dates : num [1:891] 17022 16973 17673 17685 16560 ...
..$ scaling_factors: NULL
$ test :List of 7
..$ labels : num [1:1062, 1:4] 1 0 1 1 0 1 1 0 1 1 ...
..$ image : num [1:1062, 1:36] 0 0.587 0 0 0 ...
..$ classifications: num [1:1062] 0 1 0 0 2 0 0 1 0 0 ...
..$ toxicity : num [1:1062] 0.319 24.427 0.225 0.823 38.919 ...
..$ locations : chr [1:1062] "PSP25.08" "PSP27.05" "PSP25.06" "PSP10.2" ...
..$ dates : num [1:1062] 16237 16279 16293 16265 16218 ...
..$ scaling_factors: NULL
tibble(location = balanced_val$val$locations, class = balanced_val$val$classifications) |>
ggplot(aes(x=class)) +
geom_bar()
Weighted classes
keras::fit()
also accepts a class_weights
argument. psptools
provides a function get_class_weights()
to obtain these.
<- list(
cfg configuration="test",
image_list = list(tox_levels = c(0,10,30,80),
forecast_steps = 1,
n_steps = 3,
minimum_gap = 4,
maximum_gap = 10,
multisample_weeks="last",
toxins = c("gtx4", "gtx1", "dcgtx3", "gtx5", "dcgtx2", "gtx3",
"gtx2", "neo", "dcstx", "stx", "c1", "c2")),
model = list(balance_val_set=FALSE,
downsample=FALSE,
use_class_weights=TRUE,
dropout1 = 0.3,
dropout2 = 0.3,
batch_size = 32,
units1 = 32,
units2 = 32,
epochs = 128,
validation_split = 0.2,
shuffle = TRUE,
num_classes = 4,
optimizer="adam",
loss_function="categorical_crossentropy",
model_metrics=c("categorical_accuracy")),
train_test = list(split_by="year",
train = c("2015", "2016", "2017", "2018", "2019", "2020", "2021"),
test = c("2014"))
)
<- transform_data(cfg, psp) model_input
<- get_class_weights(model_input$train$classifications)
class_weights
class_weights
$`0`
[1] 1
$`1`
[1] 8.679518
$`2`
[1] 16.00889
$`3`
[1] 16.83178