gsnn.interpret.extract_entity_function

Functions

extract_entity_function(node, model, data[, ...])

Extract the stand-alone MLP that implements a single GSNN function node.

get_conv_indices(edge_index, channels, ...)

Compute indexing structures for convolutional (sparse linear) layers.

Classes

dense_func_node(*args, **kwargs)

class gsnn.interpret.extract_entity_function.dense_func_node(*args: Any, **kwargs: Any)[source]

Bases: Module

forward(x)[source]

Forward pass reproducing ResBlock ordering of norm / nonlin.

gsnn.interpret.extract_entity_function.extract_entity_function(node, model, data, layer=0)[source]

Extract the stand-alone MLP that implements a single GSNN function node.

Given a trained GSNN model and the graph that was used to train it, this helper rebuilds the exact linear-nonlinear sequence that corresponds to a single function node at a particular layer. The returned module consumes the latent representations of its input edges and produces the hidden activations that are sent to its outgoing edges, replicating the behaviour inside the parent GSNN.

Parameters:
  • node (str) – Name of the function node to extract (must exist in data.node_names_dict['function']).

  • model (gsnn.models.GSNN.GSNN) – Reference GSNN model (weights are copied; the original model remains unchanged).

  • data (torch_geometric.data.HeteroData) – Heterogeneous graph object used for training.

  • layer (int, optional (default=0)) – Index of the GSNN layer (ResBlocks[layer]) from which to extract the node-specific sub-network.

Returns:

  • func (torch.nn.Module) – A dense two-layer network func(x_in) -> x_out that is numerically equivalent to the chosen node inside the GSNN.

  • meta (dict) – Dictionary with

    • 'input_edge_names' – list[str] of incoming edge names

    • 'output_edge_names' – list[str] of outgoing edge names

Example

>>> func_node, meta = extract_entity_function('func3', model, data, layer=1)
>>> y = func_node(torch.randn(len(meta['input_edge_names'])))
>>> print(meta['output_edge_names'])