Introduction

scSurvival is a new, scalable, and interpretable tool specifically designed for survival analysis from single-cell cohort data. It first employs a feature extraction module based on a variational autoencoder and generative modeling to learn batch-invariant single-cell representations, and then aggregates cell-level features to the patient level to perform multi-head attention-based multiple instance Cox regression. scSurvival not only enables the integration of single-cell expression data with patient-level clinical variables to build accurate survival risk prediction models, but also identifies key cell subpopulations most associated with survival risk and characterizes their risk tendencies, thereby facilitating more refined downstream analyses.

0. Installation

If the “scSurvival” package is not yet installed, the user can run the following command in the terminal:

pip install -e ..

The algorithm package scSurvival is developed in Python. If you want to use it in R (version ≥ 4.1.0), you can do so with the help of the reticulate package. First, let’s load the required R packages:

rm(list=ls())
library(Seurat)
library(reticulate)
library(ggplot2)
library(gridExtra)
library(ggpubr)
library(dplyr)
library(splatter)
library(caret)
library(survival)
library(Hmisc)
library(survminer)
# set.seed(1)

Users can select the python environment by click “Tools->Global Options->Python” in Rstudio for call scSurvival, or set it by reticulate directly.

use_python("xxx/python") #for python
use_virtualenv("xxx") #for virtual environment
use_condaenv("xxx") #for conda environment

Check whether the reticulate package is functioning properly.

print('hello python')
## hello python

scSurvial examples

The input data for scSurvival consist of a single‐cell cohort dataset—each patient is represented by a single‐cell gene‐expression matrix together with their corresponding survival information. Moreover, the single‐cell samples may be drawn from multiple batches. In this tutorial, we use the Splatter package to generate two simulated datasets (one without batch effects and one with batch effects) and demonstrate how to apply scSurvival for analysis.

Aplly scSurvival on simulated data without batch effect

Simulated data generation

First, we generate a simulated single-cell cohort dataset without batch effects, together with the corresponding survival status and survival time. In the initial step, we use the Splatter package to simulate a single-cell expression matrix comprising three ground-truth risk-cell subpopulations—good.survival, bad.survival, and other—and then convert both the simulated expression data and the group labels into a Seurat object. Then, we can perform the standard single-cell preprocessing workflow, select the top 2,000 highly variable genes, and generate the UMAP visualization.

sim.groups <- splatSimulateGroups(batchCells = 10000, nGenes=5000,
                                  #group.prob = c(0.9, 0.05, 0.05),
                                  group.prob = c(0.7, 0.15, 0.15),
                                  de.prob = c(0.2, 0.06, 0.06), de.facLoc = c(0.1, 0.1, 0.1),
                                  de.facScale = 0.4,
                                  seed = 5)#


data <- CreateSeuratObject(counts = counts(sim.groups), project = 'Scissor_Single_Cell', min.cells = 100, min.features = 100)
data <- AddMetaData(object = data, metadata = sim.groups$Group, col.name = "sim.group")
data$sim.ground.truth <- recode(data$sim.group,'Group1'='other', 'Group2'='good.survival', 'Group3'='bad.survival')
print(data)
## An object of class Seurat 
## 4873 features across 10000 samples within 1 assay 
## Active assay: RNA (4873 features, 0 variable features)
##  1 layer present: counts
# preprocessing
data <- NormalizeData(object = data, normalization.method = "LogNormalize", 
                      scale.factor = 10000)
data <- FindVariableFeatures(object = data, selection.method = 'vst', nfeatures=2000)
var_features_genes = VariableFeatures(data)

data <- ScaleData(object = data)
data <- RunPCA(object = data, features = VariableFeatures(data))

data <- RunUMAP(object = data, dims = 1:10, n.neighbors = 5, min.dist=0.5, spread=1.5)
DimPlot(object = data, reduction = 'umap',  cols = c('grey','blue', 'red'), group.by = 'sim.group', pt.size = 0.5, label = T)

DimPlot(object = data, reduction = 'umap',  cols = c('grey','blue', 'red'), group.by = 'sim.ground.truth', pt.size = 0.5, label = T)

