# Author: Chandler Squires
"""
Base class for partially directed acyclic graphs
"""
from collections import defaultdict
from causaldag.utils import core_utils
import itertools as itr
import numpy as np
from typing import Set
from collections import namedtuple
from scipy.special import factorial
import networkx as nx
from typing import Set, FrozenSet, Iterable
import csv
SmallDag = namedtuple('SmallDag', ['arcs', 'reversible_arcs', 'parents_dict', 'children_dict', 'level'])
[docs]class PDAG:
def __init__(
self,
nodes: Set = set(),
arcs: Set = set(),
edges: Set = set(),
known_arcs=set(),
new=False
):
self._nodes = set(nodes)
self._arcs = set()
self._edges = set()
self._parents = defaultdict(set)
self._children = defaultdict(set)
self._neighbors = defaultdict(set)
self._undirected_neighbors = defaultdict(set)
if new: # for some reason this is slower than the old way. memory?
self._add_arcs_from(arcs)
self._add_edges_from(edges)
else:
for arc in arcs:
self._add_arc(*arc)
for edge in edges:
self._add_edge(*edge)
self._known_arcs = known_arcs.copy()
@classmethod
def from_df(cls, df, source_axis=0):
arcs = set()
edges = set()
amat = df.values
nodes = set(df.index)
name_map = dict(enumerate(df.index))
for i, j in zip(*np.triu_indices_from(amat, k=1)):
if amat[i, j] != 0 and amat[j, i] != 0:
edges.add((i, j))
elif amat[i, j] != 0:
arcs.add((i, j) if source_axis == 0 else (j, i))
elif amat[j, i] != 0:
arcs.add((j, i) if source_axis == 0 else (i, j))
arcs = {(name_map[i], name_map[j]) for i, j in arcs}
edges = {(name_map[i], name_map[j]) for i, j in edges}
return PDAG(nodes, arcs, edges)
@classmethod
def from_sparse(cls, sparse_amat, source_axis=0):
raise NotImplementedError
@classmethod
def from_csv(cls, filename):
raise NotImplementedError
[docs] @classmethod
def from_amat(cls, amat: np.ndarray, source_axis=0):
"""Return a PDAG with arcs/edges given by amat
"""
nrows, ncols = amat.shape
arcs = set()
edges = set()
for i, j in zip(*np.triu_indices_from(amat, k=1)):
if amat[i, j] != 0 and amat[j, i] != 0:
edges.add((i, j))
elif amat[i, j] != 0:
arcs.add((i, j) if source_axis == 0 else (j, i))
elif amat[j, i] != 0:
arcs.add((j, i) if source_axis == 0 else (j, i))
return PDAG(set(range(nrows)), arcs, edges)
@classmethod
def from_nx(cls, nx_graph):
return PDAG(nodes=nx_graph.nodes, edges=nx_graph.edges)
# CONVERTERS
def to_nx(self):
if self._arcs:
raise NotImplementedError
g = nx.Graph()
g.add_edges_from(self._edges)
return g
def to_csv(self, filename):
with open(filename, 'w', newline='\n') as file:
writer = csv.writer(file)
for source, target in self._arcs:
writer.writerow([source, target])
for node1, node2 in self._edges:
writer.writerow([node1, node2])
writer.writerow([node2, node1])
def to_df(self, node_list=None, source_axis=0):
if node_list is None:
node_list = sorted(self._nodes)
node2ix = {node: i for i, node in enumerate(node_list)}
shape = (len(self._nodes), len(self._nodes))
amat = np.zeros(shape, dtype=int)
for source, target in self._arcs:
if source_axis == 0:
amat[node2ix[source], node2ix[target]] = 1
else:
amat[node2ix[target], node2ix[source]] = 1
for i, j in self._edges:
amat[node2ix[i], node2ix[j]] = 1
amat[node2ix[j], node2ix[i]] = 1
from pandas import DataFrame
return DataFrame(amat, index=node_list, columns=node_list)
def to_sparse(self, node_list: list = None, source_axis=0):
from scipy.sparse import lil_matrix
shape = (len(self._nodes), len(self._nodes))
amat = lil_matrix(shape, dtype=int)
if node_list is None:
node_list = sorted(self._nodes)
node2ix = {node: i for i, node in enumerate(node_list)}
for source, target in self._arcs:
if source_axis == 0:
amat[node2ix[source], node2ix[target]] = 1
else:
amat[node2ix[target], node2ix[source]] = 1
for i, j in self._edges:
amat[node2ix[i], node2ix[j]] = 1
amat[node2ix[j], node2ix[i]] = 1
return amat, node_list
[docs] def to_amat(self, node_list: list = None, source_axis=0) -> (np.ndarray, list):
"""Return an adjacency matrix for the graph
"""
if node_list is None:
node_list = sorted(self._nodes)
node2ix = {node: i for i, node in enumerate(node_list)}
shape = (len(self._nodes), len(self._nodes))
amat = np.zeros(shape, dtype=int)
for source, target in self._arcs:
if source_axis == 0:
amat[node2ix[source], node2ix[target]] = 1
else:
amat[node2ix[target], node2ix[source]] = 1
for i, j in self._edges:
amat[node2ix[i], node2ix[j]] = 1
amat[node2ix[j], node2ix[i]] = 1
return amat, node_list
def __eq__(self, other):
same_nodes = self._nodes == other._nodes
same_arcs = self._arcs == other._arcs
same_edges = self._edges == other._edges
return same_nodes and same_arcs and same_edges
def __str__(self):
substrings = []
for node in self._nodes:
parents = self._parents[node]
nbrs = self._undirected_neighbors[node]
parents_str = ','.join(map(str, parents)) if len(parents) != 0 else ''
nbrs_str = ','.join(map(str, nbrs)) if len(nbrs) != 0 else ''
if len(parents) == 0 and len(nbrs) == 0:
substrings.append('[{node}]'.format(node=node))
else:
substrings.append('[{node}|{parents}:{nbrs}]'.format(node=node, parents=parents_str, nbrs=nbrs_str))
return ''.join(substrings)
[docs] def copy(self):
"""Return a copy of the graph
"""
return PDAG(nodes=self._nodes, arcs=self._arcs, edges=self._edges, known_arcs=self._known_arcs)
def rename_nodes(self, name_map):
return PDAG(
nodes={name_map[n] for n in self._nodes},
arcs={(name_map[i], name_map[j]) for i, j in self._arcs},
edges={(name_map[i], name_map[j]) for i, j in self._edges}
)
# === PROPERTIES
@property
def nodes(self):
return set(self._nodes)
@property
def nnodes(self):
return len(self._nodes)
@property
def num_arcs(self):
return len(self._arcs)
@property
def num_edges(self):
return len(self._edges)
@property
def num_adjacencies(self):
return self.num_arcs + self.num_edges
@property
def arcs(self):
return set(self._arcs)
@property
def edges(self):
return set(self._edges)
@property
def parents(self):
return core_utils.defdict2dict(self._parents, self._nodes)
@property
def children(self):
return core_utils.defdict2dict(self._children, self._nodes)
@property
def neighbors(self):
return core_utils.defdict2dict(self._neighbors, self._nodes)
@property
def undirected_neighbors(self):
return core_utils.defdict2dict(self._undirected_neighbors, self._nodes)
@property
def skeleton(self):
return {frozenset({i, j}) for i, j in self._arcs | self._edges}
@property
def dominated_nodes(self):
dominated_nodes = set()
for node in self._nodes:
num_nbrs = self.undirected_degree_of(node)
if num_nbrs == 0:
dominated_nodes.add(node)
elif num_nbrs == 1:
max_nbrs_of_nbrs = max(len(self._undirected_neighbors[nbr]) for nbr in self._undirected_neighbors[node])
if max_nbrs_of_nbrs > 1:
dominated_nodes.add(node)
return dominated_nodes
def clique_size(self):
if len(self._arcs) == 0:
g = self.to_nx()
return nx.chordal_graph_treewidth(g) + 1
else:
return max(cc.clique_size() for cc in self.chain_components())
def max_cliques(self):
if len(self._arcs) == 0:
g = self.to_nx()
m_cliques = nx.chordal_graph_cliques(g)
return {frozenset(c) for c in m_cliques}
else:
return set.union(*(cc.max_cliques() for cc in self.chain_components()))
# === PROPERTIES W/ ARGUMENTS
def indegree_of(self, node):
return len(self._parents[node])
def outdegree_of(self, node):
return len(self._children[node])
def undirected_degree_of(self, node):
return len(self._undirected_neighbors[node])
def total_degree_of(self, node):
return len(self._neighbors[node])
def parents_of(self, node):
return set(self._parents[node])
def children_of(self, node):
return set(self._children[node])
def neighbors_of(self, node):
return set(self._neighbors[node])
def undirected_neighbors_of(self, node):
return set(self._undirected_neighbors[node])
[docs] def has_edge(self, i, j):
"""Return True if the graph contains the edge i--j
"""
return frozenset({i, j}) in self._edges
def has_arc(self, i, j):
"""Return True if the graph contains the arc i->j"""
return (i, j) in self._arcs
[docs] def has_edge_or_arc(self, i, j):
"""Return True if the graph contains the edge i--j or an arc i->j or i<-j
"""
return (i, j) in self._arcs or (j, i) in self._arcs or self.has_edge(i, j)
def vstructs(self):
vstructs = set()
for node in self._nodes:
for p1, p2 in itr.combinations(self._parents[node], 2):
if p1 not in self._parents[p2] and p2 not in self._parents[p1]:
vstructs.add((p1, node))
vstructs.add((p2, node))
return vstructs
def _undirected_reachable(self, node, tmp, visited):
visited.add(node)
tmp.add(node)
for nbr in filter(lambda nbr: nbr not in visited, self._undirected_neighbors[node]):
tmp = self._undirected_reachable(nbr, tmp, visited)
return tmp
def chain_components(self, rename=False):
"""Return the chain components of this graph.
Return
------
List[Set[node]]
Return the partition of nodes coming from the relation of reachability by undirected edges.
"""
node_queue = self._nodes.copy()
components = []
visited_nodes = set()
while node_queue:
node = node_queue.pop()
if node not in visited_nodes:
reachable = self._undirected_reachable(node, set(), visited_nodes)
if len(reachable) > 1: components.append(reachable)
return [self.induced_subgraph(c, rename=rename) for c in components]
def induced_subgraph(self, nodes, rename=False):
if rename:
ixs = dict(map(reversed, enumerate(nodes)))
new_nodes = set(range(len(nodes)))
arcs = {(ixs[i], ixs[j]) for i, j in self._arcs if i in nodes and j in nodes}
edges = {(ixs[i], ixs[j]) for i, j in self._edges if i in nodes and j in nodes}
else:
new_nodes = nodes
arcs = {(i, j) for i, j in self._arcs if i in nodes and j in nodes}
edges = {(i, j) for i, j in self._edges if i in nodes and j in nodes}
return PDAG(nodes=new_nodes, arcs=arcs, edges=edges)
def interventional_cpdag(self, dag, intervened_nodes):
cut_edges = set()
for node in intervened_nodes:
cut_edges.update(dag.incident_arcs(node))
p = PDAG(self._nodes, self._arcs | cut_edges, self._edges - {frozenset({i, j}) for i, j in cut_edges})
p.to_complete_pdag()
return p
# === MUTATORS
def _add_arc(self, i, j):
self._nodes.add(i)
self._nodes.add(j)
self._arcs.add((i, j))
self._neighbors[i].add(j)
self._neighbors[j].add(i)
self._children[i].add(j)
self._parents[j].add(i)
def _add_arcs_from(self, arcs):
if not arcs:
return
sources, sinks = zip(*arcs)
self._nodes.update(sources)
self._nodes.update(sinks)
self._arcs.update(arcs)
for i, j in arcs:
self._neighbors[i].add(j)
self._neighbors[j].add(i)
self._children[i].add(j)
self._parents[j].add(i)
def _add_edges_from(self, edges):
if not edges:
return
s1, s2 = zip(*edges)
self._nodes.update(s1)
self._nodes.update(s2)
self._edges.update(map(frozenset, edges))
for i, j in edges:
self._undirected_neighbors[i].add(j)
self._undirected_neighbors[j].add(i)
self._neighbors[i].add(j)
self._neighbors[j].add(i)
def _add_edge(self, i, j):
self._nodes.add(i)
self._nodes.add(j)
self._edges.add(frozenset({i, j}))
self._neighbors[i].add(j)
self._neighbors[j].add(i)
self._undirected_neighbors[i].add(j)
self._undirected_neighbors[j].add(i)
def remove_edge(self, i, j, ignore_error=False):
try:
self._edges.remove(frozenset({i, j}))
self._neighbors[i].remove(j)
self._neighbors[j].remove(i)
self._undirected_neighbors[i].remove(j)
self._undirected_neighbors[j].remove(i)
except KeyError as e:
if ignore_error:
pass
else:
raise e
def remove_edges_from(self, edges):
for i, j in edges:
self.remove_edge(i, j)
def remove_arc(self, i, j, ignore_error=False):
try:
self._arcs.remove((i, j))
self._children[i].remove(j)
self._parents[j].remove(i)
self._neighbors[i].remove(j)
self._neighbors[j].remove(i)
except KeyError as e:
if ignore_error:
pass
else:
raise e
def remove_arcs_from(self, arcs):
for i, j in arcs:
self.remove_arc(i, j)
[docs] def remove_node(self, node):
"""Remove a node from the graph
"""
self._nodes.remove(node)
self._arcs = {(i, j) for i, j in self._arcs if i != node and j != node}
self._edges = {frozenset({i, j}) for i, j in self._edges if i != node and j != node}
for child in self._children[node]:
self._parents[child].remove(node)
self._neighbors[child].remove(node)
for parent in self._parents[node]:
self._children[parent].remove(node)
self._neighbors[parent].remove(node)
for u_nbr in self._undirected_neighbors[node]:
self._undirected_neighbors[u_nbr].remove(node)
self._neighbors[u_nbr].remove(node)
del self._parents[node]
del self._children[node]
del self._neighbors[node]
del self._undirected_neighbors[node]
def remove_nodes_from(self, nodes):
for node in nodes:
self.remove_node(node)
def remove_all_arcs(self):
self.remove_arcs_from(set(self._arcs))
def replace_edge_with_arc(self, arc, ignore_error=False):
try:
self._replace_edge_with_arc(arc)
except KeyError as e:
if ignore_error:
pass
else:
raise e
def _replace_arc_with_edge(self, arc):
self._arcs.remove(arc)
self._edges.add(frozenset({*arc}))
i, j = arc
self._parents[j].remove(i)
self._children[i].remove(j)
self._undirected_neighbors[i].add(j)
self._undirected_neighbors[j].add(i)
def _replace_edge_with_arc(self, arc):
self._edges.remove(frozenset({*arc}))
self._arcs.add(arc)
i, j = arc
self._parents[j].add(i)
self._children[i].add(j)
self._undirected_neighbors[i].remove(j)
self._undirected_neighbors[j].remove(i)
def assign_parents(self, node, parents, verbose=False):
for p in parents:
self._replace_edge_with_arc((p, node))
for c in self._undirected_neighbors[node] - parents:
self._replace_edge_with_arc((node, c))
self.to_complete_pdag(verbose=verbose)
def to_complete_pdag_new(self, verbose=False):
protected_parents = defaultdict(set)
protected_children = defaultdict(set)
undecided_edges = {(i, j) for i, j in self._edges}
neighbors = self._undirected_neighbors.copy()
while True:
chain_arcs1 = {(i, j) for i, j in undecided_edges if protected_parents[i] - self._neighbors[j]}
undecided_edges -= chain_arcs1
chain_arcs2 = {(j, i) for i, j in undecided_edges if protected_parents[j] - self._neighbors[i]}
undecided_edges -= set(map(reversed, chain_arcs2))
cycle_arcs1 = {(i, j) for i, j in undecided_edges if protected_children[i] & protected_parents[j]}
undecided_edges -= cycle_arcs1
cycle_arcs2 = {(j, i) for i, j in undecided_edges if protected_children[j] & protected_parents[i]}
undecided_edges -= set(map(reversed, cycle_arcs2))
a1 = {
(i, j) for i, j in undecided_edges
if any((not self.has_edge_or_arc(k1, k2)) for k1, k2 in
itr.combinations(neighbors[i] & protected_parents[j], 2))
}
undecided_edges -= a1
a2 = {
(j, i) for i, j in undecided_edges
if any((not self.has_edge_or_arc(k1, k2)) for k1, k2 in
itr.combinations(neighbors[j] & protected_parents[i], 2))
}
undecided_edges -= a1
new_arcs = chain_arcs1 | chain_arcs2 | cycle_arcs1 | cycle_arcs2 | a1 | a2
if len(new_arcs) == 0:
break
for i, j in new_arcs:
protected_children[i].add(j)
protected_parents[j].add(i)
neighbors[i].remove(j)
neighbors[j].remove(i)
def to_complete_pdag(self, verbose=False, solve_conflict=False):
"""
Replace with arcs those edges whose orientations can be determined by Meek rules:
=====
See Koller & Friedman, Algorithm 3.5
"""
if solve_conflict:
raise NotImplementedError
PROTECTED = 'P' # indicates that some configuration definitely exists to protect the edge
UNDECIDED = 'U' # indicates that some configuration exists that could protect the edge
NOT_PROTECTED = 'N' # indicates no possible configuration that could protect the edge
edges1 = {(i, j) for i, j in self._edges}
undecided_arcs = edges1 | {(j, i) for i, j in edges1}
arc_flags = {arc: PROTECTED for arc in self._arcs}
arc_flags.update({arc: UNDECIDED for arc in undecided_arcs})
while undecided_arcs:
for arc in undecided_arcs:
i, j = arc
flag = NOT_PROTECTED
# check configuration (a) -- causal chain
s = ''
for k in self._parents[i]:
if not self.has_edge_or_arc(k, j):
if arc_flags[(k, i)] == PROTECTED:
flag = PROTECTED
s = f': {k}->{i}-{j}'
break
else:
flag = UNDECIDED
if verbose: print(f'{arc} marked {flag} by (a){s}')
# check configuration (b) -- acyclicity
s = ''
if flag != PROTECTED:
for k in self._parents[j]:
if i in self._parents[k]:
if arc_flags[(i, k)] == PROTECTED and arc_flags[(k, j)] == PROTECTED:
flag = PROTECTED
s = f': {k}->{j}-{i}->{k}'
break
else:
flag = UNDECIDED
if verbose: print(f'{arc} marked {flag} by (b){s}')
# check configuration (d)
s = ''
if flag != PROTECTED:
for k1, k2 in itr.combinations(self._parents[j], 2):
if self.has_edge(i, k1) and self.has_edge(i, k2) and not self.has_edge_or_arc(k1, k2):
if arc_flags[(k1, j)] == PROTECTED and arc_flags[(k2, j)] == PROTECTED:
flag = PROTECTED
s = f': {i}-{k1}->{j}<-{k2}-{i}'
break
else:
flag = UNDECIDED
if verbose: print(f'{arc} marked {flag} by (c){s}')
arc_flags[arc] = flag
if all(arc_flags[arc] == NOT_PROTECTED for arc in undecided_arcs): break
for arc in undecided_arcs.copy():
if arc_flags[arc] == PROTECTED:
if not solve_conflict:
if self.has_arc(arc[1], arc[0]): # arc has already been oriented the opposite way
continue
undecided_arcs.remove(arc)
undecided_arcs.remove((arc[1], arc[0]))
self._replace_edge_with_arc(arc)
def remove_unprotected_orientations(self, verbose=False):
"""
Replace with edges those arcs whose orientations cannot be determined by either:
- prior knowledge, or
- Meek rules
=====
See Koller & Friedman, Algorithm 3.5
"""
PROTECTED = 'P' # indicates that some configuration definitely exists to protect the edge
UNDECIDED = 'U' # indicates that some configuration exists that could protect the edge
NOT_PROTECTED = 'N' # indicates no possible configuration that could protect the edge
undecided_arcs = self._arcs - self._known_arcs
arc_flags = {arc: PROTECTED for arc in self._known_arcs}
arc_flags.update({arc: UNDECIDED for arc in undecided_arcs})
while undecided_arcs:
for arc in undecided_arcs:
i, j = arc
flag = NOT_PROTECTED
# check configuration (a) -- causal chain
for k in self._parents[i]:
if not self.has_edge_or_arc(k, j):
if arc_flags[(k, i)] == PROTECTED:
flag = PROTECTED
break
else:
flag = UNDECIDED
if verbose: print('{edge} marked {flag} by (a)'.format(edge=arc, flag=flag))
# check configuration (b) -- acyclicity
if flag != PROTECTED:
for k in self._parents[j]:
if i in self._parents[k]:
if arc_flags[(i, k)] == PROTECTED and arc_flags[(k, j)] == PROTECTED:
flag = PROTECTED
break
else:
flag = UNDECIDED
if verbose: print('{edge} marked {flag} by (b)'.format(edge=arc, flag=flag))
# check configuration (d)
if flag != PROTECTED:
for k1, k2 in itr.combinations(self._parents[j], 2):
if self.has_edge(i, k1) and self.has_edge(i, k2) and not self.has_edge_or_arc(k1, k2):
if arc_flags[(k1, j)] == PROTECTED and arc_flags[(k2, j)] == PROTECTED:
flag = PROTECTED
else:
flag = UNDECIDED
if verbose: print('{edge} marked {flag} by (c)'.format(edge=arc, flag=flag))
arc_flags[arc] = flag
for arc in undecided_arcs.copy():
if arc_flags[arc] != UNDECIDED:
undecided_arcs.remove(arc)
if arc_flags[arc] == NOT_PROTECTED:
self._replace_arc_with_edge(arc)
[docs] def add_known_arc(self, i, j):
if (i, j) in self._known_arcs:
return
self._known_arcs.add((i, j))
self._edges.remove(frozenset({i, j}))
self.remove_unprotected_orientations()
def add_known_arcs(self, arcs):
raise NotImplementedError
# === MUTATORS
def _possible_sinks(self):
return {node for node in self._nodes if len(self._children[node]) == 0}
def _neighbors_covered(self, node):
return {node2: self.neighbors[node2] - {node} == self.neighbors[node] for node2 in self._nodes}
[docs] def to_dag(self):
"""
Return a DAG that is consistent with this CPDAG.
Returns
-------
d
Examples
--------
TODO
"""
from causaldag import DAG
pdag2 = self.copy()
arcs = set()
while len(pdag2._edges) + len(pdag2._arcs) != 0:
is_sink = lambda n: len(pdag2._children[n]) == 0
no_vstructs = lambda n: all(
(pdag2._neighbors[n] - {u_nbr}).issubset(pdag2._neighbors[u_nbr])
for u_nbr in pdag2._undirected_neighbors[n]
)
sink = next((n for n in pdag2._nodes if is_sink(n) and no_vstructs(n)), None)
if sink is None:
break
arcs.update((nbr, sink) for nbr in pdag2._neighbors[sink])
pdag2.remove_node(sink)
return DAG(arcs=arcs)
# === MEC
def mec_size(self):
"""Return the number of DAGs in the MEC represented by this PDAG
"""
if self.num_arcs > 0:
return len(self.all_dags())
if self.num_edges == self.nnodes:
return 2*self.nnodes
elif self.num_edges == self.nnodes - 1:
return self.nnodes
elif self.num_edges == self.nnodes * (self.nnodes - 1) / 2:
return factorial(self.nnodes)
else:
return len(self.all_dags())
def exact_sample(self, save_sampler=True, nsamples=1):
"""Return a DAG sampled uniformly at random from the MEC represented by this PDAG
"""
raise NotImplementedError
[docs] def all_dags(self, verbose=False):
"""Return all DAGs consistent with this PDAG
"""
dag = self.to_dag()
arcs = dag._arcs
all_arcs = set()
orig_reversible_arcs = dag.reversible_arcs() - self._arcs
orig_parents_dict = dag.parents
orig_children_dict = dag.children
level = 0
q = [SmallDag(arcs, orig_reversible_arcs, orig_parents_dict, orig_children_dict, level)]
while q:
dag = q.pop()
all_arcs.add(frozenset(dag.arcs))
for i, j in dag.reversible_arcs:
new_arcs = frozenset({arc for arc in dag.arcs if arc != (i, j)} | {(j, i)})
if new_arcs not in all_arcs:
new_parents_dict = {}
new_children_dict = {}
for node in dag.parents_dict.keys():
parents = set(dag.parents_dict[node])
children = set(dag.children_dict[node])
if node == i:
new_parents_dict[node] = parents | {j}
new_children_dict[node] = children - {j}
elif node == j:
new_parents_dict[node] = parents - {i}
new_children_dict[node] = children | {i}
else:
new_parents_dict[node] = parents
new_children_dict[node] = children
new_reversible_arcs = dag.reversible_arcs.copy()
for k in dag.parents_dict[j]:
if (new_parents_dict[j] - {k}) == new_parents_dict[k] and (k, j) not in self._arcs:
new_reversible_arcs.add((k, j))
else:
new_reversible_arcs.discard((k, j))
for k in dag.children_dict[j]:
if new_parents_dict[j] == (new_parents_dict[k] - {j}) and (j, k) not in self._arcs:
new_reversible_arcs.add((j, k))
else:
new_reversible_arcs.discard((j, k))
for k in dag.parents_dict[i]:
if (new_parents_dict[i] - {k}) == new_parents_dict[k] and (k, i) not in self._arcs:
new_reversible_arcs.add((k, i))
else:
new_reversible_arcs.discard((k, i))
for k in dag.children_dict[i]:
if new_parents_dict[i] == (new_parents_dict[k] - {i}) and (i, k) not in self._arcs:
new_reversible_arcs.add((i, k))
else:
new_reversible_arcs.discard((i, k))
q.append(
SmallDag(new_arcs, new_reversible_arcs, new_parents_dict, new_children_dict, dag.level + 1))
return all_arcs
def is_edge_clique(self, s):
"""
Check if every pair of nodes in s is adjacent.
"""
return all(self.has_edge(i, j) for i, j in itr.combinations(s, 2))
def possible_parents(self, node) -> Iterable:
return core_utils.powerset_predicate(self._undirected_neighbors[node], self.is_edge_clique)
# === COMPARISON
[docs] def shd(self, other):
"""Return the structural Hamming distance between this PDAG and another.
For each pair of nodes, the SHD is incremented by 1 if the edge type/presence between the two nodes is different
"""
self_undirected = {frozenset({*arc}) for arc in self._arcs} | self._edges
other_undirected = {frozenset({*arc}) for arc in other._arcs} | other._edges
num_additions = len(self_undirected - other_undirected)
num_deletions = len(other_undirected - self_undirected)
diff_type = {
(i, j) for i, j in self_undirected & other_undirected
if ((i, j) in self._arcs and (i, j) not in other._arcs) or
((j, i) in self._arcs and (j, i) not in other._arcs) or
(frozenset({i, j}) in self._edges and frozenset({i, j}) not in other._edges)
}
return num_additions + num_deletions + len(diff_type)
def shd_skeleton(self, other) -> int:
return len(self.skeleton.symmetric_difference(other.skeleton))
if __name__ == '__main__':
from causaldag.rand import directed_erdos
g = directed_erdos(10, .5)
c = g.cpdag()
a1 = c.to_amat()
a2, _ = c.to_amat(mode='numpy')
a3, _ = c.to_amat(mode='sparse')