Causal Graphical Models#

In this section, we explore what are known as causal graphical models (CGM), which are essentially Bayesian networks where edges imply causal influence rather then just probabilistic dependence.

CGMs are assumed to be acyclic, meaning they do not have cycles among their variables.

[22]:
import networkx as nx
from causal_networkx import ADMG

import matplotlib.pyplot as plt

Causally Sufficient Models#

Here, we don’t have any latent variables. We demonstrate how a CGM works in code and what we can do with it.

We also demonstrate Clustered DAGs (CDAGs), which form from a cluster of variables, which is represented underneath the hood with two graphs. One consisting of all the variables denoting the cluster ID in the metadata, and another consisting of the graph between clusters. The first graph may be incompletely specified, since we do not require the edges within a cluster be fully specified.

Based on knowledge of CDAGs, we know that d-separation is complete.

[23]:
dag = nx.MultiDiGraph()
[63]:
dag.add_edge('A', 'B', key='direct')
dag.add_edge('A', 'B', key='bidirected')
dag.add_edge('B', 'A', key='bidirected')
dag.add_edge('C', 'B', key='direct')
[63]:
'direct'
[64]:
print(dag.edges)
[('A', 'B', 'direct'), ('A', 'B', 'bidirected'), ('B', 'A', 'bidirected'), ('C', 'B', 'direct')]
[68]:
print(dag)
G = dag
pos = nx.random_layout(dag)
node_sizes = [3 + 10 * i for i in range(len(G))]
M = G.number_of_edges()
edge_colors = range(2, M + 2)
edge_alphas = [(5 + i) / (M + 4) for i in range(M)]
cmap = plt.cm.viridis

# nx.draw_networkx(dag, pos=pos)
nodes = nx.draw_networkx_nodes(G, pos, node_size=node_sizes, node_color="indigo")
directed_edges = nx.draw_networkx_edges(
    G,
    pos,
    edgelist=[('A', 'B', 'direct'), ('C', 'B', 'direct')],
    node_size=node_sizes,
    arrowstyle="->",
    arrowsize=10,
    # edge_color=edge_colors,
    edge_cmap=cmap,
    width=2,
    # connectionstyle="arc3,rad=0.1"
)
bd_edges = nx.draw_networkx_edges(
    G,
    pos,
    edgelist=[('A', 'B', 'bidirected')],
    node_size=node_sizes,
    style='dotted',
    # arrowstyle="->",
    arrowsize=10,
    # edge_color=edge_colors,
    edge_cmap=cmap,
    width=2,
    connectionstyle="arc3,rad=0.4"
)

bd_edges = nx.draw_networkx_edges(
    G,
    pos,
    edgelist=[('B', 'A', 'bidirected')],
    node_size=node_sizes,
    style='dotted',
    # arrowstyle="->",
    arrowsize=10,
    # edge_color=edge_colors,
    edge_cmap=cmap,
    width=2,
    connectionstyle="arc3,rad=-0.4"
)

# set alpha value for each edge
# for i in range(M):
    # edges[i].set_alpha(edge_alphas[i])

print(nx.d_separated(G, 'C', 'A', {}))
MultiDiGraph with 3 nodes and 4 edges
[<matplotlib.patches.FancyArrowPatch object at 0x13897d7f0>]
---------------------------------------------------------------------------
NetworkXError                             Traceback (most recent call last)
Input In [68], in <cell line: 57>()
     52 print(edges)
     53 # set alpha value for each edge
     54 # for i in range(M):
     55     # edges[i].set_alpha(edge_alphas[i])
---> 57 print(nx.d_separated(G, 'C', 'A', {}))

File <class 'networkx.utils.decorators.argmap'> compilation 8:4, in argmap_d_separated_5(G, x, y, z)
      2 from os.path import splitext
      3 from contextlib import contextmanager
----> 4 from pathlib import Path
      6 import networkx as nx
      7 from networkx.utils import create_random_state, create_py_random_state

File ~/miniforge3/envs/causal3.8m1/lib/python3.8/site-packages/networkx/algorithms/d_separation.py:106, in d_separated(G, x, y, z)
     70 """
     71 Return whether node sets ``x`` and ``y`` are d-separated by ``z``.
     72
   (...)
    102
    103 """
    105 if not nx.is_directed_acyclic_graph(G):
--> 106     raise nx.NetworkXError("graph should be directed acyclic")
    108 union_xyz = x.union(y).union(z)
    110 if any(n not in G.nodes for n in union_xyz):

NetworkXError: graph should be directed acyclic
../_images/tutorials_cgm_6_2.png
[1]:
nodes = [
    'a', 'b', 'c', 'd', 'e'
]
cgm = ADMG(ebunch=nodes)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
/Users/adam2392/Documents/causalscm/examples/cgm.ipynb Cell 2' in <cell line: 4>()
      <a href='vscode-notebook-cell:/Users/adam2392/Documents/causalscm/examples/cgm.ipynb#ch0000001?line=0'>1</a> nodes = [
      <a href='vscode-notebook-cell:/Users/adam2392/Documents/causalscm/examples/cgm.ipynb#ch0000001?line=1'>2</a>     'a', 'b', 'c', 'd', 'e'
      <a href='vscode-notebook-cell:/Users/adam2392/Documents/causalscm/examples/cgm.ipynb#ch0000001?line=2'>3</a> ]
----> <a href='vscode-notebook-cell:/Users/adam2392/Documents/causalscm/examples/cgm.ipynb#ch0000001?line=3'>4</a> cgm = ADMGicalModel(ebunch=nodes)

NameError: name 'ADMGicalModel' is not defined
[ ]:

[ ]:

[ ]: