-
Notifications
You must be signed in to change notification settings - Fork 11
/
stack_logic.py
1369 lines (1206 loc) · 39.8 KB
/
stack_logic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#
# Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
#
# SPDX-License-Identifier: BSD-2-Clause
#
import syntax
import solver
import problem
import rep_graph
import search
import logic
import check
from target_objects import functions, trace, pairings, pre_pairings, printout
import target_objects
from logic import azip
from syntax import mk_var, word32T, builtinTs, mk_eq, mk_less_eq
last_stuff = [0]
def default_n_vc (p, n):
head = p.loop_id (n)
general = [(n2, rep_graph.vc_options ([0], [1]))
for n2 in p.loop_heads ()
if n2 != head]
specific = [(head, rep_graph.vc_offs (1)) for _ in [1] if head]
return (n, tuple (general + specific))
def split_sum_s_expr (expr, solv, extra_defs, typ):
"""divides up a linear expression 'a - b - 1 + a'
into ({'a':2, 'b': -1}, -1) i.e. 'a' times 2 etc and constant
value of -1."""
def rec (expr):
return split_sum_s_expr (expr, solv, extra_defs, typ)
if expr[0] == 'bvadd':
var = {}
const = 0
for x in expr[1:]:
(var2, const2) = rec (x)
for (v, count) in var2.iteritems ():
var.setdefault (v, 0)
var[v] += count
const += const2
return (var, const)
elif expr[0] == 'bvsub':
(_, lhs, rhs) = expr
(lvar, lconst) = rec (lhs)
(rvar, rconst) = rec (rhs)
const = lconst - rconst
var = dict ([(v, lvar.get (v, 0) - rvar.get (v, 0))
for v in set.union (set (lvar), set (rvar))])
return (var, const)
elif expr in solv.defs:
return rec (solv.defs[expr])
elif expr in extra_defs:
return rec (extra_defs[expr])
elif expr[:2] in ['#x', '#b']:
val = solver.smt_to_val (expr)
assert val.kind == 'Num'
return ({}, val.val)
else:
return ({expr: 1}, 0)
def split_merge_ite_sum_sexpr (foo):
(s0, s1) = [solver.smt_num_t (n, typ) for n in [0, 1]]
if y != s0:
expr = ('bvadd', ('ite', cond, ('bvsub', x, y), s0), y)
return rec (expr)
(xvar, xconst) = rec (x)
var = dict ([(('ite', cond, v, s0), n)
for (v, n) in xvar.iteritems ()])
var.setdefault (('ite', cond, s1, s0), 0)
var[('ite', cond, s1, s0)] += xconst
return (var, 0)
def simplify_expr_whyps (sexpr, rep, hyps, cache = None, extra_defs = {},
bool_hyps = None):
if cache == None:
cache = {}
if bool_hyps == None:
bool_hyps = []
if sexpr in extra_defs:
sexpr = extra_defs[sexpr]
if sexpr in rep.solv.defs:
sexpr = rep.solv.defs[sexpr]
if sexpr[0] == 'ite':
(_, cond, x, y) = sexpr
cond_exp = solver.mk_smt_expr (solver.flat_s_expression (cond),
syntax.boolT)
(mk_nimp, mk_not) = (syntax.mk_n_implies, syntax.mk_not)
if rep.test_hyp_whyps (mk_nimp (bool_hyps, cond_exp),
hyps, cache = cache):
return x
elif rep.test_hyp_whyps (mk_nimp (bool_hyps, mk_not (cond_exp)),
hyps, cache = cache):
return y
x = simplify_expr_whyps (x, rep, hyps, cache = cache,
extra_defs = extra_defs,
bool_hyps = bool_hyps + [cond_exp])
y = simplify_expr_whyps (y, rep, hyps, cache = cache,
extra_defs = extra_defs,
bool_hyps = bool_hyps + [syntax.mk_not (cond_exp)])
if x == y:
return x
return ('ite', cond, x, y)
return sexpr
last_10_non_const = []
def offs_expr_const (addr_expr, sp_expr, rep, hyps, extra_defs = {},
cache = None, typ = syntax.word32T):
"""if the offset between a stack addr and the initial stack pointer
is a constant offset, try to compute it."""
addr_x = solver.parse_s_expression (addr_expr)
sp_x = solver.parse_s_expression (sp_expr)
vs = [(addr_x, 1), (sp_x, -1)]
const = 0
while True:
start_vs = list (vs)
new_vs = {}
for (x, mult) in vs:
(var, c) = split_sum_s_expr (x, rep.solv, extra_defs,
typ = typ)
for v in var:
new_vs.setdefault (v, 0)
new_vs[v] += var[v] * mult
const += c * mult
vs = [(x, n) for (x, n) in new_vs.iteritems ()
if n % (2 ** typ.num) != 0]
if not vs:
return const
vs = [(simplify_expr_whyps (x, rep, hyps,
cache = cache, extra_defs = extra_defs), n)
for (x, n) in vs]
if sorted (vs) == sorted (start_vs):
pass # vs = split_merge_ite_sum_sexpr (vs)
if sorted (vs) == sorted (start_vs):
trace ('offs_expr_const: not const')
trace ('%s - %s' % (addr_expr, sp_expr))
trace (str (vs))
trace (str (hyps))
last_10_non_const.append ((addr_expr, sp_expr, vs, hyps))
del last_10_non_const[:-10]
return None
def has_stack_var (expr, stack_var):
while True:
if expr.is_op ('MemUpdate'):
[m, p, v] = expr.vals
expr = m
elif expr.kind == 'Var':
return expr == stack_var
else:
assert not 'has_stack_var: expr kind', expr
def mk_not_callable_hyps (p):
hyps = []
for n in p.nodes:
if p.nodes[n].kind != 'Call':
continue
if get_asm_callable (p.nodes[n].fname):
continue
tag = p.node_tags[n][0]
hyp = rep_graph.pc_false_hyp ((default_n_vc (p, n), tag))
hyps.append (hyp)
return hyps
last_get_ptr_offsets = [0]
last_get_ptr_offsets_setup = [0]
def get_ptr_offsets (p, n_ptrs, bases, hyps = [], cache = None,
fail_early = False):
"""detect which ptrs are guaranteed to be at constant offsets
from some set of basis ptrs"""
rep = rep_graph.mk_graph_slice (p, fast = True)
if cache == None:
cache = {}
last_get_ptr_offsets[0] = (p, n_ptrs, bases, hyps)
smt_bases = []
for (n, ptr, k) in bases:
n_vc = default_n_vc (p, n)
(_, env) = rep.get_node_pc_env (n_vc)
smt = solver.smt_expr (ptr, env, rep.solv)
smt_bases.append ((smt, k))
ptr_typ = ptr.typ
smt_ptrs = []
for (n, ptr) in n_ptrs:
n_vc = default_n_vc (p, n)
pc_env = rep.get_node_pc_env (n_vc)
if not pc_env:
continue
smt = solver.smt_expr (ptr, pc_env[1], rep.solv)
hyp = rep_graph.pc_true_hyp ((n_vc, p.node_tags[n][0]))
smt_ptrs.append (((n, ptr), smt, hyp))
hyps = hyps + mk_not_callable_hyps (p)
for tag in set ([p.node_tags[n][0] for (n, _) in n_ptrs]):
hyps = hyps + init_correctness_hyps (p, tag)
tags = set ([p.node_tags[n][0] for (n, ptr) in n_ptrs])
ex_defs = {}
for t in tags:
ex_defs.update (get_extra_sp_defs (rep, t))
offs = []
for (v, ptr, hyp) in smt_ptrs:
off = None
for (ptr2, k) in smt_bases:
off = offs_expr_const (ptr, ptr2, rep, [hyp] + hyps,
cache = cache, extra_defs = ex_defs,
typ = ptr_typ)
if off != None:
offs.append ((v, off, k))
break
if off == None:
trace ('get_ptr_offs fallthrough at %d: %s' % v)
trace (str ([hyp] + hyps))
assert not fail_early, (v, ptr)
return offs
def init_correctness_hyps (p, tag):
(_, fname, _) = p.get_entry_details (tag)
if fname not in pairings:
# conveniently handles bootstrap case
return []
# revise if multi-pairings for ASM an option
[pair] = pairings[fname]
true_tag = None
if tag in pair.funs:
true_tag = tag
elif p.hook_tag_hints.get (tag, tag) in pair.funs:
true_tag = p.hook_tag_hints.get (tag, tag)
if true_tag == None:
return []
(inp_eqs, _) = pair.eqs
in_tag = "%s_IN" % true_tag
eqs = [eq for eq in inp_eqs if eq[0][1] == in_tag
and eq[1][1] == in_tag]
return check.inst_eqs (p, (), eqs, {true_tag: tag})
extra_symbols = set ()
def preserves_sp (fname):
"""all functions will keep the stack pointer equal, whether they have
pairing partners or not."""
assume_sp_equal = bool (target_objects.hooks ('assume_sp_equal'))
if not extra_symbols:
for fname2 in target_objects.symbols:
extra_symbols.add(fname2)
extra_symbols.add('_'.join (fname2.split ('.')))
return (get_asm_calling_convention (fname)
or assume_sp_equal
or fname in extra_symbols)
def get_extra_sp_defs (rep, tag):
"""add extra defs/equalities about stack pointer for the
purposes of stack depth analysis."""
# FIXME how to parametrise this?
sp = mk_var ('r13', syntax.word32T)
defs = {}
fcalls = [n_vc for n_vc in rep.funcs
if logic.is_int (n_vc[0])
if rep.p.node_tags[n_vc[0]][0] == tag
if preserves_sp (rep.p.nodes[n_vc[0]].fname)]
for (n, vc) in fcalls:
(inputs, outputs, _) = rep.funcs[(n, vc)]
if (sp.name, sp.typ) not in outputs:
continue
inp_sp = solver.smt_expr (sp, inputs, rep.solv)
inp_sp = solver.parse_s_expression (inp_sp)
out_sp = solver.smt_expr (sp, outputs, rep.solv)
out_sp = solver.parse_s_expression (out_sp)
if inp_sp != out_sp:
defs[out_sp] = inp_sp
return defs
def get_stack_sp (p, tag):
"""get stack and stack-pointer variables"""
entry = p.get_entry (tag)
renames = p.entry_exit_renames (tags = [tag])
r = renames[tag + '_IN']
sp = syntax.rename_expr (mk_var ('r13', syntax.word32T), r)
stack = syntax.rename_expr (mk_var ('stack',
syntax.builtinTs['Mem']), r)
return (stack, sp)
def pseudo_node_lvals_rvals (node):
assert node.kind == 'Call'
cc = get_asm_calling_convention_at_node (node)
if not cc:
return None
arg_vars = set ([var for arg in cc['args']
for var in syntax.get_expr_var_set (arg)])
callee_saved_set = set (cc['callee_saved'])
rets = [(nm, typ) for (nm, typ) in node.rets
if mk_var (nm, typ) not in callee_saved_set]
return (rets, arg_vars)
def is_asm_node (p, n):
tag = p.node_tags[n][0]
return tag == 'ASM' or p.hook_tag_hints.get (tag, None) == 'ASM'
def all_pseudo_node_lvals_rvals (p):
pseudo = {}
for n in p.nodes:
if not is_asm_node (p, n):
continue
elif p.nodes[n].kind != 'Call':
continue
ps = pseudo_node_lvals_rvals (p.nodes[n])
if ps != None:
pseudo[n] = ps
return pseudo
def adjusted_var_dep_outputs_for_tag (p, tag):
(ent, fname, _) = p.get_entry_details (tag)
fun = functions[fname]
cc = get_asm_calling_convention (fname)
callee_saved_set = set (cc['callee_saved'])
ret_set = set ([(nm, typ) for ret in cc['rets']
for (nm, typ) in syntax.get_expr_var_set (ret)])
rets = [(nm2, typ) for ((nm, typ), (nm2, _))
in azip (fun.outputs, p.outputs[tag])
if (nm, typ) in ret_set
or mk_var (nm, typ) in callee_saved_set]
return rets
def adjusted_var_dep_outputs (p):
outputs = {}
for tag in p.outputs:
ent = p.get_entry (tag)
if is_asm_node (p, ent):
outputs[tag] = adjusted_var_dep_outputs_for_tag (p, tag)
else:
outputs[tag] = p.outputs[tag]
def output (n):
tag = p.node_tags[n][0]
return outputs[tag]
return output
def is_stack (expr):
return expr.kind == 'Var' and 'stack' in expr.name
class StackOffsMissing (Exception):
pass
def stack_virtualise_expr (expr, sp_offs):
if expr.is_op ('MemAcc') and is_stack (expr.vals[0]):
[m, p] = expr.vals
if expr.typ == syntax.word8T:
ps = [(syntax.mk_minus (p, syntax.mk_word32 (n)), n)
for n in [0, 1, 2, 3]]
elif expr.typ == syntax.word32T:
ps = [(p, 0)]
else:
assert expr.typ == syntax.word32T, expr
ptrs = [(p, 'MemAcc') for (p, _) in ps]
if sp_offs == None:
return (ptrs, None)
# FIXME: very 32-bit specific
ps = [(p, n) for (p, n) in ps if p in sp_offs
if sp_offs[p][1] % 4 == 0]
if not ps:
return (ptrs, expr)
[(p, n)] = ps
if p not in sp_offs:
raise StackOffsMissing ()
(k, offs) = sp_offs[p]
v = mk_var (('Fake', k, offs), syntax.word32T)
if n != 0:
v = syntax.mk_shiftr (v, n * 8)
v = syntax.mk_cast (v, expr.typ)
return (ptrs, v)
elif expr.kind == 'Op':
vs = [stack_virtualise_expr (v, sp_offs) for v in expr.vals]
return ([p for (ptrs, _) in vs for p in ptrs],
syntax.adjust_op_vals (expr, [v for (_, v) in vs]))
else:
return ([], expr)
def stack_virtualise_upd (((nm, typ), expr), sp_offs):
if 'stack' in nm:
upds = []
ptrs = []
while expr.is_op ('MemUpdate'):
[m, p, v] = expr.vals
ptrs.append ((p, 'MemUpdate'))
(ptrs2, v2) = stack_virtualise_expr (v, sp_offs)
ptrs.extend (ptrs2)
if sp_offs != None:
if p not in sp_offs:
raise StackOffsMissing ()
(k, offs) = sp_offs[p]
upds.append (((('Fake', k, offs),
syntax.word32T), v2))
expr = m
assert is_stack (expr), expr
return (ptrs, upds)
else:
(ptrs, expr2) = stack_virtualise_expr (expr, sp_offs)
return (ptrs, [((nm, typ), expr2)])
def stack_virtualise_ret (expr, sp_offs):
if expr.kind == 'Var':
return ([], (expr.name, expr.typ))
elif expr.is_op ('MemAcc'):
[m, p] = expr.vals
assert expr.typ == syntax.word32T, expr
assert is_stack (m), expr
if sp_offs != None:
(k, offs) = sp_offs[p]
r = (('Fake', k, offs), syntax.word32T)
else:
r = None
return ([(p, 'MemUpdate')], r)
else:
assert not 'ret expr understood', expr
def stack_virtualise_node (node, sp_offs):
if node.kind == 'Cond':
(ptrs, cond) = stack_virtualise_expr (node.cond, sp_offs)
if sp_offs == None:
return (ptrs, None)
else:
return (ptrs, syntax.Node ('Cond',
node.get_conts (), cond))
elif node.kind == 'Call':
if is_instruction (node.fname):
return ([], node)
cc = get_asm_calling_convention_at_node (node)
assert cc != None, node.fname
args = [arg for arg in cc['args'] if not is_stack (arg)]
args = [stack_virtualise_expr (arg, sp_offs) for arg in args]
rets = [ret for ret in cc['rets_inp'] if not is_stack (ret)]
rets = [stack_virtualise_ret (ret, sp_offs) for ret in rets]
ptrs = list (set ([p for (ps, _) in args for p in ps]
+ [p for (ps, _) in rets for p in ps]))
if sp_offs == None:
return (ptrs, None)
else:
return (ptrs, syntax.Node ('Call', node.cont,
(None, [v for (_, v) in args]
+ [p for (p, _) in ptrs],
[r for (_, r) in rets])))
elif node.kind == 'Basic':
upds = [stack_virtualise_upd (upd, sp_offs) for upd in node.upds]
ptrs = list (set ([p for (ps, _) in upds for p in ps]))
if sp_offs == None:
return (ptrs, None)
else:
ptr_upds = [(('unused#ptr#name%d' % i, syntax.word32T),
ptr) for (i, (ptr, _)) in enumerate (ptrs)]
return (ptrs, syntax.Node ('Basic', node.cont,
[upd for (_, us) in upds for upd in us]
+ ptr_upds))
else:
assert not "node kind understood", node.kind
def mk_get_local_offs (p, tag, sp_reps):
(stack, _) = get_stack_sp (p, tag)
def mk_local (n, kind, off, k):
(v, off2) = sp_reps[n][k]
ptr = syntax.mk_plus (v, syntax.mk_word32 (off + off2))
if kind == 'Ptr':
return ptr
elif kind == 'MemAcc':
return syntax.mk_memacc (stack, ptr, syntax.word32T)
return mk_local
def adjust_ret_ptr (ptr):
"""this is a bit of a hack.
the return slots are named based on r0_input, which will be unchanged,
which is handy, but we really want to be talking about r0, which will
produce meaningful offsets against the pointers actually used in the
program."""
return logic.var_subst (ptr, {('ret_addr_input', syntax.word32T):
syntax.mk_var ('r0', syntax.word32T)}, must_subst = False)
def get_loop_virtual_stack_analysis (p, tag):
"""computes variable liveness etc analyses with stack slots treated
as virtual variables."""
cache_key = ('loop_stack_analysis', tag)
if cache_key in p.cached_analysis:
return p.cached_analysis[cache_key]
(ent, fname, _) = p.get_entry_details (tag)
(_, sp) = get_stack_sp (p, tag)
cc = get_asm_calling_convention (fname)
rets = list (set ([ptr for arg in cc['rets']
for (ptr, _) in stack_virtualise_expr (arg, None)[0]]))
rets = [adjust_ret_ptr (ret) for ret in rets]
renames = p.entry_exit_renames (tags = [tag])
r = renames[tag + '_OUT']
rets = [syntax.rename_expr (ret, r) for ret in rets]
ns = [n for n in p.nodes if p.node_tags[n][0] == tag]
loop_ns = logic.minimal_loop_node_set (p)
ptrs = list (set ([(n, ptr) for n in ns
for ptr in (stack_virtualise_node (p.nodes[n], None))[0]]))
ptrs += [(n, (sp, 'StackPointer')) for n in ns if n in loop_ns]
offs = get_ptr_offsets (p, [(n, ptr) for (n, (ptr, _)) in ptrs],
[(ent, sp, 'stack')]
+ [(ent, ptr, 'indirect_ret') for ptr in rets[:1]])
ptr_offs = {}
rep_offs = {}
upd_offsets = {}
for ((n, ptr), off, k) in offs:
off = norm_int (off, 32)
ptr_offs.setdefault (n, {})
rep_offs.setdefault (n, {})
ptr_offs[n][ptr] = (k, off)
rep_offs[n][k] = (ptr, - off)
for (n, (ptr, kind)) in ptrs:
if kind == 'MemUpdate' and n in loop_ns:
loop = p.loop_id (n)
(k, off) = ptr_offs[n][ptr]
upd_offsets.setdefault (loop, set ())
upd_offsets[loop].add ((k, off))
loc_offs = mk_get_local_offs (p, tag, rep_offs)
adj_nodes = {}
for n in ns:
try:
(_, node) = stack_virtualise_node (p.nodes[n],
ptr_offs.get (n, {}))
except StackOffsMissing, e:
printout ("Stack analysis issue at (%d, %s)."
% (n, p.node_tags[n]))
node = p.nodes[n]
adj_nodes[n] = node
# finally do analysis on this collection of nodes
preds = dict (p.preds)
preds['Ret'] = [n for n in preds['Ret'] if p.node_tags[n][0] == tag]
preds['Err'] = [n for n in preds['Err'] if p.node_tags[n][0] == tag]
vds = logic.compute_var_deps (adj_nodes,
adjusted_var_dep_outputs (p), preds)
result = (vds, adj_nodes, loc_offs, upd_offsets, (ptrs, offs))
p.cached_analysis[cache_key] = result
return result
def norm_int (n, radix):
n = n & ((1 << radix) - 1)
n2 = n - (1 << radix)
if abs (n2) < abs (n):
return n2
else:
return n
def loop_var_analysis (p, split):
"""computes the same loop dataflow analysis as in the 'logic' module
but with stack slots treated as virtual variables."""
if not is_asm_node (p, split):
return None
head = p.loop_id (split)
tag = p.node_tags[split][0]
assert head
key = ('loop_stack_virtual_var_cycle_analysis', split)
if key in p.cached_analysis:
return p.cached_analysis[key]
(vds, adj_nodes, loc_offs,
upd_offsets, _) = get_loop_virtual_stack_analysis (p, tag)
loop = p.loop_body (head)
va = logic.compute_loop_var_analysis (p, vds, split,
override_nodes = adj_nodes)
(stack, _) = get_stack_sp (p, tag)
va2 = []
uoffs = upd_offsets.get (head, [])
for (v, data) in va:
if v.kind == 'Var' and v.name[0] == 'Fake':
(_, k, offs) = v.name
if (k, offs) not in uoffs:
continue
v2 = loc_offs (split, 'MemAcc', offs, k)
va2.append ((v2, data))
elif v.kind == 'Var' and v.name.startswith ('stack'):
assert v.typ == stack.typ
continue
else:
va2.append ((v, data))
stack_const = stack
for (k, off) in uoffs:
stack_const = syntax.mk_memupd (stack_const,
loc_offs (split, 'Ptr', off, k),
syntax.mk_word32 (0))
sp = asm_stack_rep_hook (p, (stack.name, stack.typ), 'Loop', split)
assert sp and sp[0] == 'SplitMem', (split, sp)
(_, st_split) = sp
stack_const = logic.mk_stack_wrapper (st_split, stack_const, [])
stack_const = logic.mk_eq_selective_wrapper (stack_const,
([], [0]))
va2.append ((stack_const, 'LoopConst'))
p.cached_analysis[key] = va2
return va2
def inline_no_pre_pairing (p):
# FIXME: handle code sharing with check.inline_completely_unmatched
while True:
ns = [n for n in p.nodes if p.nodes[n].kind == 'Call'
if p.nodes[n].fname not in pre_pairings
if not is_instruction (p.nodes[n].fname)]
for n in ns:
trace ('Inlining %s at %d.' % (p.nodes[n].fname, n))
problem.inline_at_point (p, n)
if not ns:
return
last_asm_stack_depth_fun = [0]
def check_before_guess_asm_stack_depth (fun):
from solver import smt_expr
if not fun.entry:
return None
p = fun.as_problem (problem.Problem, name = 'Target')
try:
p.do_analysis ()
p.check_no_inner_loops ()
inline_no_pre_pairing (p)
except problem.Abort, e:
return None
rep = rep_graph.mk_graph_slice (p, fast = True)
try:
rep.get_pc (default_n_vc (p, 'Ret'), 'Target')
err_pc = rep.get_pc (default_n_vc (p, 'Err'), 'Target')
except solver.EnvMiss, e:
return None
inlined_funs = set ([fn for (_, _, fn) in p.inline_scripts['Target']])
if inlined_funs:
printout (' (stack analysis also involves %s)'
% ', '.join(inlined_funs))
return p
def guess_asm_stack_depth (fun):
p = check_before_guess_asm_stack_depth (fun)
if not p:
return (0, {})
last_asm_stack_depth_fun[0] = fun.name
entry = p.get_entry ('Target')
(_, sp) = get_stack_sp (p, 'Target')
nodes = get_asm_reachable_nodes (p, tag_set = ['Target'])
offs = get_ptr_offsets (p, [(n, sp) for n in nodes],
[(entry, sp, 'InitSP')], fail_early = True)
assert len (offs) == len (nodes), map (hex, set (nodes)
- set ([n for ((n, _), _, _) in offs]))
all_offs = [(n, signed_offset (off, 32, 10 ** 6))
for ((n, ptr), off, _) in offs]
min_offs = min ([offs for (n, offs) in all_offs])
max_offs = max ([offs for (n, offs) in all_offs])
assert min_offs >= 0 or max_offs <= 0, all_offs
multiplier = 1
if min_offs < 0:
multiplier = -1
max_offs = - min_offs
fcall_offs = [(p.nodes[n].fname, offs * multiplier)
for (n, offs) in all_offs if p.nodes[n].kind == 'Call']
fun_offs = {}
for f in set ([f for (f, _) in fcall_offs]):
fun_offs[f] = max ([offs for (f2, offs) in fcall_offs
if f2 == f])
return (max_offs, fun_offs)
def signed_offset (n, bits, bound = 0):
n = n & ((1 << bits) - 1)
if n >= (1 << (bits - 1)):
n = n - (1 << bits)
if bound:
assert n <= bound, (n, bound)
assert n >= (- bound), (n, bound)
return n
def ident_conds (fname, idents):
rolling = syntax.true_term
conds = []
for ident in idents.get (fname, [syntax.true_term]):
conds.append ((ident, syntax.mk_and (rolling, ident)))
rolling = syntax.mk_and (rolling, syntax.mk_not (ident))
return conds
def ident_callables (fname, callees, idents):
from solver import to_smt_expr, smt_expr
from syntax import mk_not, mk_and, true_term
auto_callables = dict ([((ident, f, true_term), True)
for ident in idents.get (fname, [true_term])
for f in callees if f not in idents])
if not [f for f in callees if f in idents]:
return auto_callables
fun = functions[fname]
p = fun.as_problem (problem.Problem, name = 'Target')
check_ns = [(n, ident, cond) for n in p.nodes
if p.nodes[n].kind == 'Call'
if p.nodes[n].fname in idents
for (ident, cond) in ident_conds (p.nodes[n].fname, idents)]
p.do_analysis ()
assert check_ns
rep = rep_graph.mk_graph_slice (p, fast = True)
err_hyp = rep_graph.pc_false_hyp ((default_n_vc (p, 'Err'), 'Target'))
callables = auto_callables
nhyps = mk_not_callable_hyps (p)
for (ident, cond) in ident_conds (fname, idents):
renames = p.entry_exit_renames (tags = ['Target'])
cond = syntax.rename_expr (cond, renames['Target_IN'])
entry = p.get_entry ('Target')
e_vis = ((entry, ()), 'Target')
hyps = [err_hyp, rep_graph.eq_hyp ((cond, e_vis),
(true_term, e_vis))]
for (n, ident2, cond2) in check_ns:
k = (ident, p.nodes[n].fname, ident2)
(inp_env, _, _) = rep.get_func (default_n_vc (p, n))
pc = rep.get_pc (default_n_vc (p, n))
cond2 = to_smt_expr (cond2, inp_env, rep.solv)
if rep.test_hyp_whyps (mk_not (mk_and (pc, cond2)),
hyps + nhyps):
callables[k] = False
else:
callables[k] = True
return callables
def compute_immediate_stack_bounds (idents, names):
from syntax import true_term
immed = {}
names = sorted (names)
for (i, fname) in enumerate (names):
printout ('Doing stack analysis for %r. (%d of %d)' % (fname,
i + 1, len (names)))
fun = functions[fname]
(offs, fn_offs) = guess_asm_stack_depth (fun)
callables = ident_callables (fname, fn_offs.keys (), idents)
for ident in idents.get (fname, [true_term]):
calls = [((fname2, ident2), fn_offs[fname2])
for fname2 in fn_offs
for ident2 in idents.get (fname2, [true_term])
if callables[(ident, fname2, ident2)]]
immed[(fname, ident)] = (offs, dict (calls))
last_immediate_stack_bounds[0] = immed
return immed
last_immediate_stack_bounds = [0]
def immediate_stack_bounds_loop (immed):
graph = dict ([(k, immed[k][1].keys ()) for k in immed])
graph['ENTRY'] = list (immed)
comps = logic.tarjan (graph, ['ENTRY'])
rec_comps = [[x] + y for (x, y) in comps if y]
return rec_comps
def compute_recursive_stack_bounds (immed):
assert not immediate_stack_bounds_loop (immed)
bounds = {}
todo = immed.keys ()
report = 1000
while todo:
if len (todo) >= report:
trace ('todo length %d' % len (todo))
trace ('tail: %s' % todo[-20:])
report += 1000
(fname, ident) = todo.pop ()
if (fname, ident) in bounds:
continue
(static, calls) = immed[(fname, ident)]
if [1 for k in calls if k not in bounds]:
todo.append ((fname, ident))
todo.extend (calls.keys ())
continue
else:
bounds[(fname, ident)] = max ([static]
+ [bounds[k] + calls[k] for k in calls])
return bounds
def stack_bounds_to_closed_form (bounds, names, idents):
closed = {}
for fname in names:
res = syntax.mk_word32 (bounds[(fname, syntax.true_term)])
extras = []
if fname in idents:
assert idents[fname][-1] == syntax.true_term
extras = reversed (idents[fname][:-1])
for ident in extras:
alt = syntax.mk_word32 (bounds[(fname, ident)])
res = syntax.mk_if (ident, alt, res)
closed[fname] = res
return closed
def compute_asm_stack_bounds (idents, names):
immed = compute_immediate_stack_bounds (idents, names)
bounds = compute_recursive_stack_bounds (immed)
closed = stack_bounds_to_closed_form (bounds, names, idents)
return closed
recursion_trace = []
recursion_last_assns = [[]]
def get_recursion_identifiers (funs, extra_unfolds = []):
idents = {}
del recursion_trace[:]
graph = dict ([(f, list (functions[f].function_calls ()))
for f in functions])
fs = funs
fs2 = set ()
while fs2 != fs:
fs2 = fs
fs = set.union (set ([f for f in graph if [f2 for f2 in graph[f]
if f2 in fs2]]),
set ([f2 for f in fs2 for f2 in graph[f]]), fs2)
graph = dict ([(f, graph[f]) for f in fs])
entries = list (fs - set ([f2 for f in graph for f2 in graph[f]]))
comps = logic.tarjan (graph, entries)
for (head, tail) in comps:
if tail or head in graph[head]:
group = [head] + list (tail)
idents2 = compute_recursion_idents (group,
extra_unfolds)
idents.update (idents2)
return idents
def compute_recursion_idents (group, extra_unfolds):
idents = {}
group = set (group)
recursion_trace.append ('Computing for group %s' % group)
printout ('Doing recursion analysis for function group:')
printout (' %s' % list(group))
prevs = set ([f for f in functions
if [f2 for f2 in functions[f].function_calls () if f2 in group]])
for f in prevs - group:
recursion_trace.append (' checking for %s' % f)
trace ('Checking idents for %s' % f)
while add_recursion_ident (f, group, idents, extra_unfolds):
pass
return idents
def function_link_assns (p, call_site, tag):
call_vis = (default_n_vc (p, call_site), p.node_tags[call_site][0])
return rep_graph.mk_function_link_hyps (p, call_vis, tag)
def add_recursion_ident (f, group, idents, extra_unfolds):
from syntax import mk_eq, mk_implies, mk_var
p = problem.Problem (None, name = 'Recursion Test')
chain = []
tag = 'fun0'
p.add_entry_function (functions[f], tag)
p.do_analysis ()
assns = []
recursion_last_assns[0] = assns
while True:
res = find_unknown_recursion (p, group, idents, tag, assns,
extra_unfolds)
if res == None:
break
if p.nodes[res].fname not in group:
problem.inline_at_point (p, res)
continue
fname = p.nodes[res].fname
chain.append (fname)
tag = 'fun%d' % len (chain)
(args, _, entry) = p.add_entry_function (functions[fname], tag)
p.do_analysis ()
assns += function_link_assns (p, res, tag)
if chain == []:
return None
recursion_trace.append (' created fun chain %s' % chain)
word_args = [(i, mk_var (s, typ))
for (i, (s, typ)) in enumerate (args)
if typ.kind == 'Word']
rep = rep_graph.mk_graph_slice (p, fast = True)
(_, env) = rep.get_node_pc_env ((entry, ()))
m = {}
res = rep.test_hyp_whyps (syntax.false_term, assns, model = m)
assert m
if find_unknown_recursion (p, group, idents, tag, [], []) == None:
idents.setdefault (fname, [])
idents[fname].append (syntax.true_term)
recursion_trace.append (' found final ident for %s' % fname)
return syntax.true_term
assert word_args
recursion_trace.append (' scanning for ident for %s' % fname)
for (i, arg) in word_args:
(nm, typ) = functions[fname].inputs[i]
arg_smt = solver.to_smt_expr (arg, env, rep.solv)
val = search.eval_model_expr (m, rep.solv, arg_smt)
if not rep.test_hyp_whyps (mk_eq (arg_smt, val), assns):
recursion_trace.append (' discarded %s = 0x%x, not stable' % (nm, val.val))
continue
entry_vis = ((entry, ()), tag)
ass = rep_graph.eq_hyp ((arg, entry_vis), (val, entry_vis))
res = find_unknown_recursion (p, group, idents, tag,
assns + [ass], [])
if res:
fname2 = p.nodes[res].fname
recursion_trace.append (' discarded %s, allows recursion to %s' % (nm, fname2))
continue
eq = syntax.mk_eq (mk_var (nm, typ), val)
idents.setdefault (fname, [])
idents[fname].append (eq)
recursion_trace.append (' found ident for %s: %s' % (fname, eq))
return eq
assert not "identifying assertion found"
def find_unknown_recursion (p, group, idents, tag, assns, extra_unfolds):
from syntax import mk_not, mk_and, foldr1
rep = rep_graph.mk_graph_slice (p, fast = True)
for n in p.nodes:
if p.nodes[n].kind != 'Call':
continue
if p.node_tags[n][0] != tag:
continue
fname = p.nodes[n].fname
if fname in extra_unfolds:
return n
if fname not in group:
continue
(inp_env, _, _) = rep.get_func (default_n_vc (p, n))
pc = rep.get_pc (default_n_vc (p, n))
new = foldr1 (mk_and, [pc] + [syntax.mk_not (
solver.to_smt_expr (ident, inp_env, rep.solv))
for ident in idents.get (fname, [])])
if rep.test_hyp_whyps (mk_not (new), assns):
continue
return n
return None
asm_cc_cache = {}
def is_instruction (fname):
bits = fname.split ("'")
return bits[1:] and bits[:1] in [["l_impl"], ["instruction"]]
def get_asm_calling_convention (fname):
if fname in asm_cc_cache:
return asm_cc_cache[fname]
if fname not in pre_pairings:
bits = fname.split ("'")
if not is_instruction (fname):
trace ("Warning: unusual unmatched function (%s, %s)."
% (fname, bits))
return None
pair = pre_pairings[fname]
assert pair['ASM'] == fname
c_fun = functions[pair['C']]
from logic import split_scalar_pairs
(var_c_args, c_imem, glob_c_args) = split_scalar_pairs (c_fun.inputs)
(var_c_rets, c_omem, glob_c_rets) = split_scalar_pairs (c_fun.outputs)
num_args = len (var_c_args)
num_rets = len (var_c_rets)
const_mem = not (c_omem)
cc = get_asm_calling_convention_inner (num_args, num_rets, const_mem)
asm_cc_cache[fname] = cc
return cc
def get_asm_calling_convention_inner (num_c_args, num_c_rets, const_mem):
key = ('Inner', num_c_args, num_c_rets, const_mem)
if key in asm_cc_cache:
return asm_cc_cache[key]