Automatic dendrogram slicing for mixed-type data clustering

supplementary material

Initialization

libraries and functions needed

Code
library(cluster)
library(data.table)
library(distances)
library(dynamicTreeCut)
library(factoextra)
library(FactoMineR)
library(fastcluster)
library(fastDummies)
library(fpc)
library(ggdendro)
library(ggrepel)
library(ggsurvfit)
library(gower)
library(graphics)
library(kableExtra)
library(janitor)
library(kableExtra)
library(Matrix)
library(mclust)
library(NbClust)
library(palmerpenguins)
library(PCAmixdata)
library(rcompanion)
library(reticulate)
library(survival)
library(tidymodels)
library(tidyverse)

source("R/linkage_maker.R")
source("R/my_hclust.R")
source("R/single_div_step.R")
source("R/node_heights_compute.R")
source("R/is_terminal_node.R")
source("R/despota_evo.R")
source("R/mdist.R")
source("despota_GUDMM/GUDMM_Run.R")
source("despota_GUDMM/public_data.R")
tidymodels_prefer()
genom <- readRDS(file = "main_results/OutputCV1_PCAmerged.RDS")
genom_mix <- readRDS(file = "main_results/OutputCV1_PCAmixmerged.RDS")

1. On the equivalence of bottom-up/top-down the hierarchy levels

The Ward.D2 (and, Ward.D) aggregation methods that can be selected in the standard hclust function are based on the Lance-Williams formula (that generalizes at any linkage function). To keep it simple, and to avoid computing the full hierarchy for each permutation, we use the Minimum Increase of Sum of Squares (MISSQ) computed at each split

\[ MISSQ:\frac{n_{L} \cdot n_{R}}{n_{L}+n_{R}}\|\mu_{L}-\mu_{R}\|^{2} = \sum_{{\bf x}_{i}\in (L\cup R)}\|{\bf x}_{i}-\mu_{(L\cup R)} \|^{2} - \sum_{{\bf x}_{i}\in (L)}\|{\bf x}_{i}-\mu_{(L)} \|^{2} - \sum_{{\bf x}_{i}\in (R)}\|{\bf x}_{i}-\mu_{(R)} \|^{2} \] that is, the increase in the \(SS\) due to the merge.

The quantities in Equation (1) can also be computed via pairwise distances, since, for the general group A, the following equation holds

\[ \sum_{x_{i}\in A}\|x_{i}-\mu_{A} \|^{2} = \frac{1}{n_{A}}\sum_{i=1}^{n_{A}}\sum_{j=1}^{n_{A}}\|x_{i}-x_{j} \|^{2} \ \ \ \ (2) \] where \(n_{A}\) indicates the group size.

  • To build the hierarchy, the right-hand side of Equation (2) is used.

  • When it comes to re-compute the increase of SS of the split for the permutation test, the left-hand side of Equation (1) is useful.

Toy experiment: top-down computations of node heights

It’s easy to compute \(h(L,R)\), \(h(L)\) and \(h(R)\) (the quantity one needs in DESPOTA) in a top down approach. All it takes is to compute the max distance overall, and the max distances in the left and right-hand side groups (\(L\) and \(R\)).

The ward-linkage case is less obvious, but it can be done, we show how in the next toy example.

  • Generate data and compute the pairwise Euclidean distances
set.seed(123)
toy_data=tibble(x1=round(runif(1:10,min=2,max=15)),
                x2=round(runif(1:10,min=1,max=4))
                )

toy_dist=dist(toy_data)
  • perform standard hclust on the toy data with:
    • ward linkage
    • complete linkage
    • single linkage
toy_hc_w = hclust(toy_dist,method="ward.D")
toy_hc_sing = hclust(toy_dist,method =  "single")
toy_hc_comp = hclust(toy_dist,method =  "complete")
  • take the data off of the hclust object that are needed for the dendrogram
# ward
o_toy_hc_w = toy_hc_w
toy_hc_w = as.dendrogram(toy_hc_w)
toy_hc_w_ddata <- dendro_data(toy_hc_w, type = "rectangle")
# single
o_toy_hc_sing = toy_hc_sing
toy_hc_sing = as.dendrogram(toy_hc_sing)
toy_hc_sing_ddata <- dendro_data(toy_hc_sing, type = "rectangle")

# complete
o_toy_hc_comp = toy_hc_comp
toy_hc_comp = as.dendrogram(toy_hc_comp)
toy_hc_comp_ddata <- dendro_data(toy_hc_comp, type = "rectangle")
  • select the nodes of interest and highlight them in the dendrogram
hc_point_data_w = toy_hc_w_ddata$segments |> filter(y==max(y))

custom_hc_plot_w = toy_hc_w_ddata$segments |> ggplot() + 
  geom_segment(aes(x = x, y = y, xend = xend, yend = yend)) + 
  geom_segment(data=hc_point_data_w, 
               aes(x = x, y = y, xend = xend, yend = yend),
               alpha=.5,color="indianred",linewidth=4,
               inherit.aes = FALSE) +
  geom_point(data=hc_point_data_w, 
               aes(x = x, y = yend),color="dodgerblue",size=4,
               inherit.aes = FALSE)+
  theme_bw()+ggtitle("ward")

ggsave(file="MISSQ_comp.pdf",custom_hc_plot_w,height=7, width = 7)
hc_point_data_sing = toy_hc_sing_ddata$segments |> filter(y==max(y))

custom_hc_plot_sing = toy_hc_sing_ddata$segments |> ggplot() + 
  geom_segment(aes(x = x, y = y, xend = xend, yend = yend)) + 
  geom_segment(data=hc_point_data_sing, 
               aes(x = x, y = y, xend = xend, yend = yend),
               alpha=.5,color="indianred",linewidth=4,
               inherit.aes = FALSE) +
  geom_point(data=hc_point_data_sing, 
               aes(x = x, y = yend),color="dodgerblue",size=4,
               inherit.aes = FALSE)+
  theme_bw()+ggtitle("single")

hc_point_data_comp = toy_hc_comp_ddata$segments |> filter(y==max(y))

