forked from jpata/hepaccelerate-cms
-
Notifications
You must be signed in to change notification settings - Fork 5
/
Disco_tf.py
59 lines (45 loc) · 2.56 KB
/
Disco_tf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import tensorflow as tf
def distance_corr(var_1, var_2, normedweight, power=1):
"""var_1: First variable to decorrelate (eg mass)
var_2: Second variable to decorrelate (eg classifier output)
normedweight: Per-example weight. Sum of weights should add up to N (where N is the number of examples)
power: Exponent used in calculating the distance correlation
va1_1, var_2 and normedweight should all be 1D tf tensors with the same number of entries
Usage: Add to your loss function. total_loss = BCE_loss + lambda * distance_corr
"""
xx = tf.reshape(var_1, [-1, 1])
xx = tf.tile(xx, [1, tf.size(var_1)])
xx = tf.reshape(xx, [tf.size(var_1), tf.size(var_1)])
yy = tf.tile(var_1, [tf.size(var_1)])
yy = tf.reshape(yy, [tf.size(var_1), tf.size(var_1)])
amat = tf.math.abs(xx-yy)
xx = tf.reshape(var_2, [-1, 1])
xx = tf.tile(xx, [1, tf.size(var_2)])
xx = tf.reshape(xx, [tf.size(var_2), tf.size(var_2)])
yy = tf.tile(var_2, [tf.size(var_2)])
yy = tf.reshape(yy, [tf.size(var_2), tf.size(var_2)])
bmat = tf.math.abs(xx-yy)
amatavg = tf.reduce_mean(amat*normedweight, axis=1)
bmatavg = tf.reduce_mean(bmat*normedweight, axis=1)
minuend_1 = tf.tile(amatavg, [tf.size(var_1)])
minuend_1 = tf.reshape(minuend_1, [tf.size(var_1), tf.size(var_1)])
minuend_2 = tf.reshape(amatavg, [-1, 1])
minuend_2 = tf.tile(minuend_2, [1, tf.size(var_1)])
minuend_2 = tf.reshape(minuend_2, [tf.size(var_1), tf.size(var_1)])
Amat = amat-minuend_1-minuend_2+tf.reduce_mean(amatavg*normedweight)
minuend_1 = tf.tile(bmatavg, [tf.size(var_2)])
minuend_1 = tf.reshape(minuend_1, [tf.size(var_2), tf.size(var_2)])
minuend_2 = tf.reshape(bmatavg, [-1, 1])
minuend_2 = tf.tile(minuend_2, [1, tf.size(var_2)])
minuend_2 = tf.reshape(minuend_2, [tf.size(var_2), tf.size(var_2)])
Bmat = bmat-minuend_1-minuend_2+tf.reduce_mean(bmatavg*normedweight)
ABavg = tf.reduce_mean(Amat*Bmat*normedweight,axis=1)
AAavg = tf.reduce_mean(Amat*Amat*normedweight,axis=1)
BBavg = tf.reduce_mean(Bmat*Bmat*normedweight,axis=1)
if power==1:
dCorr = tf.reduce_mean(ABavg*normedweight)/tf.math.sqrt(tf.reduce_mean(AAavg*normedweight)*tf.reduce_mean(BBavg*normedweight))
elif power==2:
dCorr = (tf.reduce_mean(ABavg*normedweight))**2/(tf.reduce_mean(AAavg*normedweight)*tf.reduce_mean(BBavg*normedweight))
else:
dCorr = (tf.reduce_mean(ABavg*normedweight)/tf.math.sqrt(tf.reduce_mean(AAavg*normedweight)*tf.reduce_mean(BBavg*normedweight)))**power
return dCorr