Improved library initialization code.

This commit is contained in:
LoRd_MuldeR 2020-10-26 19:56:45 +01:00
parent b577afba49
commit ae3318a12f
Signed by: mulder
GPG Key ID: 2B5913365F57E03F
4 changed files with 108 additions and 58 deletions

View File

@ -86,13 +86,13 @@ SLUNKCRYPT_API void slunkcrypt_cleanup(void);
/* /*
* Seed generator * Seed generator
*/ */
SLUNKCRYPT_API int slunkcrypt_generate_salt(uint64_t* const seed); SLUNKCRYPT_API int slunkcrypt_generate_nonce(uint64_t* const nonce);
/* /*
* Allocate, reset or free state * Allocate, reset or free state
*/ */
SLUNKCRYPT_API slunkcrypt_t slunkcrypt_alloc(const uint64_t salt, const uint8_t *const passwd, const size_t passwd_len); SLUNKCRYPT_API slunkcrypt_t slunkcrypt_alloc(const uint64_t nonce, const uint8_t *const passwd, const size_t passwd_len);
SLUNKCRYPT_API int slunkcrypt_reset(const slunkcrypt_t context, const uint64_t salt, const uint8_t *const passwd, const size_t passwd_len); SLUNKCRYPT_API int slunkcrypt_reset(const slunkcrypt_t context, const uint64_t nonce, const uint8_t *const passwd, const size_t passwd_len);
SLUNKCRYPT_API void slunkcrypt_free(const slunkcrypt_t context); SLUNKCRYPT_API void slunkcrypt_free(const slunkcrypt_t context);
/* /*

View File

@ -16,11 +16,11 @@ class SlunkCryptEncr
public: public:
SlunkCryptEncr(const std::string &passwd) SlunkCryptEncr(const std::string &passwd)
{ {
if (slunkcrypt_generate_salt(&m_salt) != SLUNKCRYPT_SUCCESS) if (slunkcrypt_generate_nonce(&m_nonce) != SLUNKCRYPT_SUCCESS)
{ {
throw std::runtime_error("Failed to generate the seed value!"); throw std::runtime_error("Failed to generate the seed value!");
} }
if ((m_instance = slunkcrypt_alloc(m_salt, (const uint8_t*)passwd.c_str(), passwd.length())) == SLUNKCRYPT_NULL) if ((m_instance = slunkcrypt_alloc(m_nonce, (const uint8_t*)passwd.c_str(), passwd.length())) == SLUNKCRYPT_NULL)
{ {
throw std::runtime_error("Failed to create encoder instance!"); throw std::runtime_error("Failed to create encoder instance!");
} }
@ -29,7 +29,7 @@ public:
SlunkCryptEncr(SlunkCryptEncr &&other) noexcept SlunkCryptEncr(SlunkCryptEncr &&other) noexcept
{ {
this->m_instance = other.m_instance; this->m_instance = other.m_instance;
this->m_salt = other.m_salt; this->m_nonce = other.m_nonce;
other.m_instance = SLUNKCRYPT_NULL; other.m_instance = SLUNKCRYPT_NULL;
} }
@ -65,24 +65,24 @@ public:
return (slunkcrypt_encrypt_inplace(m_instance, buffer.data(), buffer.size()) == SLUNKCRYPT_SUCCESS); return (slunkcrypt_encrypt_inplace(m_instance, buffer.data(), buffer.size()) == SLUNKCRYPT_SUCCESS);
} }
uint64_t salt_value(void) const uint64_t get_nonce(void) const
{ {
return m_salt; return m_nonce;
} }
private: private:
SlunkCryptEncr(const SlunkCryptEncr&); SlunkCryptEncr(const SlunkCryptEncr&);
SlunkCryptEncr& operator=(const SlunkCryptEncr&); SlunkCryptEncr& operator=(const SlunkCryptEncr&);
uint64_t m_salt; uint64_t m_nonce;
slunkcrypt_t m_instance; slunkcrypt_t m_instance;
}; };
class SlunkCryptDecr class SlunkCryptDecr
{ {
public: public:
SlunkCryptDecr(const uint64_t salt, const std::string& passwd) SlunkCryptDecr(const uint64_t nonce, const std::string& passwd)
{ {
if ((m_instance = slunkcrypt_alloc(salt, (const uint8_t*)passwd.c_str(), passwd.length())) == SLUNKCRYPT_NULL) if ((m_instance = slunkcrypt_alloc(nonce, (const uint8_t*)passwd.c_str(), passwd.length())) == SLUNKCRYPT_NULL)
{ {
throw std::runtime_error("Failed to create encoder instance!"); throw std::runtime_error("Failed to create encoder instance!");
} }

View File

@ -3,11 +3,21 @@
/* This work has been released under the CC0 1.0 Universal license! */ /* This work has been released under the CC0 1.0 Universal license! */
/******************************************************************************/ /******************************************************************************/
#include <slunkcrypt.h> /* Internal */
#include "../include/slunkcrypt.h"
/* CRT */
#include <string.h>
#include <fcntl.h>
#include <limits.h>
/* Platform compatibility */
#ifdef _WIN32 #ifdef _WIN32
# define WIN32_LEAN_AND_MEAN 1 # define WIN32_LEAN_AND_MEAN 1
# include <Windows.h> # include <Windows.h>
# define SCHED_YIELD() Sleep(1U)
# define COMPARE_AND_SWAP(PTR,OLD,NEW) InterlockedCompareExchange((PTR),(NEW),(OLD))
# define ATOMIC_STORE(PTR,VAL) InterlockedExchange((PTR),(VAL))
# if defined(SecureZeroMemory) # if defined(SecureZeroMemory)
# define HAVE_SECURE_ZERO_MEMORY 1 # define HAVE_SECURE_ZERO_MEMORY 1
# else # else
@ -17,8 +27,15 @@
# define HAVE_EXPLICIT_BZERO 0 # define HAVE_EXPLICIT_BZERO 0
#else #else
# include <unistd.h> # include <unistd.h>
# include <fcntl.h> # include <sched.h>
# include <string.h> # define SCHED_YIELD() sched_yield()
# if defined(__GNUC__) || defined(__clang__) || defined(__INTEL_COMPILER)
# define COMPARE_AND_SWAP(PTR,OLD,NEW) __sync_val_compare_and_swap((PTR),(OLD),(NEW))
# define ATOMIC_STORE(PTR,VAL) __atomic_store_n((PTR),(VAL),__ATOMIC_RELEASE)
# else
# define COMPARE_AND_SWAP(PTR,OLD,NEW) ((OLD))
# define ATOMIC_STORE(PTR,VAL) do { *(PTR) = (VAL); } while(0)
# endif
# if defined(__GLIBC__) && (__GLIBC__ >= 2) && (__GLIBC_MINOR__ >= 25) # if defined(__GLIBC__) && (__GLIBC__ >= 2) && (__GLIBC_MINOR__ >= 25)
# define HAVE_GETRANDOM 1 # define HAVE_GETRANDOM 1
# define HAVE_EXPLICIT_BZERO 1 # define HAVE_EXPLICIT_BZERO 1
@ -37,10 +54,32 @@
# endif # endif
#endif #endif
// ==========================================================================
// Critical sections
// ==========================================================================
static int enter_critsec(volatile long *const lock, const int flag)
{
const long expected = flag ? 0L : 1L;
long status;
while ((status = COMPARE_AND_SWAP(lock, expected, -1L)) < 0L)
{
SCHED_YIELD();
}
return (status == expected);
}
static void leave_critsec(volatile long *const lock, const int flag)
{
ATOMIC_STORE(lock, flag ? 1L : 0L);
}
// ========================================================================== // ==========================================================================
// (De)Initialization // (De)Initialization
// ========================================================================== // ==========================================================================
static volatile long s_initialized = 0L;
#if defined(_WIN32) #if defined(_WIN32)
typedef BOOLEAN(WINAPI *genrandom_t)(void*, ULONG); typedef BOOLEAN(WINAPI *genrandom_t)(void*, ULONG);
static HMODULE s_advapi32 = NULL; static HMODULE s_advapi32 = NULL;
@ -52,35 +91,43 @@ static int s_random_fd = -1;
void slunkcrypt_startup(void) void slunkcrypt_startup(void)
{ {
if (enter_critsec(&s_initialized, 1))
{
#if defined(_WIN32) #if defined(_WIN32)
if (s_advapi32 || (s_advapi32 = LoadLibraryW(L"advapi32.dll"))) if ((s_advapi32 = LoadLibraryW(L"advapi32.dll")))
{ {
s_genrandom = (genrandom_t)GetProcAddress(s_advapi32, "SystemFunction036"); s_genrandom = (genrandom_t)GetProcAddress(s_advapi32, "SystemFunction036");
} }
#elif !HAVE_GETRANDOM #elif !HAVE_GETRANDOM
for (size_t i = 0U; (s_random_fd < 0) && DEV_RANDOM[i]; ++i) for (size_t i = 0U; (s_random_fd < 0) && DEV_RANDOM[i]; ++i)
{ {
s_random_fd = open(DEV_RANDOM[i], O_RDONLY); s_random_fd = open(DEV_RANDOM[i], O_RDONLY);
} }
#endif #endif
leave_critsec(&s_initialized, 1);
}
} }
void slunkcrypt_cleanup(void) void slunkcrypt_cleanup(void)
{ {
if (enter_critsec(&s_initialized, 0))
{
#if defined(_WIN32) #if defined(_WIN32)
s_genrandom = NULL; s_genrandom = NULL;
if (s_advapi32) if (s_advapi32)
{ {
FreeLibrary(s_advapi32); FreeLibrary(s_advapi32);
s_advapi32 = NULL; s_advapi32 = NULL;
} }
#elif !HAVE_GETRANDOM #elif !HAVE_GETRANDOM
if (s_random_fd >= 0) if (s_random_fd >= 0)
{ {
close(s_random_fd); close(s_random_fd);
s_random_fd = -1; s_random_fd = -1;
} }
#endif #endif
leave_critsec(&s_initialized, 0);
}
} }
// ========================================================================== // ==========================================================================

