Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft]: Calculate local markov blanket against true graph for parted nodes that passed AD Test #1757

Draft
wants to merge 6 commits into
base: development
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ public static boolean isClique(Collection<Node> set, Graph graph) {
* @param graph a DAG, CPDAG, MAG, or PAG.
* @return a {@link edu.cmu.tetrad.graph.Graph} object
*/
public static Graph markovBlanketSubgraph(Node target, Graph graph) {
public static Graph markovBlanketSubgraph(Node target, Graph graph) { // TODO VBC: @Joe is this the more general method you recommended?
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jdramsey for confirmation :D

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the method I had in mind:

public Graph getMarkovBlanketSubgraph(Graph graph, Node targetNode) {
    EdgeListGraph g = new EdgeListGraph();
    Set<Node> nodes = GraphUtils.markovBlanket(targetNode, g);
    return g.subgraph(new ArrayList<>(nodes));
}

You could put that method somewhere helpful; it may not already exist in the code.

Set<Node> mb = markovBlanket(target, graph);

Graph mbGraph = new EdgeListGraph();
Expand Down Expand Up @@ -2253,7 +2253,7 @@ public static Graph trimGraph(List<Node> targets, Graph graph, int trimmingStyle
graph = trimAdjacentToTarget(targets, graph);
break;
case 3:
graph = trimMarkovBlanketGraph(targets, graph);
graph = trimMarkovBlanketGraph(targets, graph); // TODO VBC currently using this
break;
case 4:
graph = trimSemidirected(targets, graph);
Expand Down Expand Up @@ -2298,7 +2298,7 @@ private static Graph trimAdjacentToTarget(List<Node> targets, Graph graph) {
* @param graph the original graph from which the Markov blanket graph is derived
* @return the trimmed Markov blanket graph
*/
private static Graph trimMarkovBlanketGraph(List<Node> targets, Graph graph) {
private static Graph trimMarkovBlanketGraph(List<Node> targets, Graph graph) { // TODO vbc this is
Graph mbDag = new EdgeListGraph(graph);

M:
Expand Down
57 changes: 57 additions & 0 deletions tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.IndependenceFact;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.search.test.*;
import edu.cmu.tetrad.util.SublistGenerator;
import edu.cmu.tetrad.util.TetradLogger;
Expand Down Expand Up @@ -251,6 +252,62 @@ public Double checkAgainstAndersonDarlingTest(List<Double> pValues) {
return generalAndersonDarlingTest.getP();
}

public List<List<Node>> getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(IndependenceTest independenceTest, Graph graph, Double threshold) {
// when calling, default reject null as <=0.05
List<List<Node>> accepts_rejects = new ArrayList<>();
List<Node> accepts = new ArrayList<>();
List<Node> rejects = new ArrayList<>();
List<Node> allNodes = graph.getNodes();
for (Node x : allNodes) {
List<IndependenceFact> localIndependenceFacts = getLocalIndependenceFacts(x);
List<Double> localPValues = getLocalPValues(independenceTest, localIndependenceFacts);
Double ADTest = checkAgainstAndersonDarlingTest(localPValues);
if (ADTest <= threshold) {
rejects.add(x);
} else {
accepts.add(x);
}
}
accepts_rejects.add(accepts);
accepts_rejects.add(rejects);
return accepts_rejects;
}

public Double getPrecisionOrRecallOnMarkovBlanketGraph(Node x, Graph estimatedGraph, Graph trueGraph, boolean getPrecision) {
List<Node> singleNode = Arrays.asList(x);
// Lookup graph is the same structure as trueGraph's structure but node objects replaced by estimated graph nodes.
Graph lookupGraph = GraphUtils.replaceNodes(trueGraph, estimatedGraph.getNodes());

// TODO VBC use the recurssion method
Graph xMBLookupGraph = GraphUtils.trimGraph(singleNode, lookupGraph, 3);
Set<Edge> xMBLookupGraphEdges = xMBLookupGraph.getEdges();
System.out.println("@@@@@@@@@@@@@@@@");
System.out.println("True Graph:" + trueGraph);
System.out.println("LookupGraph:" + lookupGraph); // this print should be the same as the true graph

System.out.println("xMBLookupGraphEdges size: " + xMBLookupGraphEdges.size());
System.out.println("xMBLookupGraph Nodes size: " + xMBLookupGraph.getNodes().size());
System.out.println("xMBLookupGraph:" + xMBLookupGraph); // The MB trim of the lookup graph, so it should be a subset of the lookup graph print

// Get Markov Blanket Subgraph for this node x.
Graph xMBEstimatedGraph = GraphUtils.trimGraph(singleNode, estimatedGraph, 3);
Set<Edge> xMBEstimatedGraphEdges = xMBEstimatedGraph.getEdges();
System.out.println("xMBEstimatedGraphEdges size: " + xMBEstimatedGraphEdges.size());
System.out.println("xMBEstimatedGraph Nodes size: " + xMBEstimatedGraph.getNodes().size());
System.out.println("xMBEstimatedGraph:" + xMBEstimatedGraph); // This should be compared with the xMBLookupGraph
System.out.println("@@@@@@@@@@@@@@@@");

HashSet<Edge> truePositive = new HashSet<>(xMBEstimatedGraphEdges);
// TODO VBC: QUESTION FOr DISCUSSION
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jdramsey We need to discuss this tho

// Here it would only be retained if the points and direction of an edge is exactly the same.
// Do we want to only check for points? If not, a lot wrong/no direction edges would be filtered out at this step.
truePositive.retainAll(xMBLookupGraphEdges);

double precision = (double) truePositive.size() / xMBLookupGraphEdges.size();
double recall = (double) truePositive.size() / xMBEstimatedGraphEdges.size();
return getPrecision ? precision : recall;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These outprints seems correct to me, maybe you want a second check : )

}


/**
* Returns the variables of the independence test.
Expand Down
41 changes: 41 additions & 0 deletions tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemPm;
import edu.cmu.tetrad.util.NumberFormatUtil;
import edu.cmu.tetrad.util.Parameters;
import org.junit.Test;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -110,4 +112,43 @@ public void test2() {

System.out.println(markovCheck.getMarkovCheckRecordString());
}

@Test
public void testPrecissionRecallForLocal() {
Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false);
System.out.println("Test True Graph: " + trueGraph);
System.out.println("Test True Graph size: " + trueGraph.getNodes().size());

SemPm pm = new SemPm(trueGraph);
SemIm im = new SemIm(pm, new Parameters());
DataSet data = im.simulateData(1000, false);
edu.cmu.tetrad.search.score.SemBicScore score = new SemBicScore(data, false);
score.setPenaltyDiscount(2);
Graph estimatedCpdag = new PermutationSearch(new Boss(score)).search(); // Estimated graph
System.out.println("Test Estimated CPDAG Graph: " + estimatedCpdag);
System.out.println("Test Estimated Graph size: " + estimatedCpdag.getNodes().size());
System.out.println("=====================================");

IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05);
MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.LOCAL_MARKOV); // TODO Also try MB for settype
List<List<Node>> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05);
List<Node> accepts = accepts_rejects.get(0);
List<Node> rejects = accepts_rejects.get(1);
System.out.println("Accepts size: " + accepts.size());
System.out.println("Rejects size: " + rejects.size());

List<Double> acceptsPrecision = new ArrayList<>();
List<Double> acceptsRecall = new ArrayList<>();
for(Node a: accepts) {
double precision = markovCheck.getPrecisionOrRecallOnMarkovBlanketGraph(a, estimatedCpdag, trueGraph, true);
double recall = markovCheck.getPrecisionOrRecallOnMarkovBlanketGraph(a, estimatedCpdag, trueGraph, false);
acceptsPrecision.add(precision);
acceptsRecall.add(recall);
}
System.out.println("Accepts Precisions: " + acceptsPrecision);
System.out.println("Accepts Recall: " + acceptsRecall);
System.out.println("****************************************************");


}
}