diff --git a/libslunkcrypt/src/slunkcrypt.c b/libslunkcrypt/src/slunkcrypt.c index 635ffa5..abd4462 100644 --- a/libslunkcrypt/src/slunkcrypt.c +++ b/libslunkcrypt/src/slunkcrypt.c @@ -262,10 +262,8 @@ static void thread_worker(const size_t thread_count, void *const context, uint8_ buffer[i] = process_next_symbol(state, buffer[i]); state->counter += (uint32_t)thread_count; random_skip(&state->random, 63U * (thread_count - 1U)); - CHECK_ABORTED(); } -aborted: update_index(state, thread_count, length); } @@ -311,7 +309,14 @@ slunkcrypt_t slunkcrypt_alloc_ext(const uint64_t nonce, const uint8_t *const pas return SLUNKCRYPT_NULL; } - state->thread_pool = slunkcrypt_thrdpl_create(param->thread_count); + if ((state->thread_pool = slunkcrypt_thrdpl_create(param->thread_count, thread_worker))) + { + size_t i; + for (i = 0U; i < slunkcrypt_thrdpl_count(state->thread_pool); ++i) + { + slunkcrypt_thrdpl_init(state->thread_pool, i, &state->data.thread_data[i]); + } + } if (initialize_state(&state->data, THREAD_COUNT(state), nonce, passwd, passwd_len, mode) == SLUNKCRYPT_SUCCESS) { @@ -368,12 +373,7 @@ int slunkcrypt_inplace(const slunkcrypt_t context, uint8_t *const buffer, size_t const size_t thread_count = THREAD_COUNT(state); if (thread_count > 1U) { - size_t i; - for (i = 0U; i < thread_count; ++i) - { - slunkcrypt_thrdpl_exec(state->thread_pool, i, thread_worker, &state->data.thread_data[i], buffer, length); - } - slunkcrypt_thrdpl_await(state->thread_pool); + slunkcrypt_thrdpl_exec(state->thread_pool, buffer, length); } else { diff --git a/libslunkcrypt/src/thread.c b/libslunkcrypt/src/thread.c index 0351101..d4d1d70 100644 --- a/libslunkcrypt/src/thread.c +++ b/libslunkcrypt/src/thread.c @@ -34,27 +34,19 @@ typedef struct { thrdpl_worker_t worker; - void *context; - size_t length; uint8_t *buffer; -} -thrdpl_task_t; - -typedef struct -{ - size_t thread_count, pending; + size_t thread_count, generation, remain, length; + int stop_flag; pthread_mutex_t mutex; - pthread_cond_t cond_pending; + pthread_cond_t cond_0, cond_1; } thrdpl_shared_t; typedef struct { thrdpl_shared_t *shared; - size_t state; - pthread_cond_t cond_state; + void *context; pthread_t thread; - thrdpl_task_t task; } thrdpl_thread_t; @@ -68,12 +60,12 @@ struct thrdpl_data_t // Utilities // ========================================================================== -static INLINE size_t bound(const size_t min, const size_t value, const size_t max) +static INLINE size_t BOUND(const size_t min, const size_t value, const size_t max) { return (value < min) ? min : ((value > max) ? max : value); } -#define PTHRD_MUTEX_LOCK(X) do \ +#define PTHRD_MUTEX_ENTER(X) do \ { \ if (pthread_mutex_lock((X)) != 0) \ { \ @@ -82,7 +74,7 @@ static INLINE size_t bound(const size_t min, const size_t value, const size_t ma } \ while(0) -#define PTHRD_MUTEX_UNLOCK(X) do \ +#define PTHRD_MUTEX_LEAVE(X) do \ { \ if (pthread_mutex_unlock((X)) != 0) \ { \ @@ -100,7 +92,7 @@ while(0) } \ while(0) -#define PTHRD_COND_BROADCAST(X) do \ +#define PTHRD_COND_BRDCST(X) do \ { \ if (pthread_cond_broadcast((X)) != 0) \ { \ @@ -118,11 +110,11 @@ while(0) } \ while(0) -#define CHECK_IF_CANCELLED() do \ +#define CHECK_IF_STOPPED() do \ { \ - if (data->state == TSTATE_EXIT) \ + if (shared->stop_flag) \ { \ - PTHRD_MUTEX_UNLOCK(&shared->mutex); \ + PTHRD_MUTEX_LEAVE(&shared->mutex); \ return NULL; \ } \ } \ @@ -137,35 +129,31 @@ static void *worker_thread_main(void *const arg) thrdpl_thread_t *const data = (thrdpl_thread_t*) arg; thrdpl_shared_t *const shared = (thrdpl_shared_t*) data->shared; - thrdpl_task_t *task; + size_t previous = 0U; + + PTHRD_MUTEX_ENTER(&shared->mutex); + CHECK_IF_STOPPED(); for (;;) { - PTHRD_MUTEX_LOCK(&shared->mutex); - CHECK_IF_CANCELLED(); - - while (data->state != TSTATE_WORK) + while (shared->generation == previous) { - PTHRD_COND_WAIT(&data->cond_state, &shared->mutex); - CHECK_IF_CANCELLED(); + PTHRD_COND_WAIT(&shared->cond_0, &shared->mutex); + CHECK_IF_STOPPED(); } - task = &data->task; - PTHRD_MUTEX_UNLOCK(&shared->mutex); + previous = shared->generation; + PTHRD_MUTEX_LEAVE(&shared->mutex); - task->worker(shared->thread_count, task->context, task->buffer, task->length); + shared->worker(shared->thread_count, data->context, shared->buffer, shared->length); - PTHRD_MUTEX_LOCK(&shared->mutex); - CHECK_IF_CANCELLED(); + PTHRD_MUTEX_ENTER(&shared->mutex); + CHECK_IF_STOPPED(); - data->state = TSTATE_IDLE; - if (!(--shared->pending)) + if (!(--shared->remain)) { - PTHRD_COND_BROADCAST(&shared->cond_pending); + PTHRD_COND_SIGNAL(&shared->cond_1); } - - PTHRD_MUTEX_UNLOCK(&shared->mutex); - PTHRD_COND_SIGNAL(&data->cond_state); } } @@ -191,52 +179,16 @@ static size_t detect_cpu_count(void) return 1U; } -// ========================================================================== -// Manage threads -// ========================================================================== - -static int create_worker(thrdpl_shared_t *const shared, thrdpl_thread_t *const thread_data) -{ - thread_data->state = TSTATE_IDLE; - thread_data->shared = shared; - - if (pthread_cond_init(&thread_data->cond_state, NULL) != 0) - { - return -1; - } - - if (pthread_create(&thread_data->thread, NULL, worker_thread_main, thread_data) != 0) - { - pthread_cond_destroy(&thread_data->cond_state); - return -1; - } - - return 0; -} - -static int destroy_worker(thrdpl_thread_t *const thread_data) -{ - PTHRD_MUTEX_LOCK(&thread_data->shared->mutex); - thread_data->state = TSTATE_EXIT; - PTHRD_MUTEX_UNLOCK(&thread_data->shared->mutex); - - PTHRD_COND_BROADCAST(&thread_data->cond_state); - pthread_join(thread_data->thread, NULL); - pthread_cond_destroy(&thread_data->cond_state); - - return 0; -} - // ========================================================================== // Thread pool API // ========================================================================== -thrdpl_t *slunkcrypt_thrdpl_create(const size_t count) +thrdpl_t *slunkcrypt_thrdpl_create(const size_t count, const thrdpl_worker_t worker) { - size_t i, j; + size_t i; thrdpl_t *thrdpl = NULL; - const size_t cpu_count = bound(1U, (count > 0U) ? count : detect_cpu_count(), MAX_THREADS); + const size_t cpu_count = BOUND(1U, (count > 0U) ? count : detect_cpu_count(), MAX_THREADS); if (cpu_count < 2U) { return NULL; @@ -248,31 +200,35 @@ thrdpl_t *slunkcrypt_thrdpl_create(const size_t count) } memset(thrdpl, 0, sizeof(thrdpl_t)); - thrdpl->shared.thread_count = cpu_count; + thrdpl->shared.worker = worker; if (pthread_mutex_init(&thrdpl->shared.mutex, NULL) != 0) { goto failure; } - if (pthread_cond_init(&thrdpl->shared.cond_pending, NULL) != 0) + if (pthread_cond_init(&thrdpl->shared.cond_0, NULL) != 0) { pthread_mutex_destroy(&thrdpl->shared.mutex); goto failure; } + if (pthread_cond_init(&thrdpl->shared.cond_1, NULL) != 0) + { + pthread_cond_destroy(&thrdpl->shared.cond_0); + pthread_mutex_destroy(&thrdpl->shared.mutex); + goto failure; + } + for (i = 0U; i < cpu_count; ++i) { - if (create_worker(&thrdpl->shared, &thrdpl->thread_data[i]) != 0) + thrdpl->thread_data[i].shared = &thrdpl->shared; + if (pthread_create(&thrdpl->thread_data[i].thread, NULL, worker_thread_main, &thrdpl->thread_data[i]) != 0) { - for (j = 0U; j < i; ++j) - { - destroy_worker(&thrdpl->thread_data[j]); - } - pthread_cond_destroy(&thrdpl->shared.cond_pending); - pthread_mutex_destroy(&thrdpl->shared.mutex); - goto failure; + slunkcrypt_thrdpl_destroy(thrdpl); + return NULL; } + ++thrdpl->shared.thread_count; } return thrdpl; @@ -287,58 +243,56 @@ size_t slunkcrypt_thrdpl_count(const thrdpl_t *const thrdpl) return thrdpl->shared.thread_count; } -void slunkcrypt_thrdpl_exec(thrdpl_t *const thrdpl, const size_t index, const thrdpl_worker_t worker, void *const context, uint8_t *const buffer, const size_t length) +void slunkcrypt_thrdpl_init(thrdpl_t *const thrdpl, const size_t index, void *const context) { - thrdpl_thread_t *const thread = &thrdpl->thread_data[index]; + thrdpl->thread_data[index].context = context; +} - PTHRD_MUTEX_LOCK(&thrdpl->shared.mutex); +void slunkcrypt_thrdpl_exec(thrdpl_t *const thrdpl, uint8_t *const buffer, const size_t length) +{ + PTHRD_MUTEX_ENTER(&thrdpl->shared.mutex); - while ((thread->state != TSTATE_IDLE) && (thread->state != TSTATE_EXIT)) - { - PTHRD_COND_WAIT(&thread->cond_state, &thrdpl->shared.mutex); - } - - if (thread->state == TSTATE_EXIT) + if (thrdpl->shared.stop_flag || (thrdpl->shared.remain != 0U)) { abort(); /*this is not supposed to happen!*/ } - thread->state = TSTATE_WORK; - thread->task.worker = worker; - thread->task.context = context; - thread->task.buffer = buffer; - thread->task.length = length; + thrdpl->shared.buffer = buffer; + thrdpl->shared.length = length; + thrdpl->shared.remain = thrdpl->shared.thread_count; + + ++thrdpl->shared.generation; + PTHRD_COND_BRDCST(&thrdpl->shared.cond_0); - ++thrdpl->shared.pending; - - PTHRD_MUTEX_UNLOCK(&thrdpl->shared.mutex); - PTHRD_COND_SIGNAL(&thread->cond_state); -} - -void slunkcrypt_thrdpl_await(thrdpl_t *const thrdpl) -{ - PTHRD_MUTEX_LOCK(&thrdpl->shared.mutex); - - while (thrdpl->shared.pending) + while (thrdpl->shared.remain) { - PTHRD_COND_WAIT(&thrdpl->shared.cond_pending, &thrdpl->shared.mutex); + PTHRD_COND_WAIT(&thrdpl->shared.cond_1, &thrdpl->shared.mutex); } - PTHRD_MUTEX_UNLOCK(&thrdpl->shared.mutex); + PTHRD_MUTEX_LEAVE(&thrdpl->shared.mutex); } void slunkcrypt_thrdpl_destroy(thrdpl_t *const thrdpl) { size_t i; + PTHRD_MUTEX_ENTER(&thrdpl->shared.mutex); - if (thrdpl) + if (!thrdpl->shared.stop_flag) { - for (i = 0U; i < thrdpl->shared.thread_count; ++i) - { - destroy_worker(&thrdpl->thread_data[i]); - } - pthread_cond_destroy(&thrdpl->shared.cond_pending); - pthread_mutex_destroy(&thrdpl->shared.mutex); - free(thrdpl); + thrdpl->shared.stop_flag = 1; + PTHRD_COND_BRDCST(&thrdpl->shared.cond_0); } + + PTHRD_MUTEX_LEAVE(&thrdpl->shared.mutex); + + for (i = 0U; i < thrdpl->shared.thread_count; ++i) + { + pthread_join(thrdpl->thread_data[i].thread, NULL); + } + + pthread_cond_destroy(&thrdpl->shared.cond_0); + pthread_cond_destroy(&thrdpl->shared.cond_1); + pthread_mutex_destroy(&thrdpl->shared.mutex); + + free(thrdpl); } diff --git a/libslunkcrypt/src/thread.h b/libslunkcrypt/src/thread.h index b21c3a5..a2ae84c 100644 --- a/libslunkcrypt/src/thread.h +++ b/libslunkcrypt/src/thread.h @@ -14,10 +14,10 @@ typedef void (*thrdpl_worker_t)(const size_t thread_count, void *const context, uint8_t *const buffer, const size_t length); typedef struct thrdpl_data_t thrdpl_t; -thrdpl_t *slunkcrypt_thrdpl_create(const size_t count); +thrdpl_t *slunkcrypt_thrdpl_create(const size_t count, const thrdpl_worker_t worker); size_t slunkcrypt_thrdpl_count(const thrdpl_t *const thrdpl); -void slunkcrypt_thrdpl_exec(thrdpl_t *const thrdpl, const size_t index, const thrdpl_worker_t worker, void *const context, uint8_t *const buffer, const size_t length); -void slunkcrypt_thrdpl_await(thrdpl_t *const thrdpl); +void slunkcrypt_thrdpl_init(thrdpl_t *const thrdpl, const size_t index, void *const context); +void slunkcrypt_thrdpl_exec(thrdpl_t *const thrdpl, uint8_t *const buffer, const size_t length); void slunkcrypt_thrdpl_destroy(thrdpl_t *const thrdpl); #endif