by Neal R. Wagner
Copyright © 2001 by Neal R. Wagner. All rights reserved.
NOTE: This site is obsolete. See book draft (in PDF):
Classes Tables, GetBytes, Copy, and Print are the same as for encryption.
The class AESdecrypt provides all the principle functions for the AES decryption algorithm:
// AESdecrypt: AES decryption
public class AESdecrypt {
public final int Nb = 4; // words in a block, always 4 for now
public int Nk; // key length in words
public int Nr; // number of rounds, = Nk + 6
private int wCount; // position in w (= 4*Nb*(Nr+1) each encrypt)
AEStables tab; // all the tables needed for AES
byte[] w; // the expanded key
// AESdecrypt: constructor for class. Mainly expands key
public AESdecrypt(byte[] key, int NkIn) {
Nk = NkIn; // words in a key, = 4, or 6, or 8
Nr = Nk + 6; // corresponding number of rounds
tab = new AEStables(); // class to give values of various functions
w = new byte[4*Nb*(Nr+1)]; // room for expanded key
KeyExpansion(key, w); // length of w depends on Nr
}
// InvCipher: actual AES decryption
public void InvCipher(byte[] in, byte[] out) {
wCount = 4*Nb*(Nr+1); // count bytes during decryption
byte[][] state = new byte[4][Nb]; // the state array
Copy.copy(state, in); // actual component-wise copy
InvAddRoundKey(state); // xor with expanded key
for (int round = Nr-1; round >= 1; round--) {
Print.printArray("Start round " + (Nr - round) + ":", state);
InvShiftRows(state); // mix up rows
InvSubBytes(state); // inverse S-box substitution
InvAddRoundKey(state); // xor with expanded key
InvMixColumns(state); // complicated mix of columns
}
Print.printArray("Start round " + Nr + ":", state);
InvShiftRows(state); // mix up rows
InvSubBytes(state); // inverse S-box substitution
InvAddRoundKey(state); // xor with expanded key
Copy.copy(out, state);
}
// KeyExpansion: expand key, byte-oriented code, but tracks words
// (the same as for encryption)
private void KeyExpansion(byte[] key, byte[] w) {
byte[] temp = new byte[4];
// first just copy key to w
int j = 0;
while (j < 4*Nk) {
w[j] = key[j++];
}
// here j == 4*Nk;
int i;
while(j < 4*Nb*(Nr+1)) {
i = j/4; // j is always multiple of 4 here
// handle everything word-at-a time, 4 bytes at a time
for (int iTemp = 0; iTemp < 4; iTemp++)
temp[iTemp] = w[j-4+iTemp];
if (i % Nk == 0) {
byte ttemp, tRcon;
byte oldtemp0 = temp[0];
for (int iTemp = 0; iTemp < 4; iTemp++) {
if (iTemp == 3) ttemp = oldtemp0;
else ttemp = temp[iTemp+1];
if (iTemp == 0) tRcon = tab.Rcon(i/Nk);
else tRcon = 0;
temp[iTemp] = (byte)(tab.SBox(ttemp) ^ tRcon);
}
}
else if (Nk > 6 && (i%Nk) == 4) {
for (int iTemp = 0; iTemp < 4; iTemp++)
temp[iTemp] = tab.SBox(temp[iTemp]);
}
for (int iTemp = 0; iTemp < 4; iTemp++)
w[j+iTemp] = (byte)(w[j - 4*Nk + iTemp] ^ temp[iTemp]);
j = j + 4;
}
}
// InvSubBytes: apply inverse Sbox substitution to each byte of state
private void InvSubBytes(byte[][] state) {
for (int row = 0; row < 4; row++)
for (int col = 0; col < Nb; col++)
state[row][col] = tab.invSBox(state[row][col]);
}
// InvShiftRows: right circular shift of rows 1, 2, 3 by 1, 2, 3
private void InvShiftRows(byte[][] state) {
byte[] t = new byte[4];
for (int r = 1; r < 4; r++) {
for (int c = 0; c < Nb; c++)
t[(c + r)%Nb] = state[r][c];
for (int c = 0; c < Nb; c++)
state[r][c] = t[c];
}
}
// InvMixColumns: complex and sophisticated mixing of columns
private void InvMixColumns(byte[][] s) {
int[] sp = new int[4];
byte b0b = (byte)0x0b; byte b0d = (byte)0x0d;
byte b09 = (byte)0x09; byte b0e = (byte)0x0e;
for (int c = 0; c < 4; c++) {
sp[0] = tab.FFMul(b0e, s[0][c]) ^ tab.FFMul(b0b, s[1][c]) ^
tab.FFMul(b0d, s[2][c]) ^ tab.FFMul(b09, s[3][c]);
sp[1] = tab.FFMul(b09, s[0][c]) ^ tab.FFMul(b0e, s[1][c]) ^
tab.FFMul(b0b, s[2][c]) ^ tab.FFMul(b0d, s[3][c]);
sp[2] = tab.FFMul(b0d, s[0][c]) ^ tab.FFMul(b09, s[1][c]) ^
tab.FFMul(b0e, s[2][c]) ^ tab.FFMul(b0b, s[3][c]);
sp[3] = tab.FFMul(b0b, s[0][c]) ^ tab.FFMul(b0d, s[1][c]) ^
tab.FFMul(b09, s[2][c]) ^ tab.FFMul(b0e, s[3][c]);
for (int i = 0; i < 4; i++) s[i][c] = (byte)(sp[i]);
}
}
// InvAddRoundKey: same as AddRoundKey, but backwards
private void InvAddRoundKey(byte[][] state) {
for (int c = Nb - 1; c >= 0; c--)
for (int r = 3; r >= 0 ; r--)
state[r][c] = (byte)(state[r][c] ^ w[--wCount]);
}
}
The class AESinvTest is a driver for testing decryption:
// AESinvTest: test AES decryption
public class AESinvTest {
public static void main(String[] args) {
// for 128-bit key, use 16, 16, and 4 below
// for 192-bit key, use 16, 24 and 6 below
// for 256-bit key, use 16, 32 and 8 below
GetBytes getInput = new GetBytes("ciphertext1.txt", 16);
byte[] in = getInput.getBytes();
GetBytes getKey = new GetBytes("key1.txt", 16);
byte[] key = getKey.getBytes();
AESdecrypt aesDec = new AESdecrypt(key, 4);
Print.printArray("Ciphertext: ", in);
Print.printArray("Key: ", key);
byte[] out = new byte[16];
aesDec.InvCipher(in, out);
Print.printArray("Plaintext: ", out);
}
}