gsnn.models.PathwayLatentRegularizer
- class gsnn.models.PathwayLatentRegularizer(*args: Any, **kwargs: Any)[source]
Bases:
ModulePer-pathway latent-factor auxiliary loss for GSNN.
For each
ResBlock\(\ell\) and minibatch of size \(B\):Reduce the cached activation \(A_\ell \in \mathbb{R}^{B \times N_{\text{func}} \times C_{pn}}\) to per-node scalars \(s_\ell \in \mathbb{R}^{B \times N_{\text{func}}}\) via \(\phi\).
Compute per-pathway scores \(S_\ell = \mathrm{normalize}(M\, s_\ell^\top)^\top \in \mathbb{R}^{B \times P}\).
Standardize across the batch dimension and compute the member-by-pathway correlation matrix \(C \in \mathbb{R}^{N_{\text{func}} \times P}\).
Add the negative member-side correlation to \(L_{\text{sim}}\), and (if
dissim_pairsis provided) the squared score-correlation of dissimilar pairs to \(L_{\text{dis}}\).
Cost is \(O(L \cdot B \cdot (N_{\text{func}} \cdot C_{pn} + P \cdot N_{\text{func}}))\) per minibatch and parameter count is either zero (
phi='mean') or \(C_{pn}\) (phi='learned').- Parameters:
model (GSNN) – Reference model. Used only for shape introspection (
ResBlocks[0].channel_groups) and to toggle :pyobj:`_store_activations`.pathway_membership (Tensor) – Float tensor of shape
(P, N_func). Real-valued entries are allowed and treated as soft membership weights; binary 0/1 entries give classical hard membership. Rows correspond to pathways, columns to function-node indices in the same order as the model’sfunctionnode names.dissim_pairs (Tensor, optional) –
(M, 2)LongTensor of pathway index pairs whose scores should be encouraged to be uncorrelated. (default:None)lambda_sim (float, optional) – Scaling for the similarity term. (default:
0.1)lambda_dis (float, optional) – Scaling for the dissimilarity term. (default:
0.0)phi (str or nn.Module, optional) – Per-node scalar reduction.
'mean'averages across channels (no parameters).'learned'uses a singlenn.LinearprojectionLinear(C_pn, 1)shared across layers. A customnn.Modulemay also be passed; it must accept a tensor of shape(B, N_func, C_pn)and return(B, N_func, 1)or(B, N_func). (default:'mean')eps (float, optional) – Small constant for numerical stability in the standardization step. (default:
1e-6)
- __init__(model, pathway_membership, dissim_pairs=None, lambda_sim=0.1, lambda_dis=0.0, phi='mean', eps=1e-06)[source]
Methods
__init__(model, pathway_membership[, ...])disable(model)Disable activation caching and clear cached tensors on
model.enable(model)Enable activation caching on every ResBlock of
model.loss(model)Compute the auxiliary similarity / dissimilarity losses.
- disable(model)[source]
Disable activation caching and clear cached tensors on
model.Useful at evaluation time to avoid retaining graph references. Returns
selffor chaining.
- enable(model)[source]
Enable activation caching on every ResBlock of
model.Idempotent. Called automatically at construction time. Returns
selffor chaining.
- loss(model)[source]
Compute the auxiliary similarity / dissimilarity losses.
Must be called after a training-mode forward pass so that each
ResBlockhas populated its :pyobj:`_last_activation`.- Parameters:
model (GSNN) – The same model passed to
__init__().- Returns:
(L_sim, L_dis)— both already scaled by their respectivelambda_*.L_disis0.0if nodissim_pairswere provided.- Return type: