gsnn.models.NodeActivity

Classes

NodeActivity(*args, **kwargs)

Node-wise channel attention, mediated by external features.

SignedMessagePassing(*args, **kwargs)

Aggregate scalar signals over function-function edges using stored signs (edge weights).

class gsnn.models.NodeActivity.NodeActivity(*args: Any, **kwargs: Any)[source]

Bases: Module

Node-wise channel attention, mediated by external features. This is performed once per forward pass, and the node attention are applied at every layer.

This requires an FUNCTION_NODE feature matrix to be provided, and parameters are shared across all nodes. e.g., mutation or expression features for function nodes. Note, missingness is not supported.

Parameters:
  • channel_groups (Sequence[int] or Tensor) – A 1-D list/array mapping global channel indexnode index. Length equals the total number of hidden channels across all nodes.

  • activity_dim (int,) – The dimension of the function node activity features.

  • dropout (float, optional (default=0.0)) – Dropout probability applied inside the MLP.

  • temperature (float, optional (default=1.0)) – Sigmoid temperature applied to the logits before gating. Lower values produce sharper (closer to 0/1) gates; higher values produce softer (closer to 0.5) gates.

forward(x: torch.Tensor)[source]

Infer the node attention/activity.

Parameters:

x (Tensor of shape (B, Nf, activity_dim) or (B, Nf)) – Function node features. If activity_dim == 1 a 2-D tensor of shape (B, Nf) is also accepted and will be unsqueezed internally. Nf must equal self.n_nodes.

Returns:

Tensor – Node activity/attention applied per channel.

Return type:

(B, Nf * C_pn)

get_alpha_mean()[source]