diff --git a/.vscode/setting.json b/.vscode/setting.json new file mode 100644 index 00000000..ac4dec91 --- /dev/null +++ b/.vscode/setting.json @@ -0,0 +1,11 @@ +{ + "editor.formatOnSave": true, + "files.watcherExclude": { + "**/.git/**": true, + "**/.cache/**": true, + "**/bazel-*/**": true, + "**/external/**": true + }, + "files.insertFinalNewline": true, + "files.trimTrailingWhitespace": true +} \ No newline at end of file diff --git a/psi/algorithm/BUILD.bazel b/psi/algorithm/BUILD.bazel index 49bf6a00..e32c25af 100644 --- a/psi/algorithm/BUILD.bazel +++ b/psi/algorithm/BUILD.bazel @@ -29,6 +29,7 @@ psi_cc_library( hdrs = ["psi_io.h"], deps = [ ":types", + "//psi/utils:hash_bucket_cache", ], ) diff --git a/psi/algorithm/psi_io.h b/psi/algorithm/psi_io.h index bfb0d82c..7c44926e 100644 --- a/psi/algorithm/psi_io.h +++ b/psi/algorithm/psi_io.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include #include @@ -21,6 +22,7 @@ #include "yacl/base/int128.h" #include "psi/algorithm/types.h" +#include "psi/utils/hash_bucket_cache.h" namespace psi { @@ -57,13 +59,27 @@ class IDataProvider { public: virtual std::vector ReadNext(size_t size) = 0; virtual std::vector ReadAll() = 0; + [[nodiscard]] virtual size_t Size() const = 0; }; class IDataStore { public: [[nodiscard]] virtual size_t GetBucketNum() const = 0; + [[nodiscard]] virtual size_t GetBucketDatasize(size_t tag) const = 0; virtual std::shared_ptr Load(size_t tag) = 0; virtual ~IDataStore() = default; }; +class IBucketDataStore { + public: + virtual std::vector GetBucketItems( + size_t bucket_idx) = 0; + virtual void WriteIntersetionItems( + size_t bucket_idx, const std::vector& items, + const std::vector& intersection_indices, + const std::vector& peer_dup_cnts) = 0; + virtual std::pair GetBucketDatasize(size_t bucket_idx) = 0; + virtual ~IBucketDataStore() = default; +}; + } // namespace psi diff --git a/psi/algorithm/rr22/BUILD.bazel b/psi/algorithm/rr22/BUILD.bazel index 71d54a16..d7e4eab5 100644 --- a/psi/algorithm/rr22/BUILD.bazel +++ b/psi/algorithm/rr22/BUILD.bazel @@ -102,6 +102,7 @@ psi_cc_library( ":rr22_utils", "//psi/proto:psi_v2_cc_proto", "//psi/utils:bucket", + "//psi/utils:simple_channel", "//psi/utils:sync", ], ) diff --git a/psi/algorithm/rr22/common.cc b/psi/algorithm/rr22/common.cc index 8f7fff6b..13f40c0d 100644 --- a/psi/algorithm/rr22/common.cc +++ b/psi/algorithm/rr22/common.cc @@ -16,6 +16,7 @@ #include "omp.h" +#include "psi/algorithm/rr22/rr22_psi.h" #include "psi/utils/bucket.h" namespace psi::rr22 { @@ -28,4 +29,52 @@ Rr22PsiOptions GenerateRr22PsiOptions(bool low_comm_mode) { return options; } +BucketDataStoreImpl::BucketDataStoreImpl( + std::shared_ptr lctx, + HashBucketCache* input_bucket_store, + IndexWriter* intersection_indices_writer, RecoveryManager* recovery_manager) + : IBucketDataStore(), + input_bucket_store_(input_bucket_store), + intersection_indices_writer_(intersection_indices_writer), + recovery_manager_(recovery_manager), + lctx_(std::move(lctx)) { + self_sizes_ = std::vector(input_bucket_store_->BucketNum()); + for (size_t i = 0; i < self_sizes_.size(); i++) { + self_sizes_[i] = input_bucket_store_->GetBucketSize(i); + } + peer_sizes_ = std::vector(input_bucket_store_->BucketNum()); + yacl::ByteContainerView buffer(self_sizes_.data(), + self_sizes_.size() * sizeof(uint32_t)); + auto data = yacl::link::AllGather(lctx_, buffer, "exchange size"); + std::memcpy(peer_sizes_.data(), data[lctx_->NextRank()].data(), + self_sizes_.size() * sizeof(uint32_t)); +} + +std::vector BucketDataStoreImpl::GetBucketItems( + size_t bucket_idx) { + if (bucket_idx >= input_bucket_store_->BucketNum()) { + return {}; + } + return input_bucket_store_->LoadBucketItems(bucket_idx); +} + +void BucketDataStoreImpl::WriteIntersetionItems( + size_t bucket_idx, const std::vector& items, + const std::vector& intersection_indices, + const std::vector& peer_dup_cnts) { + for (size_t i = 0; i != intersection_indices.size(); ++i) { + intersection_indices_writer_->WriteCache( + items[intersection_indices[i]].index, peer_dup_cnts[i]); + } + intersection_indices_writer_->Commit(); + if (recovery_manager_ != nullptr) { + recovery_manager_->UpdateParsedBucketCount(bucket_idx + 1); + } +} + +std::pair BucketDataStoreImpl::GetBucketDatasize( + size_t bucket_idx) { + return std::make_pair(self_sizes_[bucket_idx], peer_sizes_[bucket_idx]); +} + } // namespace psi::rr22 diff --git a/psi/algorithm/rr22/common.h b/psi/algorithm/rr22/common.h index 0a70d093..a14f5634 100644 --- a/psi/algorithm/rr22/common.h +++ b/psi/algorithm/rr22/common.h @@ -31,4 +31,30 @@ constexpr bool kDefaultCompress = true; Rr22PsiOptions GenerateRr22PsiOptions(bool low_comm_mode); +class BucketDataStoreImpl : public IBucketDataStore { + public: + BucketDataStoreImpl(std::shared_ptr lctx, + HashBucketCache* input_bucket_store, + IndexWriter* intersection_indices_writer, + RecoveryManager* recovery_manager); + + std::vector GetBucketItems( + size_t bucket_idx) override; + + void WriteIntersetionItems( + size_t bucket_idx, const std::vector& items, + const std::vector& intersection_indices, + const std::vector& peer_dup_cnts) override; + + std::pair GetBucketDatasize(size_t bucket_idx) override; + + private: + HashBucketCache* input_bucket_store_; + IndexWriter* intersection_indices_writer_; + RecoveryManager* recovery_manager_; + std::shared_ptr lctx_; + std::vector peer_sizes_; + std::vector self_sizes_; +}; + } // namespace psi::rr22 diff --git a/psi/algorithm/rr22/receiver.cc b/psi/algorithm/rr22/receiver.cc index 7b016037..dfa97e8d 100644 --- a/psi/algorithm/rr22/receiver.cc +++ b/psi/algorithm/rr22/receiver.cc @@ -112,32 +112,17 @@ void Rr22PsiReceiver::Online() { Rr22PsiOptions rr22_options = GenerateRr22PsiOptions( config_.protocol_config().rr22_config().low_comm_mode()); - PreProcessFunc pre_f = - [&](size_t idx) -> std::vector { - if (idx >= input_bucket_store_->BucketNum()) { - return {}; - } - return input_bucket_store_->LoadBucketItems(idx); - }; - PostProcessFunc post_f = - [&](size_t bucket_idx, - const std::vector& bucket_items, - const std::vector& indices, - const std::vector& peer_cnt) { - for (size_t i = 0; i != indices.size(); ++i) { - intersection_indices_writer_->WriteCache( - bucket_items[indices[i]].index, peer_cnt[i]); - } - intersection_indices_writer_->Commit(); - if (recovery_manager_) { - recovery_manager_->UpdateParsedBucketCount(bucket_idx + 1); - } - }; - + BucketDataStoreImpl data_processor(lctx_, input_bucket_store_.get(), + intersection_indices_writer_.get(), + recovery_manager_.get()); Rr22Runner runner(lctx_, rr22_options, input_bucket_store_->BucketNum(), - config_.protocol_config().broadcast_result(), pre_f, - post_f); - SyncWait(lctx_, [&] { runner.AsyncRun(bucket_idx, false); }); + config_.protocol_config().broadcast_result(), + &data_processor); + auto scoped_temp_dir = std::make_unique(); + scoped_temp_dir->CreateUniqueTempDirUnderPath(GetTaskDir()); + SyncWait(lctx_, [&] { + runner.AsyncRun(bucket_idx, false, scoped_temp_dir->path()); + }); SPDLOG_INFO("[Rr22PsiReceiver::Online] end"); } diff --git a/psi/algorithm/rr22/rr22_operator.cc b/psi/algorithm/rr22/rr22_operator.cc index 16ef3b03..b26e3251 100644 --- a/psi/algorithm/rr22/rr22_operator.cc +++ b/psi/algorithm/rr22/rr22_operator.cc @@ -14,26 +14,46 @@ #include "psi/algorithm/rr22/rr22_operator.h" +#include +#include +#include #include +#include +#include -namespace psi::rr22 { +#include "psi/algorithm/rr22/rr22_psi.h" -Rr22Operator::Rr22Operator(Options opts, - std::shared_ptr input_store, - std::shared_ptr output_store) - : PsiOperator(opts.lctx, std::move(input_store), std::move(output_store), - opts.recovery_manager, opts.broadcast_result), - opts_(std::move(opts)) {} +namespace psi::rr22 { -bool Rr22Operator::ReceiveResult() { - return opts_.receiver_rank == opts_.lctx->Rank() || opts_.broadcast_result; -} +namespace { -void Rr22Operator::OnInit() {} +// TODO: refactor to reduce duplicated code +class BucketDataStoreImpl : public IBucketDataStore { + public: + BucketDataStoreImpl(std::shared_ptr lctx, + IDataStore* input_store, IResultStore* output_store, + RecoveryManager* recovery_manager) + : input_store_(input_store), + output_store_(output_store), + recovery_manager_(recovery_manager), + lctx_(std::move(lctx)) { + self_sizes_ = std::vector(input_store_->GetBucketNum()); + for (size_t i = 0; i < self_sizes_.size(); i++) { + self_sizes_[i] = input_store_->GetBucketDatasize(i); + } + peer_sizes_ = std::vector(input_store_->GetBucketNum()); + yacl::ByteContainerView buffer(self_sizes_.data(), + self_sizes_.size() * sizeof(uint32_t)); + auto data = yacl::link::AllGather(lctx_, buffer, "exchange size"); + std::memcpy(peer_sizes_.data(), data[lctx_->NextRank()].data(), + self_sizes_.size() * sizeof(uint32_t)); + }; -void Rr22Operator::OnRun() { - PreProcessFunc pre_process_func = - [this](size_t bucket_idx) -> std::vector { + std::vector GetBucketItems( + size_t bucket_idx) override { + if (bucket_idx >= input_store_->GetBucketNum()) { + return {}; + } std::vector bucket_items; auto provider = input_store_->Load(bucket_idx); auto item_datas = provider->ReadAll(); @@ -46,22 +66,57 @@ void Rr22Operator::OnRun() { } return bucket_items; }; - PostProcessFunc post_process_func = - [this](size_t bucket_idx, - const std::vector& bucket_items, - const std::vector& intersection_indices, - const std::vector& peer_dup_cnts) { - std::vector indices; - indices.reserve(intersection_indices.size()); - - for (size_t i = 0; i < intersection_indices.size(); ++i) { - indices.emplace_back(PsiResultIndex{ - .data = bucket_items[intersection_indices[i]].index, - .peer_item_cnt = peer_dup_cnts[i] + 1}); - } - auto recevier = output_store_->GetReceiver(bucket_idx); - recevier->Add(std::move(indices)); - }; + + void WriteIntersetionItems( + size_t bucket_idx, const std::vector& items, + const std::vector& intersection_indices, + const std::vector& peer_dup_cnts) override { + std::vector indices; + indices.reserve(intersection_indices.size()); + + for (size_t i = 0; i < intersection_indices.size(); ++i) { + indices.emplace_back( + PsiResultIndex{.data = items[intersection_indices[i]].index, + .peer_item_cnt = peer_dup_cnts[i] + 1}); + } + auto recevier = output_store_->GetReceiver(bucket_idx); + recevier->Add(std::move(indices)); + if (recovery_manager_ != nullptr) { + recovery_manager_->UpdateParsedBucketCount(bucket_idx + 1); + } + }; + + std::pair GetBucketDatasize(size_t bucket_idx) override { + return std::make_pair(self_sizes_[bucket_idx], peer_sizes_[bucket_idx]); + }; + + private: + IDataStore* input_store_; + IResultStore* output_store_; + RecoveryManager* recovery_manager_; + std::shared_ptr lctx_; + std::vector peer_sizes_; + std::vector self_sizes_; +}; +} // namespace + +Rr22Operator::Rr22Operator(Options opts, + std::shared_ptr input_store, + std::shared_ptr output_store) + : PsiOperator(opts.lctx, std::move(input_store), std::move(output_store), + opts.recovery_manager, opts.broadcast_result), + opts_(std::move(opts)) {} + +bool Rr22Operator::ReceiveResult() { + return opts_.receiver_rank == opts_.lctx->Rank() || opts_.broadcast_result; +} + +void Rr22Operator::OnInit() {} + +void Rr22Operator::OnRun() { + BucketDataStoreImpl data_processor(link_ctx_, input_store_.get(), + output_store_.get(), + recovery_manager_.get()); size_t bucket_idx = recovery_manager_ @@ -70,11 +125,11 @@ void Rr22Operator::OnRun() { : 0; Rr22Runner runner(link_ctx_, opts_.rr22_opts, input_store_->GetBucketNum(), - opts_.broadcast_result, pre_process_func, - post_process_func); + opts_.broadcast_result, &data_processor); if (opts_.pipeline_mode) { - runner.AsyncRun(bucket_idx, opts_.lctx->Rank() != opts_.receiver_rank); + runner.AsyncRun(bucket_idx, opts_.lctx->Rank() != opts_.receiver_rank, + opts_.cache_dir); } else { runner.ParallelRun(bucket_idx, opts_.lctx->Rank() != opts_.receiver_rank, opts_.parallel_level); diff --git a/psi/algorithm/rr22/rr22_operator.h b/psi/algorithm/rr22/rr22_operator.h index 13249ee3..22d68565 100644 --- a/psi/algorithm/rr22/rr22_operator.h +++ b/psi/algorithm/rr22/rr22_operator.h @@ -31,6 +31,8 @@ class Rr22Operator : public PsiOperator { size_t parallel_level = 6; std::shared_ptr recovery_manager = nullptr; + std::string cache_dir = + std::filesystem::temp_directory_path() / GetRandomString(); }; public: diff --git a/psi/algorithm/rr22/rr22_oprf.cc b/psi/algorithm/rr22/rr22_oprf.cc index 9636447b..578fa36a 100644 --- a/psi/algorithm/rr22/rr22_oprf.cc +++ b/psi/algorithm/rr22/rr22_oprf.cc @@ -16,8 +16,11 @@ #include #include +#include #include +#include #include +#include #include #include "spdlog/spdlog.h" @@ -32,6 +35,7 @@ #include "psi/algorithm/rr22/davis_meyer_hash.h" #include "psi/algorithm/rr22/okvs/galois128.h" #include "psi/algorithm/rr22/rr22_utils.h" +#include "psi/utils/io.h" namespace psi::rr22 { @@ -132,8 +136,16 @@ std::vector Rr22OprfSender::Send( } void Rr22OprfSender::Init(const std::shared_ptr& lctx, - size_t peer_size, size_t num_threads) { + size_t peer_size, size_t num_threads, bool cache_vole, + const std::filesystem::path& cache_dir) { + cache_vole_ = cache_vole; num_threads_ = num_threads; + + if (cache_vole) { + scoped_temp_dir_ = std::make_unique(); + YACL_ENFORCE(scoped_temp_dir_->CreateUniqueTempDirUnderPath(cache_dir)); + } + if (mode_ == Rr22PsiMode::FastMode) { uint128_t baxos_seed; SPDLOG_INFO("recv baxos seed..."); @@ -160,7 +172,6 @@ void Rr22OprfSender::Init(const std::shared_ptr& lctx, vole_sender.Send(lctx, b128_span); delta_ = vole_sender.GetDelta(); - SPDLOG_INFO("end vole send"); } else if (mode_ == Rr22PsiMode::LowCommMode) { uint128_t paxos_seed; @@ -188,16 +199,24 @@ void Rr22OprfSender::Init(const std::shared_ptr& lctx, vole_sender.SfSend(lctx, b128_span); delta_ = vole_sender.GetDelta(); - SPDLOG_INFO("end vole send"); } else { YACL_THROW("unsupported mode:{}", int(mode_)); } + if (cache_vole_) { + v_b_ = std::make_shared(scoped_temp_dir_->path() / "v_b"); + v_b_->WriteVector(b_); + b_.clear(); + b_.shrink_to_fit(); + } } std::vector Rr22OprfSender::SendFast( const std::shared_ptr& lctx, const absl::Span& inputs) { + if (cache_vole_) { + b_ = v_b_->ReadVector(); + } uint128_t ws = 0; if (malicious_) { SPDLOG_INFO("malicious version"); @@ -226,7 +245,11 @@ std::vector Rr22OprfSender::SendFast( } SPDLOG_INFO("recv paxos solve ..."); - auto paxos_solve_v = RecvChunked(lctx, paxos_size_); + std::vector paxos_solve_v(paxos_size_, 0); + yacl::Buffer paxos_solve_buf = + lctx->Recv(lctx->NextRank(), fmt::format("recv paxos_solve")); + std::memcpy(paxos_solve_v.data(), paxos_solve_buf.data(), + paxos_solve_buf.size()); SPDLOG_INFO("recv paxos solve finished. bytes:{}", paxos_solve_v.size() * sizeof(uint128_t)); @@ -248,10 +271,17 @@ std::vector Rr22OprfSender::SendFast( std::vector Rr22OprfSender::SendLowComm( const std::shared_ptr& lctx, const absl::Span& inputs) { + if (cache_vole_) { + b_ = v_b_->ReadVector(); + } auto hash_outputs = HashInputMulDelta(inputs); SPDLOG_INFO("recv paxos solve ..."); - auto paxos_solve_v = RecvChunked(lctx, paxos_size_); + std::vector paxos_solve_v(paxos_size_, 0); + yacl::Buffer paxos_solve_buf = + lctx->Recv(lctx->NextRank(), fmt::format("recv paxos_solve")); + std::memcpy(paxos_solve_v.data(), paxos_solve_buf.data(), + paxos_solve_buf.size()); SPDLOG_INFO("recv paxos solve finished. bytes:{}", paxos_solve_v.size() * sizeof(uint64_t)); @@ -321,7 +351,7 @@ std::vector Rr22OprfSender::Eval( } SPDLOG_INFO("paxos decode finished"); b_.clear(); - + b_.shrink_to_fit(); okvs::AesCrHash aes_crhash(kAesHashSeed); okvs::Galois128 delta_gf128(delta_); @@ -376,6 +406,7 @@ std::vector Rr22OprfSender::Eval( } SPDLOG_INFO("paxos decode finished"); b_.clear(); + b_.shrink_to_fit(); yacl::parallel_for(0, inputs.size(), [&](int64_t begin, int64_t end) { for (int64_t idx = begin; idx < end; ++idx) { @@ -398,8 +429,16 @@ std::vector Rr22OprfSender::Eval( } void Rr22OprfReceiver::Init(const std::shared_ptr& lctx, - size_t self_size, size_t num_threads) { + size_t self_size, size_t num_threads, + bool cache_vole, + const std::filesystem::path& cache_dir) { + cache_vole_ = cache_vole; num_threads_ = num_threads; + if (cache_vole) { + scoped_temp_dir_ = std::make_unique(); + YACL_ENFORCE(scoped_temp_dir_->CreateUniqueTempDirUnderPath(cache_dir)); + } + if (mode_ == Rr22PsiMode::FastMode) { uint128_t baxos_seed = yacl::crypto::SecureRandU128(); yacl::ByteContainerView paxos_seed_buf(&baxos_seed, sizeof(uint128_t)); @@ -416,10 +455,15 @@ void Rr22OprfReceiver::Init(const std::shared_ptr& lctx, size_t v_size = std::max(256, baxos_.size()); a_ = std::vector(v_size, 0); c_ = std::vector(v_size, 0); - SPDLOG_INFO("begin vole recv"); vole_receiver.Recv(lctx, absl::MakeSpan(a_), absl::MakeSpan(c_)); SPDLOG_INFO("end vole recv"); + if (cache_vole_) { + v_a_ = std::make_shared(scoped_temp_dir_->path() / "v_a"); + v_a_->WriteVector(a_); + a_.clear(); + a_.shrink_to_fit(); + } } else if (mode_ == Rr22PsiMode::LowCommMode) { uint128_t paxos_seed = yacl::crypto::SecureRandU128(); yacl::ByteContainerView paxos_seed_buf(&paxos_seed, sizeof(uint128_t)); @@ -438,14 +482,25 @@ void Rr22OprfReceiver::Init(const std::shared_ptr& lctx, size_t v_size = std::max(256, paxos_.size()); a64_ = std::vector(v_size, 0); c_ = std::vector(v_size, 0); - SPDLOG_INFO("begin vole recv"); vole_receiver.SfRecv(lctx, absl::MakeSpan(a64_), absl::MakeSpan(c_)); SPDLOG_INFO("end vole recv"); - + if (cache_vole_) { + v_a64_ = + std::make_shared(scoped_temp_dir_->path() / "v_a64"); + v_a64_->WriteVector(a64_); + a64_.clear(); + a64_.shrink_to_fit(); + } } else { YACL_THROW("unsupported mode:{}", int(mode_)); } + if (cache_vole_) { + v_c_ = std::make_shared(scoped_temp_dir_->path() / "v_c"); + v_c_->WriteVector(c_); + c_.clear(); + c_.shrink_to_fit(); + } } std::vector Rr22OprfReceiver::Recv( @@ -461,6 +516,10 @@ std::vector Rr22OprfReceiver::Recv( std::vector Rr22OprfReceiver::RecvFast( const std::shared_ptr& lctx, const absl::Span& inputs) { + if (cache_vole_) { + a_ = v_a_->ReadVector(); + c_ = v_c_->ReadVector(); + } uint128_t w = 0; uint128_t wr = 0; yacl::Buffer ws_hash_buf; @@ -471,7 +530,6 @@ std::vector Rr22OprfReceiver::RecvFast( ws_hash_buf = lctx->Recv(lctx->NextRank(), fmt::format("recv ws_hash")); YACL_ENFORCE(ws_hash_buf.size() == 32); } - okvs::AesCrHash aes_crhash(kAesHashSeed); std::vector outputs(inputs.size(), 0); auto outputs_span = absl::MakeSpan(outputs); @@ -481,7 +539,6 @@ std::vector Rr22OprfReceiver::RecvFast( auto p128_span = absl::MakeSpan(p128_v); baxos_.Solve(inputs, outputs_span, p128_span, nullptr, num_threads_); SPDLOG_INFO("solve end"); - if (malicious_) { // send wr lctx->SendAsyncThrottled(lctx->NextRank(), @@ -509,6 +566,7 @@ std::vector Rr22OprfReceiver::RecvFast( baxos_.Decode(inputs, outputs_span, absl::MakeSpan(c_.data(), baxos_.size()), num_threads_); c_.clear(); + c_.shrink_to_fit(); if (malicious_) { for (size_t i = 0; i < outputs.size(); ++i) { outputs[i] = outputs[i] ^ w; @@ -533,10 +591,12 @@ std::vector Rr22OprfReceiver::RecvFast( } }); a_.clear(); - + a_.shrink_to_fit(); SPDLOG_INFO("end p xor a"); - - SendChunked(lctx, p128_span); + yacl::ByteContainerView buffer(p128_span.data(), + p128_span.size() * sizeof(uint128_t)); + lctx->SendAsyncThrottled(lctx->NextRank(), buffer, + "send paxos_solve_byteview"); oprf_eval_proc.get(); return outputs; } @@ -544,6 +604,10 @@ std::vector Rr22OprfReceiver::RecvFast( std::vector Rr22OprfReceiver::RecvLowComm( const std::shared_ptr& lctx, const absl::Span& inputs) { + if (cache_vole_) { + a64_ = v_a64_->ReadVector(); + c_ = v_c_->ReadVector(); + } okvs::AesCrHash aes_crhash(kAesHashSeed); std::vector outputs(inputs.size()); auto outputs_span = absl::MakeSpan(outputs); @@ -572,6 +636,7 @@ std::vector Rr22OprfReceiver::RecvLowComm( paxos_.Decode(inputs, outputs_span, absl::MakeSpan(c_.data(), paxos_.size())); c_.clear(); + c_.shrink_to_fit(); // oprf end output aes_crhash.Hash(outputs_span, outputs_span); SPDLOG_INFO("end receiver oprf"); @@ -583,9 +648,13 @@ std::vector Rr22OprfReceiver::RecvLowComm( p64_span[i] ^= a64_[i]; } a64_.clear(); + a64_.shrink_to_fit(); SPDLOG_INFO("end p xor a"); - SendChunked(lctx, p64_span); + yacl::ByteContainerView buffer(p64_span.data(), + p64_span.size() * sizeof(uint64_t)); + lctx->SendAsyncThrottled(lctx->NextRank(), buffer, + "send paxos_solve_byteview"); oprf_eval_proc.get(); return outputs; } diff --git a/psi/algorithm/rr22/rr22_oprf.h b/psi/algorithm/rr22/rr22_oprf.h index 1c766212..f908e064 100644 --- a/psi/algorithm/rr22/rr22_oprf.h +++ b/psi/algorithm/rr22/rr22_oprf.h @@ -15,7 +15,9 @@ #pragma once #include +#include #include +#include #include #include "yacl/base/int128.h" @@ -23,6 +25,9 @@ #include "yacl/link/context.h" #include "psi/algorithm/rr22/okvs/baxos.h" +#include "psi/utils/io.h" +#include "psi/utils/multiplex_disk_cache.h" +#include "psi/utils/random_str.h" // Reference: // Blazing Fast PSI from Improved OKVS and Subfield VOLE @@ -39,6 +44,33 @@ enum class Rr22PsiMode { LowCommMode, }; +// can only write one vector +class VectorCache { + public: + explicit VectorCache(const std::string& file_name) { + file_options_.file_name = file_name; + } + template + void WriteVector(const std::vector& v) { + auto output_stream = io::BuildOutputStream(file_options_); + size_in_bytes_ = v.size() * sizeof(T); + output_stream->Write(v.data(), size_in_bytes_); + output_stream->Close(); + } + template + std::vector ReadVector() { + auto input_stream = io::BuildInputStream(file_options_); + YACL_ENFORCE(size_in_bytes_ % sizeof(T) == 0, + "Size mismatch in VectorCache ReadVector"); + std::vector v(size_in_bytes_ / sizeof(T)); + input_stream->Read(v.data(), size_in_bytes_); + input_stream->Close(); + return v; + } + io::FileIoOptions file_options_; + size_t size_in_bytes_ = 0; +}; + class MocRr22VoleSender { public: explicit MocRr22VoleSender(uint128_t seed); @@ -116,6 +148,8 @@ class Rr22Oprf { bool debug_ = false; size_t paxos_size_ = 0; + + bool cache_vole_ = false; }; class Rr22OprfSender : public Rr22Oprf { @@ -129,8 +163,10 @@ class Rr22OprfSender : public Rr22Oprf { YACL_THROW("RR22 malicious psi not support LowCommMode"); } } - void Init(const std::shared_ptr& lctx, size_t init_size, - size_t num_threads = 0); + void Init(const std::shared_ptr& lctx, size_t peer_size, + size_t num_threads = 0, bool cache_vole = false, + const std::filesystem::path& cache_dir = + std::filesystem::temp_directory_path() / GetRandomString()); std::vector Send(const std::shared_ptr& lctx, const absl::Span& inputs); @@ -159,6 +195,8 @@ class Rr22OprfSender : public Rr22Oprf { // b = delta * a + c uint128_t delta_ = 0; std::vector b_; + std::shared_ptr v_b_; + std::unique_ptr scoped_temp_dir_; }; class Rr22OprfReceiver : public Rr22Oprf { @@ -173,8 +211,10 @@ class Rr22OprfReceiver : public Rr22Oprf { } } - void Init(const std::shared_ptr& lctx, size_t init_size, - size_t num_threads = 0); + void Init(const std::shared_ptr& lctx, size_t self_size, + size_t num_threads = 0, bool cache_vole = false, + const std::filesystem::path& cache_dir = + std::filesystem::temp_directory_path() / GetRandomString()); std::vector Recv(const std::shared_ptr& lctx, const absl::Span& inputs); @@ -196,6 +236,10 @@ class Rr22OprfReceiver : public Rr22Oprf { // low comm use int64 std::vector a64_; std::vector c_; + std::shared_ptr v_a_; + std::shared_ptr v_a64_; + std::shared_ptr v_c_; + std::unique_ptr scoped_temp_dir_; }; } // namespace psi::rr22 diff --git a/psi/algorithm/rr22/rr22_psi.cc b/psi/algorithm/rr22/rr22_psi.cc index 0b2ead04..efd6611f 100644 --- a/psi/algorithm/rr22/rr22_psi.cc +++ b/psi/algorithm/rr22/rr22_psi.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -25,6 +26,7 @@ #include #include +#include "sparsehash/dense_hash_map" #include "yacl/base/byte_container_view.h" #include "yacl/utils/parallel.h" @@ -32,12 +34,26 @@ #include "psi/algorithm/rr22/rr22_oprf.h" #include "psi/algorithm/rr22/rr22_utils.h" #include "psi/utils/bucket.h" +#include "psi/utils/serialize.h" #include "psi/utils/sync.h" namespace psi::rr22 { namespace { +struct NoHash { + inline size_t operator()(const uint128_t& v) const { + uint32_t v32; + std::memcpy(&v32, &v, sizeof(uint32_t)); + + return v32; + } +}; + +} // namespace + +namespace { + size_t ComputeTruncateSize(size_t self_size, size_t peer_size, size_t ssp, bool malicious) { size_t truncate_size = @@ -53,37 +69,38 @@ size_t ComputeTruncateSize(size_t self_size, size_t peer_size, size_t ssp, } // namespace -std::pair ExchangeTruncateSize( - const std::shared_ptr& lctx, size_t self_size, - const Rr22PsiOptions& options) { - // Gather Items Size - std::vector items_size = AllGatherItemsSize(lctx, self_size); - - YACL_ENFORCE(self_size == items_size[lctx->Rank()]); - size_t peer_size = items_size[lctx->NextRank()]; +size_t ComputeMaskSize(const Rr22PsiOptions& options, size_t self_size, + size_t peer_size) { size_t mask_size = sizeof(uint128_t); if (options.compress) { mask_size = ComputeTruncateSize(self_size, peer_size, options.ssp, options.malicious); } - return {mask_size, peer_size}; + return mask_size; } -void BucketRr22Sender::Prepare( - const std::shared_ptr& lctx) { - bucket_items_ = pre_f_(bucket_idx_); - self_size_ = bucket_items_.size(); - std::mt19937 g(yacl::crypto::SecureRandU64()); - std::shuffle(bucket_items_.begin(), bucket_items_.end(), g); - - self_size_ = bucket_items_.size(); - std::tie(mask_size_, peer_size_) = - ExchangeTruncateSize(lctx, bucket_items_.size(), rr22_options_); +void BucketRr22Sender::Vole(const std::shared_ptr& lctx, + bool cache_vole, + const std::filesystem::path& cache_dir) { + std::tie(self_size_, peer_size_) = + data_processor_->GetBucketDatasize(bucket_idx_); + mask_size_ = ComputeMaskSize(rr22_options_, self_size_, peer_size_); SPDLOG_INFO("mask size: {}", mask_size_); if ((peer_size_ == 0) || (self_size_ == 0)) { null_bucket_ = true; return; } + oprf_sender_.Init(lctx, peer_size_, rr22_options_.num_threads, cache_vole, + cache_dir); +} + +void BucketRr22Sender::RunOprf(const std::shared_ptr&) { + if (null_bucket_) { + return; + } + bucket_items_ = data_processor_->GetBucketItems(bucket_idx_); + std::mt19937 g(yacl::crypto::SecureRandU64()); + std::shuffle(bucket_items_.begin(), bucket_items_.end(), g); inputs_hash_ = std::vector(bucket_items_.size()); yacl::parallel_for(0, bucket_items_.size(), [&](int64_t begin, int64_t end) { @@ -91,54 +108,73 @@ void BucketRr22Sender::Prepare( inputs_hash_[i] = yacl::crypto::Blake3_128(bucket_items_[i].base64_data); } }); - oprf_sender_.Init(lctx, peer_size_, rr22_options_.num_threads); } -void BucketRr22Sender::RunOprf( +void BucketRr22Sender::Intersection( const std::shared_ptr& lctx) { if (null_bucket_) { + data_processor_->WriteIntersetionItems(bucket_idx_, bucket_items_, {}, {}); return; } auto inputs_hash = oprf_sender_.Send(lctx, inputs_hash_); oprfs_ = oprf_sender_.Eval(inputs_hash_, inputs_hash); + SPDLOG_INFO("get intersection begin"); + bool compress = mask_size_ != sizeof(uint128_t); + auto* data_ptr = reinterpret_cast(oprfs_.data()); + if (compress) { + for (size_t i = 0; i < oprfs_.size(); ++i) { + std::memmove(data_ptr + (i * mask_size_), &oprfs_[i], mask_size_); + } + } + yacl::ByteContainerView send_buffer(data_ptr, oprfs_.size() * mask_size_); + lctx->SendAsyncThrottled(lctx->NextRank(), send_buffer, + fmt::format("oprf_value")); + std::unordered_map self_cnt; + for (size_t i = 0; i != bucket_items_.size(); ++i) { + if (bucket_items_[i].extra_dup_cnt > 0) { + self_cnt[i] = bucket_items_[i].extra_dup_cnt; + } + } + lctx->SendAsyncThrottled(lctx->NextRank(), utils::SerializeItemsCnt(self_cnt), + fmt::format("items_cnt")); + SPDLOG_INFO("get intersection end"); } -void BucketRr22Sender::GetIntersection( +void BucketRr22Sender::BroadCastResult( const std::shared_ptr& lctx) { - if (null_bucket_) { - post_f_(bucket_idx_, bucket_items_, {}, {}); - return; + std::unordered_map items_cnt; + if (broadcast_result_) { + auto buffer = lctx->Recv(lctx->NextRank(), "broadcast_result"); + indices_.resize(buffer.size() / sizeof(uint32_t)); + std::memcpy(indices_.data(), buffer.data(), buffer.size()); + buffer = lctx->Recv(lctx->NextRank(), "broadcast_items_cnt"); + items_cnt = utils::DeserializeItemsCnt(buffer); + } + peer_cnt_.resize(indices_.size()); + for (auto& item : items_cnt) { + peer_cnt_[item.first] = item.second; } - SPDLOG_INFO("get intersection begin"); - std::vector indices; - std::vector peer_cnt; - std::tie(indices, peer_cnt) = GetIntersectionSender( - oprfs_, bucket_items_, lctx, mask_size_, broadcast_result_); - SPDLOG_INFO("get intersection end"); - post_f_(bucket_idx_, bucket_items_, indices, peer_cnt); - SPDLOG_INFO("get intersection post f"); } -void BucketRr22Receiver::Prepare( - const std::shared_ptr& lctx) { - bucket_items_ = pre_f_(bucket_idx_); +void BucketRr22Sender::WriteResult() { + data_processor_->WriteIntersetionItems(bucket_idx_, bucket_items_, indices_, + peer_cnt_); + SPDLOG_INFO("sender write bucket idx {} result", bucket_idx_); +} - self_size_ = bucket_items_.size(); - std::tie(mask_size_, peer_size_) = - ExchangeTruncateSize(lctx, bucket_items_.size(), rr22_options_); +void BucketRr22Receiver::Vole(const std::shared_ptr& lctx, + bool cache_vole, + const std::filesystem::path& cache_dir) { + std::tie(self_size_, peer_size_) = + data_processor_->GetBucketDatasize(bucket_idx_); + mask_size_ = ComputeMaskSize(rr22_options_, self_size_, peer_size_); SPDLOG_INFO("mask size: {}", mask_size_); if ((peer_size_ == 0) || (self_size_ == 0)) { null_bucket_ = true; return; } - - inputs_hash_ = std::vector(self_size_); - yacl::parallel_for(0, bucket_items_.size(), [&](int64_t begin, int64_t end) { - for (int64_t i = begin; i < end; ++i) { - inputs_hash_[i] = yacl::crypto::Blake3_128(bucket_items_[i].base64_data); - } - }); - oprf_receiver_.Init(lctx, self_size_, rr22_options_.num_threads); + oprf_receiver_.Init(lctx, self_size_, rr22_options_.num_threads, cache_vole, + cache_dir); } void BucketRr22Receiver::RunOprf( @@ -146,23 +182,125 @@ void BucketRr22Receiver::RunOprf( if (null_bucket_) { return; } + + bucket_items_ = data_processor_->GetBucketItems(bucket_idx_); + + inputs_hash_ = std::vector(bucket_items_.size()); + yacl::parallel_for(0, bucket_items_.size(), [&](int64_t begin, int64_t end) { + for (int64_t i = begin; i < end; ++i) { + inputs_hash_[i] = yacl::crypto::Blake3_128(bucket_items_[i].base64_data); + } + }); oprfs_ = oprf_receiver_.Recv(lctx, inputs_hash_); } -void BucketRr22Receiver::GetIntersection( +void BucketRr22Receiver::Intersection( const std::shared_ptr& lctx) { if (null_bucket_) { - post_f_(bucket_idx_, bucket_items_, {}, {}); + data_processor_->WriteIntersetionItems(bucket_idx_, bucket_items_, {}, {}); return; } SPDLOG_INFO("get intersection begin"); - std::vector indices; - std::vector peer_cnt; - std::tie(indices, peer_cnt) = GetIntersectionReceiver( - oprfs_, bucket_items_, peer_size_, lctx, rr22_options_.num_threads, - mask_size_, broadcast_result_); + bool compress = mask_size_ != sizeof(uint128_t); + google::dense_hash_map dense_map(oprfs_.size()); + dense_map.set_empty_key(yacl::MakeUint128(0, 0)); + auto map_f = std::async([&]() { + auto truncate_mask = yacl::MakeUint128(0, 0); + if (compress) { + for (size_t i = 0; i < mask_size_; ++i) { + truncate_mask = 0xff | (truncate_mask << 8); + SPDLOG_DEBUG( + "{}, truncate_mask:{}", i, + (std::ostringstream() << okvs::Galois128(truncate_mask)).str()); + } + } + for (size_t i = 0; i < oprfs_.size(); ++i) { + if (compress) { + dense_map.insert(std::make_pair(oprfs_[i] & truncate_mask, i)); + } else { + dense_map.insert(std::make_pair(oprfs_[i], i)); + } + } + }); + SPDLOG_INFO("recv rr22 oprf begin"); + auto peer_buffer = lctx->Recv(lctx->NextRank(), fmt::format("paxos_solve")); + YACL_ENFORCE(peer_size_ == peer_buffer.size() / mask_size_); + SPDLOG_INFO("recv rr22 oprf finished: {} vector:{}", peer_buffer.size(), + peer_size_); + auto cnt_buffer = lctx->Recv(lctx->NextRank(), fmt::format("items_cnt")); + auto peer_item_cnt_map = utils::DeserializeItemsCnt(cnt_buffer); + for (uint32_t i = 0; i < peer_size_; ++i) { + (void)peer_item_cnt_map[i]; + } + + map_f.get(); + auto* peer_data_ptr = peer_buffer.data(); + std::mutex merge_mtx; + size_t grain_size = + (peer_size_ + rr22_options_.num_threads - 1) / rr22_options_.num_threads; + yacl::parallel_for( + 0, peer_size_, grain_size, [&](int64_t begin, int64_t end) { + std::vector tmp_indexs; + std::vector tmp_peer_cnt; + + std::vector tmp_peer_indexs; + std::vector tmp_self_cnt; + + uint128_t data = yacl::MakeUint128(0, 0); + for (int64_t j = begin; j < end; j++) { + std::memcpy(&data, peer_data_ptr + (j * mask_size_), mask_size_); + auto iter = dense_map.find(data); + if (iter != dense_map.end()) { + tmp_indexs.push_back(iter->second); + tmp_peer_cnt.push_back(peer_item_cnt_map[j]); + if (broadcast_result_) { + tmp_peer_indexs.push_back(j); + YACL_ENFORCE( + iter->second < bucket_items_.size(), + "random str matched in result, which is not expected."); + tmp_self_cnt.push_back(bucket_items_[iter->second].extra_dup_cnt); + } + } + } + if (!tmp_indexs.empty()) { + std::lock_guard lock(merge_mtx); + self_indices_.insert(self_indices_.end(), tmp_indexs.begin(), + tmp_indexs.end()); + peer_cnt_.insert(peer_cnt_.end(), tmp_peer_cnt.begin(), + tmp_peer_cnt.end()); + if (broadcast_result_) { + peer_indices_.insert(peer_indices_.end(), tmp_peer_indexs.begin(), + tmp_peer_indexs.end()); + self_cnt_.insert(self_cnt_.end(), tmp_self_cnt.begin(), + tmp_self_cnt.end()); + } + } + }); SPDLOG_INFO("get intersection end"); - post_f_(bucket_idx_, bucket_items_, indices, peer_cnt); - SPDLOG_INFO("get intersection post f"); } + +void BucketRr22Receiver::BroadCastResult( + const std::shared_ptr& lctx) { + if (broadcast_result_) { + auto buffer = yacl::Buffer(peer_indices_.data(), + peer_indices_.size() * sizeof(uint32_t)); + lctx->SendAsyncThrottled(lctx->NextRank(), buffer, "broadcast_result"); + std::unordered_map self_cnt_map; + for (size_t i = 0; i < self_cnt_.size(); ++i) { + if (self_cnt_[i] > 0) { + self_cnt_map[i] = self_cnt_[i]; + } + } + lctx->SendAsyncThrottled(lctx->NextRank(), + utils::SerializeItemsCnt(self_cnt_map), + "broadcast_items_cnt"); + } +} + +void BucketRr22Receiver::WriteResult() { + data_processor_->WriteIntersetionItems(bucket_idx_, bucket_items_, + self_indices_, peer_cnt_); + SPDLOG_INFO("receiver write bucket idx {} result", bucket_idx_); +} + } // namespace psi::rr22 diff --git a/psi/algorithm/rr22/rr22_psi.h b/psi/algorithm/rr22/rr22_psi.h index d940ad77..8a2a832f 100644 --- a/psi/algorithm/rr22/rr22_psi.h +++ b/psi/algorithm/rr22/rr22_psi.h @@ -18,6 +18,8 @@ #include #include #include +#include +#include #include #include #include @@ -25,6 +27,7 @@ #include #include #include +#include #include #include "yacl/base/exception.h" @@ -34,6 +37,7 @@ #include "psi/algorithm/rr22/rr22_oprf.h" #include "psi/utils/bucket.h" #include "psi/utils/hash_bucket_cache.h" +#include "psi/utils/simple_channel.h" // [RR22] Blazing Fast PSI from Improved OKVS and Subfield VOLE, CCS 2022 // https://eprint.iacr.org/2022/320 @@ -75,28 +79,31 @@ struct Rr22PsiOptions { const size_t oprf_bin_size = 1 << 14; }; -using PreProcessFunc = - std::function(size_t)>; -// input: bucket index, bucket_items, intersection indices, dup cnt -using PostProcessFunc = std::function&, - const std::vector&, const std::vector&)>; - class BucketRr22Core { public: BucketRr22Core(const Rr22PsiOptions& rr22_options, size_t bucket_num, - size_t bucket_idx, bool broadcast_result) + size_t bucket_idx, bool broadcast_result, + IBucketDataStore* data_processor) : rr22_options_(rr22_options), bucket_num_(bucket_num), broadcast_result_(broadcast_result), - bucket_idx_(bucket_idx) {} + bucket_idx_(bucket_idx), + data_processor_(data_processor) {} virtual ~BucketRr22Core() = default; - - virtual void Prepare(const std::shared_ptr& lctx) = 0; + virtual void Vole(const std::shared_ptr& lctx, + bool cache_vole = false, + const std::filesystem::path& cache_dir = + std::filesystem::temp_directory_path() / + GetRandomString()) = 0; virtual void RunOprf(const std::shared_ptr& lctx) = 0; - virtual void GetIntersection( + virtual void Intersection( const std::shared_ptr& lctx) = 0; + virtual void BroadCastResult( + const std::shared_ptr& lctx) = 0; + virtual bool IsSender() = 0; + virtual void WriteResult() = 0; + [[nodiscard]] size_t BucketIdx() const { return bucket_idx_; } protected: Rr22PsiOptions rr22_options_; @@ -106,147 +113,206 @@ class BucketRr22Core { size_t self_size_ = 0; size_t peer_size_ = 0; + size_t vole_size_ = 0; size_t mask_size_ = sizeof(uint128_t); std::vector inputs_hash_; std::vector oprfs_; std::vector bucket_items_; bool null_bucket_ = false; + IBucketDataStore* data_processor_; }; class BucketRr22Sender : public BucketRr22Core { public: BucketRr22Sender(const Rr22PsiOptions& rr22_options, size_t bucket_num, size_t bucket_idx, bool broadcast_result, - PreProcessFunc& pre_f, PostProcessFunc& post_f) - : BucketRr22Core(rr22_options, bucket_num, bucket_idx, broadcast_result), - pre_f_(pre_f), - post_f_(post_f), + IBucketDataStore* data_processor) + : BucketRr22Core(rr22_options, bucket_num, bucket_idx, broadcast_result, + data_processor), oprf_sender_(rr22_options.oprf_bin_size, rr22_options_.ssp, rr22_options_.mode, rr22_options_.code_type, rr22_options_.malicious) {} - void Prepare(const std::shared_ptr& lctx) override; + void Vole(const std::shared_ptr& lctx, + bool cache_vole = false, + const std::filesystem::path& cache_dir = + std::filesystem::temp_directory_path() / + GetRandomString()) override; void RunOprf(const std::shared_ptr& lctx) override; - void GetIntersection( + void Intersection(const std::shared_ptr& lctx) override; + void BroadCastResult( const std::shared_ptr& lctx) override; + void WriteResult() override; + bool IsSender() override { return true; } private: - PreProcessFunc pre_f_; - PostProcessFunc post_f_; Rr22OprfSender oprf_sender_; + std::vector indices_; + std::vector peer_cnt_; }; class BucketRr22Receiver : public BucketRr22Core { public: BucketRr22Receiver(const Rr22PsiOptions& rr22_options, size_t bucket_num, size_t bucket_idx, bool broadcast_result, - PreProcessFunc& pre_f, PostProcessFunc& post_f) - : BucketRr22Core(rr22_options, bucket_num, bucket_idx, broadcast_result), - pre_f_(pre_f), - post_f_(post_f), + IBucketDataStore* data_processor) + : BucketRr22Core(rr22_options, bucket_num, bucket_idx, broadcast_result, + data_processor), oprf_receiver_(rr22_options.oprf_bin_size, rr22_options_.ssp, rr22_options_.mode, rr22_options_.code_type, rr22_options_.malicious) {} - void Prepare(const std::shared_ptr& lctx) override; + void Vole(const std::shared_ptr& lctx, + bool cache_vole = false, + const std::filesystem::path& cache_dir = + std::filesystem::temp_directory_path() / + GetRandomString()) override; void RunOprf(const std::shared_ptr& lctx) override; - void GetIntersection( + void Intersection(const std::shared_ptr& lctx) override; + void BroadCastResult( const std::shared_ptr& lctx) override; + void WriteResult() override; + bool IsSender() override { return false; } private: - PreProcessFunc pre_f_; - PostProcessFunc post_f_; Rr22OprfReceiver oprf_receiver_; + std::vector self_cnt_; + std::vector peer_cnt_; + std::vector self_indices_; + std::vector peer_indices_; }; -// return {mask_size, peer_size} -std::pair ExchangeTruncateSize( - const std::shared_ptr& lctx, size_t self_size, - const Rr22PsiOptions& options); - class Rr22Runner { public: Rr22Runner(const std::shared_ptr& lctx, const Rr22PsiOptions& rr22_options, size_t bucket_num, - bool broadcast_result, PreProcessFunc& pre_f, - PostProcessFunc& post_f) - : rr22_options_(rr22_options), + bool broadcast_result, IBucketDataStore* data_processor) + : lctx_(lctx), + rr22_options_(rr22_options), bucket_num_(bucket_num), broadcast_result_(broadcast_result), - pre_f_(pre_f), - post_f_(post_f) { - intersection_lctx_ = lctx->Spawn("intersection"); - read_lctx_ = lctx->Spawn("read"); - run_lctx_ = lctx->Spawn("run"); - } + data_processor_(data_processor) {} void Run(size_t start_idx, bool is_sender) { for (size_t idx = start_idx; idx < bucket_num_; ++idx) { auto bucket_runner = CreateBucketRunner(idx, is_sender); - bucket_runner->Prepare(read_lctx_); - bucket_runner->RunOprf(run_lctx_); - bucket_runner->GetIntersection(intersection_lctx_); + bucket_runner->Vole(lctx_); + bucket_runner->RunOprf(lctx_); + bucket_runner->Intersection(lctx_); + bucket_runner->BroadCastResult(lctx_); + bucket_runner->WriteResult(); } } - void AsyncRun(size_t start_idx, bool is_sender) { + void AsyncRun(size_t start_idx, bool is_sender, + const std::filesystem::path& cache_dir) { // cache size meaning the size you can prepare input data into queue // bigger cache size may run a little fast but consume more memory - constexpr size_t cache_size = 1; + constexpr size_t cache_size = 2; if (bucket_num_ <= cache_size) { Run(start_idx, is_sender); return; } - std::queue> prepared_runner_queue; - std::queue> oprf_runner_queue; - std::mutex prepare_mtx; - std::condition_variable prepare_cv; - std::mutex oprf_mtx; - std::condition_variable oprf_cv; - auto prepare_f = std::async(std::launch::async, [&]() { - for (size_t i = start_idx; i < bucket_num_; i++) { - auto runner = CreateBucketRunner(i, is_sender); - runner->Prepare(read_lctx_); - std::unique_lock lock(prepare_mtx); - prepare_cv.wait( - lock, [&] { return prepared_runner_queue.size() < cache_size; }); - prepared_runner_queue.push(runner); - prepare_cv.notify_all(); - } - }); - auto run_f = std::async(std::launch::async, [&]() { - for (int i = start_idx; i < static_cast(bucket_num_); ++i) { - std::shared_ptr runner; - { - std::unique_lock lock(prepare_mtx); - prepare_cv.wait(lock, [&] { return !prepared_runner_queue.empty(); }); - runner = prepared_runner_queue.front(); - prepared_runner_queue.pop(); - prepare_cv.notify_all(); - } - runner->RunOprf(run_lctx_); - { - std::unique_lock lock(oprf_mtx); - oprf_runner_queue.push(runner); - oprf_cv.notify_all(); - } - } - }); - auto intersection_f = std::async(std::launch::async, [&]() { - for (int i = start_idx; i < static_cast(bucket_num_); ++i) { - std::shared_ptr runner; - { - std::unique_lock lock(oprf_mtx); - oprf_cv.wait(lock, [&] { return !oprf_runner_queue.empty(); }); - runner = oprf_runner_queue.front(); - oprf_runner_queue.pop(); - oprf_cv.notify_all(); - } - runner->GetIntersection(intersection_lctx_); - } - }); - run_f.get(); - prepare_f.get(); - intersection_f.get(); + // create cache dir if not exist + if (!std::filesystem::exists(cache_dir)) { + std::filesystem::create_directory(cache_dir); + } + // main computation flow of rr22 - takes bucket data as input, returns + // intersection index + auto helper = + [&](SimpleChannel>* run_queue, + SimpleChannel>* result_queue, + size_t capacity) { + SimpleChannel> intersection_queue( + capacity); + SimpleChannel> broadcast_queue( + capacity); + auto run_f = std::async(std::launch::async, [&]() { + while (true) { + auto data = run_queue->Pop(); + if (!data.has_value()) { + break; + } + auto runner = data.value(); + std::shared_ptr lctx = + lctx_->Spawn("oprf-" + std::to_string(runner->BucketIdx())); + runner->RunOprf(lctx); + intersection_queue.Push(runner); + } + intersection_queue.Close(); + }); + auto intersection_f = std::async(std::launch::async, [&]() { + while (true) { + auto data = intersection_queue.Pop(); + if (!data.has_value()) { + break; + } + auto runner = data.value(); + std::shared_ptr lctx = lctx_->Spawn( + "intersection-" + std::to_string(runner->BucketIdx())); + runner->Intersection(lctx); + broadcast_queue.Push(runner); + } + broadcast_queue.Close(); + }); + auto broadcast_f = std::async(std::launch::async, [&]() { + while (true) { + auto data = broadcast_queue.Pop(); + if (!data.has_value()) { + break; + } + auto runner = data.value(); + std::shared_ptr lctx = lctx_->Spawn( + "broadcasrt-" + std::to_string(runner->BucketIdx())); + runner->BroadCastResult(lctx); + result_queue->Push(runner); + } + result_queue->Close(); + }); + intersection_f.get(); + broadcast_f.get(); + }; + // create vole in parallel + std::vector> runners(bucket_num_); + // selected based on test results + constexpr size_t VoleParallelSize = 6; + std::vector> futures(VoleParallelSize); + for (size_t i = 0; i < futures.size(); i++) { + futures[i] = std::async( + std::launch::async, + [&](size_t thread_idx) { + for (size_t j = 0; j < bucket_num_; j++) { + if (j % futures.size() == thread_idx) { + SPDLOG_INFO("idx: {}, is_sender: {}", j, is_sender); + auto runner = CreateBucketRunner(j, is_sender); + std::shared_ptr spawn_lctx = + lctx_->Spawn("vole-" + std::to_string(runner->BucketIdx())); + runner->Vole(spawn_lctx, true, cache_dir); + runners[runner->BucketIdx()] = runner; + } + } + }, + i); + } + // waiting for futures completed + for (auto& f : futures) { + f.get(); + } + SimpleChannel> run_queue(bucket_num_); + for (size_t idx = start_idx; idx < bucket_num_; idx++) { + run_queue.Push(runners[idx]); + } + runners.clear(); + run_queue.Close(); + SimpleChannel> result_queue(cache_size); + auto f = std::async(std::launch::async, helper, &run_queue, &result_queue, + cache_size); + + for (size_t idx = start_idx; idx < bucket_num_; idx++) { + auto data = result_queue.Pop(); + data.value()->WriteResult(); + } + f.get(); } + // deprecated void ParallelRun(size_t start_idx, bool is_sender, int parallel_num = 6) { if (static_cast(bucket_num_) <= parallel_num) { Run(start_idx, is_sender); @@ -257,18 +323,16 @@ class Rr22Runner { futures[i] = std::async( std::launch::async, [&](size_t thread_idx) { - std::shared_ptr spawn_read_lctx = - read_lctx_->Spawn(std::to_string(thread_idx)); - std::shared_ptr spawn_run_lctx = - run_lctx_->Spawn(std::to_string(thread_idx)); - std::shared_ptr spawn_intersection_lctx = - intersection_lctx_->Spawn(std::to_string(thread_idx)); + std::shared_ptr spawn_lctx = + lctx_->Spawn(std::to_string(thread_idx)); for (size_t j = 0; j < bucket_num_; j++) { if (j % parallel_num == thread_idx) { auto runner = CreateBucketRunner(j, is_sender); - runner->Prepare(spawn_read_lctx); - runner->RunOprf(spawn_run_lctx); - runner->GetIntersection(spawn_intersection_lctx); + runner->Vole(spawn_lctx); + runner->RunOprf(spawn_lctx); + runner->Intersection(spawn_lctx); + runner->BroadCastResult(spawn_lctx); + runner->WriteResult(); } } }, @@ -285,21 +349,21 @@ class Rr22Runner { std::shared_ptr bucker_runner; if (is_sender) { bucker_runner = std::make_shared( - rr22_options_, bucket_num_, idx, broadcast_result_, pre_f_, post_f_); + rr22_options_, bucket_num_, idx, broadcast_result_, data_processor_); } else { bucker_runner = std::make_shared( - rr22_options_, bucket_num_, idx, broadcast_result_, pre_f_, post_f_); + rr22_options_, bucket_num_, idx, broadcast_result_, data_processor_); } return bucker_runner; } - std::shared_ptr intersection_lctx_; - std::shared_ptr read_lctx_; - std::shared_ptr run_lctx_; + + std::shared_ptr lctx_; Rr22PsiOptions rr22_options_; size_t bucket_num_; bool broadcast_result_; - PreProcessFunc pre_f_; - PostProcessFunc post_f_; + IBucketDataStore* data_processor_; }; +size_t ComputeMaskSize(const Rr22PsiOptions& options, size_t self_size, + size_t peer_size); } // namespace psi::rr22 diff --git a/psi/algorithm/rr22/rr22_psi_benchmark.cc b/psi/algorithm/rr22/rr22_psi_benchmark.cc index 5b5c4933..f3f5d217 100644 --- a/psi/algorithm/rr22/rr22_psi_benchmark.cc +++ b/psi/algorithm/rr22/rr22_psi_benchmark.cc @@ -124,10 +124,9 @@ static void BM_Rr22FastPsi(benchmark::State& state) { } auto psi_receiver_proc = std::async([&] { - size_t mask_size; - size_t peer_size; - std::tie(mask_size, peer_size) = - ExchangeTruncateSize(lctxs[0], inputs_a.size(), psi_options); + size_t peer_size = inputs_b.size(); + size_t mask_size = + psi::rr22::ComputeMaskSize(psi_options, inputs_a.size(), peer_size); psi::rr22::Rr22OprfReceiver oprf_receiver( kRr22OprfBinSize, kRr22DefaultSsp, psi_options.mode); oprf_receiver.Init(lctxs[0], inputs_a.size(), psi_options.num_threads); @@ -138,14 +137,12 @@ static void BM_Rr22FastPsi(benchmark::State& state) { }); auto psi_sender_proc = std::async([&] { - size_t mask_size; - size_t peer_size; - std::tie(mask_size, peer_size) = - ExchangeTruncateSize(lctxs[1], inputs_b.size(), psi_options); + size_t peer_size = inputs_a.size(); + size_t mask_size = + psi::rr22::ComputeMaskSize(psi_options, inputs_b.size(), peer_size); psi::rr22::Rr22OprfSender oprf_sender(kRr22OprfBinSize, kRr22DefaultSsp, psi_options.mode); - oprf_sender.Init(lctxs[1], std::max(peer_size, inputs_b.size()), - psi_options.num_threads); + oprf_sender.Init(lctxs[1], peer_size, psi_options.num_threads); auto inputs_hash = oprf_sender.Send(lctxs[1], inputs_b); auto oprfs = oprf_sender.Eval(inputs_b, inputs_hash); return psi::rr22::GetIntersectionSender(std::move(oprfs), lctxs[1], diff --git a/psi/algorithm/rr22/rr22_psi_test.cc b/psi/algorithm/rr22/rr22_psi_test.cc index 1120ae15..461f9ef5 100644 --- a/psi/algorithm/rr22/rr22_psi_test.cc +++ b/psi/algorithm/rr22/rr22_psi_test.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include #include "gtest/gtest.h" @@ -74,6 +75,45 @@ struct TestParams { class Rr22PsiTest : public testing::TestWithParam {}; +class BucketDataStoreImpl : public IBucketDataStore { + public: + BucketDataStoreImpl(const std::vector& inputs, + const uint32_t peer_size) + : inputs_(inputs), peer_size_(peer_size) {} + + std::vector GetBucketItems(size_t) override { + std::vector bucket_items(inputs_.size()); + for (size_t i = 0; i < bucket_items.size(); ++i) { + bucket_items[i] = {.index = i, + .base64_data = fmt::format("{}", inputs_[i])}; + } + return bucket_items; + } + + void WriteIntersetionItems( + size_t, const std::vector& items, + const std::vector& intersection_indices, + const std::vector& peer_dup_cnts) override { + for (size_t i = 0; i < intersection_indices.size(); ++i) { + indices_result_.push_back(items[intersection_indices[i]].index); + for (size_t j = 0; j < peer_dup_cnts[i]; ++j) { + indices_result_.push_back(items[intersection_indices[i]].index); + } + } + } + + std::pair GetBucketDatasize(size_t) override { + return std::make_pair(inputs_.size(), peer_size_); + } + + std::vector GetResult() { return indices_result_; } + + private: + std::vector inputs_; + uint32_t peer_size_; + std::vector indices_result_; +}; + TEST_P(Rr22PsiTest, CorrectTest) { auto params = GetParam(); @@ -93,57 +133,25 @@ TEST_P(Rr22PsiTest, CorrectTest) { psi_options.mode = params.mode; psi_options.malicious = params.malicious; - std::vector indices_psi; - PreProcessFunc receiver_pre_f = [&](size_t) { - std::vector bucket_items(inputs_a.size()); - for (size_t i = 0; i < inputs_a.size(); ++i) { - bucket_items[i] = {.index = i, - .base64_data = fmt::format("{}", inputs_a[i])}; - } - return bucket_items; - }; - std::mutex mtx; - PostProcessFunc receiver_post_f = - [&](size_t, const std::vector& bucket_items, - const std::vector& indices, - const std::vector& peer_dup_cnt) { - SPDLOG_INFO("receiver_post_f: {}, {}", indices.size(), - peer_dup_cnt.size()); - std::unique_lock lock(mtx); - for (size_t i = 0; i < indices.size(); ++i) { - indices_psi.push_back(bucket_items[indices[i]].index); - for (size_t j = 0; j < peer_dup_cnt[i]; ++j) { - indices_psi.push_back(bucket_items[indices[i]].index); - } - } - }; - PreProcessFunc sender_pre_f = [&](size_t) { - std::vector bucket_items(inputs_b.size()); - for (size_t i = 0; i < inputs_b.size(); ++i) { - bucket_items[i] = {.index = i, - .base64_data = fmt::format("{}", inputs_b[i])}; - } - return bucket_items; - }; - size_t bucket_num = 1; - PostProcessFunc sender_post_f = - [&](size_t, const std::vector&, - const std::vector&, - const std::vector&) { return; }; + BucketDataStoreImpl receiver_data(inputs_a, inputs_b.size()); + BucketDataStoreImpl sender_data(inputs_b, inputs_a.size()); + + constexpr size_t bucket_num = 1; auto psi_receiver_proc = std::async([&] { - Rr22Runner runner(lctxs[0], psi_options, bucket_num, false, receiver_pre_f, - receiver_post_f); - runner.AsyncRun(0, false); + Rr22Runner runner(lctxs[0], psi_options, bucket_num, false, &receiver_data); + runner.AsyncRun(0, false, + std::filesystem::temp_directory_path() / GetRandomString()); }); auto psi_sender_proc = std::async([&] { - Rr22Runner runner(lctxs[1], psi_options, bucket_num, false, sender_pre_f, - sender_post_f); - runner.AsyncRun(0, true); + Rr22Runner runner(lctxs[1], psi_options, bucket_num, false, &sender_data); + runner.AsyncRun(0, true, + std::filesystem::temp_directory_path() / GetRandomString()); }); psi_sender_proc.get(); psi_receiver_proc.get(); + auto indices_psi = receiver_data.GetResult(); std::sort(indices_psi.begin(), indices_psi.end()); std::vector indices_result; for (size_t i = 0; i < bucket_num; i++) { diff --git a/psi/algorithm/rr22/rr22_utils.cc b/psi/algorithm/rr22/rr22_utils.cc index f51af911..a6b0b214 100644 --- a/psi/algorithm/rr22/rr22_utils.cc +++ b/psi/algorithm/rr22/rr22_utils.cc @@ -233,17 +233,10 @@ std::pair, std::vector> GetIntersectionSender( bool broadcast_result) { std::vector result; bool compress = mask_size != sizeof(uint128_t); - auto truncate_mask = yacl::MakeUint128(0, 0); auto* data_ptr = reinterpret_cast(self_oprfs.data()); if (compress) { - for (size_t i = 0; i < mask_size; ++i) { - truncate_mask = 0xff | (truncate_mask << 8); - SPDLOG_DEBUG( - "{}, truncate_mask:{}", i, - (std::ostringstream() << okvs::Galois128(truncate_mask)).str()); - } for (size_t i = 0; i < self_oprfs.size(); ++i) { - std::memcpy(data_ptr + (i * mask_size), &self_oprfs[i], mask_size); + std::memmove(data_ptr + (i * mask_size), &self_oprfs[i], mask_size); } } yacl::ByteContainerView send_buffer(data_ptr, self_oprfs.size() * mask_size); @@ -278,17 +271,10 @@ std::vector GetIntersectionSender( bool broadcast_result) { std::vector result; bool compress = mask_size != sizeof(uint128_t); - auto truncate_mask = yacl::MakeUint128(0, 0); auto* data_ptr = reinterpret_cast(self_oprfs.data()); if (compress) { - for (size_t i = 0; i < mask_size; ++i) { - truncate_mask = 0xff | (truncate_mask << 8); - SPDLOG_DEBUG( - "{}, truncate_mask:{}", i, - (std::ostringstream() << okvs::Galois128(truncate_mask)).str()); - } for (size_t i = 0; i < self_oprfs.size(); ++i) { - std::memcpy(data_ptr + (i * mask_size), &self_oprfs[i], mask_size); + std::memmove(data_ptr + (i * mask_size), &self_oprfs[i], mask_size); } } for (size_t i = 0; i < self_oprfs.size(); i += kSendChunkSize) { diff --git a/psi/algorithm/rr22/sender.cc b/psi/algorithm/rr22/sender.cc index 8e619a7f..dc4c0f31 100644 --- a/psi/algorithm/rr22/sender.cc +++ b/psi/algorithm/rr22/sender.cc @@ -25,6 +25,7 @@ #include "psi/algorithm/rr22/rr22_utils.h" #include "psi/trace_categories.h" #include "psi/utils/bucket.h" +#include "psi/utils/multiplex_disk_cache.h" #include "psi/utils/sync.h" namespace psi::rr22 { @@ -105,33 +106,16 @@ void Rr22PsiSender::Online() { Rr22PsiOptions rr22_options = GenerateRr22PsiOptions( config_.protocol_config().rr22_config().low_comm_mode()); - - PreProcessFunc pre_f = - [&](size_t idx) -> std::vector { - if (idx >= input_bucket_store_->BucketNum()) { - return {}; - } - return input_bucket_store_->LoadBucketItems(idx); - }; - PostProcessFunc post_f = - [&](size_t bucket_idx, - const std::vector& bucket_items, - const std::vector& indices, - const std::vector& peer_cnt) { - for (size_t i = 0; i != indices.size(); ++i) { - intersection_indices_writer_->WriteCache( - bucket_items[indices[i]].index, peer_cnt[i]); - } - intersection_indices_writer_->Commit(); - if (recovery_manager_) { - recovery_manager_->UpdateParsedBucketCount(bucket_idx + 1); - } - }; - + BucketDataStoreImpl data_processor(lctx_, input_bucket_store_.get(), + intersection_indices_writer_.get(), + recovery_manager_.get()); Rr22Runner runner(lctx_, rr22_options, input_bucket_store_->BucketNum(), - config_.protocol_config().broadcast_result(), pre_f, - post_f); - SyncWait(lctx_, [&] { runner.AsyncRun(bucket_idx, true); }); + config_.protocol_config().broadcast_result(), + &data_processor); + auto scoped_temp_dir = std::make_unique(); + scoped_temp_dir->CreateUniqueTempDirUnderPath(GetTaskDir()); + SyncWait(lctx_, + [&] { runner.AsyncRun(bucket_idx, true, scoped_temp_dir->path()); }); SPDLOG_INFO("[Rr22PsiSender::Online] end"); } diff --git a/psi/utils/BUILD.bazel b/psi/utils/BUILD.bazel index 3074a404..0874c0db 100644 --- a/psi/utils/BUILD.bazel +++ b/psi/utils/BUILD.bazel @@ -427,3 +427,8 @@ psi_cc_test( "@com_google_googletest//:gtest", ], ) + +cc_library( + name = "simple_channel", + hdrs = ["simple_channel.h"], +) diff --git a/psi/utils/batch_provider_impl.h b/psi/utils/batch_provider_impl.h index 93e2556b..a1e7042d 100644 --- a/psi/utils/batch_provider_impl.h +++ b/psi/utils/batch_provider_impl.h @@ -48,6 +48,8 @@ class MemoryBatchProvider : public IBasicBatchProvider, [[nodiscard]] size_t batch_size() const override { return batch_size_; } + [[nodiscard]] size_t Size() const override { return items_.size(); } + [[nodiscard]] const std::vector& items() const; [[nodiscard]] const std::vector& labels() const; @@ -127,6 +129,10 @@ class MemoryDataStore : public IDataStore { return provider_; } + [[nodiscard]] size_t GetBucketDatasize(size_t) const override { + return provider_->Size(); + } + private: std::shared_ptr provider_; }; diff --git a/psi/utils/hash_bucket_cache.cc b/psi/utils/hash_bucket_cache.cc index 23679f21..876ce957 100644 --- a/psi/utils/hash_bucket_cache.cc +++ b/psi/utils/hash_bucket_cache.cc @@ -16,10 +16,12 @@ #include +#include #include #include #include #include +#include #include "absl/strings/escaping.h" @@ -39,6 +41,7 @@ HashBucketCache::HashBucketCache(const std::string& target_dir, std::filesystem::path(target_dir), use_scoped_tmp_dir); YACL_ENFORCE(disk_cache_, "cannot create disk cache from dir={}", target_dir); disk_cache_->CreateOutputStreams(bucket_num_, &bucket_os_vec_); + bucket_data_sizes_ = std::vector(bucket_num, 0); } HashBucketCache::~HashBucketCache() { @@ -52,12 +55,13 @@ void HashBucketCache::WriteItem(const std::string& data, bucket_item.index = item_index_; bucket_item.extra_dup_cnt = duplicate_cnt; bucket_item.base64_data = absl::Base64Escape(data); - - auto& out = bucket_os_vec_[std::hash()(bucket_item.base64_data) % - bucket_os_vec_.size()]; + size_t bucket_idx = + std::hash()(bucket_item.base64_data) % bucket_os_vec_.size(); + auto& out = bucket_os_vec_[bucket_idx]; out->Write(bucket_item.Serialize()); out->Write("\n"); item_index_++; + bucket_data_sizes_[bucket_idx]++; } void HashBucketCache::Flush() { @@ -80,6 +84,10 @@ std::vector HashBucketCache::LoadBucketItems( return ret; } +size_t HashBucketCache::GetBucketSize(uint32_t index) { + return bucket_data_sizes_[index]; +} + std::unique_ptr CreateCacheFromCsv( const std::string& csv_path, const std::vector& schema_names, const std::string& cache_dir, uint32_t bucket_num, uint32_t read_batch_size, diff --git a/psi/utils/hash_bucket_cache.h b/psi/utils/hash_bucket_cache.h index 7b58cfa0..4c1ad39b 100644 --- a/psi/utils/hash_bucket_cache.h +++ b/psi/utils/hash_bucket_cache.h @@ -87,6 +87,8 @@ class HashBucketCache { std::vector LoadBucketItems(uint32_t index); + size_t GetBucketSize(uint32_t index); + uint32_t BucketNum() const { return bucket_num_; } uint64_t ItemCount() const { return item_index_; } @@ -96,6 +98,8 @@ class HashBucketCache { std::vector> bucket_os_vec_; + std::vector bucket_data_sizes_; + uint32_t bucket_num_; uint64_t item_index_; diff --git a/psi/utils/simple_channel.h b/psi/utils/simple_channel.h new file mode 100644 index 00000000..bb9c67bb --- /dev/null +++ b/psi/utils/simple_channel.h @@ -0,0 +1,99 @@ +// Copyright 2025 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "yacl/base/exception.h" + +namespace psi { +template +class SimpleChannel { + public: + explicit SimpleChannel(size_t capacity) : capacity_(capacity) {} + ~SimpleChannel() {} + + void Push(T& item) { + { + std::unique_lock lock(mutex_); + if (closed_) { + cond_.notify_all(); + YACL_THROW("send data to a closed queue"); + } + while (queue_.size() >= capacity_) { + cond_.wait(lock, [&] { return queue_.size() < capacity_; }); + } + queue_.push(item); + } + cond_.notify_all(); + } + + void Push(T&& item) { + { + std::unique_lock lock(mutex_); + if (closed_) { + cond_.notify_all(); + YACL_THROW("send data to a closed queue"); + } + while (queue_.size() >= capacity_) { + cond_.wait(lock, [&] { return queue_.size() < capacity_; }); + } + queue_.push(std::forward(item)); + } + + cond_.notify_all(); + } + + std::optional Pop() { + std::optional item; + { + std::unique_lock lock(mutex_); + cond_.wait(lock, [&] { return !queue_.empty() || closed_; }); + + if (queue_.empty()) { + // queue is empty and closed + return std::nullopt; + } + item = std::move(queue_.front()); + queue_.pop(); + } + cond_.notify_all(); + return item; + } + + // close queue, if a queue is empty, all Pop() will return empty item. + void Close() { + if (closed_) { + YACL_THROW("close a closed queue"); + } + std::unique_lock lock(mutex_); + closed_ = true; + cond_.notify_all(); + } + bool IsClosed() { return closed_; } + + private: + std::atomic_bool closed_{false}; + std::mutex mutex_; + std::condition_variable cond_; + size_t capacity_; + std::queue queue_; +}; +} // namespace psi