gsnn.simulate.utils

Functions

nx_to_pyro_model(G, input_nodes, output_nodes)

Converts a NetworkX directed graph into a Pyro Bayesian network model with Gaussian distributions and allows complex transformations (e.g., squaring inputs, logic gates) specified by the user.

gsnn.simulate.utils.nx_to_pyro_model(G, input_nodes, output_nodes, special_functions=None, noise_scale=1, *, signed_edges: dict | None = None)[source]

Converts a NetworkX directed graph into a Pyro Bayesian network model with Gaussian distributions and allows complex transformations (e.g., squaring inputs, logic gates) specified by the user.

Parameters: G (networkx.DiGraph): A directed graph where nodes represent random variables and edges represent dependencies. input_nodes (list): A list of input node names. output_nodes (list): A list of output node names. special_functions (dict): A dictionary where the keys are node names, and the values are lambda functions

that define how to process the parent nodes’ values.

signed_edges (dict, optional): Mapping from (parent, child) edge tuples to sign (+1 or -1). If not provided,

the function looks for a ‘sign’ attribute on each edge in G. Defaults to +1 when unspecified, reproducing the original (unsigned) behaviour.

Returns: model (function): A Pyro model function that takes input values and returns output values.