/******************************************************************************/
/* SlunkCrypt, by LoRd_MuldeR <MuldeR2@GMX.de>                                */
/* This work has been released under the CC0 1.0 Universal license!           */
/******************************************************************************/

#ifndef INC_SLUNKCRYPT_PLUSPLUS
#define INC_SLUNKCRYPT_PLUSPLUS

/*
 * Compiler check
 */
#if (!defined(__cplusplus) || (__cplusplus < 201103L)) && (!defined(_MSVC_LANG) || (_MSVC_LANG < 201103L))
#error This file requires compiler and library support for the ISO C++11 standard.
#endif

/*
 * Dependencies
 */
#include "slunkcrypt.h"
#include <string>
#include <vector>
#include <stdexcept>
#include <cstring>

/*
 * Namespace
 */
namespace slunkcrypt
{
	/*
	 * Base class for SlunkCrypt
	 */
	class SlunkBase
	{
	public:
		SlunkBase(const size_t thread_count, const bool legacy_compat, const bool debug_logging)
		{
			std::memset(&m_param, 0, sizeof(m_param));
			m_param.version = ::SLUNKCRYPT_PARAM_VERSION;
			m_param.thread_count = thread_count;
			m_param.legacy_compat = legacy_compat ? SLUNKCRYPT_TRUE : SLUNKCRYPT_FALSE;
			m_param.debug_logging = debug_logging ? SLUNKCRYPT_TRUE : SLUNKCRYPT_FALSE;
		}

		virtual bool process(const uint8_t *const input, uint8_t *const output, size_t length) = 0;
		virtual bool process(const std::vector<uint8_t> &input, std::vector<uint8_t> &output) = 0;
		virtual bool inplace(uint8_t *const buffer, size_t length) = 0;
		virtual bool inplace(std::vector<uint8_t> &buffer) = 0;

	protected:
		::slunkparam_t m_param;
		::slunkcrypt_t m_instance;
	};

	/*
	 * Class for encryption
	 */
	class Encryptor : public SlunkBase
	{
	public:
		Encryptor(const std::string &passwd, const size_t thread_count = 0U, const bool legacy_compat = false, const bool debug_logging = false)
			: SlunkBase(thread_count, legacy_compat, debug_logging)
		{
			if (::slunkcrypt_generate_nonce(&m_nonce) != SLUNKCRYPT_SUCCESS)
			{
				throw std::runtime_error("SlunkCryptEncr: Failed to generate the seed value!");
			}
			if ((m_instance = ::slunkcrypt_alloc_ext(m_nonce, (const uint8_t*)passwd.c_str(), passwd.length(), SLUNKCRYPT_ENCRYPT, &m_param)) == SLUNKCRYPT_NULL)
			{
				throw std::runtime_error("SlunkCryptEncr: Failed to create encoder instance!");
			}
		}

		~Encryptor(void)
		{
			if (m_instance != SLUNKCRYPT_NULL)
			{
				::slunkcrypt_free(m_instance);
			}
		}

		bool process(const uint8_t *const input, uint8_t *const output, size_t length)
		{
			return (::slunkcrypt_process(m_instance, input, output, length) == SLUNKCRYPT_SUCCESS);
		}

		bool process(const std::vector<uint8_t> &input, std::vector<uint8_t> &output)
		{
			if (output.size() >= input.size())
			{
				return (::slunkcrypt_process(m_instance, input.data(), output.data(), input.size()) == SLUNKCRYPT_SUCCESS);
			}
			return false;
		}

		bool inplace(uint8_t *const buffer, size_t length)
		{
			return (::slunkcrypt_inplace(m_instance, buffer, length) == SLUNKCRYPT_SUCCESS);
		}

		bool inplace(std::vector<uint8_t> &buffer)
		{
			return (::slunkcrypt_inplace(m_instance, buffer.data(), buffer.size()) == SLUNKCRYPT_SUCCESS);
		}

		uint64_t get_nonce(void) const
		{
			return m_nonce;
		}

	private:
		Encryptor(const Encryptor&) = delete;
		Encryptor& operator=(const Encryptor&) = delete;
		uint64_t m_nonce;
	};

	/*
	 * Class for decryption
	 */
	class Decryptor : public SlunkBase
	{
	public:
		Decryptor(const std::string &passwd, const uint64_t nonce, const size_t thread_count = 0U, const bool legacy_compat = false, const bool debug_logging = false)
			: SlunkBase(thread_count, legacy_compat, debug_logging)
		{
			if ((m_instance = ::slunkcrypt_alloc_ext(nonce, (const uint8_t*)passwd.c_str(), passwd.length(), SLUNKCRYPT_DECRYPT, &m_param)) == SLUNKCRYPT_NULL)
			{
				throw std::runtime_error("SlunkCryptDecr: Failed to create decoder instance!");
			}
		}

		~Decryptor(void)
		{
			if (m_instance != SLUNKCRYPT_NULL)
			{
				::slunkcrypt_free(m_instance);
			}
		}

		bool process(const uint8_t *const input, uint8_t *const output, size_t length)
		{
			return (::slunkcrypt_process(m_instance, input, output, length) == SLUNKCRYPT_SUCCESS);
		}

		bool process(const std::vector<uint8_t> &input, std::vector<uint8_t> &output)
		{
			if (output.size() >= input.size())
			{
				return (::slunkcrypt_process(m_instance, input.data(), output.data(), input.size()) == SLUNKCRYPT_SUCCESS);
			}
			return false;
		}

		bool inplace(uint8_t *const buffer, size_t length)
		{
			return (::slunkcrypt_inplace(m_instance, buffer, length) == SLUNKCRYPT_SUCCESS);
		}

		bool inplace(std::vector<uint8_t> &buffer)
		{
			return (::slunkcrypt_inplace(m_instance, buffer.data(), buffer.size()) == SLUNKCRYPT_SUCCESS);
		}

	private:
		Decryptor(const Decryptor&) = delete;
		Decryptor& operator=(const Decryptor&) = delete;
	};
}
#endif