custom_hc_plot_comp = toy_hc_comp_ddata$segments |> ggplot() + 
  geom_segment(aes(x = x, y = y, xend = xend, yend = yend)) + 
  geom_segment(data=hc_point_data_comp, 
               aes(x = x, y = y, xend = xend, yend = yend),
               alpha=.5,color="indianred",linewidth=4,
               inherit.aes = FALSE) +
  geom_point(data=hc_point_data_comp, 
               aes(x = x, y = yend),color="dodgerblue",size=4,
               inherit.aes = FALSE)+
  theme_bw()+ggtitle("complete")
library(patchwork)
custom_hc_plot_w |  
custom_hc_plot_sing | 
custom_hc_plot_comp

par(mfrow=c(1,3))
plot(toy_hc_w);title("ward linkage")
plot(toy_hc_sing);title("single linkage")
plot(toy_hc_comp);title("complete linkage")

For the Ward linkage, the trick is to compute the quantity in the left-hand side of the MISSQ formula for the top node, \(MISSQ_{L\cup R}\), using the distances as in formula 2 and do the same going down one level to obtain \(MISSQ_{L}\) and \(MISSQ_{R}\)

Here follow the computations by hand, just as an illustrative example

L_set = c(5,8,2,4)
L_l_set = c(5,8)
L_r_set = c(2,4)

R_set = c(6,7,9,10,1,3)
R_l_set = c(6)
R_r_set = c(7,9,10,1,3)

within_L_R = ((1/length(L_set))*sum(as.matrix(toy_dist)[L_set, L_set])+
(1/length(R_set))*sum(as.matrix(toy_dist)[R_set,R_set]))

top_node = (1/nrow(toy_data))*sum(as.matrix(toy_dist))-within_L_R

within_L_lr = ((1/length(L_l_set))*sum(as.matrix(toy_dist)[L_l_set,L_l_set])+
(1/length(L_r_set))*sum(as.matrix(toy_dist)[L_r_set,L_r_set]))

L_node = (1/length(L_set))*sum(as.matrix(toy_dist)[L_set,L_set]) - within_L_lr

within_R_lr = (2*sum(as.matrix(toy_dist)[R_l_set,R_l_set])+
(1/length(R_r_set))*sum(as.matrix(toy_dist)[R_r_set,R_r_set]))

R_node = (1/length(R_set))*sum(as.matrix(toy_dist)[R_set,R_set]) - within_R_lr


tibble(node_name = c("top_node", "left_node","right_node"), 
       node_height=c(top_node, L_node, R_node), 
       hclust_nodes=o_toy_hc_w$height[c(9,6,8)]) |> kbl()
node_name node_height hclust_nodes
top_node 21.960782 21.960782
left_node 3.765029 3.765029
right_node 6.858485 6.858485

Top-down computations: penguins dataset

The function node_heights_compute takes as input the distance matrix, the labels for the L/R split, and for the lower level split (l/r for the R set and for the L set, respectively).

We consider the Palmer Penguins dataset, we compute the dendrogram for the ward.D2 linkage (the real Ward… ref murtagh and legendre).

pengs = palmerpenguins::penguins |> na.omit()|> 
  dplyr::select(where(is.numeric),-year) |> 
  mutate(across(.cols=everything(),~scale(.x)))

id_pengs= paste0("id_",1:nrow(pengs))

pengs_dist =  dist(pengs) 
attr(pengs_dist,"Labels")=id_pengs

hclust_solution = hclust(pengs_dist,method = "ward.D2")
plot(hclust_solution,cex=0.5)
rect.hclust(hclust_solution,k=4)

  • For the first split, we compute the \(rc_k\) test statistic: the \(F\) node is the root node of the dendrogram, the left and right nodes are the third and second nodes of the tree, respectively.
top_nd=sort(hclust_solution$height,decreasing=T)[c(1)]
L_nd=sort(hclust_solution$height,decreasing=T)[c(3)]
R_nd=sort(hclust_solution$height,decreasing=T)[c(2)]

observed_rck = abs(L_nd-R_nd)/(top_nd-min(L_nd,R_nd))
observed_rck
## [1] 0.2292955

The heights for the top node (father), and for its children nodes can also be otained from the original pairwise distances, provided that the sets of children nodes originating from the L and R split, are respectively identified.

  • The \(L\) and \(R\) set are obtained via a 2-clusters cut of the dendrogram.
hclust_clusters = cutree(hclust_solution,k=2)
LR_set = hclust_clusters
L_set=names(LR_set[which(LR_set==2)])
R_set=names(LR_set[which(LR_set==1)])
  • Similarly the sets for \(L_l\), \(L_r\), \(R_l\) and \(R_r\) are obtained via a 4-clusters cut of the dendrogram. We manually re-alligned the clusters so that each subsets is correctly assocciated to its label.
hclust_clusters_sub = cutree(hclust_solution,k=4)
lr_set = hclust_clusters_sub

L_l_set = names(lr_set[which(lr_set==2)])
L_r_set = names(lr_set[which(lr_set==3)])

R_l_set = names(lr_set[which(lr_set==4)])
R_r_set = names(lr_set[which(lr_set==1)])
  • Given the sets, and the original distance matrix, the function node_heights_compute.r to compute the \(F\), \(L\) and \(R\) node heights (the values of the linkage functions computed on the corresponding sets).
out=node_heights_compute(distance_matrix = pengs_dist, method="ward.D2",
                         L_set = L_set,R_set = R_set,
                         L_l_set = L_l_set,
                         L_r_set = L_r_set,
                         R_l_set = R_l_set,
                         R_r_set = R_r_set)

In fact, the node heights computed above correspond to those returned from hclust.

out |> mutate(hclust_height = sort(hclust_solution$height,decreasing=T)[c(1,3,2)] ) 
## # A tibble: 3 × 3
##   node     height hclust_height
##   <chr>     <dbl>         <dbl>
## 1 top_node   39.4          39.4
## 2 L_node     12.1          12.1
## 3 R_node     18.4          18.4

2. Application

The same data preparation and preprocessing pipeline as in Ellen et al. is implemented at this link.

The resulting dataset has been saved in .RDS format. Here follows the dimension reduction step.

Data re-coding and taming

genom <- readRDS("data/OutputCV1_train_FULL.RDS")

