Distances for categorical data

the delta framework

Code
# library(devtools)
# install_github("alfonsoIodiceDE/catdist_package/catdist")
library(kableExtra)
library(catdist)
library(tidyverse)
library(tidymodels)
library(tidytext)
library(cluster)
library(aricode)
library(ggplot2)
#tidymodels_prefer()

This document provides the code to replicate the experiments from the paper

A general framework for implementing distances for categorical variables

by Michel van de Velden, Alfonso Iodice D’Enza, Angelos Markos and Carlo Cavicchia

Data structures

The considered distances are described in Table 1 of the paper

distance description
SM simple matching
IOF Iof
OF of
Goodall 3 Goodall 3
Goodall 4 Goodall 4
chi-2 chi-square distance
VE variable entropy
VM variable mutability
KL Kullback-Leibler
TVD total variation distance
Supervised TVD one-vs-all variation of TVD
Supervised TVD full all-vs-all variation of TVD

Synthetic data set

We refer to a \(n=1000\) by \(Q=12\) categorical dataset.

A different dataset is considered for mild/strong structure, that is, the discriminant power of the active variables

Of the \(Q=12\) attribute considered, \(8\) are noise.

The synthetic data sets are generated as described in the paper Cluster Correspondence Analysis.

Real data

Supervised learning: KNN-based experiment

Analysis and results on synthetic data sets

For each scenario and considered distance, a KNN classifier for a grid of values \(K=\{1,3,7,13,21, 31, 43\}\) is trained

set.seed(123)
replicates=20

k_vec = c(1,3,7,13,21,31,43)

selected_distances = c(
  "tot_var_dist", "gifi_chi2", "supervised", "supervised_full", "matching", 
  "eskin", "goodall_3", "goodall_4", "iof", "of", "lin", 
  "var_entropy", "var_mutability", "kullback-leibler"
)

data(simcatdat1)
data(simcatdat2)

cat_data_structure <- tibble(
  structure = c("mild", "strong"),
  cat_data = list(simcatdat1, simcatdat2),
  true_clusters = list(simcatdat1$response, simcatdat2$response)
) |>
  dplyr::mutate(
    distance_method = list(selected_distances, selected_distances)
  )

The hyper-parameter tuning process is done via 5-fold cross-validation, the process is repeated 20 times.

cat_data_resamples = cat_data_structure |> 
  dplyr::mutate(
    cv_iterations = purrr::map(1:n(), ~ 1:replicates)
  ) |> unnest(cv_iterations) |> 
  mutate(
    cross_validation = purrr::map(.x=cat_data,~vfold_cv(.x, v= 5, strata=response)),
    fold_id = purrr::map(.x=cross_validation,~.x$id),
    fold_splits = purrr::map(.x=cross_validation,~.x$splits),
    k_par = purrr::map(1:n(),~k_vec)
  ) |>  unnest(cols = distance_method) |>  unnest(cols = k_par) |> 
  unnest(cols = c(fold_id,fold_splits)) |> 
  dplyr::select(-cross_validation)

The results are produced, and the metrics computed for each combination of folds.

cat_data_knn_results = cat_data_resamples |> 
  mutate(knn_results = purrr::pmap(.l = list(fold_splits, k_par,distance_method),
                            .f = ~cdistKNN(train_df = analysis(..1),
                                           assess_df = assessment(..1),
                                           k=..2, method=..3),
  ),
  predictions = purrr::map(.x = knn_results,~.x$.pred),
  truth = purrr::map(.x = knn_results,~.x$truth),
  same_levs = purrr::map2_dbl(.x = predictions,.y=truth,
                       ~(length(levels(.x))==length(levels(.y)))),
  predictions = purrr::map2(.x = predictions,.y = truth,
                     ~factor(.x, levels = levels(.y))),
  knn_accuracy=purrr::map2_dbl(.x = predictions,.y = truth,
                        ~accuracy_vec(truth = .y,estimate = .x)),
  measure =  purrr::map2_dbl(.x = predictions,.y = truth,
                      .f=function(x=.x,y=.y) clustComp(c1 = x, c2 = y)$ARI
  )
  )

