diff --git a/sxt/multiexp/pippenger2/multiexponentiation.h b/sxt/multiexp/pippenger2/multiexponentiation.h index b4735417..f563e37c 100644 --- a/sxt/multiexp/pippenger2/multiexponentiation.h +++ b/sxt/multiexp/pippenger2/multiexponentiation.h @@ -53,128 +53,132 @@ struct multiexponentiate_options { }; //-------------------------------------------------------------------------------------------------- -// multiexponentiate_no_chunks +// multiexponentiate_product_step //-------------------------------------------------------------------------------------------------- template requires std::constructible_from xena::future<> -multiexponentiate_no_chunks(basct::span res, const partition_table_accessor& accessor, - unsigned element_num_bytes, basct::cspan scalars) noexcept { - auto num_outputs = res.size(); - auto n = scalars.size() / (num_outputs * element_num_bytes); - auto num_products = num_outputs * element_num_bytes * 8u; - SXT_DEBUG_ASSERT( - // clang-format off - scalars.size() % (num_outputs * element_num_bytes) == 0 - // clang-format on - ); +multiexponentiate_product_step(basct::span products, basdv::stream& reduction_stream, + const partition_table_accessor& accessor, unsigned num_products, + unsigned num_output_bytes, basct::cspan scalars, + const multiexponentiate_options& options) noexcept { + auto n = scalars.size() / num_output_bytes; // compute bitwise products - basl::info("computing {} bitwise multiexponentiation products of length {}", num_products, n); - memmg::managed_array products(num_products, memr::get_device_resource()); - co_await async_partition_product(products, accessor, scalars, 0); + // + // We split the work by groups of generators so that a single chunk will process + // all the outputs for those generators. This minimizes the amount of host->device + // copying we need to do for the table of precomputed sums. + auto [chunk_first, chunk_last] = basit::split(basit::index_range{0, n} + .chunk_multiple(16) + .min_chunk_size(options.min_chunk_size) + .max_chunk_size(options.max_chunk_size), + options.split_factor); + auto num_chunks = static_cast(std::distance(chunk_first, chunk_last)); + basl::info("computing {} bitwise multiexponentiation products of length {} using {} chunks", + num_products, n, num_chunks); - // reduce products - basl::info("reducing {} products to {} outputs", num_products, num_products); - basdv::stream stream; - memr::async_device_resource resource{stream}; - memmg::managed_array res_dev{num_outputs, &resource}; - reduce_products(res_dev, stream, products); - products.reset(); + // handle no chunk case + if (num_chunks == 1) { + co_await async_partition_product(products, accessor, scalars, 0); + co_return; + } - // copy result - basdv::async_copy_device_to_host(res, res_dev, stream); - co_await xendv::await_stream(stream); - basl::info("completed {} reductions", num_outputs); + // handle multiple chunks + memmg::managed_array partial_products{num_products * num_chunks, memr::get_pinned_resource()}; + size_t chunk_index = 0; + co_await xendv::concurrent_for_each( + chunk_first, chunk_last, [&](const basit::index_range& rng) noexcept -> xena::future<> { + basl::info("computing {} multiproducts for generators [{}, {}] on device {}", num_products, + rng.a(), rng.b(), basdv::get_device()); + memmg::managed_array partial_products_dev{num_products, memr::get_device_resource()}; + auto scalars_slice = + scalars.subspan(num_output_bytes * rng.a(), rng.size() * num_output_bytes); + co_await async_partition_product(partial_products_dev, accessor, scalars_slice, rng.a()); + basdv::stream stream; + basdv::async_copy_device_to_host( + basct::subspan(partial_products, num_products * chunk_index, num_products), + partial_products_dev, stream); + ++chunk_index; + co_await xendv::await_stream(stream); + }); + + // combine the partial products + basl::info("combining {} partial product chunks", num_chunks); + memr::async_device_resource resource{reduction_stream}; + memmg::managed_array partial_products_dev{partial_products.size(), &resource}; + basdv::async_copy_host_to_device(partial_products_dev, partial_products, reduction_stream); + combine(products, reduction_stream, partial_products_dev); + co_await xendv::await_stream(reduction_stream); } //-------------------------------------------------------------------------------------------------- -// complete_multiexponentiation +// multiexponentiate_impl //-------------------------------------------------------------------------------------------------- -template -xena::future<> complete_multiexponentiation(basct::span res, unsigned element_num_bytes, - basct::cspan partial_products) noexcept { +template + requires std::constructible_from +xena::future<> multiexponentiate_impl(basct::span res, + const partition_table_accessor& accessor, + unsigned element_num_bytes, basct::cspan scalars, + const multiexponentiate_options& options) noexcept { auto num_outputs = res.size(); auto num_products = num_outputs * element_num_bytes * 8u; + SXT_DEBUG_ASSERT( + // clang-format off + scalars.size() % (num_outputs * element_num_bytes) == 0 + // clang-format on + ); basdv::stream stream; memr::async_device_resource resource{stream}; - - // combine the partial results - memmg::managed_array partial_products_dev{partial_products.size(), &resource}; - basdv::async_copy_host_to_device(partial_products_dev, partial_products, stream); memmg::managed_array products{num_products, &resource}; - combine(products, stream, partial_products_dev); - partial_products_dev.reset(); + co_await multiexponentiate_product_step(products, stream, accessor, num_products, + num_outputs * element_num_bytes, scalars, options); // reduce the products + basl::info("reducing products for {} outputs", num_outputs); memmg::managed_array res_dev{num_outputs, &resource}; reduce_products(res_dev, stream, products); products.reset(); + basl::info("completed {} reductions", num_outputs); // copy result basdv::async_copy_device_to_host(res, res_dev, stream); co_await xendv::await_stream(stream); + basl::info("complete multiexponentiation"); } -//-------------------------------------------------------------------------------------------------- -// multiexponentiate_impl -//-------------------------------------------------------------------------------------------------- template requires std::constructible_from -xena::future<> multiexponentiate_impl(basct::span res, - const partition_table_accessor& accessor, - unsigned element_num_bytes, basct::cspan scalars, - const multiexponentiate_options& options) noexcept { +xena::future<> +multiexponentiate_impl(basct::span res, const partition_table_accessor& accessor, + basct::cspan output_bit_table, basct::cspan scalars, + const multiexponentiate_options& options) noexcept { auto num_outputs = res.size(); - auto n = scalars.size() / (num_outputs * element_num_bytes); - auto num_products = num_outputs * element_num_bytes * 8u; + auto num_products = std::accumulate(output_bit_table.begin(), output_bit_table.end(), 0u); + auto num_output_bytes = basn::divide_up(num_products, 8); SXT_DEBUG_ASSERT( // clang-format off - scalars.size() % (num_outputs * element_num_bytes) == 0 + scalars.size() % num_output_bytes == 0 // clang-format on ); + basdv::stream stream; + memr::async_device_resource resource{stream}; + memmg::managed_array products{num_products, &resource}; + co_await multiexponentiate_product_step(products, stream, accessor, num_products, + num_output_bytes, scalars, options); - // compute bitwise products - // - // We split the work by groups of generators so that a single chunk will process - // all the outputs for those generators. This minimizes the amount of host->device - // copying we need to do for the table of precomputed sums. - auto [chunk_first, chunk_last] = basit::split(basit::index_range{0, n} - .chunk_multiple(16) - .min_chunk_size(options.min_chunk_size) - .max_chunk_size(options.max_chunk_size), - options.split_factor); - auto num_chunks = std::distance(chunk_first, chunk_last); - if (num_chunks == 1) { - multiexponentiate_no_chunks(res, accessor, element_num_bytes, scalars); - co_return; - } - - memmg::managed_array products{num_products * num_chunks, memr::get_pinned_resource()}; - size_t chunk_index = 0; - basl::info("computing {} bitwise multiexponentiation products of length {} using {} chunks", - num_products, n, num_chunks); - co_await xendv::concurrent_for_each( - chunk_first, chunk_last, [&](const basit::index_range& rng) noexcept -> xena::future<> { - basl::info("computing {} multiproducts for generators [{}, {}] on device {}", num_products, - rng.a(), rng.b(), basdv::get_device()); - memmg::managed_array products_dev{num_products, memr::get_device_resource()}; - auto scalars_slice = scalars.subspan(num_outputs * element_num_bytes * rng.a(), - rng.size() * num_outputs * element_num_bytes); - co_await async_partition_product(products_dev, accessor, scalars_slice, rng.a()); - basdv::stream stream; - basdv::async_copy_device_to_host( - basct::subspan(products, num_products * chunk_index, num_products), products_dev, - stream); - ++chunk_index; - co_await xendv::await_stream(stream); - }); - - // complete the multi-exponentiation - basl::info("reducing products for {} outputs", num_outputs); - co_await complete_multiexponentiation(res, element_num_bytes, products); + // reduce products + basl::info("reducing {} products to {} outputs", num_products, num_products); + memmg::managed_array res_dev{num_outputs, &resource}; + reduce_products(res_dev, stream, output_bit_table, products); + products.reset(); basl::info("completed {} reductions", num_outputs); + + // copy result + basdv::async_copy_device_to_host(res, res_dev, stream); + co_await xendv::await_stream(stream); + basl::info("complete multiexponentiation"); } //-------------------------------------------------------------------------------------------------- @@ -196,6 +200,17 @@ async_multiexponentiate(basct::span res, const partition_table_accessor& a return multiexponentiate_impl(res, accessor, element_num_bytes, scalars, options); } +template + requires std::constructible_from +xena::future<> async_multiexponentiate(basct::span res, + const partition_table_accessor& accessor, + basct::cspan output_bit_table, + basct::cspan scalars) noexcept { + multiexponentiate_options options; + options.split_factor = static_cast(basdv::get_num_devices()); + return multiexponentiate_impl(res, accessor, output_bit_table, scalars, options); +} + //-------------------------------------------------------------------------------------------------- // multiexponentiate //-------------------------------------------------------------------------------------------------- diff --git a/sxt/multiexp/pippenger2/multiexponentiation.t.cc b/sxt/multiexp/pippenger2/multiexponentiation.t.cc index 707cf6ad..40181f90 100644 --- a/sxt/multiexp/pippenger2/multiexponentiation.t.cc +++ b/sxt/multiexp/pippenger2/multiexponentiation.t.cc @@ -155,6 +155,64 @@ TEST_CASE("we can compute multiexponentiations using a precomputed table of part } } +TEST_CASE("we can compute multiexponentiations with packed scalars") { + using E = bascrv::element97; + + std::vector generators(32); + std::mt19937 rng{0}; + for (auto& g : generators) { + g = std::uniform_int_distribution{0, 96}(rng); + } + + auto accessor = make_in_memory_partition_table_accessor(generators); + + std::vector scalars(1); + std::vector res(1); + std::vector output_bit_table(1); + + SECTION("we can compute a multiexponentiation for a single bit scalar") { + output_bit_table[0] = 1; + auto fut = async_multiexponentiate(res, *accessor, output_bit_table, scalars); + xens::get_scheduler().run(); + REQUIRE(fut.ready()); + REQUIRE(res[0] == E::identity()); + } + + SECTION("we can compute a multiexponentiation for a two bit scalar") { + output_bit_table[0] = 2; + scalars[0] = 2u; + auto fut = async_multiexponentiate(res, *accessor, output_bit_table, scalars); + xens::get_scheduler().run(); + REQUIRE(fut.ready()); + REQUIRE(res[0] == 2u * generators[0].value); + } + + SECTION("we can compute a multiexponentiation with multiple outputs of varying bit sizes") { + output_bit_table = {2, 1, 3}; + scalars = {0b110011}; + res.resize(3); + auto fut = async_multiexponentiate(res, *accessor, output_bit_table, scalars); + xens::get_scheduler().run(); + REQUIRE(fut.ready()); + REQUIRE(res[0] == 3u * generators[0].value); + REQUIRE(res[1] == 0u); + REQUIRE(res[2] == 6u * generators[0].value); + } + + SECTION("we can compute a multiexponentiation with multiple outputs of varying bit sizes and " + "length 2") { + output_bit_table = {2, 1, 3}; + scalars = {0b110011, 0b101101}; + res.resize(3); + auto fut = async_multiexponentiate(res, *accessor, output_bit_table, scalars); + xens::get_scheduler().run(); + REQUIRE(fut.ready()); + REQUIRE(res[0] == 3u * generators[0].value + generators[1].value); + REQUIRE(res[1] == generators[1].value); + REQUIRE(res[2] == 6u * generators[0].value + 5u * generators[1].value); + } +} + TEST_CASE("we can compute multiexponentiations with curve-21") { using E = c21t::element_p3; using Ep = c21t::compact_element;