Source code for causal_networkx.ci.oracle
from typing import Union
import numpy as np
from causal_networkx import ADMG, DAG
from causal_networkx.algorithms.d_separation import d_separated
from .base import BaseConditionalIndependenceTest
[docs]class Oracle(BaseConditionalIndependenceTest):
"""Oracle conditional independence testing.
Used for unit testing and checking intuition.
Parameters
----------
graph : DAG | ADMG
The ground-truth causal graph.
"""
def __init__(self, graph: Union[ADMG, DAG]) -> None:
self.graph = graph
[docs] def test(self, df, x_var, y_var, z_covariates):
"""Conditional independence test given an oracle.
Checks conditional independence between 'x_var' and 'y_var'
given 'z_covariates' of variables using the causal graph
as an oracle.
Parameters
----------
df : pd.DataFrame of shape (n_samples, n_variables)
The data matrix. Passed in for API consistency, but not
used.
x_var : node
A node in the dataset.
y_var : node
A node in the dataset.
z_covariates : set
The set of variables to check that separates x_var and y_var.
Returns
-------
statistic : None
A return argument for the statistic.
pvalue : float
The pvalue. Return '1.0' if not independent and '0.0'
if they are.
"""
self._check_test_input(df, x_var, y_var, z_covariates)
# just check for d-separation between x and y
# given sep_set
is_sep = d_separated(self.graph, x_var, y_var, z_covariates)
if is_sep:
pvalue = 1
test_stat = 0
else:
pvalue = 0
test_stat = np.inf
return test_stat, pvalue
[docs]class ParentChildOracle(Oracle):
"""Parent and children oracle for conditional independence testing.
An oracle that knows the definite parents and children of every node.
"""
[docs] def get_children(self, x):
"""Return the definite children of node 'x'."""
return self.graph.successors(x)
[docs] def get_parents(self, x):
"""Return the definite parents of node 'x'."""
return self.graph.predecessors(x)
class MarkovBlanketOracle(ParentChildOracle):
"""MB oracle for conditional independence testing.
An oracle that knows the definite Markov Blanket of every node.
"""
def __init__(self, graph: Union[ADMG, DAG]) -> None:
super().__init__(graph)
def get_markov_blanket(self, x):
"""Return the markov blanket of node 'x'."""
return self.graph.markov_blanket_of(x)
class AncestralOracle(ParentChildOracle):
"""Oracle with access to ancestors of any specific node."""
def get_ancestors(self, x):
return self.graph.ancestors(x)