genom0 <- genom |>
  mutate(clin_gender = gender,
         clin_pack_years_smoked = ifelse(is.na(pack_years_smoked), median(pack_years_smoked,na.rm = TRUE),pack_years_smoked),
         clin_prior_malignancy = prior_malignancy,
         clin_volume = ifelse(is.na(volume), median(volume,na.rm = TRUE),volume),
         clin_censor_time1 = censor_time1,
         clin_ajcc_pathologic_stage = ifelse(is.na(ajcc_pathologic_stage), 4, ajcc_pathologic_stage),
         clin_vital_status = vital_status,
         clin_volume_low = ifelse(clin_volume <= 0.3200, 1, 2), # if volume <= median
         clin_pack_years_smoked_low = ifelse(clin_pack_years_smoked <= 30, 1, 2) # if packs smoked <= Q1 
         clin_disease = as.factor(disease),
         clin_gender = as.factor(clin_gender),
         clin_prior_malignancy = as.factor(clin_prior_malignancy),
         clin_volume_low = as.factor(clin_volume_low),
         clin_pack_years_smoked_low = as.factor(clin_pack_years_smoked_low),
         clin_ajcc_pathologic_stage = as.factor(clin_ajcc_pathologic_stage),
         clin_vital_status = as.factor(clin_vital_status)
         ) |>
  dplyr::select(-gender,-pack_years_smoked,-prior_malignancy,-volume,-disease,
                -clin_pack_years_smoked,-clin_volume,
                -ajcc_pathologic_stage,-censor_time1,-vital_status
  )

clin0 <- genom0 |> dplyr::select(starts_with("clin_"))

Dimension reduction: principal component analysis

The continuous features are linearly combined via principal component Analysis. The first 20 components are considered that replace the starting continuous variables

only_genom <- genom0 |> dplyr::select(-barcode,-starts_with("clin_"))
genomPCA <- PCA(only_genom,ncp = 20)
genomPCA_full <- cbind(genomPCA$ind$coord,clin0)


genom_full <- genom0
saveRDS(genom_full,file = "/OutputCV1_merged.RDS")
saveRDS(genomPCA_full,file = "/OutputCV1_PCAmerged.RDS")

As an alternative to the continuous-only dimension reduction step, a factor analysis for mixed data is applied on the full set of data. We use the PCAmixdata package, and keep the first 25 components, 20 for the continuous features (as in the PCA-based case), the extra 5 components are for the 5 categorical variables.

X <- genom0 |> 
  dplyr::select(-barcode,-clin_vital_status,-clin_censor_time1,-clin_disease) |> 
  splitmix()
X1 <- X$X.quanti
X2 <- X$X.quali

pcamix_mod <- PCAmix(X.quanti=X1, X.quali=X2,rename.level=TRUE,ndim = 40,
                     graph=FALSE)

genomPCAmix_full <- cbind(pcamix_mod$ind$coord[,1:30],clin0) 

Select only the continuous variables (PCs)


genom_red <- genom |>
  dplyr::select(Dim.1:Dim.20,
         clin_gender,clin_prior_malignancy,clin_volume_low,clin_pack_years_smoked_low,clin_ajcc_pathologic_stage) |>
  as_tibble()

Full dataset with no factors, to be used to compute the GUDMM distance


genom_red_nofct <- genom |> 
  mutate(
    across(starts_with("clin_"), ~ as.numeric(.x) )
  ) |>
  dplyr::select(Dim.1:Dim.20,
                clin_gender,clin_prior_malignancy,clin_volume_low,clin_pack_years_smoked_low,clin_ajcc_pathologic_stage) |>
  as_tibble()

Distance computations

# NOTE: before running this chunk be sure to have properly installed reticulate and python in your PC. 
# It is also important to point to the right active python root via the following command:
# Sys.setenv(RETICULATE_PYTHON = "/usr/local/bin/python3")

GUDMM <- GUDMM_Run(genom_red_nofct, no_f_cont = 20, no_f_ord = 5)
dist_Gud <- as.dist(GUDMM$dist_mat)
names(dist_Gud) <- paste0("id_",1:nrow(dist_Gud))

dist_Gow <- daisy(genom_red, metric = "gower")
names(dist_Gow) <- paste0("id_",1:nrow(dist_Gow))

genom_continuos_reco <- recipe(data=genom_red,formula = ~.) |> 
  step_dummy(all_nominal(),one_hot = T) |> prep() |>
  bake(new_data=NULL)
dist_Euc <- daisy(genom_continuos_reco,metric="euclidean")
names(dist_Euc) <- paste0("id_",1:nrow(dist_Euc))

genom_continuos_pcamix <- genom_mix |> 
  dplyr::select(1:25) 
dist_EucMix <- daisy(genom_continuos_pcamix,metric="euclidean")
names(dist_EucMix) <- paste0("id_",1:nrow(dist_EucMix))

Run the simulations for analysis


# tidymodels_prefer()

despota_str_k1 <- expand_grid(
  distances = list(list(dist_Euc,"euclidean"),list(dist_EucMix,"PCAmix"),list(dist_Gow,"gower"),list(dist_Gud,"gudmm")),
  aggMethod = c("ward.D2"),
  M = 499, 
  alpha =  c(seq(0.01,0.20,by = 0.005),seq(0.3,0.5,by = 0.1)) 
  ) |> 
  mutate(
    seed = sample(1000:9999, n(), replace = FALSE),
    id = 1:n(),
    hclust = pmap(list(a = distances, b = aggMethod, e = id), 
                  .f = function(a,b,e){
                    # cat("HCLUST loop",e,"of",n(),"\n")
                    OUT <- hclust(d = a[[1]], method = b)
                    OUT
      }),
    despota = pmap(list(a = distances, b = aggMethod, c = hclust, mm = M,
              d = alpha, e = id), 
         .f = function(a,b,c,mm,d,e){
           cat("DESPOTA EVO loop ",e," of ",n()," (",a[[2]],", alpha = ",d,") \n", sep = "")
           start <- Sys.time()
           mod <- despota_evo(distance_matrix=a[[1]],hclust_solution=c,linkage_method=b,
                        top_down = TRUE, M=mm, alpha=d)
           OUT = list(
             "model" = mod,
             "labels" = mod |> 
                          unnest(cl_members) |> 
                          select(cluster)
             )
           cat(" ",format(Sys.time() - start)," --------------------- \n")
           
           OUT
           }),
    dist = map_chr(distances, ~ .x[[2]])
    )