Then, for each scenario and distance measure, the metrics are averaged over the CV folds first, and then the descriptive statistics of the metrics are computed over the considered replicates.

cat_knn_PM_scenarios_accuracies = cat_data_knn_results |> 
  group_by(structure ,distance_method,k_par,cv_iterations) |>
  summarise(cv_accuracy=mean(knn_accuracy),
            cv_ARI=mean(measure),.groups="rowwise") |> ungroup()|>
  group_by(structure,distance_method,k_par) |> 
  summarise(sd_cv_accuracy=sd(cv_accuracy),
            min_accuracy=min(cv_accuracy),
            Q1_accuracy = quantile(cv_accuracy,probs=.25),
            Q2_accuracy = quantile(cv_accuracy,probs=.5),
            Q3_accuracy = quantile(cv_accuracy,probs=.75),
            max_accuracy=max(cv_accuracy),
            cv_accuracy=mean(cv_accuracy),
            sd_cv_ARI=sd(cv_ARI),
            min_ARI=min(cv_ARI),
            Q1_ARI = quantile(cv_ARI,probs=.25),
            Q2_ARI = quantile(cv_ARI,probs=.5),
            Q3_ARI = quantile(cv_ARI,probs=.75),
            max_ARI=max(cv_ARI),
            cv_ARI=mean(cv_ARI),
            .groups="rowwise") |>
  ungroup() |> group_by(structure,distance_method)

Finally, the tuned classifier accuracy is computed

Select the best K for each distance method and data set

cat_knn_PM_scenarios_accuracies_tuned = cat_knn_PM_scenarios_accuracies |>
  slice_max(cv_accuracy,with_ties = FALSE)|>
  mutate(
    distance_method = str_to_title(distance_method)
  )

And the results are displayed in the plot

labelize <- Vectorize(function(x){library("stringr")
  word(x, start = 1, sep = "\\:")
}
)

cat_knn_PM_scenarios_accuracies_tuned |> ungroup() |> 
  mutate(scenario=paste0(structure,"_"),
         distance_method = distance_method |> fct_recode("SM"="Matching",
                                                         "IOF"="Iof",
                                                         "OF"="Of",
                                                         "Goodall 3"="Goodall_3",
                                                         "Goodall 4"="Goodall_4",
                                                         "Chi2" = "Gifi_chi2",
                                                         "VE"="Var_entropy",
                                                         "VM"="Var_mutability",
                                                         "KL"="Kullback-Leibler",
                                                         "TVD"="Tot_var_dist",
                                                         "Supervised TVD"="Supervised",
                                                         "Supervised TVD full"="Supervised_full"
         ) |> as.character(),
         ordered_dist = reorder_within(x =distance_method,by = Q2_accuracy,
                                       within = structure, sep = ":")
  ) |> 
  rename(`distance measure` = `distance_method`) |> 
  ggplot(aes(x = Q2_accuracy, y = ordered_dist))+theme_minimal() +
  geom_point(aes(colour = `distance measure`, size = k_par))+
  geom_linerange(aes(xmin = Q1_accuracy,xmax = Q3_accuracy, y = ordered_dist),
                 inherit_aes=FALSE)+
  facet_wrap(.~structure, scales = "free_y")+
  theme(axis.text.y =  element_text(size=6),
        axis.text.x =  element_text(angle=90, size=5),
        legend.position = "bottom",
        strip.text.x.top = element_text(size=8)
  )  + 
  scale_y_discrete(label=labelize)+
  xlab("accuracy") + ylab("")+
  scale_size(range=c(.5,2.5),
             guide=guide_legend(title="neighbors",title.position="top",
                                ncol=2))+
  scale_color_discrete(breaks=c("SM","Eskin","Lin","IOF","OF","Goodall 3","Goodall 4","VE","VM","TVD","Chi2","KL","Supervised TVD","Supervised TVD full"),
                       guide=guide_legend(title.position="top"))

