diff --git a/mtapi_plugins_c/mtapi_network_c/src/embb_mtapi_network.c b/mtapi_plugins_c/mtapi_network_c/src/embb_mtapi_network.c index 33bf0f9..83638a9 100644 --- a/mtapi_plugins_c/mtapi_network_c/src/embb_mtapi_network.c +++ b/mtapi_plugins_c/mtapi_network_c/src/embb_mtapi_network.c @@ -73,7 +73,8 @@ void embb_mtapi_network_finalize() { enum embb_mtapi_network_operation_enum { EMBB_MTAPI_NETWORK_START_TASK = 0x01AFFE01, EMBB_MTAPI_NETWORK_RETURN_RESULT = 0x02AFFE02, - EMBB_MTAPI_NETWORK_RETURN_FAILURE = 0x03AFFE03 + EMBB_MTAPI_NETWORK_RETURN_FAILURE = 0x03AFFE03, + EMBB_MTAPI_NETWORK_CANCEL_TASK = 0x04AFFE04 }; struct embb_mtapi_network_plugin_struct { @@ -115,9 +116,35 @@ struct embb_mtapi_network_task_struct { typedef struct embb_mtapi_network_task_struct embb_mtapi_network_task_t; -static void embb_mtapi_network_task_failure( - ) { +static void embb_mtapi_network_return_failure( + int32_t remote_task_id, + int32_t remote_task_tag, + mtapi_status_t status, + embb_mtapi_network_socket_t * socket, + embb_mtapi_network_buffer_t * buffer) +{ + embb_mtapi_network_buffer_clear(buffer); + + // packet size + embb_mtapi_network_buffer_push_back_int32( + buffer, 16); + + // operation + embb_mtapi_network_buffer_push_back_int32( + buffer, EMBB_MTAPI_NETWORK_RETURN_FAILURE); + + // task handle + embb_mtapi_network_buffer_push_back_int32( + buffer, remote_task_id); + embb_mtapi_network_buffer_push_back_int32( + buffer, remote_task_tag); + // status + embb_mtapi_network_buffer_push_back_int32( + buffer, (int32_t)status); + + embb_mtapi_network_socket_sendbuffer( + socket, buffer); } static void embb_mtapi_network_task_complete( @@ -144,49 +171,68 @@ static void embb_mtapi_network_task_complete( (embb_mtapi_network_task_t*)local_task->attributes.user_data; embb_mtapi_network_buffer_t * send_buf = &plugin->send_buffer; + embb_atomic_memory_barrier(); + local_task->attributes.complete_func = NULL; + embb_atomic_memory_barrier(); + // serialize sending of results embb_mutex_lock(&plugin->send_mutex); embb_mtapi_network_buffer_clear(send_buf); - // actual counts bytes actually put into the buffer - int actual = 0; - // expected counts bytes we intended to put into the buffer - int expected = - 4 + // operation - 4 + 4 + // remote task handle - 4 + // status - 4 + (int)local_task->result_size; // result buffer - - // packet size - actual += embb_mtapi_network_buffer_push_back_int32( + if (local_task->error_code == MTAPI_SUCCESS) { + // actual counts bytes actually put into the buffer + int actual = 0; + // expected counts bytes we intended to put into the buffer + int expected = + 4 + // operation + 4 + 4 + // remote task handle + 4 + // status + 4 + (int)local_task->result_size; // result buffer + + // packet size + actual += embb_mtapi_network_buffer_push_back_int32( send_buf, expected); - expected += 4; - - // operation is "return result" - actual += embb_mtapi_network_buffer_push_back_int32( - send_buf, EMBB_MTAPI_NETWORK_RETURN_RESULT); - - // remote task id - actual += embb_mtapi_network_buffer_push_back_int32( - send_buf, network_task->remote_task_id); - actual += embb_mtapi_network_buffer_push_back_int32( - send_buf, network_task->remote_task_tag); - - // status - actual += embb_mtapi_network_buffer_push_back_int32( - send_buf, local_task->error_code); - - // result size - actual += embb_mtapi_network_buffer_push_back_int32( - send_buf, (int32_t)local_task->result_size); - actual += embb_mtapi_network_buffer_push_back_rawdata( - send_buf, (int32_t)local_task->result_size, - local_task->result_buffer); - - if (expected == actual) { - int sent = embb_mtapi_network_socket_sendbuffer( + expected += 4; + + // operation is "return result" + actual += embb_mtapi_network_buffer_push_back_int32( + send_buf, EMBB_MTAPI_NETWORK_RETURN_RESULT); + + // remote task id + actual += embb_mtapi_network_buffer_push_back_int32( + send_buf, network_task->remote_task_id); + actual += embb_mtapi_network_buffer_push_back_int32( + send_buf, network_task->remote_task_tag); + + // status + actual += embb_mtapi_network_buffer_push_back_int32( + send_buf, local_task->error_code); + + // result size + actual += embb_mtapi_network_buffer_push_back_int32( + send_buf, (int32_t)local_task->result_size); + actual += embb_mtapi_network_buffer_push_back_rawdata( + send_buf, (int32_t)local_task->result_size, + local_task->result_buffer); + + if (expected == actual) { + int sent = embb_mtapi_network_socket_sendbuffer( + &network_task->socket, send_buf); + assert(sent == send_buf->size); + } + else { + embb_mtapi_network_return_failure( + network_task->remote_task_id, + network_task->remote_task_tag, + MTAPI_ERR_UNKNOWN, + &network_task->socket, send_buf); + } + } else { + embb_mtapi_network_return_failure( + network_task->remote_task_id, + network_task->remote_task_tag, + local_task->error_code, &network_task->socket, send_buf); - assert(sent == send_buf->size); } // sending done @@ -196,6 +242,14 @@ static void embb_mtapi_network_task_complete( embb_free((void*)local_task->arguments); embb_free(local_task->result_buffer); + void * data = local_task->attributes.user_data; + + embb_atomic_memory_barrier(); + local_task->attributes.user_data = NULL; + embb_atomic_memory_barrier(); + + embb_free(data); + local_status = MTAPI_SUCCESS; } } @@ -204,37 +258,6 @@ static void embb_mtapi_network_task_complete( mtapi_status_set(status, local_status); } -static void embb_mtapi_network_return_failure( - int32_t remote_task_id, - int32_t remote_task_tag, - mtapi_status_t status, - embb_mtapi_network_socket_t * socket, - embb_mtapi_network_buffer_t * buffer) -{ - embb_mtapi_network_buffer_clear(buffer); - - // packet size - embb_mtapi_network_buffer_push_back_int32( - buffer, 16); - - // operation - embb_mtapi_network_buffer_push_back_int32( - buffer, EMBB_MTAPI_NETWORK_RETURN_FAILURE); - - // task handle - embb_mtapi_network_buffer_push_back_int32( - buffer, remote_task_id); - embb_mtapi_network_buffer_push_back_int32( - buffer, remote_task_tag); - - // status - embb_mtapi_network_buffer_push_back_int32( - buffer, (int32_t)status); - - embb_mtapi_network_socket_sendbuffer( - socket, buffer); -} - static mtapi_status_t embb_mtapi_network_handle_start_task( embb_mtapi_network_socket_t * socket, embb_mtapi_network_buffer_t * buffer, @@ -365,9 +388,9 @@ static mtapi_status_t embb_mtapi_network_handle_return_result( embb_mtapi_network_buffer_t * buffer, int packet_size) { - int task_status; - int task_id; - int task_tag; + int32_t task_status; + int32_t task_id; + int32_t task_tag; int32_t results_size; int err; @@ -447,9 +470,9 @@ static mtapi_status_t embb_mtapi_network_handle_return_failure( embb_mtapi_network_buffer_t * buffer, int packet_size) { - int task_status; - int task_id; - int task_tag; + int32_t task_status; + int32_t task_id; + int32_t task_tag; int err; mtapi_status_t local_status = MTAPI_ERR_UNKNOWN; @@ -485,9 +508,13 @@ static mtapi_status_t embb_mtapi_network_handle_return_failure( embb_mtapi_action_pool_get_storage_for_handle( node->action_pool, local_task->action); - local_task->error_code = (mtapi_status_t)task_status; - embb_atomic_store_int(&local_task->state, MTAPI_TASK_ERROR); embb_atomic_fetch_and_add_int(&local_action->num_tasks, -1); + local_task->error_code = (mtapi_status_t)task_status; + if (MTAPI_ERR_ACTION_CANCELLED == task_status) { + embb_atomic_store_int(&local_task->state, MTAPI_TASK_CANCELLED); + } else { + embb_atomic_store_int(&local_task->state, MTAPI_TASK_ERROR); + } /* is task associated with a group? */ if (embb_mtapi_group_pool_is_handle_valid( @@ -509,6 +536,48 @@ static mtapi_status_t embb_mtapi_network_handle_return_failure( return local_status; } +static mtapi_status_t embb_mtapi_network_handle_cancel_task( + embb_mtapi_network_buffer_t * buffer, + int packet_size) { + + mtapi_status_t local_status = MTAPI_ERR_UNKNOWN; + int32_t remote_task_id; + int32_t remote_task_tag; + int err; + EMBB_UNUSED_IN_RELEASE(err); + + // do we have 8 bytes? + if (packet_size == 8) { + // get task handle + err = embb_mtapi_network_buffer_pop_front_int32(buffer, &remote_task_id); + assert(err == 4); + err = embb_mtapi_network_buffer_pop_front_int32(buffer, &remote_task_tag); + assert(err == 4); + + if (embb_mtapi_node_is_initialized()) { + embb_mtapi_node_t * node = embb_mtapi_node_get_instance(); + + // search for task to cancel + for (mtapi_uint_t ii = 0; ii < node->attributes.max_tasks; ii++) { + embb_mtapi_task_t * task = &node->task_pool->storage[ii]; + // is this our task? + if (embb_mtapi_network_task_complete == task->attributes.complete_func) { + embb_mtapi_network_task_t * network_task = + (embb_mtapi_network_task_t*)task->attributes.user_data; + // is this task the one matching the given remote task? + if (remote_task_id == network_task->remote_task_id && + remote_task_tag == network_task->remote_task_tag) { + mtapi_task_cancel(task->handle, &local_status); + break; + } + } + } + } + } + + return local_status; +} + static int embb_mtapi_network_thread(void * args) { embb_mtapi_network_plugin_t * plugin = &embb_mtapi_network_plugin; embb_mtapi_network_buffer_t * buffer = &plugin->recv_buffer; @@ -562,6 +631,9 @@ static int embb_mtapi_network_thread(void * args) { case EMBB_MTAPI_NETWORK_RETURN_FAILURE: embb_mtapi_network_handle_return_failure(buffer, packet_size); break; + case EMBB_MTAPI_NETWORK_CANCEL_TASK: + embb_mtapi_network_handle_cancel_task(buffer, packet_size); + break; default: // invalid, ignore break; @@ -785,17 +857,18 @@ static void network_task_start( // check if everything fit into the buffer if (actual == expected) { + embb_atomic_fetch_and_add_int(&local_action->num_tasks, 1); + embb_atomic_store_int(&local_task->state, MTAPI_TASK_RUNNING); int sent = embb_mtapi_network_socket_sendbuffer( &network_action->socket, send_buf); // was everything sent? if (sent == send_buf->size) { - embb_atomic_fetch_and_add_int(&local_action->num_tasks, 1); - embb_atomic_store_int(&local_task->state, MTAPI_TASK_RUNNING); // we've done it, success! mtapi_status_set(status, MTAPI_SUCCESS); } else { // could not send the whole task, this will fail on the remote side, // so we can safely assume that the task is in error + embb_atomic_fetch_and_add_int(&local_action->num_tasks, -1); embb_atomic_store_int(&local_task->state, MTAPI_TASK_ERROR); } } @@ -810,11 +883,73 @@ static void network_task_start( static void network_task_cancel( MTAPI_IN mtapi_task_hndl_t task, MTAPI_OUT mtapi_status_t* status) { - mtapi_status_t local_status = MTAPI_ERR_UNKNOWN; - EMBB_UNUSED(task); + // assume failure + mtapi_status_set(status, MTAPI_ERR_UNKNOWN); - mtapi_status_set(status, local_status); + if (embb_mtapi_node_is_initialized()) { + embb_mtapi_node_t * node = embb_mtapi_node_get_instance(); + + if (embb_mtapi_task_pool_is_handle_valid(node->task_pool, task)) { + embb_mtapi_task_t * local_task = + embb_mtapi_task_pool_get_storage_for_handle(node->task_pool, task); + + if (embb_mtapi_action_pool_is_handle_valid( + node->action_pool, local_task->action)) { + embb_mtapi_action_t * local_action = + embb_mtapi_action_pool_get_storage_for_handle( + node->action_pool, local_task->action); + + embb_mtapi_network_action_t * network_action = + (embb_mtapi_network_action_t*)local_action->plugin_data; + embb_mtapi_network_buffer_t * send_buf = &network_action->send_buffer; + + // serialize sending + embb_mutex_lock(&network_action->send_mutex); + embb_mtapi_network_buffer_clear(send_buf); + + // actual counts bytes actually put into the buffer + int actual = 0; + // expected counts bytes we intended to put into the buffer + int expected = + 4 + // operation + 4 + 4; // task handle + + // packet size + actual += embb_mtapi_network_buffer_push_back_int32( + send_buf, (int32_t)expected); + expected += 4; + + // operation is "cancel task" + actual += embb_mtapi_network_buffer_push_back_int32( + send_buf, EMBB_MTAPI_NETWORK_CANCEL_TASK); + + // task handle + actual += embb_mtapi_network_buffer_push_back_int32( + send_buf, (int32_t)local_task->handle.id); + actual += embb_mtapi_network_buffer_push_back_int32( + send_buf, (int32_t)local_task->handle.tag); + + // check if everything fit into the buffer + if (actual == expected) { + int sent = embb_mtapi_network_socket_sendbuffer( + &network_action->socket, send_buf); + // was everything sent? + if (sent == send_buf->size) { + // we've done it, success! + mtapi_status_set(status, MTAPI_SUCCESS); + } else { + embb_atomic_store_int(&local_task->state, MTAPI_TASK_ERROR); + } + } else { + embb_atomic_store_int(&local_task->state, MTAPI_TASK_ERROR); + } + + embb_mtapi_network_buffer_clear(send_buf); + embb_mutex_unlock(&network_action->send_mutex); + } + } + } } static void network_action_finalize( diff --git a/mtapi_plugins_c/mtapi_network_c/test/embb_mtapi_network_test_task.cc b/mtapi_plugins_c/mtapi_network_c/test/embb_mtapi_network_test_task.cc index cec09a2..bfc27d0 100644 --- a/mtapi_plugins_c/mtapi_network_c/test/embb_mtapi_network_test_task.cc +++ b/mtapi_plugins_c/mtapi_network_c/test/embb_mtapi_network_test_task.cc @@ -61,13 +61,52 @@ static void test( } } +static void cancel_test( + void const * /*arguments*/, + mtapi_size_t /*arguments_size*/, + void * /*result_buffer*/, + mtapi_size_t /*result_buffer_size*/, + void const * /*node_local_data*/, + mtapi_size_t /*node_local_data_size*/, + mtapi_task_context_t * context) { + mtapi_status_t status; + while (true) { + mtapi_task_state_t state = mtapi_context_taskstate_get(context, &status); + if (status != MTAPI_SUCCESS) { + break; + } else { + if (state == MTAPI_TASK_CANCELLED) { + break; + } + } + } +} NetworkTaskTest::NetworkTaskTest() { - CreateUnit("mtapi network task test").Add(&NetworkTaskTest::TestBasic, this); + CreateUnit("mtapi network task test") + .Add(&NetworkTaskTest::TestBasic, this); } void NetworkTaskTest::TestBasic() { mtapi_status_t status; + + mtapi_initialize( + NETWORK_DOMAIN, + NETWORK_LOCAL_NODE, + MTAPI_NULL, + MTAPI_NULL, + &status); + MTAPI_CHECK_STATUS(status); + + TestSimple(); + TestCancel(); + + mtapi_finalize(&status); + MTAPI_CHECK_STATUS(status); +} + +void NetworkTaskTest::TestSimple() { + mtapi_status_t status; mtapi_job_hndl_t job; mtapi_task_hndl_t task; mtapi_action_hndl_t network_action, local_action; @@ -81,14 +120,6 @@ void NetworkTaskTest::TestBasic() { arguments[ii + kElements] = static_cast(ii); } - mtapi_initialize( - NETWORK_DOMAIN, - NETWORK_LOCAL_NODE, - MTAPI_NULL, - MTAPI_NULL, - &status); - MTAPI_CHECK_STATUS(status); - mtapi_network_plugin_initialize("127.0.0.1", 12345, 5, kElements * 4 * 3 + 32, &status); MTAPI_CHECK_STATUS(status); @@ -139,7 +170,68 @@ void NetworkTaskTest::TestBasic() { mtapi_network_plugin_finalize(&status); MTAPI_CHECK_STATUS(status); +} - mtapi_finalize(&status); +void NetworkTaskTest::TestCancel() { + mtapi_status_t status; + mtapi_job_hndl_t job; + mtapi_task_hndl_t task; + mtapi_action_hndl_t network_action, local_action; + + float argument = 1.0f; + float result; + + mtapi_network_plugin_initialize("127.0.0.1", 12345, 5, + 4 * 3 + 32, &status); + MTAPI_CHECK_STATUS(status); + + float node_remote = 1.0f; + local_action = mtapi_action_create( + NETWORK_REMOTE_JOB, + cancel_test, + &node_remote, sizeof(float), + MTAPI_DEFAULT_ACTION_ATTRIBUTES, + &status); + MTAPI_CHECK_STATUS(status); + + network_action = mtapi_network_action_create( + NETWORK_DOMAIN, + NETWORK_LOCAL_JOB, + NETWORK_REMOTE_JOB, + "127.0.0.1", 12345, + &status); + MTAPI_CHECK_STATUS(status); + + status = MTAPI_ERR_UNKNOWN; + job = mtapi_job_get(NETWORK_LOCAL_JOB, NETWORK_DOMAIN, &status); + MTAPI_CHECK_STATUS(status); + + task = mtapi_task_start( + MTAPI_TASK_ID_NONE, + job, + &argument, sizeof(float), + &result, sizeof(float), + MTAPI_DEFAULT_TASK_ATTRIBUTES, + MTAPI_GROUP_NONE, + &status); + MTAPI_CHECK_STATUS(status); + + mtapi_task_wait(task, 1, &status); + PT_ASSERT_EQ(status, MTAPI_TIMEOUT); + + mtapi_task_cancel(task, &status); + MTAPI_CHECK_STATUS(status); + + mtapi_task_wait(task, MTAPI_INFINITE, &status); + PT_ASSERT_NE(status, MTAPI_TIMEOUT); + PT_ASSERT_EQ(status, MTAPI_ERR_ACTION_CANCELLED); + + mtapi_action_delete(network_action, MTAPI_INFINITE, &status); + MTAPI_CHECK_STATUS(status); + + mtapi_action_delete(local_action, MTAPI_INFINITE, &status); + MTAPI_CHECK_STATUS(status); + + mtapi_network_plugin_finalize(&status); MTAPI_CHECK_STATUS(status); } diff --git a/mtapi_plugins_c/mtapi_network_c/test/embb_mtapi_network_test_task.h b/mtapi_plugins_c/mtapi_network_c/test/embb_mtapi_network_test_task.h index 8162679..aa2573d 100644 --- a/mtapi_plugins_c/mtapi_network_c/test/embb_mtapi_network_test_task.h +++ b/mtapi_plugins_c/mtapi_network_c/test/embb_mtapi_network_test_task.h @@ -35,6 +35,9 @@ class NetworkTaskTest : public partest::TestCase { private: void TestBasic(); + + void TestSimple(); + void TestCancel(); }; #endif // MTAPI_PLUGINS_C_MTAPI_NETWORK_C_TEST_EMBB_MTAPI_NETWORK_TEST_TASK_H_