-
Notifications
You must be signed in to change notification settings - Fork 3
/
blake3.py
365 lines (284 loc) · 13.5 KB
/
blake3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
from cryptopals_lib import *
class Blake3Chunk(object):
"""docstring for Blake3Node"""
def __init__(self, buffers, flags, blocks_compressed=0, node_number=0):
#Flags:
# CHUNK_START = 0x01
# CHUNK_END = 0x02
# PARENT = 0x04
# ROOT = 0x08
# KEYED_HASH = 0x10
# DERIVE_KEY_CONTEXT = 0x20
# DERIVE_KEY_MATERIAL = 0x40
self.flags = flags
self.chaining_values = buffers
self.input_data = b""
self.blocks_compressed = blocks_compressed
self.node_number = node_number
self.max_chunk_size = 1024
self.max_block_length = 64
#Compression Settings
self.permutations = [
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15],
[14,10, 4, 8, 9,15,13, 6, 1,12, 0, 2,11, 7, 5, 3],
[11, 8,12, 0, 5, 2,15,13,10,14, 3, 6, 7, 1, 9, 4],
[ 7, 9, 3, 1,13,12,11,14, 2, 6, 5,10, 4, 0,15, 8],
[ 9, 0, 5, 7, 2, 4,10,15,14, 1,11,12, 6, 8, 3,13],
[ 2,12, 6,10, 0,11, 8, 3, 4,13, 7, 5,15,14, 1, 9],
[12, 5, 1,15,14,13, 4,10, 0, 7, 6, 3, 9, 2, 8,11],
[13,11, 7,14,12, 1, 3, 9, 5, 0,15, 4, 8, 6, 2,10],
[ 6,15,14, 9,11, 3, 0, 8,12, 2,13, 7, 1, 4,10, 5],
[10, 2, 8, 4, 7, 6, 1, 5,15,11, 9,14, 3,12,13, 0],
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15],
[14,10, 4, 8, 9,15,13, 6, 1,12, 0, 2,11, 7, 5, 3],
[11, 8,12, 0, 5, 2,15,13,10,14, 3, 6, 7, 1, 9, 4],
[ 7, 9, 3, 1,13,12,11,14, 2, 6, 5,10, 4, 0,15, 8],
[ 9, 0, 5, 7, 2, 4,10,15,14, 1,11,12, 6, 8, 3,13],
[ 2,12, 6,10, 0,11, 8, 3, 4,13, 7, 5,15,14, 1, 9],
[12, 5, 1,15,14,13, 4,10, 0, 7, 6, 3, 9, 2, 8,11],
[13,11, 7,14,12, 1, 3, 9, 5, 0,15, 4, 8, 6, 2,10],
[ 6,15,14, 9,11, 3, 0, 8,12, 2,13, 7, 1, 4,10, 5],
[10, 2, 8, 4, 7, 6, 1, 5,15,11, 9,14, 3,12,13, 0],
]
self.blake3_permutations = [2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8]
self.rounds = 7
self.rotations = [16,12,8,7]
self.blocksize = 32
#
self.iv = [0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A,
0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19,]
def len(self):
return self.max_block_length * self.blocks_compressed + len(self.input_data)
def _chacha_quarter_round(self, a, b, c, d, message, round_num, index):
#Calculate indexes from Permuation table and round_index and offset
message_index = self.permutations[index]
constant_index = self.permutations[index+1]
#Modified first part to include message and round xor
a = asint((a + b) + message[index], self.blocksize)
d = asint(d ^ a, self.blocksize)
d = asint(shift_rotate_right(d, self.rotations[0], self.blocksize), self.blocksize)
c = asint(c + d, self.blocksize)
b = asint(b ^ c, self.blocksize)
b = asint(shift_rotate_right(b, self.rotations[1], self.blocksize), self.blocksize)
#Modified first part to include message and round xor
a = asint((a + b) + message[index+1], self.blocksize)
d = asint(d ^ a, self.blocksize)
d = asint(shift_rotate_right(d, self.rotations[2], self.blocksize), self.blocksize)
c = asint(d + c, self.blocksize)
b = asint(b ^ c, self.blocksize)
b = asint(shift_rotate_right(b, self.rotations[3], self.blocksize), self.blocksize)
return [a,b,c,d]
def _permutation(self, block):
temp_buffers = block[:]
for index in range(len(block)):
#Use the permutation lookup table to get new index
new_index = self.blake3_permutations[index]
temp_buffers[index] = block[new_index]
return temp_buffers
def _compress_chunk_manual(self, chaining_values, counter, flags, block_length, input_data):
#Extend inputdata
if type(input_data) == bytes:
input_data = input_data.ljust(self.max_block_length, b"\x00")
input_data = bytes_to_intarray(input_data, (self.blocksize//8), byte_order="little")
#Check input length
assert len(input_data) == 16
'''
|chainedValue |chainedValue |chainedValue |chainedValue |
|chainedValue |chainedValue |chainedValue |chainedValue |
|IV |IV |IV |IV |
|blockcounter[0] |blockcounter[0] |blocklen |flags |
'''
#Start setting up the temp buffers
temp_buffers = chaining_values[:8] + self.iv[:4] + [0,0,0,0]
#Add the Number of blocks that have been processed
temp_buffers[12] ^= asint(counter, self.blocksize)
temp_buffers[13] ^= asint(counter >> self.blocksize, self.blocksize)
#Add the number of bytes in the current block to be hashed
temp_buffers[14] = block_length
temp_buffers[15] = flags
#print(f"compress: {chaining_values[0]}, {counter}, {flags}, {block_length}, {input_data}")
#print(f"before: {[hex(x) for x in temp_buffers]}")
#Do ChaCha rounds with modifications
for index in range(self.rounds):
#Do Each Column
temp_buffers[0], temp_buffers[4], temp_buffers[8], temp_buffers[12] = self._chacha_quarter_round(temp_buffers[0], temp_buffers[4], temp_buffers[8], temp_buffers[12], input_data, index, 0)
temp_buffers[1], temp_buffers[5], temp_buffers[9], temp_buffers[13] = self._chacha_quarter_round(temp_buffers[1], temp_buffers[5], temp_buffers[9], temp_buffers[13], input_data, index, 2)
temp_buffers[2], temp_buffers[6], temp_buffers[10], temp_buffers[14] = self._chacha_quarter_round(temp_buffers[2], temp_buffers[6], temp_buffers[10], temp_buffers[14], input_data, index, 4)
temp_buffers[3], temp_buffers[7], temp_buffers[11], temp_buffers[15] = self._chacha_quarter_round(temp_buffers[3], temp_buffers[7], temp_buffers[11], temp_buffers[15], input_data, index, 6)
#Do Each Diagonal
temp_buffers[0], temp_buffers[5], temp_buffers[10], temp_buffers[15] = self._chacha_quarter_round(temp_buffers[0], temp_buffers[5], temp_buffers[10], temp_buffers[15], input_data, index, 8)
temp_buffers[1], temp_buffers[6], temp_buffers[11], temp_buffers[12] = self._chacha_quarter_round(temp_buffers[1], temp_buffers[6], temp_buffers[11], temp_buffers[12], input_data, index, 10)
temp_buffers[2], temp_buffers[7], temp_buffers[8], temp_buffers[13] = self._chacha_quarter_round(temp_buffers[2], temp_buffers[7], temp_buffers[8], temp_buffers[13], input_data, index, 12)
temp_buffers[3], temp_buffers[4], temp_buffers[9], temp_buffers[14] = self._chacha_quarter_round(temp_buffers[3], temp_buffers[4], temp_buffers[9], temp_buffers[14], input_data, index, 14)
#Black3 only permuste the input data
if index != self.rounds - 1:
input_data = self._permutation(input_data)
#print(f"after: {[hex(x) for x in temp_buffers]}")
#Update Buffers
for x in range(8):
temp_buffers[x] ^= temp_buffers[x+8]
temp_buffers[x+8] ^= chaining_values[x]
#print(f"done: {[hex(x) for x in temp_buffers]}")
return temp_buffers
def _compress_chunk(self, **kwargs):
#Set defaults
chaining_values = self.chaining_values
node_number = self.node_number
flags = self.flags
block_length = len(self.input_data)
input_data = self.input_data
#Add the flags to the end
if self.blocks_compressed == 0:
#Set CHUNK_START flag
flags |= 0x01
elif self.blocks_compressed == 16:
#Set CHUNK_END
flags |= 0x02
#Overwride defaults if needed
for arg in kwargs:
if arg == "chaining_values":
chaining_values = kwargs[arg]
elif arg == "counter":
node_number = kwargs[arg]
elif arg == "block_length":
block_length = kwargs[arg]
elif arg == "input_data":
input_data = kwargs[arg]
elif arg == "flags":
flags |= kwargs[arg]
return self._compress_chunk_manual(chaining_values, node_number, flags, block_length, input_data)
def update(self, byte_input):
while len(byte_input) > 0:
#Check if block is currently full
if len(self.input_data) == self.max_block_length:
self.chaining_values = self._compress_chunk(flags=self.flags)[:8]
#Update Compressed
self.blocks_compressed +=1
self.input_data = b""
#Add up to the max_block_length
input_length = min(self.max_block_length, self.max_block_length - len(self.input_data))
self.input_data += byte_input[:input_length]
byte_input = byte_input[input_length:]
def output(self):
#If less than 64 bytes pad data
data = self.input_data.rjust(self.blocksize * 2, b"\x00")
#Add the END_CHUNK Flag
return self._compress_chunk(flags = (self.flags | 0x02))
class Blake3(object):
def __init__(self, output_size=256, key=None, personalization=None):
#Blake3 Constants
#Chunk State Varables
self.output_size = output_size
self.blocksize = 32
self.iv = [0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A,
0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19,]
self.cv_stack = []
#Blake3 with a custom IV For keyed hashing
if key != None:
#If specifying a key it must be 32 bytes to fit into the buffers
assert len(key) == 32
self.iv = bytes_to_intarray(key, (self.blocksize//8), byte_order="little")
self.flags = 0x10
#Since the key is set the flag for Keyed Hash (16 = 0x10)
self.chunk = Blake3Chunk(self.iv, self.flags)
#Blake3 to derive key from personalization
elif personalization != None:
#Get the Blake3 Hash of rhe personaliztion message to use for the key.
#This will be set with the DERIVE_KEY_CONTEXT flag (32 = 0x20)
derived_key = Blake3()
derived_key.flags |= 0x20
derived_key.update(personalization)
#Set the key to the derived key
self.iv = bytes_to_intarray(derived_key.finalize(), (self.blocksize//8), byte_order="little")
#Set the DERIVE_KEY_MATERIAL flag (64 = 0x40)
self.flags = 0x40
self.chunk = Blake3Chunk(self.iv, self.flags)
else:
self.flags = 0x00
self.chunk = Blake3Chunk(self.iv, self.flags)
def _set_message(self, message):
#Convert to bytes if not already
byte_message = bytearray(message)
#Set Final Length
self.final_length = len(message)
#Pad the data to a multable of the block size
while len(byte_message) == 0 or len(byte_message) % (self.blocksize * 2) != 0:
byte_message.append(0x00)
return byte_message
def append_chunk_cv(self, right_node_cv, chunk_num):
#Check If new chunk is the first one in the next level
while chunk_num & 1 == 0:
#Get the Left Node
left_node_cv = self.cv_stack.pop()
#Compress the left and right node with the parrent flag
right_node_cv = self.chunk._compress_chunk(chaining_values=self.iv, counter=0, block_length=self.chunk.max_block_length, flags=(self.flags | 0x04), input_data=(left_node_cv + right_node_cv))[:8]
#Move Chunk to the next level to compress
chunk_num >>= 1
self.cv_stack.append(right_node_cv)
def update(self, byte_input):
#Add Data to Chunks
while len(byte_input) > 0:
#Test if chunk reaches max_size then add a new chunk node
if self.chunk.max_chunk_size == self.chunk.len():
#Get Chaining Value
chunk_chaining_value = self.chunk.output()
#Update and Reset Data
self.chunk.node_number += 1
self.chunk.input_data = b""
#Update Chunk and Check if needs to compress
self.append_chunk_cv(chunk_chaining_value[:8], self.chunk.node_number)
#Create New Chunk
self.chunk = Blake3Chunk(self.iv, self.flags, 0, self.chunk.node_number)
#Add data to chunk up to the chunk_length
max_read_bytes = min(self.chunk.max_chunk_size - len(self.chunk.input_data), len(byte_input))
#Send Buffer to the chunk
self.chunk.update(byte_input[:max_read_bytes])
#Remove the Data that was sent to the chunk
byte_input = byte_input[max_read_bytes:]
def finalize(self, output_size=32):
right_data = []
left_data = bytes_to_intarray(self.chunk.input_data.ljust(self.chunk.max_block_length, b"\x00"), (self.blocksize//8), byte_order="little")
cv_stack_remaining = len(self.cv_stack)
#Set the End Flag for the next compress
self.chunk.flags |= 0x02
##Compress all Parent Values to a single Value
while cv_stack_remaining > 0:
#Decrease Stack Number
cv_stack_remaining -= 1
#Set the Parent Flag globaly until the end
self.flags |= 0x04
#Get Current Chaining Value
if right_data == []:
#If is the first time get the output
right_data = self.chunk.output()[:8]
else:
right_data = self.chunk._compress_chunk(chaining_values=self.iv, counter=0, block_length=self.chunk.max_block_length, flags=self.flags, input_data=(left_data + right_data))[:8]
#Setup the next Chain
left_data = self.cv_stack[cv_stack_remaining]
self.chunk = Blake3Chunk(self.iv, self.flags, 1, 0)
#Do Final Compress from the root
i = 0
ret = []
while (len(ret) * 4) < output_size:
if right_data == []:
#Set the ROOT Flag
ret += self.chunk._compress_chunk(counter=i, flags=(self.flags | 0x08), block_length=len(self.chunk.input_data), input_data=(left_data + right_data)
)
else:
ret += self.chunk._compress_chunk(counter=i, flags=(self.flags | 0x08), block_length=self.chunk.max_block_length, input_data=(left_data[:8] + right_data)
)
i += 1
return intarray_to_bytes(ret, (self.blocksize//8), byte_order="little")[:output_size]
def hash_digest(self, message, output_size=32):
return self.hash(message, output_size).hex()
if __name__ == '__main__':
#messages = [b"TESTDATA", b"TESTDATA" *10, b"TESTDATA" * 200] #b"TESTDATA" * 1000
messages = [b"TESTDATA" * 1000]
for message in messages:
#blake3 = Blake3(key=b"\xBB\x67\xAE\x85"*8)
#blake3 = Blake3(personalization=b"pure_blake3 2021-10-29 18:37:44 example context")
blake3 = Blake3()
blake3.update(message)
#print(f"blake3.finalize()")
output = blake3.finalize()
print(f"{message}: {output.hex()}")