Next, we simulate single-cell expression data and survival information for 100 patients by sampling from the ground-truth risk-associated cell subpopulations. Each simulated patient is assigned 1,000 cells. The underlying principle is that patients with longer survival times have a higher proportion of good.survival cells and a lower proportion of bad.survival cells, while the proportion of other cells remains constant across all patients.

Additionally, we randomly select 10% of the patients to simulate censoring events. For these patients, the survival time is randomly shortened to a point prior to death, and their survival status is set to 0 (censored). The remaining patients are assigned a survival status of 1 (event occurred).

data_save_path <- './sim_data_wo_batch/'
scdata_save_path <- sprintf('%s/single_cell/', data_save_path)
dir.create(scdata_save_path, recursive=T)

Expression_pbmc <- as.matrix(data@assays[["RNA"]]@layers[["data"]])

rownames(Expression_pbmc) <- rownames(data)
colnames(Expression_pbmc) <- colnames(data)
Expression_pbmc <- Expression_pbmc[VariableFeatures(data), ]

###---simulation single cell expression data and survival info---------------------
set.seed(1)
sampled_cells = 1000
patient_num=100

other_cells <- colnames(Expression_pbmc)[data$sim.ground.truth=='other']
good_cells <- colnames(Expression_pbmc)[data$sim.ground.truth=='good.survival']
bad_cells <- colnames(Expression_pbmc)[data$sim.ground.truth=='bad.survival']
num_good <- length(good_cells)
num_bad <- length(bad_cells)

censor_prob = 0.1

status = NULL
surv_time = NULL

num_good_cond_cells = NULL
num_bad_cond_cells = NULL

for (i in 1:patient_num){
  ratio <- (i-1) / (patient_num-1)
  num_good_cond_cells_i = round(num_good * ratio)
  num_bad_cond_cells_i = round(num_bad * (1-ratio))
  condition_good_cells <- good_cells[sample(num_good, num_good_cond_cells_i , replace=TRUE)]
  condition_bad_cells <- bad_cells[sample(num_bad, num_bad_cond_cells_i, replace=TRUE)]
  condition_cells <- c(condition_good_cells, condition_bad_cells, other_cells)
  
  num_good_cond_cells = c(num_good_cond_cells, num_good_cond_cells_i)
  num_bad_cond_cells = c(num_bad_cond_cells, num_bad_cond_cells_i)
  
  Expression_condition = Expression_pbmc[, condition_cells]
  Expression_selected <- Expression_condition[, sample(ncol(Expression_condition),size=sampled_cells,replace=TRUE)]
  
  write.csv(Expression_selected, file = sprintf('%s/%d.csv', scdata_save_path, i))
  
  if (runif(1, min = 0, max = 1) < censor_prob){
    status = c(status, 0)
    surv_time = c(surv_time, sample(i, 1))
  }
  else{
    surv_time = c(surv_time, i)
    status = c(status, 1)
  }
}

patient_names <- paste0('patient', 1:patient_num)
surv_info <- data.frame(
  time=surv_time,
  status=status,
  num.good.cells = num_good_cond_cells,
  num.bad.cells = num_bad_cond_cells,
  row.names = patient_names
)

###---save---------------------
labels <- data$sim.ground.truth
labels <- as.data.frame(labels)
row.names(labels) <- colnames(data)
write.csv(labels, file=sprintf('%s/sim_groups.csv', data_save_path))
write.csv(Expression_pbmc, file = sprintf('%s/%s.csv', scdata_save_path, 'all_cells'))
write.csv(surv_info, file=sprintf('%s/surv_info.csv', data_save_path))

We can illustrate this simulation process by visualizing the single-cell expression profiles of a subset of patients,

plot_list <- list()
for (i in c(2, 10, 40, 60, 90, 99)){
  ratio <- (i-1) / (patient_num-1)
  num_good_cond_cells_i = round(num_good * ratio)
  num_bad_cond_cells_i = round(num_bad * (1-ratio))
  condition_good_cells <- good_cells[sample(num_good, num_good_cond_cells_i , replace=TRUE)]
  condition_bad_cells <- bad_cells[sample(num_bad, num_bad_cond_cells_i, replace=TRUE)]
  condition_cells <- c(condition_good_cells, condition_bad_cells, other_cells)
  
  p <- DimPlot(data[, condition_cells], group.by = 'sim.ground.truth', cols = c('grey','blue', 'red'), pt.size = 0.5) +
    ggtitle(paste("survival.time :", i, "months")) +
    theme(plot.title = element_text(size = 10))
  plot_list[[length(plot_list) + 1]] <- p
}
ggarrange(plotlist = plot_list, ncol = 3, nrow=2, common.legend = TRUE, legend = "bottom")