despota_str_k1_sel <- despota_str_k1 |> 
  mutate(
     desp_nclus = map_dbl(despota, ~ length(table(.x$labels)))
     ) |> 
  group_by(dist,aggMethod,M,desp_nclus) |> 
  slice(which.min(alpha)) |> 
  ungroup() |> 
  filter(desp_nclus > 1) |> 
  select(-desp_nclus)

dim(despota_str_k1_sel);dim(despota_str_k1)

Comparisons

Simulation scheme for the paper


despota_eval1 <- despota_str_kappa1 |> mutate(
  despota_lab = map(despota, ~ .x$labels |>  pull(cluster)),
  despota_k = map_dbl(despota_lab,~ length(table(.x))),
  hier_lab2 = map2(hclust, despota_k, ~ cutree(.x, k = .y)),
  # ---------------------------------------------
  CST_d = map2(distances,despota_lab, ~ cluster.stats(.x[[1]],.y,)),
  CST_h2 = map2(distances,hier_lab2, ~ cluster.stats(.x[[1]],.y,)),
  ASW_d = map_dbl(CST_d, ~ .x$avg.silwidth),
  ASW_h2 = map_dbl(CST_h2, ~ .x$avg.silwidth)
) |> 
  select(-distances, -despota, -hclust, -CST_d, -CST_h2)


mat_sil <- despota_eval_all1 %>% 
  select(ASW_d,ASW_h2,dist,aggMethod,M,alpha,despota_k)  %>% 
  pivot_longer(cols = c(ASW_d,ASW_h2),names_to = "fit_type",values_to = "fit_vals") %>% 
    mutate(
      method = "Avg. Silhouette",
      fit_type  =recode(fit_type,
        "ASW_d" = "DESPOTA",
        "ASW_h2" = "HClust")
      )


plotFIN2 <- mat_sil |> 
  mutate(
    method = recode_factor(method,
                                "Avg. Silhouette" = "Avg. Silhouette",
                                "Calinski-Harabasz" = "Calinski-Harabasz",
                                "DUNN2" = "Dunn",
                                "DUNN" = "DUNN"),
    dist = recode_factor(dist,
                                "euclidean" = "Euclidean",
                                "gower" = "Gower-based",
                                "PCAmix" = "FAMD-based",
                                "gudmm" = "Entropy-based",
                         .ordered = TRUE),
    fit_type = recode_factor(fit_type,
                                "DESPOTA" = "DESPOTA",
                                "HClust" = "HCl")
         ) |>  
  filter(despota_k >= 2 & method %in% c("Avg. Silhouette")) |>  #aggMethod == "ward.D") |>  # & 
  ggplot(aes(x = despota_k, y  = fit_vals, colour = fit_type, group = fit_type, linetype = fit_type)) +
  geom_point() +
  geom_line() +
  facet_grid(. ~ dist) +
  xlab("k") +
  ylab("Avg. Silhouette") +
  theme_bw() +
  scale_x_continuous(breaks =seq(0, 30, by = 5)) +
  scale_colour_manual(values = c("#F75836","#3639F7")) +
  theme(legend.position = "bottom", legend.title = element_blank())

plotFIN2

# ggsave("gof_sil.pdf",
#        width = 6, height = 4.5)

DFsurv1 <- surv1 |> 
  select(id,dist,aggMethod,despota_k,despota_lab,hier_lab2) |>
  mutate(
    survival = rep(list(genom$clin_censor_time1),nrow(surv1)),
    deaths = rep(list(genom$clin_vital_status),nrow(surv1)),
    deaths = map(deaths, ~ recode(.x,
                "Dead" = 1, 
                "Alive" = 0
                )),
    disease = rep(list(genom$clin_disease),nrow(surv1)),
    memb_d = despota_lab,
    memb_h = surv1$hier_lab2,
    surv_d = pmap(list(sur = survival, dea = deaths, mem = memb_d), function(sur, dea, mem){ Surv(sur,dea) ~ mem}),
    surv_h = pmap(list(sur = survival, dea = deaths, mem = memb_h), function(sur, dea, mem){ Surv(sur,dea) ~ mem}),
    sdiff_d = map(surv_d, ~ survdiff(as.vector(.x))),
    sdiff_h = map(surv_h, ~ survdiff(as.vector(.x))),
    pval_d = map_dbl(sdiff_d, ~ .x$pvalue),
    pval_h = map_dbl(sdiff_h, ~ .x$pvalue)
)

# Entropy-based with 4 clusters
plotsur1 <- DFsurv1 |> 
  filter(dist == "gudmm" & aggMethod == "ward.D2" & despota_k == 4) |> 
    mutate(
      surv_d2 = pmap(list(sur = survival, dea = deaths, mem = memb_d, dis = disease), function(sur, dea, mem, dis){ Surv(sur,dea) ~ mem + dis}),
    surv_h2 = pmap(list(sur = survival, dea = deaths, mem = memb_h, dis = disease), function(sur, dea, mem, dis){ Surv(sur,dea) ~ mem + dis}),
    fit_d = map(surv_d2, ~ survfit(.x)),
    fit_h = map(surv_h2, ~ survfit(.x)),
    gpd = map(surv_d, ~ survfit(.x) |>
        ggsurvfit() +
        theme(legend.position = "none")
  ),
    gph = map(surv_h, ~ survfit(.x) |>
        ggsurvfit() +
        theme(legend.position = "none")
  )
  ) |> select(gpd,gph,despota_lab,hier_lab2,surv_d,surv_h,sdiff_d,sdiff_h,fit_d,fit_h)  
table(plotsur1$despota_lab)
## 
##   1   2   3   4 
## 537  17  24   8
table(plotsur1$hier_lab2)
## 
##   1   2   3   4 
## 312 225  17  32

