Source code for causaldag.structure_learning.dag.gsp

from typing import Dict, Optional, Any, List, Set, Union
from causaldag import DAG
import itertools as itr
from causaldag.utils.ci_tests import CI_Tester, partial_correlation_test
from causaldag.classes.custom_types import UndirectedEdge
from causaldag.utils.invariance_tests import InvarianceTester
from causaldag.utils.core_utils import powerset, iszero
import random
from causaldag.structure_learning.undirected import threshold_ug, partial_correlation_threshold
from causaldag import UndirectedGraph
import numpy as np
from tqdm import trange, tqdm
from causaldag.utils.core_utils import powerset
from math import factorial


def perm2dag_precision(perm, precision, alpha=.01, num_samples=None):
    perm = np.array(perm)
    current_precision = precision.copy()
    current_precision[:, :] = current_precision[perm, :]
    current_precision[:, :] = current_precision[:, perm]
    nnodes = precision.shape[0]
    arcs = set()
    for node in range(nnodes-1, -1, -1):
        if num_samples is not None:
            current_precision_thresholded = partial_correlation_threshold(current_precision, n=num_samples, alpha=alpha)
            parents = np.nonzero(current_precision_thresholded[-1])[0]
        else:
            parents = np.nonzero(~iszero(current_precision[-1]))[0]
        label = perm[node]
        new_arcs = set(itr.product(perm[parents], [label])) - {(label, label)}
        arcs.update(new_arcs)
        current_precision = current_precision[:-1, :-1] - current_precision[-1, -1]**-1 * np.outer(current_precision[:-1, -1], current_precision[:-1, -1])

    return DAG(nodes=set(perm), arcs=arcs)


