r"""Pathway latent-factor regularizer for :class:`gsnn.models.GSNN`.
Implements the latent-factor objective described in
``docs/notes/functional_dis_and_similarity.md`` (section 5).
Given a pathway membership matrix ``M`` of shape ``(P, N_func)`` over the
function nodes of a :class:`GSNN` model, the regularizer encourages members of
the same pathway to *co-vary across the batch dimension* with a shared
per-pathway latent score, and (optionally) penalizes correlation between
designated dissimilar pathway pairs. This biases the model toward solutions
consistent with prior pathway co-membership without modifying the network
topology or adding hub nodes.
The regularizer is opt-in and fully backward compatible: it does not modify
:class:`GSNN`. It only toggles the existing :pyobj:`ResBlock._store_activations`
flag and reads :pyobj:`ResBlock._last_activation` after a training-mode forward
pass.
Example
-------
>>> from gsnn.models import GSNN, PathwayLatentRegularizer
>>> model = GSNN(edge_index_dict, node_names_dict, channels=8, layers=3)
>>> reg = PathwayLatentRegularizer(model, pathway_membership=M, lambda_sim=0.1)
>>> for x, y in loader:
... yhat = model(x)
... L_main = mse(y, yhat)
... L_sim, L_dis = reg.loss(model)
... (L_main + L_sim + L_dis).backward()
"""
import torch
import torch.nn as nn
[docs]class PathwayLatentRegularizer(nn.Module):
r"""Per-pathway latent-factor auxiliary loss for GSNN.
For each :class:`ResBlock` :math:`\ell` and minibatch of size :math:`B`:
1. Reduce the cached activation
:math:`A_\ell \in \mathbb{R}^{B \times N_{\text{func}} \times C_{pn}}`
to per-node scalars
:math:`s_\ell \in \mathbb{R}^{B \times N_{\text{func}}}` via
:math:`\phi`.
2. Compute per-pathway scores
:math:`S_\ell = \mathrm{normalize}(M\, s_\ell^\top)^\top \in
\mathbb{R}^{B \times P}`.
3. Standardize across the batch dimension and compute the
member-by-pathway correlation matrix
:math:`C \in \mathbb{R}^{N_{\text{func}} \times P}`.
4. Add the negative member-side correlation to :math:`L_{\text{sim}}`,
and (if ``dissim_pairs`` is provided) the squared score-correlation of
dissimilar pairs to :math:`L_{\text{dis}}`.
Cost is :math:`O(L \cdot B \cdot (N_{\text{func}} \cdot C_{pn} + P \cdot
N_{\text{func}}))` per minibatch and parameter count is either zero
(``phi='mean'``) or :math:`C_{pn}` (``phi='learned'``).
Args:
model (GSNN): Reference model. Used only for shape introspection
(``ResBlocks[0].channel_groups``) and to toggle
:pyobj:`_store_activations`.
pathway_membership (Tensor): Float tensor of shape
:obj:`(P, N_func)`. Real-valued entries are allowed and treated as
soft membership weights; binary 0/1 entries give classical hard
membership. Rows correspond to pathways, columns to function-node
indices in the same order as the model's ``function`` node names.
dissim_pairs (Tensor, optional): :obj:`(M, 2)` LongTensor of pathway
index pairs whose scores should be encouraged to be uncorrelated.
(default: :obj:`None`)
lambda_sim (float, optional): Scaling for the similarity term.
(default: :obj:`0.1`)
lambda_dis (float, optional): Scaling for the dissimilarity term.
(default: :obj:`0.0`)
phi (str or nn.Module, optional): Per-node scalar reduction.
``'mean'`` averages across channels (no parameters). ``'learned'``
uses a single :class:`nn.Linear` projection ``Linear(C_pn, 1)``
shared across layers. A custom :class:`nn.Module` may also be
passed; it must accept a tensor of shape
:obj:`(B, N_func, C_pn)` and return :obj:`(B, N_func, 1)` or
:obj:`(B, N_func)`. (default: :obj:`'mean'`)
eps (float, optional): Small constant for numerical stability in the
standardization step. (default: :obj:`1e-6`)
"""
[docs] def __init__(
self,
model,
pathway_membership,
dissim_pairs=None,
lambda_sim=0.1,
lambda_dis=0.0,
phi="mean",
eps=1e-6,
):
super().__init__()
if not hasattr(model, "ResBlocks") or len(model.ResBlocks) == 0:
raise ValueError(
"model must be a GSNN instance with at least one ResBlock."
)
# Derive per-node and per-channel sizes from the first block. GSNN
# currently uses uniform channel counts across function nodes; this is
# the same assumption made by ResBlock and NodeAttention internally.
channel_groups = model.ResBlocks[0].channel_groups
self.n_func = int(channel_groups.max().item() + 1)
self.c_pn = int(channel_groups.numel() // self.n_func)
if not isinstance(pathway_membership, torch.Tensor):
pathway_membership = torch.as_tensor(pathway_membership)
if pathway_membership.dim() != 2:
raise ValueError(
"pathway_membership must be 2-D (P, N_func); got "
f"shape {tuple(pathway_membership.shape)}."
)
if pathway_membership.size(1) != self.n_func:
raise ValueError(
f"pathway_membership has {pathway_membership.size(1)} columns; "
f"expected {self.n_func} (= number of function nodes)."
)
self.register_buffer("M", pathway_membership.float())
if dissim_pairs is not None:
if not isinstance(dissim_pairs, torch.Tensor):
dissim_pairs = torch.as_tensor(dissim_pairs)
if dissim_pairs.dim() != 2 or dissim_pairs.size(-1) != 2:
raise ValueError(
"dissim_pairs must be of shape (M, 2); got "
f"shape {tuple(dissim_pairs.shape)}."
)
self.register_buffer("dissim_pairs", dissim_pairs.long())
else:
self.dissim_pairs = None
self.lambda_sim = float(lambda_sim)
self.lambda_dis = float(lambda_dis)
self.eps = float(eps)
if phi == "mean":
self.phi = None
elif phi == "learned":
self.phi = nn.Linear(self.c_pn, 1, bias=False)
elif isinstance(phi, nn.Module):
self.phi = phi
else:
raise ValueError(
"phi must be 'mean', 'learned', or an nn.Module; got "
f"{phi!r}."
)
self.enable(model)
[docs] def enable(self, model):
r"""Enable activation caching on every ResBlock of ``model``.
Idempotent. Called automatically at construction time. Returns
``self`` for chaining.
"""
for blk in model.ResBlocks:
blk._store_activations = True
return self
[docs] def disable(self, model):
r"""Disable activation caching and clear cached tensors on ``model``.
Useful at evaluation time to avoid retaining graph references. Returns
``self`` for chaining.
"""
for blk in model.ResBlocks:
blk._store_activations = False
if hasattr(blk, "_last_activation"):
blk._last_activation = None
return self
def _reduce(self, A):
r"""Reduce a cached activation tensor to per-(sample, node) scalars.
Args:
A (Tensor): Cached activation of shape :obj:`(B, N_func * C_pn)`
or :obj:`(B, N_func * C_pn, 1)`.
Returns:
Tensor: :obj:`(B, N_func)`.
"""
A = A.squeeze(-1).reshape(A.size(0), self.n_func, self.c_pn)
if self.phi is None:
return A.mean(-1)
out = self.phi(A)
if out.dim() == 3 and out.size(-1) == 1:
out = out.squeeze(-1)
return out
[docs] def loss(self, model):
r"""Compute the auxiliary similarity / dissimilarity losses.
Must be called *after* a training-mode forward pass so that each
:class:`ResBlock` has populated its :pyobj:`_last_activation`.
Args:
model (GSNN): The same model passed to :py:meth:`__init__`.
Returns:
tuple: ``(L_sim, L_dis)`` — both already scaled by their
respective :obj:`lambda_*`. ``L_dis`` is :obj:`0.0` if no
``dissim_pairs`` were provided.
"""
L_sim = torch.zeros((), device=self.M.device)
L_dis = torch.zeros((), device=self.M.device)
n_layers = 0
# Member mask reused across layers.
member_mask = self.M.T # (N_func, P)
member_norm = member_mask.sum().clamp_min(1)
pathway_size = self.M.sum(1).clamp_min(1) # (P,)
for blk in model.ResBlocks:
A = getattr(blk, "_last_activation", None)
if A is None:
raise RuntimeError(
"ResBlock._last_activation is missing. Run a forward "
"pass with the model in training mode before calling "
"loss(); the regularizer's enable() must have set "
"_store_activations=True (this happens automatically "
"at construction)."
)
n_layers += 1
s = self._reduce(A) # (B, N_func)
S = (self.M @ s.t()).t() / pathway_size # (B, P)
s_std = (s - s.mean(0, keepdim=True)) / (
s.std(0, unbiased=False, keepdim=True) + self.eps
)
S_std = (S - S.mean(0, keepdim=True)) / (
S.std(0, unbiased=False, keepdim=True) + self.eps
)
# member <-> pathway-score correlation
corr = (s_std.t() @ S_std) / s_std.size(0) # (N_func, P)
L_sim = L_sim - (corr * member_mask).sum() / member_norm
if self.dissim_pairs is not None and self.dissim_pairs.numel() > 0:
Sp = S_std[:, self.dissim_pairs[:, 0]]
Sq = S_std[:, self.dissim_pairs[:, 1]]
pair_corr = (Sp * Sq).mean(0) # (M,)
L_dis = L_dis + (pair_corr ** 2).mean()
if n_layers > 1:
L_sim = L_sim / n_layers
L_dis = L_dis / n_layers
return self.lambda_sim * L_sim, self.lambda_dis * L_dis