Gene regression¶
This set of classes/functions can be used to predict the gene expression profile based on:
the cell-type and
cell-specific covariates (such as location in space and local micro-environment).
The intended use case is to test whether the semantic features extracted by one the Self-Supervised Lerning (ssl) algorithm contain biologically relevant information.
We treat each cell-type separately. For each cell, the gene counts are modelled as a multinomial distribution:
where \(N=\sum_{g=1}^G n_g\) is the total number of counts in the cell (sometimes referred as the total UMI count) and \(\sum_{g=1}^G p_g=1\) are the probabilities of measuring each gene. When \(N\) is large and \(p_i\) are small the counts for each gene can be approximated by a Poisson distribution with rate \(r_i = N p_i\). Therefore the counts for cell \(n\) and gene \(g\) are modelled as:
To account for noise and the presence of L (cell-specific) covariates, we model the probability as:
where \(\beta_g^0\) is a gene-specific intercepts, \(X_{nl}\) are the cell covariates
We recap the dimension of the variable involved in the full model (for K different cell-types):
\(X_{nl}\) is a fixed covariate matrix of shape \(N \times L\) (i.e. cells by covariates)
\(N_n\) is a fixed vector of shape \(N\) with the (observed) total counts in a cell.
\(\beta_{kg}^0\) is the intercepts of the regression of shape \(K \times G\) (i.e. cell-types by genes)
\(\beta_{klg}\) are the regression coefficients of shape \(K \times L \times G\) (i.e. cell-types by covariates by genes)
Typical values are \(N \sim 10^3, G \sim 10^3, K\sim 10, L\sim 50\). The goal of the inference is to determine \(\beta^0\) and \(\beta\). We enforce a penalty (either L1 or a L2) on the regression coefficients \(\beta\) to encourage them to be small. There is no prior on \(\beta^0\). Overall the model has one hyper-parameter (the strength of the regularization on \(\beta\)) which is determined by cross-validation.
See notebook3 for an example.
- class GeneDataset(covariates: torch.Tensor, cell_type_ids: torch.Tensor, cell_type_props: torch.Tensor, counts: torch.Tensor, k_cell_types: int, cell_type_mapping: dict, gene_names: List[str])[source]¶
Container for organizing the gene expression data
- cell_type_ids: torch.Tensor¶
long tensor with the cell_type_ids of shape (n)
- cell_type_mapping: dict¶
dictionary with mapping from unique_cell_type to cell_type_ids
- cell_type_props: torch.Tensor¶
float tensor with the cell type proportions of shape (n, k)
- counts: torch.Tensor¶
long tensor with the count data of shape (n, g)
- covariates: torch.Tensor¶
float tensor with the covariates of shape (n, k)
- gene_names: List[str]¶
list of the gene names
- k_cell_types: int¶
number of cell types
- make_gene_dataset_from_anndata(anndata: scanpy.AnnData, cell_type_key: str, cell_type_prop_key: str, covariate_key: str, preprocess_strategy: str = 'raw', apply_pca: bool = False, n_components: int | float = 0.9) GeneDataset[source]¶
Convert a anndata object into a GeneDataset object which can be used for gene regression.
- Parameters:
anndata – AnnData object with the raw counts stored in anndata.X
cell_type_key – key corresponding to the cell type, i.e. cell_types = anndata.obs[cell_type_key]
covariate_key – key corresponding to the covariate, i.e. covariates = anndata.obsm[covariate_key]
cell_type_prop_keys – key coresponding to the proportions for all k cell types, i.e. prop of cell_type k = anndata.obs[cell_type_prop_keys[k]]
preprocess_strategy – either ‘center’, ‘z_score’ or ‘raw’. It describes how to preprocess the covariates. ‘raw’ (default) means no preprocessing.
apply_pca – if True, we compute the pca of the covariates. This operation happens after the preprocessing.
n_components – Used only if
apply_pca== True. If integer specifies the dimensionality of the data after PCA. If float in (0, 1) it auto selects the dimensionality so that the explained variance is at least that value.
- Returns:
GeneDataset – a GeneDataset object
- train_test_val_split(data: List[torch.Tensor] | List[numpy.ndarray] | GeneDataset, train_size: float = 0.8, test_size: float = 0.1, val_size: float = 0.1, n_splits: int = 1, random_state: int = None, stratify: bool = True, spatial: bool = False)[source]¶
Utility function used to split the data into train/test/val.
- Parameters:
data – the data to split into train/test/val
train_size – the relative size of the train dataset
test_size – the relative size of the test dataset
val_size – the relative size of the val dataset
n_splits – how many times to split the data
random_state – specify the random state for reproducibility
stratify – If true the train/test are stratified so that they contain approximately the same number of example from each class. If data is a list of arrays the 2nd array is assumed to represent the class. If data is a GeneDataset the class is the cell_type.
spatial – If true, the train/test are stratified based on spatial coordinates of patches. If both spatial and stratify are true, a spatial split that best preserves stratification will be found.
- Returns:
tuple – yields multiple splits of the data.
Example
>>> for train, test, val in train_test_val_split(data=[x,y,z]): >>> x_train, y_train, z_train = train >>> x_test, y_test, z_test = test >>> x_val, y_val, z_val = val >>> ... do something ...
Example
>>> for train, test, val in train_test_val_split(data=GeneDataset): >>> assert isinstance(train, GeneDataset) >>> assert isinstance(test, GeneDataset) >>> assert isinstance(val, GeneDataset) >>> ... do something ...
- plot_gene_hist(cell_types_n, value1_ng, value2_ng=None, bins=20) matplotlib.pyplot.Figure[source]¶
Plot the per cell-type histogram. If
value2_ngis defined the two histogram are interlieved.- Parameters:
cell_types_n – tensor of shape N with the cell type labels (with K distinct values)
value1_ng – the first quantity to whose histogram is computed lot of shape (N,G)
value2_ng – the second quantity to plot of shape (N,G) (optional)
bins – number of bins in the histogram
- Returns:
fig – A figure with G rows and K columns where K is the number of distinct cell types.