import itertools as itr
from causaldag.classes import UndirectedGraph
from causaldag.utils.ci_tests import CI_Tester, partial_correlation_test
from numpy import sqrt, log1p, ndenumerate, errstate, diagonal, fill_diagonal
from scipy.special import erf
[docs]def threshold_ug(nodes: set, ci_tester: CI_Tester) -> UndirectedGraph:
"""
Estimate an undirected graph by testing whether each pair of nodes is independent given all others.
Parameters
----------
nodes:
Nodes in the graph.
ci_tester:
Conditional independence tester.
Examples
--------
TODO
"""
if hasattr(ci_tester, 'ci_test') and ci_tester.ci_test == partial_correlation_test:
return threshold_ug_gauss(ci_tester)
edges = {(i, j) for i, j in itr.combinations(nodes, 2) if not ci_tester.is_ci(i, j, nodes - {i, j})}
return UndirectedGraph(nodes, edges)
def partial_correlation_threshold(precision, n=None, alpha=None):
if n is None:
return precision
assert(len(precision.shape) == 2)
r = precision/sqrt(diagonal(precision))/sqrt(diagonal(precision))[:, None]
p = r.shape[0]
n_cond = p - 2
# note: log1p(2r/(1-r)) = log((1+r)/(1-r)) but is more numerically stable for r near 0
# r = 1 causes warnings but gives the correct answer
with errstate(divide='ignore', invalid='ignore'):
statistic = sqrt(n - n_cond - 3) * abs(.5 * log1p(2*r/(1 - r)))
p_values = 1 - .5*(1 + erf(statistic/sqrt(2)))
zero_ixs = p_values > alpha
fill_diagonal(zero_ixs, False)
r[zero_ixs] = 0
return r * sqrt(diagonal(precision))*sqrt(diagonal(precision))[:, None]
def threshold_ug_gauss(ci_tester):
"""
Estimate an undirected graph by testing whether each pair of nodes is independent given all others,
which reduces to thresholding partial correlations (after the Fisher z-transform) for multivariate Gaussian
data.
Parameters
----------
ci_tester:
Conditional independence tester.
Examples
--------
TODO
"""
r = partial_correlation_threshold(ci_tester.suffstat["P"], ci_tester.suffstat['n'], ci_tester.kwargs.get('alpha'))
edges = {(i, j) for (i, j), val in ndenumerate(r) if val != 0 and i != j}
return UndirectedGraph(set(range(ci_tester.suffstat["P"].shape[0])), edges)