print(survfit(plotsur1$surv_d[[1]]), rmean= 730, scale=365)
## Call: survfit(formula = plotsur1$surv_d[[1]])
## 
##         n events rmean* se(rmean) median 0.95LCL 0.95UCL
## mem=1 537    182  1.793    0.0214  4.756   4.195    5.96
## mem=2  17     17  0.839    0.1541  0.712   0.326    1.65
## mem=3  24     21  1.067    0.1270  0.901   0.647    1.92
## mem=4   8      6  1.268    0.2005  1.189   0.970      NA
##     * restricted mean with upper limit =  2
print(survfit(plotsur1$surv_h[[1]]), rmean= 730, scale=365)
## Call: survfit(formula = plotsur1$surv_h[[1]])
## 
##         n events rmean* se(rmean) median 0.95LCL 0.95UCL
## mem=1 312     92  1.841    0.0244  6.258   5.014    9.21
## mem=2 225     90  1.724    0.0381  3.279   2.923    4.38
## mem=3  17     17  0.839    0.1541  0.712   0.326    1.65
## mem=4  32     27  1.119    0.1094  0.970   0.704    1.51
##     * restricted mean with upper limit =  2
print(plotsur1$fit_d[[1]], rmean= 730)
## Call: survfit(formula = .x)
## 
##                   n events rmean* se(rmean) median 0.95LCL 0.95UCL
## mem=1, dis=LUAD 296     94    659      10.2   1632    1498    2617
## mem=1, dis=LUSC 241     88    649      12.1   1912    1640    2625
## mem=2, dis=LUAD   9      9    363      76.4    275     187      NA
## mem=2, dis=LUSC   8      8    243      77.1    148      52      NA
## mem=3, dis=LUAD  15     13    417      60.4    340     250      NA
## mem=3, dis=LUSC   9      8    339      67.3    282     198      NA
## mem=4, dis=LUAD   7      5    513      64.3    447     370      NA
## mem=4, dis=LUSC   1      1    113       0.0    113      NA      NA
##     * restricted mean with upper limit =  730
print(plotsur1$fit_h[[1]], rmean= 730)
## Call: survfit(formula = .x)
## 
##                   n events rmean* se(rmean) median 0.95LCL 0.95UCL
## mem=1, dis=LUAD 162     46    680      11.2   2617    1653      NA
## mem=1, dis=LUSC 150     46    664      13.9   2284    1841      NA
## mem=2, dis=LUAD 134     48    633      17.9   1288    1171    1622
## mem=2, dis=LUSC  91     42    625      22.2   1107     927    2378
## mem=3, dis=LUAD   9      9    363      76.4    275     187      NA
## mem=3, dis=LUSC   8      8    243      77.1    148      52      NA
## mem=4, dis=LUAD  22     18    448      47.5    434     340    1115
## mem=4, dis=LUSC  10      9    314      64.3    236     166      NA
##     * restricted mean with upper limit =  730
summary(plotsur1$fit_d[[1]],time=(0:4)*182.5, scale=365)
## Call: survfit(formula = .x)
## 
##                 mem=1, dis=LUAD 
##  time n.risk n.event survival std.err lower 95% CI upper 95% CI
##   0.0    296       0    1.000  0.0000        1.000        1.000
##   0.5    265      14    0.951  0.0127        0.927        0.976
##   1.0    239      13    0.903  0.0178        0.869        0.938
##   1.5    183      11    0.858  0.0215        0.817        0.901
##   2.0    123      11    0.800  0.0263        0.750        0.853
## 
##                 mem=1, dis=LUSC 
##  time n.risk n.event survival std.err lower 95% CI upper 95% CI
##   0.0    241       0    1.000  0.0000        1.000        1.000
##   0.5    208      13    0.944  0.0152        0.914        0.974
##   1.0    191      11    0.893  0.0207        0.853        0.934
##   1.5    155      13    0.828  0.0259        0.779        0.880
##   2.0    125       9    0.775  0.0296        0.719        0.836
## 
##                 mem=2, dis=LUAD 
##  time n.risk n.event survival std.err lower 95% CI upper 95% CI
##   0.0      9       0    1.000   0.000       1.0000        1.000
##   0.5      7       2    0.778   0.139       0.5485        1.000
##   1.0      3       4    0.333   0.157       0.1323        0.840
##   1.5      3       0    0.333   0.157       0.1323        0.840
##   2.0      1       2    0.111   0.105       0.0175        0.705
## 
##                 mem=2, dis=LUSC 
##  time n.risk n.event survival std.err lower 95% CI upper 95% CI
##   0.0      8       0    1.000   0.000        1.000        1.000
##   0.5      4       4    0.500   0.177        0.250        1.000
##   1.0      3       1    0.375   0.171        0.153        0.917
##   1.5      1       2    0.125   0.117        0.020        0.782
## 
##                 mem=3, dis=LUAD 
##  time n.risk n.event survival std.err lower 95% CI upper 95% CI
##   0.0     15       0    1.000   0.000       1.0000        1.000
##   0.5     12       3    0.800   0.103       0.6212        1.000
##   1.0      7       5    0.467   0.129       0.2717        0.802
##   1.5      4       2    0.320   0.124       0.1498        0.684
##   2.0      3       1    0.240   0.116       0.0931        0.619
## 
##                 mem=3, dis=LUSC 
##  time n.risk n.event survival std.err lower 95% CI upper 95% CI
##   0.0      9       0    1.000   0.000       1.0000        1.000
##   0.5      6       2    0.750   0.153       0.5027        1.000
##   1.0      2       4    0.250   0.153       0.0753        0.830
##   1.5      2       0    0.250   0.153       0.0753        0.830
##   2.0      1       1    0.125   0.117       0.0200        0.782
## 
##                 mem=4, dis=LUAD 
##  time n.risk n.event survival std.err lower 95% CI upper 95% CI
##   0.0      7       0    1.000   0.000        1.000            1
##   0.5      7       0    1.000   0.000        1.000            1
##   1.0      5       1    0.833   0.152        0.583            1
##   1.5      2       3    0.333   0.192        0.108            1
##   2.0      2       0    0.333   0.192        0.108            1
## 
##                 mem=4, dis=LUSC 
##         time       n.risk      n.event     survival      std.err lower 95% CI 
##            0            1            0            1            0            1 
## upper 95% CI 
##            1
summary(plotsur1$fit_h[[1]],time=(0:4)*182.5, scale=365)
## Call: survfit(formula = .x)
## 
##                 mem=1, dis=LUAD 
##  time n.risk n.event survival std.err lower 95% CI upper 95% CI
##   0.0    162       0    1.000  0.0000        1.000        1.000
##   0.5    149       4    0.974  0.0126        0.950        0.999
##   1.0    137       6    0.934  0.0201        0.896        0.975
##   1.5    109       5    0.899  0.0249        0.851        0.949
##   2.0     77       7    0.834  0.0330        0.772        0.902
## 
##                 mem=1, dis=LUSC 
##  time n.risk n.event survival std.err lower 95% CI upper 95% CI
##   0.0    150       0    1.000  0.0000        1.000        1.000
##   0.5    133       7    0.951  0.0180        0.917        0.987
##   1.0    126       4    0.922  0.0225        0.879        0.968
##   1.5    105       9    0.854  0.0303        0.796        0.915
##   2.0     82       5    0.807  0.0351        0.741        0.879
## 
##                 mem=2, dis=LUAD 
##  time n.risk n.event survival std.err lower 95% CI upper 95% CI
##   0.0    134       0    1.000  0.0000        1.000        1.000
##   0.5    116      10    0.923  0.0234        0.878        0.970
##   1.0    102       7    0.865  0.0306        0.807        0.927
##   1.5     74       6    0.808  0.0364        0.739        0.882
##   2.0     46       4    0.759  0.0416        0.681        0.845
## 
##                 mem=2, dis=LUSC 
##  time n.risk n.event survival std.err lower 95% CI upper 95% CI
##   0.0     91       0    1.000  0.0000        1.000        1.000
##   0.5     75       6    0.931  0.0270        0.880        0.986
##   1.0     65       7    0.842  0.0403        0.767        0.925
##   1.5     50       4    0.784  0.0470        0.697        0.881
##   2.0     43       4    0.718  0.0532        0.621        0.831
## 
##                 mem=3, dis=LUAD 
##  time n.risk n.event survival std.err lower 95% CI upper 95% CI
##   0.0      9       0    1.000   0.000       1.0000        1.000
##   0.5      7       2    0.778   0.139       0.5485        1.000
##   1.0      3       4    0.333   0.157       0.1323        0.840
##   1.5      3       0    0.333   0.157       0.1323        0.840
##   2.0      1       2    0.111   0.105       0.0175        0.705
## 
##                 mem=3, dis=LUSC 
##  time n.risk n.event survival std.err lower 95% CI upper 95% CI
##   0.0      8       0    1.000   0.000        1.000        1.000
##   0.5      4       4    0.500   0.177        0.250        1.000
##   1.0      3       1    0.375   0.171        0.153        0.917
##   1.5      1       2    0.125   0.117        0.020        0.782
## 
##                 mem=4, dis=LUAD 
##  time n.risk n.event survival std.err lower 95% CI upper 95% CI
##   0.0     22       0    1.000  0.0000        1.000        1.000
##   0.5     19       3    0.864  0.0732        0.732        1.000
##   1.0     12       6    0.578  0.1074        0.401        0.832
##   1.5      6       5    0.330  0.1044        0.178        0.614
##   2.0      5       1    0.275  0.1005        0.134        0.563
## 
##                 mem=4, dis=LUSC 
##  time n.risk n.event survival std.err lower 95% CI upper 95% CI
##   0.0     10       0    1.000   0.000       1.0000        1.000
##   0.5      6       3    0.667   0.157       0.4200        1.000
##   1.0      2       4    0.222   0.139       0.0655        0.754
##   1.5      2       0    0.222   0.139       0.0655        0.754
##   2.0      1       1    0.111   0.105       0.0175        0.705



