diff --git a/common/BUILD b/common/BUILD index 8dcf803a3..80240926a 100644 --- a/common/BUILD +++ b/common/BUILD @@ -1138,3 +1138,36 @@ cc_test( "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "ipaddress_oss", + srcs = ["ipaddress_oss.cc"], + hdrs = ["ipaddress_oss.h"], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/numeric:int128", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_test( + name = "ipaddress_oss_test", + srcs = ["ipaddress_oss_test.cc"], + deps = [ + ":ipaddress_oss", + "//internal:testing", + "@com_google_absl//absl/base:log_severity", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/numeric:int128", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], +) diff --git a/common/ipaddress_oss.cc b/common/ipaddress_oss.cc new file mode 100644 index 000000000..7981d8f68 --- /dev/null +++ b/common/ipaddress_oss.cc @@ -0,0 +1,646 @@ +// Copyright 2008 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/ipaddress_oss.h" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/optimization.h" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/numeric/int128.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/ascii.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" + +namespace cel { + +// Sanity check: be sure INET_ADDRSTRLEN fits into INET6_ADDRSTRLEN. +// ToCharBuf() (below) depends on this. +static_assert(INET_ADDRSTRLEN <= INET6_ADDRSTRLEN, "ipv6_larger_than_ipv4"); + +#define UNALIGNED_LOAD32(p) cel::common_internal::UnalignedLoad(p) +#define UNALIGNED_LOAD64(p) cel::common_internal::UnalignedLoad(p) + +namespace { + +const int kMaxNetmaskIPv4 = 32; +const int kMaxNetmaskIPv6 = 128; + +IPAddress MakeIPAddressWithOptionalScopeId(absl::uint128 ip6uint, + uint32_t scope_id) { + const IPAddress ip6(UInt128ToIPAddress(ip6uint)); + const auto rval = MakeIPAddressWithScopeId(ip6.ipv6_address(), scope_id); + return rval.ok() ? rval.value() : ip6; +} + +} // namespace + +IPAddress IPAddress::Any4() { return HostUInt32ToIPAddress(INADDR_ANY); } + +IPAddress IPAddress::Loopback4() { return HostUInt32ToIPAddress(0x7f000001); } + +IPAddress IPAddress::Any6() { return IPAddress(in6addr_any); } + +IPAddress IPAddress::Loopback6() { return IPAddress(in6addr_loopback); } + +in6_addr IPAddress::ipv6_address_slowpath() const { + ABSL_CHECK_EQ(AF_INET6, address_family_); + if (ABSL_PREDICT_FALSE(HasCompactScopeId(addr_.addr6))) { + in6_addr copy = addr_.addr6; + copy.s6_addr32[1] = 0; // clear the scope_id (interface index) + return copy; + } + return addr_.addr6; +} + +IPAddress HostUInt32ToIPAddress(uint32_t address) { + in_addr addr; + addr.s_addr = htonl(address); + return IPAddress(addr); +} + +IPAddress UInt128ToIPAddress(const absl::uint128 bigint) { + in6_addr addr6; + addr6.s6_addr32[0] = + htonl(static_cast(absl::Uint128High64(bigint) >> 32)); + addr6.s6_addr32[1] = + htonl(static_cast(absl::Uint128High64(bigint) & 0xFFFFFFFFULL)); + addr6.s6_addr32[2] = + htonl(static_cast(absl::Uint128Low64(bigint) >> 32)); + addr6.s6_addr32[3] = + htonl(static_cast(absl::Uint128Low64(bigint) & 0xFFFFFFFFULL)); + return IPAddress(addr6); +} + +namespace { + +void AppendIPv4ToString(const uint8_t* octets, std::string* out) { + absl::StrAppendFormat(out, "%d.%d.%d.%d", octets[0], octets[1], octets[2], + octets[3]); +} + +// Returns start of the longest sequence of zero words of length at least 2. +// Returns -1 if no such sequence exists. Returns first such sequence in the +// event of multiple such sequences. +int FindLongestZeroWordSequence(const uint16_t* addr) { + int cnt = 0; + int best_len = 1; + int best_start = -1; + for (int i = 0; i < 8; ++i) { + if (addr[i] == 0) { + ++cnt; + if (cnt > best_len) { + best_len = cnt; + best_start = i + 1 - cnt; + } + } else { + cnt = 0; + } + } + return best_start; +} + +void AppendIPv6ToString(const in6_addr& addr, std::string* out) { + if (addr.s6_addr32[0] == 0 && addr.s6_addr32[1] == 0) { + // If lower half of address is zero, it starts with :: and it may be + // embedded IPv4 address. + out->push_back(':'); + // Check for IPv6 embedded IPv4 address. + if (addr.s6_addr16[4] == 0 && + (addr.s6_addr16[5] == 0xffff || + (addr.s6_addr16[5] == 0 && addr.s6_addr16[6] != 0))) { + if (addr.s6_addr16[5] != 0) { + absl::StrAppend(out, ":ffff"); + } + out->push_back(':'); + AppendIPv4ToString(&addr.s6_addr[12], out); + return; + } + int i = 4; + // Skip remaining zero words. + while (i < 8 && addr.s6_addr16[i] == 0) { + ++i; + } + if (i < 8) { + for (; i < 8; ++i) { + absl::StrAppend(out, ":", absl::Hex(ntohs(addr.s6_addr16[i]))); + } + } else { + out->push_back(':'); + } + } else { + const int start = FindLongestZeroWordSequence(addr.s6_addr16); + for (int i = 0; i < 8; ++i) { + if (i == start) { + // At least two words are guaranteed to be zero. + i += 2; + while (i < 8 && addr.s6_addr16[i] == 0) { + ++i; + } + out->push_back(':'); + if (i == 8) { + out->push_back(':'); + break; + } + } + if (i) { + out->push_back(':'); + } + absl::StrAppend(out, absl::Hex(ntohs(addr.s6_addr16[i]))); + } + } +} + +} // namespace + +std::string IPAddress::ToString() const { + std::string out; + out.reserve(INET6_ADDRSTRLEN + 1); + switch (address_family_) { + case AF_INET: + AppendIPv4ToString(reinterpret_cast(&addr_.addr4.s_addr), + &out); + break; + case AF_INET6: + if (ABSL_PREDICT_FALSE(HasCompactScopeId(addr_.addr6))) { + AppendIPv6ToString(ipv6_address(), &out); + } else { + AppendIPv6ToString(addr_.addr6, &out); + } + break; + case AF_UNSPEC: + ABSL_LOG(ERROR) << "Calling ToCharBuf() on an empty IPAddress"; + return ""; + break; + default: + ABSL_LOG(FATAL) << "Unknown address family " << address_family_; + } + return out; +} + +std::string IPAddress::ToPackedString() const { + switch (address_family_) { + case AF_INET: + return std::string(reinterpret_cast(&addr_.addr4), + sizeof(addr_.addr4)); + case AF_INET6: + if (ABSL_PREDICT_FALSE(HasCompactScopeId(addr_.addr6))) { + // Calling ToPackedString() on an IPv6 link-local address is somewhat + // suspect. When later de-serialized, even on the same machine, there + // is no inherent guarantee that a given interface index remains + // valid. For now, output them the same way as their unscoped cousins + // -- what to do with the interface index and/or name is likely to be + // an application-dependent matter. + ABSL_VLOG(2) << "ToPackedString() dropping scope ID"; + const auto addr6 = ipv6_address(); + return std::string(reinterpret_cast(&addr6), + sizeof(addr6)); + } else { + return std::string(reinterpret_cast(&addr_.addr6), + sizeof(addr_.addr6)); + } + case AF_UNSPEC: + ABSL_LOG(ERROR) << "Calling ToPackedString() on an empty IPAddress"; + return ""; + default: + ABSL_LOG(FATAL) << "Unknown address family " << address_family_; + } +} + +absl::StatusOr MakeIPAddressWithScopeId(const in6_addr& addr, + uint32_t scope_id) { + if (scope_id == 0) return IPAddress(addr); + + if (!IPAddress::MayUseScopeIds(addr)) { + return absl::InvalidArgumentError("address does not use scope_ids"); + } else if (!IPAddress::MayUseCompactScopeIds(addr)) { + return absl::InvalidArgumentError("address cannot use compact scope_ids"); + } else if (!IPAddress::MayStoreCompactScopeId(addr)) { + return absl::InvalidArgumentError("address cannot safely compact scope_id"); + } + + return IPAddress(addr, scope_id); +} + +bool StringToIPAddress(const char* str, IPAddress* out) { + // Try to parse the string as an IPv4 address first. (glibc does not + // yet recognize IPv6 addresses when given an address family of + // AF_INET, but at some point it will, and at that point the other way + // around would break.) + in_addr addr4; + if (str && inet_pton(AF_INET, str, &addr4) > 0) { + if (out) { + *out = IPAddress(addr4); + } + return true; + } + + in6_addr addr6; + if (str && inet_pton(AF_INET6, str, &addr6) > 0) { + if (out) { + *out = IPAddress(addr6); + } + return true; + } + + return false; +} + +bool StringToIPAddress(const absl::string_view str, IPAddress* out) { + // We spend a lot of time in this routine, so make a zero-terminated + // copy of the piece on the stack if it's short, rather than constructing + // a temporary string. + if (str.size() <= INET6_ADDRSTRLEN) { + char buf[INET6_ADDRSTRLEN + 1]; + buf[str.size()] = '\0'; + memcpy(buf, str.data(), str.size()); + return StringToIPAddress(buf, out); + } else { + return StringToIPAddress(std::string(str).c_str(), out); + } +} + +namespace { + +// Maps error values from getaddrinfo(3) to canonical Status codes. +absl::Status InternalGetaddrinfoErrorToStatus(int rval, int copied_errno) { + if (rval == 0) return absl::OkStatus(); + + const char* error_str = gai_strerror(rval); + // Note that getaddrinfo is only guaranteed to set errno when the return + // value is EAI_SYSTEM. Otherwise, errno is unreliable. + // + // Plausible error values are from + // https://tools.ietf.org/html/rfc3493#section-6.1 + switch (rval) { + case EAI_AGAIN: + return absl::UnavailableError(absl::StrCat("EAI_AGAIN: ", error_str)); + case EAI_BADFLAGS: + return absl::InvalidArgumentError( + absl::StrCat("EAI_BADFLAGS: ", error_str)); + case EAI_FAIL: + return absl::NotFoundError(absl::StrCat("EAI_FAIL: ", error_str)); + case EAI_FAMILY: + return absl::InvalidArgumentError( + absl::StrCat("EAI_FAMILY: ", error_str)); + case EAI_MEMORY: + return absl::ResourceExhaustedError( + absl::StrCat("EAI_MEMORY: ", error_str)); + case EAI_NONAME: + return absl::NotFoundError(absl::StrCat("EAI_NONAME: ", error_str)); + case EAI_SERVICE: + return absl::InvalidArgumentError( + absl::StrCat("EAI_SERVICE: ", error_str)); + case EAI_SOCKTYPE: + return absl::InvalidArgumentError( + absl::StrCat("EAI_SOCKTYPE: ", error_str)); + default: + return absl::UnknownError( + absl::StrCat("getaddrinfo returned ", rval, " (", error_str, ")")); + } +} + +} // namespace + +absl::StatusOr StringToIPAddressWithOptionalScope( + const absl::string_view str) { + const auto scope_delimiter = str.rfind('%'); + if (scope_delimiter == absl::string_view::npos) { + IPAddress ip{}; + if (StringToIPAddress(str, &ip)) { + return ip; + } else { + return absl::InvalidArgumentError("bad IP string literal"); + } + } + + // Addresses with a scope delimiter ('%') but without a following zone_id + // does not seem to comport with any of this text: + // + // https://tools.ietf.org/html/rfc4007#section-11.2 + // https://tools.ietf.org/html/rfc4007#section-11.6 + // https://tools.ietf.org/html/rfc6874#section-2 + // + // However, it seems at least one getaddrinfo() implementation accepts this + // syntax. Until further review of text and use cases comes to a different + // conclusion, check for this case and return an error. + if (str.substr(scope_delimiter).size() == 1) { // EndsWith('%') + return absl::InvalidArgumentError("missing zone_id"); + } + + const std::string str_null_terminated(str); + + // Trust getaddrinfo()'s ability to parse scope_ids and interface names. + struct addrinfo hints{}; + hints.ai_flags = AI_NUMERICHOST | AI_NUMERICSERV; + hints.ai_family = AF_INET6; + // Hint that getaddrinfo() need not return a linked list of answers. + hints.ai_socktype = SOCK_DGRAM; + hints.ai_protocol = IPPROTO_UDP; + + struct addrinfo* res{nullptr}; + const int rval = + getaddrinfo(str_null_terminated.c_str(), nullptr, &hints, &res); + std::unique_ptr cleanup( + res, freeaddrinfo); + if (rval != 0) { + return InternalGetaddrinfoErrorToStatus(rval, errno); + } + if (res == nullptr || res->ai_addr == nullptr || + res->ai_addrlen < sizeof(struct sockaddr_in6)) { + return absl::InternalError("getaddrinfo returned nonsensical response"); + } + const auto* sin6 = reinterpret_cast(res->ai_addr); + return MakeIPAddressWithScopeId(sin6->sin6_addr, sin6->sin6_scope_id); +} + +bool PackedStringToIPAddress(absl::string_view str, IPAddress* out) { + if (str.length() == sizeof(in_addr)) { + if (out) { + in_addr addr; + memcpy(&addr, str.data(), sizeof(addr)); + *out = IPAddress(addr); + } + return true; + } else if (str.length() == sizeof(in6_addr)) { + if (out) { + in6_addr addr; + memcpy(&addr, str.data(), sizeof(addr)); + *out = IPAddress(addr); + } + return true; + } + + return false; +} + +namespace { + +bool InternalStringToNetmaskLength(absl::string_view str, + int host_address_family, int* out) { + ABSL_DCHECK(out); + + // Explicitly check that the first and last characters are digits, because + // SimpleAtoi will accept whitespace, +, -, etc. + if (str.empty() || !absl::ascii_isdigit(*str.begin()) || + !absl::ascii_isdigit(*str.rbegin())) { + return false; + } + + // Check for a decimal number. + if (absl::SimpleAtoi(str, out)) { + ABSL_DCHECK_GE(*out, 0); + const int max_length = + host_address_family == AF_INET6 ? kMaxNetmaskIPv6 : kMaxNetmaskIPv4; + return *out <= max_length; + } + + // Check for a netmask in dotted quad form, e.g. "255.255.0.0". + in_addr mask; + if (host_address_family == AF_INET && + inet_pton(AF_INET, std::string(str).c_str(), &mask) > 0) { + if (mask.s_addr == 0) { + *out = 0; + } else { + // Now we check to make sure we have a sane netmask. + // The inverted mask in native byte order (+1) will have to be a + // power of two, if it's valid. + uint32_t inv_mask = (~ntohl(mask.s_addr)) + 1; + // Power of two iff x & (x - 1) == 0. + if ((inv_mask & (inv_mask - 1)) != 0) { + return false; + } + *out = 32 - __builtin_ffs(ntohl(mask.s_addr)) + 1; + } + return true; + } + + return false; +} + +// +// The "meat" of StringToIPRange{,AndTruncate}. Does no checking of correct +// prefix length, nor any automatic truncation. +// +bool InternalStringToIPRange(absl::string_view str, + std::pair* out) { + ABSL_DCHECK(out); + + // Try to parse everything before the slash as an IP address. + // If there is no slash, then substr(0, npos) yields the full string. + const size_t slash_pos = str.find('/'); + if (!StringToIPAddress(str.substr(0, slash_pos), &out->first)) { + return false; + } + + // Try to parse everything after the slash as a prefix length. + if (slash_pos != absl::string_view::npos) { + int length; + if (!InternalStringToNetmaskLength(absl::ClippedSubstr(str, slash_pos + 1), + out->first.address_family(), &length)) { + return false; + } + out->second = length; + return true; + } + + // There was no slash, so the range covers a single address. + out->second = IPAddressLength(out->first); + return true; +} + +} // namespace + +bool StringToIPRange(absl::string_view str, IPRange* out) { + std::pair parsed; + if (!InternalStringToIPRange(str, &parsed)) { + return false; + } + const IPRange result(parsed.first, parsed.second); + if (result.host() != parsed.first) { + // Some bits were truncated. + return false; + } + if (out) { + *out = result; + } + return true; +} + +bool StringToIPRangeAndTruncate(absl::string_view str, IPRange* out) { + std::pair parsed; + if (!InternalStringToIPRange(str, &parsed)) { + return false; + } + if (out) { + *out = IPRange(parsed.first, parsed.second); + } + return true; +} + +namespace ipaddress_internal { +IPAddress TruncateIPAndLength(const IPAddress& addr, int* length_io) { + const int length = *length_io; + switch (addr.address_family()) { + case AF_INET: { + if (length >= kMaxNetmaskIPv4) { + *length_io = kMaxNetmaskIPv4; + return addr; + } else if (length > 0) { + uint32_t ip4 = IPAddressToHostUInt32(addr); + ip4 &= ~0U << (32 - length); + return HostUInt32ToIPAddress(ip4); + } else if (length == 0) { + return IPAddress::Any4(); + } + break; + } + case AF_INET6: { + if (length >= kMaxNetmaskIPv6) { + *length_io = kMaxNetmaskIPv6; + return addr; + } else if (length > 0) { + absl::uint128 ip6 = IPAddressToUInt128(addr); + ip6 &= ~absl::uint128(0) << (128 - length); + return MakeIPAddressWithOptionalScopeId(ip6, addr.scope_id()); + } else if (length == 0) { + return IPAddress::Any6(); + } + break; + } + case AF_UNSPEC: + *length_io = -1; + return addr; + } + ABSL_LOG(ERROR) << "Invalid truncation: " << addr << "/" << length; + *length_io = -1; + return IPAddress(); +} + +} // namespace ipaddress_internal + +//----------------------------------------------------------------------------- +// Carry over required functions from ipaddress.cc +//------------------------------------------------------------------------------ + +bool IsAnyIPAddress(const IPAddress& ip) { + switch (ip.address_family()) { + case AF_INET: + return ip.ipv4_address().s_addr == INADDR_ANY; + case AF_INET6: + return ip == IPAddress(in6addr_any); + default: + return false; + } +} + +bool IsLoopbackIPAddress(const IPAddress& ip) { + switch (ip.address_family()) { + case AF_INET: + return (IPAddressToHostUInt32(ip) & 0xff000000U) == 0x7f000000U; + case AF_INET6: + return ip == IPAddress::Loopback6(); + default: + return false; + } +} + +bool IsLinkLocalIP(const IPAddress& ip) { + if (ip.is_ipv4()) { + // 169.254.0.0/16 + return (IPAddressToHostUInt32(ip) & 0xffff0000U) == 0xa9fe0000U; + } + if (ip.is_ipv6()) { + // Store the address in a variable (address of temporary object) + const in6_addr addr6 = ip.ipv6_address(); + return IN6_IS_ADDR_LINKLOCAL(&addr6); + } + return false; +} + +bool IsV4MulticastIPAddress(const IPAddress& ip) { + if (!ip.is_ipv4()) { + return false; + } + return (IPAddressToHostUInt32(ip) & 0xf0000000U) == 0xe0000000U; +} + +namespace { +bool IsUniqueLocalIP(const IPAddress& ip) { + return ip.is_ipv6() && (ip.ipv6_address().s6_addr[0] & 0xfe) == 0xfc; +} +} // namespace + +bool IsPrivateIP(const IPAddress& ip) { + if (ip.is_ipv4()) { + uint32_t h = IPAddressToHostUInt32(ip); + return (h & 0xff000000U) == 0x0a000000U || // 10.0.0.0/8 + (h & 0xfff00000U) == 0xac100000U || // 172.16.0.0/12 + (h & 0xffff0000U) == 0xc0a80000U; // 192.168.0.0/16 + } else if (ip.is_ipv6()) { + return IsUniqueLocalIP(ip); + } + return false; +} + +bool IsNonRoutableIP(const IPAddress& ip) { + return !ip.is_ipv4() + ? IsPrivateIP(ip) + : IsPrivateIP(ip) || IsLinkLocalIP(ip) || + IsLoopbackIPAddress(ip) || + // 0.0.0.0/8 + (IPAddressToHostUInt32(ip) & 0xff000000U) == 0x00000000U || + // 224.0.0.0/3 + (IPAddressToHostUInt32(ip) & 0xe0000000U) == 0xe0000000U; +} + +bool GetMappedIPv4Address(const IPAddress& ip6, IPAddress* ip4) { + if (ip6.address_family() != AF_INET6) { + ABSL_DCHECK_NE(AF_UNSPEC, ip6.address_family()); + return false; + } + + in6_addr addr6 = ip6.ipv6_address(); + if (UNALIGNED_LOAD64(addr6.s6_addr16) != 0 || addr6.s6_addr16[4] != 0 || + addr6.s6_addr16[5] != 0xffff) { + return false; + } + + if (ip4) { + in_addr ipv4; + ipv4.s_addr = UNALIGNED_LOAD32(addr6.s6_addr16 + 6); + *ip4 = IPAddress(ipv4); + } + + return true; +} + +} // namespace cel diff --git a/common/ipaddress_oss.h b/common/ipaddress_oss.h new file mode 100644 index 000000000..e585e19ca --- /dev/null +++ b/common/ipaddress_oss.h @@ -0,0 +1,641 @@ +// Copyright 2008 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_IPADDRESS_OSS_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_IPADDRESS_OSS_H_ + +#include +#include +#include + +#include +#include +#include + +#include "absl/base/optimization.h" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/numeric/int128.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" + +#ifdef __APPLE__ +#define s6_addr16 __u6_addr.__u6_addr16 +#define s6_addr32 __u6_addr.__u6_addr32 +#endif + +namespace cel { + +namespace common_internal { +template +inline T UnalignedLoad(const void* p) { + T result; + std::memcpy(&result, p, sizeof(T)); + return result; +} +} // namespace common_internal + +// Forward declaration for IPAddress ostream operator, so that DCHECK +// macros know that there is an appropriate overload. +class IPAddress; +std::ostream& operator<<(std::ostream& stream, const IPAddress& address); +absl::StatusOr MakeIPAddressWithScopeId(const in6_addr&, uint32_t); + +class IPAddress { + public: + // Default constructor. Leaves the object in an empty state. + // The empty state is analogous to a NULL pointer; the only operations + // that are allowed on the object are: + // + // * Assignment and copy construction. + // * Checking for the empty state (address_family() will return AF_UNSPEC, + // or equivalently, the helper function IsInitializedAddress() will + // return false). + // * Comparison (operator== and operator!=). + // * Logging (operator<<) + // + // In particular, no guarantees are made about the behavior of the + // IPv4/IPv6 conversion accessors, of string conversions and + // serialization, of IsAnyIPAddress() and friends. + IPAddress() : address_family_(AF_UNSPEC) {} + + // Constructors from standard BSD socket address structures. + constexpr explicit IPAddress(const in_addr& addr) + : addr_(addr), address_family_(AF_INET) {} + explicit IPAddress(const in6_addr& addr) : IPAddress(addr, 0U) {} + + // The address family; either AF_UNSPEC, AF_INET or AF_INET6. + int address_family() const { return address_family_; } + + // The address as an in_addr structure; CHECK-fails if address_family() is + // not AF_INET (ie. the held address is not an IPv4 address). + in_addr ipv4_address() const { + ABSL_CHECK_EQ(AF_INET, address_family_); + return addr_.addr4; + } + + // The address as an in6_addr structure; CHECK-fails if address_family() is + // not AF_INET6 (ie. the held address is not an IPv6 address). + in6_addr ipv6_address() const { + ABSL_CHECK_EQ(AF_INET6, address_family_); + if (ABSL_PREDICT_FALSE(HasCompactScopeId(addr_.addr6))) { + return ipv6_address_slowpath(); + } + return addr_.addr6; + } + + // Convenient helpers that return true if IP address is v4 or v6 respectively. + bool is_ipv4() const { return address_family_ == AF_INET; } + bool is_ipv6() const { return address_family_ == AF_INET6; } + + // Returns the scope_id if this is an IPv6 link-local address with a + // compactly stored scope_id; 0U otherwise. An IPv6 link-local address + // may not have had a scope_id assigned; in this case 0U is also returned. + uint32_t scope_id() const { + if (is_ipv6() && HasCompactScopeId(addr_.addr6)) { + return ntohl(addr_.addr6.s6_addr32[1]); + } + return 0U; + } + + // The address in string form, as returned by inet_ntop(). In particular, + // this means that IPv6 addresses may be subject to zero compression + // (e.g. "2001:700:300:1800::f" instead of "2001:700:300:1800:0:0:0:f"). + // + std::string ToString() const; + + // Returns the same as ToString(). + // must have room for at least INET6_ADDRSTRLEN bytes, + // including the final NUL. + // Returns the pointer to the terminating NUL. + char* ToCharBuf(char* buffer) const; + + // Returns the address as a sequence of bytes in network-byte-order. + // This is suitable for writing onto the wire or into a protocol buffer + // as a string (proto1 syntax) or bytes (proto2 syntax). + // IPv4 will be 4 bytes. IPv6 will be 16 bytes. + // Can be parsed using PackedStringToIPAddress(). + std::string ToPackedString() const; + + // Static, constant IPAddresses for convenience. + static IPAddress Any4(); // 0.0.0.0 + static IPAddress Loopback4(); // 127.0.0.1 + static IPAddress Any6(); // :: + static IPAddress Loopback6(); // ::1 + + // IP addresses have no natural ordering, so only equality operators + // are defined. + bool operator==(const IPAddress& other) const { + if (address_family_ != other.address_family_) { + return false; + } + + switch (address_family_) { + case AF_INET: + return addr_.addr4.s_addr == other.addr_.addr4.s_addr; + case AF_INET6: + return Equals6(other); + default: + // We've already verified that they've got the same address family, and + // the only possibility at this point is AF_UNSPEC, which are all equal. + return true; + } + } + + bool operator!=(const IPAddress& other) const { return !(*this == other); } + + IPAddress(const IPAddress&) = default; + IPAddress(IPAddress&&) = default; + IPAddress& operator=(const IPAddress&) = default; + IPAddress& operator=(IPAddress&&) = default; + + template + friend H AbslHashValue(H h, const IPAddress& ip) { + switch (ip.address_family()) { + case AF_INET: + return H::combine(std::move(h), ip.address_family(), + ip.addr_.addr4.s_addr); + case AF_INET6: + return H::combine( + std::move(h), ip.address_family(), ip.addr_.addr6.s6_addr32[0], + ip.addr_.addr6.s6_addr32[1], ip.addr_.addr6.s6_addr32[2], + ip.addr_.addr6.s6_addr32[3]); + default: + return H::combine(std::move(h), ip.address_family()); + } + } + + friend absl::StatusOr MakeIPAddressWithScopeId(const in6_addr&, + uint32_t); + + private: + // Returns true if the given address is one that can make use of scope_ids; + // false otherwise. + // + // Currently only IPv6 link-local unicast and link-local multicast addresses + // fit this description. In the future, though, the IPv4 link-local unicast + // range (169.254.0.0/16; RFC 3927) could be considered to use scope_ids. + static bool MayUseScopeIds(const in6_addr& in6) { + return (IN6_IS_ADDR_LINKLOCAL(&in6) || IN6_IS_ADDR_MC_LINKLOCAL(&in6)); + } + + // A much stricter test for whether in6 is a candidate for the kind of + // scope_id compaction implemented here (cf. IPAddressMayUseScopeIds()). + static bool MayUseCompactScopeIds(const in6_addr& in6) { + return ((in6.s6_addr32[0] == htonl(0xfe800000U)) || + (in6.s6_addr32[0] == htonl(0xff020000U))); + } + + // Test for whether in6 may safely, compactly store a scope_id. + static bool MayStoreCompactScopeId(const in6_addr& in6) { + return (MayUseCompactScopeIds(in6) && (in6.s6_addr32[1] == 0x0U)); + } + + // Test for whether in6 appears to have a compact scope_id stored. + static bool HasCompactScopeId(const in6_addr& in6) { + return (MayUseCompactScopeIds(in6) && (in6.s6_addr32[1] != 0x0U)); + } + + // Constructor that also supports an IPv6 link-local address with a scope_id. + IPAddress(const in6_addr& addr, uint32_t scope_id) + : addr_(addr), address_family_(AF_INET6) { + if (ABSL_PREDICT_FALSE(MayUseScopeIds(addr_.addr6))) { + if (MayUseCompactScopeIds(addr_.addr6)) { + // May have been asked to explicitly overwrite one scope with another. + addr_.addr6.s6_addr32[1] = htonl(scope_id); + } else if (scope_id != 0) { + ABSL_LOG(WARNING) << "Discarding scope_id; cannot be compactly stored."; + } + } + } + + bool Equals6(const IPAddress& other) const { + ABSL_DCHECK_EQ(address_family_, AF_INET6); +#if defined(__x86_64__) || defined(__powerpc64__) + // These 64-bit CPUs have efficient implementations of + // UnalignedLoad(). + uint64_t a1 = + common_internal::UnalignedLoad(&addr_.addr6.s6_addr32[0]); + uint64_t a2 = + common_internal::UnalignedLoad(&addr_.addr6.s6_addr32[2]); + uint64_t b1 = common_internal::UnalignedLoad( + &other.addr_.addr6.s6_addr32[0]); + uint64_t b2 = common_internal::UnalignedLoad( + &other.addr_.addr6.s6_addr32[2]); + return ((a1 ^ b1) | (a2 ^ b2)) == 0; +#else + return addr_.addr6.s6_addr32[0] == other.addr_.addr6.s6_addr32[0] && + addr_.addr6.s6_addr32[1] == other.addr_.addr6.s6_addr32[1] && + addr_.addr6.s6_addr32[2] == other.addr_.addr6.s6_addr32[2] && + addr_.addr6.s6_addr32[3] == other.addr_.addr6.s6_addr32[3]; +#endif + } + + in6_addr ipv6_address_slowpath() const; + + // In order to conserve space, a separate scope_id field has not been + // added to this class. Instead, IPv6 link-local addresses have their + // scope_id stored within in6_addr.s6_addr32[1] + // + // If IETF recommendations for fe80::/10 and/or ff02::/16 or standard usage + // of these prefixes ever change to include use of addresses with a non-zero + // second 32 bits, this compaction scheme MUST be revisited. + union Addr { + Addr() {} + constexpr explicit Addr(const in_addr& a4) : addr4(a4) {} + constexpr explicit Addr(const in6_addr& a6) : addr6(a6) {} + in_addr addr4; + in6_addr addr6; + } addr_; + + // Not all platforms define sa_family_t, so use a uint16. + // On Windows, Linux, sa_family_t is uint16. + // On OSX, sa_family_t is uint8 + uint16_t address_family_; +}; + +namespace ipaddress_internal { + +// Truncate any IPv4, IPv6, or empty IPAddress to the specified length. +// If *length_io exceeds the number of bits in the address family, then it +// will be overwritten with the correct value. Normal addresses will +// CHECK-fail if the length is negative, but empty addresses ignore the +// length and write -1. +// +IPAddress TruncateIPAndLength(const IPAddress& addr, int* length_io); + +// A templated Formatter for use with the strings::Join API to print +// collections of IPAddresses, SocketAddresses, or IPRanges (or anything +// with a suitable ToString() method). See also //strings/join.h. +template +struct ToStringJoinFormatter { + void operator()(std::string* out, const T& t) const { + out->append(t.ToString()); + } +}; + +} // namespace ipaddress_internal + +// Forward declaration. See definition below. +int IPAddressLength(const IPAddress& ip); + +class IPRange { + private: + // IPRange is a tuple of (host, length). + // Using inheritance for the host allows the compiler to pack + // length into the unused bytes of IPAddress. + // + // The data is stored in a separate struct so that overload resolution + // will remain the same as before. + struct Data : public IPAddress { + int16_t length; + + Data() : length(-1) {} + Data(const Data& d) = default; + Data& operator=(const Data&) = default; + + template + explicit Data(const T& host) : IPAddress(host) {} + + template + Data(const T& host, int16_t length_arg) + : IPAddress(host), length(length_arg) {} + + bool operator==(const Data& other) const { + // Compare length first, since it is cheaper. + return length == other.length && IPAddress::operator==(other); + } + }; + + public: + // Default constructor. Leaves the object in an empty state. + // The empty state is analogous to a NULL pointer; the only operations + // that are allowed on the object are: + // + // * Assignment and copy construction. + // * Checking for the empty state (IsInitializedRange() will return false). + // * Comparison (operator== and operator!=). + // * Logging (operator<<) + // + // In particular, no guarantees are made about the behavior of the + // of string conversions and serialization, or any other accessors. + IPRange() {} + + // Constructs an IPRange from an address and a length. Properly zeroes out + // bits and adjusts length as required, but CHECK-fails on negative lengths + // (since that is inherently nonsensical). Typical examples: + // + // 129.240.2.3/10 => 129.192.0.0/10 + // 2001:700:300:1800::/48 => 2001:700:300::/48 + // + // 127.0.0.1/33 => 127.0.0.1/32 + // ::1/129 => ::1/128 + // + // IPAddress()/* => empty IPRange() + // + // 127.0.0.1/-1 => undefined (currently CHECK-fail) + // ::1/-1 => undefined (currently CHECK-fail) + // + IPRange(const IPAddress& host, int length) + : data_(ipaddress_internal::TruncateIPAndLength(host, &length)) { + data_.length = static_cast(length); + } + + // Unsafe constructor from a host and prefix length. + // + // This is the fastest way to construct an IPRange, but the caller must + // ensure that all inputs are strictly validated: + // - IPv4 host must have length 0..32 + // - IPv6 host must have length 0..128 + // - The host must be cleanly truncated, i.e. there must not be any bits + // set beyond the prefix length. + // - Uninitialized IPAddress() must have length -1 + // + // For performance reasons, these constraints are only checked in debug mode. + // Any violations will result in undefined behavior. Callers who cannot + // guarantee correctness should use IPRange(host, length) instead. + static IPRange UnsafeConstruct(const IPAddress& host, int length) { + return IPRange(host, length, /* dummy = */ 0); + } + + // Construct an IPRange from just an IPAddress, applying the + // address-family-specific maximum netmask length. + explicit IPRange(const IPAddress& host) + : data_(host, IPAddressLength(host)) {} + + // The individual parts of the subnet. + IPAddress host() const { return data_; } + int length() const { return data_.length; } + + // Subnets have no natural ordering, so only equality operators + // are defined. + bool operator==(const IPRange& other) const { return data_ == other.data_; } + bool operator!=(const IPRange& other) const { + return !(data_ == other.data_); + } + + // A string representation of the subnet, in "host/length" format. + // Examples would be "127.0.0.0/8" or "2001:700:300:1800::/64". + std::string ToString() const { + return absl::StrCat(data_.ToString(), "/", length()); + } + + // Convenience ranges, representing every IPv4 or IPv6 address. + static IPRange Any4() { + return IPRange::UnsafeConstruct(IPAddress::Any4(), 0); // 0.0.0.0/0 + } + static IPRange Any6() { + return IPRange::UnsafeConstruct(IPAddress::Any6(), 0); // ::/0 + } + + IPRange(const IPRange&) = default; + IPRange(IPRange&&) = default; + IPRange& operator=(const IPRange&) = default; + IPRange& operator=(IPRange&&) = default; + + private: + // Internal implementation of UnsafeConstruct(). + IPRange(const IPAddress& host, int length, int dummy) + : data_(host, static_cast(length)) { + ABSL_DCHECK_EQ(this->host(), + ipaddress_internal::TruncateIPAndLength(host, &length)) + << "Host has bits set beyond the prefix length."; + ABSL_DCHECK_EQ(this->length(), length) + << "Length is inconsistent with address family."; + } + + Data data_; +}; + +// Convert a host byte order uint32 into an IPv4 IPAddress. +// +// This is the less-evil cousin of UInt32ToIPAddress. It can be used with +// protobufs, mathematical/bitwise operations, or any other case where the +// address is represented as an ordinary number. +// +// Example usage: +// HostUInt32ToIPAddress(0x01020304).ToString(); // Yields "1.2.3.4" +// +IPAddress HostUInt32ToIPAddress(uint32_t address); + +// Convert an IPv4 IPAddress to a uint32 in host byte order. +// This is the inverse of HostUInt32ToIPAddress(). +// +// Example usage: +// const IPAddress addr(...); // 1.2.3.4 +// IPAddressToHostUInt32(addr); // Yields 0x01020304 +// +// Will CHECK-fail if addr does not contain an IPv4 address. +inline uint32_t IPAddressToHostUInt32(const IPAddress& addr) { + return ntohl(addr.ipv4_address().s_addr); +} + +// Convert a uint128 in host byte order to an IPv6 IPAddress +// (e.g., uint128(0, 1) will become "::1"). +// Not a constructor, to make it easier to grep for the ugliness later. +IPAddress UInt128ToIPAddress(absl::uint128 bigint); + +// Convert an IPv6 IPAddress to a uint128 in host byte order +// (e.g., "::1" will become uint128(0, 1)). +// Will CHECK-fail if addr does not contain an IPv6 address, +// so use with care, and only in low-level code. +inline absl::uint128 IPAddressToUInt128(const IPAddress& addr) { + struct in6_addr addr6 = addr.ipv6_address(); + return absl::MakeUint128( + static_cast(ntohl(addr6.s6_addr32[0])) << 32 | + static_cast(ntohl(addr6.s6_addr32[1])), + static_cast(ntohl(addr6.s6_addr32[2])) << 32 | + static_cast(ntohl(addr6.s6_addr32[3]))); +} + +// Parse an IPv4 or IPv6 address in textual form to an IPAddress. +// Not a constructor since it can fail (in which case it returns false, +// and the contents of "out" is undefined). If only validation is required, +// "out" can be set to nullptr. +// +// The input argument can be in whatever form inet_pton(AF_INET, ...) or +// inet_pton(AF_INET6, ...) accepts (ie. typically something like "127.0.0.1" +// or "2001:700:300:1800::f"). +// +// Note that in particular, this function does not do DNS lookup. +// +ABSL_MUST_USE_RESULT bool StringToIPAddress(absl::string_view str, + IPAddress* out); + +// Parse an IPv4 or IPv6 address in textual form to an IPAddress. +// +// This difference between this function and others is that this function +// additionally understands IPv6 addresses with scope identifiers (either +// numerical interface indexes or interface names) and can return properly +// scoped IP addresses (see MakeIPAddressWithScopeId() above). +absl::StatusOr StringToIPAddressWithOptionalScope( + absl::string_view str); + +// StringToIPAddress conversion methods that CHECK()-fail on invalid input. +// Not a good idea to use on user-provided input. +inline IPAddress StringToIPAddressOrDie(absl::string_view str) { + IPAddress ip; + ABSL_CHECK(StringToIPAddress(str, &ip)) << "Invalid IP " << str; + return ip; +} +// Parse a "binary" or packed string containing an IPv4 or IPv6 address in +// non-textual, network-byte-order form to an IPAddress. Not a constructor +// since it can fail (in which case it returns false, and the contents of +// "out" is undefined). If only validation is required, "out" can be set to +// nullptr. +ABSL_MUST_USE_RESULT bool PackedStringToIPAddress(absl::string_view str, + IPAddress* out); +// Binary packed string conversion methods that CHECK()-fail on invalid input. +inline IPAddress PackedStringToIPAddressOrDie(absl::string_view str) { + IPAddress ip; + ABSL_CHECK(PackedStringToIPAddress(str, &ip)) + << "Invalid packed IP address of length " << str.length(); + return ip; +} + +// For debugging/logging. Note that as a special case, you can log an +// uninitialized IP address, although you cannot use ToString() on it. +inline std::ostream& operator<<(std::ostream& stream, + const IPAddress& address) { + switch (address.address_family()) { + case AF_INET: + case AF_INET6: + return stream << address.ToString(); + case AF_UNSPEC: + return stream << ""; + default: + return stream << ""; + } +} + +// Return the family-dependent length (in bits) of an IP address given an +// IPAddress object. A debug-fatal error is logged if the address family +// is not of the Internet variety, i.e. not one of set(AF_INET, AF_INET6); +// the caller is responsible for verifying IsInitializedAddress(ip). +inline int IPAddressLength(const IPAddress& ip) { + switch (ip.address_family()) { + case AF_INET: + return 32; + case AF_INET6: + return 128; + default: + ABSL_LOG(ERROR) + << "IPAddressLength() of object with invalid address family: " + << ip.address_family(); + return -1; + } +} + +// +// Parse an IPv4 or IPv6 subnet mask in textual form into an IPRange. +// Not a constructor since it can fail (in which case it returns false, +// and the contents of "out" is undefined). If only validation is required, +// "out" can be set to nullptr. +// +// Note that an improperly zeroed out mask (say, 192.168.0.0/8) will be +// rejected as invalid by this function. If you instead want the excess bits to +// be zeroed out silently, see StringToIPRangeAndTruncate(), below. +// +// The format accepted is the same that is output by IPRange::ToString(). +// Any IP addresses without a "/netmask" will be given an implicit +// CIDR netmask length equal to the number of bits in the address +// family (e.g. /32 or /128). Additionally, IPv4 ranges may have a netmask +// specifier in the older dotted quad format, e.g. "/255.255.0.0". +// +ABSL_MUST_USE_RESULT bool StringToIPRange(absl::string_view str, IPRange* out); + +// StringToIPRange conversion methods that CHECK()-fail on invalid input. +// Not a good idea to use on user-provided input. +inline IPRange StringToIPRangeOrDie(absl::string_view str) { + IPRange ipr; + ABSL_CHECK(StringToIPRange(str, &ipr)) << "Invalid IP range " << str; + return ipr; +} + +// +// The same as StringToIPRange and StringToIPRangeOrDie, but truncating instead +// of returning an error in the event of an improperly zeroed out mask (ie., +// 192.168.0.0/8 will automatically be changed to 192.0.0.0/8). +// +ABSL_MUST_USE_RESULT bool StringToIPRangeAndTruncate(absl::string_view str, + IPRange* out); +inline IPRange StringToIPRangeAndTruncateOrDie(absl::string_view str) { + IPRange ipr; + ABSL_CHECK(StringToIPRangeAndTruncate(str, &ipr)) + << "Invalid IP range " << str; + return ipr; +} +// For debugging/logging. +inline std::ostream& operator<<(std::ostream& stream, const IPRange& range) { + if (range.host().address_family() == AF_UNSPEC) { + return stream << ""; + } else { + return stream << range.ToString(); + } +} + +// Checks whether the given IP address "needle" is within the IP range +// "haystack". Note that an IPv4 address is never considered to be within an +// IPv6 range, and vice versa. +inline bool IsWithinSubnet(const IPRange& haystack, const IPAddress& needle) { + return haystack.host().address_family() == needle.address_family() && + haystack == IPRange(needle, haystack.length()); +} + +// Checks whether the given IP range "needle" is properly contained within +// the IP range "haystack", i.e. whether "needle" is a more specific of +// "haystack". Note that an IPv4 range is never considered to be contained +// within an IPv6 range, and vice versa. +inline bool IsProperSubRange(const IPRange& haystack, const IPRange& needle) { + return haystack.length() < needle.length() && + IsWithinSubnet(haystack, needle.host()); +} + +static_assert(sizeof(IPAddress) == 20, "IPAddress should be 20 bytes"); +static_assert(sizeof(IPRange) == 20, "IPRange should be 20 bytes"); + +// Returns true if ip is initialized. +inline bool IsInitializedAddress(const IPAddress& ip) { + return ip.address_family() != AF_UNSPEC; +} + +// Returns true if ip is :: or 0.0.0.0. +bool IsAnyIPAddress(const IPAddress& ip); + +// Returns true if ip is in 127.0.0.0/8 or is ::1. +bool IsLoopbackIPAddress(const IPAddress& ip); + +// Returns true if ip is in 169.254.0.0/16 or fe80::/10. +bool IsLinkLocalIP(const IPAddress& ip); + +// Returns true if ip is in 224.0.0.0/4. +bool IsV4MulticastIPAddress(const IPAddress& ip); + +// Returns true if ip is a private address (RFC1918) or IPv6 ULA (fc00::/7). +bool IsPrivateIP(const IPAddress& ip); + +// Returns true if IP is not globally routable. +bool IsNonRoutableIP(const IPAddress& ip); + +// If ip6 is an IPv4-mapped IPv6 address, stores the IPv4 address in *ip4 and +// returns true. Otherwise returns false. +bool GetMappedIPv4Address(const IPAddress& ip6, IPAddress* ip4); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_IPADDRESS_OSS_H_ diff --git a/common/ipaddress_oss_test.cc b/common/ipaddress_oss_test.cc new file mode 100644 index 000000000..52f882f66 --- /dev/null +++ b/common/ipaddress_oss_test.cc @@ -0,0 +1,1214 @@ +// Copyright 2008 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/ipaddress_oss.h" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/log_severity.h" +#include "absl/log/absl_check.h" +#include "absl/numeric/int128.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +class ScopedMockLogVerifier { + public: + explicit ScopedMockLogVerifier(const std::string& substr) {} +}; + +IPAddress MakeScopedIP(const IPAddress& addr, uint32_t scope_id) { + if (scope_id == 0) return addr; + + ABSL_CHECK_EQ(AF_INET6, addr.address_family()); + return MakeIPAddressWithScopeId(addr.ipv6_address(), scope_id).value(); +} + +// Tests for IPAddress. +TEST(IPAddressTest, BasicTests) { + in_addr addr4; + in6_addr addr6; + + inet_pton(AF_INET, "1.2.3.4", &addr4); + inet_pton(AF_INET6, "2001:700:300:1800::f", &addr6); + + IPAddress addr(addr4); + in_addr returned_addr4 = addr.ipv4_address(); + ASSERT_EQ(AF_INET, addr.address_family()); + EXPECT_TRUE(addr.is_ipv4()); + EXPECT_FALSE(addr.is_ipv6()); + EXPECT_EQ(0, memcmp(&addr4, &returned_addr4, sizeof(addr4))); + + addr = IPAddress(addr6); + in6_addr returned_addr6 = addr.ipv6_address(); + ASSERT_EQ(AF_INET6, addr.address_family()); + EXPECT_FALSE(addr.is_ipv4()); + EXPECT_TRUE(addr.is_ipv6()); + EXPECT_EQ(0, memcmp(&addr6, &returned_addr6, sizeof(addr6))); + + addr = IPAddress(); + ASSERT_EQ(AF_UNSPEC, addr.address_family()); +} + +TEST(IPAddressTest, ConstexprIPv4) { + constexpr IPAddress addr4(in_addr{0x12345678}); + ASSERT_EQ(addr4, IPAddress(in_addr{0x12345678})); +} + +TEST(IPAddressTest, ToAndFromString4) { + const std::string kIPString = "1.2.3.4"; + const std::string kBogusIPString = "1.2.3.256"; + in_addr addr4; + ABSL_CHECK_GT(inet_pton(AF_INET, kIPString.c_str(), &addr4), 0); + + IPAddress addr; + EXPECT_FALSE(StringToIPAddress(kBogusIPString, nullptr)); + EXPECT_FALSE(StringToIPAddress(kBogusIPString, &addr)); + ASSERT_TRUE(StringToIPAddress(kIPString, nullptr)); + ASSERT_TRUE(StringToIPAddress(kIPString, &addr)); + + in_addr returned_addr4 = addr.ipv4_address(); + EXPECT_EQ(AF_INET, addr.address_family()); + EXPECT_EQ(0, memcmp(&addr4, &returned_addr4, sizeof(addr4))); + + std::string packed = addr.ToPackedString(); + EXPECT_EQ(sizeof(addr4), packed.length()); + EXPECT_EQ(0, memcmp(packed.data(), &addr4, sizeof(addr4))); + + EXPECT_TRUE(PackedStringToIPAddress(packed, nullptr)); + IPAddress unpacked; + EXPECT_TRUE(PackedStringToIPAddress(packed, &unpacked)); + EXPECT_EQ(addr, unpacked); + + EXPECT_EQ(kIPString, addr.ToString()); +} + +TEST(IPAddressTest, ThoroughToString4) { + // This thoroughly tests the internal function FastOctetToBuffer by + // evaluating every possible number in every possible position. + for (int i = 0; i < 256; ++i) { + const std::string expected = absl::StrCat(i, ".0.0.0"); + EXPECT_EQ(expected, StringToIPAddressOrDie(expected).ToString()); + } + for (int i = 0; i < 256; ++i) { + const std::string expected = absl::StrCat("0.", i, ".0.0"); + EXPECT_EQ(expected, StringToIPAddressOrDie(expected).ToString()); + } + for (int i = 0; i < 256; ++i) { + const std::string expected = absl::StrCat("0.0.", i, ".0"); + EXPECT_EQ(expected, StringToIPAddressOrDie(expected).ToString()); + } + for (int i = 0; i < 256; ++i) { + const std::string expected = absl::StrCat("0.0.0.", i); + EXPECT_EQ(expected, StringToIPAddressOrDie(expected).ToString()); + } +} + +TEST(IPAddressTest, UnsafeIPv4Strings) { + // These IPv4 string literal formats are supported by inet_aton(3). + // They are one source of "spoofed" addresses in URLs and generally + // considered unsafe. We explicitly do not support them + // (thankfully inet_pton(3) is significantly more sane). + std::vector kUnsafeIPv4Strings = { + "016.016.016", // 14.14.0.14 + "016.016", // 14.0.0.14 + "016", // 0.0.0.14 + "0x0a.0x0a.0x0a.0x0a", // 10.10.10.10 + "0x0a.0x0a.0x0a", // 10.10.0.10 + "0x0a.0x0a", // 10.0.0.10 + "0x0a", // 0.0.0.10 + "42.42.42", // 42.42.0.42 + "42.42", // 42.0.0.42 + "42", // 0.0.0.42 + // On Darwin inet_pton ignores leading zeros so this would be a valid + // 16.16.16.16 address. +#if !defined(__APPLE__) + "016.016.016.016", // 14.14.14.14 +#endif + }; + + IPAddress ip; + for (const std::string& unsafe : kUnsafeIPv4Strings) { + EXPECT_FALSE(StringToIPAddress(unsafe, &ip)); + } +} + +TEST(IPAddressTest, ToAndFromString6) { + const std::string kIPString = "2001:db8:300:1800::f"; + const std::string kBogusIPString = "2001:db8:300:1800:1:2:3:4:5"; + const std::string kBogusIPString2 = "2001:db8::g"; + + in6_addr addr6; + ABSL_CHECK_GT(inet_pton(AF_INET6, kIPString.c_str(), &addr6), 0); + + IPAddress addr; + EXPECT_FALSE(StringToIPAddress(kBogusIPString, nullptr)); + EXPECT_FALSE(StringToIPAddress(kBogusIPString, &addr)); + EXPECT_FALSE(StringToIPAddress(kBogusIPString2, nullptr)); + EXPECT_FALSE(StringToIPAddress(kBogusIPString2, &addr)); + ASSERT_TRUE(StringToIPAddress(kIPString, nullptr)); + ASSERT_TRUE(StringToIPAddress(kIPString, &addr)); + + in6_addr returned_addr6 = addr.ipv6_address(); + EXPECT_EQ(AF_INET6, addr.address_family()); + EXPECT_EQ(0, memcmp(&addr6, &returned_addr6, sizeof(addr6))); + + std::string packed = addr.ToPackedString(); + EXPECT_EQ(sizeof(addr6), packed.length()); + EXPECT_EQ(0, memcmp(packed.data(), &addr6, sizeof(addr6))); + + EXPECT_TRUE(PackedStringToIPAddress(packed, nullptr)); + IPAddress unpacked; + EXPECT_TRUE(PackedStringToIPAddress(packed, &unpacked)); + EXPECT_EQ(addr, unpacked); + + EXPECT_EQ(kIPString, addr.ToString()); +} + +// The main purpose of this test is to validate that +// StringToIPAddressWithOptionalScope has feature parity with StringToIPAddress. +TEST(IPAddressTest, ToAndFromString6WithOptionalScope) { + constexpr char kIPString[] = "2001:db8:300:1800::f"; + constexpr char const* kBogusIPStrings[] = { + "2001:db8:300:1800:1:2:3:4:5", "2001:db8::g", + "2001:db8:300:1800:1:2:3:4:5%ifacename", "2001:db8::g%ifacename"}; + + in6_addr addr6; + ABSL_CHECK_GT(inet_pton(AF_INET6, kIPString, &addr6), 0); + + for (const auto& bogus_ip_string : kBogusIPStrings) { + // This sets the environment for an error-handling bug which would only + // trigger when errno == 0. The bug has since been fixed, and this allows + // us to detect regressions. + errno = 0; + EXPECT_FALSE(StringToIPAddressWithOptionalScope(bogus_ip_string).ok()) + << "failed to reject bogus IP string \"" << bogus_ip_string << "\""; + } + + auto addr_or = StringToIPAddressWithOptionalScope(kIPString); + ASSERT_TRUE(addr_or.status().ok()); + IPAddress addr = addr_or.value(); + + in6_addr returned_addr6 = addr.ipv6_address(); + EXPECT_EQ(AF_INET6, addr.address_family()); + EXPECT_EQ(0, memcmp(&addr6, &returned_addr6, sizeof(addr6))); + + std::string packed = addr.ToPackedString(); + EXPECT_EQ(sizeof(addr6), packed.length()); + EXPECT_EQ(0, memcmp(packed.data(), &addr6, sizeof(addr6))); + + EXPECT_TRUE(PackedStringToIPAddress(packed, nullptr)); + IPAddress unpacked; + EXPECT_TRUE(PackedStringToIPAddress(packed, &unpacked)); + EXPECT_EQ(addr, unpacked); + + EXPECT_EQ(kIPString, addr.ToString()); +} + +TEST(IPAddressTest, ToAndFromString6EightColons) { + IPAddress addr; + IPAddress expected; + + EXPECT_TRUE(StringToIPAddress("::7:6:5:4:3:2:1", &addr)); + EXPECT_TRUE(StringToIPAddress("0:7:6:5:4:3:2:1", &expected)); + EXPECT_EQ(expected, addr); + + EXPECT_TRUE(StringToIPAddress("7:6:5:4:3:2:1::", &addr)); + EXPECT_TRUE(StringToIPAddress("7:6:5:4:3:2:1:0", &expected)); + EXPECT_EQ(expected, addr); +} + +TEST(IPAddressTest, EmptyStrings) { + IPAddress ip; + EXPECT_FALSE(StringToIPAddress("", &ip)); + std::string empty; + EXPECT_FALSE(StringToIPAddress(empty, &ip)); +} + +TEST(IPAddressTest, SameAsInetNToP6) { + // Test that for various classes of IP addresses IPAddress::ToString generates + // the same result as inet_ntop. + const std::string cases[] = { + "ffff:ffff:100::808:808", + "::1", + "::", + "1:2:3:4:5:6:7:8", + "2001:0:0:4::8", + "2001::4:5:6:7:8", + "2001:2:3:4:5:6:7:8", + "0:0:3::ffff", + "::4:0:0:0:ffff", + "::5:0:0:ffff", + "1::4:0:0:7:8", + "2001:658:22a:cafe::", + "::1.2.3.4", + "::ffff:1.2.3.4", + "::ffff:ffff:1:1", + "::0.1.0.0", + "1234:abcd::", + "1234::abcd:0:0:5678", + "1234:0:0:abcd::5678", + "::192.168.90.1", + "::ffff:192.168.90.1", + "1234:0:0:abcd::5678", + "1234:5678:2:9abc:def0:3:1234:5678", +// Darwin's inet_ntop does not follow RFC 5952 so we skip following tests on +// __APPLE__. +#if !defined(__APPLE__) + "::ffff", + "2001:0:3:4:5:6:7:8", + "::abcd", + "1234:5678:0:9abc:def0:0:1234:5678", +#endif + }; + char buf[INET6_ADDRSTRLEN]; + IPAddress addr; + + for (const auto& c : cases) { + EXPECT_TRUE(StringToIPAddress(c, &addr)); + EXPECT_EQ(addr.ToString(), c); + std::string packed = addr.ToPackedString(); + inet_ntop(AF_INET6, packed.data(), buf, INET6_ADDRSTRLEN); + EXPECT_EQ(addr.ToString(), std::string(buf)); + } +} + +TEST(IPAddressTest, Equality) { + const std::string kIPv4String1 = "1.2.3.4"; + const std::string kIPv4String2 = "2.3.4.5"; + const std::string kIPv6String1 = "2001:700:300:1800::f"; + const std::string kIPv6String2 = "2001:700:300:1800:0:0:0:f"; + const std::string kIPv6String3 = "::1"; + + IPAddress empty; + IPAddress addr4_1, addr4_2; + IPAddress addr6_1, addr6_2, addr6_3; + + ASSERT_TRUE(StringToIPAddress(kIPv4String1, &addr4_1)); + ASSERT_TRUE(StringToIPAddress(kIPv4String2, &addr4_2)); + ASSERT_TRUE(StringToIPAddress(kIPv6String1, &addr6_1)); + ASSERT_TRUE(StringToIPAddress(kIPv6String2, &addr6_2)); + ASSERT_TRUE(StringToIPAddress(kIPv6String3, &addr6_3)); + + // operator== + EXPECT_TRUE(empty == empty); + EXPECT_FALSE(empty == addr4_1); + EXPECT_FALSE(empty == addr4_2); + EXPECT_FALSE(empty == addr6_1); + EXPECT_FALSE(empty == addr6_2); + EXPECT_FALSE(empty == addr6_3); + + EXPECT_FALSE(addr4_1 == empty); + EXPECT_TRUE(addr4_1 == addr4_1); + EXPECT_FALSE(addr4_1 == addr4_2); + EXPECT_FALSE(addr4_1 == addr6_1); + EXPECT_FALSE(addr4_1 == addr6_2); + EXPECT_FALSE(addr4_1 == addr6_3); + + EXPECT_FALSE(addr4_2 == empty); + EXPECT_FALSE(addr4_2 == addr4_1); + EXPECT_TRUE(addr4_2 == addr4_2); + EXPECT_FALSE(addr4_2 == addr6_1); + EXPECT_FALSE(addr4_2 == addr6_2); + EXPECT_FALSE(addr4_2 == addr6_3); + + EXPECT_FALSE(addr6_1 == empty); + EXPECT_FALSE(addr6_1 == addr4_1); + EXPECT_FALSE(addr6_1 == addr4_2); + EXPECT_TRUE(addr6_1 == addr6_1); + EXPECT_TRUE(addr6_1 == addr6_2); + EXPECT_FALSE(addr6_1 == addr6_3); + + EXPECT_FALSE(addr6_2 == empty); + EXPECT_FALSE(addr6_2 == addr4_1); + EXPECT_FALSE(addr6_2 == addr4_2); + EXPECT_TRUE(addr6_2 == addr6_1); + EXPECT_TRUE(addr6_2 == addr6_2); + EXPECT_FALSE(addr6_2 == addr6_3); + + EXPECT_FALSE(addr6_3 == empty); + EXPECT_FALSE(addr6_3 == addr4_1); + EXPECT_FALSE(addr6_3 == addr4_2); + EXPECT_FALSE(addr6_3 == addr6_1); + EXPECT_FALSE(addr6_3 == addr6_2); + EXPECT_TRUE(addr6_3 == addr6_3); + + // operator!= (same tests, just inverted) + EXPECT_FALSE(empty != empty); + EXPECT_TRUE(empty != addr4_1); + EXPECT_TRUE(empty != addr4_2); + EXPECT_TRUE(empty != addr6_1); + EXPECT_TRUE(empty != addr6_2); + EXPECT_TRUE(empty != addr6_3); + + EXPECT_TRUE(addr4_1 != empty); + EXPECT_FALSE(addr4_1 != addr4_1); + EXPECT_TRUE(addr4_1 != addr4_2); + EXPECT_TRUE(addr4_1 != addr6_1); + EXPECT_TRUE(addr4_1 != addr6_2); + EXPECT_TRUE(addr4_1 != addr6_3); + + EXPECT_TRUE(addr4_2 != empty); + EXPECT_TRUE(addr4_2 != addr4_1); + EXPECT_FALSE(addr4_2 != addr4_2); + EXPECT_TRUE(addr4_2 != addr6_1); + EXPECT_TRUE(addr4_2 != addr6_2); + EXPECT_TRUE(addr4_2 != addr6_3); + + EXPECT_TRUE(addr6_1 != empty); + EXPECT_TRUE(addr6_1 != addr4_1); + EXPECT_TRUE(addr6_1 != addr4_2); + EXPECT_FALSE(addr6_1 != addr6_1); + EXPECT_FALSE(addr6_1 != addr6_2); + EXPECT_TRUE(addr6_1 != addr6_3); + + EXPECT_TRUE(addr6_2 != empty); + EXPECT_TRUE(addr6_2 != addr4_1); + EXPECT_TRUE(addr6_2 != addr4_2); + EXPECT_FALSE(addr6_2 != addr6_1); + EXPECT_FALSE(addr6_2 != addr6_2); + EXPECT_TRUE(addr6_2 != addr6_3); + + EXPECT_TRUE(addr6_3 != empty); + EXPECT_TRUE(addr6_3 != addr4_1); + EXPECT_TRUE(addr6_3 != addr4_2); + EXPECT_TRUE(addr6_3 != addr6_1); + EXPECT_TRUE(addr6_3 != addr6_2); + EXPECT_FALSE(addr6_3 != addr6_3); +} + +TEST(IPAddressTest, HostUInt32ToIPAddress) { + uint32_t addr1 = 0; + uint32_t addr2 = 0x7f000001; + uint32_t addr3 = 0xffffffff; + + EXPECT_EQ("0.0.0.0", HostUInt32ToIPAddress(addr1).ToString()); + EXPECT_EQ("127.0.0.1", HostUInt32ToIPAddress(addr2).ToString()); + EXPECT_EQ("255.255.255.255", HostUInt32ToIPAddress(addr3).ToString()); +} + +TEST(IPAddressTest, IPAddressToHostUInt32) { + IPAddress addr = StringToIPAddressOrDie("1.2.3.4"); + EXPECT_EQ(0x01020304, IPAddressToHostUInt32(addr)); +} + +TEST(IPAddressTest, UInt128ToIPAddress) { + absl::uint128 addr1(0); + absl::uint128 addr2(1); + absl::uint128 addr3 = absl::MakeUint128(std::numeric_limits::max(), + std::numeric_limits::max()); + + EXPECT_EQ("::", UInt128ToIPAddress(addr1).ToString()); + EXPECT_EQ("::1", UInt128ToIPAddress(addr2).ToString()); + EXPECT_EQ("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", + UInt128ToIPAddress(addr3).ToString()); +} + +TEST(IPAddressTest, Constants) { + EXPECT_EQ("0.0.0.0", IPAddress::Any4().ToString()); + EXPECT_EQ("::", IPAddress::Any6().ToString()); +} + +TEST(IPAddressTest, Logging) { + const std::string kIPv4String = "1.2.3.4"; + const std::string kIPv6String = "2001:700:300:1800::f"; + IPAddress addr4, addr6; + + ASSERT_TRUE(StringToIPAddress(kIPv4String, &addr4)); + ASSERT_TRUE(StringToIPAddress(kIPv6String, &addr6)); + + std::ostringstream out; + out << addr4 << " " << addr6; + EXPECT_EQ("1.2.3.4 2001:700:300:1800::f", out.str()); +} + +TEST(IPAddressTest, LoggingUninitialized) { + std::ostringstream out; + out << IPAddress(); + EXPECT_EQ("", out.str()); +} + +TEST(IPAddressTest, LoggingCorrupt) { + IPAddress corrupt_ip; + memset(&corrupt_ip, 0x55, sizeof(corrupt_ip)); + std::ostringstream out; + out << corrupt_ip; + EXPECT_EQ("", out.str()); +} + +TEST(IPAddressTest, IPv6LinkLocal) { + const IPAddress fe80_1 = StringToIPAddressOrDie("fe80::1"); + const IPAddress fe80_2 = StringToIPAddressOrDie("fe80::2"); + const IPAddress fe80_1_if17(MakeScopedIP(fe80_1, 17)); + const IPAddress fe80_2_if17(MakeScopedIP(fe80_2, 17)); + const IPAddress fe80_2_if22(MakeScopedIP(fe80_2, 22)); + const IPAddress ff02_2 = StringToIPAddressOrDie("ff02::2"); // all-routers + const IPAddress ff02_2_if17(MakeScopedIP(ff02_2, 17)); + + // IPAddress::scope_id() + EXPECT_EQ(0, fe80_1.scope_id()); + EXPECT_EQ(0, fe80_2.scope_id()); + EXPECT_EQ(17, fe80_1_if17.scope_id()); + EXPECT_EQ(17, fe80_2_if17.scope_id()); + EXPECT_EQ(22, fe80_2_if22.scope_id()); + EXPECT_EQ(0, ff02_2.scope_id()); + EXPECT_EQ(17, ff02_2_if17.scope_id()); + + for (const auto& ip : {fe80_1, fe80_2, fe80_1_if17, fe80_2_if17, fe80_2_if22, + ff02_2, ff02_2_if17}) { + EXPECT_EQ(AF_INET6, ip.address_family()); + EXPECT_EQ(128, IPAddressLength(ip)); + } + + // != + EXPECT_TRUE(fe80_1 != fe80_2); + EXPECT_TRUE(fe80_1 != fe80_1_if17); + EXPECT_TRUE(fe80_2 != fe80_2_if17); + EXPECT_TRUE(fe80_2 != fe80_2_if22); + EXPECT_TRUE(fe80_1_if17 != fe80_2_if17); + EXPECT_TRUE(fe80_2_if17 != fe80_2_if22); + + // == + EXPECT_TRUE(fe80_1_if17 == fe80_1_if17); + EXPECT_TRUE(fe80_2_if17 == fe80_2_if17); + EXPECT_TRUE(fe80_2_if22 == fe80_2_if22); + EXPECT_TRUE(fe80_1 == IPAddress(fe80_1_if17.ipv6_address())); + EXPECT_TRUE(fe80_2 == IPAddress(fe80_2_if17.ipv6_address())); + EXPECT_TRUE(fe80_2 == IPAddress(fe80_2_if22.ipv6_address())); + + // Check that we can't be super tricksy with the implementation to create + // a collision. + in6_addr addr6 = fe80_2_if17.ipv6_address(); + EXPECT_NE(fe80_2_if17, IPAddress(addr6)); + addr6.s6_addr32[1] = 17; + EXPECT_NE(fe80_2_if17, IPAddress(addr6)); + addr6.s6_addr32[1] = htonl(17); + EXPECT_NE(fe80_2_if17, IPAddress(addr6)); + + // Double-check that absence of a scope descriptor is not evidence of the + // absence of a working IP string literal parser. + EXPECT_EQ(ff02_2, + StringToIPAddressWithOptionalScope(ff02_2.ToString()).value()); + + // Appending a scope delimiter ('%') without a following zone_id does not + // seem to comport with any of this text: + // + // https://tools.ietf.org/html/rfc4007#section-11.2 + // https://tools.ietf.org/html/rfc4007#section-11.6 + // https://tools.ietf.org/html/rfc6874#section-2 + // + // so check that it's treated as an error. + EXPECT_FALSE(StringToIPAddressWithOptionalScope("fe80::%").ok()); + // Appending a default zone_id of "0" is, however, perfectly fine. + EXPECT_EQ(StringToIPAddressOrDie("fe80::"), + StringToIPAddressWithOptionalScope("fe80::%0").value()); + + // For now verify that PackedString representation is identical to the + // the un-scoped address implementation. How to meaningfully serialize and + // de-serialize an interface index or name is left as an exercise for the + // application. + EXPECT_EQ(fe80_2.ToPackedString(), fe80_2_if17.ToPackedString()); + EXPECT_EQ(fe80_2.ToPackedString(), fe80_2_if22.ToPackedString()); +} + +TEST(IPAddressTest, IPAddressLength) { + IPAddress ip; + ASSERT_TRUE(StringToIPAddress("1.2.3.4", &ip)); + EXPECT_EQ(32, IPAddressLength(ip)); + ASSERT_TRUE(StringToIPAddress("2001:db8::1", &ip)); + EXPECT_EQ(128, IPAddressLength(ip)); +} + +TEST(IPAddressDeathTest, IPAddressLength) { + IPAddress ip; + int bitlength = 0; + + ScopedMockLogVerifier log( + "IPAddressLength() of object with invalid address family"); + bitlength = IPAddressLength(ip); + EXPECT_EQ(-1, bitlength); +} + +TEST(IPAddressTest, IPAddressToUInt128) { + IPAddress addr; + ASSERT_TRUE(StringToIPAddress("2001:700:300:1803:b0ff::12", &addr)); + EXPECT_EQ(absl::MakeUint128(0x2001070003001803ULL, 0xb0ff000000000012ULL), + IPAddressToUInt128(addr)); +} + +TEST(IPAddressTest, IsAnyIPAddress) { + EXPECT_TRUE(IsAnyIPAddress(IPAddress::Any4())); + EXPECT_TRUE(IsAnyIPAddress(IPAddress::Any6())); + EXPECT_FALSE(IsAnyIPAddress(StringToIPAddressOrDie("1.2.3.4"))); + EXPECT_FALSE(IsAnyIPAddress(StringToIPAddressOrDie("::1"))); +} + +TEST(IPAddressTest, IsLoopbackIPAddress) { + EXPECT_TRUE(IsLoopbackIPAddress(StringToIPAddressOrDie("127.0.0.1"))); + EXPECT_TRUE(IsLoopbackIPAddress(StringToIPAddressOrDie("127.1.2.3"))); + EXPECT_FALSE(IsLoopbackIPAddress(StringToIPAddressOrDie("128.0.0.1"))); + EXPECT_TRUE(IsLoopbackIPAddress(IPAddress::Loopback6())); + EXPECT_FALSE(IsLoopbackIPAddress(StringToIPAddressOrDie("::2"))); +} + +TEST(IPAddressTest, IsLinkLocalIP) { + EXPECT_TRUE(IsLinkLocalIP(StringToIPAddressOrDie("169.254.0.1"))); + EXPECT_TRUE(IsLinkLocalIP(StringToIPAddressOrDie("169.254.100.200"))); + EXPECT_FALSE(IsLinkLocalIP(StringToIPAddressOrDie("169.253.0.1"))); + EXPECT_TRUE(IsLinkLocalIP(StringToIPAddressOrDie("fe80::1"))); + EXPECT_FALSE(IsLinkLocalIP(StringToIPAddressOrDie("fec0::1"))); +} + +TEST(IPAddressTest, GetMappedIPv4Address) { + IPAddress ipv4_mapped_ipv6 = StringToIPAddressOrDie("::ffff:192.168.0.1"); + IPAddress ipv4; + EXPECT_TRUE(GetMappedIPv4Address(ipv4_mapped_ipv6, &ipv4)); + EXPECT_EQ(ipv4, StringToIPAddressOrDie("192.168.0.1")); + + IPAddress ipv6 = StringToIPAddressOrDie("::1"); + EXPECT_FALSE(GetMappedIPv4Address(ipv6, &ipv4)); +} + +TEST(IPAddressTest, IsV4MulticastIPAddress) { + EXPECT_TRUE(IsV4MulticastIPAddress(StringToIPAddressOrDie("224.0.0.1"))); + EXPECT_TRUE( + IsV4MulticastIPAddress(StringToIPAddressOrDie("239.255.255.255"))); + EXPECT_FALSE( + IsV4MulticastIPAddress(StringToIPAddressOrDie("223.255.255.255"))); + EXPECT_FALSE(IsV4MulticastIPAddress(StringToIPAddressOrDie("::1"))); +} + +TEST(IPAddressTest, IsPrivateIP) { + EXPECT_TRUE(IsPrivateIP(StringToIPAddressOrDie("10.0.0.1"))); + EXPECT_TRUE(IsPrivateIP(StringToIPAddressOrDie("172.16.0.1"))); + EXPECT_TRUE(IsPrivateIP(StringToIPAddressOrDie("172.31.255.254"))); + EXPECT_TRUE(IsPrivateIP(StringToIPAddressOrDie("192.168.0.1"))); + EXPECT_FALSE(IsPrivateIP(StringToIPAddressOrDie("11.0.0.1"))); + EXPECT_FALSE(IsPrivateIP(StringToIPAddressOrDie("172.15.0.1"))); + EXPECT_FALSE(IsPrivateIP(StringToIPAddressOrDie("172.32.0.1"))); + EXPECT_FALSE(IsPrivateIP(StringToIPAddressOrDie("192.169.0.1"))); + EXPECT_TRUE(IsPrivateIP(StringToIPAddressOrDie("fc00::1"))); + EXPECT_FALSE(IsPrivateIP(StringToIPAddressOrDie("fe00::1"))); +} + +TEST(IPAddressTest, IsNonRoutableIP) { + EXPECT_TRUE(IsNonRoutableIP(StringToIPAddressOrDie("10.0.0.1"))); + EXPECT_TRUE(IsNonRoutableIP(StringToIPAddressOrDie("127.0.0.1"))); + EXPECT_TRUE(IsNonRoutableIP(StringToIPAddressOrDie("169.254.0.1"))); + EXPECT_TRUE(IsNonRoutableIP(StringToIPAddressOrDie("0.0.0.0"))); + EXPECT_TRUE(IsNonRoutableIP(StringToIPAddressOrDie("224.0.0.1"))); + EXPECT_FALSE(IsNonRoutableIP(StringToIPAddressOrDie("8.8.8.8"))); + EXPECT_TRUE(IsNonRoutableIP(StringToIPAddressOrDie("fc00::1"))); + // Current implementation for IPv6 only checks IsPrivateIP. + EXPECT_FALSE(IsNonRoutableIP(StringToIPAddressOrDie("::1"))); + EXPECT_FALSE(IsNonRoutableIP(StringToIPAddressOrDie("fe80::1"))); +} + +// Various death tests for IPAddress emergency behavior in production that +// should simply result in CHECK failures in debug mode. + +TEST(IPAddressDeathTest, EmergencyCoercion) { + const std::string kIPv6Address = "2001:700:300:1803::1"; + IPAddress addr; + in_addr addr4; + + ABSL_CHECK(StringToIPAddress(kIPv6Address, &addr)); + + if (DEBUG_MODE) { + EXPECT_DEBUG_DEATH(addr4 = addr.ipv4_address(), "Check failed"); + } else { + ScopedMockLogVerifier log("returning IPv4-coerced address"); + addr4 = addr.ipv4_address(); + } +} + +TEST(IPAddressDeathTest, EmergencyCompatibility) { + const std::string kIPv4Address = "129.240.2.40"; + IPAddress addr; + in6_addr addr6; + + ABSL_CHECK(StringToIPAddress(kIPv4Address, &addr)); + + if (DEBUG_MODE) { + EXPECT_DEBUG_DEATH(addr6 = addr.ipv6_address(), "Check failed"); + } else { + ScopedMockLogVerifier log("returning IPv6 mapped address"); + addr6 = addr.ipv6_address(); + EXPECT_EQ("::ffff:129.240.2.40", IPAddress(addr6).ToString()); + } +} + +// Invalid conversion in *OrDie() functions. +TEST(IPAddressDeathTest, InvalidStringConversion) { + // Valid conversion. + EXPECT_EQ(StringToIPAddressOrDie("1.2.3.4").ToString(), "1.2.3.4"); + EXPECT_EQ(StringToIPAddressOrDie("1.2.3.4").ToString(), "1.2.3.4"); + EXPECT_EQ(StringToIPAddressOrDie(absl::string_view("1.2.3.4")).ToString(), + "1.2.3.4"); + EXPECT_EQ(StringToIPAddressOrDie("2001:700:300:1803::1").ToString(), + "2001:700:300:1803::1"); + EXPECT_EQ(StringToIPAddressOrDie("2001:700:300:1803::1").ToString(), + "2001:700:300:1803::1"); + EXPECT_EQ(StringToIPAddressOrDie(absl::string_view("2001:700:300:1803::1")) + .ToString(), + "2001:700:300:1803::1"); +} + +// Tests for IPRange. +TEST(IPRangeTest, BasicTest4) { + IPAddress addr; + const uint16_t kPrefixLength = 16; + ASSERT_TRUE(StringToIPAddress("192.168.0.0", &addr)); + IPRange subnet(addr, kPrefixLength); + EXPECT_EQ(addr, subnet.host()); + EXPECT_EQ(kPrefixLength, subnet.length()); + + // Test copy construction. + IPRange another_subnet = subnet; + EXPECT_EQ(addr, another_subnet.host()); + EXPECT_EQ(kPrefixLength, another_subnet.length()); + + // Test IPAddress constructor. + EXPECT_EQ(addr, IPRange(addr).host()); + EXPECT_EQ(32, IPRange(addr).length()); +} + +TEST(IPRangeTest, BasicTest6) { + IPAddress addr; + const uint16_t kPrefixLength = 64; + ASSERT_TRUE(StringToIPAddress("2001:700:300:1800::", &addr)); + IPRange subnet(addr, kPrefixLength); + EXPECT_EQ(addr, subnet.host()); + EXPECT_EQ(kPrefixLength, subnet.length()); + + // Test copy construction. + IPRange another_subnet = subnet; + EXPECT_EQ(addr, another_subnet.host()); + EXPECT_EQ(kPrefixLength, another_subnet.length()); + + // Test IPAddress constructor. + EXPECT_EQ(addr, IPRange(addr).host()); + EXPECT_EQ(128, IPRange(addr).length()); +} + +TEST(IPRangeTest, AnyRanges) { + EXPECT_EQ("0.0.0.0/0", IPRange::Any4().ToString()); + EXPECT_EQ("::/0", IPRange::Any6().ToString()); +} + +TEST(IPRangeTest, ToAndFromString4) { + const std::string kIPString = "192.168.0.0"; + const int kLength = 16; + const std::string kSubnetString = kIPString + absl::StrFormat("/%u", kLength); + const std::string kBogusSubnetString1 = "192.168.0.0/8"; + const std::string kBogusSubnetString2 = "192.256.0.0/16"; + const std::string kBogusSubnetString3 = "192.168.0.0/34"; + const std::string kBogusSubnetString4 = "0.0.0.0/-1"; + const std::string kBogusSubnetString5 = "0.0.0.0/+1"; + const std::string kBogusSubnetString6 = "0.0.0.0/"; + const std::string kBogusSubnetString7 = "192.168.0.0/16/16"; + const std::string kBogusSubnetString8 = "192.168.0.0/16 "; + const std::string kBogusSubnetString9 = " 192.168.0.0/16"; + const std::string kBogusSubnetString10 = "192.168.0.0 /16"; + + IPRange subnet; + EXPECT_FALSE(StringToIPRange(kBogusSubnetString1, &subnet)); + EXPECT_FALSE(StringToIPRange(kBogusSubnetString2, &subnet)); + EXPECT_FALSE(StringToIPRange(kBogusSubnetString3, &subnet)); + EXPECT_FALSE(StringToIPRange(kBogusSubnetString4, &subnet)); + EXPECT_FALSE(StringToIPRange(kBogusSubnetString5, &subnet)); + EXPECT_FALSE(StringToIPRange(kBogusSubnetString6, &subnet)); + EXPECT_FALSE(StringToIPRange(kBogusSubnetString7, &subnet)); + EXPECT_FALSE(StringToIPRange(kBogusSubnetString8, &subnet)); + EXPECT_FALSE(StringToIPRange(kBogusSubnetString9, &subnet)); + EXPECT_FALSE(StringToIPRange(kBogusSubnetString10, &subnet)); + ASSERT_TRUE(StringToIPRange(kSubnetString, nullptr)); + ASSERT_TRUE(StringToIPRange(kSubnetString, &subnet)); + + IPAddress addr4; + ASSERT_TRUE(StringToIPAddress(kIPString, &addr4)); + EXPECT_EQ(addr4, subnet.host()); + EXPECT_EQ(kLength, subnet.length()); + + EXPECT_EQ(kSubnetString, subnet.ToString()); + + EXPECT_TRUE(StringToIPRangeAndTruncate(kBogusSubnetString1, &subnet)); + EXPECT_EQ("192.0.0.0/8", subnet.ToString()); + EXPECT_FALSE(StringToIPRangeAndTruncate(kBogusSubnetString2, &subnet)); + EXPECT_FALSE(StringToIPRangeAndTruncate(kBogusSubnetString3, &subnet)); + EXPECT_FALSE(StringToIPRangeAndTruncate(kBogusSubnetString4, &subnet)); + EXPECT_FALSE(StringToIPRangeAndTruncate(kBogusSubnetString5, &subnet)); + EXPECT_FALSE(StringToIPRangeAndTruncate(kBogusSubnetString6, &subnet)); + EXPECT_FALSE(StringToIPRangeAndTruncate(kBogusSubnetString7, &subnet)); + EXPECT_FALSE(StringToIPRangeAndTruncate(kBogusSubnetString8, &subnet)); + EXPECT_FALSE(StringToIPRangeAndTruncate(kBogusSubnetString9, &subnet)); + EXPECT_FALSE(StringToIPRangeAndTruncate(kBogusSubnetString10, &subnet)); +} + +TEST(IPRangeTest, DottedQuadNetmasks) { + const std::string kIPString = "192.168.0.0"; + const std::string kDottedQuadNetmaskString = "255.255.0.0"; + const int kLength = 16; + const std::string kSubnetString = kIPString + absl::StrFormat("/%u", kLength); + const std::string kDottedQuadSubnetString = + kIPString + "/" + kDottedQuadNetmaskString; + + const std::vector kBogusDottedQuadStrings = { + "192.168.0.0/128.255.0.0", + "3ffe::1/255.255.0.0", + "1.2.3.4/255", + "1.2.3.4/255.", + "1.2.3.4/255.255", + "1.2.3.4/255.255.", + "1.2.3.4/255.255.255", + "1.2.3.4/255.255.255.", + "1.2.3.4/255.255.255.256", + "1.2.3.4/255.255.255.-255", + "1.2.3.4/255.255.255.+255", + "1.2.3.4/255.255.255.garbage", + // On Darwin inet_pton ignores leading zeros so these would be valid. +#if !defined(__APPLE__) + "1.2.3.4/0255.255.255.255", + "1.2.3.4/255.255.255.000255", +#endif + }; + + // Check bogus strings. + for (const std::string& bogus : kBogusDottedQuadStrings) { + EXPECT_FALSE(StringToIPRangeAndTruncate(bogus, nullptr)) + << "Apparently '" << bogus << "' is actually valid?"; + } + + // Check valid strings. + IPRange cidr; + IPRange dotted_quad; + ASSERT_TRUE(StringToIPRangeAndTruncate(kSubnetString, &cidr)); + ASSERT_TRUE( + StringToIPRangeAndTruncate(kDottedQuadSubnetString, &dotted_quad)); + ASSERT_TRUE(cidr == dotted_quad); + + // Check some corner cases. + EXPECT_TRUE(StringToIPRange("0.0.0.0/0.0.0.0", &cidr)); + EXPECT_EQ(0, cidr.length()); + EXPECT_EQ(IPAddress::Any4(), cidr.host()); + + // If .expected_host_string is empty then .dotted_quad_string is + // expected to FAIL StringToIPRangeAndTruncate(). + struct DottedQuadExpecations { + std::string dotted_quad_string; + std::string expected_host_string; + int expected_length; + }; + const std::vector dotted_quad_tests = { + {"1.2.3.4/0.0.0.1", "", -1}, + {"1.2.3.4/1.0.0.0", "", -1}, + {"1.2.3.4/127.255.255.255", "", -1}, + {"1.2.3.4/254.255.255.255", "", -1}, + {"1.2.3.4/255.255.255.254", "1.2.3.4", 31}, + {"1.2.3.4/0.0.0.0", "0.0.0.0", 0}, + }; + + for (const DottedQuadExpecations& entry : dotted_quad_tests) { + IPRange range; + IPAddress host; + + if (entry.expected_host_string.empty()) { + // The dotted quad string should be rejected as invalid. + ASSERT_FALSE( + StringToIPRangeAndTruncate(entry.dotted_quad_string, &range)); + continue; + } + ASSERT_TRUE(StringToIPRangeAndTruncate(entry.dotted_quad_string, &range)); + ASSERT_TRUE(StringToIPAddress(entry.expected_host_string, &host)); + EXPECT_EQ(host, range.host()) + << entry.dotted_quad_string << " host equality expectation failed"; + EXPECT_EQ(entry.expected_length, range.length()) + << entry.dotted_quad_string << " length equality expectation failed"; + } +} + +TEST(IPRangeTest, FromAddressString4) { + const std::string kIPString = "192.168.0.0"; + IPAddress addr4; + ASSERT_TRUE(StringToIPAddress(kIPString, &addr4)); + + IPRange subnet; + EXPECT_TRUE(StringToIPRange(kIPString, &subnet)); + EXPECT_EQ(addr4, subnet.host()); + EXPECT_EQ(32, subnet.length()); + + EXPECT_TRUE(StringToIPRangeAndTruncate(kIPString, &subnet)); + EXPECT_EQ(addr4, subnet.host()); + EXPECT_EQ(32, subnet.length()); +} + +TEST(IPRangeTest, ToAndFromString6) { + const std::string kIPString = "2001:700:300:1800::"; + const int kLength = 64; + const std::string kSubnetString = kIPString + absl::StrFormat("/%u", kLength); + const std::string kBogusSubnetString1 = "2001:700:300:1800::/48"; + const std::string kBogusSubnetString2 = "2001:700:300:180g::/64"; + const std::string kBogusSubnetString3 = "2001:700:300:1800::/129"; + const std::string kBogusSubnetString4 = "::/-1"; + const std::string kBogusSubnetString5 = "::/+1"; + const std::string kBogusSubnetString6 = "::/"; + const std::string kBogusSubnetString7 = "2001:700:300:1800::/64/64"; + const std::string kBogusSubnetString8 = "2001:700:300:1800::/64 "; + const std::string kBogusSubnetString9 = " 2001:700:300:1800::/64"; + const std::string kBogusSubnetString10 = "2001:700:300:1800:: /64"; + + IPRange subnet; + EXPECT_FALSE(StringToIPRange(kBogusSubnetString1, &subnet)); + EXPECT_FALSE(StringToIPRange(kBogusSubnetString2, &subnet)); + EXPECT_FALSE(StringToIPRange(kBogusSubnetString3, &subnet)); + EXPECT_FALSE(StringToIPRange(kBogusSubnetString4, &subnet)); + EXPECT_FALSE(StringToIPRange(kBogusSubnetString5, &subnet)); + EXPECT_FALSE(StringToIPRange(kBogusSubnetString6, &subnet)); + EXPECT_FALSE(StringToIPRange(kBogusSubnetString7, &subnet)); + EXPECT_FALSE(StringToIPRange(kBogusSubnetString8, &subnet)); + EXPECT_FALSE(StringToIPRange(kBogusSubnetString9, &subnet)); + EXPECT_FALSE(StringToIPRange(kBogusSubnetString10, &subnet)); + ASSERT_TRUE(StringToIPRange(kSubnetString, nullptr)); + ASSERT_TRUE(StringToIPRange(kSubnetString, &subnet)); + + IPAddress addr6; + ASSERT_TRUE(StringToIPAddress(kIPString, &addr6)); + EXPECT_EQ(addr6, subnet.host()); + EXPECT_EQ(kLength, subnet.length()); + + EXPECT_EQ(kSubnetString, subnet.ToString()); + + EXPECT_TRUE(StringToIPRangeAndTruncate(kBogusSubnetString1, &subnet)); + EXPECT_EQ("2001:700:300::/48", subnet.ToString()); + EXPECT_FALSE(StringToIPRangeAndTruncate(kBogusSubnetString2, &subnet)); + EXPECT_FALSE(StringToIPRangeAndTruncate(kBogusSubnetString3, &subnet)); + EXPECT_FALSE(StringToIPRangeAndTruncate(kBogusSubnetString4, &subnet)); + EXPECT_FALSE(StringToIPRangeAndTruncate(kBogusSubnetString5, &subnet)); + EXPECT_FALSE(StringToIPRangeAndTruncate(kBogusSubnetString6, &subnet)); + EXPECT_FALSE(StringToIPRangeAndTruncate(kBogusSubnetString7, &subnet)); + EXPECT_FALSE(StringToIPRangeAndTruncate(kBogusSubnetString8, &subnet)); + EXPECT_FALSE(StringToIPRangeAndTruncate(kBogusSubnetString9, &subnet)); + EXPECT_FALSE(StringToIPRangeAndTruncate(kBogusSubnetString10, &subnet)); +} + +TEST(IPRangeTest, FromAddressString6) { + const std::string kIPString = "2001:700:300:1800::"; + IPAddress addr6; + ASSERT_TRUE(StringToIPAddress(kIPString, &addr6)); + + IPRange subnet; + EXPECT_TRUE(StringToIPRange(kIPString, &subnet)); + EXPECT_EQ(addr6, subnet.host()); + EXPECT_EQ(128, subnet.length()); + + EXPECT_TRUE(StringToIPRangeAndTruncate(kIPString, &subnet)); + EXPECT_EQ(addr6, subnet.host()); + EXPECT_EQ(128, subnet.length()); +} + +TEST(IPRangeTest, Equality) { + const std::string kIPv4String1 = "192.168.0.0/16"; + const std::string kIPv4String2 = "192.168.0.0/24"; + const std::string kIPv6String1 = "2001:700:300:1800::/64"; + const std::string kIPv6String2 = "2001:700:300:1800:0:0::/64"; + const std::string kIPv6String3 = "2001:700:300:dc0f::/64"; + + IPRange subnet4_1, subnet4_2; + IPRange subnet6_1, subnet6_2, subnet6_3; + + ASSERT_TRUE(StringToIPRange(kIPv4String1, &subnet4_1)); + ASSERT_TRUE(StringToIPRange(kIPv4String2, &subnet4_2)); + ASSERT_TRUE(StringToIPRange(kIPv6String1, &subnet6_1)); + ASSERT_TRUE(StringToIPRange(kIPv6String2, &subnet6_2)); + ASSERT_TRUE(StringToIPRange(kIPv6String3, &subnet6_3)); + + // operator== + EXPECT_TRUE(subnet4_1 == subnet4_1); + EXPECT_FALSE(subnet4_1 == subnet4_2); + EXPECT_FALSE(subnet4_1 == subnet6_1); + EXPECT_FALSE(subnet4_1 == subnet6_2); + EXPECT_FALSE(subnet4_1 == subnet6_3); + + EXPECT_FALSE(subnet4_2 == subnet4_1); + EXPECT_TRUE(subnet4_2 == subnet4_2); + EXPECT_FALSE(subnet4_2 == subnet6_1); + EXPECT_FALSE(subnet4_2 == subnet6_2); + EXPECT_FALSE(subnet4_2 == subnet6_3); + + EXPECT_FALSE(subnet6_1 == subnet4_1); + EXPECT_FALSE(subnet6_1 == subnet4_2); + EXPECT_TRUE(subnet6_1 == subnet6_1); + EXPECT_TRUE(subnet6_1 == subnet6_2); + EXPECT_FALSE(subnet6_1 == subnet6_3); + + EXPECT_FALSE(subnet6_2 == subnet4_1); + EXPECT_FALSE(subnet6_2 == subnet4_2); + EXPECT_TRUE(subnet6_2 == subnet6_1); + EXPECT_TRUE(subnet6_2 == subnet6_2); + EXPECT_FALSE(subnet6_2 == subnet6_3); + + EXPECT_FALSE(subnet6_3 == subnet4_1); + EXPECT_FALSE(subnet6_3 == subnet4_2); + EXPECT_FALSE(subnet6_3 == subnet6_1); + EXPECT_FALSE(subnet6_3 == subnet6_2); + EXPECT_TRUE(subnet6_3 == subnet6_3); + + // operator!= (same tests, just inverted) + EXPECT_FALSE(subnet4_1 != subnet4_1); + EXPECT_TRUE(subnet4_1 != subnet4_2); + EXPECT_TRUE(subnet4_1 != subnet6_1); + EXPECT_TRUE(subnet4_1 != subnet6_2); + EXPECT_TRUE(subnet4_1 != subnet6_3); + + EXPECT_TRUE(subnet4_2 != subnet4_1); + EXPECT_FALSE(subnet4_2 != subnet4_2); + EXPECT_TRUE(subnet4_2 != subnet6_1); + EXPECT_TRUE(subnet4_2 != subnet6_2); + EXPECT_TRUE(subnet4_2 != subnet6_3); + + EXPECT_TRUE(subnet6_1 != subnet4_1); + EXPECT_TRUE(subnet6_1 != subnet4_2); + EXPECT_FALSE(subnet6_1 != subnet6_1); + EXPECT_FALSE(subnet6_1 != subnet6_2); + EXPECT_TRUE(subnet6_1 != subnet6_3); + + EXPECT_TRUE(subnet6_2 != subnet4_1); + EXPECT_TRUE(subnet6_2 != subnet4_2); + EXPECT_FALSE(subnet6_2 != subnet6_1); + EXPECT_FALSE(subnet6_2 != subnet6_2); + EXPECT_TRUE(subnet6_2 != subnet6_3); + + EXPECT_TRUE(subnet6_3 != subnet4_1); + EXPECT_TRUE(subnet6_3 != subnet4_2); + EXPECT_TRUE(subnet6_3 != subnet6_1); + EXPECT_TRUE(subnet6_3 != subnet6_2); + EXPECT_FALSE(subnet6_3 != subnet6_3); +} + +TEST(IPRangeTest, LowerAndUpper4) { + IPAddress expected, ip; + IPRange range; + + ASSERT_TRUE(StringToIPAddress("1.2.3.4", &ip)); + + // 1.2.3.4/0 + range = IPRange(ip, 0); + ASSERT_TRUE(StringToIPAddress("0.0.0.0", &expected)); + EXPECT_EQ(expected, range.host()); + ASSERT_TRUE(StringToIPAddress("255.255.255.255", &expected)); + + // 1.2.3.4/25 + range = IPRange(ip, 25); + ASSERT_TRUE(StringToIPAddress("1.2.3.0", &expected)); + EXPECT_EQ(expected, range.host()); + ASSERT_TRUE(StringToIPAddress("1.2.3.127", &expected)); + + // 1.2.3.4/31 + range = IPRange(ip, 31); + EXPECT_EQ(ip, range.host()); + ASSERT_TRUE(StringToIPAddress("1.2.3.5", &expected)); + + // 1.2.3.4/32 + range = IPRange(ip, 32); + EXPECT_EQ(ip, range.host()); +} + +TEST(IPRangeTest, LowerAndUpper6) { + IPAddress expected, ip; + IPRange range; + + ASSERT_TRUE(StringToIPAddress("1:2:3:4:5:6:7:8", &ip)); + + // 1:2:3:4:5:6:7:8/0 + range = IPRange(ip, 0); + ASSERT_TRUE(StringToIPAddress("::", &expected)); + EXPECT_EQ(expected, range.host()); + ASSERT_TRUE( + StringToIPAddress("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", &expected)); + + // 1:2:3:4:5:6:7:8/113 + range = IPRange(ip, 113); + ASSERT_TRUE(StringToIPAddress("1:2:3:4:5:6:7:0", &expected)); + EXPECT_EQ(expected, range.host()); + ASSERT_TRUE(StringToIPAddress("1:2:3:4:5:6:7:7fff", &expected)); + + // 1:2:3:4:5:6:7:8/127 + range = IPRange(ip, 127); + EXPECT_EQ(ip, range.host()); + ASSERT_TRUE(StringToIPAddress("1:2:3:4:5:6:7:9", &expected)); + + // 1:2:3:4:5:6:7:8/128 + range = IPRange(ip, 128); + EXPECT_EQ(ip, range.host()); +} + +TEST(IPRangeTest, IsWithinSubnet) { + const IPRange subnet1 = StringToIPRangeOrDie("192.168.0.0/16"); + const IPRange subnet2 = StringToIPRangeOrDie("192.168.0.0/24"); + const IPRange subnet3 = StringToIPRangeOrDie("2001:700:300:1800::/64"); + const IPRange subnet4 = StringToIPRangeOrDie("::/0"); + + const IPAddress addr1 = StringToIPAddressOrDie("192.168.1.5"); + const IPAddress addr2 = StringToIPAddressOrDie("2001:700:300:1800::1"); + const IPAddress addr3 = StringToIPAddressOrDie("2001:700:300:1801::1"); + + EXPECT_TRUE(IsWithinSubnet(subnet1, addr1)); + EXPECT_FALSE(IsWithinSubnet(subnet2, addr1)); + EXPECT_FALSE(IsWithinSubnet(subnet3, addr1)); + EXPECT_FALSE(IsWithinSubnet(subnet4, addr1)); + + EXPECT_FALSE(IsWithinSubnet(subnet1, addr2)); + EXPECT_FALSE(IsWithinSubnet(subnet2, addr2)); + EXPECT_TRUE(IsWithinSubnet(subnet3, addr2)); + EXPECT_TRUE(IsWithinSubnet(subnet4, addr2)); + + EXPECT_FALSE(IsWithinSubnet(subnet1, addr3)); + EXPECT_FALSE(IsWithinSubnet(subnet2, addr3)); + EXPECT_FALSE(IsWithinSubnet(subnet3, addr3)); + EXPECT_TRUE(IsWithinSubnet(subnet4, addr3)); + + EXPECT_FALSE(IsWithinSubnet(subnet1, IPAddress())); + EXPECT_FALSE(IsWithinSubnet(IPRange(), addr1)); +} + +TEST(IPRangeTest, IsProperSubRange) { + std::vector kRangeString = { + "192.168.0.0/15", "192.169.0.0/16", "192.168.0.0/24", + "192.168.0.80/28", "::/0", "2001:700:300:1800::/64", + }; + + std::vector ranges; + for (const std::string& range_str : kRangeString) { + IPRange range; + ASSERT_TRUE(StringToIPRange(range_str, &range)); + EXPECT_FALSE(IsProperSubRange(range, range)); + ranges.push_back(range); + } + + EXPECT_TRUE(IsProperSubRange(ranges[0], ranges[1])); + EXPECT_TRUE(IsProperSubRange(ranges[0], ranges[2])); + EXPECT_TRUE(IsProperSubRange(ranges[0], ranges[3])); + EXPECT_FALSE(IsProperSubRange(ranges[0], ranges[4])); + EXPECT_FALSE(IsProperSubRange(ranges[0], ranges[5])); + + EXPECT_FALSE(IsProperSubRange(ranges[1], ranges[0])); + EXPECT_FALSE(IsProperSubRange(ranges[1], ranges[2])); + EXPECT_FALSE(IsProperSubRange(ranges[1], ranges[3])); + EXPECT_FALSE(IsProperSubRange(ranges[1], ranges[4])); + EXPECT_FALSE(IsProperSubRange(ranges[1], ranges[5])); + + EXPECT_FALSE(IsProperSubRange(ranges[2], ranges[0])); + EXPECT_FALSE(IsProperSubRange(ranges[2], ranges[1])); + EXPECT_TRUE(IsProperSubRange(ranges[2], ranges[3])); + EXPECT_FALSE(IsProperSubRange(ranges[2], ranges[4])); + EXPECT_FALSE(IsProperSubRange(ranges[2], ranges[5])); + + EXPECT_FALSE(IsProperSubRange(ranges[3], ranges[0])); + EXPECT_FALSE(IsProperSubRange(ranges[3], ranges[1])); + EXPECT_FALSE(IsProperSubRange(ranges[3], ranges[2])); + EXPECT_FALSE(IsProperSubRange(ranges[3], ranges[4])); + EXPECT_FALSE(IsProperSubRange(ranges[3], ranges[5])); + + EXPECT_FALSE(IsProperSubRange(ranges[4], ranges[0])); + EXPECT_FALSE(IsProperSubRange(ranges[4], ranges[1])); + EXPECT_FALSE(IsProperSubRange(ranges[4], ranges[2])); + EXPECT_FALSE(IsProperSubRange(ranges[4], ranges[3])); + EXPECT_TRUE(IsProperSubRange(ranges[4], ranges[5])); + + EXPECT_FALSE(IsProperSubRange(ranges[5], ranges[0])); + EXPECT_FALSE(IsProperSubRange(ranges[5], ranges[1])); + EXPECT_FALSE(IsProperSubRange(ranges[5], ranges[2])); + EXPECT_FALSE(IsProperSubRange(ranges[5], ranges[3])); + EXPECT_FALSE(IsProperSubRange(ranges[5], ranges[4])); + + for (const IPRange& r : ranges) { + EXPECT_FALSE(IsProperSubRange(IPRange(), r)); + EXPECT_FALSE(IsProperSubRange(r, IPRange())); + } +} + +TEST(IPRangeTest, IPv6LinkLocal) { + const IPAddress fe80_1 = StringToIPAddressOrDie("fe80::1"); + const IPAddress fe80_1_if17(MakeScopedIP(fe80_1, 17)); + + const IPRange linklocal64(IPRange(fe80_1, 64)); + const IPRange linklocal64_if17(IPRange(fe80_1_if17, 64)); + + EXPECT_NE(linklocal64, linklocal64_if17); + + // Truncation beyond the boundary of what qualifies a prefix as being + // scope_id-applicable doesn't really make sense. In order to prevent the + // creation of some kind of ::%eth0/0 IPRange, truncation beyond scope_id + // qualification discards the scope_id. + // + // IPv6 unicast link-local prefix is fe80::/10 (but this is [presently] + // indistinguishable from fe80::/9). + EXPECT_EQ(17, IPRange(fe80_1_if17, 10).host().scope_id()); + EXPECT_EQ(0, IPRange(fe80_1_if17, 8).host().scope_id()); + // IPv6 multicast link-local prefix is ff02::/16 (but this is [presently] + // indistinguishable from ff02::/15). + const IPAddress ff02_2_if17( + StringToIPAddressWithOptionalScope("ff02::2%17").value()); + EXPECT_EQ(17, IPRange(ff02_2_if17, 16).host().scope_id()); + EXPECT_EQ(0, IPRange(ff02_2_if17, 14).host().scope_id()); + + // IsWithinSubnet follows the truncation logic above. + EXPECT_FALSE(IsWithinSubnet(linklocal64, fe80_1_if17)); + EXPECT_FALSE(IsWithinSubnet(linklocal64_if17, fe80_1)); + EXPECT_TRUE(IsWithinSubnet(linklocal64_if17, fe80_1_if17)); + EXPECT_TRUE(IsWithinSubnet(IPRange(fe80_1_if17, 10), fe80_1_if17)); + EXPECT_FALSE(IsWithinSubnet(IPRange(fe80_1, 10), fe80_1_if17)); + EXPECT_TRUE(IsWithinSubnet(IPRange(fe80_1, 8), fe80_1_if17)); + EXPECT_TRUE(IsWithinSubnet(IPRange::Any6(), fe80_1_if17)); + + // IsProperSubRange also follows the truncation logic above. + EXPECT_FALSE(IsProperSubRange(linklocal64, IPRange(fe80_1_if17))); + EXPECT_FALSE(IsProperSubRange(linklocal64_if17, IPRange(fe80_1))); + EXPECT_TRUE(IsProperSubRange(linklocal64_if17, IPRange(fe80_1_if17))); + EXPECT_TRUE(IsProperSubRange(IPRange(fe80_1_if17, 10), linklocal64_if17)); + EXPECT_FALSE(IsProperSubRange(IPRange(fe80_1, 10), linklocal64_if17)); + EXPECT_TRUE(IsProperSubRange(IPRange(fe80_1, 8), linklocal64_if17)); + EXPECT_TRUE(IsProperSubRange(IPRange::Any6(), linklocal64_if17)); +} + +TEST(IPRangeTest, UnsafeConstruct) { + // Valid inputs. + IPRange::UnsafeConstruct(IPAddress(), -1); + IPRange::UnsafeConstruct(StringToIPAddressOrDie("192.0.2.0"), 24); + IPRange::UnsafeConstruct(StringToIPAddressOrDie("2001:db8::"), 32); +} + +TEST(IPRangeTest, LoggingUninitialized) { + std::ostringstream out; + out << IPRange(); + EXPECT_EQ("", out.str()); +} + +} // namespace +} // namespace cel diff --git a/extensions/BUILD b/extensions/BUILD index d84881716..2ceb39293 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -173,6 +173,78 @@ cc_test( ], ) +cc_library( + name = "network_ext_functions", + srcs = ["network_ext_functions.cc"], + hdrs = ["network_ext_functions.h"], + deps = [ + "//common:ipaddress_oss", + "//common:native_type", + "//common:typeinfo", + "//common:value", + "//runtime:function", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "network_ext", + srcs = ["network_ext.cc"], + hdrs = ["network_ext.h"], + deps = [ + ":network_ext_functions", + "//base:builtins", + "//checker:type_checker_builder", + "//common:decl", + "//common:ipaddress_oss", + "//common:native_type", + "//common:type", + "//common:value", + "//compiler", + "//internal:status_macros", + "//runtime:function", + "//runtime:function_adapter", + "//runtime:function_registry", + "//runtime:runtime_options", + "//runtime:type_registry", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "network_ext_test", + srcs = ["network_ext_test.cc"], + deps = [ + ":network_ext", + "//checker:validation_result", + "//common:ast", + "//common:minimal_descriptor_pool", + "//common:value", + "//compiler", + "//compiler:compiler_factory", + "//internal:status_macros", + "//internal:testing", + "//runtime", + "//runtime:activation", + "//runtime:runtime_builder", + "//runtime:runtime_builder_factory", + "//runtime:runtime_options", + "//runtime:standard_functions", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + # New users should use ":regex_ext" instead. cc_library( name = "regex_functions", diff --git a/extensions/network_ext.cc b/extensions/network_ext.cc new file mode 100644 index 000000000..95486b1bc --- /dev/null +++ b/extensions/network_ext.cc @@ -0,0 +1,592 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/network_ext.h" + +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "base/builtins.h" +#include "checker/type_checker_builder.h" +#include "common/decl.h" +#include "common/ipaddress_oss.h" +#include "common/native_type.h" +#include "common/type.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "extensions/network_ext_functions.h" +#include "internal/status_macros.h" +#include "runtime/function.h" +#include "runtime/function_adapter.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" +#include "runtime/type_registry.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::extensions { +namespace { + +using ::cel::BinaryFunctionAdapter; +using ::cel::BoolValue; +using ::cel::IPAddress; +using ::cel::MakeFunctionDecl; +using ::cel::MakeMemberOverloadDecl; +using ::cel::MakeOverloadDecl; +using ::cel::NativeTypeId; +using ::cel::OpaqueType; +using ::cel::OpaqueValue; +using ::cel::OpaqueValueContent; +using ::cel::OpaqueValueDispatcher; +using ::cel::StringValue; +using ::cel::Type; +using ::cel::TypeType; +using ::cel::UnaryFunctionAdapter; +using ::cel::UnsafeOpaqueValue; +using ::cel::Value; +using ::cel::builtin::kEqual; +using ::cel::builtin::kString; +using ::google::protobuf::Arena; +using ::google::protobuf::DescriptorPool; +using ::google::protobuf::MessageFactory; + +// Arena for static type instances +Arena* absl_nonnull BuiltinsArena() { + static absl::NoDestructor arena; + return arena.get(); +} + +// CEL Type Declarations +OpaqueType IpType() { + static const absl::NoDestructor kInstance( + BuiltinsArena(), "net.IP", std::vector{}); + return *kInstance; +} + +Type TypeOfIpType() { + static const absl::NoDestructor kInstance( + TypeType(BuiltinsArena(), IpType())); + return *kInstance; +} + +OpaqueType CidrType() { + static const absl::NoDestructor kInstance( + BuiltinsArena(), "net.CIDR", std::vector{}); + return *kInstance; +} + +Type TypeOfCidrType() { + static const absl::NoDestructor kInstance( + TypeType(BuiltinsArena(), CidrType())); + return *kInstance; +} + +// ----------------------------------------------------------------------------- +// Dispatcher for IpAddrRep (net.IP) +// ----------------------------------------------------------------------------- + +NativeTypeId IpAddrRep_GetTypeId(const OpaqueValueDispatcher*, + OpaqueValueContent content) { + return IpAddrRep::GetTypeId(); +} + +absl::string_view IpAddrRep_GetTypeName(const OpaqueValueDispatcher*, + OpaqueValueContent content) { + return "net.IP"; +} + +std::string IpAddrRep_DebugString(const OpaqueValueDispatcher*, + OpaqueValueContent content) { + return content.To()->DebugString(); +} + +absl::Status IpAddrRep_Equal(const OpaqueValueDispatcher*, + OpaqueValueContent content, + const OpaqueValue& other, const DescriptorPool*, + MessageFactory*, Arena*, Value* result) { + const IpAddrRep* self = content.To(); + const IpAddrRep* other_rep = IpAddrRep::Unwrap(other); + if (!other_rep) { + *result = BoolValue(false); + return absl::OkStatus(); + } + *result = BoolValue(self->Equals(*other_rep)); + return absl::OkStatus(); +} + +OpaqueValue IpAddrRep_Clone(const OpaqueValueDispatcher*, + OpaqueValueContent content, Arena* arena) { + const IpAddrRep* self = content.To(); + return IpAddrRep::Create(arena, self->addr()).GetOpaque(); +} + +OpaqueType IpAddrRep_GetRuntimeType(const OpaqueValueDispatcher*, + OpaqueValueContent) { + return IpType(); +} + +static const OpaqueValueDispatcher kIpAddrRepDispatcher = { + .get_type_id = IpAddrRep_GetTypeId, + .get_arena = nullptr, + .get_type_name = IpAddrRep_GetTypeName, + .debug_string = IpAddrRep_DebugString, + .get_runtime_type = IpAddrRep_GetRuntimeType, + .equal = IpAddrRep_Equal, + .clone = IpAddrRep_Clone, +}; + +// ----------------------------------------------------------------------------- +// Dispatcher for CidrRangeRep (net.CIDR) +// ----------------------------------------------------------------------------- + +NativeTypeId CidrRangeRep_GetTypeId(const OpaqueValueDispatcher*, + OpaqueValueContent content) { + return CidrRangeRep::GetTypeId(); +} + +absl::string_view CidrRangeRep_GetTypeName(const OpaqueValueDispatcher*, + OpaqueValueContent content) { + return "net.CIDR"; +} + +std::string CidrRangeRep_DebugString(const OpaqueValueDispatcher*, + OpaqueValueContent content) { + return content.To()->DebugString(); +} + +absl::Status CidrRangeRep_Equal(const OpaqueValueDispatcher*, + OpaqueValueContent content, + const OpaqueValue& other, const DescriptorPool*, + MessageFactory*, Arena*, Value* result) { + const CidrRangeRep* self = content.To(); + const CidrRangeRep* other_rep = CidrRangeRep::Unwrap(other); + if (!other_rep) { + *result = BoolValue(false); + return absl::OkStatus(); + } + *result = BoolValue(self->Equals(*other_rep)); + return absl::OkStatus(); +} + +OpaqueValue CidrRangeRep_Clone(const OpaqueValueDispatcher*, + OpaqueValueContent content, Arena* arena) { + const CidrRangeRep* self = content.To(); + return CidrRangeRep::Create(arena, self->host(), self->length()).GetOpaque(); +} + +OpaqueType CidrRangeRep_GetRuntimeType(const OpaqueValueDispatcher*, + OpaqueValueContent) { + return CidrType(); +} + +static const OpaqueValueDispatcher kCidrRangeRepDispatcher = { + .get_type_id = CidrRangeRep_GetTypeId, + .get_arena = nullptr, + .get_type_name = CidrRangeRep_GetTypeName, + .debug_string = CidrRangeRep_DebugString, + .get_runtime_type = CidrRangeRep_GetRuntimeType, + .equal = CidrRangeRep_Equal, + .clone = CidrRangeRep_Clone, +}; + +} // namespace + +// ----------------------------------------------------------------------------- +// IpAddrRep Method Implementations +// ----------------------------------------------------------------------------- +Value IpAddrRep::Create(Arena* arena, const IPAddress& addr) { + IpAddrRep* rep = Arena::Create(arena, addr); + return UnsafeOpaqueValue(&kIpAddrRepDispatcher, + OpaqueValueContent::From(rep)); +} + +const IpAddrRep* IpAddrRep::Unwrap(const Value& value) { + auto opaque = value.AsOpaque(); + if (!opaque.has_value() || opaque->GetTypeId() != IpAddrRep::GetTypeId()) { + return nullptr; + } + return opaque->content().To(); +} + +std::string IpAddrRep::DebugString() const { + if (!IsInitializedAddress(addr_)) { + return "ip()"; + } + return absl::StrCat("ip('", addr_.ToString(), "')"); +} + +// ----------------------------------------------------------------------------- +// CidrRangeRep Method Implementations +// ----------------------------------------------------------------------------- +Value CidrRangeRep::Create(Arena* arena, const IPAddress& host, + int length) { // Changed signature + CidrRangeRep* rep = Arena::Create( + arena, host, length); // Changed constructor call + return UnsafeOpaqueValue(&kCidrRangeRepDispatcher, + OpaqueValueContent::From(rep)); +} + +const CidrRangeRep* CidrRangeRep::Unwrap(const Value& value) { + auto opaque = value.AsOpaque(); + if (!opaque.has_value() || opaque->GetTypeId() != CidrRangeRep::GetTypeId()) { + return nullptr; + } + return opaque->content().To(); +} + +std::string CidrRangeRep::DebugString() const { + if (!IsInitializedAddress(host_) || + length_ < 0) { // Changed to use host_ and length_ + return "cidr()"; + } + return absl::StrCat("cidr('", host_.ToString(), "/", length_, + "')"); // Changed to use host_ and length_ +} + +// ----------------------------------------------------------------------------- +// CEL Extension Registration +// ----------------------------------------------------------------------------- + +absl::Status ConfigureNetworkFunctions(cel::TypeCheckerBuilder& builder) { + // Register Type Identifiers + CEL_RETURN_IF_ERROR( + builder.AddVariable(cel::MakeVariableDecl("net.IP", TypeOfIpType()))); + CEL_RETURN_IF_ERROR( + builder.AddVariable(cel::MakeVariableDecl("net.CIDR", TypeOfCidrType()))); + + CEL_ASSIGN_OR_RETURN( + auto decl_is_ip, + MakeFunctionDecl("isIP", MakeOverloadDecl("is_ip_string", cel::BoolType(), + cel::StringType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_is_ip)); + + CEL_ASSIGN_OR_RETURN( + auto decl_ip, + MakeFunctionDecl( + "ip", MakeOverloadDecl("string_to_ip", IpType(), cel::StringType()), + MakeMemberOverloadDecl("cidr_ip", IpType(), CidrType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_ip)); + + CEL_ASSIGN_OR_RETURN( + auto decl_is_cidr, + MakeFunctionDecl("isCIDR", + MakeOverloadDecl("is_cidr_string", cel::BoolType(), + cel::StringType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_is_cidr)); + + CEL_ASSIGN_OR_RETURN( + auto decl_cidr, + MakeFunctionDecl("cidr", MakeOverloadDecl("string_to_cidr", CidrType(), + cel::StringType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_cidr)); + + CEL_ASSIGN_OR_RETURN( + auto decl_ip_is_canonical, + MakeFunctionDecl("ip.isCanonical", + MakeOverloadDecl("ip_is_canonical_string", + cel::BoolType(), cel::StringType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_ip_is_canonical)); + + CEL_ASSIGN_OR_RETURN( + auto decl_ip_family, + MakeFunctionDecl("family", MakeMemberOverloadDecl( + "ip_family", cel::IntType(), IpType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_ip_family)); + + CEL_ASSIGN_OR_RETURN( + auto decl_ip_is_loopback, + MakeFunctionDecl( + "isLoopback", + MakeMemberOverloadDecl("ip_is_loopback", cel::BoolType(), IpType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_ip_is_loopback)); + + CEL_ASSIGN_OR_RETURN( + auto decl_ip_is_global_unicast, + MakeFunctionDecl("isGlobalUnicast", + MakeMemberOverloadDecl("ip_is_global_unicast", + cel::BoolType(), IpType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_ip_is_global_unicast)); + + CEL_ASSIGN_OR_RETURN( + auto decl_ip_is_link_local_multicast, + MakeFunctionDecl("isLinkLocalMulticast", + MakeMemberOverloadDecl("ip_is_link_local_multicast", + cel::BoolType(), IpType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_ip_is_link_local_multicast)); + + CEL_ASSIGN_OR_RETURN( + auto decl_ip_is_link_local_unicast, + MakeFunctionDecl("isLinkLocalUnicast", + MakeMemberOverloadDecl("ip_is_link_local_unicast", + cel::BoolType(), IpType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_ip_is_link_local_unicast)); + + CEL_ASSIGN_OR_RETURN( + auto decl_ip_is_unspecified, + MakeFunctionDecl("isUnspecified", + MakeMemberOverloadDecl("ip_is_unspecified", + cel::BoolType(), IpType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_ip_is_unspecified)); + + CEL_ASSIGN_OR_RETURN( + auto decl_cidr_contains_ip, + MakeFunctionDecl( + "containsIP", + MakeMemberOverloadDecl("cidr_contains_ip_ip", cel::BoolType(), + CidrType(), IpType()), + MakeMemberOverloadDecl("cidr_contains_ip_string", cel::BoolType(), + CidrType(), cel::StringType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_cidr_contains_ip)); + + CEL_ASSIGN_OR_RETURN( + auto decl_cidr_contains_cidr, + MakeFunctionDecl( + "containsCIDR", + MakeMemberOverloadDecl("cidr_contains_cidr_cidr", cel::BoolType(), + CidrType(), CidrType()), + MakeMemberOverloadDecl("cidr_contains_cidr_string", cel::BoolType(), + CidrType(), cel::StringType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_cidr_contains_cidr)); + + CEL_ASSIGN_OR_RETURN( + auto decl_cidr_masked, + MakeFunctionDecl("masked", MakeMemberOverloadDecl( + "cidr_masked", CidrType(), CidrType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_cidr_masked)); + + CEL_ASSIGN_OR_RETURN( + auto decl_cidr_prefix_length, + MakeFunctionDecl("prefixLength", + MakeMemberOverloadDecl("cidr_prefix_length", + cel::IntType(), CidrType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_cidr_prefix_length)); + + CEL_ASSIGN_OR_RETURN( + auto decl_string, + MakeFunctionDecl( + kString, + MakeMemberOverloadDecl("ip_to_string", cel::StringType(), IpType()), + MakeMemberOverloadDecl("cidr_to_string", cel::StringType(), + CidrType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_string)); + + // Add Equality Operator overloads for net.IP and net.CIDR + CEL_ASSIGN_OR_RETURN( + auto decl_equals, + MakeFunctionDecl( + kEqual, + MakeOverloadDecl("ip_equal_ip", cel::BoolType(), IpType(), IpType()), + MakeOverloadDecl("cidr_equal_cidr", cel::BoolType(), CidrType(), + CidrType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(decl_equals)); + + return absl::OkStatus(); +} + +cel::CompilerLibrary NetworkCompilerLibrary() { + return cel::CompilerLibrary("cel.extensions.network", + ConfigureNetworkFunctions); +} + +absl::Status RegisterNetworkTypes(cel::TypeRegistry& registry, + const cel::RuntimeOptions& options) { + CEL_RETURN_IF_ERROR(registry.RegisterType(IpType())); + CEL_RETURN_IF_ERROR(registry.RegisterType(CidrType())); + return absl::OkStatus(); +} +// Implementation for Opaque type equality +Value OpaqueEq(const Value& v1, const Value& v2, + const Function::InvokeContext& context) { + Value result; + absl::Status status = + v1.Equal(v2, context.descriptor_pool(), context.message_factory(), + context.arena(), &result); + if (!status.ok()) { + // This shouldn't happen if the types are supported by the dispatcher + return ErrorValue(status); + } + return result; +} + +// Declarations +cel::Value NetIsIP(const cel::StringValue& str_val, + const cel::Function::InvokeContext& context); +cel::Value NetIPString(const cel::StringValue& str_val, + const cel::Function::InvokeContext& context); +cel::Value NetIsCIDR(const cel::StringValue& str_val, + const cel::Function::InvokeContext& context); +cel::Value NetCIDRString(const cel::StringValue& str_val, + const cel::Function::InvokeContext& context); +cel::Value NetIPFamily(const cel::OpaqueValue& self, + const cel::Function::InvokeContext& context); +cel::Value NetIPIsLoopback(const cel::OpaqueValue& self, + const cel::Function::InvokeContext& context); +cel::Value NetIPIsGlobalUnicast(const cel::OpaqueValue& self, + const cel::Function::InvokeContext& context); +cel::Value NetIPIsLinkLocalMulticast( + const cel::OpaqueValue& self, const cel::Function::InvokeContext& context); +cel::Value NetIPIsLinkLocalUnicast(const cel::OpaqueValue& self, + const cel::Function::InvokeContext& context); +cel::Value NetIPIsUnspecified(const cel::OpaqueValue& self, + const cel::Function::InvokeContext& context); +cel::Value NetIPIsCanonical(const cel::StringValue& str_val, + const cel::Function::InvokeContext& context); +cel::Value NetCIDRContainsIP(const cel::OpaqueValue& self, + const cel::OpaqueValue& other, + const cel::Function::InvokeContext& context); +cel::Value NetCIDRContainsIPString(const cel::OpaqueValue& self, + const cel::StringValue& other_str, + const cel::Function::InvokeContext& context); +cel::Value NetCIDRContainsCIDR(const cel::OpaqueValue& self, + const cel::OpaqueValue& other, + const cel::Function::InvokeContext& context); +cel::Value NetCIDRContainsCIDRString( + const cel::OpaqueValue& self, const cel::StringValue& other_str, + const cel::Function::InvokeContext& context); +cel::Value NetCIDRIP(const cel::OpaqueValue& self, + const cel::Function::InvokeContext& context); +cel::Value NetCIDRMasked(const cel::OpaqueValue& self, + const cel::Function::InvokeContext& context); +cel::Value NetCIDRPrefixLength(const cel::OpaqueValue& self, + const cel::Function::InvokeContext& context); +cel::Value NetToString(const cel::OpaqueValue& self, + const cel::Function::InvokeContext& context); + +// Registers network functions with the function registry. +absl::Status RegisterNetworkFunctions(cel::FunctionRegistry& registry, + const cel::RuntimeOptions& options) { + // ... other function registrations ... + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor("isIP", + false), + UnaryFunctionAdapter::WrapFunction(&NetIsIP))); + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor("ip", + false), + UnaryFunctionAdapter::WrapFunction( + &NetIPString))); + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor( + "isCIDR", false), + UnaryFunctionAdapter::WrapFunction( + &NetIsCIDR))); + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor("cidr", + false), + UnaryFunctionAdapter::WrapFunction( + &NetCIDRString))); + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor( + "ip.isCanonical", false), + UnaryFunctionAdapter::WrapFunction( + &NetIPIsCanonical))); + + // Register Member Functions for net.IP + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor( + "family", true), + UnaryFunctionAdapter::WrapFunction( + &NetIPFamily))); + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor( + "isLoopback", true), + UnaryFunctionAdapter::WrapFunction( + &NetIPIsLoopback))); + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor( + "isGlobalUnicast", true), + UnaryFunctionAdapter::WrapFunction( + &NetIPIsGlobalUnicast))); + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor( + "isLinkLocalMulticast", true), + UnaryFunctionAdapter::WrapFunction( + &NetIPIsLinkLocalMulticast))); + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor( + "isLinkLocalUnicast", true), + UnaryFunctionAdapter::WrapFunction( + &NetIPIsLinkLocalUnicast))); + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor( + "isUnspecified", true), + UnaryFunctionAdapter::WrapFunction( + &NetIPIsUnspecified))); + + // Register Member Functions for net.CIDR + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter::CreateDescriptor("containsIP", + true), + BinaryFunctionAdapter:: + WrapFunction(&NetCIDRContainsIP))); + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter::CreateDescriptor("containsIP", + true), + BinaryFunctionAdapter:: + WrapFunction(&NetCIDRContainsIPString))); + + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter:: + CreateDescriptor("containsCIDR", true), + BinaryFunctionAdapter:: + WrapFunction(&NetCIDRContainsCIDR))); + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter:: + CreateDescriptor("containsCIDR", true), + BinaryFunctionAdapter:: + WrapFunction(&NetCIDRContainsCIDRString))); + + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor("ip", + true), + UnaryFunctionAdapter::WrapFunction( + &NetCIDRIP))); + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor( + "masked", true), + UnaryFunctionAdapter::WrapFunction( + &NetCIDRMasked))); + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor( + "prefixLength", true), + UnaryFunctionAdapter::WrapFunction( + &NetCIDRPrefixLength))); + + // Register the combined string function + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor(kString, + true), + UnaryFunctionAdapter::WrapFunction( + &NetToString))); + + // Register equality for IP and CIDR + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter::CreateDescriptor(kEqual, + false), + BinaryFunctionAdapter::WrapFunction(&OpaqueEq))); + + return absl::OkStatus(); +} + +} // namespace cel::extensions diff --git a/extensions/network_ext.h b/extensions/network_ext.h new file mode 100644 index 000000000..1da10eb16 --- /dev/null +++ b/extensions/network_ext.h @@ -0,0 +1,38 @@ +// Copyright 2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_NETWORK_EXT_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_NETWORK_EXT_H_ + +#include "absl/status/status.h" +#include "compiler/compiler.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" +#include "runtime/type_registry.h" + +namespace cel::extensions { + +// Provides a CEL compiler library for network functions. +cel::CompilerLibrary NetworkCompilerLibrary(); + +// Registers network function overloads with the function registry. +absl::Status RegisterNetworkFunctions(cel::FunctionRegistry& registry, + const cel::RuntimeOptions& options); + +// Registers network types with the type registry. +absl::Status RegisterNetworkTypes(cel::TypeRegistry& registry, + const cel::RuntimeOptions& options); + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_NETWORK_EXT_H_ diff --git a/extensions/network_ext_functions.cc b/extensions/network_ext_functions.cc new file mode 100644 index 000000000..5c72987a6 --- /dev/null +++ b/extensions/network_ext_functions.cc @@ -0,0 +1,368 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/network_ext_functions.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "common/ipaddress_oss.h" +#include "common/value.h" +#include "common/values/error_value.h" +#include "runtime/function.h" + +namespace cel::extensions { + +namespace { + +using ::cel::BoolValue; +using ::cel::ErrorValue; +using ::cel::Function; +using ::cel::GetMappedIPv4Address; +using ::cel::IntValue; +using ::cel::IPAddress; +using ::cel::IPRange; +using ::cel::IsAnyIPAddress; +using ::cel::IsInitializedAddress; +using ::cel::IsLinkLocalIP; +using ::cel::IsLoopbackIPAddress; +using ::cel::IsNonRoutableIP; +using ::cel::IsProperSubRange; +using ::cel::IsWithinSubnet; +using ::cel::StringToIPAddress; +using ::cel::StringValue; +using ::cel::Value; + +// ----------------------------------------------------------------------------- +// Strict Parsing Helpers +// ----------------------------------------------------------------------------- + +bool IsStrictIP(const IPAddress& addr) { + if (!IsInitializedAddress(addr)) return false; + IPAddress unused; + // Check for IPv4-mapped IPv6 addresses. + if (GetMappedIPv4Address(addr, &unused)) { + return false; + } + // zone() is not a member of net_base::IPAddress + return true; +} + +// Helper to parse CIDR string into host and length without truncation +bool ParseCIDRWithoutTruncation(absl::string_view s, IPAddress* host, + int* length) { + std::vector parts = absl::StrSplit(s, '/'); + if (parts.size() != 2) { + return false; + } + if (!StringToIPAddress(parts[0], host)) { + return false; + } + if (!absl::SimpleAtoi(parts[1], length)) { + return false; + } + if ((*length < 0 || (*host).is_ipv4()) && + (*length > 32 || (*host).is_ipv6()) && *length > 128) { + return false; + } + return true; +} + +bool IsStrictCIDR(const IPAddress& host, int length) { + if (!IsInitializedAddress(host)) return false; + return IsStrictIP(host); +} + +} // namespace + +// ----------------------------------------------------------------------------- +// CEL Function Implementations +// ----------------------------------------------------------------------------- + +// isIP(string) -> bool +Value NetIsIP(const StringValue& str_val, + const Function::InvokeContext& context) { + IPAddress addr; + if (!StringToIPAddress(std::string(str_val.ToString()), &addr)) { + return BoolValue(false); + } + return BoolValue(IsStrictIP(addr)); +} + +// ip(string) -> net.IP +Value NetIPString(const StringValue& str_val, + const Function::InvokeContext& context) { + std::string str = std::string(str_val.ToString()); + IPAddress addr; + if (!StringToIPAddress(str, &addr)) { + return ErrorValue(absl::InvalidArgumentError( + absl::StrCat("IP Address '", str, "' parse error"))); + } + if (!IsStrictIP(addr)) { + return ErrorValue(absl::InvalidArgumentError(absl::StrCat( + "IP Address '", str, "' is not a strict IP (e.g., mapped IPv4)"))); + } + return IpAddrRep::Create(context.arena(), addr); +} + +// isCIDR(string) -> bool +Value NetIsCIDR(const StringValue& str_val, + const Function::InvokeContext& context) { + std::string str = std::string(str_val.ToString()); + IPAddress host; + int length; + if (!ParseCIDRWithoutTruncation(str, &host, &length)) { + return BoolValue(false); + } + return BoolValue(IsStrictCIDR(host, length)); +} + +// cidr(string) -> net.CIDR +Value NetCIDRString(const StringValue& str_val, + const Function::InvokeContext& context) { + std::string str = std::string(str_val.ToString()); + IPAddress host; + int length; + if (!ParseCIDRWithoutTruncation(str, &host, &length)) { + return ErrorValue(absl::InvalidArgumentError( + absl::StrCat("CIDR '", str, "' parse error"))); + } + + if (!IsStrictCIDR(host, length)) { + return ErrorValue(absl::InvalidArgumentError(absl::StrCat( + "CIDR '", str, "' is not a strict CIDR (e.g., mapped IPv4)"))); + } + return CidrRangeRep::Create(context.arena(), host, length); +} + +// .family() -> int +Value NetIPFamily(const OpaqueValue& self, + const Function::InvokeContext& context) { + const IpAddrRep* rep = IpAddrRep::Unwrap(self); + if (!rep || !IsInitializedAddress(rep->addr())) { + return ErrorValue(absl::InvalidArgumentError("Uninitialized IPAddress")); + } + switch (rep->addr().address_family()) { + case AF_INET: + return IntValue(4); + case AF_INET6: + return IntValue(6); + default: + return ErrorValue(absl::InvalidArgumentError("Unknown address family")); + } +} + +// .isLoopback() -> bool +Value NetIPIsLoopback(const OpaqueValue& self, + const Function::InvokeContext& context) { + const IpAddrRep* rep = IpAddrRep::Unwrap(self); + if (!rep || !IsInitializedAddress(rep->addr())) { + return ErrorValue(absl::InvalidArgumentError("Uninitialized IPAddress")); + } + return BoolValue(IsLoopbackIPAddress(rep->addr())); +} + +// .isGlobalUnicast() -> bool +Value NetIPIsGlobalUnicast(const OpaqueValue& self, + const Function::InvokeContext& context) { + const IpAddrRep* rep = IpAddrRep::Unwrap(self); + if (!rep || !IsInitializedAddress(rep->addr())) { + return ErrorValue(absl::InvalidArgumentError("Uninitialized IPAddress")); + } + const IPAddress& addr = rep->addr(); + + if (IsAnyIPAddress(addr) || addr == IPAddress::Loopback4() || + addr == IPAddress::Loopback6() || IsLinkLocalIP(addr) || + IsNonRoutableIP(addr)) { + return BoolValue(false); + } + + if (addr.is_ipv4()) { + return BoolValue(!IsV4MulticastIPAddress(addr)); + } + + if (addr.is_ipv6()) { + in6_addr addr6 = addr.ipv6_address(); + if (IN6_IS_ADDR_MULTICAST(&addr6)) { + return BoolValue(false); + } + return BoolValue(true); + } + return BoolValue(false); +} + +Value NetIPIsLinkLocalMulticast(const OpaqueValue& self, + const Function::InvokeContext& context) { + const IpAddrRep* rep = IpAddrRep::Unwrap(self); + if (!rep || !rep->addr().is_ipv6()) { + return BoolValue(false); + } + in6_addr addr6 = rep->addr().ipv6_address(); + return BoolValue(IN6_IS_ADDR_MC_LINKLOCAL(&addr6)); +} + +Value NetIPIsLinkLocalUnicast(const OpaqueValue& self, + const Function::InvokeContext& context) { + const IpAddrRep* rep = IpAddrRep::Unwrap(self); + if (!rep || !IsInitializedAddress(rep->addr())) { + return ErrorValue(absl::InvalidArgumentError("Uninitialized IPAddress")); + } + return BoolValue(IsLinkLocalIP(rep->addr())); +} + +Value NetIPIsUnspecified(const OpaqueValue& self, + const Function::InvokeContext& context) { + const IpAddrRep* rep = IpAddrRep::Unwrap(self); + if (!rep) { + return ErrorValue(absl::InvalidArgumentError("Invalid IP object")); + } + return BoolValue(IsAnyIPAddress(rep->addr())); +} + +Value NetIPIsCanonical(const StringValue& str_val, + const Function::InvokeContext& context) { + std::string str = std::string(str_val.ToString()); + IPAddress addr; + if (!StringToIPAddress(str, &addr)) { + return BoolValue(false); + } + if (!IsStrictIP(addr)) { + return BoolValue(false); + } + return BoolValue(addr.ToString() == str); +} + +Value NetCIDRContainsIP(const OpaqueValue& self, const OpaqueValue& other, + const Function::InvokeContext& context) { + const CidrRangeRep* self_rep = CidrRangeRep::Unwrap(self); + const IpAddrRep* other_rep = IpAddrRep::Unwrap(other); + if (!self_rep || !IsInitializedAddress(self_rep->host()) || + self_rep->length() < 0 || !other_rep || + !IsInitializedAddress(other_rep->addr())) { + return ErrorValue(absl::InvalidArgumentError("Uninitialized CIDR or IP")); + } + return BoolValue(IsWithinSubnet(self_rep->ToIPRange(), other_rep->addr())); +} + +Value NetCIDRContainsIPString(const OpaqueValue& self, + const StringValue& other_str, + const Function::InvokeContext& context) { + const CidrRangeRep* self_rep = CidrRangeRep::Unwrap(self); + if (!self_rep || !IsInitializedAddress(self_rep->host()) || + self_rep->length() < 0) { + return ErrorValue(absl::InvalidArgumentError("Uninitialized CIDR")); + } + + std::string str = std::string(other_str.ToString()); + IPAddress other_addr; + if (!StringToIPAddress(str, &other_addr) || !IsStrictIP(other_addr)) { + return ErrorValue(absl::InvalidArgumentError( + absl::StrCat("Invalid or non-strict IP string: ", str))); + } + return BoolValue(IsWithinSubnet(self_rep->ToIPRange(), other_addr)); +} + +Value NetCIDRContainsCIDR(const OpaqueValue& self, const OpaqueValue& other, + const Function::InvokeContext& context) { + const CidrRangeRep* self_rep = CidrRangeRep::Unwrap(self); + const CidrRangeRep* other_rep = CidrRangeRep::Unwrap(other); + if (!self_rep || !IsInitializedAddress(self_rep->host()) || + self_rep->length() < 0 || !other_rep || + !IsInitializedAddress(other_rep->host()) || other_rep->length() < 0) { + return ErrorValue(absl::InvalidArgumentError("Uninitialized CIDR")); + } + IPRange self_range = self_rep->ToIPRange(); + IPRange other_range = other_rep->ToIPRange(); + return BoolValue(self_range == other_range || + IsProperSubRange(self_range, other_range)); +} + +Value NetCIDRContainsCIDRString(const OpaqueValue& self, + const StringValue& other_str, + const Function::InvokeContext& context) { + const CidrRangeRep* self_rep = CidrRangeRep::Unwrap(self); + if (!self_rep || !IsInitializedAddress(self_rep->host()) || + self_rep->length() < 0) { + return ErrorValue(absl::InvalidArgumentError("Uninitialized CIDR")); + } + + std::string str = std::string(other_str.ToString()); + IPAddress other_host; + int other_length; + if (!ParseCIDRWithoutTruncation(str, &other_host, &other_length) || + !IsStrictCIDR(other_host, other_length)) { + return ErrorValue(absl::InvalidArgumentError( + absl::StrCat("Invalid or non-strict CIDR string: ", str))); + } + IPRange self_range = self_rep->ToIPRange(); + IPRange other_range(other_host, other_length); + return BoolValue(self_range == other_range || + IsProperSubRange(self_range, other_range)); +} + +Value NetCIDRIP(const OpaqueValue& self, + const Function::InvokeContext& context) { + const CidrRangeRep* rep = CidrRangeRep::Unwrap(self); + if (!rep || !IsInitializedAddress(rep->host()) || rep->length() < 0) { + return ErrorValue(absl::InvalidArgumentError("Uninitialized CIDR")); + } + return IpAddrRep::Create(context.arena(), rep->host()); +} + +Value NetCIDRMasked(const OpaqueValue& self, + const Function::InvokeContext& context) { + const CidrRangeRep* rep = CidrRangeRep::Unwrap(self); + if (!rep || !IsInitializedAddress(rep->host()) || rep->length() < 0) { + return ErrorValue(absl::InvalidArgumentError("Uninitialized CIDR")); + } + IPRange masked_range = rep->ToIPRange(); + return CidrRangeRep::Create(context.arena(), masked_range.host(), + masked_range.length()); +} + +Value NetCIDRPrefixLength(const OpaqueValue& self, + const Function::InvokeContext& context) { + const CidrRangeRep* rep = CidrRangeRep::Unwrap(self); + if (!rep || !IsInitializedAddress(rep->host()) || rep->length() < 0) { + return ErrorValue(absl::InvalidArgumentError("Uninitialized CIDR")); + } + return IntValue(rep->length()); +} + +Value NetToString(const OpaqueValue& self, + const Function::InvokeContext& context) { + if (const IpAddrRep* rep = IpAddrRep::Unwrap(self)) { + if (!IsInitializedAddress(rep->addr())) { + return ErrorValue(absl::InvalidArgumentError("Uninitialized IPAddress")); + } + return StringValue::From(rep->addr().ToString(), context.arena()); + } + if (const CidrRangeRep* rep = CidrRangeRep::Unwrap(self)) { + if (!IsInitializedAddress(rep->host()) || rep->length() < 0) { + return ErrorValue(absl::InvalidArgumentError("Uninitialized CIDR")); + } + return StringValue::From( + absl::StrCat(rep->host().ToString(), "/", rep->length()), + context.arena()); + } + return ErrorValue( + absl::InvalidArgumentError("Unsupported type for string()")); +} + +} // namespace cel::extensions diff --git a/extensions/network_ext_functions.h b/extensions/network_ext_functions.h new file mode 100644 index 000000000..3e376ddce --- /dev/null +++ b/extensions/network_ext_functions.h @@ -0,0 +1,80 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// extensions/network_ext_functions.h + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_NETWORK_EXT_FUNCTIONS_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_NETWORK_EXT_FUNCTIONS_H_ + +#include + +#include "common/ipaddress_oss.h" +#include "common/native_type.h" +#include "common/typeinfo.h" +#include "common/value.h" +#include "google/protobuf/arena.h" + +namespace cel::extensions { + +// ... IpAddrRep and CidrRangeRep classes ... +class IpAddrRep { + public: + static cel::Value Create(google::protobuf::Arena* arena, const IPAddress& addr); + static const IpAddrRep* Unwrap(const cel::Value& value); + IpAddrRep() = default; + explicit IpAddrRep(const IPAddress& addr) : addr_(addr) {} + const IPAddress& addr() const { return addr_; } + bool Equals(const IpAddrRep& other) const { return addr_ == other.addr_; } + std::string DebugString() const; + static cel::NativeTypeId GetTypeId() { return cel::TypeId(); } + + private: + IPAddress addr_; +}; + +class CidrRangeRep { + public: + static cel::Value Create(google::protobuf::Arena* arena, const IPAddress& host, + int length); + static const CidrRangeRep* Unwrap(const cel::Value& value); + + CidrRangeRep() = default; + explicit CidrRangeRep(const IPAddress& host, int length) + : host_(host), length_(length) {} + + const IPAddress& host() const { return host_; } + int length() const { return length_; } + + // Utility to get the IPRange (which will be truncated) + IPRange ToIPRange() const { return IPRange(host_, length_); } + + bool Equals(const CidrRangeRep& other) const { + return length_ == other.length_ && host_ == other.host_; + } + std::string DebugString() const; + + static cel::NativeTypeId GetTypeId() { return cel::TypeId(); } + + template + friend H AbslHashValue(H h, const CidrRangeRep& c) { + return H::combine(std::move(h), c.host_, c.length_); + } + + private: + IPAddress host_; + int length_ = -1; +}; + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_NETWORK_EXT_FUNCTIONS_H_ diff --git a/extensions/network_ext_test.cc b/extensions/network_ext_test.cc new file mode 100644 index 000000000..73cf0fb2f --- /dev/null +++ b/extensions/network_ext_test.cc @@ -0,0 +1,394 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/network_ext.h" + +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "common/minimal_descriptor_pool.h" +#include "common/value.h" +#include "common/values/bool_value.h" +#include "common/values/error_value.h" +#include "common/values/int_value.h" +#include "common/values/string_value.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "runtime/activation.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_builder_factory.h" +#include "runtime/standard_functions.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +// Includes for Compiler +#include "checker/validation_result.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "runtime/runtime_options.h" + +namespace cel::extensions { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::Activation; +using ::testing::Eq; +using ::testing::HasSubstr; + +class NetworkExtTest : public ::testing::Test { + protected: + NetworkExtTest() = default; + + void SetUp() override { + // 1. Configure the Compiler + auto compiler_builder = + cel::NewCompilerBuilder(cel::GetMinimalDescriptorPool()); + ASSERT_THAT(compiler_builder.status(), IsOk()); + ASSERT_THAT((*compiler_builder)->AddLibrary(NetworkCompilerLibrary()), + IsOk()); + ASSERT_OK_AND_ASSIGN(compiler_, std::move(*compiler_builder)->Build()); + + // 2. Configure the Modern Runtime + cel::RuntimeOptions runtime_options; + // Wrap the raw pointer in a std::shared_ptr with a NO-OP DELETER + std::shared_ptr descriptor_pool( + cel::GetMinimalDescriptorPool(), [](const google::protobuf::DescriptorPool*) { + // Do nothing, as the pool is static. + }); + + auto runtime_builder = + cel::CreateRuntimeBuilder(descriptor_pool, runtime_options); + ASSERT_THAT(runtime_builder.status(), + IsOk()); // Check if CreateRuntimeBuilder succeeded + + ASSERT_THAT( + RegisterNetworkTypes(runtime_builder->type_registry(), runtime_options), + IsOk()); + ASSERT_THAT(RegisterNetworkFunctions(runtime_builder->function_registry(), + runtime_options), + IsOk()); + + ASSERT_THAT(cel::RegisterStandardFunctions( + runtime_builder->function_registry(), runtime_options), + IsOk()); + + // Build the runtime + ASSERT_OK_AND_ASSIGN(runtime_, std::move(*runtime_builder).Build()); + } + + // ... Evaluate() function and member variables ... + absl::StatusOr Evaluate(absl::string_view expr) { + auto validation_result = compiler_->Compile(expr); + CEL_RETURN_IF_ERROR(validation_result.status()); + + if (!validation_result->GetIssues().empty()) { + return absl::InvalidArgumentError( + validation_result->GetIssues()[0].message()); + } + + if (!validation_result->IsValid()) { + return absl::InternalError( + "Compilation produced an invalid AST without issues."); + } + + CEL_ASSIGN_OR_RETURN(std::unique_ptr ast, + validation_result->ReleaseAst()); + + if (ast == nullptr) { + return absl::InternalError("ValidationResult returned a null AST."); + } + + CEL_ASSIGN_OR_RETURN(auto program, runtime_->CreateProgram(std::move(ast))); + + Activation activation; + return program->Evaluate(&arena_, activation); + } + + std::unique_ptr compiler_; + std::unique_ptr runtime_; + google::protobuf::Arena arena_; +}; + +// --- Global Checks (isIP, isCIDR) --- +TEST_F(NetworkExtTest, IsIPValidIPv4) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("isIP('1.2.3.4')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +TEST_F(NetworkExtTest, IsIPValidIPv6) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("isIP('2001:db8::1')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +TEST_F(NetworkExtTest, IsIPInvalid) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("isIP('not.an.ip')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(false)); +} + +TEST_F(NetworkExtTest, IsIPWithPort) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("isIP('127.0.0.1:80')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(false)); +} + +TEST_F(NetworkExtTest, IsCIDRValid) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("isCIDR('10.0.0.0/8')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +TEST_F(NetworkExtTest, IsCIDRInvalidMask) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("isCIDR('10.0.0.0/999')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(false)); +} + +// --- IP Constructors & Equality --- +TEST_F(NetworkExtTest, IPEqualityIPv4) { + ASSERT_OK_AND_ASSIGN(auto value, + Evaluate("ip('127.0.0.1') == ip('127.0.0.1')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +TEST_F(NetworkExtTest, IPInequality) { + ASSERT_OK_AND_ASSIGN(auto value, + Evaluate("ip('127.0.0.1') == ip('1.2.3.4')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(false)); +} + +TEST_F(NetworkExtTest, IPEqualityIPv6MixedCase) { + ASSERT_OK_AND_ASSIGN(auto value, + Evaluate("ip('2001:db8::1') == ip('2001:DB8::1')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +// --- String Conversion --- +TEST_F(NetworkExtTest, IPToStringIPv4) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("ip('1.2.3.4').string()")); + ASSERT_TRUE(value.IsString()); + EXPECT_THAT(value.As()->ToString(), Eq("1.2.3.4")); +} + +TEST_F(NetworkExtTest, IPToStringIPv6) { + ASSERT_OK_AND_ASSIGN(auto value, + Evaluate("ip('2001:db8::1').string()")); // .string() + ASSERT_TRUE(value.IsString()); + EXPECT_THAT(value.As()->ToString(), Eq("2001:db8::1")); +} + +TEST_F(NetworkExtTest, CIDRToStringIPv4) { + ASSERT_OK_AND_ASSIGN(auto value, + Evaluate("cidr('10.0.0.0/8').string()")); // .string() + ASSERT_TRUE(value.IsString()); + EXPECT_THAT(value.As()->ToString(), Eq("10.0.0.0/8")); +} + +TEST_F(NetworkExtTest, CIDRToStringIPv6) { + ASSERT_OK_AND_ASSIGN(auto value, + Evaluate("cidr('::1/128').string()")); // .string() + ASSERT_TRUE(value.IsString()); + EXPECT_THAT(value.As()->ToString(), Eq("::1/128")); +} + +// --- Family --- +TEST_F(NetworkExtTest, FamilyIPv4) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("ip('127.0.0.1').family()")); + ASSERT_TRUE(value.IsInt()); + EXPECT_THAT(value.As()->NativeValue(), Eq(4)); +} + +TEST_F(NetworkExtTest, FamilyIPv6) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("ip('::1').family()")); + ASSERT_TRUE(value.IsInt()); + EXPECT_THAT(value.As()->NativeValue(), Eq(6)); +} + +// --- Canonicalization --- +TEST_F(NetworkExtTest, IsCanonicalIPv4) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("ip.isCanonical('127.0.0.1')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +TEST_F(NetworkExtTest, IsCanonicalIPv6) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("ip.isCanonical('2001:db8::1')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +TEST_F(NetworkExtTest, IsCanonicalIPv6Uppercase) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("ip.isCanonical('2001:DB8::1')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(false)); +} + +TEST_F(NetworkExtTest, IsCanonicalIPv6Expanded) { + ASSERT_OK_AND_ASSIGN(auto value, + Evaluate("ip.isCanonical('2001:db8:0:0:0:0:0:1')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(false)); +} + +// --- IP Types (Loopback, Unspecified, etc) --- +TEST_F(NetworkExtTest, IsLoopbackIPv4) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("ip('127.0.0.1').isLoopback()")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +TEST_F(NetworkExtTest, IsLoopbackIPv6) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("ip('::1').isLoopback()")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +TEST_F(NetworkExtTest, IsUnspecifiedIPv4) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("ip('0.0.0.0').isUnspecified()")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +TEST_F(NetworkExtTest, IsUnspecifiedIPv6) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("ip('::').isUnspecified()")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +TEST_F(NetworkExtTest, IsGlobalUnicast) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("ip('8.8.8.8').isGlobalUnicast()")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +TEST_F(NetworkExtTest, IsLinkLocalMulticast) { + ASSERT_OK_AND_ASSIGN(auto value, + Evaluate("ip('ff02::1').isLinkLocalMulticast()")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +// --- CIDR Accessors --- +TEST_F(NetworkExtTest, CIDRPrefixLength) { + ASSERT_OK_AND_ASSIGN(auto value, + Evaluate("cidr('192.168.0.0/24').prefixLength()")); + ASSERT_TRUE(value.IsInt()); + EXPECT_THAT(value.As()->NativeValue(), Eq(24)); +} + +TEST_F(NetworkExtTest, CIDRIPExtraction) { + ASSERT_OK_AND_ASSIGN( + auto value, Evaluate("cidr('192.168.0.0/24').ip() == ip('192.168.0.0')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +TEST_F(NetworkExtTest, CIDRIPExtractionHostBitsSet) { + ASSERT_OK_AND_ASSIGN( + auto value, Evaluate("cidr('192.168.1.5/24').ip() == ip('192.168.1.5')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +TEST_F(NetworkExtTest, CIDRMasked) { + ASSERT_OK_AND_ASSIGN( + auto value, + Evaluate("cidr('192.168.1.5/24').masked() == cidr('192.168.1.0/24')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +TEST_F(NetworkExtTest, CIDRMaskedIdentity) { + ASSERT_OK_AND_ASSIGN( + auto value, + Evaluate("cidr('192.168.1.0/24').masked() == cidr('192.168.1.0/24')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +// --- Containment (IP in CIDR) --- +TEST_F(NetworkExtTest, ContainsIPSimple) { + ASSERT_OK_AND_ASSIGN( + auto value, Evaluate("cidr('10.0.0.0/8').containsIP(ip('10.1.2.3'))")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +TEST_F(NetworkExtTest, ContainsIPStringOverload) { + ASSERT_OK_AND_ASSIGN(auto value, + Evaluate("cidr('10.0.0.0/8').containsIP('10.1.2.3')")); + ASSERT_TRUE(value.IsBool()); + EXPECT_THAT(value.As()->NativeValue(), Eq(true)); +} + +// ... other Contains tests ... + +// --- Runtime Errors --- +TEST_F(NetworkExtTest, ErrIPConstructorInvalid) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("ip('999.999.999.999')")); + ASSERT_TRUE(value.IsError()); + EXPECT_THAT( + value.As()->ToStatus(), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("parse error"))); +} + +TEST_F(NetworkExtTest, ErrCIDRConstructorInvalid) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("cidr('1.2.3.4')")); + ASSERT_TRUE(value.IsError()); + EXPECT_THAT( + value.As()->ToStatus(), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("parse error"))); +} + +TEST_F(NetworkExtTest, ErrCIDRConstructorInvalidMask) { + ASSERT_OK_AND_ASSIGN(auto value, Evaluate("cidr('10.0.0.0/999')")); + ASSERT_TRUE(value.IsError()); + EXPECT_THAT( + value.As()->ToStatus(), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("parse error"))); +} + +TEST_F(NetworkExtTest, ErrContainsIPStringInvalid) { + ASSERT_OK_AND_ASSIGN(auto value, + Evaluate("cidr('10.0.0.0/8').containsIP('not-an-ip')")); + ASSERT_TRUE(value.IsError()); + EXPECT_THAT(value.As()->ToStatus(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid or non-strict IP string"))); +} + +TEST_F(NetworkExtTest, ErrContainsCIDRStringInvalid) { + ASSERT_OK_AND_ASSIGN( + auto value, Evaluate("cidr('10.0.0.0/8').containsCIDR('not-a-cidr')")); + ASSERT_TRUE(value.IsError()); + EXPECT_THAT(value.As()->ToStatus(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid or non-strict CIDR string"))); +} + +} // namespace +} // namespace cel::extensions