Main Stack Code Explained With Loss

Overview

This guide focuses on the main execution path of the Stack model and adds the loss logic that is needed to understand training behavior.[web:11][cite:7] Stack is described as a single-cell foundation model that uses tabular attention so each cell representation is informed both by its own gene structure and by neighboring cells in the same context set.[web:10][web:13][web:19]

The most important files for this path are src/stack/models/core/base.py, which defines the top-level forward pass, src/stack/modules/attention.py, which defines TabularAttentionLayer, and src/stack/models/core/losses.py, which contains the reconstruction and evaluation logic.[cite:7][cite:9][cite:6]

Main forward code

The central model execution path is the forward() method of StateICLModelBase.[cite:7]

def forward(
    self,
    features: torch.Tensor,
    return_loss: bool = True,
) -> Dict[str, torch.Tensor]:
    batch_size, n_cells, _ = features.shape
    device = features.device

    original_features = features.clone()
    observed_lib_size = original_features.sum(dim=-1, keepdim=True)

    features = torch.log1p(features)

    masked_features, mask = self.apply_mask(features)

    tokens = self._reduce_and_tokenize(masked_features)
    x = self._run_attention_layers(tokens)
    final_cell_embeddings = x.reshape(batch_size, n_cells, -1)

    nb_mean, nb_dispersion, px_scale = self._compute_nb_parameters(
        final_cell_embeddings, observed_lib_size
    )

    result = {
        "nb_mean": nb_mean,
        "nb_dispersion": nb_dispersion,
        "px_scale": px_scale,
        "observed_lib_size": observed_lib_size,
        "mask": mask,
        "cell_embeddings": final_cell_embeddings,
        "masked_features": masked_features,
        "original_features": original_features,
    }

    if return_loss:
        recon_loss, _ = self._compute_reconstruction_loss(
            nb_mean, nb_dispersion, original_features, mask
        )
        sw_loss = self._compute_sw_loss(final_cell_embeddings)
        total_loss = recon_loss + self.sw_weight * sw_loss

        result.update(
            {
                "loss": total_loss,
                "recon_loss": recon_loss,
                "sw_loss": sw_loss,
            }
        )

        if not self.training:
            metrics = self._compute_eval_metrics(nb_mean, original_features, mask)
            result.update(metrics)
        else:
            zero = torch.tensor(0.0, device=device, dtype=nb_mean.dtype)
            result.update(
                {
                    "masked_mae": zero,
                    "masked_corr": zero,
                    "mask_rate": zero,
                }
            )

    return result

What the forward pass does

The model starts from a tensor of raw counts shaped as (batch_size, n_cells, n_genes) and preserves a clone of the original counts before any transformation.[cite:7] It computes the per-cell library size as observed_lib_size, log-transforms the counts with torch.log1p, and then masks a random subset of genes to create the reconstruction task.[cite:7]

The masked tensor is projected into a token grid by _reduce_and_tokenize, passed through _run_attention_layers, and flattened into one embedding per cell.[cite:7] These embeddings are decoded into nb_mean, nb_dispersion, and px_scale, which together define the count reconstruction distribution.[cite:7]

Where attention is used in the core

The core model creates a stack of tabular attention layers in __init__ using nn.ModuleList, so the model repeatedly applies the same attention pattern across multiple layers.[cite:7] Each element of self.layers is a TabularAttentionLayer, defined in src/stack/modules/attention.py.[cite:7][cite:9]

self.layers = nn.ModuleList(
    [
        TabularAttentionLayer(
            token_dim=token_dim,
            n_cells=n_cells,
            n_hidden=n_hidden,
            n_heads=n_heads,
            mlp_ratio=mlp_ratio,
            dropout=dropout,
        )
        for _ in range(n_layers)
    ]
)

The actual call happens in _run_attention_layers, which loops through self.layers and applies each attention block to the token tensor in sequence.[cite:7]

def _run_attention_layers(
    self,
    tokens: torch.Tensor,
    gene_attn_mask: Optional[torch.Tensor] = None,
    return_attn: bool = False,
):
    attn_maps: List[torch.Tensor] = []
    x = tokens
    for layer in self.layers:
        x, attn = layer(x, self.gene_pos_embedding, gene_attn_mask, return_attn)
        if return_attn:
            attn_maps.append(attn)
    if return_attn:
        return x, attn_maps
    return x

This means the line x = self._run_attention_layers(tokens) in forward() is the exact moment where the core model enters the tabular attention stack.[cite:7][cite:9]

Main attention code

The key logic for tabular attention is implemented in TabularAttentionLayer.forward.[cite:9]

