scSurvival is a new, scalable, and interpretable tool specifically designed for survival analysis from single-cell cohort data. It first employs a feature extraction module based on a variational autoencoder and generative modeling to learn batch-invariant single-cell representations, and then aggregates cell-level features to the patient level to perform multi-head attention-based multiple instance Cox regression. scSurvival not only enables the integration of single-cell expression data with patient-level clinical variables to build accurate survival risk prediction models, but also identifies key cell subpopulations most associated with survival risk and characterizes their risk tendencies, thereby facilitating more refined downstream analyses.
If the “scSurvival” package is not yet installed, the user can run the following command in the terminal:
pip install -e ..
The algorithm package scSurvival is developed in Python. If you want to use it in R (version ≥ 4.1.0), you can do so with the help of the reticulate package. First, let’s load the required R packages:
rm(list=ls())
library(Seurat)
library(reticulate)
library(ggplot2)
library(gridExtra)
library(ggpubr)
library(dplyr)
library(splatter)
library(caret)
library(survival)
library(Hmisc)
library(survminer)
# set.seed(1)
Users can select the python environment by click “Tools->Global
Options->Python” in Rstudio for call scSurvival
, or set
it by reticulate
directly.
use_python("xxx/python") #for python
use_virtualenv("xxx") #for virtual environment
use_condaenv("xxx") #for conda environment
Check whether the reticulate package is functioning properly.
print('hello python')
## hello python
The input data for scSurvival consist of a single‐cell cohort dataset—each patient is represented by a single‐cell gene‐expression matrix together with their corresponding survival information. Moreover, the single‐cell samples may be drawn from multiple batches. In this tutorial, we use the Splatter package to generate two simulated datasets (one without batch effects and one with batch effects) and demonstrate how to apply scSurvival for analysis.
First, we generate a simulated single-cell cohort dataset without batch effects, together with the corresponding survival status and survival time. In the initial step, we use the Splatter package to simulate a single-cell expression matrix comprising three ground-truth risk-cell subpopulations—good.survival, bad.survival, and other—and then convert both the simulated expression data and the group labels into a Seurat object. Then, we can perform the standard single-cell preprocessing workflow, select the top 2,000 highly variable genes, and generate the UMAP visualization.
sim.groups <- splatSimulateGroups(batchCells = 10000, nGenes=5000,
#group.prob = c(0.9, 0.05, 0.05),
group.prob = c(0.7, 0.15, 0.15),
de.prob = c(0.2, 0.06, 0.06), de.facLoc = c(0.1, 0.1, 0.1),
de.facScale = 0.4,
seed = 5)#
data <- CreateSeuratObject(counts = counts(sim.groups), project = 'Scissor_Single_Cell', min.cells = 100, min.features = 100)
data <- AddMetaData(object = data, metadata = sim.groups$Group, col.name = "sim.group")
data$sim.ground.truth <- recode(data$sim.group,'Group1'='other', 'Group2'='good.survival', 'Group3'='bad.survival')
print(data)
## An object of class Seurat
## 4873 features across 10000 samples within 1 assay
## Active assay: RNA (4873 features, 0 variable features)
## 1 layer present: counts
# preprocessing
data <- NormalizeData(object = data, normalization.method = "LogNormalize",
scale.factor = 10000)
data <- FindVariableFeatures(object = data, selection.method = 'vst', nfeatures=2000)
var_features_genes = VariableFeatures(data)
data <- ScaleData(object = data)
data <- RunPCA(object = data, features = VariableFeatures(data))
data <- RunUMAP(object = data, dims = 1:10, n.neighbors = 5, min.dist=0.5, spread=1.5)
DimPlot(object = data, reduction = 'umap', cols = c('grey','blue', 'red'), group.by = 'sim.group', pt.size = 0.5, label = T)
DimPlot(object = data, reduction = 'umap', cols = c('grey','blue', 'red'), group.by = 'sim.ground.truth', pt.size = 0.5, label = T)
Next, we simulate single-cell expression data and survival information for 100 patients by sampling from the ground-truth risk-associated cell subpopulations. Each simulated patient is assigned 1,000 cells. The underlying principle is that patients with longer survival times have a higher proportion of good.survival cells and a lower proportion of bad.survival cells, while the proportion of other cells remains constant across all patients.
Additionally, we randomly select 10% of the patients to simulate censoring events. For these patients, the survival time is randomly shortened to a point prior to death, and their survival status is set to 0 (censored). The remaining patients are assigned a survival status of 1 (event occurred).
data_save_path <- './sim_data_wo_batch/'
scdata_save_path <- sprintf('%s/single_cell/', data_save_path)
dir.create(scdata_save_path, recursive=T)
Expression_pbmc <- as.matrix(data@assays[["RNA"]]@layers[["data"]])
rownames(Expression_pbmc) <- rownames(data)
colnames(Expression_pbmc) <- colnames(data)
Expression_pbmc <- Expression_pbmc[VariableFeatures(data), ]
###---simulation single cell expression data and survival info---------------------
set.seed(1)
sampled_cells = 1000
patient_num=100
other_cells <- colnames(Expression_pbmc)[data$sim.ground.truth=='other']
good_cells <- colnames(Expression_pbmc)[data$sim.ground.truth=='good.survival']
bad_cells <- colnames(Expression_pbmc)[data$sim.ground.truth=='bad.survival']
num_good <- length(good_cells)
num_bad <- length(bad_cells)
censor_prob = 0.1
status = NULL
surv_time = NULL
num_good_cond_cells = NULL
num_bad_cond_cells = NULL
for (i in 1:patient_num){
ratio <- (i-1) / (patient_num-1)
num_good_cond_cells_i = round(num_good * ratio)
num_bad_cond_cells_i = round(num_bad * (1-ratio))
condition_good_cells <- good_cells[sample(num_good, num_good_cond_cells_i , replace=TRUE)]
condition_bad_cells <- bad_cells[sample(num_bad, num_bad_cond_cells_i, replace=TRUE)]
condition_cells <- c(condition_good_cells, condition_bad_cells, other_cells)
num_good_cond_cells = c(num_good_cond_cells, num_good_cond_cells_i)
num_bad_cond_cells = c(num_bad_cond_cells, num_bad_cond_cells_i)
Expression_condition = Expression_pbmc[, condition_cells]
Expression_selected <- Expression_condition[, sample(ncol(Expression_condition),size=sampled_cells,replace=TRUE)]
write.csv(Expression_selected, file = sprintf('%s/%d.csv', scdata_save_path, i))
if (runif(1, min = 0, max = 1) < censor_prob){
status = c(status, 0)
surv_time = c(surv_time, sample(i, 1))
}
else{
surv_time = c(surv_time, i)
status = c(status, 1)
}
}
patient_names <- paste0('patient', 1:patient_num)
surv_info <- data.frame(
time=surv_time,
status=status,
num.good.cells = num_good_cond_cells,
num.bad.cells = num_bad_cond_cells,
row.names = patient_names
)
###---save---------------------
labels <- data$sim.ground.truth
labels <- as.data.frame(labels)
row.names(labels) <- colnames(data)
write.csv(labels, file=sprintf('%s/sim_groups.csv', data_save_path))
write.csv(Expression_pbmc, file = sprintf('%s/%s.csv', scdata_save_path, 'all_cells'))
write.csv(surv_info, file=sprintf('%s/surv_info.csv', data_save_path))
We can illustrate this simulation process by visualizing the single-cell expression profiles of a subset of patients,
plot_list <- list()
for (i in c(2, 10, 40, 60, 90, 99)){
ratio <- (i-1) / (patient_num-1)
num_good_cond_cells_i = round(num_good * ratio)
num_bad_cond_cells_i = round(num_bad * (1-ratio))
condition_good_cells <- good_cells[sample(num_good, num_good_cond_cells_i , replace=TRUE)]
condition_bad_cells <- bad_cells[sample(num_bad, num_bad_cond_cells_i, replace=TRUE)]
condition_cells <- c(condition_good_cells, condition_bad_cells, other_cells)
p <- DimPlot(data[, condition_cells], group.by = 'sim.ground.truth', cols = c('grey','blue', 'red'), pt.size = 0.5) +
ggtitle(paste("survival.time :", i, "months")) +
theme(plot.title = element_text(size = 10))
plot_list[[length(plot_list) + 1]] <- p
}
ggarrange(plotlist = plot_list, ncol = 3, nrow=2, common.legend = TRUE, legend = "bottom")
Next, we load the simulated single-cell cohort data and corresponding survival information in a Python chunk and run scSurvival. We then visualize the distribution of the computed attention weights using a histogram.
from scSurvival import scSurvivalRun, PredictIndSample
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
# load single cell cohort expression data
xs = []
samples = []
for i in range(1, 101):
df = pd.read_csv(f'{r.scdata_save_path}/{i}.csv', index_col=0)
xs.append(df.values.T)
samples.extend(['patient%d' % i] * df.shape[1])
X = np.concatenate(xs, axis=0)
adata = sc.AnnData(X, obs=pd.DataFrame(samples, index=np.arange(X.shape[0]).astype(str), columns=['sample']))
# load survival information
clinic = pd.read_csv(f'{r.data_save_path}/surv_info.csv', index_col=0)
surv = clinic[['time', 'status']].copy()
surv['time'] = surv['time'].astype(float)
surv['status'] = surv['status'].astype(int)
# run scSurvival
adata, surv, model = scSurvivalRun(adata,
sample_column='sample',
surv=surv,
feature_flavor='AE',
valid=False,
entropy_threshold=0.7,
pretrain_epochs=200,
epochs=500,
patience=100,
fitnetune_strategy='alternating',
)
sns.histplot(adata.obs['attention'], bins=50)
plt.show()
plt.close()
Typically, adata contains the results of calculations directly and can be used for visualization and downstream analysis. However, for speed and convenience, we visualize the cell-level inference results produced by scSurvival using the initial simulated single-cell dataset in the tutorial. First, we apply the trained model to this dataset to perform inference.
df = pd.read_csv(f'{r.scdata_save_path}/all_cells.csv', index_col=0)
x = df.values.T
adata_new = sc.AnnData(x, obs=pd.DataFrame(index=np.arange(x.shape[0]).astype(str)))
adata_new, _ = PredictIndSample(adata_new, adata, model)
## gene missing rate: 0.00%
## Added hazard and attention to adata_new.obs.
attention = adata_new.obs['attention'].values
hazard_adj = adata_new.obs['hazard_adj'].values
We then load the inference results into the Seurat object for visualization.
data$attention <- py$attention
data$hazard_adj <- py$hazard_adj
cols = c("blue","lightgrey", "red")
FeaturePlot(data, features = c('attention'), pt.size = 0.5) + scale_colour_gradientn(colours=c("lightgrey", "blue"))
## Scale for colour is already present.
## Adding another scale for colour, which will replace the existing scale.
FeaturePlot(data, features = c('hazard_adj'), pt.size = 0.5) + scale_colour_gradientn(colours=cols)
## Scale for colour is already present.
## Adding another scale for colour, which will replace the existing scale.
Furthermore, we can use the computed attention scores and cell-level risk scores (hazard_adj) to stratify cells into different risk groups. To begin, we apply K-means clustering to determine a threshold for the attention scores.
data = adata.obs['attention'].values.reshape(-1, 1)
kmeans = KMeans(n_clusters=2, random_state=42)
kmeans.fit(data)
## KMeans(n_clusters=2, random_state=42)
labels = kmeans.labels_
cluster_centers = kmeans.cluster_centers_
atten_thr = cluster_centers.flatten().mean()
print("attention cutoff:", atten_thr)
## attention cutoff: 0.5012851
Next, based on the attention threshold and the sign of the hazard_adj score, we classify cells into distinct risk groups. Specifically, we define three categories of cells: higher-risk, lower-risk, and inattentive.
atten_thr <- py$atten_thr
risk_group <- rep('inattentive', dim(data)[2])
risk_group[(data$hazard_adj > 0 & data$attention > atten_thr)] <- 'higher'
risk_group[(data$hazard_adj < 0 & data$attention > atten_thr)] <- 'lower'
data$surv.risk.group <- factor(risk_group, levels=c('higher', 'lower', 'inattentive'))
DimPlot(object = data, reduction = 'umap', cols = c('red','blue','grey'), group.by = 'surv.risk.group', pt.size = 0.5, label = T)
Based on the inferred cell risk groups, we can evaluate the accuracy of cell identification by computing a confusion matrix with respect to the ground-truth risk-associated subpopulations, as well as performance metrics such as precision, recall, and F1 score.
data$predicted.risk.group = recode(data$surv.risk.group, 'higher'='bad.survival', 'lower'='good.survival', 'inattentive'='other')
cm <- confusionMatrix(data$predicted.risk.group, data$sim.ground.truth)
## Warning in confusionMatrix.default(data$predicted.risk.group,
## data$sim.ground.truth): Levels are not in the same order for reference and
## data. Refactoring data to match.
precision <- mean(cm$byClass[, 'Pos Pred Value'])
recall <- mean(cm$byClass[, 'Sensitivity'])
f1_score <- 2 * precision * recall / (precision + recall)
# confusion matrix
cm_table <- table(data$sim.ground.truth, data$surv.risk.group)
cm_df <- as.data.frame(as.table(cm_table))
colnames(cm_df) <- c("actual", "predicted", "Freq")
# row normalization
cm_df <- cm_df %>%
group_by(actual) %>%
mutate(Proportion = Freq / sum(Freq)) %>%
ungroup()
# visualization
precision_str <- sprintf("Precision: %.3f", precision)
recall_str <- sprintf("Recall: %.3f", recall)
f1_str <- sprintf("F1-score: %.3f", f1_score)
ggplot(data = cm_df, aes(x = predicted, y = actual, fill = Proportion)) +
geom_tile() +
geom_text(aes(label = Freq), color = "white", size = 5) + # 仍然显示原始数量
scale_fill_gradient(low = "white", high = "red") + # 颜色按行比例填充
theme_minimal() +
labs(title = "Confusion Matrix", x = "Detected", y = "Actual") +
theme(
axis.line = element_line(color = "black"),
axis.ticks = element_line(color = "black"),
axis.text = element_text(size = 12, color = "black"),
axis.title = element_text(size = 14, face = "bold")
) +
annotate("text", x = max(as.numeric(cm_df$predicted))-0.21, y = 3.4, label = precision_str, hjust = 0, size = 3, color = "black") +
annotate("text", x = max(as.numeric(cm_df$predicted))-0.21, y = 3.2, label = recall_str, hjust = 0, size = 3, color = "black") +
annotate("text", x = max(as.numeric(cm_df$predicted))-0.21, y = 3, label = f1_str, hjust = 0, size = 3, color = "black")
scSurvival can also be used as a risk prediction tool to estimate patient survival directly from single-cell expression data. To assess its performance at the patient level, we employ K-fold cross-validation to evaluate the predictive accuracy of the scSurvival model in estimating patient survival risk.
from sklearn.model_selection import KFold
import io
import contextlib
f = io.StringIO()
from lifelines.utils import concordance_index
from scipy.stats import percentileofscore
patients = adata.obs['sample'].unique()
# K fold cross validation
cv_hazards_adj_cells = np.zeros((adata.shape[0], ))
surv['cv_hazards_adj_patient'] = 0.0
surv['cv_hazard_percentile_patient'] = 0.0
cindexs = []
surv_test_all_folds = []
kf = KFold(n_splits=5, shuffle=True, random_state=42)
for i, (train_index, test_index) in enumerate(kf.split(patients)):
print(f'fold {i}, train_size: {train_index.shape[0]}, test_size: {test_index.shape[0]}')
train_patients = patients[train_index]
test_patients = patients[test_index]
# train
adata_train = adata[adata.obs['sample'].isin(train_patients), :]
surv_train = surv.loc[surv.index.isin(train_patients), :].copy()
adata_train, surv_train, model = scSurvivalRun(
adata_train,
sample_column='sample',
surv=surv_train,
# batch_key='batch',
feature_flavor='AE',
entropy_threshold=0.7,
pretrain_epochs=200,
epochs=500,
patience=100,
fitnetune_strategy='alternating', # jointly, alternating, alternating_lightly,
)
train_cindex = concordance_index(surv_train['time'], -surv_train['patient_hazards'], surv_train['status'])
print(f'train c-index: {train_cindex:.4f}'
# test
print('testing...')
adata_test = adata[adata.obs['sample'].isin(test_patients), :]
with contextlib.redirect_stdout(f):
for test_patient in test_patients:
adata_test_patient = adata_test[adata_test.obs['sample'] == test_patient, :].copy()
adata_test_patient, patient_hazard = PredictIndSample(adata_test_patient, adata_train, model)
cv_hazards_adj_cells[adata.obs['sample'] == test_patient] = adata_test_patient.obs['hazard_adj'].values
surv.loc[surv.index == test_patient, 'cv_hazards_adj_patient'] = patient_hazard
surv.loc[surv.index == test_patient, 'cv_hazard_percentile_patient'] = percentileofscore(surv_train['patient_hazards'], patient_hazard, kind='rank')
surv_test = surv.loc[surv.index.isin(test_patients), :]
c_index = concordance_index(surv_test['time'], -surv_test['cv_hazards_adj_patient'], surv_test['status'])
cindexs.append(c_index)
surv_test_all_folds.append(surv_test)
print(f'c-index: {c_index:.4f}')
print('='*50)
mean_cindex = np.mean(cindexs)
std_cindex = np.std(cindexs)
print(f'mean c-index: {mean_cindex:.4f} ± {std_cindex:.4f}')
By collecting the risk prediction results (quantiles of relative risk scores) from all folds, we can further evaluate the overall predictive accuracy and the risk stratification capability of the model across the entire patient cohort.
df = py$surv
cindex <- rcorr.cens(-data$cv_hazard_percentile_patient, Surv(data$time, data$status))[['C Index']]
print(paste("Over all C-index:", cindex))
df$risk_group <- ifelse(df$cv_hazard_percentile_patient > median(df$cv_hazard_percentile_patient), "High risk", "Low risk")
fit <- survfit(Surv(time, status) ~ risk_group, data=df)
p <- ggsurvplot(fit,
legend.title = "Risk group",
legend.labs = c("High risk", "Low risk"),
palette = c("red", "blue"),
risk.table = F,
pval = TRUE,
conf.int = TRUE,
title = "KM curve of risk groups (5-fold CV)"
)
p
For the second example, we generate a simulated single-cell cohort dataset with batch effects. The simulation process is similar to the previous example, but we introduce batch effects by simulating two batches of cells with different proportions of good.survival and bad.survival cells. We then convert the simulated expression data and group labels into a Seurat object.
sim.groups <- splatSimulateGroups(batchCells = c(6000, 4000), nGenes=5000,
#group.prob = c(0.9, 0.05, 0.05),
group.prob = c(0.7, 0.15, 0.15),
de.prob = c(0.2, 0.06, 0.06), de.facLoc = c(0.1, 0.1, 0.1),
de.facScale = 0.4,
seed = 5)#
data <- CreateSeuratObject(counts = counts(sim.groups), project = 'ScSurvival_Single_Cell', min.cells = 100, min.features = 100)
data <- AddMetaData(object = data, metadata = sim.groups$Group, col.name = "sim.group")
data$sim.ground.truth <- recode(data$sim.group,'Group1'='other', 'Group2'='good.survival', 'Group3'='bad.survival')
data$batch <- c(rep('Batch1', 6000), rep('Batch2', 4000))
data <- NormalizeData(object = data, normalization.method = "LogNormalize",
scale.factor = 10000)
data <- FindVariableFeatures(object = data, selection.method = 'vst', nfeatures=2000)
var_features_genes = VariableFeatures(data)
data <- ScaleData(object = data)
data <- RunPCA(object = data, features = VariableFeatures(data))
data <- RunUMAP(object = data, dims = 1:10, n.neighbors = 5, min.dist=0.5, spread=1.5)
DimPlot(object = data, reduction = 'umap', cols = c('grey','blue','red'), group.by = 'sim.group', pt.size = 0.5, label = F)
DimPlot(object = data, reduction = 'umap', group.by = 'batch', pt.size = 0.5, label = F)
We then simulate single-cell data for 50 patients from each of the two
batches. Within each batch, the simulation of patients’ cellular
composition and survival status follows the same approach as before.
data_save_path <- './sim_data_with_batch/'
scdata_save_path <- sprintf('%s/single_cell/', data_save_path)
dir.create(scdata_save_path, recursive=T)
Expression_pbmc <- as.matrix(data@assays[["RNA"]]@layers[["data"]])
rownames(Expression_pbmc) <- rownames(data)
colnames(Expression_pbmc) <- colnames(data)
Expression_pbmc <- Expression_pbmc[VariableFeatures(data), ]
###---simulation---------------------
set.seed(1)
sampled_cells = 1000
bulk_num=50
###---simulation from batch1---------------------
other_cells <- colnames(Expression_pbmc)[data$sim.ground.truth=='other' & data$batch=='Batch1']
good_cells <- colnames(Expression_pbmc)[data$sim.ground.truth=='good.survival' & data$batch=='Batch1']
bad_cells <- colnames(Expression_pbmc)[data$sim.ground.truth=='bad.survival' & data$batch=='Batch1']
num_good <- length(good_cells)
num_bad <- length(bad_cells)
bulk_condition = NULL
censor_prob = 0.1
status = NULL
surv_time = NULL
num_good_cond_cells = NULL
num_bad_cond_cells = NULL
for (i in 1:bulk_num){
ratio <- (i-1) / (bulk_num-1)
num_good_cond_cells_i = round(num_good * ratio)
num_bad_cond_cells_i = round(num_bad * (1-ratio))
condition_good_cells <- good_cells[sample(num_good, num_good_cond_cells_i , replace=TRUE)]
condition_bad_cells <- bad_cells[sample(num_bad, num_bad_cond_cells_i, replace=TRUE)]
condition_cells <- c(condition_good_cells, condition_bad_cells, other_cells)
num_good_cond_cells = c(num_good_cond_cells, num_good_cond_cells_i)
num_bad_cond_cells = c(num_bad_cond_cells, num_bad_cond_cells_i)
Expression_condition = Expression_pbmc[, condition_cells]
Expression_selected <- Expression_condition[, sample(ncol(Expression_condition),size=sampled_cells,replace=TRUE)]
write.csv(Expression_selected, file = sprintf('%s/%d.csv', scdata_save_path, i))
if (runif(1, min = 0, max = 1) < censor_prob){
status = c(status, 0)
surv_time = c(surv_time, sample(i, 1))
}
else{
surv_time = c(surv_time, i)
status = c(status, 1)
}
}
###---simulation from batch2---------------------
other_cells <- colnames(Expression_pbmc)[data$sim.ground.truth=='other' & data$batch=='Batch2']
good_cells <- colnames(Expression_pbmc)[data$sim.ground.truth=='good.survival' & data$batch=='Batch2']
bad_cells <- colnames(Expression_pbmc)[data$sim.ground.truth=='bad.survival' & data$batch=='Batch2']
num_good <- length(good_cells)
num_bad <- length(bad_cells)
censor_prob = 0.1
for (i in 1:bulk_num){
ratio <- (i-1) / (bulk_num-1)
num_good_cond_cells_i = round(num_good * ratio)
num_bad_cond_cells_i = round(num_bad * (1-ratio))
condition_good_cells <- good_cells[sample(num_good, num_good_cond_cells_i , replace=TRUE)]
condition_bad_cells <- bad_cells[sample(num_bad, num_bad_cond_cells_i, replace=TRUE)]
condition_cells <- c(condition_good_cells, condition_bad_cells, other_cells)
num_good_cond_cells = c(num_good_cond_cells, num_good_cond_cells_i)
num_bad_cond_cells = c(num_bad_cond_cells, num_bad_cond_cells_i)
Expression_condition = Expression_pbmc[, condition_cells]
Expression_selected <- Expression_condition[, sample(ncol(Expression_condition),size=sampled_cells,replace=TRUE)]
write.csv(Expression_selected, file = sprintf('%s/%d.csv', scdata_save_path, i+50))
if (runif(1, min = 0, max = 1) < censor_prob){
status = c(status, 0)
surv_time = c(surv_time, sample(i, 1))
}
else{
surv_time = c(surv_time, i)
status = c(status, 1)
}
}
bulk_names <- paste0('bulk', 1:(bulk_num*2))
surv_info <- data.frame(
time=surv_time,
status=status,
num.good.cells = num_good_cond_cells,
num.bad.cells = num_bad_cond_cells,
row.names = bulk_names
)
###-----save---------------------
labels <- data$sim.ground.truth
labels <- as.data.frame(labels)
row.names(labels) <- colnames(data)
write.csv(labels, file=sprintf('%s/sim_groups.csv', data_save_path))
write.csv(Expression_pbmc, file = sprintf('%s/%s.csv', scdata_save_path, 'all_cells'))
write.csv(surv_info, file=sprintf('%s/surv_info.csv', data_save_path))
batch_ids <- as.data.frame(data$batch)
row.names(batch_ids) <- colnames(data)
colnames(batch_ids) <- 'batch_ids'
write.csv(batch_ids, file=sprintf('%s/batch_ids.csv', data_save_path))
When running scSurvival, specify the batch_key parameter to enable scSurvival to automatically handle batch effects.
from scSurvival import scSurvivalRun, PredictIndSample
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
xs = []
samples = []
batches = []
for i in range(1, 101):
df = pd.read_csv(f'{r.scdata_save_path}/{i}.csv', index_col=0)
xs.append(df.values.T)
samples.extend(['bulk%d' % i] * df.shape[1])
if i <= 50:
batches.extend(['batch1'] * df.shape[1])
else:
batches.extend(['batch2'] * df.shape[1])
obs_df = pd.DataFrame({'sample': samples, 'batch': batches})
obs_df.index = np.arange(len(samples))
X = np.concatenate(xs, axis=0)
adata = sc.AnnData(X, obs=obs_df)
clinic = pd.read_csv(f'{r.data_save_path}/surv_info.csv', index_col=0)
surv = clinic[['time', 'status']].copy()
surv['time'] = surv['time'].astype(float)
surv['status'] = surv['status'].astype(int)
adata, surv, model = scSurvivalRun(adata,
sample_column='sample',
surv=surv,
batch_key='batch',
feature_flavor='AE',
entropy_threshold=0.7,
pretrain_epochs=200,
epochs=500,
patience=100,
fitnetune_strategy='alternating'
)
sns.histplot(adata.obs['attention'], bins=50)
plt.show()
plt.close()
# calculate the attention threshold
from sklearn.cluster import KMeans
data = adata.obs['attention'].values.reshape(-1, 1)
kmeans = KMeans(n_clusters=2, random_state=42)
kmeans.fit(data)
labels = kmeans.labels_
cluster_centers = kmeans.cluster_centers_
atten_thr = cluster_centers.flatten().mean()
print("cutoff:", atten_thr)
As in the first example, we can use the trained model to perform inference on the initially simulated single-cell data and visualize the results.
df = pd.read_csv(f'{r.scdata_save_path}/all_cells.csv', index_col=0)
x = df.values.T
sim_group = pd.read_csv(f'{r.data_save_path}/sim_groups.csv', index_col=0)
sim_group = sim_group['labels'].values
batches = pd.read_csv(f'{r.data_save_path}/batch_ids.csv', index_col=0)
batches = batches['batch_ids'].values
batches= [each.lower() for each in batches]
adata_new = sc.AnnData(x, obs=pd.DataFrame(sim_group, index=np.arange(x.shape[0]).astype(str), columns=['sim_group']))
adata_new.obs['batch'] = batches
batches = model.le.transform(adata_new.obs['batch'].values)
exp = adata_new.X
h, a, cell_hazards, cell_hazards_weighted = model.predict_cells(exp, batch_labels=batches)
adata_new.obsm['X_ae'] = h.cpu().detach().numpy()
adata_new.obs['hazard'] = cell_hazards.cpu().detach().numpy()
adata_new.obs['attention'] = a.cpu().detach().numpy()
adata_new.obs['hazard_adj'] = cell_hazards_weighted.cpu().detach().numpy()
attention = adata_new.obs['attention'].values
hazard_adj = adata_new.obs['hazard_adj'].values
hazard = adata_new.obs['hazard'].values
X_ae = adata_new.obsm['X_ae']
data$attention <- py$attention
data$hazard_adj <- py$hazard_adj
data$hazard <- py$hazard
atten_thr <- py$atten_thr
risk_group <- rep('inattentive', dim(data)[2])
risk_group[(data$hazard_adj > 0 & data$attention > atten_thr)] <- 'higher'
risk_group[(data$hazard_adj < 0 & data$attention > atten_thr)] <- 'lower'
data$surv.risk.group <- factor(risk_group, levels=c('higher', 'lower', 'inattentive'))
cols = c("blue","lightgrey", "red")
FeaturePlot(data, features = c('attention'), pt.size = 0.5) + scale_colour_gradientn(colours=c("lightgrey", "blue"))
## Scale for colour is already present.
## Adding another scale for colour, which will replace the existing scale.
FeaturePlot(data, features = c('hazard_adj'), pt.size = 0.5) + scale_colour_gradientn(colours=cols)
## Scale for colour is already present.
## Adding another scale for colour, which will replace the existing scale.
DimPlot(object = data, reduction = 'umap', cols = c('red','blue','grey'), group.by = 'surv.risk.group', pt.size = 0.5, label = T)
At the same time, we can extract the cell embeddings (X_ae) generated by the scSurvival model and add them to the Seurat object as a new low-dimensional representation. Using this embedding, we generate new UMAP coordinates for visualization. As shown, the batch effects have been effectively corrected.
emd_ae <- py$X_ae
rownames(emd_ae) <- colnames(data)
data[["X_ae"]] <- CreateDimReducObject(embeddings = emd_ae, key = "AE_", assay = "RNA")
data <- RunUMAP(object = data, reduction = 'X_ae', dims = 1:ncol(emd_ae), n.neighbors = 5, min.dist=0.5, spread=1.5)
DimPlot(object = data, cols = c('red','blue'), group.by = 'batch', pt.size = 0.5, label = F)
FeaturePlot(data, features = c('attention'), pt.size = 0.5) + scale_colour_gradientn(colours=c("lightgrey", "blue"))
FeaturePlot(data, features = c('hazard_adj'), pt.size = 0.5) + scale_colour_gradientn(colours=cols)
DimPlot(object = data, reduction = 'umap', cols = c('red','blue','grey'), group.by = 'surv.risk.group', pt.size = 0.5, label = T)
When batch effects exist either within the training data or between the training and test datasets, directly applying a model trained by scSurvival to predict survival risk on independent data may yield inaccurate results. This is because the batch labels of the test data were never seen during training, which prevents the model from performing reliable inference. In such cases, the PredictIndSample function is not applicable.
To address this issue, we recommend combining the test data (along with their batch labels) with the training data and re-training the model using the scSurvivalRun function. This enables transfer learning by jointly learning from both datasets. Upon completion, scSurvivalRun will return an updated adata object and a surv DataFrame that include prediction results for the test data.
R packages loaded in this tutorial:
Seurat 5.1.0
reticulate
1.38.0
ggplot2 3.5.1
gridExtra 2.3
ggpubr 0.6.0
dplyr 1.1.4
splatter 1.28.0
caret 6.0-94
survival 3.6-4
Hmisc 5.2-1
survminer 0.5.0
Python packages that scSurvival
depends on:
torch
2.4.0+cu124
numpy 1.26.4
pandas 2.2.2
scanpy 1.10.2
scikit-learn 1.4.2
scipy 1.13.1
lifelines 0.30.0