I am trying to implement the AES Key Expansion in python, but I am having some trouble. Here is the code I’m using:
def print_matrix(m): #Just to print the matrix for i in m: for j in i: print(j, end=" ") print(" ") def xor_bytes(a, b): #Returns a new byte array with the elements xor'ed return bytes(i^j for i, j in zip(a, b)) def bytes2matrix(text): #Converts a 16-byte array into a 4x4 matrix return [list(text[i:i+4]) for i in range(0, len(text), 4)] def expand_key(master_key): #Expands and returns a list of key matrices for the given master_key. # Initialize round keys with raw key material. rounds_by_key_size = {16: 10, 24: 12, 32: 14} n_rounds = rounds_by_key_size[len(master_key)] key_columns = bytes2matrix(master_key) iteration_size = len(master_key) // 4 # Each iteration has exactly as many columns as the key material. columns_per_iteration = len(key_columns) i = 1 while len(key_columns) < (n_rounds + 1) * 4: # Copy previous word. word = list(key_columns[-1]) # Perform schedule_core once every "row". if len(key_columns) % iteration_size == 0: # Circular shift. word.append(word.pop(0)) # Map to S-BOX. word = [sbox[b] for b in word] # XOR with first byte of R-CON, since the others bytes of R-CON are 0. word[0] ^= rcon[i] i += 1 elif len(master_key) == 32 and len(key_columns) % iteration_size == 4: # Run word through S-box in the fourth iteration when using a # 256-bit key. word = [sbox[b] for b in word] # XOR with equivalent word from previous iteration. word = xor_bytes(word, key_columns[-iteration_size]) key_columns.append(word ) #Group key words in 4x4 byte matrices. return [key_columns[4*i : 4*(i+1)] for i in range(len(key_columns) // 4)] #Testing key='Thats my Kung Fu' key=key.encode('utf-8') print_matrix(expand_key(key))
This is what I should have as results.
But this is what I have with the code:
NOTE: Round 0 (first line) doesn’t matter as I transformed it (key=key.encode(‘utf-8’)) in order to get keys 1 to 10. I’m sure round 0 is correct.
Some bytes are coming correct. But why are the others coming wrong? what mistake am I doing?
In addition, how can I transform the notation b’\something ’ to hexadecimal (0xsomething), and how could I extract the arrays (each round key) from the result that the code gives me?
Thanks in advance for your help, and merry christmas!!!