diff --git a/src/main/java/org/juv25d/config/IpFilterConfig.java b/src/main/java/org/juv25d/config/IpFilterConfig.java index ec970cdd..bc69d936 100644 --- a/src/main/java/org/juv25d/config/IpFilterConfig.java +++ b/src/main/java/org/juv25d/config/IpFilterConfig.java @@ -20,5 +20,7 @@ public boolean allowByDefault() { return allowByDefault; } - public boolean trustProxyHeaders() {return trustProxyHeaders;} + public boolean trustProxyHeaders() { + return trustProxyHeaders; + } } diff --git a/src/main/java/org/juv25d/filter/IpFilter.java b/src/main/java/org/juv25d/filter/IpFilter.java index c1a7280b..9c0ff43d 100644 --- a/src/main/java/org/juv25d/filter/IpFilter.java +++ b/src/main/java/org/juv25d/filter/IpFilter.java @@ -33,6 +33,20 @@ public class IpFilter implements Filter { private final boolean allowByDefault; private final boolean trustProxyHeaders; + /** + * Constructs an IP filter with specified whitelist, blacklist, and default policy. + * This constructor sets {@code trustProxyHeaders = false} + *

+ * To specify proxy trusting use {@link #IpFilter(Set, Set, boolean, boolean)} + * + * @param whitelist set of IPs/CIDR ranges to always allow (can be null) + * @param blacklist set of IPs/CIDR ranges to always block (can be null) + * @param allowByDefault whether to allow IPs not in either list + * */ + public IpFilter(@Nullable Set whitelist, @Nullable Set blacklist, boolean allowByDefault) { + this(whitelist, blacklist, allowByDefault, false); + } + /** * Constructs an IP filter with specified whitelist, blacklist, and default policy. * @@ -261,12 +275,12 @@ private String getClientIp(HttpRequest req) { Map headers = req.headers(); - String ip = headers.get("X-Forwarded-For"); + String ip = getHeaderIgnoreCase(headers, "X-Forwarded-For"); if (ip != null && !ip.isBlank()) { return ip.split(",")[0].trim(); } - ip = headers.get("X-Real-IP"); + ip = getHeaderIgnoreCase(headers, "X-Real-IP"); if (ip != null && !ip.isBlank()) { return ip.trim(); } @@ -274,6 +288,22 @@ private String getClientIp(HttpRequest req) { return req.remoteIp(); } + /** + * Retrieves a header value using case-insensitive name matching. + * + * @param headers the header map to search + * @param name the header name + * @return the header value, or {@code null} if not found + */ + private @Nullable String getHeaderIgnoreCase(Map headers, String name) { + for (Map.Entry entry : headers.entrySet()) { + if (entry.getKey() != null && entry.getKey().equalsIgnoreCase(name)) { + return entry.getValue(); + } + } + return null; + } + /** * Sends a 403 Forbidden response to the client. * diff --git a/src/test/java/org/juv25d/filter/IpFilterTest.java b/src/test/java/org/juv25d/filter/IpFilterTest.java index 345c7bb0..628f50db 100644 --- a/src/test/java/org/juv25d/filter/IpFilterTest.java +++ b/src/test/java/org/juv25d/filter/IpFilterTest.java @@ -34,7 +34,7 @@ void setUp() { @Test @DisplayName("Allow ip only in whitelist") void whitelist_allowsIp() throws IOException { - IpFilter filter = new IpFilter(Set.of("127.0.0.1"), null, false, false); + IpFilter filter = new IpFilter(Set.of("127.0.0.1"), null, false); filter.doFilter(req, res, chain); @@ -45,7 +45,7 @@ void whitelist_allowsIp() throws IOException { @Test @DisplayName("Allow ip from CIDR range only in whitelist") void whitelist_allowsIpInRange() throws IOException { - IpFilter filter = new IpFilter(Set.of("127.0.0.0/24"), null, false, false); + IpFilter filter = new IpFilter(Set.of("127.0.0.0/24"), null, false); filter.doFilter(req, res, chain); @@ -56,7 +56,7 @@ void whitelist_allowsIpInRange() throws IOException { @Test @DisplayName("Block ip only in blacklist") void blacklist_blocksIp() throws IOException { - IpFilter filter = new IpFilter(null, Set.of("127.0.0.1"), true, false); + IpFilter filter = new IpFilter(null, Set.of("127.0.0.1"), true); filter.doFilter(req, res, chain); verify(chain, never()).doFilter(req, res); @@ -68,7 +68,7 @@ void blacklist_blocksIp() throws IOException { @Test @DisplayName("Block ip from CIDR range only in blacklist") void blacklist_blocksIpInRange() throws IOException { - IpFilter filter = new IpFilter(null, Set.of("127.0.0.0/24"), true, false); + IpFilter filter = new IpFilter(null, Set.of("127.0.0.0/24"), true); filter.doFilter(req, res, chain); verify(chain, never()).doFilter(req, res); @@ -80,7 +80,7 @@ void blacklist_blocksIpInRange() throws IOException { @Test @DisplayName("Allow ip in both list (whitelist prio)") void whitelist_overrides_blacklist() throws IOException { - IpFilter filter = new IpFilter(Set.of("127.0.0.1"), Set.of("127.0.0.0/24"), false, false); + IpFilter filter = new IpFilter(Set.of("127.0.0.1"), Set.of("127.0.0.0/24"), false); filter.doFilter(req, res, chain); @@ -91,8 +91,8 @@ void whitelist_overrides_blacklist() throws IOException { @ParameterizedTest @ValueSource(booleans = {true, false}) @DisplayName("Follow default when ip in neither list") - void Ip_inNeitherList_followsDefault(boolean allowByDefault) throws IOException { - IpFilter filter = new IpFilter(null, null, allowByDefault, false); + void ip_inNeitherList_followsDefault(boolean allowByDefault) throws IOException { + IpFilter filter = new IpFilter(null, null, allowByDefault); filter.doFilter(req, res, chain); @@ -110,7 +110,7 @@ void Ip_inNeitherList_followsDefault(boolean allowByDefault) throws IOException @ValueSource(strings = {"127.0.0.1", "127.0.0.0/24"}) @DisplayName("Allow IP or CIDR range added in existing filter") void addIpOrRange_whitelist(String ipOrCidr) throws IOException { - IpFilter filter = new IpFilter(null, null, false, false); + IpFilter filter = new IpFilter(null, null, false); filter.addToWhitelist(ipOrCidr); filter.doFilter(req, res, chain); @@ -125,7 +125,7 @@ void addIpOrRange_whitelist(String ipOrCidr) throws IOException { @ValueSource(strings = {"127.0.0.1", "127.0.0.0/24"}) @DisplayName("Block IP or CIDR range added in existing filter") void addIpOrRange_blacklist(String ipOrCidr) throws IOException { - IpFilter filter = new IpFilter(null, null, false, false); + IpFilter filter = new IpFilter(null, null, false); filter.addToBlacklist(ipOrCidr); filter.doFilter(req, res, chain); @@ -139,7 +139,7 @@ void addIpOrRange_blacklist(String ipOrCidr) throws IOException { @Test @DisplayName("Adding IP or CIDR range already in filter doesn't create duplicates") void doesNotAddDuplicates() { - IpFilter filter = new IpFilter(Set.of("127.0.0.1", "127.0.0.0/24"), null, false, false); + IpFilter filter = new IpFilter(Set.of("127.0.0.1", "127.0.0.0/24"), null, false); filter.addToWhitelist("127.0.0.1"); filter.addToWhitelist("127.0.0.0/24"); @@ -152,7 +152,7 @@ void doesNotAddDuplicates() { @ValueSource(strings = {"127.0.0.1", "127.0.0.0/24"}) @DisplayName("Fall back on blacklist/default after removing IP or CIDR range from whitelist") void removeIpOrRange_whitelist(String ipOrCidr) throws IOException { - IpFilter filter = new IpFilter(Set.of(ipOrCidr), null, false, false); + IpFilter filter = new IpFilter(Set.of(ipOrCidr), null, false); filter.removeFromWhitelist(ipOrCidr); filter.doFilter(req, res, chain); @@ -167,7 +167,7 @@ void removeIpOrRange_whitelist(String ipOrCidr) throws IOException { @ValueSource(strings = {"127.0.0.1", "127.0.0.0/24"}) @DisplayName("Fall back on whitelist/default after removing IP or CIDR range from blacklist") void removeIpOrRange_blacklist(String ipOrCidr) throws IOException { - IpFilter filter = new IpFilter(null, Set.of(ipOrCidr), true, false); + IpFilter filter = new IpFilter(null, Set.of(ipOrCidr), true); filter.removeFromBlacklist(ipOrCidr); filter.doFilter(req, res, chain); @@ -182,7 +182,7 @@ void removeIpOrRange_blacklist(String ipOrCidr) throws IOException { @NullAndEmptySource @DisplayName("Block null or empty IP") void nullOrBlankIp_blocked(String ip) throws IOException { - IpFilter filter = new IpFilter(null, null, true, false); + IpFilter filter = new IpFilter(null, null, true); when(req.remoteIp()).thenReturn(ip); @@ -195,7 +195,7 @@ void nullOrBlankIp_blocked(String ip) throws IOException { @Test @DisplayName("Ignore incorrectly formatted CIDR range") void invalidCidr_loggedAndIgnored() { - IpFilter filter = new IpFilter(null, null, false, false); + IpFilter filter = new IpFilter(null, null, false); filter.addToWhitelist("not-a-cidr/99"); @@ -205,7 +205,7 @@ void invalidCidr_loggedAndIgnored() { @Test @DisplayName("Get methods return immutable copies") void get_returnsImmutableCopy() { - IpFilter filter = new IpFilter(null, null, false, false); + IpFilter filter = new IpFilter(null, null, false); assertThrows(UnsupportedOperationException.class, () -> filter.getWhitelistIps().add("test"));