Run scSurvival to identify risk-associated cell subpopulations

Next, we load the simulated single-cell cohort data and corresponding survival information in a Python chunk and run scSurvival. We then visualize the distribution of the computed attention weights using a histogram.

from scSurvival import scSurvivalRun, PredictIndSample
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans

import os 
os.environ['KMP_DUPLICATE_LIB_OK']='True'

# load single cell cohort expression data
xs = []
samples = []
for i in range(1, 101):
    df = pd.read_csv(f'{r.scdata_save_path}/{i}.csv', index_col=0)
    xs.append(df.values.T)
    samples.extend(['patient%d' % i] * df.shape[1])

X = np.concatenate(xs, axis=0)
adata = sc.AnnData(X, obs=pd.DataFrame(samples, index=np.arange(X.shape[0]).astype(str), columns=['sample']))

# load survival information
clinic = pd.read_csv(f'{r.data_save_path}/surv_info.csv', index_col=0)
surv = clinic[['time', 'status']].copy()
surv['time'] = surv['time'].astype(float)
surv['status'] = surv['status'].astype(int)

# run scSurvival
adata, surv, model = scSurvivalRun(adata, 
                                sample_column='sample',
                                surv=surv,
                                feature_flavor='AE',
                                valid=False,
                                entropy_threshold=0.7,
                                pretrain_epochs=200,
                                epochs=500,
                                patience=100,
                                fitnetune_strategy='alternating', 
                                )  
sns.histplot(adata.obs['attention'], bins=50)
plt.show()

plt.close()

Result visualization

Typically, adata contains the results of calculations directly and can be used for visualization and downstream analysis. However, for speed and convenience, we visualize the cell-level inference results produced by scSurvival using the initial simulated single-cell dataset in the tutorial. First, we apply the trained model to this dataset to perform inference.

df = pd.read_csv(f'{r.scdata_save_path}/all_cells.csv', index_col=0)
x = df.values.T
adata_new = sc.AnnData(x, obs=pd.DataFrame(index=np.arange(x.shape[0]).astype(str)))
adata_new, _ = PredictIndSample(adata_new, adata, model)
## gene missing rate: 0.00%
## Added hazard and attention to adata_new.obs.

attention = adata_new.obs['attention'].values
hazard_adj = adata_new.obs['hazard_adj'].values

We then load the inference results into the Seurat object for visualization.

data$attention <- py$attention
data$hazard_adj <- py$hazard_adj

cols = c("blue","lightgrey", "red")
FeaturePlot(data, features = c('attention'), pt.size = 0.5) + scale_colour_gradientn(colours=c("lightgrey", "blue"))
## Scale for colour is already present.
## Adding another scale for colour, which will replace the existing scale.

FeaturePlot(data, features = c('hazard_adj'), pt.size = 0.5) + scale_colour_gradientn(colours=cols)
## Scale for colour is already present.
## Adding another scale for colour, which will replace the existing scale.

Furthermore, we can use the computed attention scores and cell-level risk scores (hazard_adj) to stratify cells into different risk groups. To begin, we apply K-means clustering to determine a threshold for the attention scores.

data = adata.obs['attention'].values.reshape(-1, 1)
kmeans = KMeans(n_clusters=2, random_state=42)
kmeans.fit(data)
## KMeans(n_clusters=2, random_state=42)
labels = kmeans.labels_
cluster_centers = kmeans.cluster_centers_
atten_thr = cluster_centers.flatten().mean()
print("attention cutoff:", atten_thr)
## attention cutoff: 0.5012851

Next, based on the attention threshold and the sign of the hazard_adj score, we classify cells into distinct risk groups. Specifically, we define three categories of cells: higher-risk, lower-risk, and inattentive.