# figure Entropy-based
ggpubr::ggarrange(
    plotsur1$gpd[[1]] + ggtitle("DESPOTA"),
    plotsur1$gph[[1]] + ggtitle("HCl") + ylab(NULL),
    nrow = 1)

# ggsave("survival_gudmm.pdf",
#        width = 8, height = 4)




# EUCLIDEAN and 4 clusters
plotsur2 <- DFsurv1 |> 
  filter(dist == "euclidean" & aggMethod == "ward.D2" & despota_k == 4) |> 
      mutate(
    surv_d2 = pmap(list(sur = survival, dea = deaths, mem = memb_d, dis = disease), function(sur, dea, mem, dis){ Surv(sur,dea) ~ mem + dis}),
    surv_h2 = pmap(list(sur = survival, dea = deaths, mem = memb_h, dis = disease), function(sur, dea, mem, dis){ Surv(sur,dea) ~ mem + dis}),
    fit_d = map(surv_d2, ~ survfit(.x)),
    fit_h = map(surv_h2, ~ survfit(.x)),
    gpd = map(surv_d, ~ survfit(.x) |>
        ggsurvfit() +
        theme(legend.position = "none")
  ),
    gph = map(surv_h, ~ survfit(.x) |>
        ggsurvfit() +
        theme(legend.position = "none")
  )
  ) |> select(gpd,gph,despota_lab,hier_lab2,surv_d,surv_h,sdiff_d,sdiff_h,fit_d,fit_h)  
table(plotsur2$despota_lab[[1]])
## 
##   1   2   3   4 
##   1 469  93  23

