Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support a packed interface for multi-exponentiation (PROOF-877) #140

Merged
merged 51 commits into from
Jul 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
ed0e5ae
work on packed interface
rnburn Jun 14, 2024
998060a
work on packed interface
rnburn Jun 14, 2024
1785b32
rework partition product
rnburn Jun 14, 2024
e49127c
rework mx
rnburn Jun 14, 2024
8bf11be
fill in packed interface
rnburn Jun 14, 2024
29d4897
fill in packed interface
rnburn Jun 14, 2024
276ca0e
refactor if
rnburn Jun 17, 2024
2373494
rework mx
rnburn Jun 17, 2024
886e22a
add test case
rnburn Jun 17, 2024
730b733
rework mx
rnburn Jun 17, 2024
d2acbf9
add stub for new reduce
rnburn Jun 17, 2024
79c7428
fill in new mx
rnburn Jun 17, 2024
998bcf6
fill in reduction
rnburn Jun 17, 2024
1976b1d
rework reduction
rnburn Jun 18, 2024
2de4edc
fill in reduction
rnburn Jun 18, 2024
d95bd1b
fill in reduction
rnburn Jun 18, 2024
c9ef907
fill in reduction tests
rnburn Jun 18, 2024
54bd548
fill in reduce tests
rnburn Jun 18, 2024
ad3f478
fill in reduce tests
rnburn Jun 18, 2024
04064b6
fill in tests
rnburn Jun 18, 2024
288e822
add multi-output test case
rnburn Jun 19, 2024
ec8fa47
fill in test cases
rnburn Jun 19, 2024
7cec172
reformat
rnburn Jun 19, 2024
a55418b
add round up
rnburn Jun 19, 2024
a6f7782
refactor partition product
rnburn Jun 19, 2024
eeb2a6d
Merge branch 'main' of github.com:spaceandtimelabs/blitzar into if-PR…
rnburn Jun 19, 2024
b91f66a
reformat
rnburn Jun 19, 2024
ff4987a
add assertion
rnburn Jun 20, 2024
47a70c8
add test case
rnburn Jun 20, 2024
f4ac19f
Merge branch 'main' of github.com:spaceandtimelabs/blitzar into if-PR…
rnburn Jun 20, 2024
744bcbf
Merge branch 'main' of github.com:spaceandtimelabs/blitzar into if-PR…
rnburn Jun 24, 2024
0c23df4
make reduction more efficient
rnburn Jun 25, 2024
93a3c6e
add assertions
rnburn Jun 25, 2024
d67db83
fix test name
rnburn Jun 25, 2024
6d590cb
Merge branch 'main' of github.com:spaceandtimelabs/blitzar into if-PR…
rnburn Jun 26, 2024
0964abd
rework mx
rnburn Jun 26, 2024
a13c6b8
rework mx
rnburn Jun 26, 2024
6b04653
rework mx
rnburn Jun 26, 2024
1d10992
rework mx
rnburn Jun 26, 2024
4d76ec1
rework mx
rnburn Jun 26, 2024
dec6f98
rework mx
rnburn Jun 27, 2024
8c9b138
rework mx
rnburn Jun 27, 2024
c2e2139
rework mx
rnburn Jun 27, 2024
e4709a7
drop dead code
rnburn Jun 27, 2024
8aab43c
add more logging
rnburn Jun 27, 2024
e69f9f6
rework
rnburn Jun 27, 2024
75fcf1a
drop dead code
rnburn Jun 27, 2024
04547f6
reformat
rnburn Jun 27, 2024
296f17d
add log
rnburn Jun 27, 2024
4fae570
Merge branch 'main' of github.com:spaceandtimelabs/blitzar into if-PR…
rnburn Jun 27, 2024
28da659
clean ups
rnburn Jun 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 99 additions & 84 deletions sxt/multiexp/pippenger2/multiexponentiation.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,128 +53,132 @@ struct multiexponentiate_options {
};

