Much improved threading support.

This commit is contained in:
LoRd_MuldeR 2022-03-21 22:45:29 +01:00
parent ddefc8c142
commit c32c85d8c9
Signed by: mulder
GPG Key ID: 2B5913365F57E03F
6 changed files with 329 additions and 234 deletions

View File

@ -28,7 +28,7 @@
static const uint64_t MAGIC_NUMBER = 0x243F6A8885A308D3ull; static const uint64_t MAGIC_NUMBER = 0x243F6A8885A308D3ull;
#define BUFFER_SIZE 4096U #define BUFFER_SIZE 32768U
// ========================================================================== // ==========================================================================
// Auxiliary functions // Auxiliary functions
@ -138,7 +138,7 @@ int encrypt(const char *const passphrase, const CHR *const input_path, const CHR
{ {
break; /*EOF*/ break; /*EOF*/
} }
if (!(++refresh_cycles & 0x7)) if (!(++refresh_cycles & 0x3))
{ {
const uint64_t clk_now = clock_read(); const uint64_t clk_now = clock_read();
if ((clk_now < clk_update) || (clk_now - clk_update > update_interval)) if ((clk_now < clk_update) || (clk_now - clk_update > update_interval))
@ -322,7 +322,7 @@ int decrypt(const char *const passphrase, const CHR *const input_path, const CHR
{ {
break; /*EOF*/ break; /*EOF*/
} }
if (!(++refresh_cycles & 0x7)) if (!(++refresh_cycles & 0x3))
{ {
const uint64_t clk_now = clock_read(); const uint64_t clk_now = clock_read();
if ((clk_now < clk_update) || (clk_now - clk_update > update_interval)) if ((clk_now < clk_update) || (clk_now - clk_update > update_interval))

View File

@ -19,5 +19,5 @@ using System.Windows;
[assembly: ComVisible(false)] [assembly: ComVisible(false)]
[assembly: ThemeInfo(ResourceDictionaryLocation.None, ResourceDictionaryLocation.SourceAssembly)] [assembly: ThemeInfo(ResourceDictionaryLocation.None, ResourceDictionaryLocation.SourceAssembly)]
[assembly: AssemblyVersion("1.1.*")] [assembly: AssemblyVersion("1.2.*")]
[assembly: AssemblyFileVersion("1.1.0.0")] [assembly: AssemblyFileVersion("1.2.0.0")]

View File

@ -75,6 +75,17 @@ static const int SLUNKCRYPT_ABORTED = -2;
static const size_t SLUNKCRYPT_PWDLEN_MIN = 8U; static const size_t SLUNKCRYPT_PWDLEN_MIN = 8U;
static const size_t SLUNKCRYPT_PWDLEN_MAX = 256U; static const size_t SLUNKCRYPT_PWDLEN_MAX = 256U;
/*
* Optional parameters
*/
static const uint16_t SLUNKCRYPT_PARAM_VERSION = 1U;
typedef struct
{
uint16_t version; /* Must set to SLUNKCRYPT_PARAM_VERSION */
size_t thread_count; /* Number of threads, set to 0 for auto-detection */
}
slunk_param_t;
/* /*
* Version info * Version info
*/ */
@ -101,6 +112,7 @@ SLUNKCRYPT_API int slunkcrypt_generate_nonce(uint64_t *const nonce);
* Allocate, reset or free state * Allocate, reset or free state
*/ */
SLUNKCRYPT_API slunkcrypt_t slunkcrypt_alloc(const uint64_t nonce, const uint8_t *const passwd, const size_t passwd_len, const int mode); SLUNKCRYPT_API slunkcrypt_t slunkcrypt_alloc(const uint64_t nonce, const uint8_t *const passwd, const size_t passwd_len, const int mode);
SLUNKCRYPT_API slunkcrypt_t slunkcrypt_alloc_ext(const uint64_t nonce, const uint8_t *const passwd, const size_t passwd_len, const int mode, const slunk_param_t *const param);
SLUNKCRYPT_API int slunkcrypt_reset(const slunkcrypt_t context, const uint64_t nonce, const uint8_t *const passwd, const size_t passwd_len, const int mode); SLUNKCRYPT_API int slunkcrypt_reset(const slunkcrypt_t context, const uint64_t nonce, const uint8_t *const passwd, const size_t passwd_len, const int mode);
SLUNKCRYPT_API void slunkcrypt_free(const slunkcrypt_t context); SLUNKCRYPT_API void slunkcrypt_free(const slunkcrypt_t context);

View File

@ -38,21 +38,24 @@ typedef struct
int reverse_mode; int reverse_mode;
const uint8_t (*wheel)[256U]; const uint8_t (*wheel)[256U];
uint32_t counter; uint32_t counter;
size_t index_off;
rand_state_t random; rand_state_t random;
uint8_t *data;
} }
thread_state_t; thread_state_t;
typedef struct typedef struct
{ {
uint8_t wheel[256U][256U]; uint8_t wheel[256U][256U];
thrdpl_t thread_pool;
size_t thread_idx;
thread_state_t thread_data[MAX_THREADS]; thread_state_t thread_data[MAX_THREADS];
} }
crypt_state_t; crypt_data_t;
#define THREAD_COUNT 1U typedef struct
{
thrdpl_t thread_pool;
crypt_data_t data;
}
crypt_state_t;
// ========================================================================== // ==========================================================================
// Abort flag // Abort flag
@ -109,9 +112,17 @@ static INLINE uint32_t random_next(rand_state_t *const state)
return (state->d += 0x000587C5) + state->v; return (state->d += 0x000587C5) + state->v;
} }
static INLINE void random_seed(rand_state_t *const state, uint64_t salt, const uint16_t pepper, const uint8_t *const passwd, const size_t passwd_len) static INLINE void random_skip(rand_state_t *const state, const size_t skip_count)
{ {
size_t i; size_t i;
for (i = 0U; i < skip_count; ++i)
{
/* UNUSED volatile uint32_t q = */ random_next(state);
}
}
static INLINE void random_seed(rand_state_t *const state, uint64_t salt, const uint16_t pepper, const uint8_t *const passwd, const size_t passwd_len)
{
keydata_t key; keydata_t key;
do do
{ {
@ -120,52 +131,38 @@ static INLINE void random_seed(rand_state_t *const state, uint64_t salt, const u
slunkcrypt_bzero(&key, sizeof(keydata_t)); slunkcrypt_bzero(&key, sizeof(keydata_t));
} }
while (!(state->x || state->y || state->z || state->w || state->v)); while (!(state->x || state->y || state->z || state->w || state->v));
for (i = 0U; i < 97U; ++i) random_skip(state, 97U);
{
UNUSED volatile uint32_t q = random_next(state);
}
} }
// ========================================================================== // ==========================================================================
// Initialization // Initialization
// ========================================================================== // ==========================================================================
static int initialize_state(crypt_state_t *const state, const uint64_t nonce, const uint8_t *const passwd, const size_t passwd_len, const int mode, const int reset) static int initialize_state(crypt_data_t *const data, const size_t thread_count, const uint64_t nonce, const uint8_t *const passwd, const size_t passwd_len, const int mode)
{ {
uint8_t temp[256U][256U]; uint8_t temp[256U][256U];
size_t r, i; size_t r, i;
rand_state_t random;
uint32_t counter;
const int reverse_mode = BOOLIFY(mode); const int reverse_mode = BOOLIFY(mode);
/* backup previous value */
const thrdpl_t thread_pool = reset ? state->thread_pool : THRDPL_NULL;
/* initialize state */ /* initialize state */
slunkcrypt_bzero(state, sizeof(crypt_state_t)); slunkcrypt_bzero(data, sizeof(crypt_data_t));
/* create thread-pool */
if ((state->thread_pool = reset ? thread_pool : thrdpl_create(THREAD_COUNT)) == THRDPL_NULL)
{
return SLUNKCRYPT_FAILURE;
}
/* initialize counter */ /* initialize counter */
random_seed(&random, nonce, (uint16_t)(-1), passwd, passwd_len); random_seed(&data->thread_data[0].random, nonce, (uint16_t)(-1), passwd, passwd_len);
counter = random_next(&random); data->thread_data[0].counter = random_next(&data->thread_data[0].random);
/* set up the wheel permutations */ /* set up the wheel permutations */
for (r = 0U; r < 256U; ++r) for (r = 0U; r < 256U; ++r)
{ {
random_seed(&random, nonce, (uint16_t)r, passwd, passwd_len); random_seed(&data->thread_data[0].random, nonce, (uint16_t)r, passwd, passwd_len);
for (i = 0U; i < 256U; ++i) for (i = 0U; i < 256U; ++i)
{ {
const size_t j = random_next(&random) % (i + 1U); const size_t j = random_next(&data->thread_data[0].random) % (i + 1U);
if (j != i) if (j != i)
{ {
state->wheel[r][i] = state->wheel[r][j]; data->wheel[r][i] = data->wheel[r][j];
} }
state->wheel[r][j] = (uint8_t)i; data->wheel[r][j] = (uint8_t)i;
} }
CHECK_ABORTED(); CHECK_ABORTED();
} }
@ -177,43 +174,38 @@ static int initialize_state(crypt_state_t *const state, const uint64_t nonce, co
{ {
for (i = 0U; i < 256U; ++i) for (i = 0U; i < 256U; ++i)
{ {
temp[r][state->wheel[r][i]] = (uint8_t)i; temp[r][data->wheel[r][i]] = (uint8_t)i;
} }
} }
for (r = 0U; r < 256U; ++r) for (r = 0U; r < 256U; ++r)
{ {
memcpy(state->wheel[255U - r], temp[r], 256U); memcpy(data->wheel[255U - r], temp[r], 256U);
} }
slunkcrypt_bzero(temp, sizeof(temp)); slunkcrypt_bzero(temp, sizeof(temp));
CHECK_ABORTED(); CHECK_ABORTED();
} }
/* set up thread state */ /* initialize up thread state */
random_seed(&random, nonce, 256U, passwd, passwd_len); data->thread_data[0].reverse_mode = reverse_mode;
for (i = 0U; i < THREAD_COUNT; ++i) data->thread_data[0].wheel = data->wheel;
data->thread_data[0].index_off = 0U;
random_seed(&data->thread_data[0].random, nonce, 256U, passwd, passwd_len);
for (i = 1U; i < thread_count; ++i)
{ {
state->thread_data[i].reverse_mode = reverse_mode; data->thread_data[i].reverse_mode = data->thread_data[0].reverse_mode;
state->thread_data[i].wheel = state->wheel; data->thread_data[i].wheel = data->thread_data[0].wheel;
state->thread_data[i].counter = counter + ((uint32_t)i); data->thread_data[i].counter = data->thread_data[0].counter + ((uint32_t)i);
memcpy(&state->thread_data[i].random, &random, sizeof(rand_state_t)); data->thread_data[i].index_off = data->thread_data[i - 1U].index_off + 1U;
for (r = 0U; r < i * 63U; ++r) memcpy(&data->thread_data[i].random, &data->thread_data[0].random, sizeof(rand_state_t));
{ random_skip(&data->thread_data[i].random, i * 63U);
random_next(&state->thread_data[i].random);
}
CHECK_ABORTED(); CHECK_ABORTED();
} }
slunkcrypt_bzero(&counter, sizeof(uint32_t));
slunkcrypt_bzero(&random, sizeof(rand_state_t));
return SLUNKCRYPT_SUCCESS; return SLUNKCRYPT_SUCCESS;
/* aborted */ /* aborted */
aborted: aborted:
thrdpl_destroy(state->thread_pool); slunkcrypt_bzero(data, sizeof(crypt_data_t));
slunkcrypt_bzero(state, sizeof(crypt_state_t));
slunkcrypt_bzero(&counter, sizeof(uint32_t));
slunkcrypt_bzero(&random, sizeof(rand_state_t));
return SLUNKCRYPT_ABORTED; return SLUNKCRYPT_ABORTED;
} }
@ -232,22 +224,44 @@ static INLINE void update_offset(uint8_t *const offset, uint32_t seed, rand_stat
} }
offset[reverse ? (255U - i) : i] = (uint8_t)seed; offset[reverse ? (255U - i) : i] = (uint8_t)seed;
} }
for (i = 0U; i < 63U * (THREAD_COUNT - 1U); ++i)
{
random_next(state);
}
} }
static INLINE void process_next_symbol(thread_state_t *const state) static INLINE uint8_t process_next_symbol(thread_state_t *const state, uint8_t value)
{ {
uint8_t offset[256U]; uint8_t offset[256U];
size_t i; size_t i;
update_offset(offset, state->counter, &state->random, state->reverse_mode); update_offset(offset, state->counter, &state->random, state->reverse_mode);
for (i = 0U; i < 256U; ++i) for (i = 0U; i < 256U; ++i)
{ {
*state->data = (state->wheel[i][(*state->data + offset[i]) & 0xFF] - offset[i]) & 0xFF; value = (state->wheel[i][(value + offset[i]) & 0xFF] - offset[i]) & 0xFF;
} }
state->counter += THREAD_COUNT; return value;
}
// ==========================================================================
// Thread entry point
// ==========================================================================
static INLINE void update_index(thread_state_t *const state, const size_t thread_count, const size_t length)
{
const size_t remaining = thread_count - (length % thread_count);
if (remaining != thread_count)
{
state->index_off = (state->index_off + remaining) % thread_count;
}
}
static void thread_worker(const size_t thread_count, void *const context, const uint8_t *const input, uint8_t *const output, const size_t length)
{
thread_state_t *const state = (thread_state_t*) context;
size_t i;
for (i = state->index_off; i < length; i += thread_count)
{
output[i] = process_next_symbol(state, input[i]);
state->counter += (uint32_t)thread_count;
random_skip(&state->random, 63U * (thread_count - 1U));
}
update_index(state, thread_count, length);
} }
// ========================================================================== // ==========================================================================
@ -272,22 +286,39 @@ int slunkcrypt_generate_nonce(uint64_t *const nonce)
} }
slunkcrypt_t slunkcrypt_alloc(const uint64_t nonce, const uint8_t *const passwd, const size_t passwd_len, const int mode) slunkcrypt_t slunkcrypt_alloc(const uint64_t nonce, const uint8_t *const passwd, const size_t passwd_len, const int mode)
{
slunk_param_t param = { SLUNKCRYPT_PARAM_VERSION, 0U };
return slunkcrypt_alloc_ext(nonce, passwd, passwd_len, mode, &param);
}
slunkcrypt_t slunkcrypt_alloc_ext(const uint64_t nonce, const uint8_t *const passwd, const size_t passwd_len, const int mode, const slunk_param_t *const param)
{ {
crypt_state_t* state = NULL; crypt_state_t* state = NULL;
if ((!passwd) || (passwd_len < SLUNKCRYPT_PWDLEN_MIN) || (passwd_len > SLUNKCRYPT_PWDLEN_MAX) || (mode < SLUNKCRYPT_ENCRYPT) || (mode > SLUNKCRYPT_DECRYPT))
if ((!passwd) || (passwd_len < SLUNKCRYPT_PWDLEN_MIN) || (passwd_len > SLUNKCRYPT_PWDLEN_MAX) ||
(mode < SLUNKCRYPT_ENCRYPT) || (mode > SLUNKCRYPT_DECRYPT) || (!param) || (param->version == 0U) || (param->version > SLUNKCRYPT_PARAM_VERSION))
{ {
return SLUNKCRYPT_NULL; return SLUNKCRYPT_NULL;
} }
if (!(state = (crypt_state_t*)malloc(sizeof(crypt_state_t)))) if (!(state = (crypt_state_t*)malloc(sizeof(crypt_state_t))))
{ {
return SLUNKCRYPT_NULL; return SLUNKCRYPT_NULL;
} }
if (initialize_state(state, nonce, passwd, passwd_len, mode, 0) == SLUNKCRYPT_SUCCESS)
if ((state->thread_pool = slunkcrypt_thrdpl_create(param->thread_count)) == THRDPL_NULL)
{
free(state);
return SLUNKCRYPT_NULL;
}
if (initialize_state(&state->data, slunkcrypt_thrdpl_count(state->thread_pool), nonce, passwd, passwd_len, mode) == SLUNKCRYPT_SUCCESS)
{ {
return ((slunkcrypt_t)state); return ((slunkcrypt_t)state);
} }
else else
{ {
slunkcrypt_thrdpl_destroy(state->thread_pool);
slunkcrypt_bzero(state, sizeof(crypt_state_t)); slunkcrypt_bzero(state, sizeof(crypt_state_t));
return SLUNKCRYPT_NULL; return SLUNKCRYPT_NULL;
} }
@ -295,15 +326,16 @@ slunkcrypt_t slunkcrypt_alloc(const uint64_t nonce, const uint8_t *const passwd,
int slunkcrypt_reset(const slunkcrypt_t context, const uint64_t nonce, const uint8_t *const passwd, const size_t passwd_len, const int mode) int slunkcrypt_reset(const slunkcrypt_t context, const uint64_t nonce, const uint8_t *const passwd, const size_t passwd_len, const int mode)
{ {
crypt_state_t *const state = (crypt_state_t*)context; crypt_state_t *const state = (crypt_state_t*) context;
int result = SLUNKCRYPT_FAILURE; int result = SLUNKCRYPT_FAILURE;
if ((!state) || (!passwd) || (passwd_len < SLUNKCRYPT_PWDLEN_MIN) || (passwd_len > SLUNKCRYPT_PWDLEN_MAX) || (mode < SLUNKCRYPT_ENCRYPT) || (mode > SLUNKCRYPT_DECRYPT)) if ((!state) || (!passwd) || (passwd_len < SLUNKCRYPT_PWDLEN_MIN) || (passwd_len > SLUNKCRYPT_PWDLEN_MAX) || (mode < SLUNKCRYPT_ENCRYPT) || (mode > SLUNKCRYPT_DECRYPT))
{ {
return SLUNKCRYPT_FAILURE; return SLUNKCRYPT_FAILURE;
} }
if ((result = initialize_state(state, nonce, passwd, passwd_len, mode, 1)) != SLUNKCRYPT_SUCCESS) if ((result = initialize_state(&state->data, slunkcrypt_thrdpl_count(state->thread_pool), nonce, passwd, passwd_len, mode)) != SLUNKCRYPT_SUCCESS)
{ {
slunkcrypt_bzero(state, sizeof(crypt_state_t)); slunkcrypt_bzero(&state->data, sizeof(crypt_data_t));
} }
return result; return result;
} }
@ -311,7 +343,7 @@ int slunkcrypt_reset(const slunkcrypt_t context, const uint64_t nonce, const uin
int slunkcrypt_process(const slunkcrypt_t context, const uint8_t *const input, uint8_t *const output, size_t length) int slunkcrypt_process(const slunkcrypt_t context, const uint8_t *const input, uint8_t *const output, size_t length)
{ {
size_t i; size_t i;
crypt_state_t *const state = (crypt_state_t*)context; crypt_state_t *const state = (crypt_state_t*) context;
if (!state) if (!state)
{ {
return SLUNKCRYPT_FAILURE; return SLUNKCRYPT_FAILURE;
@ -319,27 +351,27 @@ int slunkcrypt_process(const slunkcrypt_t context, const uint8_t *const input, u
if (length > 0U) if (length > 0U)
{ {
memcpy(output, input, length * sizeof(uint8_t)); const size_t thread_count = slunkcrypt_thrdpl_count(state->thread_pool);
for (i = 0; i < length; ++i) for (i = 0; i < thread_count; ++i)
{ {
abort(); //process_next_symbol(state, output + i); slunkcrypt_thrdpl_exec(state->thread_pool, i, thread_worker, &state->data.thread_data[i], input, output, length);
CHECK_ABORTED();
} }
slunkcrypt_thrdpl_await(state->thread_pool);
} }
thrdpl_await(state->thread_pool); CHECK_ABORTED();
return SLUNKCRYPT_SUCCESS; return SLUNKCRYPT_SUCCESS;
aborted: aborted:
thrdpl_await(state->thread_pool); slunkcrypt_thrdpl_await(state->thread_pool);
slunkcrypt_bzero(state, sizeof(crypt_state_t)); slunkcrypt_bzero(&state->data, sizeof(crypt_data_t));
return SLUNKCRYPT_ABORTED; return SLUNKCRYPT_ABORTED;
} }
int slunkcrypt_inplace(const slunkcrypt_t context, uint8_t *const buffer, size_t length) int slunkcrypt_inplace(const slunkcrypt_t context, uint8_t *const buffer, size_t length)
{ {
size_t i; size_t i;
crypt_state_t *const state = (crypt_state_t*)context; crypt_state_t *const state = (crypt_state_t*) context;
if (!state) if (!state)
{ {
return SLUNKCRYPT_FAILURE; return SLUNKCRYPT_FAILURE;
@ -347,34 +379,29 @@ int slunkcrypt_inplace(const slunkcrypt_t context, uint8_t *const buffer, size_t
if (length > 0U) if (length > 0U)
{ {
for (i = 0; i < length; ++i) const size_t thread_count = slunkcrypt_thrdpl_count(state->thread_pool);
for (i = 0; i < thread_count; ++i)
{ {
state->thread_data[state->thread_idx].data = buffer + i; slunkcrypt_thrdpl_exec(state->thread_pool, i, thread_worker, &state->data.thread_data[i], buffer, buffer, length);
//process_next_symbol(&state->thread_data[state->thread_idx]);
thrdpl_submit(state->thread_pool, process_next_symbol, &state->thread_data[state->thread_idx]);
if (++state->thread_idx >= THREAD_COUNT)
{
state->thread_idx = 0U;
}
CHECK_ABORTED();
} }
slunkcrypt_thrdpl_await(state->thread_pool);
} }
thrdpl_await(state->thread_pool); CHECK_ABORTED();
return SLUNKCRYPT_SUCCESS; return SLUNKCRYPT_SUCCESS;
aborted: aborted:
thrdpl_await(state->thread_pool); slunkcrypt_thrdpl_await(state->thread_pool);
slunkcrypt_bzero(state, sizeof(crypt_state_t)); slunkcrypt_bzero(&state->data, sizeof(crypt_data_t));
return SLUNKCRYPT_ABORTED; return SLUNKCRYPT_ABORTED;
} }
void slunkcrypt_free(const slunkcrypt_t context) void slunkcrypt_free(const slunkcrypt_t context)
{ {
crypt_state_t *const state = (crypt_state_t*)context; crypt_state_t *const state = (crypt_state_t*) context;
if (state) if (state)
{ {
thrdpl_destroy(state->thread_pool); slunkcrypt_thrdpl_destroy(state->thread_pool);
slunkcrypt_bzero(state, sizeof(crypt_state_t)); slunkcrypt_bzero(state, sizeof(crypt_state_t));
free(state); free(state);
} }

View File

@ -3,7 +3,7 @@
/* This work has been released under the CC0 1.0 Universal license! */ /* This work has been released under the CC0 1.0 Universal license! */
/******************************************************************************/ /******************************************************************************/
#ifdef _MSC_VER #if defined(_MSC_VER) && !defined(_DLL)
#define PTW32_STATIC_LIB 1 #define PTW32_STATIC_LIB 1
#endif #endif
@ -16,22 +16,35 @@
/* PThread */ /* PThread */
#include <pthread.h> #include <pthread.h>
#include <semaphore.h> #ifdef __unix__
#include <sys/sysinfo.h>
#endif
/* States */
#define THRD_STATE_IDLE 0
#define THRD_STATE_WORK 1
#define THRD_STATE_EXIT 2
// ==========================================================================
// Data types
// ==========================================================================
typedef struct typedef struct
{ {
thrdpl_worker_t worker; thrdpl_worker_t worker;
void *args; void *context;
size_t length;
const uint8_t *input;
uint8_t *output;
} }
thrdpl_task_t; thrdpl_task_t;
typedef struct typedef struct
{ {
size_t index; const size_t *count;
pthread_mutex_t *mutex; int state;
pthread_cond_t *ready; pthread_mutex_t mutex;
size_t *queue; pthread_cond_t cond;
sem_t sem_free, sem_used;
pthread_t thread; pthread_t thread;
thrdpl_task_t task; thrdpl_task_t task;
} }
@ -39,65 +52,101 @@ thrdpl_thread_t;
typedef struct typedef struct
{ {
size_t thread_count, index, queue; size_t thread_count;
pthread_mutex_t mutex; thrdpl_thread_t thread_data[MAX_THREADS];
pthread_cond_t ready;
thrdpl_thread_t threads[MAX_THREADS];
} }
thrdpl_data_t; thrdpl_data_t;
// ==========================================================================
// Helper macros
// ==========================================================================
#define BOUND(MIN, VAL, MAX) \
{ \
if ((VAL) > (MAX)) { VAL = (MAX); } \
if ((VAL) < (MIN)) { VAL = (MIN); } \
} \
while (0)
#define PTHRD_MUTEX_LOCK(X) do \
{ \
if (pthread_mutex_lock((X)) != 0) \
{ \
abort(); \
} \
} \
while(0)
#define PTHRD_MUTEX_UNLOCK(X) do \
{ \
if (pthread_mutex_unlock((X)) != 0) \
{ \
abort(); \
} \
} \
while(0)
#define PTHRD_COND_BROADCAST(X) do \
{ \
if (pthread_cond_broadcast((X)) != 0) \
{ \
abort(); \
} \
} \
while(0)
#define PTHRD_COND_WAIT(X,Y) do \
{ \
if (pthread_cond_wait((X), (Y)) != 0) \
{ \
abort(); \
} \
} \
while(0)
#define CHECK_IF_CANCELLED() do \
{ \
if (data->state == THRD_STATE_EXIT) \
{ \
if (pthread_mutex_unlock(&data->mutex) != 0) \
{ \
abort(); \
} \
return NULL; /* cancelled */ \
} \
} \
while(0)
// ========================================================================== // ==========================================================================
// Thread main // Thread main
// ========================================================================== // ==========================================================================
static void *thread_main(void *const arg) static void *worker_thread_main(void *const arg)
{ {
thrdpl_thread_t *const data = (thrdpl_thread_t*)arg; thrdpl_thread_t *const data = (thrdpl_thread_t*) arg;
thrdpl_task_t *task;
for (;;) for (;;)
{ {
if (sem_wait(&data->sem_used) != 0) PTHRD_MUTEX_LOCK(&data->mutex);
CHECK_IF_CANCELLED();
while (data->state != THRD_STATE_WORK)
{ {
abort(); PTHRD_COND_WAIT(&data->cond, &data->mutex);
CHECK_IF_CANCELLED();
} }
if (pthread_mutex_lock(data->mutex) != 0) task = &data->task;
{ PTHRD_MUTEX_UNLOCK(&data->mutex);
abort();
}
const thrdpl_worker_t worker = data->task.worker; task->worker(*data->count, task->context, task->input, task->output, task->length);
void *const args = data->task.args;
if (pthread_mutex_unlock(data->mutex) != 0) PTHRD_MUTEX_LOCK(&data->mutex);
{ CHECK_IF_CANCELLED();
abort(); data->state = THRD_STATE_IDLE;
} PTHRD_COND_BROADCAST(&data->cond);
PTHRD_MUTEX_UNLOCK(&data->mutex);
worker(args);
if (pthread_mutex_lock(data->mutex) != 0)
{
abort();
}
if (!(*data->queue -= 1U))
{
if (pthread_cond_broadcast(data->ready) != 0)
{
abort();
}
}
if (pthread_mutex_unlock(data->mutex) != 0)
{
abort();
}
if (sem_post(&data->sem_free) != 0)
{
abort();
}
} }
} }
@ -105,40 +154,61 @@ static void *thread_main(void *const arg)
// Manage threads // Manage threads
// ========================================================================== // ==========================================================================
static int create_thread(const size_t index, thrdpl_thread_t *const thread_data, pthread_mutex_t *const mutex, pthread_cond_t *const ready, size_t *const queue) #if defined(__unix__)
# define GET_NPROCS_FUNCTION() get_nprocs()
#elif defined(PTW32_VERSION)
# define GET_NPROCS_FUNCTION() pthread_num_processors_np()
#endif
static size_t detect_cpu_count(void)
{ {
thread_data->index = index; #ifdef GET_NPROCS_FUNCTION
thread_data->mutex = mutex; const int cpu_count = GET_NPROCS_FUNCTION();
thread_data->ready = ready; if (cpu_count > 0)
thread_data->queue = queue; {
return (size_t) cpu_count;
}
#endif
return 1U;
}
if (sem_init(&thread_data->sem_free, 0, 1U) != 0) static int create_worker_thread(thrdpl_thread_t *const thread_data, const size_t *const count)
{
thread_data->count = count;
thread_data->state = THRD_STATE_IDLE;
if (pthread_mutex_init(&thread_data->mutex, NULL) != 0)
{ {
return -1; return -1;
} }
if (sem_init(&thread_data->sem_used, 0, 0U) != 0) if (pthread_cond_init(&thread_data->cond, NULL) != 0)
{ {
sem_destroy(&thread_data->sem_free); pthread_mutex_destroy(&thread_data->mutex);
return -1; return -1;
} }
if (pthread_create(&thread_data->thread, NULL, thread_main, thread_data) != 0) if (pthread_create(&thread_data->thread, NULL, worker_thread_main, thread_data) != 0)
{ {
sem_destroy(&thread_data->sem_used); pthread_cond_destroy(&thread_data->cond);
sem_destroy(&thread_data->sem_free); pthread_mutex_destroy(&thread_data->mutex);
return -1; return -1;
} }
return 0; return 0;
} }
static int destroy_thread(thrdpl_thread_t *const thread_data) static int destroy_worker_thread(thrdpl_thread_t *const thread_data)
{ {
pthread_cancel(thread_data->thread); PTHRD_MUTEX_LOCK(&thread_data->mutex);
thread_data->state = THRD_STATE_EXIT;
PTHRD_COND_BROADCAST(&thread_data->cond);
PTHRD_MUTEX_UNLOCK(&thread_data->mutex);
pthread_join(thread_data->thread, NULL); pthread_join(thread_data->thread, NULL);
sem_destroy(&thread_data->sem_used); pthread_mutex_destroy(&thread_data->mutex);
sem_destroy(&thread_data->sem_free); pthread_cond_destroy(&thread_data->cond);
return 0; return 0;
} }
@ -146,45 +216,28 @@ static int destroy_thread(thrdpl_thread_t *const thread_data)
// Thread pool API // Thread pool API
// ========================================================================== // ==========================================================================
thrdpl_t thrdpl_create(const size_t count) thrdpl_t slunkcrypt_thrdpl_create(const size_t count)
{ {
size_t i, j; size_t i, j;
thrdpl_data_t *pool = NULL; thrdpl_data_t *pool = NULL;
if ((count < 1U) || (count > MAX_THREADS))
{
return THRDPL_NULL;
}
if (!(pool = (thrdpl_data_t*) malloc(sizeof(thrdpl_data_t)))) if (!(pool = (thrdpl_data_t*) malloc(sizeof(thrdpl_data_t))))
{ {
return THRDPL_NULL; return THRDPL_NULL;
} }
slunkcrypt_bzero(pool, sizeof(thrdpl_data_t)); slunkcrypt_bzero(pool, sizeof(thrdpl_data_t));
pool->thread_count = count; pool->thread_count = (count > 0U) ? count : detect_cpu_count();
BOUND(MIN_THREADS, pool->thread_count, MAX_THREADS);
if (pthread_mutex_init(&pool->mutex, NULL) != 0) for (i = 0U; i < pool->thread_count; ++i)
{ {
goto failure; if (create_worker_thread(&pool->thread_data[i], &pool->thread_count) != 0)
}
if (pthread_cond_init(&pool->ready, NULL) != 0)
{
pthread_mutex_destroy(&pool->mutex);
goto failure;
}
for (i = 0U; i < count; ++i)
{
if (create_thread(i, &pool->threads[i], &pool->mutex, &pool->ready, &pool->queue) != 0)
{ {
for (j = 0U; j < i; ++j) for (j = 0U; j < i; ++j)
{ {
destroy_thread(&pool->threads[j]); destroy_worker_thread(&pool->thread_data[j]);
} }
pthread_cond_destroy(&pool->ready);
pthread_mutex_destroy(&pool->mutex);
goto failure; goto failure;
} }
} }
@ -196,78 +249,79 @@ failure:
return (thrdpl_t)NULL; return (thrdpl_t)NULL;
} }
void thrdpl_submit(const thrdpl_t thrdpl, const thrdpl_worker_t worker, void *const args) size_t slunkcrypt_thrdpl_count(const thrdpl_t thrdpl)
{ {
thrdpl_data_t *const pool = (thrdpl_data_t*)thrdpl; thrdpl_data_t *const pool = (thrdpl_data_t*) thrdpl;
return pool->thread_count;
if (pthread_mutex_lock(&pool->mutex) != 0)
{
abort();
}
thrdpl_thread_t *const thread = &pool->threads[pool->index];
if (++pool->index >= pool->thread_count)
{
pool->index = 0U;
}
++pool->queue;
if (pthread_mutex_unlock(&pool->mutex) != 0)
{
abort();
}
if (sem_wait(&thread->sem_free) != 0)
{
abort();
}
thread->task.worker = worker;
thread->task.args = args;
if (sem_post(&thread->sem_used) != 0)
{
abort();
}
} }
void thrdpl_await(const thrdpl_t thrdpl) void slunkcrypt_thrdpl_exec(const thrdpl_t thrdpl, const size_t index, const thrdpl_worker_t worker, void *const context, const uint8_t *const input, uint8_t *const output, const size_t length)
{ {
thrdpl_data_t *const pool = (thrdpl_data_t*)thrdpl; thrdpl_data_t *const pool = (thrdpl_data_t*) thrdpl;
thrdpl_thread_t *const thread = &pool->thread_data[index];
if (pthread_mutex_lock(&pool->mutex) != 0) PTHRD_MUTEX_LOCK(&thread->mutex);
{
abort();
}
while (pool->queue) while ((thread->state != THRD_STATE_IDLE) && (thread->state != THRD_STATE_EXIT))
{ {
if (pthread_cond_wait(&pool->ready, &pool->mutex) != 0) if (pthread_cond_wait(&thread->cond, &thread->mutex) != 0)
{ {
abort(); abort();
} }
} }
if (pthread_mutex_unlock(&pool->mutex) != 0) if (thread->state == THRD_STATE_EXIT)
{ {
abort(); abort(); /*this is not supposed to happen!*/
}
thread->task.worker = worker;
thread->task.context = context;
thread->task.input = input;
thread->task.output = output;
thread->task.length = length;
thread->state = THRD_STATE_WORK;
PTHRD_COND_BROADCAST(&thread->cond);
PTHRD_MUTEX_UNLOCK(&thread->mutex);
}
void slunkcrypt_thrdpl_await(const thrdpl_t thrdpl)
{
size_t i;
thrdpl_data_t *const pool = (thrdpl_data_t*) thrdpl;
for (i = 0; i < pool->thread_count; ++i)
{
if (pthread_mutex_lock(&pool->thread_data[i].mutex) != 0)
{
abort();
}
while ((pool->thread_data[i].state != THRD_STATE_IDLE) && (pool->thread_data[i].state != THRD_STATE_EXIT))
{
if (pthread_cond_wait(&pool->thread_data[i].cond, &pool->thread_data[i].mutex) != 0)
{
abort();
}
}
if (pthread_mutex_unlock(&pool->thread_data[i].mutex) != 0)
{
abort();
}
} }
} }
void thrdpl_destroy(const thrdpl_t thrdpl) void slunkcrypt_thrdpl_destroy(const thrdpl_t thrdpl)
{ {
size_t i; size_t i;
thrdpl_data_t *const pool = (thrdpl_data_t*)thrdpl; thrdpl_data_t *const pool = (thrdpl_data_t*) thrdpl;
if (pool) if (pool)
{ {
for (i = 0U; i < pool->thread_count; ++i) for (i = 0U; i < pool->thread_count; ++i)
{ {
destroy_thread(&pool->threads[i]); destroy_worker_thread(&pool->thread_data[i]);
} }
pthread_cond_destroy(&pool->ready); slunkcrypt_bzero(pool, sizeof(thrdpl_data_t));
pthread_mutex_destroy(&pool->mutex);
free(pool); free(pool);
} }
} }

View File

@ -9,15 +9,17 @@
#include <stdlib.h> #include <stdlib.h>
#include <stdint.h> #include <stdint.h>
#define MAX_THREADS 8U #define MIN_THREADS 1U
#define MAX_THREADS 16U
#define THRDPL_NULL ((thrdpl_t)NULL) #define THRDPL_NULL ((thrdpl_t)NULL)
typedef void (*thrdpl_worker_t)(void *arguments); typedef void (*thrdpl_worker_t)(const size_t thread_count, void *const context, const uint8_t *const input, uint8_t *const output, const size_t length);
typedef uintptr_t thrdpl_t; typedef uintptr_t thrdpl_t;
thrdpl_t thrdpl_create(const size_t count); thrdpl_t slunkcrypt_thrdpl_create(const size_t count);
void thrdpl_submit(const thrdpl_t thrdpl, const thrdpl_worker_t worker, void *const arguments); size_t slunkcrypt_thrdpl_count(const thrdpl_t thrdpl);
void thrdpl_await(const thrdpl_t thrdpl); void slunkcrypt_thrdpl_exec(const thrdpl_t thrdpl, const size_t index, const thrdpl_worker_t worker, void *const context, const uint8_t *const input, uint8_t *const output, const size_t length);
void thrdpl_destroy(const thrdpl_t thrdpl); void slunkcrypt_thrdpl_await(const thrdpl_t thrdpl);
void slunkcrypt_thrdpl_destroy(const thrdpl_t thrdpl);
#endif #endif