Class Imbalance

library(psptools)
library(pspdata)

library(dplyr)
library(ggplot2)

Class imbalance in training data using binary and multiclass bins

psp <- read_psp_data(fix_species=TRUE) |>
  mutate(year = format(date, format="%Y")) |>
  filter(species == "mytilus")

Binary

Predicting probability of toxicity above/below closure limit

cfg <- list(
  configuration="test",
  image_list = list(tox_levels = c(0,80),
                    forecast_steps = 1,
                    n_steps = 3,
                    minimum_gap = 4,
                    maximum_gap = 10,
                    multisample_weeks="last",
                    toxins = c("gtx4", "gtx1", "dcgtx3", "gtx5", "dcgtx2", "gtx3", 
                               "gtx2", "neo", "dcstx", "stx", "c1", "c2")),
  model = list(balance_val_set=FALSE,
               downsample=FALSE,
               use_class_weights=FALSE,
               dropout1 = 0.3,
               dropout2 = 0.3,
               batch_size = 32, 
               units1 = 32, 
               units2 = 32, 
               epochs = 128, 
               validation_split = 0.2,
               shuffle = TRUE,
               num_classes = 4,
               optimizer="adam",
               loss_function="categorical_crossentropy",
               model_metrics=c("categorical_accuracy")),
  train_test = list(split_by="year",
                    train = c("2015", "2016", "2017", "2018", "2019", "2020", "2021"), 
                    test = c("2014"))
)
binary <- transform_data(cfg, psp)
tibble(location = binary$train$locations, class = binary$train$classifications) |> 
  ggplot(aes(x=class)) +
  geom_bar()

Multiclass

Predicting one of more than two toxicity classifications. Here we use 0, 10, 30, and 80 as cutoffs.

cfg <- list(
  configuration="test",
  image_list = list(tox_levels = c(0,10,30,80),
                    forecast_steps = 1,
                    n_steps = 3,
                    minimum_gap = 4,
                    maximum_gap = 10,
                    multisample_weeks="last",
                    toxins = c("gtx4", "gtx1", "dcgtx3", "gtx5", "dcgtx2", "gtx3", 
                               "gtx2", "neo", "dcstx", "stx", "c1", "c2")),
  model = list(balance_val_set=FALSE,
               downsample=FALSE,
               use_class_weights=FALSE,
               dropout1 = 0.3,
               dropout2 = 0.3,
               batch_size = 32, 
               units1 = 32, 
               units2 = 32, 
               epochs = 128, 
               validation_split = 0.2,
               shuffle = TRUE,
               num_classes = 4,
               optimizer="adam",
               loss_function="categorical_crossentropy",
               model_metrics=c("categorical_accuracy")),
  train_test = list(split_by="year",
                    train = c("2015", "2016", "2017", "2018", "2019", "2020", "2021"), 
                    test = c("2014"))
)
multiclass <- transform_data(cfg, psp)
tibble(location = multiclass$train$locations, class = multiclass$train$classifications) |> 
  ggplot(aes(x=class)) +
  geom_bar()

Techniques for overcoming class imbalance

Downsampling

The distribution of the classes becomes even. Since we only have around 200 samples in class 3 (the most rare), we will sample that many from each of the others.

cfg <- list(
  configuration="test",
  image_list = list(tox_levels = c(0,10,30,80),
                    forecast_steps = 1,
                    n_steps = 3,
                    minimum_gap = 4,
                    maximum_gap = 10,
                    multisample_weeks="last",
                    toxins = c("gtx4", "gtx1", "dcgtx3", "gtx5", "dcgtx2", "gtx3", 
                               "gtx2", "neo", "dcstx", "stx", "c1", "c2")),
  model = list(balance_val_set=FALSE,
               downsample=TRUE,
               use_class_weights=FALSE,
               dropout1 = 0.3,
               dropout2 = 0.3,
               batch_size = 32, 
               units1 = 32, 
               units2 = 32, 
               epochs = 128, 
               validation_split = 0.2,
               shuffle = TRUE,
               num_classes = 4,
               optimizer="adam",
               loss_function="categorical_crossentropy",
               model_metrics=c("categorical_accuracy")),
  train_test = list(split_by="year",
                    train = c("2015", "2016", "2017", "2018", "2019", "2020", "2021"), 
                    test = c("2014"))
)
downsampled <- transform_data(cfg, psp)
tibble(location = downsampled$train$locations, class = downsampled$train$classifications) |> 
  ggplot(aes(x=class)) +
  geom_bar()

Validation set balancing

The keras::fit() function will let us manually assign the samples in the validation set, rather than choosing a random percentage with the validation_split argument. We can sample an even distribution of each class. Balancing the validation set can be combined with downsampling in the training set.

