Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft]: Calculate local markov blanket against true graph for parted nodes that passed AD Test #1757

Draft
wants to merge 6 commits into
base: development
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
* @author bryanandrews, osephramsey
* @version $Id: $Id
*/
public class OrientationPrecision implements Statistic {
public class OrientationPrecision implements Statistic { // TODO VBC: is this one we want to use?
@Serial
private static final long serialVersionUID = 23L;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
* compared to the true graph. It calculates the ratio of true positive orientations to the sum of true positive and
* false negative orientations.
*/
public class OrientationRecall implements Statistic {
public class OrientationRecall implements Statistic { // TODO VBC: use this?
@Serial
private static final long serialVersionUID = 23L;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public class OrientationConfusion {
* @param truth a {@link edu.cmu.tetrad.graph.Graph} object
* @param est a {@link edu.cmu.tetrad.graph.Graph} object
*/
public OrientationConfusion(Graph truth, Graph est) {
public OrientationConfusion(Graph truth, Graph est) { // TODO VBC: is this one we want to use?
this.tp = 0;
this.fp = 0;
this.fn = 0;
Expand Down
4 changes: 2 additions & 2 deletions tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Edge.java
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ public final String toString() {
_type = new StringBuilder("no edge");
break;
case ta:
_type = new StringBuilder("-->");
_type = new StringBuilder("-->"); // TODO VBC: or shall we just compare edge by strings
break;
case at:
_type = new StringBuilder("<--");
Expand Down Expand Up @@ -439,7 +439,7 @@ public final boolean equals(Object o) {
* @param _edge a {@link edu.cmu.tetrad.graph.Edge} object
* @return a int
*/
public int compareTo(Edge _edge) {
public int compareTo(Edge _edge) { // TODO VBC: seems only comparing the edpoint not the direction?
int comp1 = getNode1().compareTo(_edge.getNode1());

if (comp1 != 0) {
Expand Down
33 changes: 23 additions & 10 deletions tetrad-lib/src/main/java/edu/cmu/tetrad/graph/EdgeListGraph.java
Original file line number Diff line number Diff line change
Expand Up @@ -612,14 +612,15 @@ public List<Node> getAdjacentNodes(Node node) {
Set<Edge> edges = this.edgeLists.get(node);
Set<Node> adj = new HashSet<>();

for (Edge edge : edges) {
if (edge == null) {
continue;
}
if (edges != null) {
for (Edge edge : edges) {
if (edge == null) {
continue;
}

adj.add(edge.getDistalNode(node));
adj.add(edge.getDistalNode(node));
}
}

return new ArrayList<>(adj);
}

Expand Down Expand Up @@ -741,13 +742,25 @@ public boolean addEdge(Edge edge) {
// Someoone may have changed the name of one of these variables, in which
// case we need to reconstitute the edgeLists map, since the name of a
// node is used part of the definition of node equality.
if (!edgeLists.containsKey(edge.getNode1()) || !edgeLists.containsKey(edge.getNode2())) {
Node node1 = edge.getNode1();
Node node2 = edge.getNode2();
// System.out.println("Real Before: " + edgeLists);
if (!edgeLists.containsKey(node1) || !edgeLists.containsKey(node2)) {
this.edgeLists = new HashMap<>(this.edgeLists);
}

this.edgeLists.get(edge.getNode1()).add(edge);
this.edgeLists.get(edge.getNode2()).add(edge);
// System.out.println("Before Adding: " + edgeLists);
if (this.edgeLists.get(node1) == null ) {
// System.out.println("Missing node1 is not in edgeLists: " + node1);
this.edgeLists.put(node1, new HashSet<>());
}
if (this.edgeLists.get(node2) == null ) {
// System.out.println("Missing node2 is not in edgeLists: " + node2);
this.edgeLists.put(node2, new HashSet<>());
}
this.edgeLists.get(node1).add(edge);
this.edgeLists.get(node2).add(edge);
this.edgesSet.add(edge);
// System.out.println("After: " + edgeLists);

this.parentsHash.remove(edge.getNode1());
this.parentsHash.remove(edge.getNode2());
Expand Down
8 changes: 5 additions & 3 deletions tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ public static boolean isClique(Collection<Node> set, Graph graph) {
* @param graph a DAG, CPDAG, MAG, or PAG.
* @return a {@link edu.cmu.tetrad.graph.Graph} object
*/
public static Graph markovBlanketSubgraph(Node target, Graph graph) {
public static Graph markovBlanketSubgraph(Node target, Graph graph) { // TODO VBC: @Joe is this the more general method you recommended?
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jdramsey for confirmation :D

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the method I had in mind:

public Graph getMarkovBlanketSubgraph(Graph graph, Node targetNode) {
    EdgeListGraph g = new EdgeListGraph();
    Set<Node> nodes = GraphUtils.markovBlanket(targetNode, g);
    return g.subgraph(new ArrayList<>(nodes));
}

You could put that method somewhere helpful; it may not already exist in the code.

Set<Node> mb = markovBlanket(target, graph);

Graph mbGraph = new EdgeListGraph();
Expand All @@ -117,6 +117,8 @@ public static Graph markovBlanketSubgraph(Node target, Graph graph) {

for (int i = 0; i < mbList.size(); i++) {
for (int j = i + 1; j < mbList.size(); j++) {
List<Edge> edges = graph.getEdges(mbList.get(i), mbList.get(j));
// System.out.println("Add edges between!!!! " + mbList.get(i) + " " + mbList.get(j));
for (Edge e : graph.getEdges(mbList.get(i), mbList.get(j))) {
mbGraph.addEdge(e);
}
Expand Down Expand Up @@ -2253,7 +2255,7 @@ public static Graph trimGraph(List<Node> targets, Graph graph, int trimmingStyle
graph = trimAdjacentToTarget(targets, graph);
break;
case 3:
graph = trimMarkovBlanketGraph(targets, graph);
graph = trimMarkovBlanketGraph(targets, graph); // TODO VBC currently using this
break;
case 4:
graph = trimSemidirected(targets, graph);
Expand Down Expand Up @@ -2298,7 +2300,7 @@ private static Graph trimAdjacentToTarget(List<Node> targets, Graph graph) {
* @param graph the original graph from which the Markov blanket graph is derived
* @return the trimmed Markov blanket graph
*/
private static Graph trimMarkovBlanketGraph(List<Node> targets, Graph graph) {
private static Graph trimMarkovBlanketGraph(List<Node> targets, Graph graph) { // TODO vbc this is
Graph mbDag = new EdgeListGraph(graph);

M:
Expand Down
118 changes: 114 additions & 4 deletions tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@

import edu.cmu.tetrad.data.GeneralAndersonDarlingTest;
import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.IndependenceFact;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.*;
import edu.cmu.tetrad.search.test.*;
import edu.cmu.tetrad.util.SublistGenerator;
import edu.cmu.tetrad.util.TetradLogger;
Expand Down Expand Up @@ -251,6 +248,119 @@ public Double checkAgainstAndersonDarlingTest(List<Double> pValues) {
return generalAndersonDarlingTest.getP();
}

public List<List<Node>> getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(IndependenceTest independenceTest, Graph graph, Double threshold) {
// when calling, default reject null as <=0.05
List<List<Node>> accepts_rejects = new ArrayList<>();
List<Node> accepts = new ArrayList<>();
List<Node> rejects = new ArrayList<>();
List<Node> allNodes = graph.getNodes();
for (Node x : allNodes) {
List<IndependenceFact> localIndependenceFacts = getLocalIndependenceFacts(x);
List<Double> localPValues = getLocalPValues(independenceTest, localIndependenceFacts);
Double ADTest = checkAgainstAndersonDarlingTest(localPValues);
if (ADTest <= threshold) {
rejects.add(x);
} else {
accepts.add(x);
}
}
accepts_rejects.add(accepts);
accepts_rejects.add(rejects);
return accepts_rejects;
}

private Graph getMarkovBlanketSubgraph(Graph graph, Node targetNode) {
EdgeListGraph g = new EdgeListGraph(graph);
Set<Node> mbNodes = GraphUtils.markovBlanket(targetNode, g);
mbNodes.add(targetNode);
return g.subgraph(new ArrayList<>(mbNodes));
}



public Double getPrecisionOrRecallOnMarkovBlanketGraph(Node x, Graph estimatedGraph, Graph trueGraph, boolean getPrecision) {
List<Node> singleNode = Arrays.asList(x);
// Lookup graph is the same structure as trueGraph's structure but node objects replaced by estimated graph nodes.
Graph lookupGraph = GraphUtils.replaceNodes(trueGraph, estimatedGraph.getNodes());
System.out.println("@@@@@@@@@@@@@@@@");
System.out.println("Node: " + x);
System.out.println("True Graph:" + trueGraph);
System.out.println("LookupGraph:" + lookupGraph); // print should look the same as the true graph

// TODO VBC: The Trim method is the most accurate in terms of all nodes and the edges
Graph RecommendedxMBLookupGraph = getMarkovBlanketSubgraph(lookupGraph, x); // Recommended, not working
Graph xMBLookupGraph = GraphUtils.markovBlanketSubgraph(x, lookupGraph); // TODO VBC: this one should include the target node
Graph TrimxMBLookupGraph = GraphUtils.trimGraph(singleNode, lookupGraph, 3); // Best
Set<Edge> xMBLookupGraphEdges = xMBLookupGraph.getEdges();

System.out.println("xMBLookupGraphEdges size: " + xMBLookupGraphEdges.size());
System.out.println("xMBLookupGraph Nodes size: " + xMBLookupGraph.getNodes().size());
System.out.println("xMBLookupGraph:" + xMBLookupGraph); // The MB trim of the lookup graph, so it should be a subset of the lookup graph print
System.out.println("RecommendedxMBLookupGraph:" + RecommendedxMBLookupGraph); // The MB trim of the lookup graph, so it should be a subset of the lookup graph print
System.out.println("TrimxMBLookupGraph:" + TrimxMBLookupGraph); // The MB trim of the lookup graph, so it should be a subset of the lookup graph print

// Get Markov Blanket Subgraph for this node x.
// Graph xMBEstimatedGraph = getMarkovBlanketSubgraph(estimatedGraph, x);
Graph xMBEstimatedGraph = GraphUtils.markovBlanketSubgraph(x, estimatedGraph);
// Graph xMBEstimatedGraph = GraphUtils.trimGraph(singleNode, estimatedGraph, 3);
Set<Edge> xMBEstimatedGraphEdges = xMBEstimatedGraph.getEdges();
System.out.println("xMBEstimatedGraphEdges size: " + xMBEstimatedGraphEdges.size());
System.out.println("xMBEstimatedGraph Nodes size: " + xMBEstimatedGraph.getNodes().size());
System.out.println("xMBEstimatedGraph:" + xMBEstimatedGraph); // This should be compared with the xMBLookupGraph
System.out.println("@@@@@@@@@@@@@@@@");

HashSet<Edge> truePositive = new HashSet<>();
HashSet<Edge> falsePositive = new HashSet<>();
HashSet<Edge> falseNegative = new HashSet<>();
Set<Edge> trueGraphEdgesEdges = trueGraph.getEdges();
Set<Edge> estimatedGraphEdgesEdges = estimatedGraph.getEdges();
if (trueGraphEdgesEdges != null && estimatedGraphEdgesEdges != null) {
for (Edge te: trueGraphEdgesEdges) {
for (Edge ee: estimatedGraphEdgesEdges) {
// True Graph's Edge info
Node teNode1 = te.getNode1();
Node teNode2 = te.getNode1();
Endpoint teEndpoint1 = te.getEndpoint1();
Endpoint teEndpoint2 = te.getEndpoint2();
// Estimated Graph's Edge info
Node eeNode1 = te.getNode1();
Node eeNode2 = te.getNode1();
Endpoint eeEndpoint1 = ee.getEndpoint1();
Endpoint eeEndpoint2 = ee.getEndpoint2();
boolean isSameNode1 = areSame(teNode1, eeNode1);
boolean isSameNode2 = areSame(teNode2, eeNode2);

// EdgeTypeProbability.EdgeType teType = te.getEdgeTypeProbabilities().getFirst().getEdgeType();

// If both n1 n2 are the same, compare the endpoint1 endpoint2
if (isSameNode1 && isSameNode2) {
// if (teEndpoint1.compareTo(eeEndpoint1))
// QUESTION: // TODO VBC: seems Edge#compareTo() only comparing the node itself not the endpoint?
// QUESTION: do we only care about edge type here?

}



}
}
}
// TODO VBC:
// Logic of comparing true graph with estimated graph

double precision = (double) truePositive.size() / (truePositive.size() + falsePositive.size());
double recall = (double) truePositive.size() / (truePositive.size() + falseNegative.size());
return getPrecision ? precision : recall;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These outprints seems correct to me, maybe you want a second check : )

}

private boolean areSame(Node n1, Node n2) {
// TODO VBC: Compare the Nodes are of the same.
// QUESTION: the compareTo() method in Node class is very complicated, involves Lag etc. is that what we want to use?
// or shall we just compare by names of these nodes

return n1.getName().equals(n2.getName());
}


/**
* Returns the variables of the independence test.
Expand Down
42 changes: 42 additions & 0 deletions tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemPm;
import edu.cmu.tetrad.util.NumberFormatUtil;
import edu.cmu.tetrad.util.Parameters;
import org.junit.Test;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -110,4 +112,44 @@ public void test2() {

System.out.println(markovCheck.getMarkovCheckRecordString());
}

@Test
public void testPrecissionRecallForLocal() {
// TODO also use randome graph then convert to cpday learn from Test Graph Utils. write a diff test case for this.
Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false);
System.out.println("Test True Graph: " + trueGraph);
System.out.println("Test True Graph size: " + trueGraph.getNodes().size());

SemPm pm = new SemPm(trueGraph);
SemIm im = new SemIm(pm, new Parameters());
DataSet data = im.simulateData(1000, false);
edu.cmu.tetrad.search.score.SemBicScore score = new SemBicScore(data, false);
score.setPenaltyDiscount(2);
Graph estimatedCpdag = new PermutationSearch(new Boss(score)).search(); // Estimated graph
System.out.println("Test Estimated CPDAG Graph: " + estimatedCpdag);
System.out.println("Test Estimated Graph size: " + estimatedCpdag.getNodes().size());
System.out.println("=====================================");

IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05);
MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.MARKOV_BLANKET);
List<List<Node>> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05);
List<Node> accepts = accepts_rejects.get(0);
List<Node> rejects = accepts_rejects.get(1);
System.out.println("Accepts size: " + accepts.size());
System.out.println("Rejects size: " + rejects.size());

List<Double> acceptsPrecision = new ArrayList<>();
List<Double> acceptsRecall = new ArrayList<>();
for(Node a: accepts) {
double precision = markovCheck.getPrecisionOrRecallOnMarkovBlanketGraph(a, estimatedCpdag, trueGraph, true);
double recall = markovCheck.getPrecisionOrRecallOnMarkovBlanketGraph(a, estimatedCpdag, trueGraph, false);
acceptsPrecision.add(precision);
acceptsRecall.add(recall);
}
System.out.println("Accepts Precisions: " + acceptsPrecision);
System.out.println("Accepts Recall: " + acceptsRecall);
System.out.println("****************************************************");


}
}