atten_thr <- py$atten_thr
risk_group <- rep('inattentive', dim(data)[2])
risk_group[(data$hazard_adj > 0 & data$attention > atten_thr)] <- 'higher'
risk_group[(data$hazard_adj < 0 & data$attention > atten_thr)] <- 'lower'
data$surv.risk.group <- factor(risk_group, levels=c('higher', 'lower', 'inattentive'))

DimPlot(object = data, reduction = 'umap', cols = c('red','blue','grey'), group.by = 'surv.risk.group', pt.size = 0.5, label = T)

Based on the inferred cell risk groups, we can evaluate the accuracy of cell identification by computing a confusion matrix with respect to the ground-truth risk-associated subpopulations, as well as performance metrics such as precision, recall, and F1 score.

data$predicted.risk.group = recode(data$surv.risk.group, 'higher'='bad.survival', 'lower'='good.survival', 'inattentive'='other')
cm <- confusionMatrix(data$predicted.risk.group, data$sim.ground.truth)
## Warning in confusionMatrix.default(data$predicted.risk.group,
## data$sim.ground.truth): Levels are not in the same order for reference and
## data. Refactoring data to match.
precision <- mean(cm$byClass[, 'Pos Pred Value'])
recall <- mean(cm$byClass[, 'Sensitivity'])
f1_score <- 2 * precision * recall / (precision + recall)

# confusion matrix
cm_table <- table(data$sim.ground.truth, data$surv.risk.group)
cm_df <- as.data.frame(as.table(cm_table))
colnames(cm_df) <- c("actual", "predicted", "Freq")

# row normalization
cm_df <- cm_df %>%
  group_by(actual) %>%
  mutate(Proportion = Freq / sum(Freq)) %>% 
  ungroup() 

# visualization 
precision_str <- sprintf("Precision: %.3f", precision)
recall_str <- sprintf("Recall: %.3f", recall)
f1_str <- sprintf("F1-score: %.3f", f1_score)
ggplot(data = cm_df, aes(x = predicted, y = actual, fill = Proportion)) +
  geom_tile() +
  geom_text(aes(label = Freq), color = "white", size = 5) +  # 仍然显示原始数量
  scale_fill_gradient(low = "white", high = "red") +  # 颜色按行比例填充
  theme_minimal() +
  labs(title = "Confusion Matrix", x = "Detected", y = "Actual") +
  theme(
    axis.line = element_line(color = "black"),  
    axis.ticks = element_line(color = "black"), 
    axis.text = element_text(size = 12, color = "black"),  
    axis.title = element_text(size = 14, face = "bold")
  ) + 
  annotate("text", x = max(as.numeric(cm_df$predicted))-0.21, y = 3.4, label = precision_str, hjust = 0, size = 3, color = "black") +
  annotate("text", x = max(as.numeric(cm_df$predicted))-0.21, y = 3.2, label = recall_str, hjust = 0, size = 3, color = "black") +
  annotate("text", x = max(as.numeric(cm_df$predicted))-0.21, y = 3, label = f1_str, hjust = 0, size = 3, color = "black")

Evaluating patient-level risk prediction performance of scSurvival

scSurvival can also be used as a risk prediction tool to estimate patient survival directly from single-cell expression data. To assess its performance at the patient level, we employ K-fold cross-validation to evaluate the predictive accuracy of the scSurvival model in estimating patient survival risk.

from sklearn.model_selection import KFold
import io
import contextlib
f = io.StringIO()
from lifelines.utils import concordance_index
from scipy.stats import percentileofscore

patients = adata.obs['sample'].unique()

# K fold cross validation
cv_hazards_adj_cells = np.zeros((adata.shape[0], ))
surv['cv_hazards_adj_patient'] = 0.0
surv['cv_hazard_percentile_patient'] = 0.0
cindexs = []
surv_test_all_folds = []
kf = KFold(n_splits=5, shuffle=True, random_state=42)