labelize <- Vectorize(function(x){library("stringr")
  word(x, start = 1, sep = "\\:")
}
)

cat_knn_PM_scenarios_accuracies_tuned |> ungroup() |> 
  mutate(scenario=paste0(structure,"_"),
         distance_method = distance_method |> fct_recode("SM"="Matching",
                                                         "IOF"="Iof",
                                                         "OF"="Of",
                                                         "Goodall 3"="Goodall_3",
                                                         "Goodall 4"="Goodall_4",
                                                         "Chi2" = "Gifi_chi2",
                                                         "VE"="Var_entropy",
                                                         "VM"="Var_mutability",
                                                         "KL"="Kullback-Leibler",
                                                         "TVD"="Tot_var_dist",
                                                         "Supervised TVD"="Supervised",
                                                         "Supervised TVD full"="Supervised_full"
         ) |> as.character(),
         ordered_dist = reorder_within(x =distance_method,by = Q2_ARI,
                                       within = structure, sep = ":")
  ) |> 
  rename(`distance measure` = `distance_method`) |> 
  ggplot(aes(x = Q2_ARI, y = ordered_dist))+theme_minimal() +
  geom_point(aes(colour = `distance measure`, size = k_par))+
  geom_linerange(aes(xmin = Q1_ARI,xmax = Q3_ARI, y = ordered_dist),
                 inherit_aes=FALSE)+
  facet_wrap(.~structure, scales = "free_y")+
  theme(axis.text.y =  element_text(size=6),
        axis.text.x =  element_text(angle=90, size=5),
        legend.position = "bottom",
        strip.text.x.top = element_text(size=8)
  )  + 
  scale_y_discrete(label=labelize)+
  xlab("test ARI") + ylab("")+
  scale_size(range=c(.5,2.5),
             guide=guide_legend(title="neighbors",title.position="top",
                                ncol=2))+
  scale_color_discrete(breaks=c("SM","Eskin","Lin","IOF","OF","Goodall 3","Goodall 4","VE","VM","TVD","Chi2","KL","Supervised TVD","Supervised TVD full"),
                       guide=guide_legend(title.position="top"))

Analysis and results on real data sets

set.seed(123)
data(australian)
data(wbcd)
data(vote)
data(tictactoe)
data(balance)
data(tae)
data(lymphography)
data(soybean)
data(cars)
replicates = 20
dataset_names = c("vote","australian","wbcd", "tictactoe","balance","tae", "lymphography","soybean","cars")

benchmark_data = tibble(
  dataset_name = dataset_names,
  prepped_data = list(vote, australian[,-c(2,3,7,10,13,14)], wbcd, tictactoe, balance, tae, lymphography, soybean, cars),
  distance_method = rep(list(selected_distances), length(dataset_names)) 
)

For each dataset and distance considered, we do a KNN, tuning the parameter on the grid 1, 3, 7, 13, 21 via 5-fold cross-validation. We repeat the process 20 times.

benchmark_data_resamples = benchmark_data |> 
  dplyr::mutate(
    cv_iterations = purrr::map(1:n(), ~ 1:replicates)
  ) |> unnest(cv_iterations) |> 
  mutate(
    cross_validation = purrr::map(.x=prepped_data,~vfold_cv(.x, v= 5, strata=response)),
    fold_id = purrr::map(.x=cross_validation,~.x$id),
    fold_splits = purrr::map(.x=cross_validation,~.x$splits),
    k_par = purrr::map(1:n(), ~ c(1,3,7,13,21)) 
  ) |> unnest(cols = distance_method) |> unnest(cols = k_par) |> unnest(cols = c(fold_id,fold_splits)) |> 
  dplyr::select(-cross_validation)