View File

@ -7,22 +7,25 @@
#define _CRT_RAND_S 1 #define _CRT_RAND_S 1
#endif #endif
#include <slunkcrypt.h> /* Internal */
#include <string.h> #include "../include/slunkcrypt.h"
#include <assert.h>
#include "version.h" #include "version.h"
#ifdef _MSC_VER /* CRT */
#define FORCE_INLINE __forceinline #include <string.h>
#define UNUSED __pragma(warning(suppress: 4189)) #include <limits.h>
#include <assert.h>
/* Compiler compatibility */
#if defined(_MSC_VER)
# define FORCE_INLINE __forceinline
# define UNUSED __pragma(warning(suppress: 4189))
#elif defined(__GNUC__)
# define FORCE_INLINE __attribute__((always_inline)) inline
# define UNUSED __attribute__((unused))
#else #else
#ifdef __GNUC__ # define FORCE_INLINE inline
#define FORCE_INLINE __attribute__((always_inline)) inline # define UNUSED
#define UNUSED __attribute__((unused))
#else
#define FORCE_INLINE inline
#define UNUSED
#endif
#endif #endif
/* Version info */ /* Version info */
@ -195,7 +198,7 @@ static void random_seed(rand_state_t* const state, const uint64_t salt, const ui
slunkcrypt_bzero(&key, sizeof(key_data_t)); slunkcrypt_bzero(&key, sizeof(key_data_t));
for (size_t i = 0U; i < 97U; ++i) for (size_t i = 0U; i < 97U; ++i)
{ {
volatile UNUSED uint32_t u = random_next(state); UNUSED volatile uint32_t u = random_next(state);
} }
} }
@ -203,7 +206,7 @@ static void random_seed(rand_state_t* const state, const uint64_t salt, const ui
// Initialization // Initialization
// ========================================================================== // ==========================================================================
static int initialize_state(crypt_state_t* const crypt_state, const uint64_t salt, const uint8_t* const passwd, const size_t passwd_len) static int initialize_state(crypt_state_t* const crypt_state, const uint64_t nonce, const uint8_t* const passwd, const size_t passwd_len)
{ {
slunkcrypt_bzero(crypt_state, sizeof(crypt_state_t)); slunkcrypt_bzero(crypt_state, sizeof(crypt_state_t));
@ -211,7 +214,7 @@ static int initialize_state(crypt_state_t* const crypt_state, const uint64_t sal
rand_state_t rand_state; rand_state_t rand_state;
for (size_t r = 0U; r < 256U; ++r) for (size_t r = 0U; r < 256U; ++r)
{ {
random_seed(&rand_state, salt, (uint16_t)r, passwd, passwd_len); random_seed(&rand_state, nonce, (uint16_t)r, passwd, passwd_len);
crypt_state->rotation_bwd[0U][255U - r] = crypt_state->rotation_fwd[0U][r] = (uint8_t)random_next(&rand_state); crypt_state->rotation_bwd[0U][255U - r] = crypt_state->rotation_fwd[0U][r] = (uint8_t)random_next(&rand_state);
crypt_state->rotation_bwd[1U][255U - r] = crypt_state->rotation_fwd[1U][r] = 0U; crypt_state->rotation_bwd[1U][255U - r] = crypt_state->rotation_fwd[1U][r] = 0U;
for (size_t i = 0U; i < 256U; ++i) for (size_t i = 0U; i < 256U; ++i)
@ -232,7 +235,7 @@ static int initialize_state(crypt_state_t* const crypt_state, const uint64_t sal
} }
/* set up stepping */ /* set up stepping */
random_seed(&rand_state, salt, 256U, passwd, passwd_len); random_seed(&rand_state, nonce, 256U, passwd, passwd_len);
for (size_t i = 0U; i < 256U; ++i) for (size_t i = 0U; i < 256U; ++i)
{ {
const size_t j = random_next(&rand_state) % (i + 1U); const size_t j = random_next(&rand_state) % (i + 1U);
@ -292,24 +295,24 @@ static FORCE_INLINE uint8_t process_dec(crypt_state_t* const crypt_state, uint8_
// Public API // Public API
// ========================================================================== // ==========================================================================
int slunkcrypt_generate_salt(uint64_t* const seed) int slunkcrypt_generate_nonce(uint64_t* const nonce)
{ {
if (!seed) if (!nonce)
{ {
return SLUNKCRYPT_FAILURE; return SLUNKCRYPT_FAILURE;
} }
do do
{ {
if (slunkcrypt_random_bytes((uint8_t*)seed, sizeof(uint64_t)) != 0) if (slunkcrypt_random_bytes((uint8_t*)nonce, sizeof(uint64_t)) != 0)
{ {
return SLUNKCRYPT_FAILURE; return SLUNKCRYPT_FAILURE;
} }
} }
while (!(*seed)); while (!(*nonce));
return SLUNKCRYPT_SUCCESS; return SLUNKCRYPT_SUCCESS;
} }
slunkcrypt_t slunkcrypt_alloc(const uint64_t salt, const uint8_t *const passwd, const size_t passwd_len) slunkcrypt_t slunkcrypt_alloc(const uint64_t nonce, const uint8_t *const passwd, const size_t passwd_len)
{ {
if ((!passwd) || (passwd_len < SLUNKCRYPT_PWDLEN_MIN) || (passwd_len > SLUNKCRYPT_PWDLEN_MAX)) if ((!passwd) || (passwd_len < SLUNKCRYPT_PWDLEN_MIN) || (passwd_len > SLUNKCRYPT_PWDLEN_MAX))
{ {
@ -320,7 +323,7 @@ slunkcrypt_t slunkcrypt_alloc(const uint64_t salt, const uint8_t *const passwd,
{ {
return SLUNKCRYPT_NULL; return SLUNKCRYPT_NULL;
} }
if (initialize_state(state, salt, passwd, passwd_len) == SLUNKCRYPT_SUCCESS) if (initialize_state(state, nonce, passwd, passwd_len) == SLUNKCRYPT_SUCCESS)
{ {
return ((slunkcrypt_t)state); return ((slunkcrypt_t)state);
} }
@ -331,14 +334,14 @@ slunkcrypt_t slunkcrypt_alloc(const uint64_t salt, const uint8_t *const passwd,
} }
} }
int slunkcrypt_reset(const slunkcrypt_t context, const uint64_t salt, const uint8_t *const passwd, const size_t passwd_len) int slunkcrypt_reset(const slunkcrypt_t context, const uint64_t nonce, const uint8_t *const passwd, const size_t passwd_len)
{ {
crypt_state_t* const state = (crypt_state_t*)context; crypt_state_t* const state = (crypt_state_t*)context;
if ((!state) || (!passwd) || (passwd_len < SLUNKCRYPT_PWDLEN_MIN) || (passwd_len > SLUNKCRYPT_PWDLEN_MAX)) if ((!state) || (!passwd) || (passwd_len < SLUNKCRYPT_PWDLEN_MIN) || (passwd_len > SLUNKCRYPT_PWDLEN_MAX))
{ {
return SLUNKCRYPT_FAILURE; return SLUNKCRYPT_FAILURE;
} }
const int result = initialize_state(state, salt, passwd, passwd_len); const int result = initialize_state(state, nonce, passwd, passwd_len);
if (result != SLUNKCRYPT_SUCCESS) if (result != SLUNKCRYPT_SUCCESS)
{ {
slunkcrypt_bzero(state, sizeof(crypt_state_t)); slunkcrypt_bzero(state, sizeof(crypt_state_t));