Source code for gsnn.interpret.extract_entity_function
import torch
import numpy as np
import torch_geometric as pyg
import scipy
from gsnn.models.utils import get_conv_indices
[docs]class dense_func_node(torch.nn.Module):
def __init__(self, lin_in, lin_out, nonlin, norm, node_mlp=None):
super().__init__()
self.lin_in = lin_in
self.lin_out = lin_out
# `nonlin` may be an instantiated module or a class – handle both
self.nonlin = nonlin if isinstance(nonlin, torch.nn.Module) else nonlin()
self.node_mlp = node_mlp # Optional node MLP
channels = lin_in.out_features # hidden channels produced by `lin_in`
# ------------------------------------------------------------------
# Normalisation layers (mirrors gsnn.models.GSNN.ResBlock logic)
# ------------------------------------------------------------------
if norm == 'layer':
assert False, 'Layer norm not implemented for extracted single-node functions'
self.norm = torch.nn.LayerNorm(channels, elementwise_affine=False)
self.norm_first = True
# TODO: copy params from gsnn norm
elif norm == 'batch':
assert False, 'Batch norm not implemented for extracted single-node functions'
self.norm = torch.nn.BatchNorm1d(channels, eps=1e-3, affine=False)
self.norm_first = True
# TODO: copy params from gsnn norm
elif norm in ('groupbatch', 'edgebatch'):
raise NotImplementedError('Group/edge batch norm not implemented for extracted single-node functions')
elif norm == 'softmax':
# Approximate SoftmaxGroupNorm with per-feature softmax.
self.norm = torch.nn.Softmax(dim=1)
self.norm_first = False # match ResBlock behaviour
elif norm == 'none':
self.norm = torch.nn.Identity()
self.norm_first = True
else:
raise ValueError(f"Unrecognized norm type '{norm}'")
[docs] def forward(self, x):
"""Forward pass reproducing ResBlock ordering of norm / nonlin."""
x = self.lin_in(x)
if self.norm_first:
x = self.norm(x)
x = self.nonlin(x)
else:
x = self.nonlin(x)
x = self.norm(x)
# Apply optional node MLP if present
if self.node_mlp is not None:
# Reshape for node MLP: (batch_size, channels) -> (batch_size, 1, channels)
batch_size = x.size(0)
channels = x.size(1)
x = x.view(batch_size, 1, channels)
x = self.node_mlp(x)
# Reshape back: (batch_size, 1, channels) -> (batch_size, channels)
x = x.view(batch_size, channels)
x = self.lin_out(x)
return x
[docs]def extract_entity_function(node, model, data, layer=0):
r"""Extract the *stand-alone* MLP that implements a single GSNN function node.
Given a trained :class:`~gsnn.models.GSNN.GSNN` model and the graph that
was used to train it, this helper rebuilds the exact linear-nonlinear
sequence that corresponds to a single *function* node at a particular
layer. The returned module consumes the latent representations of its
input edges and produces the hidden activations that are sent to its
outgoing edges, replicating the behaviour inside the parent GSNN.
Parameters
----------
node : str
Name of the *function* node to extract (must exist in
``data.node_names_dict['function']``).
model : gsnn.models.GSNN.GSNN
Reference GSNN model (weights are copied; the original model remains
unchanged).
data : torch_geometric.data.HeteroData
Heterogeneous graph object used for training.
layer : int, optional (default=0)
Index of the GSNN layer (``ResBlocks[layer]``) from which to extract
the node-specific sub-network.
Returns
-------
func : torch.nn.Module
A dense two-layer network ``func(x_in) -> x_out`` that is numerically
equivalent to the chosen node inside the GSNN.
meta : dict
Dictionary with
* ``'input_edge_names'`` – list[str] of incoming edge names
* ``'output_edge_names'`` – list[str] of outgoing edge names
Example
-------
>>> func_node, meta = extract_entity_function('func3', model, data, layer=1)
>>> y = func_node(torch.randn(len(meta['input_edge_names'])))
>>> print(meta['output_edge_names'])
"""
model = model.cpu()
# total number of nodes
N = len(data.node_names_dict['input']) + len(data.node_names_dict['function']) + len(data.node_names_dict['output'])
# get homogenous network index; see hetero2homo (GSNN) for reference
node_idx = data.node_names_dict['function'].index(node)
#node_idx = torch.tensor([node_idx], dtype=torch.long)
# convert to edge indexing
#node_idx = (utils.node2edge(torch.arange(N).unsqueeze(0), model.edge_index) == node_idx).nonzero(as_tuple=True)[0]
#print(node_idx)
# NOTE THESE ARE EDGE INDICES (NOT NODE INDICES)
row,col = model.edge_index.detach().cpu()
input_edges = (col == node_idx).nonzero(as_tuple=True)[0]
output_edges = (row == node_idx).nonzero(as_tuple=True)[0]
node_names = np.array(data.node_names_dict['function'] + data.node_names_dict['input'] + data.node_names_dict['output'])
inp_edge_names = node_names[row[input_edges]]
out_edge_names = node_names[col[output_edges]]
# the hidden channel indices relevant to `node`
function_nodes = torch.arange(len(data.node_names_dict['function']))
w1_indices, w_out_indices, w_in_size, w_out_size, channel_groups = get_conv_indices(model.edge_index, model.channels, function_nodes)
assert (w1_indices == model.ResBlocks[layer].lin_in.indices).all(), 'W1 indices do not match model W1 indices'
channel_groups = np.array(channel_groups)
hidden_idxs = torch.tensor((channel_groups == node_idx).nonzero()[0], dtype=torch.long)
N_channels = len(hidden_idxs)
# we have a bipartite network from edge_idx -> function node hidden layers
indices, values = pyg.utils.bipartite_subgraph(subset = (input_edges, hidden_idxs),
edge_index = model.ResBlocks[layer].lin_in.indices,
edge_attr = model.ResBlocks[layer].lin_in.values.data,
relabel_nodes = True,
return_edge_mask = False,
size = (model.edge_index.size(1), len(channel_groups)))
w1_smol = scipy.sparse.coo_array((values.detach(), (indices[0,:].detach(), indices[1,:].detach())), shape=(len(input_edges), N_channels)).todense()
if hasattr(model.ResBlocks[layer].lin_in, 'bias'):
# Bias vector is defined per hidden channel (out_features). Select
# the channels that belong to the current node (hidden_idxs).
w1_bias = model.ResBlocks[layer].lin_in.bias[hidden_idxs].detach().numpy()
else:
w1_bias = None
lin1_smol = torch.nn.Linear(*w1_smol.shape)
lin1_smol.weight = torch.nn.Parameter(torch.tensor(w1_smol.T, dtype=torch.float32))
if w1_bias is not None: lin1_smol.bias = torch.nn.Parameter(torch.tensor(w1_bias.squeeze(), dtype=torch.float32))
indices, values = pyg.utils.bipartite_subgraph(subset=(hidden_idxs, output_edges),
edge_index=model.ResBlocks[layer].lin_out.indices,
edge_attr=model.ResBlocks[layer].lin_out.values,
relabel_nodes=True,
return_edge_mask=False)
w3_smol = scipy.sparse.coo_array((values.detach(), (indices[0,:].detach(), indices[1,:].detach())), shape=(N_channels, len(output_edges.view(-1)))).todense()
if hasattr(model.ResBlocks[layer].lin_out, 'bias'):
w3_bias = model.ResBlocks[layer].lin_out.bias[output_edges].detach().numpy()
else:
w3_bias = None
lin3_smol = torch.nn.Linear(*w3_smol.shape)
lin3_smol.weight = torch.nn.Parameter(torch.tensor(w3_smol.T, dtype=torch.float32))
if w3_bias is not None: lin3_smol.bias = torch.nn.Parameter(torch.tensor(w3_bias.squeeze(), dtype=torch.float32))
norm = getattr(model, 'norm', 'none')
# Extract node MLP if present
node_mlp = None
if hasattr(model.ResBlocks[layer], 'node_mlp') and model.ResBlocks[layer].node_mlp and model.ResBlocks[layer].mlp is not None:
# Extract the specific node batch norm params (one param per node)
# Create a new MLP with node-specific batch norm running stats
original_mlp = model.ResBlocks[layer].mlp
# Create a new sequential module with extracted components
new_layers = []
for i, layer_module in enumerate(original_mlp):
if isinstance(layer_module, torch.nn.BatchNorm1d):
# Extract running stats for this specific node
new_bn = torch.nn.BatchNorm1d(1, eps=layer_module.eps, momentum=layer_module.momentum,
affine=layer_module.affine, track_running_stats=layer_module.track_running_stats)
# Copy the running stats for the specific node
if layer_module.track_running_stats:
new_bn.running_mean.data = layer_module.running_mean[node_idx:node_idx+1].clone()
new_bn.running_var.data = layer_module.running_var[node_idx:node_idx+1].clone()
new_bn.num_batches_tracked.data = layer_module.num_batches_tracked.clone()
# Copy affine parameters if present
if layer_module.affine:
new_bn.weight.data = layer_module.weight[node_idx:node_idx+1].clone()
new_bn.bias.data = layer_module.bias[node_idx:node_idx+1].clone()
new_layers.append(new_bn)
else:
# For non-BatchNorm layers, copy as-is
new_layers.append(layer_module)
node_mlp = torch.nn.Sequential(*new_layers)
func = dense_func_node(lin_in=lin1_smol, lin_out=lin3_smol, nonlin=model.ResBlocks[layer].nonlin, norm=norm, node_mlp=node_mlp)
func = func.eval()
meta = {'input_edge_names':inp_edge_names, 'output_edge_names':out_edge_names}
return func, meta