forked from chuanqi129/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 4
/
clear_undefinedness.cpp
39 lines (33 loc) · 958 Bytes
/
clear_undefinedness.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#include <torch/csrc/jit/passes/clear_undefinedness.h>
#include <torch/csrc/jit/jit_log.h>
namespace torch {
namespace jit {
static void clearUndefinedness(Value* o) {
if (o->type()->kind() == TensorType::Kind) {
o->setType(TensorType::get());
} else if (
o->type()->kind() == ListType::Kind &&
o->type()->expectRef<ListType>().getElementType()->kind() ==
TensorType::Kind) {
o->setType(ListType::create(TensorType::get()));
}
}
static void clearUndefinedness(Block* block) {
for (auto n : block->nodes()) {
for (auto o : n->outputs()) {
clearUndefinedness(o);
}
for (auto ib : n->blocks()) {
clearUndefinedness(ib);
}
}
}
void ClearUndefinedness(const std::shared_ptr<Graph>& graph) {
for (auto i : graph->inputs()) {
clearUndefinedness(i);
}
clearUndefinedness(graph->block());
GRAPH_DUMP("After removeUndefinedness: ", graph);
}
} // namespace jit
} // namespace torch