gsnn.interpret.extract_entity_function
Functions
|
Extract the stand-alone MLP that implements a single GSNN function node. |
|
Compute indexing structures for convolutional (sparse linear) layers. |
Classes
|
- class gsnn.interpret.extract_entity_function.dense_func_node(*args: Any, **kwargs: Any)[source]
Bases:
Module
- gsnn.interpret.extract_entity_function.extract_entity_function(node, model, data, layer=0)[source]
Extract the stand-alone MLP that implements a single GSNN function node.
Given a trained
GSNNmodel and the graph that was used to train it, this helper rebuilds the exact linear-nonlinear sequence that corresponds to a single function node at a particular layer. The returned module consumes the latent representations of its input edges and produces the hidden activations that are sent to its outgoing edges, replicating the behaviour inside the parent GSNN.- Parameters:
node (str) – Name of the function node to extract (must exist in
data.node_names_dict['function']).model (gsnn.models.GSNN.GSNN) – Reference GSNN model (weights are copied; the original model remains unchanged).
data (torch_geometric.data.HeteroData) – Heterogeneous graph object used for training.
layer (int, optional (default=0)) – Index of the GSNN layer (
ResBlocks[layer]) from which to extract the node-specific sub-network.
- Returns:
func (torch.nn.Module) – A dense two-layer network
func(x_in) -> x_outthat is numerically equivalent to the chosen node inside the GSNN.meta (dict) – Dictionary with
'input_edge_names'– list[str] of incoming edge names'output_edge_names'– list[str] of outgoing edge names
Example
>>> func_node, meta = extract_entity_function('func3', model, data, layer=1) >>> y = func_node(torch.randn(len(meta['input_edge_names']))) >>> print(meta['output_edge_names'])