Gene regression

This set of classes/functions can be used to predict the gene expression profile based on:

  1. the cell-type and

  2. cell-specific covariates (such as location in space and local micro-environment).

The intended use case is to test whether the semantic features extracted by one the Self-Supervised Lerning (ssl) algorithm contain biologically relevant information.

We treat each cell-type separately. For each cell, the gene counts are modelled as a multinomial distribution:

\[c \sim \frac{N!}{n_1!n_2!\dots n_g!} p_1^{n_1} p_2^{n_2} \dots p_g^{n_g}\]

where \(N=\sum_{g=1}^G n_g\) is the total number of counts in the cell (sometimes referred as the total UMI count) and \(\sum_{g=1}^G p_g=1\) are the probabilities of measuring each gene. When \(N\) is large and \(p_i\) are small the counts for each gene can be approximated by a Poisson distribution with rate \(r_i = N p_i\). Therefore the counts for cell \(n\) and gene \(g\) are modelled as:

\[c_{ng} \sim \text{Poi}( r_{ng} = N_n \, p_{ng})\]

To account for noise and the presence of L (cell-specific) covariates, we model the probability as:

\[\log p_{ng} = \left( \beta_g^0 + \sum_l \beta_{lg} X_{nl} \right)\]

where \(\beta_g^0\) is a gene-specific intercepts, \(X_{nl}\) are the cell covariates

We recap the dimension of the variable involved in the full model (for K different cell-types):

  1. \(X_{nl}\) is a fixed covariate matrix of shape \(N \times L\) (i.e. cells by covariates)

  2. \(N_n\) is a fixed vector of shape \(N\) with the (observed) total counts in a cell.

  3. \(\beta_{kg}^0\) is the intercepts of the regression of shape \(K \times G\) (i.e. cell-types by genes)

  4. \(\beta_{klg}\) are the regression coefficients of shape \(K \times L \times G\) (i.e. cell-types by covariates by genes)

Typical values are \(N \sim 10^3, G \sim 10^3, K\sim 10, L\sim 50\). The goal of the inference is to determine \(\beta^0\) and \(\beta\). We enforce a penalty (either L1 or a L2) on the regression coefficients \(\beta\) to encourage them to be small. There is no prior on \(\beta^0\). Overall the model has one hyper-parameter (the strength of the regularization on \(\beta\)) which is determined by cross-validation.

See notebook3 for an example.

class GeneDataset(covariates: torch.Tensor, cell_type_ids: torch.Tensor, cell_type_props: torch.Tensor, counts: torch.Tensor, k_cell_types: int, cell_type_mapping: dict, gene_names: List[str])[source]

Container for organizing the gene expression data

cell_type_ids: torch.Tensor

long tensor with the cell_type_ids of shape (n)

cell_type_mapping: dict

dictionary with mapping from unique_cell_type to cell_type_ids

cell_type_props: torch.Tensor

float tensor with the cell type proportions of shape (n, k)

counts: torch.Tensor

long tensor with the count data of shape (n, g)

covariates: torch.Tensor

float tensor with the covariates of shape (n, k)

describe()[source]

Method which described the content and the GeneDataset.

gene_names: List[str]

list of the gene names

k_cell_types: int

number of cell types

make_gene_dataset_from_anndata(anndata: scanpy.AnnData, cell_type_key: str, cell_type_prop_key: str, covariate_key: str, preprocess_strategy: str = 'raw', apply_pca: bool = False, n_components: int | float = 0.9) GeneDataset[source]

Convert a anndata object into a GeneDataset object which can be used for gene regression.

Parameters:
  • anndata – AnnData object with the raw counts stored in anndata.X

  • cell_type_key – key corresponding to the cell type, i.e. cell_types = anndata.obs[cell_type_key]

  • covariate_key – key corresponding to the covariate, i.e. covariates = anndata.obsm[covariate_key]

  • cell_type_prop_keys – key coresponding to the proportions for all k cell types, i.e. prop of cell_type k = anndata.obs[cell_type_prop_keys[k]]

  • preprocess_strategy – either ‘center’, ‘z_score’ or ‘raw’. It describes how to preprocess the covariates. ‘raw’ (default) means no preprocessing.

  • apply_pca – if True, we compute the pca of the covariates. This operation happens after the preprocessing.

  • n_components – Used only if apply_pca == True. If integer specifies the dimensionality of the data after PCA. If float in (0, 1) it auto selects the dimensionality so that the explained variance is at least that value.

Returns:

GeneDataset – a GeneDataset object

train_test_val_split(data: List[torch.Tensor] | List[numpy.ndarray] | GeneDataset, train_size: float = 0.8, test_size: float = 0.1, val_size: float = 0.1, n_splits: int = 1, random_state: int = None, stratify: bool = True, spatial: bool = False)[source]

Utility function used to split the data into train/test/val.

Parameters:
  • data – the data to split into train/test/val

  • train_size – the relative size of the train dataset

  • test_size – the relative size of the test dataset

  • val_size – the relative size of the val dataset

  • n_splits – how many times to split the data

  • random_state – specify the random state for reproducibility

  • stratify – If true the train/test are stratified so that they contain approximately the same number of example from each class. If data is a list of arrays the 2nd array is assumed to represent the class. If data is a GeneDataset the class is the cell_type.

  • spatial – If true, the train/test are stratified based on spatial coordinates of patches. If both spatial and stratify are true, a spatial split that best preserves stratification will be found.

Returns:

tuple – yields multiple splits of the data.

Example

>>> for train, test, val in train_test_val_split(data=[x,y,z]):
>>>       x_train, y_train, z_train = train
>>>       x_test, y_test, z_test = test
>>>       x_val, y_val, z_val = val
>>>       ... do something ...

Example

>>> for train, test, val in train_test_val_split(data=GeneDataset):
>>>       assert isinstance(train, GeneDataset)
>>>       assert isinstance(test, GeneDataset)
>>>       assert isinstance(val, GeneDataset)
>>>       ... do something ...
plot_gene_hist(cell_types_n, value1_ng, value2_ng=None, bins=20) matplotlib.pyplot.Figure[source]

Plot the per cell-type histogram. If value2_ng is defined the two histogram are interlieved.

Parameters:
  • cell_types_n – tensor of shape N with the cell type labels (with K distinct values)

  • value1_ng – the first quantity to whose histogram is computed lot of shape (N,G)

  • value2_ng – the second quantity to plot of shape (N,G) (optional)

  • bins – number of bins in the histogram

Returns:

fig – A figure with G rows and K columns where K is the number of distinct cell types.