-
Notifications
You must be signed in to change notification settings - Fork 0
/
fx_graph_converter_patch
289 lines (287 loc) · 17 KB
/
fx_graph_converter_patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
--- /opt/pytorch/pytorch/torch/_functorch/aot_autograd.py 2023-01-17 00:39:30.000000000 -0800
+++ aot_autograd.py 2023-01-23 20:32:14.164081000 -0800
@@ -86,6 +86,7 @@
# one counter is allocated per entire compiled block (but this block
# may involve compiling multiple subgraphs; e.g., for forwards/backwards)
AOT_COUNTER = itertools.count()
+FX_CONVERT_COUNTER = itertools.count()
KNOWN_TYPES = tuple(
[torch.Tensor, int, str, float, bool, type(None)] + list(py_sym_types)
@@ -1656,6 +1657,278 @@
log.debug(f"====== Joint graph {aot_config.aot_id} ======")
log.debug(fx_g.print_readable(print_output=False))
+ next(FX_CONVERT_COUNTER)
+ print("[FX Graph to NetworkX Exporter]: Starting Graph Export {}".format(FX_CONVERT_COUNTER))
+ from torch.fx.passes.graph_drawer import FxGraphDrawer as fgd
+ g = fgd(fx_g, 'fx_graph_extraction')
+ x = g.get_main_dot_graph()
+ import networkx as nx
+ nx_g = nx.nx_pydot.from_pydot(x)
+ source_nodes = [x.name for x in fx_g.graph.nodes if 'primals' in x.name]
+ source_nodes += [x.name for x in fx_g.graph.nodes if 'tangents' in x.name]
+ nx_edges = [x for x in nx.edge_bfs(nx_g, source=source_nodes)]
+
+ edge_attr_dict = dict()
+ node_attr_dict = dict()
+ call_sigs = []
+ for source in nx_g.nodes:
+ label = nx_g.nodes[source]['label']
+ label = label.replace('{','', 1)
+ label = label.rsplit('}',1)[0]
+ label = label.replace('\n', '').replace('\\n', '').replace('\l','').replace('\\l','').replace('\\','')
+ labels = label.split('|')
+ label_dict = {'shape':[], 'dtype':None}
+ for x in labels:
+ new_x = x.split('=')
+ #new_x[1] = new_x[1].replace(':', '=')
+ label_dict.update({new_x[0] : new_x[1]})
+
+ #label_str = label.replace(':','=') #This is a workaround to make to_pydot work for drawing. Look at https://github.com/pydot/pydot/issues/258
+
+ fx_graph_node = [x for x in fx_g.graph.nodes if x.name==source][0]
+ arg_dict = dict()
+ #if 'args' in label_dict.keys():
+ if getattr(fx_graph_node, 'args', None):
+ #Below code block is an attempt to map args received for node from FX Graph exactly to the op signature evaluated using the target
+ #This is still experimental as it requires a lot more fundamental understanding of the PyTorch codebase and individual ops to capture all cases
+ #if 'target' in label_dict.keys():
+ if getattr(fx_graph_node, 'target', None):
+ try:
+ #node_call_sig = eval(label_dict['target'])._schema.__str__()
+ node_call_sig = fx_graph_node.target._schema.__str__()
+ call_sigs.append(node_call_sig)
+ arg_dict = dict()
+
+ '''
+ #Simple code block to just add args received as is.
+ #Can comment out the experimental code block if processing the args and signature is not working
+ #=========
+ arg_dict = {'target_signature': node_call_sig, 'args': label_dict['args']}
+ if 'kwargs' in label_dict.keys():
+ arg_dict.update({'kwargs': label_dict['kwargs']})
+ #=========
+ '''
+
+ #Experimental code block -->
+ import re
+ #Match for anything between () as long as ')' is not inside the paranthesis.
+ #This is to protect signatures like 'aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)'
+ #sig_match = re.search(r'(\([^\)]+\))', node_call_sig).group()
+ #sig_args = re.split(',', sig_match.replace('(', '').replace(')', ''))
+ sig_match = re.search(r'(\(.+\)\s->)', node_call_sig).group()
+ sig_match = sig_match.replace(' ->', '')
+ sig_match = sig_match.replace('(','', 1)
+ sig_match = sig_match.rsplit(')',1)[0]
+ sig_args = re.split(',', sig_match)
+
+ req_s_args = []
+ opt_s_args = []
+ for s_arg in sig_args:
+ #Replace the empty white space if any at the beginning of the arg
+ s_arg = re.sub(r'^\s', '', s_arg)
+ if '=' in s_arg:
+ opt_s_args.append(s_arg)
+ else:
+ req_s_args.append(s_arg)
+
+ opt_s_arg_dict = dict()
+ for opt_s_arg in opt_s_args:
+ opt_s_arg_type, opt_s_arg_def = opt_s_arg.split(' ')
+ opt_s_arg_key, opt_s_arg_val = opt_s_arg_def.split('=')
+ opt_s_arg_dict.update({opt_s_arg_key:opt_s_arg_val})
+
+ #node_args = eval(label_dict['args'])
+ node_args = fx_graph_node.args
+
+ #FIXME: Update kwargs to use FX Graph information as well instead of label data
+ #Below code block is to make the kwargs str work with eval.
+ if 'kwargs' in label_dict.keys():
+ kwargs_str = label_dict['kwargs']
+ for opt_s_arg_key in opt_s_arg_dict.keys():
+ if opt_s_arg_key in kwargs_str:
+ kwargs_str = kwargs_str.replace(str(opt_s_arg_key), '"' + str(opt_s_arg_key) + '"')
+
+ node_kwargs = eval(kwargs_str)
+ for node_kwarg in node_kwargs.keys():
+ if node_kwarg in opt_s_arg_dict.keys():
+ arg_dict.update({node_kwarg:node_kwargs[node_kwarg]})
+ opt_s_arg_dict.pop(node_kwarg)
+ else:
+ node_kwargs = None
+
+ '''
+ #This section was originally written to process any definite arguments such as Tensor other=mul_1
+ #This is commented out because I have not seen a case where a definite argument is present in the args
+ #And kwargs are handled separately above.
+ #FIXME: Add check to see if args ever get a '=' definite argument
+
+ # definite_node_args = [n_arg for n_arg in node_args if '=' in n_arg]
+
+ # for def_n_arg in definite_node_args:
+ # arg_key, arg_val = node_arg.split('=')
+ # assert arg_key in sig_args,\
+ # "Found definite arg {} with value {} in FX Graph but does not match found signature {}".format(arg_key, arg_val, node_call_sig)
+ # arg_dict.update({arg_key:arg_val})
+ # sig_args.remove(arg_key)
+ # node_args.remove(def_n_arg)
+
+ '''
+
+ #This to remove self args from the signature where it exists so that other args can be mapped correctly
+ #TODO: Is it always true that self/input will show up in the first arg of signature if the node args don't have it?
+ #TODO Contd: Check if these cases even happen since I moved from label_dict to fx graph args.
+ if len(req_s_args) > len(node_args) and ('self' in req_s_args[0] or 'input' in req_s_args[0]):
+ del req_s_args[0]
+ #Case where op accepts more optional inputs but they weren't given
+ if len(req_s_args) > len(node_args) and ('*' in req_s_args):
+ req_s_args.remove('*')
+ assert len(req_s_args) == len(node_args), "Node args {} received do not match signature evaluated {}".format(node_args, sig_args)
+
+ #Case where multiple optional inputs were added
+ if len(node_args) > len(req_s_args) and '*' in req_s_args:
+ req_s_args.extend(['*' for x in range(len(node_args) - len(req_s_args))])
+ assert len(req_s_args) == len(node_args), "Node args {} received do not match signature evaluated {}".format(node_args, sig_args)
+
+ #Assign node args to req args until they run out. In most cases all node args should also get consumed here.
+ for idx, sig_arg in enumerate(req_s_args):
+ arg_dict.update({sig_arg:node_args[idx]})
+ node_args = node_args[len(req_s_args):]
+
+ #Continue assigning to opt args until node args run out. Remove opt_s_arg_dict key once assigned.
+ opt_s_args_to_assign = list(opt_s_arg_dict.keys())[:len(node_args)]
+ for sig_arg, n_arg in zip(opt_s_args_to_assign, node_args):
+ arg_dict.update({sig_arg:n_arg})
+ opt_s_arg_dict.pop(sig_arg)
+
+ #Add remaining default opt_s_args here
+ for opt_s_arg_key in opt_s_arg_dict.keys():
+ if opt_s_arg_key not in arg_dict.keys():
+ arg_dict.update({opt_s_arg_key:opt_s_arg_dict[opt_s_arg_key]})
+
+ except:
+ print("[WARNING - FX Graph to NetworkX] Failed to map args. Target signature {} for op {}".format(label_dict['target'], source))
+ arg_dict = dict((k, label_dict[k]) for k in ['args', 'kwargs'] if k in label_dict.keys())
+ else:
+ #Does this case ever exist where there are args provided but no target function call? Not sure.
+ arg_dict = dict((k, label_dict[k]) for k in ['args', 'kwargs'] if k in label_dict.keys())
+
+ arg_str = str(arg_dict).replace(':', '=') #This is a workaround to make to_pydot work for drawing. Look at https://github.com/pydot/pydot/issues/258
+ arg_str = arg_str.replace('{','')
+ arg_str = arg_str.replace('}','')
+ lbl = 'name=' + str(label_dict['name']) + '\ncall_args_info=' + arg_str
+ #node_dict = {'node_call_info' : arg_str, 'name' : label_dict['name'], 'label': lbl}
+ node_dict = {'node_call_info' : arg_str, 'label': lbl}
+ node_attr_dict.update({source:node_dict})
+
+ def _get_tensor_info(fxg_node):
+ tensor_info = None
+ if hasattr(fxg_node, "meta") and "tensor_meta" in fxg_node.meta:
+ tensor_info = fxg_node.meta['tensor_meta']
+ elif hasattr(fxg_node, "meta") and "val" in fxg_node.meta:
+ tensor_info = fxg_node.meta['val']
+ return tensor_info
+
+ edge_names = [(source,t,k) for t in nx_g[source].keys() for k in nx_g[source][t].keys()]
+
+ tensor_info = _get_tensor_info(fx_graph_node)
+ if tensor_info is None:
+ #This means that tensor metadata was not present for this node. This is undesirable and may be good to file bugs on PyTorch regarding this.
+ #For now, there are cases we can manage when output is managed by getitem nodes
+ #It may be desirable to ALWAYS match tensor metadata when multiple getitem nodes are users but that's an enhancement for future.
+ if all(['getitem' in x.name for x in fx_graph_node.users.keys()]):
+ for dest_node in fx_graph_node.users.keys():
+ dest_tensor_info = _get_tensor_info(dest_node)
+ assert dest_tensor_info is not None, "getitem node {} does not have tensor info. This fails assumptions made in the exporter.".format(dest_node)
+ tensor_info = tensor_info + [dest_tensor_info] if tensor_info else [dest_tensor_info]
+
+ if tensor_info is not None:
+ #There are underlying assumption based on my conversation with Horace at Meta
+ #1. When an op returns multiple tensors, it is unpacked by subsequent getitem operator calls.
+ #Look at assumption in len(edge_names) > 1
+ #2. When an op has a tensor that is connected to multiple users (other nodes), then the same tensor is on each edge
+ #Look at case where where len(tensor_info) < len(edge_names) and len(tensor_info) == 1:
+ #Remember - Assumption is if there are multiple tensors (len(tensor_info) > 1), they will always go to individual getitem calls
+ if not isinstance(tensor_info, list):
+ tensor_info = [tensor_info]
+ if len(tensor_info) < len(edge_names) and len(tensor_info) == 1:
+ tensor_info = [tensor_info[0] for idx in range(len(edge_names))]
+
+ if len(edge_names) == 0:
+ assert len(fx_graph_node.users) == 0, "There are no edges in the exported graph for op {} but FX Graph has users {}".format(source,fx_graph_node.users)
+ else:
+ assert len(tensor_info) == len(edge_names), "The length of tensors available in FX Graph do not match with the edges present in exported NetworkX Graph for op {}".format(source)
+ if len(edge_names) == 1:
+ edge_dict = dict()
+ edge_dict.update({'shape':tensor_info[0].shape})
+ edge_dict.update({'dtype':tensor_info[0].dtype})
+ edge_dict.update({'label':'shape=' + str(tensor_info[0].shape) + '\ndtype=' + str(tensor_info[0].dtype)})
+ edge_attr_dict.update({edge_names[0]:edge_dict})
+ elif len(edge_names) > 1:
+ #Here we are hoping that the order of edges in FX Graph tensor_info and corresponding NetworkX Graph edges match up correctly.
+ #Perhaps we could match with users from FX Graph and targets in source,target,key to match edges.
+ #TODO: Explore the above matching
+ for edge_idx in range(len(edge_names)):
+ edge_dict = dict()
+ edge_dict.update({'shape':tensor_info[edge_idx].shape})
+ edge_dict.update({'dtype':tensor_info[edge_idx].dtype})
+ edge_dict.update({'label':'shape=' + str(tensor_info[edge_idx].shape) + '\ndtype=' + str(tensor_info[edge_idx].dtype)})
+ edge_attr_dict.update({edge_names[edge_idx]:edge_dict})
+ else:
+ if len(edge_names) == 0:
+ assert fx_graph_node.op == 'output', "No edges found for non-output op {}".format(source)
+ else:
+ #This case should hopefully not be used if FXGraph has been constructed properly
+ print("[WARNING: FX Graph to NetworkX] Edge tensor information was not found in the FX Graph for op {}. Using label from PyDot export".format(source))
+ edge_dict = dict((k, label_dict[k]) for k in ('shape', 'dtype'))
+ edge_dict.update({'label':'shape=' + str(label_dict['shape']) + '\ndtype=' + str(label_dict['dtype'])})
+ edge_attr_dict.update({edge_names[0]:edge_dict})
+
+ nx.set_edge_attributes(nx_g, edge_attr_dict)
+ nx.set_node_attributes(nx_g, node_attr_dict)
+
+ import pickle
+ gen_name = 'fx_graph_extracted_id_' + str(FX_CONVERT_COUNTER)
+ pickle.dump(nx_g, open(gen_name + '.pickle', 'wb'))
+
+ #Note: To read this pickle dump, use
+ #fx_g_extracted = pickle.load(open('fx_graph_extracted.pickle', 'rb'))
+
+ #Drawing
+ #print("[FX Graph to NetworkX Exporter]: Starting to draw extracted graph {}".format(FX_CONVERT_COUNTER))
+ #edited_dot = nx.nx_pydot.to_pydot(nx_g)
+ #getattr(edited_dot, 'write_png')(gen_name + '.png')
+ #print("[FX Graph to NetworkX Exporter]: Finished drawing extracted graph {}".format(FX_CONVERT_COUNTER))
+
+ '''
+ new_nx_g = nx.Graph(nx_g)
+
+ import matplotlib.pyplot as plt
+
+ #Draw Nodes
+ pos = nx.spring_layout(new_nx_g)
+ nx.draw_networkx_nodes(new_nx_g, pos, node_size=900)
+
+ import pdb; pdb.set_trace()
+ #Create minimal label for drawing
+ name_lbls = nx.get_node_attributes(new_nx_g, 'name')
+ call_lbls = nx.get_node_attributes(new_nx_g, 'node_call_info')
+ for n in new_nx_g.nodes:
+ lbl = 'name: ' + str(name_lbls[n]) + '\ncall_args_info:' + str(call_lbls[n])
+ node_draw_lbl = {n:lbl}
+ nx.draw_networkx_labels(new_nx_g, pos, labels=node_draw_lbl)
+
+ #Draw Edges
+ nx.draw_networkx_edges(new_nx_g, pos)
+ shape_lbls = nx.get_edge_attributes(new_nx_g, 'shape')
+ dtype_lbls = nx.get_edge_attributes(new_nx_g, 'dtype')
+ for e in new_nx_g.edges():
+ lbl = 'shape: ' + str(shape_lbls[e]) + '\dtype:' + str(dtype_lbls[e])
+ edge_draw_lbl = {e:lbl}
+ nx.draw_networkx_edge_labels(new_nx_g, pos, edge_labels=edge_draw_lbl)
+
+ plt.savefig("fx_extracted_nx.png", format="PNG")
+ '''
+
with torch.no_grad():
with track_graph_compiling(aot_config, "joint"):
num_inner_fwd_outputs = _num_mutated_inputs + _num_outputs + _fw_metadata.num_intermediate_bases