diff --git a/pywhy_graphs/algorithms/generic.py b/pywhy_graphs/algorithms/generic.py index f0011c655..d878c4765 100644 --- a/pywhy_graphs/algorithms/generic.py +++ b/pywhy_graphs/algorithms/generic.py @@ -19,6 +19,7 @@ "dag_to_mag", "is_maximal", "all_vstructures", + "check_visibility" ] @@ -826,7 +827,6 @@ def is_maximal(G, L: Optional[Set] = None, S: Optional[Set] = None): continue return True - def all_vstructures(G: nx.DiGraph, as_edges: bool = False): """Generate all v-structures in the graph. @@ -855,3 +855,58 @@ def all_vstructures(G: nx.DiGraph, as_edges: bool = False): else: vstructs.add((p1, node, p2)) # type: ignore return vstructs + +def get_all_collider_paths(G : PAG, X, Y): + + out = [] + + # find all the possible paths from X to Y with only bi-directed edges + + bidirected_edge_graph = G.sub_bidirected_graph + + X_descendants = set(G.sub_directed_graph.neigbors(X)) + + candidate_collider_path_nodes = set(bidirected_edge_graph.nodes).intersection(X_descendants) + + if candidate_collider_path_nodes is None: + return out + + for elem in candidate_collider_path_nodes: + out.extend(nx.all_simple_paths(G, elem, Y)) + + # for path in out: + # path.insert(0,X) + + return out + +def check_visibility(G: PAG, X: str, Y: str): + + X_neighbors = set(G.neighbors(X)) + Y_neighbors = set(G.neighbors(Y)) + + only_x_neighbors = X_neighbors - Y_neighbors + + + for elem in only_x_neighbors: + if G.has_edge(elem, X, G.directed_edge_name): + return True + + all_nodes = set(G.nodes) + + all_nodes.remove(X) + + + candidates = all_nodes - Y_neighbors + + for elem in candidates: + collider_paths = get_all_collider_paths(G,elem,X) + for path in collider_paths: + for node in path: + if node in G.neighbors(Y): + continue + else: + return True + + return False + + diff --git a/pywhy_graphs/algorithms/tests/test_generic.py b/pywhy_graphs/algorithms/tests/test_generic.py index 09218a334..80b70eb6d 100644 --- a/pywhy_graphs/algorithms/tests/test_generic.py +++ b/pywhy_graphs/algorithms/tests/test_generic.py @@ -2,8 +2,8 @@ import pytest import pywhy_graphs -from pywhy_graphs import ADMG -from pywhy_graphs.algorithms import all_vstructures +from pywhy_graphs import ADMG, PAG +from pywhy_graphs.algorithms import all_vstructures, check_visibility def test_convert_to_latent_confounder_errors(): @@ -496,3 +496,32 @@ def test_all_vstructures(): # Assert that the returned values are as expected assert len(v_structs_edges) == 0 assert len(v_structs_tuples) == 0 + + + +def test_check_visibility(): + + # H <-> K <-> Z <-> X <- Y + + pag = PAG() + pag.add_edge("Y", "X", pag.directed_edge_name) + pag.add_edge("Z", "X", pag.bidirected_edge_name) + pag.add_edge("Z", "K", pag.bidirected_edge_name) + pag.add_edge("K", "H", pag.bidirected_edge_name) + + assert True == check_visibility(pag, "X", "Y") + + pag = PAG() + pag.add_edge("Y", "X", pag.directed_edge_name) + pag.add_edge("Z", "X", pag.bidirected_edge_name) + + assert True == check_visibility(pag, "X", "Y") + + pag = PAG() + pag.add_edge("Y", "X", pag.directed_edge_name) + pag.add_edge("Z", "Y", pag.bidirected_edge_name) + pag.add_edge("Z", "K", pag.bidirected_edge_name) + pag.add_edge("K", "H", pag.bidirected_edge_name) + + assert False == check_visibility(pag, "X", "Y") +