cfg <- list(
  configuration="test",
  image_list = list(tox_levels = c(0,10,30,80),
                    forecast_steps = 1,
                    n_steps = 3,
                    minimum_gap = 4,
                    maximum_gap = 10,
                    multisample_weeks="last",
                    toxins = c("gtx4", "gtx1", "dcgtx3", "gtx5", "dcgtx2", "gtx3", 
                               "gtx2", "neo", "dcstx", "stx", "c1", "c2")),
  model = list(balance_val_set=TRUE,
               downsample=FALSE,
               use_class_weights=FALSE,
               dropout1 = 0.3,
               dropout2 = 0.3,
               batch_size = 32, 
               units1 = 32, 
               units2 = 32, 
               epochs = 128, 
               validation_split = 0.2,
               shuffle = TRUE,
               num_classes = 4,
               optimizer="adam",
               loss_function="categorical_crossentropy",
               model_metrics=c("categorical_accuracy")),
  train_test = list(split_by="year",
                    train = c("2015", "2016", "2017", "2018", "2019", "2020", "2021"), 
                    test = c("2014"))
)
balanced_val <- transform_data(cfg, psp)
str(balanced_val)
List of 3
 $ train:List of 7
  ..$ labels         : num [1:3565, 1:4] 0 0 1 1 1 1 0 1 1 0 ...
  ..$ image          : num [1:3565, 1:36] 0 0 0 0.465 0 ...
  ..$ classifications: num [1:3565] 1 1 0 0 0 0 1 0 0 1 ...
  ..$ toxicity       : num [1:3565] 11.879 10.786 0 3.365 0.431 ...
  ..$ locations      : chr [1:3565] "PSP10.2" "PSP27.43" "PSP27.22" "PSP19.13" ...
  ..$ dates          : num [1:3565] 16573 16978 16587 18071 17666 ...
  ..$ scaling_factors: NULL
 $ val  :List of 7
  ..$ labels         : num [1:891, 1:4] 1 1 1 1 1 0 0 1 1 1 ...
  ..$ image          : num [1:891, 1:36] 0 0 0 0 0 ...
  ..$ classifications: num [1:891] 0 0 0 0 0 2 3 0 0 0 ...
  ..$ toxicity       : num [1:891] 0.402 0 6.401 7.203 0.436 ...
  ..$ locations      : chr [1:891] "PSP26.15" "PSP25.13" "PSP14.2" "PSP12.06" ...
  ..$ dates          : num [1:891] 17022 16973 17673 17685 16560 ...
  ..$ scaling_factors: NULL
 $ test :List of 7
  ..$ labels         : num [1:1062, 1:4] 1 0 1 1 0 1 1 0 1 1 ...
  ..$ image          : num [1:1062, 1:36] 0 0.587 0 0 0 ...
  ..$ classifications: num [1:1062] 0 1 0 0 2 0 0 1 0 0 ...
  ..$ toxicity       : num [1:1062] 0.319 24.427 0.225 0.823 38.919 ...
  ..$ locations      : chr [1:1062] "PSP25.08" "PSP27.05" "PSP25.06" "PSP10.2" ...
  ..$ dates          : num [1:1062] 16237 16279 16293 16265 16218 ...
  ..$ scaling_factors: NULL
tibble(location = balanced_val$val$locations, class = balanced_val$val$classifications) |> 
  ggplot(aes(x=class)) +
  geom_bar()

Weighted classes

keras::fit() also accepts a class_weights argument. psptools provides a function get_class_weights() to obtain these.

cfg <- list(
  configuration="test",
  image_list = list(tox_levels = c(0,10,30,80),
                    forecast_steps = 1,
                    n_steps = 3,
                    minimum_gap = 4,
                    maximum_gap = 10,
                    multisample_weeks="last",
                    toxins = c("gtx4", "gtx1", "dcgtx3", "gtx5", "dcgtx2", "gtx3", 
                               "gtx2", "neo", "dcstx", "stx", "c1", "c2")),
  model = list(balance_val_set=FALSE,
               downsample=FALSE,
               use_class_weights=TRUE,
               dropout1 = 0.3,
               dropout2 = 0.3,
               batch_size = 32, 
               units1 = 32, 
               units2 = 32, 
               epochs = 128, 
               validation_split = 0.2,
               shuffle = TRUE,
               num_classes = 4,
               optimizer="adam",
               loss_function="categorical_crossentropy",
               model_metrics=c("categorical_accuracy")),
  train_test = list(split_by="year",
                    train = c("2015", "2016", "2017", "2018", "2019", "2020", "2021"), 
                    test = c("2014"))
)
model_input <- transform_data(cfg, psp)
class_weights <- get_class_weights(model_input$train$classifications)

class_weights
$`0`
[1] 1

$`1`
[1] 8.679518

$`2`
[1] 16.00889

$`3`
[1] 16.83178