Apply the KNN on all fold/distance/dataset combinations

benchmark_results_training = benchmark_data_resamples |> 
  mutate(knn_results = purrr::pmap(.l = list(fold_splits, k_par,distance_method),
                            .f = ~cdistKNN(train_df = analysis(..1),
                                           assess_df = assessment(..1),
                                           k=..2, method=..3)
  )
  )

Collection of the cross-validation estimate of accuracy

accuracy_results_knn_single_rep_accuracy_ungrouped = benchmark_results_training |>
  mutate(
    classes  = purrr::map_dbl(.x=prepped_data,~length(levels(.x$response))),
    predictions = purrr::map(.x=knn_results,~.x$.pred),
    truth = purrr::map(.x=knn_results,~.x$truth),
    same_levs = purrr::map2_dbl(.x = predictions,.y=truth,~(length(levels(.x))==length(levels(.y)))),
    predictions = purrr::map2(.x=predictions,.y=truth,~factor(.x,levels=levels(.y))),
    knn_accuracy=purrr::map2_dbl(.x=predictions,.y=truth,~accuracy_vec(truth=.y,estimate=.x)),
    measure =  purrr::map2_dbl(.x = predictions,.y = truth,
                               .f=function(x=.x,y=.y){
                                 na_pos=which(is.na(x))
                                 if(length(na_pos)>0){
                                   x=x[-na_pos]
                                   y=y[-na_pos]
                                 }
                                 clustComp(c1 = x, c2 = y)$ARI
                               }
    )
  )

accuracy_results_knn_single_rep_accuracy=accuracy_results_knn_single_rep_accuracy_ungrouped|> 
  group_by(dataset_name,distance_method,k_par,cv_iterations) |>
  summarise(cv_accuracy=mean(knn_accuracy),
            cv_ARI=mean(measure)
  ) |> ungroup()



accuracy_results_knn = accuracy_results_knn_single_rep_accuracy |>
  group_by(dataset_name,distance_method,k_par) |> 
  summarise(sd_cv_accuracy=sd(cv_accuracy),
            min_accuracy=min(cv_accuracy),
            Q1_accuracy = quantile(cv_accuracy,probs=.25),
            Q2_accuracy = quantile(cv_accuracy,probs=.5),
            Q3_accuracy = quantile(cv_accuracy,probs=.75),
            max_accuracy=max(cv_accuracy),
            cv_accuracy=mean(cv_accuracy),
            sd_cv_ARI=sd(cv_ARI),
            min_ARI=min(cv_ARI),
            Q1_ARI = quantile(cv_ARI,probs=.25),
            Q2_ARI = quantile(cv_ARI,probs=.5),
            Q3_ARI = quantile(cv_ARI,probs=.75),
            max_ARI=max(cv_ARI),
            cv_ARI=mean(cv_ARI)
  ) |> ungroup() 

Select the best K for each distance method and data set

accuracy_results_tuned = accuracy_results_knn |>
  group_by(dataset_name,distance_method) |> 
  slice_max(cv_accuracy,with_ties = FALSE)|>
  mutate(
    distance_method = str_to_title(distance_method),
    datasets=as.factor(dataset_name)
  )

Further editing for the figures