for i, (train_index, test_index) in enumerate(kf.split(patients)):
    print(f'fold {i}, train_size: {train_index.shape[0]}, test_size: {test_index.shape[0]}')
    train_patients = patients[train_index]
    test_patients = patients[test_index]

    # train
    adata_train = adata[adata.obs['sample'].isin(train_patients), :]
    surv_train = surv.loc[surv.index.isin(train_patients), :].copy()
    adata_train, surv_train, model = scSurvivalRun(
        adata_train,
        sample_column='sample',
        surv=surv_train,
        # batch_key='batch',
        feature_flavor='AE',
        entropy_threshold=0.7,
        pretrain_epochs=200,
        epochs=500,
        patience=100,
        fitnetune_strategy='alternating', # jointly, alternating, alternating_lightly,
        )  
    train_cindex = concordance_index(surv_train['time'], -surv_train['patient_hazards'], surv_train['status'])
    print(f'train c-index: {train_cindex:.4f}'
    
    # test
    print('testing...')
    adata_test = adata[adata.obs['sample'].isin(test_patients), :]
    with contextlib.redirect_stdout(f):
        for test_patient in test_patients:
            adata_test_patient = adata_test[adata_test.obs['sample'] == test_patient, :].copy()
            adata_test_patient, patient_hazard = PredictIndSample(adata_test_patient, adata_train, model)
            cv_hazards_adj_cells[adata.obs['sample'] == test_patient] = adata_test_patient.obs['hazard_adj'].values
            surv.loc[surv.index == test_patient, 'cv_hazards_adj_patient'] = patient_hazard
            surv.loc[surv.index == test_patient, 'cv_hazard_percentile_patient'] = percentileofscore(surv_train['patient_hazards'], patient_hazard, kind='rank')

    surv_test = surv.loc[surv.index.isin(test_patients), :]
    c_index = concordance_index(surv_test['time'], -surv_test['cv_hazards_adj_patient'], surv_test['status'])
    cindexs.append(c_index)
    surv_test_all_folds.append(surv_test)

    print(f'c-index: {c_index:.4f}')
    print('='*50)

mean_cindex = np.mean(cindexs)
std_cindex = np.std(cindexs)
print(f'mean c-index: {mean_cindex:.4f} ± {std_cindex:.4f}')

By collecting the risk prediction results (quantiles of relative risk scores) from all folds, we can further evaluate the overall predictive accuracy and the risk stratification capability of the model across the entire patient cohort.

df = py$surv
cindex <- rcorr.cens(-data$cv_hazard_percentile_patient, Surv(data$time, data$status))[['C Index']]
print(paste("Over all C-index:", cindex))

df$risk_group <- ifelse(df$cv_hazard_percentile_patient > median(df$cv_hazard_percentile_patient), "High risk", "Low risk")
fit <- survfit(Surv(time, status) ~ risk_group, data=df)
p <- ggsurvplot(fit,
           legend.title = "Risk group",
           legend.labs = c("High risk", "Low risk"),
           palette = c("red", "blue"),
           risk.table = F,
           pval = TRUE,
           conf.int = TRUE, 
           title = "KM curve of risk groups (5-fold CV)"
) 
p

Aplly scSurvival on simulated data with batch effect

Simulated data generation

For the second example, we generate a simulated single-cell cohort dataset with batch effects. The simulation process is similar to the previous example, but we introduce batch effects by simulating two batches of cells with different proportions of good.survival and bad.survival cells. We then convert the simulated expression data and group labels into a Seurat object.

sim.groups <- splatSimulateGroups(batchCells = c(6000, 4000), nGenes=5000,
                                  #group.prob = c(0.9, 0.05, 0.05),
                                  group.prob = c(0.7, 0.15, 0.15),
                                  de.prob = c(0.2, 0.06, 0.06), de.facLoc = c(0.1, 0.1, 0.1),
                                  de.facScale = 0.4,
                                  seed = 5)#


data <- CreateSeuratObject(counts = counts(sim.groups), project = 'ScSurvival_Single_Cell', min.cells = 100, min.features = 100)
data <- AddMetaData(object = data, metadata = sim.groups$Group, col.name = "sim.group")

data$sim.ground.truth <- recode(data$sim.group,'Group1'='other', 'Group2'='good.survival', 'Group3'='bad.survival')
data$batch <- c(rep('Batch1', 6000), rep('Batch2', 4000))

data <- NormalizeData(object = data, normalization.method = "LogNormalize", 
                      scale.factor = 10000)
data <- FindVariableFeatures(object = data, selection.method = 'vst', nfeatures=2000)
var_features_genes = VariableFeatures(data)

data <- ScaleData(object = data)
data <- RunPCA(object = data, features = VariableFeatures(data))
data <- RunUMAP(object = data, dims = 1:10, n.neighbors = 5, min.dist=0.5, spread=1.5)
DimPlot(object = data, reduction = 'umap',  cols = c('grey','blue','red'), group.by = 'sim.group', pt.size = 0.5, label = F)

DimPlot(object = data, reduction = 'umap', group.by = 'batch', pt.size = 0.5, label = F)

We then simulate single-cell data for 50 patients from each of the two batches. Within each batch, the simulation of patients’ cellular composition and survival status follows the same approach as before.

data_save_path <- './sim_data_with_batch/'
scdata_save_path <- sprintf('%s/single_cell/', data_save_path)
dir.create(scdata_save_path, recursive=T)


Expression_pbmc <- as.matrix(data@assays[["RNA"]]@layers[["data"]])
rownames(Expression_pbmc) <- rownames(data)
colnames(Expression_pbmc) <- colnames(data)
Expression_pbmc <- Expression_pbmc[VariableFeatures(data), ]

###---simulation---------------------
set.seed(1)

sampled_cells = 1000
bulk_num=50

###---simulation from batch1---------------------
other_cells <- colnames(Expression_pbmc)[data$sim.ground.truth=='other' & data$batch=='Batch1']
good_cells <- colnames(Expression_pbmc)[data$sim.ground.truth=='good.survival' & data$batch=='Batch1']
bad_cells <- colnames(Expression_pbmc)[data$sim.ground.truth=='bad.survival' & data$batch=='Batch1']
num_good <- length(good_cells)
num_bad <- length(bad_cells)

bulk_condition = NULL
censor_prob = 0.1

status = NULL
surv_time = NULL

num_good_cond_cells = NULL
num_bad_cond_cells = NULL

for (i in 1:bulk_num){
  ratio <- (i-1) / (bulk_num-1)
  num_good_cond_cells_i = round(num_good * ratio)
  num_bad_cond_cells_i = round(num_bad * (1-ratio))
  condition_good_cells <- good_cells[sample(num_good, num_good_cond_cells_i , replace=TRUE)]
  condition_bad_cells <- bad_cells[sample(num_bad, num_bad_cond_cells_i, replace=TRUE)]
  condition_cells <- c(condition_good_cells, condition_bad_cells, other_cells)
  
  num_good_cond_cells = c(num_good_cond_cells, num_good_cond_cells_i)
  num_bad_cond_cells = c(num_bad_cond_cells, num_bad_cond_cells_i)
  
  Expression_condition = Expression_pbmc[, condition_cells]
  Expression_selected <- Expression_condition[, sample(ncol(Expression_condition),size=sampled_cells,replace=TRUE)]
  
  write.csv(Expression_selected, file = sprintf('%s/%d.csv', scdata_save_path, i))
  
  if (runif(1, min = 0, max = 1) < censor_prob){
    status = c(status, 0)
    surv_time = c(surv_time, sample(i, 1))
  }
  else{
    surv_time = c(surv_time, i)
    status = c(status, 1)
  }
}


###---simulation from batch2---------------------
other_cells <- colnames(Expression_pbmc)[data$sim.ground.truth=='other' & data$batch=='Batch2']
good_cells <- colnames(Expression_pbmc)[data$sim.ground.truth=='good.survival' & data$batch=='Batch2']
bad_cells <- colnames(Expression_pbmc)[data$sim.ground.truth=='bad.survival' & data$batch=='Batch2']
num_good <- length(good_cells)
num_bad <- length(bad_cells)

censor_prob = 0.1

for (i in 1:bulk_num){
  ratio <- (i-1) / (bulk_num-1)
  num_good_cond_cells_i = round(num_good * ratio)
  num_bad_cond_cells_i = round(num_bad * (1-ratio))
  condition_good_cells <- good_cells[sample(num_good, num_good_cond_cells_i , replace=TRUE)]
  condition_bad_cells <- bad_cells[sample(num_bad, num_bad_cond_cells_i, replace=TRUE)]
  condition_cells <- c(condition_good_cells, condition_bad_cells, other_cells)
  
  num_good_cond_cells = c(num_good_cond_cells, num_good_cond_cells_i)
  num_bad_cond_cells = c(num_bad_cond_cells, num_bad_cond_cells_i)
  
  Expression_condition = Expression_pbmc[, condition_cells]
  Expression_selected <- Expression_condition[, sample(ncol(Expression_condition),size=sampled_cells,replace=TRUE)]
  
  write.csv(Expression_selected, file = sprintf('%s/%d.csv', scdata_save_path, i+50))
  
  if (runif(1, min = 0, max = 1) < censor_prob){
    status = c(status, 0)
    surv_time = c(surv_time, sample(i, 1))
  }
  else{
    surv_time = c(surv_time, i)
    status = c(status, 1)
  }
}

bulk_names <- paste0('bulk', 1:(bulk_num*2))
surv_info <- data.frame(
  time=surv_time,
  status=status,
  num.good.cells = num_good_cond_cells,
  num.bad.cells = num_bad_cond_cells,
  row.names = bulk_names
)


###-----save---------------------
labels <- data$sim.ground.truth
labels <- as.data.frame(labels)
row.names(labels) <- colnames(data)
write.csv(labels, file=sprintf('%s/sim_groups.csv', data_save_path))

write.csv(Expression_pbmc, file = sprintf('%s/%s.csv', scdata_save_path, 'all_cells'))
write.csv(surv_info, file=sprintf('%s/surv_info.csv', data_save_path))

batch_ids <- as.data.frame(data$batch)
row.names(batch_ids) <- colnames(data)
colnames(batch_ids) <- 'batch_ids'
write.csv(batch_ids, file=sprintf('%s/batch_ids.csv', data_save_path))

Run scSurvival to identify risk-associated cell subpopulations across batches

When running scSurvival, specify the batch_key parameter to enable scSurvival to automatically handle batch effects.

from scSurvival import scSurvivalRun, PredictIndSample
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans

import os 
os.environ['KMP_DUPLICATE_LIB_OK']='True'

xs = []
samples = []
batches = []
for i in range(1, 101):
    df = pd.read_csv(f'{r.scdata_save_path}/{i}.csv', index_col=0)
    xs.append(df.values.T)
    samples.extend(['bulk%d' % i] * df.shape[1])
    if i <= 50:
        batches.extend(['batch1'] * df.shape[1])
    else:
        batches.extend(['batch2'] * df.shape[1])

obs_df = pd.DataFrame({'sample': samples, 'batch': batches})
obs_df.index = np.arange(len(samples))

X = np.concatenate(xs, axis=0)
adata = sc.AnnData(X, obs=obs_df)
clinic = pd.read_csv(f'{r.data_save_path}/surv_info.csv', index_col=0)
surv = clinic[['time', 'status']].copy()
surv['time'] = surv['time'].astype(float)
surv['status'] = surv['status'].astype(int)

adata, surv, model = scSurvivalRun(adata, 
                                sample_column='sample',
                                surv=surv,
                                batch_key='batch', 
                                feature_flavor='AE',
                                entropy_threshold=0.7,
                                pretrain_epochs=200,
                                epochs=500,
                                patience=100,
                                fitnetune_strategy='alternating'
                                )  
sns.histplot(adata.obs['attention'], bins=50)
plt.show()

plt.close()

# calculate the attention threshold
from sklearn.cluster import KMeans
data = adata.obs['attention'].values.reshape(-1, 1)
kmeans = KMeans(n_clusters=2, random_state=42)

kmeans.fit(data)
labels = kmeans.labels_
cluster_centers = kmeans.cluster_centers_

atten_thr = cluster_centers.flatten().mean()
print("cutoff:", atten_thr)

Result Visualization

As in the first example, we can use the trained model to perform inference on the initially simulated single-cell data and visualize the results.

df = pd.read_csv(f'{r.scdata_save_path}/all_cells.csv', index_col=0)
x = df.values.T
sim_group = pd.read_csv(f'{r.data_save_path}/sim_groups.csv', index_col=0)
sim_group = sim_group['labels'].values

batches = pd.read_csv(f'{r.data_save_path}/batch_ids.csv', index_col=0)
batches = batches['batch_ids'].values
batches= [each.lower() for each in batches]

adata_new = sc.AnnData(x, obs=pd.DataFrame(sim_group, index=np.arange(x.shape[0]).astype(str), columns=['sim_group']))
adata_new.obs['batch'] = batches

batches = model.le.transform(adata_new.obs['batch'].values)
exp = adata_new.X
h, a, cell_hazards, cell_hazards_weighted = model.predict_cells(exp, batch_labels=batches)
adata_new.obsm['X_ae'] = h.cpu().detach().numpy()

adata_new.obs['hazard'] = cell_hazards.cpu().detach().numpy()
adata_new.obs['attention'] = a.cpu().detach().numpy()
adata_new.obs['hazard_adj'] = cell_hazards_weighted.cpu().detach().numpy()

attention = adata_new.obs['attention'].values
hazard_adj = adata_new.obs['hazard_adj'].values
hazard = adata_new.obs['hazard'].values

X_ae = adata_new.obsm['X_ae']
data$attention <- py$attention
data$hazard_adj <- py$hazard_adj
data$hazard <- py$hazard

atten_thr <- py$atten_thr
risk_group <- rep('inattentive', dim(data)[2])
risk_group[(data$hazard_adj > 0 & data$attention > atten_thr)] <- 'higher'
risk_group[(data$hazard_adj < 0 & data$attention > atten_thr)] <- 'lower'

data$surv.risk.group <- factor(risk_group, levels=c('higher', 'lower', 'inattentive'))

cols = c("blue","lightgrey", "red")
FeaturePlot(data, features = c('attention'), pt.size = 0.5) + scale_colour_gradientn(colours=c("lightgrey", "blue"))
## Scale for colour is already present.
## Adding another scale for colour, which will replace the existing scale.

FeaturePlot(data, features = c('hazard_adj'), pt.size = 0.5) + scale_colour_gradientn(colours=cols)
## Scale for colour is already present.
## Adding another scale for colour, which will replace the existing scale.

DimPlot(object = data, reduction = 'umap', cols = c('red','blue','grey'), group.by = 'surv.risk.group', pt.size = 0.5, label = T)

At the same time, we can extract the cell embeddings (X_ae) generated by the scSurvival model and add them to the Seurat object as a new low-dimensional representation. Using this embedding, we generate new UMAP coordinates for visualization. As shown, the batch effects have been effectively corrected.

emd_ae <- py$X_ae
rownames(emd_ae) <- colnames(data)
data[["X_ae"]] <- CreateDimReducObject(embeddings = emd_ae, key = "AE_", assay = "RNA")
data <- RunUMAP(object = data, reduction = 'X_ae', dims = 1:ncol(emd_ae), n.neighbors = 5, min.dist=0.5, spread=1.5)

DimPlot(object = data, cols = c('red','blue'), group.by = 'batch', pt.size = 0.5, label = F)

FeaturePlot(data, features = c('attention'), pt.size = 0.5) + scale_colour_gradientn(colours=c("lightgrey", "blue"))

FeaturePlot(data, features = c('hazard_adj'), pt.size = 0.5) + scale_colour_gradientn(colours=cols)

DimPlot(object = data, reduction = 'umap', cols = c('red','blue','grey'), group.by = 'surv.risk.group', pt.size = 0.5, label = T)

Notes on predicting independent data with batch effects

When batch effects exist either within the training data or between the training and test datasets, directly applying a model trained by scSurvival to predict survival risk on independent data may yield inaccurate results. This is because the batch labels of the test data were never seen during training, which prevents the model from performing reliable inference. In such cases, the PredictIndSample function is not applicable.

To address this issue, we recommend combining the test data (along with their batch labels) with the training data and re-training the model using the scSurvivalRun function. This enables transfer learning by jointly learning from both datasets. Upon completion, scSurvivalRun will return an updated adata object and a surv DataFrame that include prediction results for the test data.

Package versions

R packages loaded in this tutorial:
Seurat 5.1.0
reticulate 1.38.0
ggplot2 3.5.1
gridExtra 2.3
ggpubr 0.6.0
dplyr 1.1.4
splatter 1.28.0
caret 6.0-94
survival 3.6-4
Hmisc 5.2-1
survminer 0.5.0

Python packages that scSurvival depends on:
torch 2.4.0+cu124
numpy 1.26.4
pandas 2.2.2
scanpy 1.10.2
scikit-learn 1.4.2
scipy 1.13.1
lifelines 0.30.0