Skip to content
4 changes: 3 additions & 1 deletion src/main/java/org/juv25d/config/IpFilterConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,7 @@ public boolean allowByDefault() {
return allowByDefault;
}

public boolean trustProxyHeaders() {return trustProxyHeaders;}
public boolean trustProxyHeaders() {
return trustProxyHeaders;
}
}
34 changes: 32 additions & 2 deletions src/main/java/org/juv25d/filter/IpFilter.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,20 @@ public class IpFilter implements Filter {
private final boolean allowByDefault;
private final boolean trustProxyHeaders;

/**
* Constructs an IP filter with specified whitelist, blacklist, and default policy.
* This constructor sets {@code trustProxyHeaders = false}
* <p>
* To specify proxy trusting use {@link #IpFilter(Set, Set, boolean, boolean)}
*
* @param whitelist set of IPs/CIDR ranges to always allow (can be null)
* @param blacklist set of IPs/CIDR ranges to always block (can be null)
* @param allowByDefault whether to allow IPs not in either list
* */
public IpFilter(@Nullable Set<String> whitelist, @Nullable Set<String> blacklist, boolean allowByDefault) {
this(whitelist, blacklist, allowByDefault, false);
}

/**
* Constructs an IP filter with specified whitelist, blacklist, and default policy.
*
Expand Down Expand Up @@ -261,19 +275,35 @@ private String getClientIp(HttpRequest req) {

Map<String, String> headers = req.headers();

String ip = headers.get("X-Forwarded-For");
String ip = getHeaderIgnoreCase(headers, "X-Forwarded-For");
if (ip != null && !ip.isBlank()) {
return ip.split(",")[0].trim();
}

ip = headers.get("X-Real-IP");
ip = getHeaderIgnoreCase(headers, "X-Real-IP");
if (ip != null && !ip.isBlank()) {
return ip.trim();
}

return req.remoteIp();
}

/**
* Retrieves a header value using case-insensitive name matching.
*
* @param headers the header map to search
* @param name the header name
* @return the header value, or {@code null} if not found
*/
private @Nullable String getHeaderIgnoreCase(Map<String, String> headers, String name) {
for (Map.Entry<String, String> entry : headers.entrySet()) {
if (entry.getKey() != null && entry.getKey().equalsIgnoreCase(name)) {
return entry.getValue();
}
}
return null;
}

/**
* Sends a 403 Forbidden response to the client.
*
Expand Down
30 changes: 15 additions & 15 deletions src/test/java/org/juv25d/filter/IpFilterTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void setUp() {
@Test
@DisplayName("Allow ip only in whitelist")
void whitelist_allowsIp() throws IOException {
IpFilter filter = new IpFilter(Set.of("127.0.0.1"), null, false, false);
IpFilter filter = new IpFilter(Set.of("127.0.0.1"), null, false);

filter.doFilter(req, res, chain);

Expand All @@ -45,7 +45,7 @@ void whitelist_allowsIp() throws IOException {
@Test
@DisplayName("Allow ip from CIDR range only in whitelist")
void whitelist_allowsIpInRange() throws IOException {
IpFilter filter = new IpFilter(Set.of("127.0.0.0/24"), null, false, false);
IpFilter filter = new IpFilter(Set.of("127.0.0.0/24"), null, false);

filter.doFilter(req, res, chain);

Expand All @@ -56,7 +56,7 @@ void whitelist_allowsIpInRange() throws IOException {
@Test
@DisplayName("Block ip only in blacklist")
void blacklist_blocksIp() throws IOException {
IpFilter filter = new IpFilter(null, Set.of("127.0.0.1"), true, false);
IpFilter filter = new IpFilter(null, Set.of("127.0.0.1"), true);

filter.doFilter(req, res, chain);
verify(chain, never()).doFilter(req, res);
Expand All @@ -68,7 +68,7 @@ void blacklist_blocksIp() throws IOException {
@Test
@DisplayName("Block ip from CIDR range only in blacklist")
void blacklist_blocksIpInRange() throws IOException {
IpFilter filter = new IpFilter(null, Set.of("127.0.0.0/24"), true, false);
IpFilter filter = new IpFilter(null, Set.of("127.0.0.0/24"), true);

filter.doFilter(req, res, chain);
verify(chain, never()).doFilter(req, res);
Expand All @@ -80,7 +80,7 @@ void blacklist_blocksIpInRange() throws IOException {
@Test
@DisplayName("Allow ip in both list (whitelist prio)")
void whitelist_overrides_blacklist() throws IOException {
IpFilter filter = new IpFilter(Set.of("127.0.0.1"), Set.of("127.0.0.0/24"), false, false);
IpFilter filter = new IpFilter(Set.of("127.0.0.1"), Set.of("127.0.0.0/24"), false);

filter.doFilter(req, res, chain);

Expand All @@ -91,8 +91,8 @@ void whitelist_overrides_blacklist() throws IOException {
@ParameterizedTest
@ValueSource(booleans = {true, false})
@DisplayName("Follow default when ip in neither list")
void Ip_inNeitherList_followsDefault(boolean allowByDefault) throws IOException {
IpFilter filter = new IpFilter(null, null, allowByDefault, false);
void ip_inNeitherList_followsDefault(boolean allowByDefault) throws IOException {
IpFilter filter = new IpFilter(null, null, allowByDefault);

filter.doFilter(req, res, chain);

Expand All @@ -110,7 +110,7 @@ void Ip_inNeitherList_followsDefault(boolean allowByDefault) throws IOException
@ValueSource(strings = {"127.0.0.1", "127.0.0.0/24"})
@DisplayName("Allow IP or CIDR range added in existing filter")
void addIpOrRange_whitelist(String ipOrCidr) throws IOException {
IpFilter filter = new IpFilter(null, null, false, false);
IpFilter filter = new IpFilter(null, null, false);
filter.addToWhitelist(ipOrCidr);

filter.doFilter(req, res, chain);
Expand All @@ -125,7 +125,7 @@ void addIpOrRange_whitelist(String ipOrCidr) throws IOException {
@ValueSource(strings = {"127.0.0.1", "127.0.0.0/24"})
@DisplayName("Block IP or CIDR range added in existing filter")
void addIpOrRange_blacklist(String ipOrCidr) throws IOException {
IpFilter filter = new IpFilter(null, null, false, false);
IpFilter filter = new IpFilter(null, null, false);
filter.addToBlacklist(ipOrCidr);

filter.doFilter(req, res, chain);
Expand All @@ -139,7 +139,7 @@ void addIpOrRange_blacklist(String ipOrCidr) throws IOException {
@Test
@DisplayName("Adding IP or CIDR range already in filter doesn't create duplicates")
void doesNotAddDuplicates() {
IpFilter filter = new IpFilter(Set.of("127.0.0.1", "127.0.0.0/24"), null, false, false);
IpFilter filter = new IpFilter(Set.of("127.0.0.1", "127.0.0.0/24"), null, false);

filter.addToWhitelist("127.0.0.1");
filter.addToWhitelist("127.0.0.0/24");
Expand All @@ -152,7 +152,7 @@ void doesNotAddDuplicates() {
@ValueSource(strings = {"127.0.0.1", "127.0.0.0/24"})
@DisplayName("Fall back on blacklist/default after removing IP or CIDR range from whitelist")
void removeIpOrRange_whitelist(String ipOrCidr) throws IOException {
IpFilter filter = new IpFilter(Set.of(ipOrCidr), null, false, false);
IpFilter filter = new IpFilter(Set.of(ipOrCidr), null, false);
filter.removeFromWhitelist(ipOrCidr);

filter.doFilter(req, res, chain);
Expand All @@ -167,7 +167,7 @@ void removeIpOrRange_whitelist(String ipOrCidr) throws IOException {
@ValueSource(strings = {"127.0.0.1", "127.0.0.0/24"})
@DisplayName("Fall back on whitelist/default after removing IP or CIDR range from blacklist")
void removeIpOrRange_blacklist(String ipOrCidr) throws IOException {
IpFilter filter = new IpFilter(null, Set.of(ipOrCidr), true, false);
IpFilter filter = new IpFilter(null, Set.of(ipOrCidr), true);
filter.removeFromBlacklist(ipOrCidr);

filter.doFilter(req, res, chain);
Expand All @@ -182,7 +182,7 @@ void removeIpOrRange_blacklist(String ipOrCidr) throws IOException {
@NullAndEmptySource
@DisplayName("Block null or empty IP")
void nullOrBlankIp_blocked(String ip) throws IOException {
IpFilter filter = new IpFilter(null, null, true, false);
IpFilter filter = new IpFilter(null, null, true);

when(req.remoteIp()).thenReturn(ip);

Expand All @@ -195,7 +195,7 @@ void nullOrBlankIp_blocked(String ip) throws IOException {
@Test
@DisplayName("Ignore incorrectly formatted CIDR range")
void invalidCidr_loggedAndIgnored() {
IpFilter filter = new IpFilter(null, null, false, false);
IpFilter filter = new IpFilter(null, null, false);

filter.addToWhitelist("not-a-cidr/99");

Expand All @@ -205,7 +205,7 @@ void invalidCidr_loggedAndIgnored() {
@Test
@DisplayName("Get methods return immutable copies")
void get_returnsImmutableCopy() {
IpFilter filter = new IpFilter(null, null, false, false);
IpFilter filter = new IpFilter(null, null, false);

assertThrows(UnsupportedOperationException.class, () ->
filter.getWhitelistIps().add("test"));
Expand Down