diff --git a/frontend/src/crypt.c b/frontend/src/crypt.c index 3d89b78..063bab0 100644 --- a/frontend/src/crypt.c +++ b/frontend/src/crypt.c @@ -28,7 +28,7 @@ static const uint64_t MAGIC_NUMBER = 0x243F6A8885A308D3ull; -#define BUFFER_SIZE 4096U +#define BUFFER_SIZE 32768U // ========================================================================== // Auxiliary functions @@ -138,7 +138,7 @@ int encrypt(const char *const passphrase, const CHR *const input_path, const CHR { break; /*EOF*/ } - if (!(++refresh_cycles & 0x7)) + if (!(++refresh_cycles & 0x3)) { const uint64_t clk_now = clock_read(); if ((clk_now < clk_update) || (clk_now - clk_update > update_interval)) @@ -322,7 +322,7 @@ int decrypt(const char *const passphrase, const CHR *const input_path, const CHR { break; /*EOF*/ } - if (!(++refresh_cycles & 0x7)) + if (!(++refresh_cycles & 0x3)) { const uint64_t clk_now = clock_read(); if ((clk_now < clk_update) || (clk_now - clk_update > update_interval)) diff --git a/gui/Properties/AssemblyInfo.cs b/gui/Properties/AssemblyInfo.cs index c8c31b5..51a1876 100644 --- a/gui/Properties/AssemblyInfo.cs +++ b/gui/Properties/AssemblyInfo.cs @@ -19,5 +19,5 @@ using System.Windows; [assembly: ComVisible(false)] [assembly: ThemeInfo(ResourceDictionaryLocation.None, ResourceDictionaryLocation.SourceAssembly)] -[assembly: AssemblyVersion("1.1.*")] -[assembly: AssemblyFileVersion("1.1.0.0")] +[assembly: AssemblyVersion("1.2.*")] +[assembly: AssemblyFileVersion("1.2.0.0")] diff --git a/libslunkcrypt/include/slunkcrypt.h b/libslunkcrypt/include/slunkcrypt.h index 1831f91..135ceca 100644 --- a/libslunkcrypt/include/slunkcrypt.h +++ b/libslunkcrypt/include/slunkcrypt.h @@ -75,6 +75,17 @@ static const int SLUNKCRYPT_ABORTED = -2; static const size_t SLUNKCRYPT_PWDLEN_MIN = 8U; static const size_t SLUNKCRYPT_PWDLEN_MAX = 256U; +/* + * Optional parameters + */ +static const uint16_t SLUNKCRYPT_PARAM_VERSION = 1U; +typedef struct +{ + uint16_t version; /* Must set to SLUNKCRYPT_PARAM_VERSION */ + size_t thread_count; /* Number of threads, set to 0 for auto-detection */ +} +slunk_param_t; + /* * Version info */ @@ -101,6 +112,7 @@ SLUNKCRYPT_API int slunkcrypt_generate_nonce(uint64_t *const nonce); * Allocate, reset or free state */ SLUNKCRYPT_API slunkcrypt_t slunkcrypt_alloc(const uint64_t nonce, const uint8_t *const passwd, const size_t passwd_len, const int mode); +SLUNKCRYPT_API slunkcrypt_t slunkcrypt_alloc_ext(const uint64_t nonce, const uint8_t *const passwd, const size_t passwd_len, const int mode, const slunk_param_t *const param); SLUNKCRYPT_API int slunkcrypt_reset(const slunkcrypt_t context, const uint64_t nonce, const uint8_t *const passwd, const size_t passwd_len, const int mode); SLUNKCRYPT_API void slunkcrypt_free(const slunkcrypt_t context); diff --git a/libslunkcrypt/src/slunkcrypt.c b/libslunkcrypt/src/slunkcrypt.c index 70cb984..edb685d 100644 --- a/libslunkcrypt/src/slunkcrypt.c +++ b/libslunkcrypt/src/slunkcrypt.c @@ -38,21 +38,24 @@ typedef struct int reverse_mode; const uint8_t (*wheel)[256U]; uint32_t counter; + size_t index_off; rand_state_t random; - uint8_t *data; } thread_state_t; typedef struct { uint8_t wheel[256U][256U]; - thrdpl_t thread_pool; - size_t thread_idx; thread_state_t thread_data[MAX_THREADS]; } -crypt_state_t; +crypt_data_t; -#define THREAD_COUNT 1U +typedef struct +{ + thrdpl_t thread_pool; + crypt_data_t data; +} +crypt_state_t; // ========================================================================== // Abort flag @@ -109,9 +112,17 @@ static INLINE uint32_t random_next(rand_state_t *const state) return (state->d += 0x000587C5) + state->v; } -static INLINE void random_seed(rand_state_t *const state, uint64_t salt, const uint16_t pepper, const uint8_t *const passwd, const size_t passwd_len) +static INLINE void random_skip(rand_state_t *const state, const size_t skip_count) { size_t i; + for (i = 0U; i < skip_count; ++i) + { + /* UNUSED volatile uint32_t q = */ random_next(state); + } +} + +static INLINE void random_seed(rand_state_t *const state, uint64_t salt, const uint16_t pepper, const uint8_t *const passwd, const size_t passwd_len) +{ keydata_t key; do { @@ -120,52 +131,38 @@ static INLINE void random_seed(rand_state_t *const state, uint64_t salt, const u slunkcrypt_bzero(&key, sizeof(keydata_t)); } while (!(state->x || state->y || state->z || state->w || state->v)); - for (i = 0U; i < 97U; ++i) - { - UNUSED volatile uint32_t q = random_next(state); - } + random_skip(state, 97U); } // ========================================================================== // Initialization // ========================================================================== -static int initialize_state(crypt_state_t *const state, const uint64_t nonce, const uint8_t *const passwd, const size_t passwd_len, const int mode, const int reset) +static int initialize_state(crypt_data_t *const data, const size_t thread_count, const uint64_t nonce, const uint8_t *const passwd, const size_t passwd_len, const int mode) { uint8_t temp[256U][256U]; size_t r, i; - rand_state_t random; - uint32_t counter; const int reverse_mode = BOOLIFY(mode); - /* backup previous value */ - const thrdpl_t thread_pool = reset ? state->thread_pool : THRDPL_NULL; - /* initialize state */ - slunkcrypt_bzero(state, sizeof(crypt_state_t)); - - /* create thread-pool */ - if ((state->thread_pool = reset ? thread_pool : thrdpl_create(THREAD_COUNT)) == THRDPL_NULL) - { - return SLUNKCRYPT_FAILURE; - } + slunkcrypt_bzero(data, sizeof(crypt_data_t)); /* initialize counter */ - random_seed(&random, nonce, (uint16_t)(-1), passwd, passwd_len); - counter = random_next(&random); + random_seed(&data->thread_data[0].random, nonce, (uint16_t)(-1), passwd, passwd_len); + data->thread_data[0].counter = random_next(&data->thread_data[0].random); /* set up the wheel permutations */ for (r = 0U; r < 256U; ++r) { - random_seed(&random, nonce, (uint16_t)r, passwd, passwd_len); + random_seed(&data->thread_data[0].random, nonce, (uint16_t)r, passwd, passwd_len); for (i = 0U; i < 256U; ++i) { - const size_t j = random_next(&random) % (i + 1U); + const size_t j = random_next(&data->thread_data[0].random) % (i + 1U); if (j != i) { - state->wheel[r][i] = state->wheel[r][j]; + data->wheel[r][i] = data->wheel[r][j]; } - state->wheel[r][j] = (uint8_t)i; + data->wheel[r][j] = (uint8_t)i; } CHECK_ABORTED(); } @@ -177,43 +174,38 @@ static int initialize_state(crypt_state_t *const state, const uint64_t nonce, co { for (i = 0U; i < 256U; ++i) { - temp[r][state->wheel[r][i]] = (uint8_t)i; + temp[r][data->wheel[r][i]] = (uint8_t)i; } } for (r = 0U; r < 256U; ++r) { - memcpy(state->wheel[255U - r], temp[r], 256U); + memcpy(data->wheel[255U - r], temp[r], 256U); } slunkcrypt_bzero(temp, sizeof(temp)); CHECK_ABORTED(); } - /* set up thread state */ - random_seed(&random, nonce, 256U, passwd, passwd_len); - for (i = 0U; i < THREAD_COUNT; ++i) + /* initialize up thread state */ + data->thread_data[0].reverse_mode = reverse_mode; + data->thread_data[0].wheel = data->wheel; + data->thread_data[0].index_off = 0U; + random_seed(&data->thread_data[0].random, nonce, 256U, passwd, passwd_len); + for (i = 1U; i < thread_count; ++i) { - state->thread_data[i].reverse_mode = reverse_mode; - state->thread_data[i].wheel = state->wheel; - state->thread_data[i].counter = counter + ((uint32_t)i); - memcpy(&state->thread_data[i].random, &random, sizeof(rand_state_t)); - for (r = 0U; r < i * 63U; ++r) - { - random_next(&state->thread_data[i].random); - } + data->thread_data[i].reverse_mode = data->thread_data[0].reverse_mode; + data->thread_data[i].wheel = data->thread_data[0].wheel; + data->thread_data[i].counter = data->thread_data[0].counter + ((uint32_t)i); + data->thread_data[i].index_off = data->thread_data[i - 1U].index_off + 1U; + memcpy(&data->thread_data[i].random, &data->thread_data[0].random, sizeof(rand_state_t)); + random_skip(&data->thread_data[i].random, i * 63U); CHECK_ABORTED(); } - slunkcrypt_bzero(&counter, sizeof(uint32_t)); - slunkcrypt_bzero(&random, sizeof(rand_state_t)); - return SLUNKCRYPT_SUCCESS; /* aborted */ aborted: - thrdpl_destroy(state->thread_pool); - slunkcrypt_bzero(state, sizeof(crypt_state_t)); - slunkcrypt_bzero(&counter, sizeof(uint32_t)); - slunkcrypt_bzero(&random, sizeof(rand_state_t)); + slunkcrypt_bzero(data, sizeof(crypt_data_t)); return SLUNKCRYPT_ABORTED; } @@ -232,22 +224,44 @@ static INLINE void update_offset(uint8_t *const offset, uint32_t seed, rand_stat } offset[reverse ? (255U - i) : i] = (uint8_t)seed; } - for (i = 0U; i < 63U * (THREAD_COUNT - 1U); ++i) - { - random_next(state); - } } -static INLINE void process_next_symbol(thread_state_t *const state) +static INLINE uint8_t process_next_symbol(thread_state_t *const state, uint8_t value) { uint8_t offset[256U]; size_t i; update_offset(offset, state->counter, &state->random, state->reverse_mode); for (i = 0U; i < 256U; ++i) { - *state->data = (state->wheel[i][(*state->data + offset[i]) & 0xFF] - offset[i]) & 0xFF; + value = (state->wheel[i][(value + offset[i]) & 0xFF] - offset[i]) & 0xFF; } - state->counter += THREAD_COUNT; + return value; +} + +// ========================================================================== +// Thread entry point +// ========================================================================== + +static INLINE void update_index(thread_state_t *const state, const size_t thread_count, const size_t length) +{ + const size_t remaining = thread_count - (length % thread_count); + if (remaining != thread_count) + { + state->index_off = (state->index_off + remaining) % thread_count; + } +} + +static void thread_worker(const size_t thread_count, void *const context, const uint8_t *const input, uint8_t *const output, const size_t length) +{ + thread_state_t *const state = (thread_state_t*) context; + size_t i; + for (i = state->index_off; i < length; i += thread_count) + { + output[i] = process_next_symbol(state, input[i]); + state->counter += (uint32_t)thread_count; + random_skip(&state->random, 63U * (thread_count - 1U)); + } + update_index(state, thread_count, length); } // ========================================================================== @@ -272,22 +286,39 @@ int slunkcrypt_generate_nonce(uint64_t *const nonce) } slunkcrypt_t slunkcrypt_alloc(const uint64_t nonce, const uint8_t *const passwd, const size_t passwd_len, const int mode) +{ + slunk_param_t param = { SLUNKCRYPT_PARAM_VERSION, 0U }; + return slunkcrypt_alloc_ext(nonce, passwd, passwd_len, mode, ¶m); +} + +slunkcrypt_t slunkcrypt_alloc_ext(const uint64_t nonce, const uint8_t *const passwd, const size_t passwd_len, const int mode, const slunk_param_t *const param) { crypt_state_t* state = NULL; - if ((!passwd) || (passwd_len < SLUNKCRYPT_PWDLEN_MIN) || (passwd_len > SLUNKCRYPT_PWDLEN_MAX) || (mode < SLUNKCRYPT_ENCRYPT) || (mode > SLUNKCRYPT_DECRYPT)) + + if ((!passwd) || (passwd_len < SLUNKCRYPT_PWDLEN_MIN) || (passwd_len > SLUNKCRYPT_PWDLEN_MAX) || + (mode < SLUNKCRYPT_ENCRYPT) || (mode > SLUNKCRYPT_DECRYPT) || (!param) || (param->version == 0U) || (param->version > SLUNKCRYPT_PARAM_VERSION)) { return SLUNKCRYPT_NULL; } + if (!(state = (crypt_state_t*)malloc(sizeof(crypt_state_t)))) { return SLUNKCRYPT_NULL; } - if (initialize_state(state, nonce, passwd, passwd_len, mode, 0) == SLUNKCRYPT_SUCCESS) + + if ((state->thread_pool = slunkcrypt_thrdpl_create(param->thread_count)) == THRDPL_NULL) + { + free(state); + return SLUNKCRYPT_NULL; + } + + if (initialize_state(&state->data, slunkcrypt_thrdpl_count(state->thread_pool), nonce, passwd, passwd_len, mode) == SLUNKCRYPT_SUCCESS) { return ((slunkcrypt_t)state); } else { + slunkcrypt_thrdpl_destroy(state->thread_pool); slunkcrypt_bzero(state, sizeof(crypt_state_t)); return SLUNKCRYPT_NULL; } @@ -295,15 +326,16 @@ slunkcrypt_t slunkcrypt_alloc(const uint64_t nonce, const uint8_t *const passwd, int slunkcrypt_reset(const slunkcrypt_t context, const uint64_t nonce, const uint8_t *const passwd, const size_t passwd_len, const int mode) { - crypt_state_t *const state = (crypt_state_t*)context; + crypt_state_t *const state = (crypt_state_t*) context; int result = SLUNKCRYPT_FAILURE; + if ((!state) || (!passwd) || (passwd_len < SLUNKCRYPT_PWDLEN_MIN) || (passwd_len > SLUNKCRYPT_PWDLEN_MAX) || (mode < SLUNKCRYPT_ENCRYPT) || (mode > SLUNKCRYPT_DECRYPT)) { return SLUNKCRYPT_FAILURE; } - if ((result = initialize_state(state, nonce, passwd, passwd_len, mode, 1)) != SLUNKCRYPT_SUCCESS) + if ((result = initialize_state(&state->data, slunkcrypt_thrdpl_count(state->thread_pool), nonce, passwd, passwd_len, mode)) != SLUNKCRYPT_SUCCESS) { - slunkcrypt_bzero(state, sizeof(crypt_state_t)); + slunkcrypt_bzero(&state->data, sizeof(crypt_data_t)); } return result; } @@ -311,7 +343,7 @@ int slunkcrypt_reset(const slunkcrypt_t context, const uint64_t nonce, const uin int slunkcrypt_process(const slunkcrypt_t context, const uint8_t *const input, uint8_t *const output, size_t length) { size_t i; - crypt_state_t *const state = (crypt_state_t*)context; + crypt_state_t *const state = (crypt_state_t*) context; if (!state) { return SLUNKCRYPT_FAILURE; @@ -319,27 +351,27 @@ int slunkcrypt_process(const slunkcrypt_t context, const uint8_t *const input, u if (length > 0U) { - memcpy(output, input, length * sizeof(uint8_t)); - for (i = 0; i < length; ++i) + const size_t thread_count = slunkcrypt_thrdpl_count(state->thread_pool); + for (i = 0; i < thread_count; ++i) { - abort(); //process_next_symbol(state, output + i); - CHECK_ABORTED(); + slunkcrypt_thrdpl_exec(state->thread_pool, i, thread_worker, &state->data.thread_data[i], input, output, length); } + slunkcrypt_thrdpl_await(state->thread_pool); } - thrdpl_await(state->thread_pool); + CHECK_ABORTED(); return SLUNKCRYPT_SUCCESS; aborted: - thrdpl_await(state->thread_pool); - slunkcrypt_bzero(state, sizeof(crypt_state_t)); + slunkcrypt_thrdpl_await(state->thread_pool); + slunkcrypt_bzero(&state->data, sizeof(crypt_data_t)); return SLUNKCRYPT_ABORTED; } int slunkcrypt_inplace(const slunkcrypt_t context, uint8_t *const buffer, size_t length) { size_t i; - crypt_state_t *const state = (crypt_state_t*)context; + crypt_state_t *const state = (crypt_state_t*) context; if (!state) { return SLUNKCRYPT_FAILURE; @@ -347,34 +379,29 @@ int slunkcrypt_inplace(const slunkcrypt_t context, uint8_t *const buffer, size_t if (length > 0U) { - for (i = 0; i < length; ++i) + const size_t thread_count = slunkcrypt_thrdpl_count(state->thread_pool); + for (i = 0; i < thread_count; ++i) { - state->thread_data[state->thread_idx].data = buffer + i; - //process_next_symbol(&state->thread_data[state->thread_idx]); - thrdpl_submit(state->thread_pool, process_next_symbol, &state->thread_data[state->thread_idx]); - if (++state->thread_idx >= THREAD_COUNT) - { - state->thread_idx = 0U; - } - CHECK_ABORTED(); + slunkcrypt_thrdpl_exec(state->thread_pool, i, thread_worker, &state->data.thread_data[i], buffer, buffer, length); } + slunkcrypt_thrdpl_await(state->thread_pool); } - thrdpl_await(state->thread_pool); + CHECK_ABORTED(); return SLUNKCRYPT_SUCCESS; aborted: - thrdpl_await(state->thread_pool); - slunkcrypt_bzero(state, sizeof(crypt_state_t)); + slunkcrypt_thrdpl_await(state->thread_pool); + slunkcrypt_bzero(&state->data, sizeof(crypt_data_t)); return SLUNKCRYPT_ABORTED; } void slunkcrypt_free(const slunkcrypt_t context) { - crypt_state_t *const state = (crypt_state_t*)context; + crypt_state_t *const state = (crypt_state_t*) context; if (state) { - thrdpl_destroy(state->thread_pool); + slunkcrypt_thrdpl_destroy(state->thread_pool); slunkcrypt_bzero(state, sizeof(crypt_state_t)); free(state); } diff --git a/libslunkcrypt/src/thread.c b/libslunkcrypt/src/thread.c index 054341e..620cbe0 100644 --- a/libslunkcrypt/src/thread.c +++ b/libslunkcrypt/src/thread.c @@ -3,7 +3,7 @@ /* This work has been released under the CC0 1.0 Universal license! */ /******************************************************************************/ -#ifdef _MSC_VER +#if defined(_MSC_VER) && !defined(_DLL) #define PTW32_STATIC_LIB 1 #endif @@ -16,22 +16,35 @@ /* PThread */ #include -#include +#ifdef __unix__ +#include +#endif + +/* States */ +#define THRD_STATE_IDLE 0 +#define THRD_STATE_WORK 1 +#define THRD_STATE_EXIT 2 + +// ========================================================================== +// Data types +// ========================================================================== typedef struct { thrdpl_worker_t worker; - void *args; + void *context; + size_t length; + const uint8_t *input; + uint8_t *output; } thrdpl_task_t; typedef struct { - size_t index; - pthread_mutex_t *mutex; - pthread_cond_t *ready; - size_t *queue; - sem_t sem_free, sem_used; + const size_t *count; + int state; + pthread_mutex_t mutex; + pthread_cond_t cond; pthread_t thread; thrdpl_task_t task; } @@ -39,65 +52,101 @@ thrdpl_thread_t; typedef struct { - size_t thread_count, index, queue; - pthread_mutex_t mutex; - pthread_cond_t ready; - thrdpl_thread_t threads[MAX_THREADS]; + size_t thread_count; + thrdpl_thread_t thread_data[MAX_THREADS]; } thrdpl_data_t; +// ========================================================================== +// Helper macros +// ========================================================================== + +#define BOUND(MIN, VAL, MAX) \ +{ \ + if ((VAL) > (MAX)) { VAL = (MAX); } \ + if ((VAL) < (MIN)) { VAL = (MIN); } \ +} \ +while (0) + +#define PTHRD_MUTEX_LOCK(X) do \ +{ \ + if (pthread_mutex_lock((X)) != 0) \ + { \ + abort(); \ + } \ +} \ +while(0) + +#define PTHRD_MUTEX_UNLOCK(X) do \ +{ \ + if (pthread_mutex_unlock((X)) != 0) \ + { \ + abort(); \ + } \ +} \ +while(0) + +#define PTHRD_COND_BROADCAST(X) do \ +{ \ + if (pthread_cond_broadcast((X)) != 0) \ + { \ + abort(); \ + } \ +} \ +while(0) + +#define PTHRD_COND_WAIT(X,Y) do \ +{ \ + if (pthread_cond_wait((X), (Y)) != 0) \ + { \ + abort(); \ + } \ +} \ +while(0) + +#define CHECK_IF_CANCELLED() do \ +{ \ + if (data->state == THRD_STATE_EXIT) \ + { \ + if (pthread_mutex_unlock(&data->mutex) != 0) \ + { \ + abort(); \ + } \ + return NULL; /* cancelled */ \ + } \ +} \ +while(0) + // ========================================================================== // Thread main // ========================================================================== -static void *thread_main(void *const arg) +static void *worker_thread_main(void *const arg) { - thrdpl_thread_t *const data = (thrdpl_thread_t*)arg; + thrdpl_thread_t *const data = (thrdpl_thread_t*) arg; + thrdpl_task_t *task; for (;;) { - if (sem_wait(&data->sem_used) != 0) + PTHRD_MUTEX_LOCK(&data->mutex); + CHECK_IF_CANCELLED(); + + while (data->state != THRD_STATE_WORK) { - abort(); + PTHRD_COND_WAIT(&data->cond, &data->mutex); + CHECK_IF_CANCELLED(); } - if (pthread_mutex_lock(data->mutex) != 0) - { - abort(); - } + task = &data->task; + PTHRD_MUTEX_UNLOCK(&data->mutex); - const thrdpl_worker_t worker = data->task.worker; - void *const args = data->task.args; + task->worker(*data->count, task->context, task->input, task->output, task->length); - if (pthread_mutex_unlock(data->mutex) != 0) - { - abort(); - } - - worker(args); - - if (pthread_mutex_lock(data->mutex) != 0) - { - abort(); - } - - if (!(*data->queue -= 1U)) - { - if (pthread_cond_broadcast(data->ready) != 0) - { - abort(); - } - } - - if (pthread_mutex_unlock(data->mutex) != 0) - { - abort(); - } - - if (sem_post(&data->sem_free) != 0) - { - abort(); - } + PTHRD_MUTEX_LOCK(&data->mutex); + CHECK_IF_CANCELLED(); + data->state = THRD_STATE_IDLE; + PTHRD_COND_BROADCAST(&data->cond); + PTHRD_MUTEX_UNLOCK(&data->mutex); } } @@ -105,40 +154,61 @@ static void *thread_main(void *const arg) // Manage threads // ========================================================================== -static int create_thread(const size_t index, thrdpl_thread_t *const thread_data, pthread_mutex_t *const mutex, pthread_cond_t *const ready, size_t *const queue) +#if defined(__unix__) +# define GET_NPROCS_FUNCTION() get_nprocs() +#elif defined(PTW32_VERSION) +# define GET_NPROCS_FUNCTION() pthread_num_processors_np() +#endif + +static size_t detect_cpu_count(void) { - thread_data->index = index; - thread_data->mutex = mutex; - thread_data->ready = ready; - thread_data->queue = queue; +#ifdef GET_NPROCS_FUNCTION + const int cpu_count = GET_NPROCS_FUNCTION(); + if (cpu_count > 0) + { + return (size_t) cpu_count; + } +#endif + return 1U; +} - if (sem_init(&thread_data->sem_free, 0, 1U) != 0) +static int create_worker_thread(thrdpl_thread_t *const thread_data, const size_t *const count) +{ + thread_data->count = count; + thread_data->state = THRD_STATE_IDLE; + + if (pthread_mutex_init(&thread_data->mutex, NULL) != 0) { return -1; } - if (sem_init(&thread_data->sem_used, 0, 0U) != 0) + if (pthread_cond_init(&thread_data->cond, NULL) != 0) { - sem_destroy(&thread_data->sem_free); + pthread_mutex_destroy(&thread_data->mutex); return -1; } - if (pthread_create(&thread_data->thread, NULL, thread_main, thread_data) != 0) + if (pthread_create(&thread_data->thread, NULL, worker_thread_main, thread_data) != 0) { - sem_destroy(&thread_data->sem_used); - sem_destroy(&thread_data->sem_free); + pthread_cond_destroy(&thread_data->cond); + pthread_mutex_destroy(&thread_data->mutex); return -1; } return 0; } -static int destroy_thread(thrdpl_thread_t *const thread_data) +static int destroy_worker_thread(thrdpl_thread_t *const thread_data) { - pthread_cancel(thread_data->thread); + PTHRD_MUTEX_LOCK(&thread_data->mutex); + thread_data->state = THRD_STATE_EXIT; + PTHRD_COND_BROADCAST(&thread_data->cond); + PTHRD_MUTEX_UNLOCK(&thread_data->mutex); + pthread_join(thread_data->thread, NULL); - sem_destroy(&thread_data->sem_used); - sem_destroy(&thread_data->sem_free); + pthread_mutex_destroy(&thread_data->mutex); + pthread_cond_destroy(&thread_data->cond); + return 0; } @@ -146,45 +216,28 @@ static int destroy_thread(thrdpl_thread_t *const thread_data) // Thread pool API // ========================================================================== -thrdpl_t thrdpl_create(const size_t count) +thrdpl_t slunkcrypt_thrdpl_create(const size_t count) { size_t i, j; thrdpl_data_t *pool = NULL; - if ((count < 1U) || (count > MAX_THREADS)) - { - return THRDPL_NULL; - } - if (!(pool = (thrdpl_data_t*) malloc(sizeof(thrdpl_data_t)))) { return THRDPL_NULL; } slunkcrypt_bzero(pool, sizeof(thrdpl_data_t)); - pool->thread_count = count; + pool->thread_count = (count > 0U) ? count : detect_cpu_count(); + BOUND(MIN_THREADS, pool->thread_count, MAX_THREADS); - if (pthread_mutex_init(&pool->mutex, NULL) != 0) + for (i = 0U; i < pool->thread_count; ++i) { - goto failure; - } - - if (pthread_cond_init(&pool->ready, NULL) != 0) - { - pthread_mutex_destroy(&pool->mutex); - goto failure; - } - - for (i = 0U; i < count; ++i) - { - if (create_thread(i, &pool->threads[i], &pool->mutex, &pool->ready, &pool->queue) != 0) + if (create_worker_thread(&pool->thread_data[i], &pool->thread_count) != 0) { for (j = 0U; j < i; ++j) { - destroy_thread(&pool->threads[j]); + destroy_worker_thread(&pool->thread_data[j]); } - pthread_cond_destroy(&pool->ready); - pthread_mutex_destroy(&pool->mutex); goto failure; } } @@ -196,78 +249,79 @@ failure: return (thrdpl_t)NULL; } -void thrdpl_submit(const thrdpl_t thrdpl, const thrdpl_worker_t worker, void *const args) +size_t slunkcrypt_thrdpl_count(const thrdpl_t thrdpl) { - thrdpl_data_t *const pool = (thrdpl_data_t*)thrdpl; - - if (pthread_mutex_lock(&pool->mutex) != 0) - { - abort(); - } - - thrdpl_thread_t *const thread = &pool->threads[pool->index]; - - if (++pool->index >= pool->thread_count) - { - pool->index = 0U; - } - - ++pool->queue; - - if (pthread_mutex_unlock(&pool->mutex) != 0) - { - abort(); - } - - if (sem_wait(&thread->sem_free) != 0) - { - abort(); - } - - thread->task.worker = worker; - thread->task.args = args; - - if (sem_post(&thread->sem_used) != 0) - { - abort(); - } + thrdpl_data_t *const pool = (thrdpl_data_t*) thrdpl; + return pool->thread_count; } -void thrdpl_await(const thrdpl_t thrdpl) +void slunkcrypt_thrdpl_exec(const thrdpl_t thrdpl, const size_t index, const thrdpl_worker_t worker, void *const context, const uint8_t *const input, uint8_t *const output, const size_t length) { - thrdpl_data_t *const pool = (thrdpl_data_t*)thrdpl; + thrdpl_data_t *const pool = (thrdpl_data_t*) thrdpl; + thrdpl_thread_t *const thread = &pool->thread_data[index]; - if (pthread_mutex_lock(&pool->mutex) != 0) - { - abort(); - } + PTHRD_MUTEX_LOCK(&thread->mutex); - while (pool->queue) + while ((thread->state != THRD_STATE_IDLE) && (thread->state != THRD_STATE_EXIT)) { - if (pthread_cond_wait(&pool->ready, &pool->mutex) != 0) + if (pthread_cond_wait(&thread->cond, &thread->mutex) != 0) { abort(); } } - if (pthread_mutex_unlock(&pool->mutex) != 0) + if (thread->state == THRD_STATE_EXIT) { - abort(); + abort(); /*this is not supposed to happen!*/ + } + + thread->task.worker = worker; + thread->task.context = context; + thread->task.input = input; + thread->task.output = output; + thread->task.length = length; + thread->state = THRD_STATE_WORK; + + PTHRD_COND_BROADCAST(&thread->cond); + PTHRD_MUTEX_UNLOCK(&thread->mutex); +} + +void slunkcrypt_thrdpl_await(const thrdpl_t thrdpl) +{ + size_t i; + thrdpl_data_t *const pool = (thrdpl_data_t*) thrdpl; + + for (i = 0; i < pool->thread_count; ++i) + { + if (pthread_mutex_lock(&pool->thread_data[i].mutex) != 0) + { + abort(); + } + while ((pool->thread_data[i].state != THRD_STATE_IDLE) && (pool->thread_data[i].state != THRD_STATE_EXIT)) + { + if (pthread_cond_wait(&pool->thread_data[i].cond, &pool->thread_data[i].mutex) != 0) + { + abort(); + } + } + if (pthread_mutex_unlock(&pool->thread_data[i].mutex) != 0) + { + abort(); + } } } -void thrdpl_destroy(const thrdpl_t thrdpl) +void slunkcrypt_thrdpl_destroy(const thrdpl_t thrdpl) { size_t i; - thrdpl_data_t *const pool = (thrdpl_data_t*)thrdpl; + thrdpl_data_t *const pool = (thrdpl_data_t*) thrdpl; if (pool) { for (i = 0U; i < pool->thread_count; ++i) { - destroy_thread(&pool->threads[i]); + destroy_worker_thread(&pool->thread_data[i]); } - pthread_cond_destroy(&pool->ready); - pthread_mutex_destroy(&pool->mutex); + slunkcrypt_bzero(pool, sizeof(thrdpl_data_t)); free(pool); } } diff --git a/libslunkcrypt/src/thread.h b/libslunkcrypt/src/thread.h index 1dcef2d..c065aa6 100644 --- a/libslunkcrypt/src/thread.h +++ b/libslunkcrypt/src/thread.h @@ -9,15 +9,17 @@ #include #include -#define MAX_THREADS 8U +#define MIN_THREADS 1U +#define MAX_THREADS 16U #define THRDPL_NULL ((thrdpl_t)NULL) -typedef void (*thrdpl_worker_t)(void *arguments); +typedef void (*thrdpl_worker_t)(const size_t thread_count, void *const context, const uint8_t *const input, uint8_t *const output, const size_t length); typedef uintptr_t thrdpl_t; -thrdpl_t thrdpl_create(const size_t count); -void thrdpl_submit(const thrdpl_t thrdpl, const thrdpl_worker_t worker, void *const arguments); -void thrdpl_await(const thrdpl_t thrdpl); -void thrdpl_destroy(const thrdpl_t thrdpl); +thrdpl_t slunkcrypt_thrdpl_create(const size_t count); +size_t slunkcrypt_thrdpl_count(const thrdpl_t thrdpl); +void slunkcrypt_thrdpl_exec(const thrdpl_t thrdpl, const size_t index, const thrdpl_worker_t worker, void *const context, const uint8_t *const input, uint8_t *const output, const size_t length); +void slunkcrypt_thrdpl_await(const thrdpl_t thrdpl); +void slunkcrypt_thrdpl_destroy(const thrdpl_t thrdpl); -#endif \ No newline at end of file +#endif