forked from AcademySoftwareFoundation/MaterialX
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ShaderGraph.h
270 lines (209 loc) · 10.4 KB
/
ShaderGraph.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
//
// TM & (c) 2017 Lucasfilm Entertainment Company Ltd. and Lucasfilm Ltd.
// All rights reserved. See LICENSE.txt for license.
//
#ifndef MATERIALX_SHADERGRAPH_H
#define MATERIALX_SHADERGRAPH_H
/// @file
/// Shader graph class
#include <MaterialXGenShader/Library.h>
#include <MaterialXGenShader/ColorManagementSystem.h>
#include <MaterialXGenShader/ShaderNode.h>
#include <MaterialXGenShader/TypeDesc.h>
#include <MaterialXGenShader/Syntax.h>
#include <MaterialXGenShader/UnitSystem.h>
#include <MaterialXCore/Document.h>
#include <MaterialXCore/Node.h>
namespace MaterialX
{
class Syntax;
class ShaderGraphEdge;
class ShaderGraphEdgeIterator;
class GenOptions;
/// An internal input socket in a shader graph,
/// used for connecting internal nodes to the outside
using ShaderGraphInputSocket = ShaderOutput;
/// An internal output socket in a shader graph,
/// used for connecting internal nodes to the outside
using ShaderGraphOutputSocket = ShaderInput;
/// A shared pointer to a shader graph
using ShaderGraphPtr = shared_ptr<class ShaderGraph>;
/// @class ShaderGraph
/// Class representing a graph (DAG) for shader generation
class ShaderGraph : public ShaderNode
{
public:
/// Constructor.
ShaderGraph(const ShaderGraph* parent, const string& name, ConstDocumentPtr document, const StringSet& reservedWords);
/// Desctructor.
virtual ~ShaderGraph() { }
/// Create a new shader graph from an element.
/// Supported elements are outputs and shader nodes.
static ShaderGraphPtr create(const ShaderGraph* parent, const string& name, ElementPtr element,
GenContext& context);
/// Create a new shader graph from a nodegraph.
static ShaderGraphPtr create(const ShaderGraph* parent, const NodeGraph& nodeGraph,
GenContext& context);
/// Return true if this node is a graph.
bool isAGraph() const override { return true; }
/// Get an internal node by name
ShaderNode* getNode(const string& name);
/// Get an internal node by name
const ShaderNode* getNode(const string& name) const;
/// Get a vector of all nodes in order
const vector<ShaderNode*>& getNodes() const { return _nodeOrder; }
/// Get number of input sockets
size_t numInputSockets() const { return numOutputs(); }
/// Get number of output sockets
size_t numOutputSockets() const { return numInputs(); }
/// Get socket by index
ShaderGraphInputSocket* getInputSocket(size_t index) { return getOutput(index); }
ShaderGraphOutputSocket* getOutputSocket(size_t index = 0) { return getInput(index); }
const ShaderGraphInputSocket* getInputSocket(size_t index) const { return getOutput(index); }
const ShaderGraphOutputSocket* getOutputSocket(size_t index = 0) const { return getInput(index); }
/// Get socket by name
ShaderGraphInputSocket* getInputSocket(const string& name) { return getOutput(name); }
ShaderGraphOutputSocket* getOutputSocket(const string& name) { return getInput(name); }
const ShaderGraphInputSocket* getInputSocket(const string& name) const { return getOutput(name); }
const ShaderGraphOutputSocket* getOutputSocket(const string& name) const { return getInput(name); }
/// Get vector of sockets
const vector<ShaderGraphInputSocket*>& getInputSockets() const { return _outputOrder; }
const vector<ShaderGraphOutputSocket*>& getOutputSockets() const { return _inputOrder; }
/// Create a new node in the graph
ShaderNode* createNode(const Node& node, GenContext& context);
/// Add input/output sockets
ShaderGraphInputSocket* addInputSocket(const string& name, const TypeDesc* type);
ShaderGraphOutputSocket* addOutputSocket(const string& name, const TypeDesc* type);
/// Return an iterator for traversal upstream from the given output
static ShaderGraphEdgeIterator traverseUpstream(ShaderOutput* output);
/// Return the map of unique identifiers used in the scope of this graph.
IdentifierMap& getIdentifierMap() { return _identifiers; }
protected:
static ShaderGraphPtr createSurfaceShader(
const string& name,
const ShaderGraph* parent,
NodePtr node,
GenContext& context,
ElementPtr& root);
/// Create node connections corresponding to the connection between a pair of elements.
/// @param downstreamElement Element representing the node to connect to.
/// @param upstreamElement Element representing the node to connect from
/// @param connectingElement If non-null, specifies the element on on the downstream node to connect to.
/// @param context Context for generation.
void createConnectedNodes(const ElementPtr& downstreamElement,
const ElementPtr& upstreamElement,
ElementPtr connectingElement,
GenContext& context);
/// Add a node to the graph
void addNode(ShaderNodePtr node);
/// Add input sockets from an interface element (nodedef, nodegraph or node)
void addInputSockets(const InterfaceElement& elem, GenContext& context);
/// Add output sockets from an interface element (nodedef, nodegraph or node)
void addOutputSockets(const InterfaceElement& elem);
/// Traverse from the given root element and add all dependencies upstream.
/// The traversal is done in the context of a material, if given, to include
/// bind input elements in the traversal.
void addUpstreamDependencies(const Element& root, GenContext& context);
/// Add a default geometric node and connect to the given input.
void addDefaultGeomNode(ShaderInput* input, const GeomPropDef& geomprop, GenContext& context);
/// Add a color transform node and connect to the given input.
void addColorTransformNode(ShaderInput* input, const ColorSpaceTransform& transform, GenContext& context);
/// Add a color transform node and connect to the given output.
void addColorTransformNode(ShaderOutput* output, const ColorSpaceTransform& transform, GenContext& context);
/// Add a unit transform node and connect to the given input.
void addUnitTransformNode(ShaderInput* input, const UnitTransform& transform, GenContext& context);
/// Add a unit transform node and connect to the given output.
void addUnitTransformNode(ShaderOutput* output, const UnitTransform& transform, GenContext& context);
/// Perform all post-build operations on the graph.
void finalize(GenContext& context);
/// Optimize the graph, removing redundant paths.
void optimize(GenContext& context);
/// Bypass a node for a particular input and output,
/// effectively connecting the input's upstream connection
/// with the output's downstream connections.
void bypass(GenContext& context, ShaderNode* node, size_t inputIndex, size_t outputIndex = 0);
/// Sort the nodes in topological order.
/// @throws ExceptionFoundCycle if a cycle is encountered.
void topologicalSort();
/// Calculate scopes for all nodes in the graph
void calculateScopes();
/// For inputs and outputs in the graph set the variable names to be used
/// in generated code. Making sure variable names are valid and unique
/// to avoid name conflicts during shader generation.
void setVariableNames(GenContext& context);
/// Populates the input or output color transform map if the provided input/parameter
/// has a color space attribute and has a type of color3 or color4.
void populateColorTransformMap(ColorManagementSystemPtr colorManagementSystem, ShaderPort* shaderPort, ValueElementPtr element, const string& targetColorSpace, bool asInput);
/// Populates the appropriate unit transform map if the provided input/parameter or output
/// has a unit attribute and is of the supported type
void populateUnitTransformMap(UnitSystemPtr unitSystem, ShaderPort* shaderPort, ValueElementPtr element, const string& targetUnitSpace, bool asInput);
/// Break all connections on a node
void disconnect(ShaderNode* node) const;
ConstDocumentPtr _document;
std::unordered_map<string, ShaderNodePtr> _nodeMap;
std::vector<ShaderNode*> _nodeOrder;
IdentifierMap _identifiers;
// Temporary storage for inputs that require color transformations
std::unordered_map<ShaderInput*, ColorSpaceTransform> _inputColorTransformMap;
// Temporary storage for inputs that require unit transformations
std::unordered_map<ShaderInput*, UnitTransform> _inputUnitTransformMap;
// Temporary storage for outputs that require color transformations
std::unordered_map<ShaderOutput*, ColorSpaceTransform> _outputColorTransformMap;
// Temporary storage for outputs that require unit transformations
std::unordered_map<ShaderOutput*, UnitTransform> _outputUnitTransformMap;
};
/// @class ShaderGraphEdge
/// An edge returned during shader graph traversal.
class ShaderGraphEdge
{
public:
ShaderGraphEdge(ShaderOutput* up, ShaderInput* down) :
upstream(up),
downstream(down)
{}
ShaderOutput* upstream;
ShaderInput* downstream;
};
/// @class ShaderGraphEdgeIterator
/// Iterator class for traversing edges between nodes in a shader graph.
class ShaderGraphEdgeIterator
{
public:
ShaderGraphEdgeIterator(ShaderOutput* output);
~ShaderGraphEdgeIterator() { }
bool operator==(const ShaderGraphEdgeIterator& rhs) const
{
return _upstream == rhs._upstream &&
_downstream == rhs._downstream &&
_stack == rhs._stack;
}
bool operator!=(const ShaderGraphEdgeIterator& rhs) const
{
return !(*this == rhs);
}
/// Dereference this iterator, returning the current output in the traversal.
ShaderGraphEdge operator*() const
{
return ShaderGraphEdge(_upstream, _downstream);
}
/// Iterate to the next edge in the traversal.
/// @throws ExceptionFoundCycle if a cycle is encountered.
ShaderGraphEdgeIterator& operator++();
/// Return a reference to this iterator to begin traversal
ShaderGraphEdgeIterator& begin()
{
return *this;
}
/// Return the end iterator.
static const ShaderGraphEdgeIterator& end();
private:
void extendPathUpstream(ShaderOutput* upstream, ShaderInput* downstream);
void returnPathDownstream(ShaderOutput* upstream);
ShaderOutput* _upstream;
ShaderInput* _downstream;
using StackFrame = std::pair<ShaderOutput*, size_t>;
std::vector<StackFrame> _stack;
std::set<ShaderOutput*> _path;
};
} // namespace MaterialX
#endif