from collections import defaultdict
from causaldag.utils import core_utils
import itertools as itr
import numpy as np
import random
from typing import List, Iterable, Set, Dict, Hashable, Tuple, FrozenSet, Union
from causaldag.classes.custom_types import Node, DirectedEdge, BidirectedEdge, UndirectedEdge, NodeSet, warn_untested
class CycleError(Exception):
def __init__(self):
super().__init__()
# def __init__(self, source, target):
# self.source = source
# self.target = target
# message = '%s -> %s will cause a cycle' % (source, target)
# super().__init__(message)
class SpouseError(Exception):
def __init__(self):
super().__init__()
# def __init__(self, ancestor, desc):
# self.ancestor = ancestor
# self.desc = desc
# message = '%s <-> %s cannot be added since %s is an ancestor of %s' % (ancestor, desc, ancestor, desc)
# super().__init__(message)
class AdjacentError(Exception):
def __init__(self, node1, node2, arrow_type):
self.node1 = node1
self.node2 = node2
self.arrow_type = arrow_type
message = '%s %s %s cannot be added since %s and %s are already adjacent' % (
node1, arrow_type, node2, node1, node2)
super().__init__(message)
class NeighborError(Exception):
def __init__(self, node, neighbors=None, parents=None, spouses=None):
self.node = node
self.neighbors = neighbors
self.parents = parents
self.spouses = spouses
if self.neighbors:
message = 'The node %s has neighbors %s. Nodes cannot have neighbors and parents/spouses.' % (
node, ','.join(map(str, neighbors)))
elif self.parents:
message = 'The node %s has parents %s. Nodes cannot have neighbors and parents/spouses.' % (
node, ','.join(map(str, parents)))
elif self.spouses:
message = 'The node %s has spouses %s. Nodes cannot have neighbors and parents/spouses.' % (
node, ','.join(map(str, spouses)))
super().__init__(message)
def path2str(path):
return '->'.join(map(str, path))
[docs]class AncestralGraph:
"""
Base class for ancestral graphs, used to represent causal models with latent variables.
"""
def __init__(
self,
nodes: Set = frozenset(),
directed: Set = frozenset(),
bidirected: Set = frozenset(),
undirected: Set = frozenset()
):
self._nodes = nodes.copy()
self._directed = set()
self._bidirected = set()
self._undirected = set()
self._neighbors = defaultdict(set)
self._spouses = defaultdict(set)
self._parents = defaultdict(set)
self._children = defaultdict(set)
self._adjacent = defaultdict(set)
for i, j in directed:
self._add_directed(i, j)
for i, j in bidirected:
self._add_bidirected(i, j)
for i, j in undirected:
self._add_undirected(i, j)
def __eq__(self, other):
if not isinstance(other, AncestralGraph):
return False
same_nodes = self._nodes == other._nodes
same_directed = self._directed == other._directed
same_bidirected = self._bidirected == other._bidirected
same_undirected = self._undirected == other._undirected
return same_nodes and same_directed and same_bidirected and same_undirected
[docs] def copy(self):
"""
Return a copy of this ancestral graph.
Returns
-------
AncestralGraph:
A copy of the ancestral graph.
"""
return AncestralGraph(self.nodes, self.directed, self.bidirected, self.undirected)
def induced_subgraph(self, nodes: Set[Node]):
"""
Return the induced subgraph over only ``nodes``
Parameters
----------
nodes:
Set of nodes for the induced subgraph.
Returns
-------
AncestralGraph:
Induced subgraph over ``nodes``.
Examples
--------
>>> import causaldag as cd
>>> g = cd.AncestralGraph(bidirected={(1, 2), (1, 4)}, directed={(1, 3), (2, 3)})
>>> g.induced_subgraph({1, 2, 3})
Directed edges: {(2, 3), (1, 3)}, Bidirected edges: {frozenset({1, 2})}, Undirected edges: set()
"""
new_directed = {(i, j) for i, j in self._directed if i in nodes and j in nodes}
new_bidirected = {(i, j) for i, j in self._bidirected if i in nodes and j in nodes}
new_undirected = {(i, j) for i, j in self._undirected if i in nodes and j in nodes}
return AncestralGraph(nodes, directed=new_directed, bidirected=new_bidirected, undirected=new_undirected)
def __str__(self):
return 'Directed edges: %s, Bidirected edges: %s, Undirected edges: %s' % (
self._directed, self._bidirected, self._undirected)
def __repr__(self):
return str(self)
# === MUTATORS
[docs] def add_node(self, node: Node):
"""
Add a node to the ancestral graph.
Parameters
----------
node:
a hashable Python object
See Also
--------
add_nodes_from
Examples
--------
>>> import causaldag as cd
>>> g = cd.AncestralGraph()
>>> g.add_node(1)
>>> g.add_node(2)
>>> len(g.nodes)
2
"""
self._nodes.add(node)
[docs] def add_nodes_from(self, nodes: Iterable[Node]):
"""
Add nodes to the ancestral graph.
Parameters
----------
nodes:
an iterable of hashable Python objects
See Also
--------
add_node
Examples
--------
>>> import causaldag as cd
>>> g = cd.AncestralGraph()
>>> g.add_nodes_from({1, 2})
>>> len(g.nodes)
2
"""
for node in nodes:
self._nodes.add(node)
def _check_ancestral(self):
self.topological_sort()
def _mark_children_visited(self, node, any_visited, curr_path_visited, curr_path, stack):
any_visited[node] = True
curr_path_visited[node] = True
curr_path.append(node)
for child in self._children[node]:
if not any_visited[child]:
self._mark_children_visited(child, any_visited, curr_path_visited, curr_path, stack)
elif curr_path_visited[child]:
cycle = curr_path + [child]
raise CycleError
for spouse in self._spouses[node]:
if curr_path_visited[spouse]:
raise SpouseError
curr_path.pop()
curr_path_visited[node] = False
stack.append(node)
def topological_sort(self) -> list:
"""
Return a linear order that is consistent with the partial order implied by ancestral relations of this graph.
Examples
--------
>>> import causaldag as cd
>>> g = cd.AncestralGraph(bidirected={(1, 2), (1, 4)}, directed={(1, 3), (2, 3)})
>>> g.topological_sort()
[4, 2, 1, 3]
"""
any_visited = {node: False for node in self._nodes}
curr_path_visited = {node: False for node in self._nodes}
curr_path = []
stack = []
for node in self._nodes:
if not any_visited[node]:
self._mark_children_visited(node, any_visited, curr_path_visited, curr_path, stack)
return list(reversed(stack))
[docs] def add_directed(self, i: Node, j: Node):
"""
Add a directed edge from node ``i`` to node ``j``.
Parameters
----------
i:
source of directed edge.
j:
target of directed edge.
Examples
--------
>>> import causaldag as cd
>>> g = cd.AncestralGraph()
>>> g.add_directed(1, 2)
>>> g.directed
{(1, 2)}
"""
self._add_directed(i, j)
try:
self._check_ancestral()
except CycleError as e:
self.remove_directed(i, j)
raise e
[docs] def add_bidirected(self, i: Node, j: Node):
"""
Add a bidirected edge between nodes ``i`` and ``j``.
Parameters
----------
i:
first endpoint of bidirected edge.
j:
second endpoint of bidirected edge.
Examples
--------
>>> import causaldag as cd
>>> g = cd.AncestralGraph()
>>> g.add_bidirected(1, 2)
>>> g.bidirected
{frozenset({i, j})}
"""
self._add_bidirected(i, j)
try:
self._add_bidirected(i, j)
except CycleError as e:
self.remove_bidirected(i, j)
raise e
[docs] def add_undirected(self, i: Node, j: Node):
"""
Add an undirected edge between nodes `i` and `j`.
Parameters
----------
i:
first endpoint of undirected edge.
j:
second endpoint of undirected edge.
Examples
--------
>>> import causaldag as cd
>>> g = cd.AncestralGraph()
>>> g.add_undirected(1, 2)
>>> g.undirected
{frozenset({i, j})}
"""
self._add_undirected(i, j)
def _add_directed(self, i: Node, j: Node, ignore_error=False):
if self.has_directed(i, j):
return
# === CHECK REMAINS ANCESTRAL
if not ignore_error and self._neighbors[j]:
raise NeighborError(j, self._neighbors[j])
# === CHECK i AND j NOT ALREADY ADJACENT
if i in self._adjacent[j]:
if ignore_error:
if self.has_directed(j, i):
self.remove_directed(j, i)
elif self.has_bidirected(i, j):
self.remove_bidirected(i, j)
else:
self.remove_undirected(i, j)
else:
raise AdjacentError(i, j, '->')
self._nodes.add(i)
self._nodes.add(j)
self._directed.add((i, j))
self._parents[j].add(i)
self._children[i].add(j)
self._adjacent[i].add(j)
self._adjacent[j].add(i)
def _add_bidirected(self, i: Node, j: Node, ignore_error=False):
if self.has_bidirected(i, j):
return
# === CHECK REMAINS ANCESTRAL
if not ignore_error and self._neighbors[i]:
raise NeighborError(i, neighbors=self._neighbors[i])
if not ignore_error and self._neighbors[j]:
raise NeighborError(j, neighbors=self._neighbors[j])
# === CHECK i AND j NOT ALREADY ADJACENT
if i in self._adjacent[j]:
if ignore_error:
if self.has_directed(i, j):
self.remove_directed(i, j)
elif self.has_directed(j, i):
self.remove_directed(j, i)
else:
self.remove_undirected(i, j)
else:
raise AdjacentError(i, j, '<->')
self._nodes.add(i)
self._nodes.add(j)
self._bidirected.add(frozenset({i, j}))
self._spouses[j].add(i)
self._spouses[i].add(j)
self._adjacent[i].add(j)
self._adjacent[j].add(i)
def _add_undirected(self, i: Node, j: Node, ignore_error=False):
if self.has_undirected(i, j):
return
# === CHECK REMAINS ANCESTRAL
if self._parents[i]:
raise NeighborError(i, parents=self._parents[i])
if self._spouses[i]:
raise NeighborError(i, spouses=self._spouses[i])
if self._parents[j]:
raise NeighborError(j, parents=self._parents[j])
if self._spouses[j]:
raise NeighborError(j, spouses=self._spouses[j])
# === CHECK i AND j NOT ALREADY ADJACENT
if i in self._adjacent[j]:
if ignore_error:
if self.has_directed(i, j):
self.remove_directed(i, j)
elif self.has_directed(j, i):
self.remove_directed(j, i)
else:
self.remove_bidirected(i, j)
else:
raise AdjacentError(i, j, '-')
self._nodes.add(i)
self._nodes.add(j)
self._undirected.add(frozenset({i, j}))
self._neighbors[j].add(i)
self._neighbors[i].add(j)
self._adjacent[i].add(j)
self._adjacent[j].add(i)
[docs] def remove_node(self, node: Node, ignore_error=False):
"""
Remove ``node``.
Parameters
----------
node
The node to be removed.
ignore_error:
If False, raises an error when the node does not belong to the graph.
Examples
--------
>>> import causaldag as cd
>>> g = cd.AncestralGraph(bidirected={(1, 2), (1, 4)}, directed={(1, 3), (2, 3)})
>>> g.remove_node(4)
>>> g
Directed edges: {(2, 3), (1, 3)}, Bidirected edges: {frozenset({1, 2})}, Undirected edges: set()
"""
try:
self._nodes.remove(node)
for parent in self._parents[node]:
self._children[parent].remove(node)
self._adjacent[parent].remove(node)
self._directed.remove((parent, node))
for child in self._children[node]:
self._parents[child].remove(node)
self._adjacent[child].remove(node)
self._directed.remove((node, child))
for spouse in self._spouses[node]:
self._spouses[spouse].remove(node)
self._adjacent[spouse].remove(node)
self._bidirected.remove(frozenset({spouse, node}))
for nbr in self._neighbors[node]:
self._neighbors[nbr].remove(node)
self._adjacent[nbr].remove(node)
self._undirected.remove(frozenset({nbr, node}))
del self._children[node]
del self._parents[node]
del self._spouses[node]
del self._neighbors[node]
del self._adjacent[node]
except KeyError as e:
if ignore_error:
pass
else:
raise e
[docs] def remove_directed(self, i: Node, j: Node, ignore_error=False):
"""
Remove the directed edge from ``i`` to ``j``.
Parameters
----------
i:
source of directed edge.
j:
target of directed edge.
ignore_error:
If False, raises an error when the directed edge does not belong to the graph.
Examples
--------
>>> import causaldag as cd
>>> g = cd.AncestralGraph(bidirected={(1, 2), (1, 4)}, directed={(1, 3), (2, 3)})
>>> g.remove_directed(1, 3)
>>> g
Directed edges: {(2, 3)}, Bidirected edges: {frozenset({1, 4}), frozenset({1, 2})}, Undirected edges: set()
"""
try:
self._directed.remove((i, j))
self._children[i].remove(j)
self._parents[j].remove(i)
self._adjacent[i].remove(j)
self._adjacent[j].remove(i)
except KeyError as e:
if ignore_error:
pass
else:
raise e
[docs] def remove_bidirected(self, i: Node, j: Node, ignore_error=False):
"""
Remove the bidirected edge between ``i`` and ``j``.
Parameters
----------
i:
first endpoint of bidirected edge.
j:
second endpoint of bidirected edge.
ignore_error:
If False, raises an error when the bidirected edge does not belong to the graph.
Examples
--------
>>> import causaldag as cd
>>> g = cd.AncestralGraph(bidirected={(1, 2), (1, 4)}, directed={(1, 3), (2, 3)})
>>> g.remove_bidirected(1, 2)
>>> g
Directed edges: {(2, 3), (1, 3)}, Bidirected edges: {frozenset({1, 4})}, Undirected edges: set()
"""
try:
self._bidirected.remove(frozenset({i, j}))
self._spouses[i].remove(j)
self._spouses[j].remove(i)
self._adjacent[i].remove(j)
self._adjacent[j].remove(i)
except KeyError as e:
if ignore_error:
pass
else:
raise e
[docs] def remove_undirected(self, i: Node, j: Node, ignore_error=False):
"""
Remove the undirected edge between ``i`` and ``j``.
Parameters
----------
i:
first endpoint of undirected edge.
j:
second endpoint of undirected edge.
ignore_error:
If False, raises an error when the undirected edge does not belong to the graph.
Examples
--------
>>> import causaldag as cd
>>> g = cd.AncestralGraph(directed={(1, 2), (1, 3)}, undirected={(1, 4)})
>>> g.remove_undirected(1, 4)
>>> g
Directed edges: {(1, 2), (1, 3)}, Bidirected edges: set(), Undirected edges: set()
"""
try:
self._undirected.remove(frozenset({i, j}))
self._neighbors[i].remove(j)
self._neighbors[j].remove(i)
self._adjacent[i].remove(j)
self._adjacent[j].remove(i)
except KeyError as e:
if ignore_error:
pass
else:
raise e
[docs] def remove_edge(self, i: Node, j: Node, ignore_error=False):
"""
Remove the edge between ``i`` and ``j``, regardless of edge type.
Parameters
----------
i:
first endpoint of edge.
j:
second endpoint of edge.
ignore_error:
If False, raises an error when the edge does not belong to the graph.
Examples
--------
>>> import causaldag as cd
>>> g = cd.AncestralGraph(directed={(1, 2), (1, 3)}, undirected={(1, 4)})
>>> g.remove_edge(1, 4)
>>> g
Directed edges: {(1, 2), (1, 3)}, Bidirected edges: set(), Undirected edges: set()
"""
if self.has_bidirected(i, j):
self.remove_bidirected(i, j)
elif self.has_directed(i, j):
self.remove_directed(i, j)
elif self.has_directed(j, i):
self.remove_directed(j, i)
elif self.has_undirected(i, j):
self.remove_undirected(i, j)
elif not ignore_error:
raise KeyError
[docs] def remove_edges(self, edges: Iterable, ignore_error=False):
"""
Remove all edges in ``edges`` from the graph, regardless of edge type.
Parameters
----------
edges
The edges to be removed from the graph.
ignore_error:
If False, raises an error when any edge does not belong to the graph.
Examples
--------
>>> import causaldag as cd
>>> g = cd.AncestralGraph(directed={(1, 2), (1, 3)}, undirected={(1, 4)})
>>> g.remove_edges([(1, 4), (1, 2)])
>>> g
Directed edges: {(1, 3)}, Bidirected edges: set(), Undirected edges: set()
"""
for i, j in edges:
self.remove_edge(i, j, ignore_error=ignore_error)
# === PROPERTIES
@property
def nodes(self) -> Set[Node]:
return self._nodes.copy()
@property
def nnodes(self) -> int:
return len(self._nodes)
@property
def directed(self) -> Set[DirectedEdge]:
return self._directed.copy()
@property
def num_directed(self) -> int:
return len(self._directed)
@property
def bidirected(self) -> Set[BidirectedEdge]:
return self._bidirected.copy()
@property
def num_bidirected(self) -> int:
return len(self._bidirected)
@property
def undirected(self) -> Set[UndirectedEdge]:
return self._undirected.copy()
@property
def num_undirected(self) -> int:
return len(self._undirected)
@property
def num_edges(self) -> int:
return self.num_directed + self.num_bidirected + self.num_undirected
@property
def skeleton(self) -> Set[UndirectedEdge]:
return {frozenset({i, j}) for i, j in self._bidirected | self._undirected | self._directed}
[docs] def children_of(self, i: NodeSet) -> Set[Node]:
"""
Return the children of the node or set of nodes ``i``.
Parameters
----------
i
Node.
Examples
--------
>>> import causaldag as cd
>>> g = cd.AncestralGraph(directed={(1, 2), (2, 3)}, undirected={(1, 4)})
>>> g.children_of(1)
{2}
>>> g.children_of({1, 2})
{2, 3}
"""
if isinstance(i, set):
return set.union(*(self._children[n] for n in i))
else:
return self._children[i].copy()
[docs] def parents_of(self, nodes: NodeSet) -> Set[Node]:
"""
Return the parents of the node or set of nodes ``nodes``.
Parameters
----------
nodes
Nodes.
Examples
--------
>>> import causaldag as cd
>>> g = cd.AncestralGraph(directed={(1, 2), (2, 3)}, undirected={(1, 4)})
>>> g.parents_of(2)
{1}
>>> g.parents_of({2, 3})
{1, 2}
"""
if isinstance(nodes, set):
return set.union(*(self._parents[n] for n in nodes))
else:
return self._parents[nodes].copy()
[docs] def spouses_of(self, nodes: NodeSet) -> Set[Node]:
"""
Return the spouses of the node or set of nodes ``nodes``.
Parameters
----------
nodes
Nodes.
Examples
--------
>>> import causaldag as cd
>>> g = cd.AncestralGraph(directed={(1, 2), (2, 3)}, bidirected={(1, 4), (2, 5)})
>>> g.spouses_of(1)
{4}
>>> g.spouses_of({1, 2})
{4, 5}
"""
if isinstance(nodes, set):
return set.union(*(self._spouses[n] for n in nodes))
else:
return self._spouses[nodes].copy()
[docs] def neighbors_of(self, nodes: NodeSet) -> Set[Node]:
"""
Return the neighbors of the node or set of nodes ``nodes``.
Parameters
----------
nodes
Nodes.
Examples
--------
>>> import causaldag as cd
>>> g = cd.AncestralGraph(directed={(1, 3), (2, 3)}, undirected={(1, 4), (2, 5)})
>>> g.neighbors_of(1)
{4}
>>> g.neighbors_of({1, 2})
{4, 5}
"""
if isinstance(nodes, set):
return set.union(*(self._neighbors[n] for n in nodes))
else:
return self._neighbors[nodes].copy()
def _add_ancestors(self, ancestors, node, exclude_arcs=set()):
for parent in self._parents[node]:
if parent not in ancestors and (parent, node) not in exclude_arcs:
ancestors.add(parent)
self._add_ancestors(ancestors, parent, exclude_arcs=exclude_arcs)
def _add_descendants(self, descendants, node, exclude_arcs=set()):
for child in self._children[node]:
if child not in descendants and (child, node) not in exclude_arcs:
descendants.add(child)
self._add_descendants(descendants, child, exclude_arcs=exclude_arcs)
[docs] def ancestors_of(self, nodes: NodeSet, exclude_arcs=set()) -> Set[Node]:
"""
Return the ancestors of the node or set of nodes ``nodes``.
Parameters
----------
nodes:
Set of nodes.
exclude_arcs:
TODO
See Also
--------
descendants_of
Return
------
Set[node]
Return all nodes j such that there is a directed path from j to node.
Example
-------
TODO
"""
ancestors = set()
if not isinstance(nodes, set):
self._add_ancestors(ancestors, nodes, exclude_arcs=exclude_arcs)
else:
return set.union(*(self.ancestors_of(node) for node in nodes))
return ancestors
def ancestor_dict(self) -> dict:
"""
Return a dictionary from each node to its ancestors.
See Also
--------
ancestors_of
Return
------
Dict[node,Set]
Mapping node to ancestors
Example
-------
TODO
"""
top_sort = self.topological_sort()
node2ancestors_plus_self = defaultdict(set)
for node in top_sort:
node2ancestors_plus_self[node].add(node)
for child in self._children[node]:
node2ancestors_plus_self[child].update(node2ancestors_plus_self[node])
for node in self._nodes:
node2ancestors_plus_self[node] -= {node}
return core_utils.defdict2dict(node2ancestors_plus_self, self._nodes)
def descendant_dict(self) -> dict:
"""
Return a dictionary from each node to its descendants.
See Also
--------
ancestors_of
Return
------
Dict[node,Set]
Mapping node to ancestors
Example
-------
TODO
"""
top_sort = self.topological_sort()
node2descendants_plus_self = defaultdict(set)
for node in reversed(top_sort):
node2descendants_plus_self[node].add(node)
for parent in self._parents[node]:
node2descendants_plus_self[parent].update(node2descendants_plus_self[node])
for node in self._nodes:
node2descendants_plus_self[node] -= {node}
return core_utils.defdict2dict(node2descendants_plus_self, self._nodes)
[docs] def descendants_of(self, nodes: NodeSet, exclude_arcs=set()) -> Set[Node]:
"""
Return the descendants of the node or set of nodes ``nodes``.
Parameters
----------
nodes:
The nodes.
See Also
--------
ancestors_of
Return
------
Set[node]
Return all nodes j such that there is a directed path from node j.
Example
-------
TODO
"""
descendants = set()
if not isinstance(nodes, set):
self._add_descendants(descendants, nodes, exclude_arcs=exclude_arcs)
else:
return set.union(*(self.descendants_of(node) for node in nodes))
return descendants
def has_directed(self, i: Node, j: Node) -> bool:
"""
Check if this graph has the directed edge ``i``->``j``.
See Also
--------
has_bidirected
has_undirected
has_any_edge
Parameters
----------
i:
Node.
j:
Node.
Examples
--------
TODO
"""
return (i, j) in self._directed
def has_bidirected(self, i: Node, j: Node) -> bool:
"""
Check if this graph has a bidirected edge between ``i`` and ``j``.
See Also
--------
has_directed
has_undirected
has_any_edge
Parameters
----------
i:
Node.
j:
Node.
Examples
--------
TODO
"""
return frozenset({i, j}) in self._bidirected
def has_undirected(self, i: Node, j: Node) -> bool:
"""
Check if this graph has an undirected edge between ``i`` and ``j``.
See Also
--------
has_directed
has_bidirected
has_any_edge
Parameters
----------
i:
Node.
j:
Node.
Examples
--------
TODO
"""
return frozenset({i, j}) in self._undirected
def has_any_edge(self, i: Node, j: Node) -> bool:
"""
Check if ``i`` and ``j`` are adjacent in this graph.
See Also
--------
has_directed
has_bidirected
has_undirected
Parameters
----------
i:
Node.
j:
Node.
Examples
--------
TODO
"""
return self.has_directed(i, j) or self.has_directed(j, i) or self.has_bidirected(i, j) or self.has_undirected(i,
j)
def vstructures(self) -> Set[Tuple]:
"""
TODO
Examples
--------
TODO
"""
vstructs = set()
for node in self._nodes:
for p1, p2 in itr.combinations(self._parents[node] | self._spouses[node], 2):
if not self.has_any_edge(p1, p2):
p1_, p2_ = sorted((p1, p2))
vstructs.add((p1_, node, p2_))
return vstructs
def colliders(self) -> set:
"""
TODO
Examples
--------
TODO
"""
return {node for node in self._nodes if len(self._parents[node] | self._spouses[node]) >= 2}
def _bidirected_reachable(self, node, tmp: Set[Node], visited: Set[Node], node_subset=None) -> Set[Node]:
node_subset = self._nodes if node_subset is None else node_subset
visited.add(node)
tmp.add(node)
for spouse in filter(lambda spouse: spouse not in visited, self._spouses[node] & node_subset):
tmp = self._bidirected_reachable(spouse, tmp, visited, node_subset=node_subset)
return tmp
def c_components(self) -> List[set]:
"""
Return the c-components of this graph.
Return
------
List[Set[node]]
Return the partition of nodes coming from the relation of reachability by bidirected edges.
Examples
--------
TODO
"""
node_queue = self._nodes.copy()
components = []
visited_nodes = set()
while node_queue:
node = node_queue.pop()
if node not in visited_nodes:
components.append(self._bidirected_reachable(node, set(), visited_nodes))
return components
def district_of(self, node: Node, node_subset=None) -> Set[Node]:
"""
Return the district of a node, i.e., the set of nodes reachable by bidirected edges. If ``node_subset`` is
provided, do this on the induced subgraph on that subset of nodes.
Return
------
Set[node]
The district of node.
Examples
--------
TODO
"""
return self._bidirected_reachable(node, set(), set(), node_subset=node_subset)
def discriminating_paths(self, verbose=False) -> Dict[Tuple, str]:
"""
TODO
Parameters
----------
TODO
Examples
--------
TODO
"""
colliders = self.colliders()
discriminating_paths = {}
if verbose: print("Checking discriminating paths")
for j, parents in self._parents.items(): # potential endpoints of discriminating paths
if verbose: print(j)
if not parents:
continue
nonadjacent = self._nodes - parents - self._children[j] - self._spouses[j] - {j}
if verbose: print(f"Checking node {j} and non-adjacent nodes {nonadjacent}")
for i in nonadjacent: # potential start points of discriminating paths
# search all paths that satisfy discriminating path criteria
path_queue = [
[i, k]
for k in self._spouses[i] | self._children[i]
if k in colliders and j in self._children[k]
]
while path_queue:
path = path_queue.pop(0)
final_node = path[-1]
# check if path is discriminating for the next node
for k in filter(lambda k: k not in path, self._spouses[final_node]):
if j in self._spouses[k]:
full_path = path.copy()
full_path.extend([k, j])
discriminating_paths[tuple(full_path)] = 'c'
elif j in self._children[k]:
full_path = path.copy()
full_path.extend([k, j])
discriminating_paths[tuple(full_path)] = 'n'
for k in filter(lambda k: k not in path, self._parents[final_node]):
if j in self._children[k]:
full_path = path.copy()
full_path.extend([k, j])
discriminating_paths[tuple(full_path)] = 'n'
# extend path
for k in self._spouses[final_node]:
if k not in path and k in colliders and j in self._children[k]:
new_path = path.copy()
new_path.append(k)
path_queue.append(new_path)
return discriminating_paths
def _reachable(self, start_node, end_node, visited=set(), allowed_edges={'b', 'u', 'c', 'p'},
predicate=lambda node: True, verbose=False):
allowed_nbrs = set()
if 'b' in allowed_edges:
allowed_nbrs.update(self._spouses[start_node])
if 'u' in allowed_edges:
allowed_nbrs.update(self._neighbors[start_node])
if 'c' in allowed_edges:
allowed_nbrs.update(self._children[start_node])
if 'p' in allowed_edges:
allowed_nbrs.update(self._parents[start_node])
allowed_nbrs = {nbr for nbr in allowed_nbrs if predicate(nbr)}
if verbose: print(f"Allowed neighbors of {start_node}: {allowed_nbrs}")
if verbose: print(f"Visited: {visited}")
results = []
for nbr in allowed_nbrs:
if nbr in visited:
continue
visited.add(nbr)
if nbr == end_node:
if verbose: print("Reached end node")
return True
results.append(
self._reachable(nbr, end_node, visited=visited, allowed_edges=allowed_edges, predicate=predicate,
verbose=verbose))
if verbose: print("reachability results:", results)
return any(results)
# === ???
def pairwise_markov_statements(self) -> Set[Tuple[Node, Node, FrozenSet[Node]]]:
"""
TODO
Examples
--------
TODO
"""
statements = set()
for i, j in itr.combinations(self._nodes, 2):
if not self.has_any_edge(i, j):
statements.add((i, j, frozenset(self.ancestors_of(i) | self.ancestors_of(j) - {i, j})))
return statements
def is_imap(self, other, certify: bool = False) -> bool:
"""
Check if this graph is an IMAP of the graph ``other``, i.e., all m-separation statements in this graph
are also m-separation statements in ``other``.
Parameters
----------
other:
Another DAG.
certify:
TODO
See Also
--------
is_minimal_imap
Examples
--------
>>> import causaldag as cd
>>> g = cd.AncestralGraph(arcs={(1, 2), (3, 2)})
TODO
"""
if not self.is_maximal():
raise Exception("Your graph is not maximal")
certificate = next(
((i, j, S) for i, j, S in self.pairwise_markov_statements() if not other.msep(i, j, S)),
None)
is_imap_ = certificate is None
if certify:
return is_imap_, certificate
else:
return is_imap_
# def is_minimal_imap(self, other, certify=False):
# print("THIS HAS NOT BEEN TESTED")
# certificate = next((
# i, j for i, j in self._directed | self._bidirected
# if other.msep(i, j, self.ancestors_of(i) | self.ancestors_of(j) - {i, j})
# ), False)
# res = not certificate and self.is_imap(other)
# if not certify:
# return res
# else:
# return res, certificate
def is_minimal_imap(self, other, certify: bool = False, check_imap=True) -> bool:
"""
TODO
Parameters
----------
TODO
Examples
--------
TODO
"""
if check_imap and not self.is_imap(other):
return False, None
for i, j in random.sample(list(self._directed) + list(self._bidirected),
self.num_bidirected + self.num_directed):
new_mag = self.copy()
if self.has_bidirected(i, j):
new_mag.remove_bidirected(i, j)
if self.has_directed(i, j):
new_mag.remove_directed(i, j)
if new_mag.is_maximal() and new_mag.is_imap(other):
if certify:
return False, (i, j)
else:
return False
if certify:
return True, None
else:
return True
def is_minimal_imap2(self, other, certify=False, check_imap=True, validate=False):
if check_imap and not self.is_imap(other):
return False, None
for i, j in random.sample(list(self._directed) + list(self._bidirected),
self.num_directed + self.num_bidirected):
if other.msep(i, j, self.ancestors_of(i) | self.ancestors_of(j) - {i, j}):
new_mag = self.copy()
if self.has_bidirected(i, j):
new_mag.remove_bidirected(i, j)
else:
new_mag.remove_directed(i, j)
if new_mag.is_maximal():
if validate:
if not new_mag.is_imap(other):
raise Exception
if certify:
return False, (i, j)
else:
return False
if certify:
return True, None
else:
return True
def is_minimal_imap3(self, other, certify=False, check_imap=True, validate=False, verbose=False):
if check_imap and not self.is_imap(other):
return False, None
for i, j in random.sample(list(self._directed) + list(self._bidirected),
self.num_directed + self.num_bidirected):
new_mag = self.copy()
if self.has_bidirected(i, j):
new_mag.remove_bidirected(i, j)
else:
new_mag.remove_directed(i, j)
current_markov_blanket = set.union(*(set(v) for v in self.markov_blanket(j).values())) | self.district_of(j)
new_markov_blanket = set.union(*(set(v) for v in new_mag.markov_blanket(j).values())) | new_mag.district_of(
j)
mb_difference = (current_markov_blanket - new_markov_blanket - {j}) | {i}
rest = new_markov_blanket - {i, j}
if verbose: print(f'i={i}, j={j}, mb_diff={mb_difference}, rest={rest}')
if verbose: print("H", self)
if verbose: print("G", other)
if other.msep(i, mb_difference, rest) and new_mag.is_maximal():
print('here')
if validate:
if not new_mag.is_imap(other):
raise Exception
if certify:
return False, (i, j)
else:
return False
if certify:
return True, None
else:
return True
def is_minimal_imap4(self, other, certify=False, check_imap=True, validate=False, extra_validate=False,
verbose=False):
if check_imap and not self.is_imap(other):
raise Exception("Not an IMAP")
print("isn't imap")
return False, None
if extra_validate:
for i, j in self._directed | self._bidirected:
new_mag = self.copy()
new_mag.remove_edge(i, j)
s = new_mag.induced_subgraph(new_mag.ancestors_of({i, j}) | {i, j}).markov_blanket(j, flat=True) - {j}
if other.msep(i, j, s) and new_mag.is_maximal():
if not new_mag.is_imap(other):
raise Exception("CI test not sufficient", new_mag, other, i, j, s)
print('extra validated')
for i, j in random.sample(list(self._directed) + list(self._bidirected),
self.num_directed + self.num_bidirected):
change = False
new_mag = self.copy()
new_mag.remove_edge(i, j)
# works:
set1 = (new_mag.markov_blanket(j, flat=True) & new_mag.ancestors_of(i)) | new_mag.parents_of(j)
remove_edge = other.msep(i, j, set1)
# new:
set2 = new_mag.induced_subgraph(new_mag.ancestors_of({i, j}) | {i, j}).markov_blanket(j, flat=True) - {i, j}
remove_edge2 = other.msep(i, j, set2)
# print(i, j, set2)
set3 = new_mag.markov_blanket(j, flat=True) & new_mag.ancestors_of({j}) - {j}
# if set2 != set3:
# print(new_mag, j, set2, set3)
if remove_edge2:
change = True
# if self.has_bidirected(i, j) and other.msep(i, j, self.parents_of(j)) and other.msep(i, j, self.parents_of(i)):
# new_mag = self.copy()
# new_mag.remove_bidirected(i, j)
# change = True
# elif self.has_directed(i, j) and other.msep(i, j, self.parents_of(j) - {i}):
# new_mag = self.copy()
# new_mag.remove_directed(i, j)
# change = True
if change and new_mag.is_maximal():
if validate:
if not new_mag.is_imap(other):
raise Exception("CI test isn't sufficient: new MAG is not an IMAP")
if certify:
return False, (i, j)
else:
return False
if certify:
return True, None
else:
return True
def markov_blanket(self, node, flat: bool = False) -> Union[Set[Node], Dict]:
"""
Return the Markov blanket of a node with respect to the whole graph.
Parameters
----------
node:
The node whose Markov blanket to find.
flat:
if ``True``, return the Markov blanket as a set, otherwise return a dictionary mapping nodes in the district
of node to their parents.
Returns
-------
The Markov blanket of node, including the node itself.
"""
if not flat:
return {d: self._parents[d] for d in self.district_of(node)}
else:
district = self.district_of(node)
return district | set.union(*(self._parents[d] for d in district)) | {node}
def resolved_quasisinks(self, other):
res_qsinks = set()
while True:
new_resolved = {
node for node in self._nodes - res_qsinks
if not (self._children[node] - res_qsinks) and
not (other._children[node] - res_qsinks) and
self.markov_blanket(node) == other.markov_blanket_of(node)
}
res_qsinks.update(new_resolved)
if not new_resolved:
break
return res_qsinks
def is_maximal(self, new=True, verbose=False) -> bool:
"""
TODO
Parameters
----------
TODO
Examples
--------
TODO
"""
new_mag = self.copy()
new_mag.to_maximal(new=new, verbose=verbose)
return new_mag == self
def to_maximal(self, new=True, verbose=False):
"""
TODO
Parameters
----------
TODO
Examples
--------
TODO
"""
if new:
converged = False
while not converged:
# === NEED DICTIONARY OF ANCESTORS AND C-COMPONENTS TO CHECK INDUCING PATHS
ancestor_dict = self.ancestor_dict()
c_components = self.c_components()
node2component = dict()
for ix, component in enumerate(c_components):
for node in component:
node2component[node] = ix
if verbose: print('==========')
if verbose: print('Ancestor dict:', ancestor_dict)
if verbose: print('C components', c_components)
# === FIND INDUCING PATHS BETWEEN PAIRS OF NODE
induced_pairs = []
non_adjacent_pairs = ((i, j) for i, j in itr.combinations(self._nodes, 2) if
not self.has_any_edge(i, j))
for node1, node2 in non_adjacent_pairs:
check_ancestry = lambda node: node in ancestor_dict[node1] or node in ancestor_dict[node2]
nbrs1 = self._children[node1] | self._spouses[node1]
nbrs2 = self._children[node2] | self._spouses[node2]
if verbose: print(f"-------------\nChecking {node1} and {node2}")
# ONLY CHECK PATHS BETWEEN SPOUSES/CHILDREN THAT ARE IN THE SAME C-COMPONENT
for nbr1, nbr2 in itr.product(nbrs1, nbrs2):
same_component = node2component[nbr1] == node2component[nbr2]
if same_component and nbr1 in ancestor_dict[node2] and nbr2 in ancestor_dict[node1]:
if verbose: print(f"Checking neighbors {nbr1} (for {node1}) and {nbr2} (for {node2})")
if self._reachable(nbr1, nbr2, visited=set(), allowed_edges={'b'}, predicate=check_ancestry,
verbose=verbose):
if verbose: print("Reachable")
induced_pairs.append((node1, node2))
continue
elif verbose:
print("No path")
if verbose: print(f"found induced pairs: {induced_pairs}")
for node1, node2 in induced_pairs:
self.add_bidirected(node1, node2)
converged = len(induced_pairs) == 0
# print('converged:', converged)
else:
for i, j in itr.combinations(self._nodes, r=2):
if not self.has_any_edge(i, j):
never_msep = not any(self.msep(i, j, S) for S in core_utils.powerset(self._nodes - {i, j}))
if never_msep: self.add_bidirected(i, j)
def to_pag(self):
raise NotImplementedError
# === CONVERTERS
[docs] def to_amat(self) -> np.ndarray:
"""
Convert the graph into an adjacency matrix.
TODO: meaning of numbers
Returns
-------
amat
The adjacency matrix of this graph.
Examples
--------
TODO
"""
amat = np.zeros([self.nnodes, self.nnodes])
for i, j in self.directed:
amat[i, j] = 2
amat[j, i] = 3
for i, j in self.bidirected:
amat[i, j] = 2
amat[j, i] = 2
for i, j in self.undirected:
amat[i, j] = 3
amat[j, i] = 3
return amat
[docs] @staticmethod
def from_amat(amat: np.ndarray):
"""
Create a graph from an adjacency matrix.
TODO: meaning of numbers
Parameters
----------
amat
The adjacency matrix
Examples
--------
TODO
"""
p = amat.shape[0]
directed = set()
bidirected = set()
undirected = set()
for i, j in itr.combinations(set(range(p)), 2):
vij = amat[i, j]
vji = amat[j, i]
if vij == 2 and vji == 3: # arrowhead at j
directed.add((i, j))
elif vij == 3 and vji == 2: # arrowhead at i
directed.add((j, i))
elif vij == 2 and vji == 2: # arrowheads at both
bidirected.add((i, j))
elif vij == 3 and vji == 3: # no arrowhead
undirected.add((i, j))
return AncestralGraph(set(range(p)), directed, bidirected, undirected)
# === COMPARISON
[docs] def markov_equivalent(self, other) -> bool:
"""
Check if this graph is Markov equivalent to the graph ``other``. Two graphs are Markov equivalent iff.
they have the same skeleton, same v-structures, and if whenever there is the same discriminating path for some
node in both graphs, the node is a collider on that path in one graph iff. it is a collider on that path in
the other graph.
Parameters
----------
other:
another AncestralGraph.
Examples
--------
TODO
"""
same_skeleton = self.skeleton == other.skeleton
same_vstructures = self.vstructures() == other.vstructures()
self_discriminating_paths = self.discriminating_paths()
other_discriminating_paths = other.discriminating_paths()
shared_disc_paths = set(self_discriminating_paths.keys()) & set(other_discriminating_paths)
same_discriminating = all(
self_discriminating_paths[path] == other_discriminating_paths[path]
for path in shared_disc_paths
)
return same_skeleton and same_vstructures and same_discriminating
def fast_markov_equivalent(self, other) -> bool:
"""
Use Algorithm 1 of "Faster algorithms for Markov equivalence" (Hu and Evans, 2020) to check for Markov
equivalence between two graphs.
"""
if self.skeleton != other.skeleton:
return False
if self.vstructures() != other.vstructures():
return False
if self.discriminating_triples() != other.discriminating_triples():
return False
return True
def _tail_of(self, v: Node, w: Node, ancestor_dict: dict):
ancestors = ancestor_dict[v] | ancestor_dict[w] | {v, w}
d = self.district_of(v, ancestors)
p = self.parents_of(d)
return p | d - {v, w}
def discriminating_triples(self, verbose=False):
"""
Return the discriminating triples of the graph, which are triples of nodes that determine the discriminating
paths.
"""
d_triples = set()
ancestor_dict = self.ancestor_dict()
desc_dict = self.descendant_dict()
for v, w in self._bidirected:
tail_vw = self._tail_of(v, w, ancestor_dict)
# LINES 10-12
for z in tail_vw:
if not (self.has_any_edge(v, z) and self.has_any_edge(w, z)):
if verbose: print(f"{z} in tail of ({v, w})")
d_triples.add(frozenset({v, w, z}))
# LINES 13-17
a = ancestor_dict[v] | ancestor_dict[w] | {v, w}
d = desc_dict[v] | desc_dict[w] | {v, w}
for z in self.spouses_of(a) & self.district_of(v) - (a | d):
if not (self.has_any_edge(v, z) and self.has_any_edge(w, z)):
dis = self.district_of(v, a | ancestor_dict[z] | {z})
if z in dis:
if verbose: print(f"{z} in district of {v} restricted to spouses, district")
d_triples.add(frozenset({v, w, z}))
return d_triples | {frozenset(vstruct) for vstruct in self.vstructures()}
def get_all_mec(self):
"""
TODO
Examples
--------
TODO
"""
visited = set()
queue = [self]
mags = []
while queue:
mag = queue.pop()
mags.append(mag)
curr_dir, curr_bidir = frozenset(mag._directed), frozenset({frozenset({*e}) for e in mag._bidirected})
visited.add((curr_dir, curr_bidir))
lmcs_dir, lmcs_bidir = mag.legitimate_mark_changes()
for i, j in lmcs_dir:
new_dir = curr_dir - {(i, j)}
new_bidir = curr_bidir | {frozenset({i, j})}
if (new_dir, new_bidir) not in visited:
new_mag = mag.copy()
new_mag.remove_directed(i, j)
new_mag.add_bidirected(i, j)
queue.append(new_mag)
for i, j in lmcs_bidir:
new_dir = curr_dir | {(i, j)}
new_bidir = curr_bidir - {frozenset({i, j})}
if (new_dir, new_bidir) not in visited:
new_mag = mag.copy()
new_mag.remove_bidirected(i, j)
new_mag.add_directed(i, j)
queue.append(new_mag)
return mags
[docs] def shd_skeleton(self, other) -> int:
"""
Compute the structure Hamming distance between the skeleton of this graph and the skeleton of another graph.
Parameters
----------
other:
the graph to which the SHD of the skeleton will be computed.
Return
------
int
The structural Hamming distance between :math:`G_1` and :math:`G_2` is the minimum number of arc additions,
deletions, and reversals required to transform :math:`G_1` into :math:`G_2` (and vice versa).
Example
-------
>>> TODO
"""
return len(self.skeleton.symmetric_difference(other.skeleton))
def as_hashed(self):
"""
TODO
Examples
--------
TODO
"""
return frozenset(self._directed), frozenset(self._bidirected), frozenset(self._undirected)
# === Algorithms
def _add_upstream(self, upstream: set, node: Node):
for parent in self._parents[node]:
if parent not in upstream:
upstream.add(parent)
self._add_upstream(upstream, parent)
def _is_collider(self, u: Node, v: Node, w: Node) -> bool:
"""return True if u-v-w is a collider"""
if v in self._children[u] and v in self._children[w]:
return True
elif v in self._children[u] and v in self._spouses[w]:
return True
elif v in self._spouses[u] and v in self._children[w]:
return True
elif v in self._spouses[u] and v in self._spouses[w]:
return True
else:
return False
def _no_other_path(self, i: Node, j: Node, ancestor_dict: dict) -> bool:
"""
Check if there is any path from ``i`` to ``j`` other than possibly the direct edge i->j.
"""
other_ancestors_j = ancestor_dict[j] - {i}
return (other_ancestors_j & self._children[i]) == set()
def legitimate_mark_changes(self, verbose=False, strict=True):
"""
Return directed edges that can be changed to bidirected edges, and bidirected edges that can be changed to
directed edges.
Parameters
----------
verbose:
If True, print each possible mark change and which condition it fails, if any.
strict:
If True, check discriminating path condition. Otherwise, check only equality of parents and spouses.
Return
------
(mark_changes_dir, mark_changes_bidir)
Directed edges that can be changed to bidirected edges, and bidirected edges that can be changed to directed
edges (which will be the new directed edge).
Example
-------
>>> import causaldag as cd
>>> g = cd.AncestralGraph(directed={(0, 1)}, bidirected={(1, 2)})
>>> g.legitimate_mark_changes()
({(0, 1)}, {(2, 1)})
"""
if self._undirected:
raise ValueError('Only defined for DMAGs')
if not strict:
# print("TODO: CHECK")
ancestor_dict = self.ancestor_dict()
mark_changes_dir = {
(i, j) for i, j in self._directed
if self._parents[i] - self._parents[j] == set() and
self._spouses[i] - self._parents[j] - self._spouses[j] == set()
and self._no_other_path(i, j, ancestor_dict)
}
bidirected = [tuple(e) for e in self._bidirected]
bidirected_reversed = [tuple(reversed(e)) for e in self._bidirected]
mark_changes_bidir = {
(i, j) for i, j in bidirected + bidirected_reversed
if self._parents[i] - self._parents[j] == set() and
self._spouses[i] - {j} - self._parents[j] - self._spouses[j] == set()
}
return mark_changes_dir, mark_changes_bidir
if strict:
disc_paths = self.discriminating_paths()
ancestor_dict = self.ancestor_dict()
mark_changes_dir = set()
for i, j in self._directed:
if verbose: print(f'{i}->{j} => {i}<->{j} ?')
parents_condition = self._parents[i] - self._parents[j]
if parents_condition != set():
if verbose: print(f'Failed parents condition on {parents_condition}')
continue
spouses_condition = self._spouses[i] - self._spouses[j] - self._parents[j]
if spouses_condition != set():
if verbose: print(f'Failed spouses condition on {spouses_condition}')
continue
ancestral_condition = self._no_other_path(i, j, ancestor_dict)
# ancestral_condition2 = i not in self.ancestors_of(j, exclude_arcs={(i, j)})
# print(ancestral_condition == ancestral_condition2)
# if ancestral_condition != ancestral_condition2:
# print(self, i, j, (ancestor_dict[j] - {i}) & self._children[i], ancestor_dict[j], self._children[i])
if not ancestral_condition:
if verbose: print(f'Failed ancestral condition')
continue
# SECOND CONDITION
disc_paths_for_i = [path for path in disc_paths.keys() if path[-2] == i]
disc_paths_condition = next((path for path in disc_paths_for_i if path[-1] == j),
None) if disc_paths_for_i else None
if disc_paths_condition is not None:
if verbose: print(f'Failed discriminating path condition on {disc_paths_condition}')
continue
if verbose: print('Passed')
mark_changes_dir.add((i, j))
mark_changes_bidir = set()
forward_edges = {(i, j) for i, j in self._bidirected}
for i, j in forward_edges | set(map(reversed, forward_edges)):
if verbose: print(f'{i}<->{j} => {i}->{j} ?')
parents_condition = self._parents[i] - self._parents[j]
if parents_condition != set():
if verbose: print(f'Failed parents condition on {parents_condition}')
continue
spouses_condition = self._spouses[i] - {j} - self._spouses[j] - self._parents[j]
if spouses_condition != set():
if verbose: print(f'Failed spouses condition on {spouses_condition}')
continue
ancestral_condition = self._no_other_path(i, j, ancestor_dict)
if not ancestral_condition:
if verbose: print('failed ancestral condition')
continue
# SECOND CONDITION
disc_paths_for_i = [path for path in disc_paths.keys() if path[-2] == i]
disc_paths_condition = next((path for path in disc_paths_for_i if path[-1] == j),
None) if disc_paths_for_i else None
if disc_paths_condition is not None:
if verbose: print(f'Failed discriminating path condition on {disc_paths_condition}')
continue
if verbose: print('Passed')
mark_changes_bidir.add((i, j))
return mark_changes_dir, mark_changes_bidir
def msep(self, A: Set[Node], B: Set[Node], C: Set[Node]=set()) -> bool:
"""
Check whether ``A`` and ``B`` are m-separated given ``C``, using the Bayes ball algorithm.
Parameters
----------
A:
Set
B:
Set
C:
Set
See Also
--------
msep_from_given
Examples
--------
TODO
"""
# type coercion
A = core_utils.to_set(A)
B = core_utils.to_set(B)
C = core_utils.to_set(C)
# shade ancestors of C
shaded_nodes = set(C)
for node in C:
self._add_upstream(shaded_nodes, node)
visited = set()
# marks whether the node has been encountered along a path where it has a tail or an arrowhead
_t = 'tail' # tail
_a = 'arrowhead' # arrowhead
schedule = {(node, _t) for node in A}
while schedule:
node, _dir = schedule.pop()
if node in B: return False
if (node, _dir) in visited: continue
visited.add((node, _dir))
# print(node, _dir)
# if coming through a tail, won't encounter v-structure
if _dir == _t and node not in C:
schedule.update({(parent, _t) for parent in self._parents[node]})
schedule.update({(child, _a) for child in self._children[node]})
schedule.update({(spouse, _a) for spouse in self._spouses[node]})
schedule.update({(nbr, _t) for nbr in self._neighbors[node]})
if _dir == _a:
# if coming through an arrowhead and see shaded node, can go through v-structure
if node in shaded_nodes:
schedule.update({(parent, _t) for parent in self._parents[node]})
schedule.update({(spouse, _a) for spouse in self._spouses[node]})
# if coming through an arrowhead and see unconditioned node, can go through children and neighbors
if node not in C:
schedule.update({(child, _a) for child in self._children[node]})
schedule.update({(nbr, _a) for nbr in self._neighbors[node]})
return True
def msep_from_given(self, A: Set[Node], C: Set[Node]=set()) -> Set[Node]:
"""
Find all nodes m-separated from ``A`` given ``C``.
Uses algorithm similar to that in Geiger, D., Verma, T., & Pearl, J. (1990).
Identifying independence in Bayesian networks. Networks, 20(5), 507-534.
Parameters
----------
A:
Set
B:
Set
See Also
--------
msep
Examples
--------
TODO
"""
warn_untested()
A = core_utils.to_set(A)
C = core_utils.to_set(C)
determined = set()
descendants = set()
for c in C:
determined.add(c)
descendants.add(c)
self._add_upstream(descendants, c)
reachable = set()
i_links = set()
labeled_links = set()
for a in A:
i_links.add((None, a))
reachable.add(a)
while True:
i_p_1_links = set()
# Find all unlabled links v->w adjacent to at least one link u->v labeled i, such that (u->v,v->w) is a legal pair.
for link in i_links:
u, v = link
for w in self._adjacent[v]:
if not u == w and (v, w) not in labeled_links:
if self._is_collider(u, v, w): # Is collider?
if v in descendants:
i_p_1_links.add((v, w))
reachable.add(w)
else: # Not collider
if v not in determined:
i_p_1_links.add((v, w))
reachable.add(w)
if len(i_p_1_links) == 0:
break
labeled_links = labeled_links.union(i_links)
i_links = i_p_1_links
return self._nodes.difference(A).difference(C).difference(reachable)
if __name__ == '__main__':
g = AncestralGraph(nodes=set(range(1, 5)), directed={(1, 2), (2, 4), (3, 2), (3, 4)})
disc_paths = g.discriminating_paths()