long_levels = benchmark_data |> 
    mutate(
      cv_iterations = rerun(1:replicates, .n = n())
    ) |> unnest(cv_iterations) |> 
    mutate(
      class = c(rep(2,replicates),rep(2,replicates),rep(2,replicates),rep(2,replicates),rep(3,replicates),
                rep(3,replicates),rep(4,replicates),rep(19,replicates),rep(4,replicates)),
      data_n = purrr::map_dbl(prepped_data,nrow),
      data_p = purrr::map_dbl(prepped_data,ncol)-1,
      long_name = purrr::pmap_chr(.l=list(dataset_name,data_n,data_p,class),.f= ~paste0(..1," (n: ",..2,
                                                                                        ", p: ",..3,", cl: ",..4,")"))
    )|> pull(long_name) |> as.factor() |> levels()
  
  long_levels[5]="soybean (n: 307, p: 35, cl: 19)"
  
  nds_tuned=accuracy_results_tuned |> pull(dataset_name) |> as.factor()
  
  levels(nds_tuned) = long_levels
  accuracy_results_tuned$nds = nds_tuned
  accuracy_results_tuned=accuracy_results_tuned |>
    mutate(datasets = nds)
accuracy_results_tuned |> ungroup() |> 
  mutate(distance_method = distance_method |> fct_recode("SM"="Matching",
                                                     "IOF"="Iof",
                                                     "OF"="Of",
                                                     "Goodall 3"="Goodall_3",
                                                     "Goodall 4"="Goodall_4",
                                                     "Chi2" = "Gifi_chi2",
                                                     "VE"="Var_entropy",
                                                     "VM"="Var_mutability",
                                                     "KL"="Kullback-Leibler",
                                                     "TVD"="Tot_var_dist",
                                                     "Supervised TVD"="Supervised",
                                                     "Supervised TVD full"="Supervised_full"
                                                     ) |> as.character(),
         ordered_dist = reorder_within(x =distance_method,by = Q2_accuracy,
                                       within = datasets, sep = ":")
         ) |> 
  rename(`distance measure` = `distance_method`) |> 
  ggplot(aes(x = Q2_accuracy, y = ordered_dist))+theme_minimal() +
  geom_point(aes(colour = `distance measure`, size = k_par))+
  geom_linerange(aes(xmin = Q1_accuracy,xmax = Q3_accuracy, y = ordered_dist), inherit_aes=FALSE)+
  facet_wrap(~datasets, scales = "free_y", nrow=3, ncol=3)+
  theme(axis.text.y =  element_text(size=5),
        axis.text.x =  element_text(angle=90, size=6),
        legend.position = "bottom",
        strip.text.x.top = element_text(size=8)
  )  + 
  scale_y_discrete(label=labelize)+
  xlab("accuracy") + ylab("")+
  scale_size(range=c(.5,2.5),
             guide=guide_legend(title="neighbors",title.position="top",
                                ncol=2))+
  scale_color_discrete(breaks=c("SM","Eskin","Lin","IOF","OF","Goodall 3","Goodall 4","VE","VM","TVD","Chi2","KL","Supervised TVD","Supervised TVD full"),
                       guide=guide_legend(title.position="top"))

accuracy_results_tuned |> ungroup() |> 
  mutate(distance_method = distance_method |> fct_recode("SM"="Matching",
                                                     "IOF"="Iof",
                                                     "OF"="Of",
                                                     "Goodall 3"="Goodall_3",
                                                     "Goodall 4"="Goodall_4",
                                                     "Chi2" = "Gifi_chi2",
                                                     "VE"="Var_entropy",
                                                     "VM"="Var_mutability",
                                                     "KL"="Kullback-Leibler",
                                                     "TVD"="Tot_var_dist",
                                                     "Supervised TVD"="Supervised",
                                                     "Supervised TVD full"="Supervised_full"
                                                     ) |> as.character(),
         # ordered_dist = reorder_within(x =distance_method,by = cv_accuracy,
         #                               within = datasets, sep = ":")
         ordered_dist = reorder_within(x =distance_method,by = Q2_ARI,
                                       within = datasets, sep = ":")
         ) |> 
  rename(`distance measure` = `distance_method`) |> 
  # ggplot(aes(x = cv_accuracy, y = ordered_dist))+theme_minimal() +
  ggplot(aes(x = Q2_ARI, y = ordered_dist))+theme_minimal() +
  geom_point(aes(colour = `distance measure`, size = k_par))+
  # geom_linerange(aes(xmin = cv_accuracy-sd_cv_accuracy,xmax = cv_accuracy+sd_cv_accuracy, y = ordered_dist), inherit_aes=FALSE)+
  geom_linerange(aes(xmin = Q1_ARI,xmax = Q3_ARI, y = ordered_dist), inherit_aes=FALSE)+
  facet_wrap(~datasets, scales = "free_y", nrow=3, ncol=3)+
  theme(axis.text.y =  element_text(size=5),
        axis.text.x =  element_text(angle=90, size=6),
        legend.position = "bottom",
        strip.text.x.top = element_text(size=8)
  )  + 
  scale_y_discrete(label=labelize)+
  # scale_size(range=c(.5,2.5))+#guides(size="none")+
  # labs(title="KNN", subtitle = "5-fold CV") +
  xlab("test ARI") + ylab("")+
  scale_size(range=c(.5,2.5),
             guide=guide_legend(title="neighbors",title.position="top",
                                ncol=2))+
  scale_color_discrete(breaks=c("SM","Eskin","Lin","IOF","OF","Goodall 3","Goodall 4","VE","VM","TVD","Chi2","KL","Supervised TVD","Supervised TVD full"),
                       guide=guide_legend(title.position="top"))

