/******************************************************************************/
/* MCrypt, by LoRd_MuldeR <MuldeR2@GMX.de>                                    */
/* This work has been released under the CC0 1.0 Universal license!           */
/******************************************************************************/

#ifdef _WIN32
#define  _CRT_RAND_S 1
#endif

#include <mcrypt.h>
#include "utils.h"
#include <string.h>

const char* const LIBMCRYPT_VERSION = "1.0.0";
const char* const LIBMCRYPT_BUILDNO = __DATE__", "__TIME__;

typedef struct
{
	uint8_t box[256U][256U];
	uint8_t inv[256U][256U];
	uint8_t off[256U];
	uint8_t pos;
}
crypt_state_t;

typedef struct
{
	uint32_t a, b, c, d;
	uint32_t counter;
}
rand_state_t;

// ==========================================================================
// Hash function
// ==========================================================================

static void hash_update(uint64_t* const h, const uint8_t* const data, const size_t data_len)
{
	for (size_t i = 0U; i < data_len; ++i)
	{
		*h ^= data[i];
		*h *= 0x00000100000001B3ull;
	}
}

static uint64_t hash_code(const uint64_t salt, const uint16_t pepper, const uint8_t* const data, const size_t data_len)
{
	uint64_t h = 0xCBF29CE484222325ull;
	hash_update(&h, (uint8_t*)&salt, sizeof(uint64_t));
	hash_update(&h, (uint8_t*)&pepper, sizeof(uint16_t));
	hash_update(&h, data, data_len);
	return h;
}

// ==========================================================================
// PRNG
// ==========================================================================

static void random_init(rand_state_t* const state, const uint64_t seed_0, const uint64_t seed_1)
{
	state->counter = 0U;
	state->a = (uint32_t)(seed_0 & 0xFFFFFFFF);
	state->b = (uint32_t)((seed_0 >> 32) & 0xFFFFFFFF);
	state->c = (uint32_t)(seed_1 & 0xFFFFFFFF);
	state->d = (uint32_t)((seed_1 >> 32) & 0xFFFFFFFF);
}

static uint32_t random_next(rand_state_t* const state)
{
	uint32_t t = state->d;
	const uint32_t s = state->a;
	state->d = state->c;
	state->c = state->b;
	state->b = s;
	t ^= t >> 2;
	t ^= t << 1;
	t ^= s ^ (s << 4);
	state->a = t;
	return t + (state->counter += 362437U);
}

static void random_seed(rand_state_t* const state, const uint64_t salt, const uint16_t pepper, const uint8_t* const key, const size_t key_len)
{
	const uint64_t hash_code_0 = hash_code(salt, pepper & 0x7FFF, key, key_len);
	const uint64_t hash_code_1 = hash_code(salt, pepper | 0x8000, key, key_len);
	random_init(state, hash_code_0, hash_code_1);
	for (size_t i = 0U; i < 13U; ++i)
	{
		random_next(state);
	}
}

// ==========================================================================
// Initialization
// ==========================================================================

static void initialize_state(crypt_state_t* const crypt_state, const uint64_t salt, const uint8_t* const key, const size_t key_len)
{
	rand_state_t rand_state;
	for (size_t r = 0U; r < 256U; ++r)
	{
		random_seed(&rand_state, salt, (uint16_t)r, key, key_len);
		crypt_state->off[r] = (uint8_t)random_next(&rand_state);
		for (size_t i = 0U; i < 256U; ++i)
		{
			const size_t j = random_next(&rand_state) % (i + 1U);
			if (j != i)
			{
				crypt_state->box[r][i] = crypt_state->box[r][j];
			}
			crypt_state->box[r][j] = (uint8_t)i;
		}
		for (size_t i = 0U; i < 256U; ++i)
		{
			const size_t j = crypt_state->box[r][i];
			crypt_state->inv[r][j] = (uint8_t)i;
		}
	}
	random_seed(&rand_state, salt, 0x0100, key, key_len);
	crypt_state->pos = (uint8_t)random_next(&rand_state);
	mcrypt_erase(&rand_state, sizeof(rand_state_t));
}

// ==========================================================================
// Encrypt / Decrypt
// ==========================================================================

static uint8_t process_enc(crypt_state_t* const crypt_state, uint8_t value)
{
	for (size_t i = 0U; i < 256U; ++i)
	{
		value = crypt_state->box[i][(value + crypt_state->off[i]) & 0xFF];
	}
	++crypt_state->off[crypt_state->pos++];
	return value;
}

static uint8_t process_dec(crypt_state_t* const crypt_state, uint8_t value)
{
	size_t i = 256U;
	while (i--)
	{
		value = (crypt_state->inv[i][value] - crypt_state->off[i]) & 0xFF;
	}
	++crypt_state->off[crypt_state->pos++];
	return value;
}

// ==========================================================================
// Public API
// ==========================================================================

int mcrypt_generate_seed(uint64_t* const seed)
{
	if (seed)
	{
		return mcrypt_random_bytes((uint8_t*)seed, sizeof(uint64_t));
	}
	return -1;
}

mcrypt_t mcrypt_alloc(const uint64_t salt, const char* const passphrase)
{
	if (!passphrase)
	{
		return ((mcrypt_t)NULL);
	}
	crypt_state_t* const state = (crypt_state_t*)malloc(sizeof(crypt_state_t));
	if (!state)
	{
		return ((mcrypt_t)NULL);
	}
	initialize_state(state, salt, (uint8_t*)passphrase, strlen(passphrase));
	return ((mcrypt_t)state);
}

int mcrypt_enc_process(const mcrypt_t context, const uint8_t* const input, uint8_t* const output, size_t length)
{
	crypt_state_t* const state = (crypt_state_t*)context;
	if (!context)
	{
		return -1;
	}
	for (size_t i = 0; i < length; ++i)
	{
		output[i] = process_enc(state, input[i]);
	}
	return 0;
}

int mcrypt_enc_process_inplace(const mcrypt_t context, uint8_t* const buffer, size_t length)
{
	crypt_state_t* const state = (crypt_state_t*)context;
	if (!context)
	{
		return -1;
	}
	for (size_t i = 0; i < length; ++i)
	{
		buffer[i] = process_enc(state, buffer[i]);
	}
	return 0;
}


int mcrypt_dec_process(const mcrypt_t context, const uint8_t* const input, uint8_t* const output, size_t length)
{
	crypt_state_t* const state = (crypt_state_t*)context;
	if (!context)
	{
		return -1;
	}
	for (size_t i = 0; i < length; ++i)
	{
		output[i] = process_dec(state, input[i]);
	}
	return 0;
}

int mcrypt_dec_process_inplace(const mcrypt_t context, uint8_t* const buffer, size_t length)
{
	crypt_state_t* const state = (crypt_state_t*)context;
	if (!context)
	{
		return -1;
	}
	for (size_t i = 0; i < length; ++i)
	{
		buffer[i] = process_dec(state, buffer[i]);
	}
	return 0;
}

void mcrypt_free(const mcrypt_t context)
{
	crypt_state_t* const state = (crypt_state_t*)context;
	if (context)
	{
		mcrypt_erase((void*)context, sizeof(crypt_state_t));
		free(context);
	}
}