print(survfit(plotsur2$surv_d[[1]]), rmean= 730, scale=365)
## Call: survfit(formula = plotsur2$surv_d[[1]])
## 
##         n events rmean* se(rmean) median 0.95LCL 0.95UCL
## mem=1   1      1  0.249    0.0000  0.249      NA      NA
## mem=2 469    150  1.809    0.0221  5.044   4.529    7.17
## mem=3  93     58  1.397    0.0721  1.715   1.282    2.73
## mem=4  23     17  1.429    0.1439  1.923   0.932    3.68
##     * restricted mean with upper limit =  2
print(plotsur2$fit_d[[1]], rmean= 730)
## Call: survfit(formula = .x)
## 
##                   n events rmean* se(rmean) median 0.95LCL 0.95UCL
## mem=1, dis=LUAD   1      1     91       0.0     91      NA      NA
## mem=2, dis=LUAD 238     68    675      10.1   1790    1531    3361
## mem=2, dis=LUSC 231     82    646      12.6   1984    1679    2625
## mem=3, dis=LUAD  81     46    538      27.1    807     500    1197
## mem=3, dis=LUSC  12     12    333      70.7    270     161      NA
## mem=4, dis=LUAD   7      6    506      83.6    624     260      NA
## mem=4, dis=LUSC  16     11    530      66.1    740     408      NA
##     * restricted mean with upper limit =  730
summary(plotsur2$fit_d[[1]],time=(0:4)*182.5, scale=365)
## Call: survfit(formula = .x)
## 
##                 mem=1, dis=LUAD 
##         time       n.risk      n.event     survival      std.err lower 95% CI 
##            0            1            0            1            0            1 
## upper 95% CI 
##            1 
## 
##                 mem=2, dis=LUAD 
##  time n.risk n.event survival std.err lower 95% CI upper 95% CI
##   0.0    238       0    1.000  0.0000        1.000        1.000
##   0.5    217       8    0.965  0.0120        0.942        0.989
##   1.0    198       8    0.929  0.0172        0.896        0.963
##   1.5    151       8    0.888  0.0216        0.847        0.932
##   2.0    103       8    0.835  0.0275        0.782        0.890
## 
##                 mem=2, dis=LUSC 
##  time n.risk n.event survival std.err lower 95% CI upper 95% CI
##   0.0    231       0    1.000  0.0000        1.000        1.000
##   0.5    198      14    0.937  0.0164        0.905        0.969
##   1.0    182      10    0.889  0.0215        0.847        0.932
##   1.5    148      12    0.826  0.0265        0.776        0.880
##   2.0    117      10    0.765  0.0308        0.707        0.828
## 
##                 mem=3, dis=LUAD 
##  time n.risk n.event survival std.err lower 95% CI upper 95% CI
##   0.0     81       0    1.000  0.0000        1.000        1.000
##   0.5     67      10    0.872  0.0379        0.800        0.949
##   1.0     52      12    0.709  0.0525        0.613        0.819
##   1.5     37       8    0.593  0.0577        0.490        0.718
##   2.0     25       4    0.523  0.0606        0.417        0.657
## 
##                 mem=3, dis=LUSC 
##  time n.risk n.event survival std.err lower 95% CI upper 95% CI
##   0.0     12       0    1.000   0.000       1.0000        1.000
##   0.5      8       4    0.667   0.136       0.4468        0.995
##   1.0      4       4    0.333   0.136       0.1498        0.742
##   1.5      3       1    0.250   0.125       0.0938        0.666
##   2.0      2       1    0.167   0.108       0.0470        0.591
## 
##                 mem=4, dis=LUAD 
##  time n.risk n.event survival std.err lower 95% CI upper 95% CI
##   0.0      7       0    1.000   0.000        1.000            1
##   0.5      7       0    1.000   0.000        1.000            1
##   1.0      4       3    0.571   0.187        0.301            1
##   1.5      4       0    0.571   0.187        0.301            1
##   2.0      1       2    0.214   0.178        0.042            1
## 
##                 mem=4, dis=LUSC 
##  time n.risk n.event survival std.err lower 95% CI upper 95% CI
##   0.0     16       0    1.000  0.0000        1.000        1.000
##   0.5     12       2    0.867  0.0878        0.711        1.000
##   1.0     10       2    0.722  0.1185        0.524        0.996
##   1.5      7       2    0.562  0.1361        0.349        0.903
##   2.0      7       0    0.562  0.1361        0.349        0.903


# figure Euclidean
ggpubr::ggarrange(
    plotsur2$gpd[[1]] +
      ggtitle("DESPOTA"),
    plotsur2$gph[[1]] +
      ggtitle("HCl") + ylab(NULL),
    nrow = 1)

# ggsave("survival_euclidean.pdf",
#        width = 8, height = 4)

Comparison between Euclidean and Entropy-based



fin_clust1d <- genom_red |> 
  select(-starts_with("Dim")) |> 
  mutate(
    d_lab = as_factor(plotsur1$despota_lab[[1]])
  )

fin_clust1h <- genom_red |> 
  select(-starts_with("Dim")) |> 
  mutate(
    d_lab = as_factor(plotsur1$hier_lab2[[1]])
  )

fin_clust2 <- genom_red |> 
  select(-starts_with("Dim")) |> 
  mutate(
    d_lab = as_factor(plotsur2$despota_lab[[1]])
  )



adjustedRandIndex(unlist(plotsur1$despota_lab),unlist(plotsur1$hier_lab2))
## [1] 0.2469297
adjustedRandIndex(unlist(plotsur2$despota_lab),unlist(plotsur2$hier_lab2))
## [1] 1
adjustedRandIndex(unlist(plotsur1$despota_lab),unlist(plotsur2$despota_lab))
## [1] 0.3779116

# V-tests
table(fin_clust1d$d_lab)
## 
##   1   2   3   4 
## 537  17  24   8
catdes1d <- fin_clust1d |> 
  mutate(
    gender = recode_factor(clin_gender,
                           "1" = "Female",
                           "2" = "Male"),
    "light smoker" = recode_factor(clin_pack_years_smoked_low,
                                   "1" = "True",
                                   "2" = "False"),
    "prior malignancy" = recode_factor(clin_prior_malignancy,
                                       "1" = "No",
                                       "2" = "Yes"),
    "low cancer volume" = recode_factor(clin_volume_low,
                                        "1" = "True",
                                        "2" = "False"),
    disease = genom$clin_disease,
    stage = recode_factor(clin_ajcc_pathologic_stage,
                  "1" = "I",
                  "2" = "I",
                  "3" = "I",
                  "4" = "II",
                  "5" = "II",
                  "6" = "II",
                  "7" = "III",
                  "8" = "III",
                  "9" = "III",
                  "10" = "IV"
                  )
  ) |>
  select(d_lab,gender,"light smoker", "prior malignancy" ,"low cancer volume","disease", "stage") |> 
FactoMineR::catdes(num.var = 1, proba=1)