Unsupervised learning: PAM-based experiment

Analysis and results on synthetic data sets

pam_synth_results_list=list()
set.seed(123)
for(cv_iter in 1:replicates){
  
  
  cat_data_resamples_pam = cat_data_structure |> 
    mutate(
      cross_validation = purrr::map(.x=cat_data,~vfold_cv(.x, v= 5, strata=response)),
      fold_id = purrr::map(.x=cross_validation,~.x$id),
      fold_splits = purrr::map(.x=cross_validation,~.x$splits),
      n_clusters  = 4
    )|> unnest(cols = distance_method) |> unnest(cols = c(fold_id,fold_splits)) |> 
  dplyr::select(-cross_validation,-cat_data) |>
    mutate(train_fold=purrr::map(.x=fold_splits,.f=~analysis(.x)),
           test_fold=purrr::map(.x=fold_splits,.f=~assessment(.x))
    ) |> 
    dplyr::select(-fold_splits)   |>
    mutate(
      distance_res = purrr::pmap(.l=list(..1 = train_fold,..2 = distance_method),
                          .f=~cdist(x = ..1 |> select(-response),
                                    y = ..1 |> pull(response),
                                    method = ..2)),
      distance_mat = purrr::map(.x=distance_res,.f=function(x=.x){x$distance_mat;
        return(x$distance_mat)}),
      delta = purrr::map(.x=distance_res,~.x$delta),
      delta_names = purrr::map(.x=distance_res,~.x$delta_names)
    )
  
  cat_data_results_pam = cat_data_resamples_pam |> 
    mutate(
      pam_solution = purrr::map2(.x=distance_mat,.y=n_clusters,.f=~pam(x = .x,k = .y)),
      pam_clust = purrr::map(.x=pam_solution, .f= ~return(.x$clustering |> factor())),
      pam_medoid_ids = purrr::map(.x=pam_solution, ~return(.x$id.med)),
      true_clust_test = purrr::map(.x = test_fold, ~return(.x |> pull(response))),
      test_data = purrr::map(.x = test_fold, ~return(.x |> select(-response))),
      train_data = purrr::map(.x = train_fold, ~return(.x |> select(-response))),
      medoids = purrr::map2(.x = train_data,.y=pam_medoid_ids, ~return(.x[.y,])),
      clust_test = purrr::pmap(.l=list(..1 = medoids,..2 = test_data,..3=delta, ..4 = delta_names),
                        ~predict_pam(medoids = ..1, newdata=..2, delta=..3, delta_names = ..4)),
      measure =  purrr::map2_dbl(.x = true_clust_test,.y = clust_test,
                                 .f=function(x=.x,y=.y) clustComp(c1 = x, c2 = y)$ARI
      )
    )
  
  
  
  pam_synth_results_list[[cv_iter]] = cat_data_results_pam |> dplyr::select(structure,distance_method,measure,fold_id) |>
    group_by(structure,distance_method) |> 
    summarise(ARI=mean(measure)) |> ungroup() |> mutate(cv_iteration=cv_iter)
}