[docs]def permutation2dag( perm: list, ci_tester: CI_Tester, verbose=False, fixed_adjacencies: Set[UndirectedEdge]=set(), fixed_gaps: Set[UndirectedEdge]=set(), progress=False ): """ Estimate the minimal IMAP of a DAG which is consistent with the given permutation. Parameters ---------- perm: list of nodes representing the permutation. ci_tester: object for testing conditional independence. verbose: if True, log each CI test. fixed_adjacencies: set of nodes known to be adjacent. fixed_gaps: set of nodes known not to be adjacent. Examples -------- >>> from causaldag.utils.ci_tests import MemoizedCI_Tester, partial_correlation_test, partial_correlation_suffstat >>> perm = [0,1,2] >>> suffstat = partial_correlation_suffstat(samples) >>> ci_tester = MemoizedCI_Tester(partial_correlation_test, suffstat) >>> permutation2dag(perm, ci_tester, fixed_gaps={frozenset({1, 2})}) """ if hasattr(ci_tester, "ci_test") and ci_tester.ci_test == partial_correlation_test and "P" in ci_tester.suffstat: return perm2dag_precision(perm, ci_tester.suffstat["P"], ci_tester.kwargs.get('alpha'), ci_tester.suffstat['n']) if fixed_adjacencies: adj = next(iter(fixed_adjacencies)) if not isinstance(adj, frozenset): raise ValueError('fixed_adjacencies should contain frozensets') if fixed_gaps: adj = next(iter(fixed_gaps)) if not isinstance(adj, frozenset): raise ValueError('fixed_gaps should contain frozensets') d = DAG(nodes=set(perm)) ixs = list(itr.chain.from_iterable(((f, s) for f in range(s)) for s in range(len(perm)))) ixs = ixs if not progress else tqdm(ixs) for i, j in ixs: pi_i, pi_j = perm[i], perm[j] # === IF FIXED, DON'T TEST if frozenset({pi_i, pi_j}) in fixed_adjacencies: d.add_arc(pi_i, pi_j) continue if frozenset({pi_i, pi_j}) in fixed_gaps: continue # === TEST MARKOV BLANKET mb = d.markov_blanket_of(pi_i) is_ci = ci_tester.is_ci(pi_i, pi_j, mb) if not is_ci: d.add_arc(pi_i, pi_j, check_acyclic=True) if verbose: print(f"{pi_i} is independent of {pi_j} given {mb}: {is_ci}") return d
[docs]def sparsest_permutation(nodes, ci_tester, progress=False): """ Estimate the Markov equivalence class of a DAG using the Sparsest Permutations (SP) algorithm. Parameters ---------- nodes: list of nodes. ci_tester: object for testing conditional independence. progress: if True, show a progress bar over the enumeration of permutations. Examples -------- >>> from causaldag.utils.ci_tests import MemoizedCI_Tester, partial_correlation_test, partial_correlation_suffstat >>> import causaldag as cd >>> import random >>> import numpy as np >>> random.seed(1212) >>> np.random.seed(12131) >>> nnodes = 7 >>> d = cd.rand.directed_erdos(nnodes, exp_nbrs=2) >>> g = cd.rand.rand_weights(d) >>> samples = g.sample(1000) >>> suffstat = partial_correlation_suffstat(samples) >>> ci_tester = MemoizedCI_Tester(partial_correlation_test, suffstat, alpha=1e-3) >>> est_dag = cd.sparsest_permutation(set(range(nnodes)), ci_tester, progress=True) >>> true_cpdag = d.cpdag() >>> est_cpdag = est_dag.cpdag() >>> print(true_cpdag.shd(est_cpdag)) >>> 0 """ permutations = itr.permutations(nodes, len(nodes)) permutations = tqdm(permutations, total=factorial(len(nodes))) if progress else permutations min_dag, min_num_arcs = None, float('inf') for perm in permutations: dag = permutation2dag(list(perm), ci_tester) if dag.num_arcs < min_num_arcs: min_dag, min_num_arcs = dag, dag.num_arcs return min_dag
def perm2dag_subsets(perm, ci_tester, max_subset_size=None): """ Not recommended unless max_subset_size set very small. Not thoroughly tested. """ arcs = set() nodes = set(perm) for i, pi_i in enumerate(perm): for candidate_parent_set in powerset(perm[:i], r_max=max_subset_size): print(candidate_parent_set) if all(ci_tester.is_ci(i, j, candidate_parent_set) for j in nodes - {i} - candidate_parent_set): # if ci_tester.is_ci(i, nodes - {i} - candidate_parent_set, candidate_parent_set): arcs.update({(parent, i) for parent in candidate_parent_set}) break return DAG(nodes=nodes, arcs=arcs) def perm2dag2(perm, ci_tester, node2nbrs=None): arcs = set() for (i, pi_i), (j, pi_j) in itr.combinations(enumerate(perm), 2): c = set(perm[:j]) - {pi_i} c = c if node2nbrs is None else c & (node2nbrs[pi_i] | node2nbrs[pi_j]) print(pi_i, pi_j, c) if not ci_tester.is_ci(pi_i, pi_j, c): arcs.add((pi_i, pi_j)) return DAG(nodes=set(perm), arcs=arcs) def update_minimal_imap(dag, i, j, ci_tester, fixed_adjacencies=set(), fixed_gaps=set()): """ TODO Parameters ---------- TODO Examples -------- TODO """ removed_arcs = set() parents = dag.parents_of(i) for parent in parents: rest = parents - {parent} if (i, parent) not in fixed_adjacencies | fixed_gaps and (parent, i) not in fixed_adjacencies | fixed_gaps: if ci_tester.is_ci(i, parent, rest): removed_arcs.add((parent, i)) if (j, parent) not in fixed_adjacencies | fixed_gaps and (parent, j) not in fixed_adjacencies | fixed_gaps: if ci_tester.is_ci(j, parent, rest | {i}): removed_arcs.add((parent, j)) return removed_arcs # def min_degree_alg(undirected_graph, ci_tester: CI_Tester, delete=False): # permutation = [] # curr_undirected_graph = undirected_graph.copy() # while curr_undirected_graph._nodes: # min_degree = min(curr_undirected_graph.degrees.values()) # min_degree_nodes = {node for node, degree in curr_undirected_graph.degrees.items() if degree == min_degree} # k = random.choice(list(min_degree_nodes)) # nbrs_k = curr_undirected_graph._neighbors[k] # # curr_undirected_graph.delete_node(k) # curr_undirected_graph.add_edges_from(itr.combinations(nbrs_k, 2)) # # for nbr1, nbr2 in itr.combinations(nbrs_k, 2): # # # if not curr_undirected_graph.has_edge(nbr1, nbr2): # # curr_undirected_graph.add_edge(nbr1, nbr2) # # elif delete and ci_tester.is_ci(nbr1, nbr2, curr_undirected_graph._nodes - {nbr1, nbr2, k}): # # curr_undirected_graph.delete_edge(nbr1, nbr2) # # permutation.append(k) # # return list(reversed(permutation)) def min_degree_alg_amat(amat, rnd=True): """ TODO Parameters ---------- TODO Examples -------- TODO """ amat = amat.copy() remaining_nodes = list(range(amat.shape[0])) permutation = [] while remaining_nodes: # === PICK A NODE OF MINIMUM DEGREE curr_amat = amat[np.ix_(remaining_nodes, remaining_nodes)] degrees = curr_amat.sum(axis=0) min_degree = degrees.min() min_degree_ixs = np.where(degrees == min_degree)[0] min_degree_ix = random.choice(min_degree_ixs) # === ATTACH ITS NEIGHBORS nbrs = {remaining_nodes[ix] for ix in curr_amat[min_degree_ix].nonzero()[0]} for i in nbrs: amat[i, list(nbrs - {i})] = 1 # === REMOVE IT permutation.append(remaining_nodes[min_degree_ix]) del remaining_nodes[min_degree_ix] return list(reversed(permutation)) # def min_degree_alg2(undirected_graph, ci_tester: CI_Tester, delete=False): # permutation = [] # curr_undirected_graph = undirected_graph.copy() # while curr_undirected_graph._nodes: # nodes2added = {node: set() for node in curr_undirected_graph._nodes} # nodes2removed = {node: set() for node in curr_undirected_graph._nodes} # for k in curr_undirected_graph._nodes: # nbrs_k = curr_undirected_graph._neighbors[k] # for nbr1, nbr2 in itr.combinations(nbrs_k, 2): # if not curr_undirected_graph.has_edge(nbr1, nbr2): # nodes2added[k].add((nbr1, nbr2)) # elif delete and ci_tester.is_ci(nbr1, nbr2, curr_undirected_graph._nodes - {nbr1, nbr2, k}): # nodes2removed[k].add((nbr1, nbr2)) # # # === PICK A NODE # min_added = min(map(len, nodes2added.values())) # min_added_nodes = {node for node, added in nodes2added.items() if len(added) == min_added} # removed_node = random.choice(list(min_added_nodes)) # # # === UPDATE GRAPH # curr_undirected_graph.delete_node(removed_node) # curr_undirected_graph.add_edges_from(nodes2added[removed_node]) # if delete: # curr_undirected_graph.delete_edges_from(nodes2removed[removed_node]) # # permutation.append(removed_node) # # return list(reversed(permutation)) # def min_degree_alg2(undirected_graph): # amat = undirected_graph.to_amat(sparse=True) # return list(reversed(list(amd.order(amat)))) def jci_gsp( setting_list: List[Dict], nodes: set, combined_ci_tester: CI_Tester, depth: int = 4, nruns: int = 5, verbose: bool = False, initial_undirected: Optional[Union[str, UndirectedGraph]] = 'threshold', ): """ TODO Parameters ---------- TODO Examples -------- TODO """ # CREATE NEW NODES AND OTHER INPUT TO ALGORITHM context_nodes = ['c%d' % i for i in range(len(setting_list))] context_adjacencies = set(itr.permutations(context_nodes, r=2)) known_iv_adjacencies = set.union(*( {('c%s' % i, node) for node in setting['known_interventions']} for i, setting in enumerate(setting_list) )) fixed_orders = set(itr.combinations(context_nodes, 2)) | set(itr.product(context_nodes, nodes)) # === DO SMART INITIALIZATION if isinstance(initial_undirected, str): if initial_undirected == 'threshold': initial_undirected = threshold_ug(set(nodes), combined_ci_tester) else: raise ValueError("initial_undirected must be one of 'threshold', or an UndirectedGraph") if initial_undirected: amat = initial_undirected.to_amat() initial_permutations = [context_nodes + min_degree_alg_amat(amat) for _ in range(nruns)] else: initial_permutations = [context_nodes + random.sample(list(nodes), len(nodes)) for _ in range(nruns)] # === RUN GSP ON FULL DAG est_meta_dag, _ = gsp( nodes | set(context_nodes), combined_ci_tester, depth=depth, nruns=nruns, initial_permutations=initial_permutations, fixed_orders=fixed_orders, fixed_adjacencies=context_adjacencies | known_iv_adjacencies, verbose=verbose ) # === PROCESS OUTPUT learned_intervention_targets = { int(node[1:]): {child for child in est_meta_dag.children_of(node) if not isinstance(child, str)} for node in est_meta_dag.nodes if isinstance(node, str) } learned_intervention_targets = [learned_intervention_targets[i] for i in range(len(setting_list))] est_dag = est_meta_dag.induced_subgraph({node for node in est_meta_dag.nodes if not isinstance(node, str)}) return est_dag, learned_intervention_targets
[docs]def gsp( nodes: set, ci_tester: CI_Tester, depth: Optional[int] = 4, nruns: int = 5, verbose: bool = False, initial_undirected: Optional[Union[str, UndirectedGraph]] = 'threshold', initial_permutations: Optional[List] = None, fixed_orders=set(), fixed_adjacencies=set(), fixed_gaps=set(), use_lowest=True, max_iters=float('inf'), factor=2, progress_bar=False, summarize=False ) -> (DAG, List[List[Dict]]): """ Estimate the Markov equivalence class of a DAG using the Greedy Sparsest Permutations (GSP) algorithm. Parameters ---------- nodes: Labels of nodes in the graph. ci_tester: A conditional independence tester, which has a method is_ci taking two sets A and B, and a conditioning set C, and returns True/False. depth: Maximum depth in depth-first search. Use None for infinite search depth. nruns: Number of runs of the algorithm. Each run starts at a random permutation and the sparsest DAG from all runs is returned. verbose: TODO initial_undirected: Option to find the starting permutation by using the minimum degree algorithm on an undirected graph that is Markov to the data. You can provide the undirected graph yourself, use the default 'threshold' to do simple thresholding on the partial correlation matrix, or select 'None' to start at a random permutation. initial_permutations: A list of initial permutations with which to start the algorithm. This option is helpful when there is background knowledge on orders. This option is mutually exclusive with initial_undirected. fixed_orders: Tuples (i, j) where i is known to come before j. fixed_adjacencies: Tuples (i, j) where i and j are known to be adjacent. fixed_gaps: Tuples (i, j) where i and j are known to be non-adjacent. See Also -------- pcalg, igsp, unknown_target_igsp Return ------ (est_dag, summaries) """ if initial_permutations is not None: nruns = len(initial_permutations) if initial_permutations is None and isinstance(initial_undirected, str): if initial_undirected == 'threshold': initial_undirected = threshold_ug(nodes, ci_tester) else: raise ValueError("initial_undirected must be one of 'threshold', or an UndirectedGraph") # === GENERATE CANDIDATE STARTING PERMUTATIONS if initial_permutations is None: if initial_undirected: amat = initial_undirected.to_amat() initial_permutations = [min_degree_alg_amat(amat) for _ in range(factor * nruns)] else: initial_permutations = [random.sample(nodes, len(nodes)) for _ in range(nruns)] # === FIND CANDIDATE STARTING DAGS starting_dags = [] for perm in initial_permutations: d = permutation2dag(perm, ci_tester, fixed_adjacencies=fixed_adjacencies, fixed_gaps=fixed_gaps) starting_dags.append(d) starting_dags = sorted(starting_dags, key=lambda d: d.num_arcs) summaries = [] min_dag = None # all_kept_dags = set() range_fn = range if not progress_bar else trange for r in range_fn(nruns): summary = [] current_dag = starting_dags[r] if verbose: print("=== STARTING DAG:", current_dag) # === FIND NEXT POSSIBLE MOVES current_covered_arcs = current_dag.reversible_arcs() - fixed_orders if verbose: print(f"Current covered arcs: {current_covered_arcs}") covered_arcs2removed_arcs = [ (i, j, update_minimal_imap(current_dag, i, j, ci_tester)) for i, j in current_covered_arcs ] covered_arcs2removed_arcs = sorted(covered_arcs2removed_arcs, key=lambda c: len(c[2])) # === RECORDS FOR DEPTH-FIRST SEARCH all_visited_dags = set() trace = [] graph_counter = 0 # === SEARCH! iters_since_improvement = 0 it_count = 0 while True: it_count += 1 if iters_since_improvement > max_iters: break if summarize: summary.append({'dag': current_dag, 'depth': len(trace), 'num_arcs': len(current_dag.arcs)}) all_visited_dags.add(frozenset(current_dag.arcs)) max_arcs_removed = len(covered_arcs2removed_arcs[-1][2]) if len(covered_arcs2removed_arcs) > 0 else 0 if (len(covered_arcs2removed_arcs) > 0 and len(trace) != depth) or max_arcs_removed > 0: graph_counter += 1 if max_arcs_removed > 0: # start over at sparser DAG iters_since_improvement = 0 # all_visited_dags = set() trace = [] # === CHOOSE A SPARSER I-MAP if use_lowest: candidate_ixs = [ ix for ix, (i, j, rem) in enumerate(covered_arcs2removed_arcs) if len(rem) == max_arcs_removed ] else: candidate_ixs = [ix for ix, (i, j, rem) in enumerate(covered_arcs2removed_arcs) if len(rem) > 0] selected_ix = random.choice(candidate_ixs) # === FIND THE DAG CORRESPONDING TO THE SPARSER IMAP i, j, rem_arcs = covered_arcs2removed_arcs.pop(selected_ix) current_dag.reverse_arc(i, j, check_acyclic=True) current_dag.remove_arcs_from(rem_arcs) current_covered_arcs = current_dag.reversible_arcs() - fixed_orders # if frozenset(current_dag.arcs) in all_kept_dags: # CHECK IF THIS MAKES SENSE # print('Break') # break # all_kept_dags.add(frozenset(current_dag.arcs)) if verbose: print("=== FOUND DAG WITH FEWER ARCS:", current_dag) else: iters_since_improvement += 1 trace.append((current_dag.copy(), current_covered_arcs, covered_arcs2removed_arcs)) i, j, _ = covered_arcs2removed_arcs.pop(random.randrange(len(covered_arcs2removed_arcs))) current_dag.reverse_arc(i, j, check_acyclic=True) current_covered_arcs = current_dag.reversible_arcs() - fixed_orders # === FIND NEXT POSSIBLE MOVES covered_arcs2removed_arcs = [ (i, j, update_minimal_imap(current_dag, i, j, ci_tester)) for i, j in current_covered_arcs ] covered_arcs2removed_arcs = sorted(covered_arcs2removed_arcs, key=lambda c: len(c[2])) # === REMOVE ANY MOVES WHICH LEAD TO ALREADY-EXPLORED DAGS current_arcs = frozenset(current_dag.arcs) covered_arcs2removed_arcs = [ (i, j, rem_arcs) for i, j, rem_arcs in covered_arcs2removed_arcs if current_arcs - {(i, j)} | {(j, i)} - rem_arcs not in all_visited_dags ] else: if len(trace) == 0: # reached minimum within search depth break else: # backtrack current_dag, current_covered_arcs, covered_arcs2removed_arcs = trace.pop() # === END OF RUN if summarize: summaries.append(summary) if min_dag is None or len(current_dag.arcs) < len(min_dag.arcs): min_dag = current_dag if summarize: return min_dag, summaries else: return min_dag
def igsp( setting_list: List[Dict], nodes: set, ci_tester: CI_Tester, invariance_tester: InvarianceTester, depth: Optional[int] = 4, nruns: int = 5, initial_undirected: Optional[Union[str, UndirectedGraph]] = 'threshold', initial_permutations: Optional[List] = None, verbose: bool = False, ): """ TODO Parameters ---------- TODO Examples -------- TODO """ only_single_node = all(len(setting['interventions']) <= 1 for setting in setting_list) interventions2setting_nums = { frozenset(setting['interventions']): setting_num for setting_num, setting in enumerate(setting_list) } def _is_icovered(i, j): """ i -> j is I-covered if: 1) if {i} is an intervention, then f^{i}(j) = f(j) """ setting_num = interventions2setting_nums.get(frozenset({i})) if setting_num is not None and not invariance_tester.is_invariant(j, context=setting_num): return False # for iv_nodes in samples.keys(): # if j in iv_nodes and i not in iv_nodes: # if not _get_is_variant(iv_nodes, i, None): # return False return True def _reverse_arc(dag, i, j): new_dag = dag.copy() parents = dag.parents_of(i) new_dag.reverse_arc(i, j) if parents: for parent in parents: rest = parents - {parent} if ci_tester.is_ci(i, parent, [*rest, j]): new_dag.remove_arc(parent, i) if ci_tester.is_ci(j, parent, cond_set=[*rest]): new_dag.remove_arc(parent, j) new_covered_arcs = new_dag.reversible_arcs() new_icovered_arcs = [(i, j) for i, j in new_covered_arcs if _is_icovered(i, j)] new_contradicting = _get_contradicting_arcs(new_dag) return new_dag, new_icovered_arcs, new_contradicting def _is_i_contradicting(i, j, dag): """ i -> j is I-contradicting if either: 1) there exists S, a subset of the neighbors of j besides i, s.t. f^I(j|S) = f(j|S) for all I containing i but not j 2) there exists I with j \in I but i \not\in I, s.t. f^I(i|S) \not\eq f(i|S) for all subsets S of the neighbors of i besides j If there are only single node interventions, this condition becomes: 1) {i} \in I and f^{i}(j) = f(j) or 2) {j} \in I and f^{j}(i) \neq f(i) """ if only_single_node: setting_num_i = interventions2setting_nums.get(frozenset({i})) if setting_num_i is not None and invariance_tester.is_invariant(j, context=setting_num_i): return True setting_num_j = interventions2setting_nums.get(frozenset({j})) if setting_num_j is not None and not invariance_tester.is_invariant(i, context=setting_num_j): return True return False else: # === TEST CONDITION 1 neighbors_j = dag.neighbors_of(j) - {i} for s in powerset(neighbors_j): for setting_num, setting in enumerate(setting_list): if i in setting['interventions'] and j not in setting['interventions']: if not invariance_tester.is_invariant(j, context=setting_num, cond_set=s): return True neighbors_i = dag.neighbors_of(i) - {j} for setting_num, setting in enumerate(setting_list): if j in setting['interventions'] and i not in setting['interventions']: i_always_varies = all( invariance_tester.is_invariant(i, context=setting_num, cond_set=s) for s in powerset(neighbors_i) ) if i_always_varies: return True return False def _get_contradicting_arcs(dag): """ Count the number of I-contradicting arcs in the DAG dag """ contradicting_arcs = {(i, j) for i, j in dag.arcs if _is_icovered(i, j) and _is_i_contradicting(i, j, dag)} return contradicting_arcs summaries = [] # === LIST OF DAGS FOUND BY EACH RUN finishing_dags = [] if initial_permutations is None and isinstance(initial_undirected, str): if initial_undirected == 'threshold': initial_undirected = threshold_ug(nodes, ci_tester) else: raise ValueError("initial_undirected must be one of 'threshold', or an UndirectedGraph") # === DO MULTIPLE RUNS for r in range(nruns): summary = [] # === STARTING VALUES if initial_permutations is not None: starting_perm = initial_permutations[r] elif initial_undirected: starting_perm = min_degree_alg_amat(initial_undirected.to_amat()) else: starting_perm = random.sample(nodes, len(nodes)) current_dag = permutation2dag(starting_perm, ci_tester) if verbose: print("=== STARTING RUN %s/%s" % (r + 1, nruns)) current_covered_arcs = current_dag.reversible_arcs() current_icovered_arcs = [(i, j) for i, j in current_covered_arcs if _is_icovered(i, j)] current_contradicting = _get_contradicting_arcs(current_dag) next_dags = [_reverse_arc(current_dag, i, j) for i, j in current_icovered_arcs] random.shuffle(next_dags) # === RECORDS FOR DEPTH-FIRST SEARCH all_visited_dags = set() trace = [] min_dag_run = (current_dag, current_contradicting) # === SEARCH while True: summary.append({ 'dag': current_dag, 'num_arcs': len(current_dag.arcs), 'num_contradicting': len(current_contradicting) }) all_visited_dags.add(frozenset(current_dag.arcs)) lower_dags = [ (d, icovered_arcs, contradicting_arcs) for d, icovered_arcs, contradicting_arcs in next_dags if len(d.arcs) < len(current_dag.arcs) ] if verbose: desc = f'({len(current_dag.arcs)} arcs' desc += f', I-covered: {current_icovered_arcs}' desc += f', I-contradicting: {current_contradicting})' print('-' * len(trace), current_dag, desc) if (len(next_dags) > 0 and len(trace) != depth) or len(lower_dags) > 0: if len(lower_dags) > 0: # restart at a lower DAG all_visited_dags = set() trace = [] current_dag, current_icovered_arcs, current_contradicting = lower_dags.pop() min_dag_run = (current_dag, current_contradicting) if verbose: print(f"FOUND DAG WITH {len(current_dag.arcs)}) ARCS: {current_dag}") else: trace.append((current_dag, current_icovered_arcs, current_contradicting)) current_dag, current_icovered_arcs, current_contradicting = next_dags.pop() if len(current_contradicting) < len(min_dag_run[1]): min_dag_run = (current_dag, current_contradicting) if verbose: print(f"FOUND DAG WITH {current_contradicting} CONTRADICTING ARCS: {current_dag}") next_dags = [_reverse_arc(current_dag, i, j) for i, j in current_icovered_arcs] next_dags = [ (d, icovered_arcs, contradicting_arcs) for d, icovered_arcs, contradicting_arcs in next_dags if frozenset(d.arcs) not in all_visited_dags ] random.shuffle(next_dags) # === DEAD END else: if len(trace) == 0: break else: # len(lower_dags) == 0, len(next_dags) > 0, len(trace) == depth current_dag, current_icovered_arcs, current_contradicting = trace.pop() # === END OF RUN summaries.append(summary) finishing_dags.append(min_dag_run) min_dag = min(finishing_dags, key=lambda dag_n: (len(dag_n[0].arcs), len(dag_n[1]))) # print(min_dag) return min_dag[0] def is_icovered( setting_list: List[Dict], i: int, j: int, dag: DAG, invariance_tester: InvarianceTester, ): """ Tell if an edge i->j is I-covered with respect to the invariance tests. True if, for all I s.t. i \in I, the distribution of j given its parents varies between the observational and interventional data. setting_list: A list of dictionaries that provide meta-information about each setting. The first setting must be observational. i: Source of the edge being tested. j: Target of the edge being tested. """ parents_j = list(dag.parents_of(j)) for setting_num, setting in enumerate(setting_list): if i in setting['interventions']: if invariance_tester.is_invariant(j, context=setting_num, cond_set=parents_j): return False return True def unknown_target_igsp( setting_list: List[Dict], nodes: set, ci_tester: CI_Tester, invariance_tester: InvarianceTester, depth: Optional[int] = 4, nruns: int = 5, initial_undirected: Optional[Union[str, UndirectedGraph]] = 'threshold', initial_permutations: Optional[List] = None, verbose: bool = False, use_lowest=True, tup_score=True, no_targets=False ) -> (DAG, List[Set[int]]): """ Use the Unknown Target Interventional Greedy Sparsest Permutation algorithm to estimate a DAG in the I-MEC of the data-generating DAG. Parameters ---------- setting_list: A list of dictionaries that provide meta-information about each non-observational setting. nodes: Nodes in the graph. ci_tester: A conditional independence tester object, which has a method is_ci taking two sets A and B, and a conditioning set C, and returns True/False. invariance_tester: An invariance tester object, which has a method is_invariant taking a node, two settings, and a conditioning set C, and returns True/False. depth: Maximum depth in depth-first search. Use None for infinite search depth. nruns: Number of runs of the algorithm. Each run starts at a random permutation and the sparsest DAG from all runs is returned. initial_undirected: Option to find the starting permutation by using the minimum degree algorithm on an undirected graph that is Markov to the data. You can provide the undirected graph yourself, use the default 'threshold' to do simple thresholding on the partial correlation matrix, or select 'None' to start at a random permutation. initial_permutations: A list of initial permutations with which to start the algorithm. This option is helpful when there is background knowledge on orders. This option is mutually exclusive with initial_undirected. no_targets: if True, leave out information on known intervention targets. """ if no_targets: setting_list = [{'known_interventions': []} for _ in setting_list] def _is_icovered(i, j, dag): """ Check if the edge i->j is I-covered in the DAG dag """ parents_j = frozenset(dag.parents_of(j)) for setting_num, setting in enumerate(setting_list): if i in setting['known_interventions']: if invariance_tester.is_invariant(j, context=setting_num, cond_set=parents_j): return False return True def _get_variants(dag): """ Count the number of variances for the DAG dag """ variants = set() for i in dag.nodes: parents_i = frozenset(dag.parents_of(i)) for setting_num, setting in enumerate(setting_list): if not invariance_tester.is_invariant(i, context=setting_num, cond_set=parents_i): variants.add((setting_num, i, parents_i)) return variants def _reverse_arc_igsp(dag, i_covered_arcs, i, j): """ Return the DAG that comes from reversing the arc i->j, as well as its I-covered arcs and its score """ new_dag = dag.copy() parents = dag.parents_of(i) new_dag.reverse_arc(i, j) if parents: for parent in parents: rest = parents - {parent} if ci_tester.is_ci(i, parent, [*rest, j]): new_dag.remove_arc(parent, i) if ci_tester.is_ci(j, parent, cond_set=[*rest]): new_dag.remove_arc(parent, j) # new_i_covered_arcs = i_covered_arcs.copy() - dag.incident_arcs(i) - dag.incident_arcs(j) # for k, l in new_dag.incident_arcs(i) | new_dag.incident_arcs(j): # if new_dag.parents_of(k) == new_dag.parents_of(l) - {k} and _is_icovered(i, j, dag): # new_i_covered_arcs.add((k, l)) new_covered_arcs = new_dag.reversible_arcs() new_i_covered_arcs = [(i, j) for i, j in new_covered_arcs if _is_icovered(i, j, new_dag)] variants = _get_variants(new_dag) new_score = len(new_dag.arcs) + len(variants) if not tup_score else (len(new_dag.arcs), len(variants)) intervention_targets = [set() for _ in range(len(setting_list))] for setting_num, i, parents_i in variants: intervention_targets[setting_num].add(i) return new_dag, new_i_covered_arcs, new_score, intervention_targets # === MINIMUM DAG AND SCORE FOUND BY ANY RUN min_dag = None min_score = float('inf') if not tup_score else (float('inf'), float('inf')) learned_intervention_targets = None if initial_permutations is None and isinstance(initial_undirected, str): if initial_undirected == 'threshold': initial_undirected = threshold_ug(nodes, ci_tester) else: raise ValueError("initial_undirected must be one of 'threshold', or an UndirectedGraph") # === MULTIPLE RUNS for r in range(nruns): # === STARTING VALUES if initial_permutations is not None: starting_perm = initial_permutations[r] elif initial_undirected: starting_perm = min_degree_alg_amat(initial_undirected.to_amat()) else: starting_perm = random.sample(nodes, len(nodes)) current_dag = permutation2dag(starting_perm, ci_tester) variants = _get_variants(current_dag) current_intervention_targets = [set() for _ in range(len(setting_list))] for setting_num, i, parents_i in variants: current_intervention_targets[setting_num].add(i) current_score = len(current_dag.arcs) + len(variants) if not tup_score else ( len(current_dag.arcs), len(variants)) if verbose: print("=== STARTING DAG:", current_dag, "== SCORE:", current_score) current_covered_arcs = current_dag.reversible_arcs() current_i_covered_arcs = [(i, j) for i, j in current_covered_arcs if _is_icovered(i, j, current_dag)] if verbose: print("=== STARTING I-COVERED ARCS:", current_i_covered_arcs) next_dags = [_reverse_arc_igsp(current_dag, current_i_covered_arcs, i, j) for i, j in current_i_covered_arcs] next_dags = [ (d, i_cov_arcs, score, iv_targets) for d, i_cov_arcs, score, iv_targets in next_dags if score <= current_score ] random.shuffle(next_dags) # === RECORDS FOR DEPTH-FIRST SEARCH all_visited_dags = set() trace = [] # === SEARCH! while True: if verbose: print('-' * len(trace), current_dag, '(%d arcs)' % len(current_dag.arcs), 'I-covered arcs:', current_i_covered_arcs, 'score:', current_score) all_visited_dags.add(frozenset(current_dag.arcs)) lower_dags = [ (d, i_cov_arcs, score, iv_targets) for d, i_cov_arcs, score, iv_targets in next_dags if score < current_score ] if (len(next_dags) > 0 and len(trace) != depth) or len(lower_dags) > 0: if len(lower_dags) > 0: # restart at a lower DAG all_visited_dags = set() trace = [] lowest_ix = min(enumerate(lower_dags), key=lambda x: x[1][2])[0] if use_lowest else 0 current_dag, current_i_covered_arcs, current_score, current_intervention_targets = lower_dags.pop( lowest_ix) if verbose: print("FOUND DAG WITH LOWER SCORE:", current_dag, "== SCORE:", current_score) if verbose: print(f"Current intervention targets: {current_intervention_targets}") else: trace.append((current_dag, current_i_covered_arcs, next_dags, current_intervention_targets)) current_dag, current_i_covered_arcs, current_score, current_intervention_targets = next_dags.pop() next_dags = [ _reverse_arc_igsp(current_dag, current_i_covered_arcs, i, j) for i, j in current_i_covered_arcs ] next_dags = [ (d, i_cov_arcs, score, iv_targets) for d, i_cov_arcs, score, iv_targets in next_dags if score <= current_score ] next_dags = [ (d, i_cov_arcs, score, iv_targets) for d, i_cov_arcs, score, iv_targets in next_dags if frozenset(d.arcs) not in all_visited_dags ] random.shuffle(next_dags) # === DEAD END === else: if len(trace) == 0: # reached minimum within search depth break else: # backtrack current_dag, current_i_covered_arcs, next_dags, current_intervention_targets = trace.pop() if min_dag is None or current_score < min_score: min_dag = current_dag min_score = current_score learned_intervention_targets = current_intervention_targets if verbose: print("=== FINISHED RUN %s/%s ===" % (r + 1, nruns)) return min_dag, learned_intervention_targets if __name__ == '__main__': import causaldag as cd from causaldag.utils.ci_tests.ci_tester import MemoizedCI_Tester from causaldag.utils.ci_tests.oracle import dsep_test p = 10 d = cd.rand.directed_erdos(p, .2) ci_tester = MemoizedCI_Tester(dsep_test, d) est_dag = gsp(set(range(p)), ci_tester, nruns=1, depth=float('inf')) print(est_dag.shd_skeleton(d))