I am trying to implement the AES Key Expansion in python, but I am having some trouble. Here is the code I’m using:
def print_matrix(m):
#Just to print the matrix
for i in m:
for j in i:
print(j, end=" ")
print(" ")
def xor_bytes(a, b):
#Returns a new byte array with the elements xor'ed
return bytes(i^j for i, j in zip(a, b))
def bytes2matrix(text):
#Converts a 16-byte array into a 4x4 matrix
return [list(text[i:i+4]) for i in range(0, len(text), 4)]
def expand_key(master_key):
#Expands and returns a list of key matrices for the given master_key.
# Initialize round keys with raw key material.
rounds_by_key_size = {16: 10, 24: 12, 32: 14}
n_rounds = rounds_by_key_size[len(master_key)]
key_columns = bytes2matrix(master_key)
iteration_size = len(master_key) // 4
# Each iteration has exactly as many columns as the key material.
columns_per_iteration = len(key_columns)
i = 1
while len(key_columns) < (n_rounds + 1) * 4:
# Copy previous word.
word = list(key_columns[-1])
# Perform schedule_core once every "row".
if len(key_columns) % iteration_size == 0:
# Circular shift.
word.append(word.pop(0))
# Map to S-BOX.
word = [sbox[b] for b in word]
# XOR with first byte of R-CON, since the others bytes of R-CON are 0.
word[0] ^= rcon[i]
i += 1
elif len(master_key) == 32 and len(key_columns) % iteration_size == 4:
# Run word through S-box in the fourth iteration when using a
# 256-bit key.
word = [sbox[b] for b in word]
# XOR with equivalent word from previous iteration.
word = xor_bytes(word, key_columns[-iteration_size])
key_columns.append(word )
#Group key words in 4x4 byte matrices.
return [key_columns[4*i : 4*(i+1)] for i in range(len(key_columns) // 4)]
#Testing
key='Thats my Kung Fu'
key=key.encode('utf-8')
print_matrix(expand_key(key))
This is what I should have as results. 
But this is what I have with the code:
NOTE: Round 0 (first line) doesn’t matter as I transformed it (key=key.encode(‘utf-8’)) in order to get keys 1 to 10. I’m sure round 0 is correct. 
Some bytes are coming correct. But why are the others coming wrong? what mistake am I doing?
In addition, how can I transform the notation b’\something ’ to hexadecimal (0xsomething), and how could I extract the arrays (each round key) from the result that the code gives me?
Thanks in advance for your help, and merry christmas!!!