Preparing the data structures and the corresponding labels for the plots

pam_synth_results = bind_rows(pam_synth_results_list) |> group_by(structure,distance_method) |>
  summarise(sd_ARI=sd(ARI),
            min_ARI=min(ARI),
            Q1_ARI = quantile(ARI,probs=.25),
            Q2_ARI = quantile(ARI,probs=.5),
            Q3_ARI = quantile(ARI,probs=.75),
            max_ARI=max(ARI),
            ARI=mean(ARI)
            ) |> ungroup() |> 
  mutate(distance_method = str_to_title(distance_method))
pam_synth_results |> 
  mutate(scenario=paste0(structure),
         distance_method = distance_method |> fct_recode("SM"="Matching",
                                                     "IOF"="Iof",
                                                     "OF"="Of",
                                                     "Goodall 3"="Goodall_3",
                                                     "Goodall 4"="Goodall_4",
                                                     "Chi2" = "Gifi_chi2",
                                                     "VE"="Var_entropy",
                                                     "VM"="Var_mutability",
                                                     "KL"="Kullback-Leibler",
                                                     "TVD"="Tot_var_dist",
                                                     "Supervised TVD"="Supervised",
                                                     "Supervised TVD full"="Supervised_full"
                                                     ) |> as.character(),
    # ordered_distances = reorder_within(distance_method,ARI,structure,sep = ":")) |>          
    ordered_distances = reorder_within(distance_method,Q2_ARI,structure,sep = ":")) |> 
  rename(`distance measure` = `distance_method`) |> 
  # ggplot(aes(x = Q2_ARI, y = ordered_distances))+theme_minimal() +
  ggplot(aes(x = Q2_ARI, y = ordered_distances))+theme_minimal() +
  geom_point(aes(colour=`distance measure`),size=2) +
  # geom_linerange(aes(xmin = ARI-sd_ARI,xmax = ARI+sd_ARI, y = ordered_distances),inherit.aes=FALSE)+
  geom_linerange(aes(xmin = Q1_ARI,xmax = Q3_ARI, y = ordered_distances),inherit.aes=FALSE)+
  theme(axis.text.y =  element_text(size=6),
        axis.text.x =  element_text(angle=90,size=5),
        legend.position = "bottom") +
  xlim(-0.05,1)+
  facet_wrap(.~structure,scales = "free") + 
  scale_y_discrete(label=labelize)+
  # labs(title="PAM clustering",subtitle = "k-medoids")+
  xlab("test ARI")+ ylab("")+
  scale_color_discrete(breaks=c("SM","Eskin","Lin","IOF","OF","Goodall 3","Goodall 4","VE","VM","TVD","Chi2","KL","Supervised TVD","Supervised TVD full"))

Analysis and results on real data sets

