Source code for causaldag.utils.ci_tests.partial_correlation_test

from typing import Dict
from math import erf
import numba
from numpy import sqrt, log1p, abs, ix_, diag, corrcoef, errstate, cov, mean
from numpy.linalg import inv, pinv
# from . import MemoizedCI_Tester


__all__ = [
    "partial_correlation_suffstat",
    "partial_correlation_test",
    "compute_partial_correlation"
]


@numba.jit
def numba_inv(A):
    return inv(A)


[docs]def partial_correlation_suffstat(samples, invert=True) -> Dict: """ Return the sufficient statistics for partial correlation testing. Parameters ---------- samples: (n x p) matrix, where n is the number of samples and p is the number of variables. invert: if True, compute the inverse correlation matrix, and normalize it into the partial correlation matrix. This will generally speed up the gauss_ci_test if large conditioning sets are used. Returns ------- dict dictionary of sufficient statistics """ n, p = samples.shape S = cov(samples, rowvar=False) # sample covariance matrix mu = mean(samples, axis=0) # TODO: NaN when variable is deterministic. Replace w/ 1 and 0? C = corrcoef(samples, rowvar=False) # sample correlation matrix if invert: K = pinv(C) P = pinv(S) # sample precision (inverse covariance) matrix rho = K/sqrt(diag(K))/sqrt(diag(K))[:, None] # sample partial correlation matrix return dict(P=P, S=S, C=C, n=n, K=K, rho=rho, mu=mu) return dict(S=S, C=C, n=n, mu=mu)
[docs]def compute_partial_correlation(suffstat, i, j, cond_set=None): """ Compute the partial correlation between i and j given ``cond_set``. Parameters ---------- suffstat: dictionary containing: 'n' -- number of samples 'C' -- correlation matrix 'K' (optional) -- inverse correlation matrix 'rho' (optional) -- partial correlation matrix (K, normalized so diagonals are 1). i: position of first variable in correlation matrix. j: position of second variable in correlation matrix. cond_set: positions of conditioning set in correlation matrix. Returns ------- float partial correlation """ C = suffstat.get('C') p = C.shape[0] rho = suffstat.get('rho') K = suffstat.get('K') # === COMPUTE PARTIAL CORRELATION # partial correlation is correlation if there is no conditioning if cond_set is None or len(cond_set) == 0: r = C[i, j] # used closed-form elif len(cond_set) == 1: k = list(cond_set)[0] r = (C[i, j] - C[i, k]*C[j, k]) / sqrt((1 - C[j, k]**2) * (1 - C[i, k]**2)) # when conditioning on everything, partial correlation comes from normalized precision matrix elif len(cond_set) == p - 2 and rho is not None: r = -rho[i, j] # faster to use Schur complement if conditioning set is large and precision matrix is pre-computed elif len(cond_set) >= p/2 and K is not None: rest = list(set(range(C.shape[0])) - {i, j, *cond_set}) if len(rest) == 1: theta_ij = K[ix_([i, j], [i, j])] - K[ix_([i, j], rest)] @ K[ix_(rest, [i, j])] / K[rest[0], rest[0]] else: theta_ij = K[ix_([i, j], [i, j])] - K[ix_([i, j], rest)] @ pinv(K[ix_(rest, rest)]) @ K[ix_(rest, [i, j])] # TODO: what to do if not invertible? r = -theta_ij[0, 1] / sqrt(theta_ij[0, 0] * theta_ij[1, 1]) else: theta = pinv(C[ix_([i, j, *cond_set], [i, j, *cond_set])]) # TODO: what to do if not invertible? r = -theta[0, 1]/sqrt(theta[0, 0] * theta[1, 1]) return r
[docs]def partial_correlation_test(suffstat: Dict, i, j, cond_set=None, alpha=None): """ Test the null hypothesis that i and j are conditionally independent given ``cond_set``. Uses Fisher's z-transform. Parameters ---------- suffstat: dictionary containing: * ``n`` -- number of samples * ``C`` -- correlation matrix * ``K`` (optional) -- inverse correlation matrix * ``rho`` (optional) -- partial correlation matrix (K, normalized so diagonals are 1). i: position of first variable in correlation matrix. j: position of second variable in correlation matrix. cond_set: positions of conditioning set in correlation matrix. alpha: Significance level. Returns ------- dict dictionary containing: * ``statistic`` * ``p_value`` * ``reject`` """ n = suffstat['n'] n_cond = 0 if cond_set is None else len(cond_set) alpha = 1/n if alpha is None else alpha r = compute_partial_correlation(suffstat, i, j, cond_set=cond_set) # === COMPUTE STATISTIC AND P-VALUE # note: log1p(2r/(1-r)) = log((1+r)/(1-r)) but is more numerically stable for r near 0 # r = 1 causes warnings but gives the correct answer with errstate(divide='ignore', invalid='ignore'): statistic = sqrt(n - n_cond - 3) * abs(.5 * log1p(2*r/(1 - r))) # note: erf is much faster than norm.cdf p_value = 2*(1 - .5*(1 + erf(statistic/sqrt(2)))) return dict(statistic=statistic, p_value=p_value, reject=p_value < alpha)
# class MemoizedGaussCI_Tester(MemoizedCI_Tester): # def __init__(self, suffstat: Dict, track_times=False, detailed=False, **kwargs): # MemoizedCI_Tester.__init__(self, partial_correlation_test, suffstat, track_times=track_times, detailed=detailed) if __name__ == '__main__': import numpy as np x = np.random.normal(size=(100, 3)) s = partial_correlation_suffstat(x) res = partial_correlation_test(s, 0, 1) print(res)