Slightly improved thread management code.

This commit is contained in:
LoRd_MuldeR 2022-03-25 00:42:44 +01:00
parent f70ccb6a14
commit 342562cf2e
Signed by: mulder
GPG Key ID: 2B5913365F57E03F
3 changed files with 87 additions and 133 deletions

View File

@ -262,10 +262,8 @@ static void thread_worker(const size_t thread_count, void *const context, uint8_
buffer[i] = process_next_symbol(state, buffer[i]);
state->counter += (uint32_t)thread_count;
random_skip(&state->random, 63U * (thread_count - 1U));
CHECK_ABORTED();
}
aborted:
update_index(state, thread_count, length);
}
@ -311,7 +309,14 @@ slunkcrypt_t slunkcrypt_alloc_ext(const uint64_t nonce, const uint8_t *const pas
return SLUNKCRYPT_NULL;
}
state->thread_pool = slunkcrypt_thrdpl_create(param->thread_count);
if ((state->thread_pool = slunkcrypt_thrdpl_create(param->thread_count, thread_worker)))
{
size_t i;
for (i = 0U; i < slunkcrypt_thrdpl_count(state->thread_pool); ++i)
{
slunkcrypt_thrdpl_init(state->thread_pool, i, &state->data.thread_data[i]);
}
}
if (initialize_state(&state->data, THREAD_COUNT(state), nonce, passwd, passwd_len, mode) == SLUNKCRYPT_SUCCESS)
{
@ -368,12 +373,7 @@ int slunkcrypt_inplace(const slunkcrypt_t context, uint8_t *const buffer, size_t
const size_t thread_count = THREAD_COUNT(state);
if (thread_count > 1U)
{
size_t i;
for (i = 0U; i < thread_count; ++i)
{
slunkcrypt_thrdpl_exec(state->thread_pool, i, thread_worker, &state->data.thread_data[i], buffer, length);
}
slunkcrypt_thrdpl_await(state->thread_pool);
slunkcrypt_thrdpl_exec(state->thread_pool, buffer, length);
}
else
{

View File

@ -34,27 +34,19 @@
typedef struct
{
thrdpl_worker_t worker;
void *context;
size_t length;
uint8_t *buffer;
}
thrdpl_task_t;
typedef struct
{
size_t thread_count, pending;
size_t thread_count, generation, remain, length;
int stop_flag;
pthread_mutex_t mutex;
pthread_cond_t cond_pending;
pthread_cond_t cond_0, cond_1;
}
thrdpl_shared_t;
typedef struct
{
thrdpl_shared_t *shared;
size_t state;
pthread_cond_t cond_state;
void *context;
pthread_t thread;
thrdpl_task_t task;
}
thrdpl_thread_t;
@ -68,12 +60,12 @@ struct thrdpl_data_t
// Utilities
// ==========================================================================
static INLINE size_t bound(const size_t min, const size_t value, const size_t max)
static INLINE size_t BOUND(const size_t min, const size_t value, const size_t max)
{
return (value < min) ? min : ((value > max) ? max : value);
}
#define PTHRD_MUTEX_LOCK(X) do \
#define PTHRD_MUTEX_ENTER(X) do \
{ \
if (pthread_mutex_lock((X)) != 0) \
{ \
@ -82,7 +74,7 @@ static INLINE size_t bound(const size_t min, const size_t value, const size_t ma
} \
while(0)
#define PTHRD_MUTEX_UNLOCK(X) do \
#define PTHRD_MUTEX_LEAVE(X) do \
{ \
if (pthread_mutex_unlock((X)) != 0) \
{ \
@ -100,7 +92,7 @@ while(0)
} \
while(0)
#define PTHRD_COND_BROADCAST(X) do \
#define PTHRD_COND_BRDCST(X) do \
{ \
if (pthread_cond_broadcast((X)) != 0) \
{ \
@ -118,11 +110,11 @@ while(0)
} \
while(0)
#define CHECK_IF_CANCELLED() do \
#define CHECK_IF_STOPPED() do \
{ \
if (data->state == TSTATE_EXIT) \
if (shared->stop_flag) \
{ \
PTHRD_MUTEX_UNLOCK(&shared->mutex); \
PTHRD_MUTEX_LEAVE(&shared->mutex); \
return NULL; \
} \
} \
@ -137,35 +129,31 @@ static void *worker_thread_main(void *const arg)
thrdpl_thread_t *const data = (thrdpl_thread_t*) arg;
thrdpl_shared_t *const shared = (thrdpl_shared_t*) data->shared;
thrdpl_task_t *task;
size_t previous = 0U;
PTHRD_MUTEX_ENTER(&shared->mutex);
CHECK_IF_STOPPED();
for (;;)
{
PTHRD_MUTEX_LOCK(&shared->mutex);
CHECK_IF_CANCELLED();
while (data->state != TSTATE_WORK)
while (shared->generation == previous)
{
PTHRD_COND_WAIT(&data->cond_state, &shared->mutex);
CHECK_IF_CANCELLED();
PTHRD_COND_WAIT(&shared->cond_0, &shared->mutex);
CHECK_IF_STOPPED();
}
task = &data->task;
PTHRD_MUTEX_UNLOCK(&shared->mutex);
previous = shared->generation;
PTHRD_MUTEX_LEAVE(&shared->mutex);
task->worker(shared->thread_count, task->context, task->buffer, task->length);
shared->worker(shared->thread_count, data->context, shared->buffer, shared->length);
PTHRD_MUTEX_LOCK(&shared->mutex);
CHECK_IF_CANCELLED();
PTHRD_MUTEX_ENTER(&shared->mutex);
CHECK_IF_STOPPED();
data->state = TSTATE_IDLE;
if (!(--shared->pending))
if (!(--shared->remain))
{
PTHRD_COND_BROADCAST(&shared->cond_pending);
PTHRD_COND_SIGNAL(&shared->cond_1);
}
PTHRD_MUTEX_UNLOCK(&shared->mutex);
PTHRD_COND_SIGNAL(&data->cond_state);
}
}
@ -191,52 +179,16 @@ static size_t detect_cpu_count(void)
return 1U;
}
// ==========================================================================
// Manage threads
// ==========================================================================
static int create_worker(thrdpl_shared_t *const shared, thrdpl_thread_t *const thread_data)
{
thread_data->state = TSTATE_IDLE;
thread_data->shared = shared;
if (pthread_cond_init(&thread_data->cond_state, NULL) != 0)
{
return -1;
}
if (pthread_create(&thread_data->thread, NULL, worker_thread_main, thread_data) != 0)
{
pthread_cond_destroy(&thread_data->cond_state);
return -1;
}
return 0;
}
static int destroy_worker(thrdpl_thread_t *const thread_data)
{
PTHRD_MUTEX_LOCK(&thread_data->shared->mutex);
thread_data->state = TSTATE_EXIT;
PTHRD_MUTEX_UNLOCK(&thread_data->shared->mutex);
PTHRD_COND_BROADCAST(&thread_data->cond_state);
pthread_join(thread_data->thread, NULL);
pthread_cond_destroy(&thread_data->cond_state);
return 0;
}
// ==========================================================================
// Thread pool API
// ==========================================================================
thrdpl_t *slunkcrypt_thrdpl_create(const size_t count)
thrdpl_t *slunkcrypt_thrdpl_create(const size_t count, const thrdpl_worker_t worker)
{
size_t i, j;
size_t i;
thrdpl_t *thrdpl = NULL;
const size_t cpu_count = bound(1U, (count > 0U) ? count : detect_cpu_count(), MAX_THREADS);
const size_t cpu_count = BOUND(1U, (count > 0U) ? count : detect_cpu_count(), MAX_THREADS);
if (cpu_count < 2U)
{
return NULL;
@ -248,31 +200,35 @@ thrdpl_t *slunkcrypt_thrdpl_create(const size_t count)
}
memset(thrdpl, 0, sizeof(thrdpl_t));
thrdpl->shared.thread_count = cpu_count;
thrdpl->shared.worker = worker;
if (pthread_mutex_init(&thrdpl->shared.mutex, NULL) != 0)
{
goto failure;
}
if (pthread_cond_init(&thrdpl->shared.cond_pending, NULL) != 0)
if (pthread_cond_init(&thrdpl->shared.cond_0, NULL) != 0)
{
pthread_mutex_destroy(&thrdpl->shared.mutex);
goto failure;
}
if (pthread_cond_init(&thrdpl->shared.cond_1, NULL) != 0)
{
pthread_cond_destroy(&thrdpl->shared.cond_0);
pthread_mutex_destroy(&thrdpl->shared.mutex);
goto failure;
}
for (i = 0U; i < cpu_count; ++i)
{
if (create_worker(&thrdpl->shared, &thrdpl->thread_data[i]) != 0)
thrdpl->thread_data[i].shared = &thrdpl->shared;
if (pthread_create(&thrdpl->thread_data[i].thread, NULL, worker_thread_main, &thrdpl->thread_data[i]) != 0)
{
for (j = 0U; j < i; ++j)
{
destroy_worker(&thrdpl->thread_data[j]);
}
pthread_cond_destroy(&thrdpl->shared.cond_pending);
pthread_mutex_destroy(&thrdpl->shared.mutex);
goto failure;
slunkcrypt_thrdpl_destroy(thrdpl);
return NULL;
}
++thrdpl->shared.thread_count;
}
return thrdpl;
@ -287,58 +243,56 @@ size_t slunkcrypt_thrdpl_count(const thrdpl_t *const thrdpl)
return thrdpl->shared.thread_count;
}
void slunkcrypt_thrdpl_exec(thrdpl_t *const thrdpl, const size_t index, const thrdpl_worker_t worker, void *const context, uint8_t *const buffer, const size_t length)
void slunkcrypt_thrdpl_init(thrdpl_t *const thrdpl, const size_t index, void *const context)
{
thrdpl_thread_t *const thread = &thrdpl->thread_data[index];
thrdpl->thread_data[index].context = context;
}
PTHRD_MUTEX_LOCK(&thrdpl->shared.mutex);
void slunkcrypt_thrdpl_exec(thrdpl_t *const thrdpl, uint8_t *const buffer, const size_t length)
{
PTHRD_MUTEX_ENTER(&thrdpl->shared.mutex);
while ((thread->state != TSTATE_IDLE) && (thread->state != TSTATE_EXIT))
{
PTHRD_COND_WAIT(&thread->cond_state, &thrdpl->shared.mutex);
}
if (thread->state == TSTATE_EXIT)
if (thrdpl->shared.stop_flag || (thrdpl->shared.remain != 0U))
{
abort(); /*this is not supposed to happen!*/
}
thread->state = TSTATE_WORK;
thread->task.worker = worker;
thread->task.context = context;
thread->task.buffer = buffer;
thread->task.length = length;
thrdpl->shared.buffer = buffer;
thrdpl->shared.length = length;
thrdpl->shared.remain = thrdpl->shared.thread_count;
++thrdpl->shared.generation;
PTHRD_COND_BRDCST(&thrdpl->shared.cond_0);
++thrdpl->shared.pending;
PTHRD_MUTEX_UNLOCK(&thrdpl->shared.mutex);
PTHRD_COND_SIGNAL(&thread->cond_state);
}
void slunkcrypt_thrdpl_await(thrdpl_t *const thrdpl)
{
PTHRD_MUTEX_LOCK(&thrdpl->shared.mutex);
while (thrdpl->shared.pending)
while (thrdpl->shared.remain)
{
PTHRD_COND_WAIT(&thrdpl->shared.cond_pending, &thrdpl->shared.mutex);
PTHRD_COND_WAIT(&thrdpl->shared.cond_1, &thrdpl->shared.mutex);
}
PTHRD_MUTEX_UNLOCK(&thrdpl->shared.mutex);
PTHRD_MUTEX_LEAVE(&thrdpl->shared.mutex);
}
void slunkcrypt_thrdpl_destroy(thrdpl_t *const thrdpl)
{
size_t i;
PTHRD_MUTEX_ENTER(&thrdpl->shared.mutex);
if (thrdpl)
if (!thrdpl->shared.stop_flag)
{
for (i = 0U; i < thrdpl->shared.thread_count; ++i)
{
destroy_worker(&thrdpl->thread_data[i]);
}
pthread_cond_destroy(&thrdpl->shared.cond_pending);
pthread_mutex_destroy(&thrdpl->shared.mutex);
free(thrdpl);
thrdpl->shared.stop_flag = 1;
PTHRD_COND_BRDCST(&thrdpl->shared.cond_0);
}
PTHRD_MUTEX_LEAVE(&thrdpl->shared.mutex);
for (i = 0U; i < thrdpl->shared.thread_count; ++i)
{
pthread_join(thrdpl->thread_data[i].thread, NULL);
}
pthread_cond_destroy(&thrdpl->shared.cond_0);
pthread_cond_destroy(&thrdpl->shared.cond_1);
pthread_mutex_destroy(&thrdpl->shared.mutex);
free(thrdpl);
}

View File

@ -14,10 +14,10 @@
typedef void (*thrdpl_worker_t)(const size_t thread_count, void *const context, uint8_t *const buffer, const size_t length);
typedef struct thrdpl_data_t thrdpl_t;
thrdpl_t *slunkcrypt_thrdpl_create(const size_t count);
thrdpl_t *slunkcrypt_thrdpl_create(const size_t count, const thrdpl_worker_t worker);
size_t slunkcrypt_thrdpl_count(const thrdpl_t *const thrdpl);
void slunkcrypt_thrdpl_exec(thrdpl_t *const thrdpl, const size_t index, const thrdpl_worker_t worker, void *const context, uint8_t *const buffer, const size_t length);
void slunkcrypt_thrdpl_await(thrdpl_t *const thrdpl);
void slunkcrypt_thrdpl_init(thrdpl_t *const thrdpl, const size_t index, void *const context);
void slunkcrypt_thrdpl_exec(thrdpl_t *const thrdpl, uint8_t *const buffer, const size_t length);
void slunkcrypt_thrdpl_destroy(thrdpl_t *const thrdpl);
#endif