pam_results_list=list()
set.seed(123)
for(cv_iter in 1:replicates){
  
  benchmark_data_resamples_pam = benchmark_data |> 
    mutate(
      cross_validation = purrr::map(.x=prepped_data,~vfold_cv(.x, v= 5, strata=response)),
      fold_id = purrr::map(.x=cross_validation,~.x$id),
      fold_splits = purrr::map(.x=cross_validation,~.x$splits),
      n_clusters  = purrr::map_dbl(.x=prepped_data,~length(levels(.x$response)))
    )|> unnest(cols = distance_method) |> unnest(cols = c(fold_id,fold_splits)) |> #|> unnest(cols = k_par)
  dplyr::select(-cross_validation,-prepped_data) |>
    mutate(train_fold=purrr::map(.x=fold_splits,.f=~analysis(.x)),
           test_fold=purrr::map(.x=fold_splits,.f=~assessment(.x))
    ) |> 
    dplyr::select(-fold_splits)|>
    
    mutate(
      distance_res = purrr::pmap(.l=list(..1 = train_fold,..2 = distance_method),
                                 .f=~cdist(x = ..1 |> dplyr::select(-response),
                                           y = ..1 |> pull(response),
                                           method = ..2)),
      distance_mat = purrr::map(.x=distance_res,.f=function(x=.x){x$distance_mat;
        return(x$distance_mat)}),
      delta = purrr::map(.x=distance_res,~.x$delta),
      delta_names = purrr::map(.x=distance_res,~.x$delta_names)
    )
  
  benchmark_results_pam = benchmark_data_resamples_pam |> 
    mutate(
      pam_solution = purrr::map2(.x=distance_mat,.y=n_clusters,.f=~pam(x = .x,k = .y))
    )
  
  benchmark_results_pam = benchmark_results_pam |> 
    mutate(
      pam_clust = purrr::map(.x=pam_solution, .f= ~return(.x$clustering |> factor())),
      pam_medoid_ids = purrr::map(.x=pam_solution, ~return(.x$id.med)),
      true_clust_test = purrr::map(.x = test_fold, ~return(.x |> pull(response))),
      test_data = purrr::map(.x = test_fold, ~return(.x |> select(-response))),
      train_data = purrr::map(.x = train_fold, ~return(.x |> select(-response))),
      medoids = purrr::map2(.x = train_data,.y=pam_medoid_ids, ~return(.x[.y,]))
    )
  
  benchmark_results_pam = benchmark_results_pam |> 
    mutate(
      clust_test = purrr::pmap(.l=list(..1 = medoids,..2 = test_data,..3=delta, ..4 = delta_names),
                               ~predict_pam(medoids = ..1, newdata=..2, delta=..3, delta_names = ..4)),
      measure =  purrr::map2_dbl(.x = true_clust_test,.y = clust_test,
                                 .f=function(x=.x,y=.y) clustComp(c1 = x, c2 = y)$ARI
      )
    )
  
  pam_results_list[[cv_iter]] = benchmark_results_pam |> dplyr::select(dataset_name,distance_method,measure,fold_id) |> group_by(dataset_name,distance_method) |> 
    summarise(ARI=mean(measure)) |> ungroup() |> mutate(cv_iteration=cv_iter)
}
long_levels = benchmark_data |> mutate(
  class = c(2,2,2,2,3,3,4,19,4),
  data_n = purrr::map_dbl(prepped_data,nrow),
  data_p = purrr::map_dbl(prepped_data,ncol)-1,
  long_name = purrr::pmap_chr(.l=list(dataset_name,data_n,data_p,class),.f= ~paste0(..1," (n: ",..2,
                                                                         ", p: ",..3,", cl: ",..4,")"))
)|> pull(long_name) |> as.factor() |> levels()

long_levels[5]="soybean (n: 307, p: 35, cl: 19)"

pam_results=bind_rows(pam_results_list) |> group_by(dataset_name,distance_method) |> 
  summarise(sd_ARI=sd(ARI),
            min_ARI=min(ARI),
            Q1_ARI = quantile(ARI,probs=.25),
            Q2_ARI = quantile(ARI,probs=.5),
            Q3_ARI = quantile(ARI,probs=.75),
            max_ARI=max(ARI),
            ARI=mean(ARI)) |> 
  mutate(distance_method = str_to_title(distance_method),
         datasets=as.factor(dataset_name)
         )

nds = pam_results |> pull(datasets) |> as.factor()
levels(nds) = long_levels

pam_results$nds = nds

pam_results = pam_results |> 
  mutate(datasets = nds)