//--------------------------------------------------------------------------------------------------
// multiexponentiate_no_chunks
// multiexponentiate_product_step
//--------------------------------------------------------------------------------------------------
template <bascrv::element T, class U>
requires std::constructible_from<T, U>
xena::future<>
multiexponentiate_no_chunks(basct::span<T> res, const partition_table_accessor<U>& accessor,
unsigned element_num_bytes, basct::cspan<uint8_t> scalars) noexcept {
auto num_outputs = res.size();
auto n = scalars.size() / (num_outputs * element_num_bytes);
auto num_products = num_outputs * element_num_bytes * 8u;
SXT_DEBUG_ASSERT(
// clang-format off
scalars.size() % (num_outputs * element_num_bytes) == 0
// clang-format on
);
multiexponentiate_product_step(basct::span<T> products, basdv::stream& reduction_stream,
const partition_table_accessor<U>& accessor, unsigned num_products,
unsigned num_output_bytes, basct::cspan<uint8_t> scalars,
const multiexponentiate_options& options) noexcept {
auto n = scalars.size() / num_output_bytes;

// compute bitwise products
basl::info("computing {} bitwise multiexponentiation products of length {}", num_products, n);
memmg::managed_array<T> products(num_products, memr::get_device_resource());
co_await async_partition_product<T>(products, accessor, scalars, 0);
//
// We split the work by groups of generators so that a single chunk will process
// all the outputs for those generators. This minimizes the amount of host->device
// copying we need to do for the table of precomputed sums.
auto [chunk_first, chunk_last] = basit::split(basit::index_range{0, n}
.chunk_multiple(16)
.min_chunk_size(options.min_chunk_size)
.max_chunk_size(options.max_chunk_size),
options.split_factor);
auto num_chunks = static_cast<size_t>(std::distance(chunk_first, chunk_last));
basl::info("computing {} bitwise multiexponentiation products of length {} using {} chunks",
num_products, n, num_chunks);

// reduce products
basl::info("reducing {} products to {} outputs", num_products, num_products);
basdv::stream stream;
memr::async_device_resource resource{stream};
memmg::managed_array<T> res_dev{num_outputs, &resource};
reduce_products<T>(res_dev, stream, products);
products.reset();
// handle no chunk case
if (num_chunks == 1) {
co_await async_partition_product<T>(products, accessor, scalars, 0);
co_return;
}

// copy result
basdv::async_copy_device_to_host(res, res_dev, stream);
co_await xendv::await_stream(stream);
basl::info("completed {} reductions", num_outputs);
// handle multiple chunks
memmg::managed_array<T> partial_products{num_products * num_chunks, memr::get_pinned_resource()};
size_t chunk_index = 0;
co_await xendv::concurrent_for_each(
chunk_first, chunk_last, [&](const basit::index_range& rng) noexcept -> xena::future<> {
basl::info("computing {} multiproducts for generators [{}, {}] on device {}", num_products,
rng.a(), rng.b(), basdv::get_device());
memmg::managed_array<T> partial_products_dev{num_products, memr::get_device_resource()};
auto scalars_slice =
scalars.subspan(num_output_bytes * rng.a(), rng.size() * num_output_bytes);
co_await async_partition_product<T>(partial_products_dev, accessor, scalars_slice, rng.a());
basdv::stream stream;
basdv::async_copy_device_to_host(
basct::subspan(partial_products, num_products * chunk_index, num_products),
partial_products_dev, stream);
++chunk_index;
co_await xendv::await_stream(stream);
});

// combine the partial products
basl::info("combining {} partial product chunks", num_chunks);
memr::async_device_resource resource{reduction_stream};
memmg::managed_array<T> partial_products_dev{partial_products.size(), &resource};
basdv::async_copy_host_to_device(partial_products_dev, partial_products, reduction_stream);
combine<T>(products, reduction_stream, partial_products_dev);
co_await xendv::await_stream(reduction_stream);
}

//--------------------------------------------------------------------------------------------------
// complete_multiexponentiation
// multiexponentiate_impl
//--------------------------------------------------------------------------------------------------
template <bascrv::element T>
xena::future<> complete_multiexponentiation(basct::span<T> res, unsigned element_num_bytes,
basct::cspan<T> partial_products) noexcept {
template <bascrv::element T, class U>
requires std::constructible_from<T, U>
xena::future<> multiexponentiate_impl(basct::span<T> res,
const partition_table_accessor<U>& accessor,
unsigned element_num_bytes, basct::cspan<uint8_t> scalars,
const multiexponentiate_options& options) noexcept {
auto num_outputs = res.size();
auto num_products = num_outputs * element_num_bytes * 8u;
SXT_DEBUG_ASSERT(
// clang-format off
scalars.size() % (num_outputs * element_num_bytes) == 0
// clang-format on
);

basdv::stream stream;
memr::async_device_resource resource{stream};

// combine the partial results
memmg::managed_array<T> partial_products_dev{partial_products.size(), &resource};
basdv::async_copy_host_to_device(partial_products_dev, partial_products, stream);
memmg::managed_array<T> products{num_products, &resource};
combine<T>(products, stream, partial_products_dev);
partial_products_dev.reset();
co_await multiexponentiate_product_step<T>(products, stream, accessor, num_products,
num_outputs * element_num_bytes, scalars, options);

// reduce the products
basl::info("reducing products for {} outputs", num_outputs);
memmg::managed_array<T> res_dev{num_outputs, &resource};
reduce_products<T>(res_dev, stream, products);
products.reset();
basl::info("completed {} reductions", num_outputs);

// copy result
basdv::async_copy_device_to_host(res, res_dev, stream);
co_await xendv::await_stream(stream);
basl::info("complete multiexponentiation");
}

//--------------------------------------------------------------------------------------------------
// multiexponentiate_impl
//--------------------------------------------------------------------------------------------------
template <bascrv::element T, class U>
requires std::constructible_from<T, U>
xena::future<> multiexponentiate_impl(basct::span<T> res,
const partition_table_accessor<U>& accessor,
unsigned element_num_bytes, basct::cspan<uint8_t> scalars,
const multiexponentiate_options& options) noexcept {
xena::future<>
multiexponentiate_impl(basct::span<T> res, const partition_table_accessor<U>& accessor,
basct::cspan<unsigned> output_bit_table, basct::cspan<uint8_t> scalars,
const multiexponentiate_options& options) noexcept {
auto num_outputs = res.size();
auto n = scalars.size() / (num_outputs * element_num_bytes);
auto num_products = num_outputs * element_num_bytes * 8u;
auto num_products = std::accumulate(output_bit_table.begin(), output_bit_table.end(), 0u);
auto num_output_bytes = basn::divide_up<size_t>(num_products, 8);
SXT_DEBUG_ASSERT(
// clang-format off
scalars.size() % (num_outputs * element_num_bytes) == 0
scalars.size() % num_output_bytes == 0
// clang-format on
);
basdv::stream stream;
memr::async_device_resource resource{stream};
memmg::managed_array<T> products{num_products, &resource};
co_await multiexponentiate_product_step<T>(products, stream, accessor, num_products,
num_output_bytes, scalars, options);

// compute bitwise products
//
// We split the work by groups of generators so that a single chunk will process
// all the outputs for those generators. This minimizes the amount of host->device
// copying we need to do for the table of precomputed sums.
auto [chunk_first, chunk_last] = basit::split(basit::index_range{0, n}
.chunk_multiple(16)
.min_chunk_size(options.min_chunk_size)
.max_chunk_size(options.max_chunk_size),
options.split_factor);
auto num_chunks = std::distance(chunk_first, chunk_last);
if (num_chunks == 1) {
multiexponentiate_no_chunks(res, accessor, element_num_bytes, scalars);
co_return;
}

memmg::managed_array<T> products{num_products * num_chunks, memr::get_pinned_resource()};
size_t chunk_index = 0;
basl::info("computing {} bitwise multiexponentiation products of length {} using {} chunks",
num_products, n, num_chunks);
co_await xendv::concurrent_for_each(
chunk_first, chunk_last, [&](const basit::index_range& rng) noexcept -> xena::future<> {
basl::info("computing {} multiproducts for generators [{}, {}] on device {}", num_products,
rng.a(), rng.b(), basdv::get_device());
memmg::managed_array<T> products_dev{num_products, memr::get_device_resource()};
auto scalars_slice = scalars.subspan(num_outputs * element_num_bytes * rng.a(),
rng.size() * num_outputs * element_num_bytes);
co_await async_partition_product<T>(products_dev, accessor, scalars_slice, rng.a());
basdv::stream stream;
basdv::async_copy_device_to_host(
basct::subspan(products, num_products * chunk_index, num_products), products_dev,
stream);
++chunk_index;
co_await xendv::await_stream(stream);
});

// complete the multi-exponentiation
basl::info("reducing products for {} outputs", num_outputs);
co_await complete_multiexponentiation<T>(res, element_num_bytes, products);
// reduce products
basl::info("reducing {} products to {} outputs", num_products, num_products);
memmg::managed_array<T> res_dev{num_outputs, &resource};
reduce_products<T>(res_dev, stream, output_bit_table, products);
products.reset();
basl::info("completed {} reductions", num_outputs);

// copy result
basdv::async_copy_device_to_host(res, res_dev, stream);
co_await xendv::await_stream(stream);
basl::info("complete multiexponentiation");
}

//--------------------------------------------------------------------------------------------------
Expand All @@ -196,6 +200,17 @@ async_multiexponentiate(basct::span<T> res, const partition_table_accessor<U>& a
return multiexponentiate_impl(res, accessor, element_num_bytes, scalars, options);
}

template <bascrv::element T, class U>
requires std::constructible_from<T, U>
xena::future<> async_multiexponentiate(basct::span<T> res,
const partition_table_accessor<U>& accessor,
basct::cspan<unsigned> output_bit_table,
basct::cspan<uint8_t> scalars) noexcept {
multiexponentiate_options options;
options.split_factor = static_cast<unsigned>(basdv::get_num_devices());
return multiexponentiate_impl(res, accessor, output_bit_table, scalars, options);
}

//--------------------------------------------------------------------------------------------------
// multiexponentiate
//--------------------------------------------------------------------------------------------------
Expand Down
58 changes: 58 additions & 0 deletions sxt/multiexp/pippenger2/multiexponentiation.t.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,64 @@ TEST_CASE("we can compute multiexponentiations using a precomputed table of part
}
}

TEST_CASE("we can compute multiexponentiations with packed scalars") {
using E = bascrv::element97;

std::vector<E> generators(32);
std::mt19937 rng{0};
for (auto& g : generators) {
g = std::uniform_int_distribution<unsigned>{0, 96}(rng);
}

auto accessor = make_in_memory_partition_table_accessor<E>(generators);

std::vector<uint8_t> scalars(1);
std::vector<E> res(1);
std::vector<unsigned> output_bit_table(1);

SECTION("we can compute a multiexponentiation for a single bit scalar") {
output_bit_table[0] = 1;
auto fut = async_multiexponentiate<E>(res, *accessor, output_bit_table, scalars);
xens::get_scheduler().run();
REQUIRE(fut.ready());
REQUIRE(res[0] == E::identity());
}

SECTION("we can compute a multiexponentiation for a two bit scalar") {
output_bit_table[0] = 2;
scalars[0] = 2u;
auto fut = async_multiexponentiate<E>(res, *accessor, output_bit_table, scalars);
xens::get_scheduler().run();
REQUIRE(fut.ready());
REQUIRE(res[0] == 2u * generators[0].value);
}

SECTION("we can compute a multiexponentiation with multiple outputs of varying bit sizes") {
output_bit_table = {2, 1, 3};
scalars = {0b110011};
res.resize(3);
auto fut = async_multiexponentiate<E>(res, *accessor, output_bit_table, scalars);
xens::get_scheduler().run();
REQUIRE(fut.ready());
REQUIRE(res[0] == 3u * generators[0].value);
REQUIRE(res[1] == 0u);
REQUIRE(res[2] == 6u * generators[0].value);
}

SECTION("we can compute a multiexponentiation with multiple outputs of varying bit sizes and "
"length 2") {
output_bit_table = {2, 1, 3};
scalars = {0b110011, 0b101101};
res.resize(3);
auto fut = async_multiexponentiate<E>(res, *accessor, output_bit_table, scalars);
xens::get_scheduler().run();
REQUIRE(fut.ready());
REQUIRE(res[0] == 3u * generators[0].value + generators[1].value);
REQUIRE(res[1] == generators[1].value);
REQUIRE(res[2] == 6u * generators[0].value + 5u * generators[1].value);
}
}

TEST_CASE("we can compute multiexponentiations with curve-21") {
using E = c21t::element_p3;
using Ep = c21t::compact_element;
Expand Down
Loading