gsnn.gsnn.models.PathwayLatentRegularizer

class gsnn.gsnn.models.PathwayLatentRegularizer(*args: Any, **kwargs: Any)[source]

Bases: Module

Per-pathway latent-factor auxiliary loss for GSNN.

For each ResBlock \(\ell\) and minibatch of size \(B\):

  1. Reduce the cached activation \(A_\ell \in \mathbb{R}^{B \times N_{\text{func}} \times C_{pn}}\) to per-node scalars \(s_\ell \in \mathbb{R}^{B \times N_{\text{func}}}\) via \(\phi\).

  2. Compute per-pathway scores \(S_\ell = \mathrm{normalize}(M\, s_\ell^\top)^\top \in \mathbb{R}^{B \times P}\).

  3. Standardize across the batch dimension and compute the member-by-pathway correlation matrix \(C \in \mathbb{R}^{N_{\text{func}} \times P}\).

  4. Add the negative member-side correlation to \(L_{\text{sim}}\), and (if dissim_pairs is provided) the squared score-correlation of dissimilar pairs to \(L_{\text{dis}}\).

Cost is \(O(L \cdot B \cdot (N_{\text{func}} \cdot C_{pn} + P \cdot N_{\text{func}}))\) per minibatch and parameter count is either zero (phi='mean') or \(C_{pn}\) (phi='learned').

Parameters:
  • model (GSNN) – Reference model. Used only for shape introspection (ResBlocks[0].channel_groups) and to toggle :pyobj:`_store_activations`.

  • pathway_membership (Tensor) – Float tensor of shape (P, N_func). Real-valued entries are allowed and treated as soft membership weights; binary 0/1 entries give classical hard membership. Rows correspond to pathways, columns to function-node indices in the same order as the model’s function node names.

  • dissim_pairs (Tensor, optional) – (M, 2) LongTensor of pathway index pairs whose scores should be encouraged to be uncorrelated. (default: None)

  • lambda_sim (float, optional) – Scaling for the similarity term. (default: 0.1)

  • lambda_dis (float, optional) – Scaling for the dissimilarity term. (default: 0.0)

  • phi (str or nn.Module, optional) – Per-node scalar reduction. 'mean' averages across channels (no parameters). 'learned' uses a single nn.Linear projection Linear(C_pn, 1) shared across layers. A custom nn.Module may also be passed; it must accept a tensor of shape (B, N_func, C_pn) and return (B, N_func, 1) or (B, N_func). (default: 'mean')

  • eps (float, optional) – Small constant for numerical stability in the standardization step. (default: 1e-6)

__init__(model, pathway_membership, dissim_pairs=None, lambda_sim=0.1, lambda_dis=0.0, phi='mean', eps=1e-06)[source]

Methods

__init__(model, pathway_membership[, ...])

disable(model)

Disable activation caching and clear cached tensors on model.

enable(model)

Enable activation caching on every ResBlock of model.

loss(model)

Compute the auxiliary similarity / dissimilarity losses.

disable(model)[source]

Disable activation caching and clear cached tensors on model.

Useful at evaluation time to avoid retaining graph references. Returns self for chaining.

enable(model)[source]

Enable activation caching on every ResBlock of model.

Idempotent. Called automatically at construction time. Returns self for chaining.

loss(model)[source]

Compute the auxiliary similarity / dissimilarity losses.

Must be called after a training-mode forward pass so that each ResBlock has populated its :pyobj:`_last_activation`.

Parameters:

model (GSNN) – The same model passed to __init__().

Returns:

(L_sim, L_dis) — both already scaled by their respective lambda_*. L_dis is 0.0 if no dissim_pairs were provided.

Return type:

tuple