def forward(
    self,
    x: torch.Tensor,
    gene_pos_emb: torch.Tensor,
    gene_attn_mask: Optional[torch.Tensor] = None,
    return_attn: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
    batch_size, n_cells, n_genes, token_dim = x.shape

    x_cell = x.reshape(batch_size * n_cells, n_genes, token_dim)
    x_cell_with_pos = x_cell + gene_pos_emb.unsqueeze(0)
    cell_attn_out, _ = self.cell_attn(x_cell_with_pos)
    x_cell = self.cell_norm(x_cell + cell_attn_out)

    x = x_cell.reshape(batch_size, n_cells, n_genes, token_dim)
    x_gene = x.reshape(batch_size, n_cells, n_genes * token_dim)

    if return_attn:
        gene_attn_out, attn = self.gene_attn(x_gene, attn_mask=gene_attn_mask, return_attn=True)
    else:
        gene_attn_out, attn = self.gene_attn(x_gene, attn_mask=gene_attn_mask)
    x_gene = self.gene_norm(x_gene + gene_attn_out)

    x = x_gene.reshape(batch_size, n_cells, n_genes, token_dim)
    mlp_input = x.reshape(-1, token_dim)
    mlp_out = self.mlp(mlp_input)
    x = self.mlp_norm(mlp_input + mlp_out).reshape(batch_size, n_cells, n_genes, token_dim)

    return x, attn

The first reshape allows attention over gene tokens within each cell, while the second reshape allows attention across cells in the same batch chunk.[cite:9] Arc Institute’s Stack description emphasizes this alternating intra-cellular and inter-cellular information flow as the central novelty of the tabular attention design.[web:10][web:13][web:19]

Decoder code

The decoder that maps embeddings back to gene-count parameters is implemented in _compute_nb_parameters in base.py.[cite:7]

def _compute_nb_parameters(
    self,
    final_cell_embeddings: torch.Tensor,
    observed_lib_size: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    batch_size, n_cells, _ = final_cell_embeddings.shape
    flat_embeddings = final_cell_embeddings.reshape(batch_size * n_cells, -1)
    output = self.output_mlp(flat_embeddings)
    output = output.reshape(batch_size, n_cells, self.n_genes, 2)

    px_scale_logits = output[..., 0]
    nb_dispersion = F.softplus(output[..., 1])
    px_scale = F.softmax(px_scale_logits, dim=-1)
    nb_mean = px_scale * observed_lib_size
    return nb_mean, nb_dispersion, px_scale

This decoder produces a distribution over genes through px_scale, ensures positive dispersion with softplus, and rescales predicted gene fractions by the observed library size to get the Negative Binomial mean.[cite:7] The design keeps the total expected counts tied to the cell’s measured sequencing depth while letting the model learn gene-wise composition.[cite:7]

Loss code

The forward pass calls three loss-related helpers: _compute_reconstruction_loss, _compute_sw_loss, and _compute_eval_metrics.[cite:7][cite:6] The first two contribute to the training objective, while the last one is used mainly for evaluation outputs.[cite:7]

The key top-level loss lines in forward() are these.[cite:7]

recon_loss, _ = self._compute_reconstruction_loss(
    nb_mean, nb_dispersion, original_features, mask
)
sw_loss = self._compute_sw_loss(final_cell_embeddings)
total_loss = recon_loss + self.sw_weight * sw_loss

This means the final objective is a weighted sum of reconstruction quality and embedding regularization.[cite:7] The coefficient self.sw_weight decides how much the Sliced Wasserstein term contributes relative to the reconstruction term.[cite:7]

Reconstruction loss meaning

_compute_reconstruction_loss compares the decoded Negative Binomial distribution against the original raw counts and uses the mask to focus on the hidden entries that the model had to infer.[cite:7][cite:6] Conceptually, this is a masked reconstruction objective over genes, which is analogous to masked modeling in language or vision but adapted to count-valued single-cell expression data.[cite:7][web:15]

The inputs tell you exactly what it needs: nb_mean and nb_dispersion define the predicted distribution, original_features provides the true counts, and mask identifies where supervision should be applied.[cite:7] This design prevents the model from simply copying visible genes and instead pushes it to recover missing expression values using cell context and gene structure.[cite:7][web:19]

Sliced Wasserstein loss meaning

_compute_sw_loss(final_cell_embeddings) adds a regularization term on the learned cell embeddings.[cite:7] In the model constructor, the base class initializes self.sw_distance = SlicedWassersteinDistance(n_proj=n_proj), which shows that this loss is based on a sliced Wasserstein distance module applied in embedding space.[cite:7]

The purpose of this term is to shape the embedding distribution so it does not degenerate while training only on masked reconstruction.[cite:7] In practice, it encourages a better-behaved representation geometry for cells, complementing the local gene reconstruction objective with a distribution-level constraint.[cite:7][web:43][web:44]

Evaluation metrics code path

When the model is in evaluation mode, forward() calls _compute_eval_metrics(nb_mean, original_features, mask) and merges those metrics into the return dictionary.[cite:7] When the model is in training mode, the same keys are filled with zeros so downstream logging code sees a consistent schema without paying the cost of evaluation metrics every step.[cite:7]

if not self.training:
    metrics = self._compute_eval_metrics(nb_mean, original_features, mask)
    result.update(metrics)
else:
    zero = torch.tensor(0.0, device=device, dtype=nb_mean.dtype)
    result.update(
        {
            "masked_mae": zero,
            "masked_corr": zero,
            "mask_rate": zero,
        }
    )

This pattern is useful because training loops can always expect the same keys in result, even though the expensive metric calculations only happen during evaluation.[cite:7]

How to study the model

To understand the model deeply, read the code in this order: forward() in base.py, _run_attention_layers() in base.py, TabularAttentionLayer.forward() in attention.py, _compute_nb_parameters() in base.py, and then the helpers in losses.py.[cite:7][cite:9][cite:6] That order exactly follows the execution path from raw counts to embeddings to reconstruction parameters to losses.[cite:7]

For each block, track four things: tensor shape, biological meaning, modeling purpose, and whether the computation is used for training, inference, or evaluation.[cite:7][cite:9] This approach works especially well for Stack because most of the complexity comes from reshaping the same tensor to support different attention views and then decoding it back into count-space outputs.[cite:7][cite:9]

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