-
Notifications
You must be signed in to change notification settings - Fork 0
/
policy_attn_arch.txt
51 lines (38 loc) · 1.77 KB
/
policy_attn_arch.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
brief architecture overview for attention policy head v5
---------
OVERVIEW:
---------
B = batch size; C = number of channels/filters in residual stack; E = embedding size; D = query & key size; N = number of policy heads
input IN:= [BxCx8x8]
reshape and transpose input -> [Bx64xC]
embed tokens (dense layer) -> [Bx64xE]
*optional encoder layers -> [Bx64xE]
generate Queries and Keys (two dense layers) -> Q:= [Bx64xD], K:= [Bx64xD]
extract last rank from K (for pawn promo) -> PK:= [Bx8xD]
create promo offsets from PK (dense layer size 3) -> PO:= [Bx8x3]
transpose PO -> PO:= [Bx3x8]
multi-head attention if N > 1:
depth:= D / N
reshape and transpose Q and K /-> [BxNx64xdepth]
matmul Q with T(K) to create 4096 move logits -> L:= if N > 1 [BxNx64x64] else [Bx64x64]
summarize policy heads if N > 1:
reduce mean on L, axis=1 /-> L:= [Bx64x64]
**create pawn promotion logits -> P:= [Bx8x24]
divide P and L by a constant to stabilize grad.
apply policy map to P and L -> OUT:= [Bx1858]
---------------
*ENCODER LAYER:
---------------
too lazy to write it all out right now, but mostly the same stuff: concat, reshape, transpose, dense layers, matmul, division, softmax, more dense layers, etc.
-----------------
**PAWN PROMOTION:
-----------------
input:
promo offsets (PO) [Bx3x8]
move logits (L) [Bx64x64]
knight promo logits:
extract R7 to R8 moves from L -> NP:= [Bx8x8]
q, r, b promo logits:
add each row (axis=1) of PO to NP -> QP, RP, BP := [Bx8x8]
concat and reshape QP, RP, BP -> P:= [Bx8x24]
return P