table(fin_clust1h$d_lab)
## 
##   1   2   3   4 
## 312 225  17  32
catdes1h <- fin_clust1h |> 
  mutate(
    gender = recode_factor(clin_gender,
                           "1" = "Female",
                           "2" = "Male"),
    "light smoker" = recode_factor(clin_pack_years_smoked_low,
                                   "1" = "True",
                                   "2" = "False"),
    "prior malignancy" = recode_factor(clin_prior_malignancy,
                                       "1" = "No",
                                       "2" = "Yes"),
    "low cancer volume" = recode_factor(clin_volume_low,
                                        "1" = "True",
                                        "2" = "False"),
    disease = genom$clin_disease,
    stage = recode_factor(clin_ajcc_pathologic_stage,
                  "1" = "I",
                  "2" = "I",
                  "3" = "I",
                  "4" = "II",
                  "5" = "II",
                  "6" = "II",
                  "7" = "III",
                  "8" = "III",
                  "9" = "III",
                  "10" = "IV"
                  )
  ) |>
  select(d_lab,gender,"light smoker", "prior malignancy" ,"low cancer volume","disease", "stage") |> 
FactoMineR::catdes(num.var = 1, proba=1)


table(fin_clust2$d_lab)
## 
##   1   2   3   4 
##   1 469  93  23
catdes2 <- fin_clust2 |> 
  mutate(
    gender = recode_factor(clin_gender,
                           "1" = "Female",
                           "2" = "Male"),
    "light smoker" = recode_factor(clin_pack_years_smoked_low,
                                   "1" = "True",
                                   "2" = "False"),
    "prior malignancy" = recode_factor(clin_prior_malignancy,
                                       "1" = "No",
                                       "2" = "Yes"),
    "low cancer volume" = recode_factor(clin_volume_low,
                                        "1" = "True",
                                        "2" = "False"),
    disease = genom$clin_disease,
    stage = recode_factor(clin_ajcc_pathologic_stage,
                  "1" = "I",
                  "2" = "I",
                  "3" = "I",
                  "4" = "II",
                  "5" = "II",
                  "6" = "II",
                  "7" = "III",
                  "8" = "III",
                  "9" = "III",
                  "10" = "IV"
                  )
  ) |>
  select(d_lab,gender,"light smoker", "prior malignancy" ,"low cancer volume","disease", "stage") |> 
FactoMineR::catdes(num.var = 1, proba=1)

# data in grid version

table(fin_clust1d$d_lab) # new order: 1-3-2-4
## 
##   1   2   3   4 
## 537  17  24   8
table(fin_clust1h$d_lab) # new order: 1-2-4-3
## 
##   1   2   3   4 
## 312 225  17  32
table(fin_clust2$d_lab)  # new order: 2-3-4-1
## 
##   1   2   3   4 
##   1 469  93  23

# final heatmap
tibble(rbind(
  catdes1d$category[[1]] |> as.data.frame() |>  rownames_to_column() |> arrange(rowname) |> mutate(cluster = 1, method = "Entropy-based\n (DESPOTA)") ,
  catdes1d$category[[2]] |> as.data.frame() |>  rownames_to_column() |> arrange(rowname) |> mutate(cluster = 3, method = "Entropy-based\n (DESPOTA)") ,
  catdes1d$category[[3]] |> as.data.frame() |>  rownames_to_column() |> arrange(rowname) |> mutate(cluster = 2, method = "Entropy-based\n (DESPOTA)") ,
  catdes1d$category[[4]] |> as.data.frame() |>  rownames_to_column() |> arrange(rowname) |> mutate(cluster = 4, method = "Entropy-based\n (DESPOTA)") ,
  # 
  catdes1h$category[[1]] |> as.data.frame() |>  rownames_to_column() |> arrange(rowname) |> mutate(cluster = 1, method = "Entropy-based\n (Horizontal)") ,
  catdes1h$category[[2]] |> as.data.frame() |>  rownames_to_column() |> arrange(rowname) |> mutate(cluster = 2, method = "Entropy-based\n (Horizontal)") ,
  catdes1h$category[[3]] |> as.data.frame() |>  rownames_to_column() |> arrange(rowname) |> mutate(cluster = 4, method = "Entropy-based\n (Horizontal)") ,
  catdes1h$category[[4]] |> as.data.frame() |>  rownames_to_column() |> arrange(rowname) |> mutate(cluster = 3, method = "Entropy-based\n (Horizontal)") ,
  # 
  catdes2$category[[1]] |> as.data.frame() |>  rownames_to_column() |> arrange(rowname) |> mutate(cluster = 2, method = "Euclidean") ,
  catdes2$category[[2]] |> as.data.frame() |>  rownames_to_column() |> arrange(rowname) |> mutate(cluster = 3, method = "Euclidean"),
  catdes2$category[[3]] |> as.data.frame() |>  rownames_to_column() |> arrange(rowname) |> mutate(cluster = 4, method = "Euclidean"),
  catdes2$category[[4]] |> as.data.frame() |>  rownames_to_column() |> arrange(rowname) |> mutate(cluster = 1, method = "Euclidean")
)
) |> select(-"Cla/Mod",-"Mod/Cla",-"Global") |> 
  filter(rowname %in% c("stage=IV","stage=III","stage=II","stage=I","low.cancer.volume=True","prior.malignancy=Yes","light.smoker=True","gender=Male","gender=Female","disease=LUSC","disease=LUAD")) |> 
  mutate(
    p.value2 = cut(p.value, breaks=c(0, 0.001, 0.01, 0.05, 0.1, 1),
                       labels=c('***', '**', '*', '.', " "))
  ) |> 
  ggplot(aes(x = cluster, y = rowname, label = p.value2, fill = v.test)) +
  geom_tile() +
  geom_text() +
  facet_grid(. ~ method) +
  scale_fill_gradientn(colours = c("#F62D2D","white","#1034A6")) +
  ylab(NULL) +
  theme_minimal() +
  theme(panel.grid.major = element_blank(), panel.grid.minor = element_blank(), strip.text = element_text(face = "bold"))

# ggsave("vtest_fig.pdf",
#        width = 7, height = 4)