diff --git a/src/bqfc.c b/src/bqfc.c index dfd7458f..9286ceab 100644 --- a/src/bqfc.c +++ b/src/bqfc.c @@ -1,6 +1,5 @@ #include "bqfc.h" -#include #include #include #include @@ -120,18 +119,25 @@ int bqfc_decompr(mpz_t out_a, mpz_t out_b, const mpz_t D, const struct qfb_c *c) return ret; } -static void bqfc_export(uint8_t *out_str, size_t *offset, size_t size, +static int bqfc_export(uint8_t *out_str, size_t *offset, size_t size, const mpz_t n) { - size_t bytes; + size_t bytes = 0; + const size_t bits = (size_t)mpz_sizeinbase(n, 2); + const size_t needed_bytes = (bits + 7) / 8; + + if (needed_bytes > size) { + return -1; + } - // mpz_export can overflow out_str if reduction bug but this should never happen mpz_export(&out_str[*offset], &bytes, -1, 1, 0, 0, n); - if (bytes > size) - gmp_printf("bqfc_export overflow offset %d size %d n %Zd\n", *offset, size, n); + if (bytes > size) { + return -1; + } if (bytes < size) memset(&out_str[*offset + bytes], 0, size - bytes); *offset += size; + return 0; } enum BQFC_FLAG_BITS { @@ -171,20 +177,29 @@ int bqfc_serialize_only(uint8_t *out_str, const struct qfb_c *c, size_t d_bits) { size_t offset, g_size; + if (d_bits == 0 || d_bits > BQFC_MAX_D_BITS) + return -1; d_bits = (d_bits + 31) & ~(size_t)31; + if (d_bits > BQFC_MAX_D_BITS) + return -1; out_str[0] = (uint8_t)c->b_sign << BQFC_B_SIGN_BIT; out_str[0] |= (mpz_sgn(c->t) < 0 ? 1 : 0) << BQFC_T_SIGN_BIT; g_size = (mpz_sizeinbase(c->g, 2) + 7) / 8 - 1; - assert(g_size <= UCHAR_MAX); + if (g_size > UCHAR_MAX) + return -1; out_str[1] = (uint8_t)g_size; offset = 2; - bqfc_export(out_str, &offset, d_bits / 16 - g_size, c->a); - bqfc_export(out_str, &offset, d_bits / 32 - g_size, c->t); + if (bqfc_export(out_str, &offset, d_bits / 16 - g_size, c->a)) + return -1; + if (bqfc_export(out_str, &offset, d_bits / 32 - g_size, c->t)) + return -1; - bqfc_export(out_str, &offset, g_size + 1, c->g); - bqfc_export(out_str, &offset, g_size + 1, c->b0); + if (bqfc_export(out_str, &offset, g_size + 1, c->g)) + return -1; + if (bqfc_export(out_str, &offset, g_size + 1, c->b0)) + return -1; return 0; } @@ -193,7 +208,11 @@ int bqfc_deserialize_only(struct qfb_c *out_c, const uint8_t *str, size_t d_bits { size_t offset, bytes, g_size; + if (d_bits == 0 || d_bits > BQFC_MAX_D_BITS) + return -1; d_bits = (d_bits + 31) & ~(size_t)31; + if (d_bits > BQFC_MAX_D_BITS) + return -1; g_size = str[1]; if (g_size >= d_bits / 32) @@ -225,8 +244,11 @@ int bqfc_deserialize_only(struct qfb_c *out_c, const uint8_t *str, size_t d_bits int bqfc_get_compr_size(size_t d_bits) { + if (d_bits == 0 || d_bits > BQFC_MAX_D_BITS) + return -1; size_t size = (d_bits + 31) / 32 * 3 + 4; - assert(size <= INT_MAX); + if (size > INT_MAX) + return -1; return (int)size; } @@ -235,6 +257,8 @@ int bqfc_serialize(uint8_t *out_str, mpz_t a, mpz_t b, size_t d_bits) struct qfb_c f_c; int ret; int valid_size = bqfc_get_compr_size(d_bits); + if (valid_size <= 0 || valid_size > BQFC_FORM_SIZE) + return -1; if (!mpz_cmp_ui(b, 1) && mpz_cmp_ui(a, 2) <= 0) { out_str[0] = !mpz_cmp_ui(a, 2) ? BQFC_IS_GEN : BQFC_IS_1; @@ -271,6 +295,8 @@ int bqfc_deserialize(mpz_t out_a, mpz_t out_b, const mpz_t D, const uint8_t *str struct qfb_c f_c; int ret; + if (d_bits == 0 || d_bits > BQFC_MAX_D_BITS) + return -1; if (size != BQFC_FORM_SIZE) return -1; diff --git a/src/create_discriminant.h b/src/create_discriminant.h index 9a294bbd..63b90ac7 100644 --- a/src/create_discriminant.h +++ b/src/create_discriminant.h @@ -15,7 +15,7 @@ integer CreateDiscriminant(std::vector& seed, int length = 1024) { } // Check 2: Validate upper bound (optional but recommended) - const int MAX_DISCRIMINANT_SIZE_BITS = 16384; + const int MAX_DISCRIMINANT_SIZE_BITS = BQFC_MAX_D_BITS; if (length > MAX_DISCRIMINANT_SIZE_BITS) { throw std::invalid_argument( "discriminant_size_bits exceeds maximum allowed value" diff --git a/src/discriminant_bounds_regression_test.cpp b/src/discriminant_bounds_regression_test.cpp new file mode 100644 index 00000000..83fa3563 --- /dev/null +++ b/src/discriminant_bounds_regression_test.cpp @@ -0,0 +1,105 @@ +#include "verifier.h" + +#include + +#include +#include +#include + +namespace { + +std::vector db_hex_to_bytes(const std::string& hex) { + EXPECT_EQ(hex.size() % 2, 0U); + std::vector out; + out.reserve(hex.size() / 2); + for (size_t i = 0; i < hex.size(); i += 2) { + out.push_back(static_cast(std::stoul(hex.substr(i, 2), nullptr, 16))); + } + return out; +} + +std::vector db_get_fixture_challenge() { + return db_hex_to_bytes("9104c5b5e45d48f374efa0488fe6a617790e9aecb3c9cddec06809b09f45ce9b"); +} + +std::vector db_get_fixture_x() { + return db_hex_to_bytes("08000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"); +} + +std::vector db_get_fixture_proof_blob() { + return db_hex_to_bytes( + "0200553bf0f382fc65a94f20afad5dbce2c1ee8ba3bf93053559ac9960c8fd80ac2222e9b649701a4141a4d8999f0dbfe0c39ea744096598a7528328e5199f0aa30aec8aae8ab5018bf1245329a8272ddff1afbd87ad2eaba1b7fd57bd25edc62e0b010000003f0ffcd0dc307a2aa4678bafba661c77d176ef23afc86e7ea9f4f9eac52b8e1850748019245ecc96547da9b731dc72cded5582a9b0c63e13fd42446c7b28b41d3ded1d0b666d5ddb5b29719e4ebe70969e67e42ddd8591eae60d83dbe619f1250400"); +} + +} // namespace + +TEST(DiscriminantBoundsRegressionTest, VerifyRejectsDiscSizeBitsAboveMaximum) { + std::vector challenge = db_get_fixture_challenge(); + std::vector x = db_get_fixture_x(); + std::vector proof_blob = db_get_fixture_proof_blob(); + const integer D = CreateDiscriminant(challenge, BQFC_MAX_D_BITS); + + EXPECT_TRUE(CheckProofOfTimeNWesolowski( + D, + x.data(), + proof_blob.data(), + proof_blob.size(), + 129499136, + BQFC_MAX_D_BITS, + 0)); + + EXPECT_FALSE(CheckProofOfTimeNWesolowski( + D, + x.data(), + proof_blob.data(), + proof_blob.size(), + 129499136, + static_cast(BQFC_MAX_D_BITS) + 1, + 0)); +} + +TEST(DiscriminantBoundsRegressionTest, CreateDiscriminantAndVerifyRejectsDiscSizeBitsAboveMaximum) { + std::vector challenge = db_get_fixture_challenge(); + std::vector x = db_get_fixture_x(); + std::vector proof_blob = db_get_fixture_proof_blob(); + + EXPECT_FALSE(CreateDiscriminantAndCheckProofOfTimeNWesolowski( + challenge, + static_cast(BQFC_MAX_D_BITS + 1), + x.data(), + proof_blob.data(), + proof_blob.size(), + 129499136, + 0)); +} + +TEST(DiscriminantBoundsRegressionTest, BqfcSerializationRejectsOversizedDiscriminantBits) { + uint8_t serialized[BQFC_FORM_SIZE]; + mpz_t a; + mpz_t b; + mpz_init_set_ui(a, 1); + mpz_init_set_ui(b, 1); + + EXPECT_EQ(bqfc_serialize(serialized, a, b, static_cast(BQFC_MAX_D_BITS) + 1), -1); + + mpz_clear(a); + mpz_clear(b); +} + +TEST(DiscriminantBoundsRegressionTest, BqfcDeserializationRejectsOversizedDiscriminantBits) { + uint8_t serialized[BQFC_FORM_SIZE] = {0}; + mpz_t D; + mpz_t out_a; + mpz_t out_b; + mpz_init_set_si(D, -23); + mpz_init(out_a); + mpz_init(out_b); + + EXPECT_EQ( + bqfc_deserialize(out_a, out_b, D, serialized, BQFC_FORM_SIZE, static_cast(BQFC_MAX_D_BITS) + 1), + -1); + + mpz_clear(D); + mpz_clear(out_a); + mpz_clear(out_b); +} diff --git a/src/regression_unit_tests.cpp b/src/regression_unit_tests.cpp index 4f6516ec..44d79444 100644 --- a/src/regression_unit_tests.cpp +++ b/src/regression_unit_tests.cpp @@ -1,4 +1,5 @@ #include "checked_cast_test.cpp" +#include "discriminant_bounds_regression_test.cpp" #include "proof_deserialization_regression_test.cpp" #include "prover_slow_regression_test.cpp" #include "two_weso_callback_regression_test.cpp" diff --git a/src/vdf_client.cpp b/src/vdf_client.cpp index 8b1bd65a..58dc0aff 100644 --- a/src/vdf_client.cpp +++ b/src/vdf_client.cpp @@ -7,6 +7,10 @@ using boost::asio::ip::tcp; std::mutex socket_mutex; +namespace { +constexpr int kIterationHeaderDigits = 2; +constexpr int kMaxIterationDigits = 20; +} // namespace // Segments are 2^16, 2^18, ..., 2^30 // Best case it'll be able to proof for up to 2^36 due to 64-wesolowski restriction. @@ -101,7 +105,9 @@ void FinishSession(tcp::socket& sock) { char ack[5]; memset(ack, 0x00, sizeof(ack)); boost::asio::read(sock, boost::asio::buffer(ack, 3), error); - assert (strncmp(ack, "ACK", 3) == 0); + if (strncmp(ack, "ACK", 3) != 0) { + throw std::runtime_error("Invalid stop ACK"); + } } catch (std::exception& e) { PrintInfo("Exception in thread: " + to_string(e.what())); } @@ -109,15 +115,41 @@ void FinishSession(tcp::socket& sock) { uint64_t ReadIteration(tcp::socket& sock) { boost::system::error_code error; - char data[20]; - memset(data, 0, sizeof(data)); - boost::asio::read(sock, boost::asio::buffer(data, 2), error); - int size = (data[0] - '0') * 10 + (data[1] - '0'); + char size_buf[kIterationHeaderDigits]; + memset(size_buf, 0, sizeof(size_buf)); + boost::asio::read(sock, boost::asio::buffer(size_buf, kIterationHeaderDigits), error); + if (error) { + throw std::runtime_error("Failed to read iteration size header"); + } + if (size_buf[0] < '0' || size_buf[0] > '9' || size_buf[1] < '0' || size_buf[1] > '9') { + throw std::runtime_error("Iteration size header must be decimal digits"); + } + + int size = (size_buf[0] - '0') * 10 + (size_buf[1] - '0'); + if (size == 0) { + return 0; + } + if (size > kMaxIterationDigits) { + throw std::runtime_error("Invalid iteration size"); + } + + char data[kMaxIterationDigits]; memset(data, 0, sizeof(data)); boost::asio::read(sock, boost::asio::buffer(data, size), error); + if (error) { + throw std::runtime_error("Failed to read iteration body"); + } uint64_t iters = 0; - for (int i = 0; i < size; i++) - iters = iters * 10 + data[i] - '0'; + for (int i = 0; i < size; i++) { + if (data[i] < '0' || data[i] > '9') { + throw std::runtime_error("Iteration body must be decimal digits"); + } + const uint64_t digit = static_cast(data[i] - '0'); + if (iters > (std::numeric_limits::max() - digit) / 10) { + throw std::runtime_error("Iteration value overflow"); + } + iters = iters * 10 + digit; + } return iters; } diff --git a/src/verifier.h b/src/verifier.h index 61baf869..97690c91 100644 --- a/src/verifier.h +++ b/src/verifier.h @@ -11,6 +11,17 @@ const uint8_t DEFAULT_ELEMENT[] = { 0x08 }; +inline bool IsDiscSizeBitsInRange(const uint64_t disc_size_bits) +{ + return disc_size_bits > 0 && disc_size_bits <= static_cast(BQFC_MAX_D_BITS); +} + +inline bool IsDiscriminantInRange(const integer& D) +{ + const int d_bits = D.num_bits(); + return d_bits > 0 && static_cast(d_bits) <= static_cast(BQFC_MAX_D_BITS); +} + int VerifyWesoSegment(integer &D, form x, form proof, integer &B, uint64_t iters, form &out_y) { PulmarkReducer reducer; @@ -43,7 +54,8 @@ void VerifyWesolowskiProof(integer &D, form x, form y, form proof, uint64_t iter bool CheckProofOfTimeNWesolowski(integer D, const uint8_t* x_s, const uint8_t* proof_blob, size_t proof_blob_len, uint64_t iterations, uint64 disc_size_bits, uint64_t depth) { - (void)disc_size_bits; + if (!IsDiscSizeBitsInRange(disc_size_bits) || !IsDiscriminantInRange(D)) + return false; const size_t form_size = BQFC_FORM_SIZE; const size_t segment_len = 8 + B_bytes + form_size; const size_t base_len = 2 * form_size; @@ -125,6 +137,7 @@ bool CheckProofOfTimeNWesolowskiCommon(integer& D, form& x, const uint8_t* proof } std::pair> CheckProofOfTimeNWesolowskiWithB(integer D, integer B, const uint8_t* x_s, const uint8_t* proof_blob, size_t proof_blob_len, uint64_t iterations, uint64_t depth) { + if (!IsDiscriminantInRange(D)) return {false, {}}; const size_t form_size = BQFC_FORM_SIZE; const size_t segment_len = 8 + B_bytes + form_size; const uint64_t max_depth = static_cast((std::numeric_limits::max() - form_size) / segment_len); @@ -150,6 +163,7 @@ std::pair> CheckProofOfTimeNWesolowskiWithB(integer D } integer GetBFromProof(integer D, const uint8_t* x_s, const uint8_t* proof_blob, size_t proof_blob_len, uint64_t iterations, uint64_t depth) { + if (!IsDiscriminantInRange(D)) throw std::runtime_error("Invalid proof."); const size_t form_size = BQFC_FORM_SIZE; const size_t segment_len = 8 + B_bytes + form_size; const size_t base_len = 2 * form_size; @@ -170,10 +184,14 @@ integer GetBFromProof(integer D, const uint8_t* x_s, const uint8_t* proof_blob, bool CreateDiscriminantAndCheckProofOfTimeNWesolowski(std::vector seed, uint32 disc_size_bits, const uint8_t* x_s, const uint8_t* proof_blob, size_t proof_blob_len, uint64_t iterations, uint64_t depth) { + if (!IsDiscSizeBitsInRange(disc_size_bits)) + return false; integer D = CreateDiscriminant( seed, disc_size_bits ); + if (!IsDiscriminantInRange(D)) + return false; return CheckProofOfTimeNWesolowski(D, x_s, proof_blob, proof_blob_len, iterations, disc_size_bits, depth); }