From 590c6c352695bc16ca5e1209fd97be433f12be7e Mon Sep 17 00:00:00 2001 From: neerajbhatt Date: Sun, 24 May 2026 14:07:01 +0530 Subject: [PATCH] fix: accept header accessor function in ServerTransportSecurityValidator Replace Map> parameter with Function> in validateHeaders(), allowing callers to pass a header accessor instead of extracting all headers upfront. This is more efficient (only requested headers are looked up) and delegates case-insensitive header matching to the underlying request implementation (e.g. HttpServletRequest.getHeaders). - Update DefaultServerTransportSecurityValidator to use the accessor directly for Origin and Host headers - Update all three servlet transport providers to pass name -> Collections.list(request.getHeaders(name)) - Remove HttpServletRequestUtils (no longer needed) - Update unit tests to use accessor-based API Closes #870 --- ...faultServerTransportSecurityValidator.java | 33 ++-- .../transport/HttpServletRequestUtils.java | 40 ---- ...HttpServletSseServerTransportProvider.java | 7 +- .../HttpServletStatelessServerTransport.java | 5 +- ...vletStreamableServerTransportProvider.java | 10 +- .../ServerTransportSecurityValidator.java | 11 +- ...ServerTransportSecurityValidatorTests.java | 176 ++++++++++-------- 7 files changed, 121 insertions(+), 161 deletions(-) delete mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletRequestUtils.java diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java index e96403e48..39c6bcacf 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java @@ -6,7 +6,7 @@ import java.util.ArrayList; import java.util.List; -import java.util.Map; +import java.util.function.Function; import io.modelcontextprotocol.util.Assert; @@ -47,27 +47,18 @@ private DefaultServerTransportSecurityValidator(List allowedOrigins, Lis } @Override - public void validateHeaders(Map> headers) throws ServerTransportSecurityException { - boolean missingHost = true; - for (Map.Entry> entry : headers.entrySet()) { - if (ORIGIN_HEADER.equalsIgnoreCase(entry.getKey())) { - List values = entry.getValue(); - if (values == null || values.isEmpty()) { - throw new ServerTransportSecurityException(403, "Invalid Origin header"); - } - validateOrigin(values.get(0)); - } - else if (HOST_HEADER.equalsIgnoreCase(entry.getKey())) { - missingHost = false; - List values = entry.getValue(); - if (values == null || values.isEmpty()) { - throw new ServerTransportSecurityException(421, "Invalid Host header"); - } - validateHost(values.get(0)); - } + public void validateHeaders(Function> headerAccessor) throws ServerTransportSecurityException { + List originValues = headerAccessor.apply(ORIGIN_HEADER); + if (originValues != null && !originValues.isEmpty()) { + validateOrigin(originValues.get(0)); } - if (!allowedHosts.isEmpty() && missingHost) { - throw new ServerTransportSecurityException(421, "Invalid Host header"); + + if (!allowedHosts.isEmpty()) { + List hostValues = headerAccessor.apply(HOST_HEADER); + if (hostValues == null || hostValues.isEmpty()) { + throw new ServerTransportSecurityException(421, "Invalid Host header"); + } + validateHost(hostValues.get(0)); } } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletRequestUtils.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletRequestUtils.java deleted file mode 100644 index 32246948c..000000000 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletRequestUtils.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright 2026-2026 the original author or authors. - */ - -package io.modelcontextprotocol.server.transport; - -import java.util.Collections; -import java.util.Enumeration; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import jakarta.servlet.http.HttpServletRequest; - -/** - * Utility methods for working with {@link HttpServletRequest}. For internal use only. - * - * @author Daniel Garnier-Moiroux - */ -final class HttpServletRequestUtils { - - private HttpServletRequestUtils() { - } - - /** - * Extracts all headers from the HTTP request into a map. - * @param request The HTTP servlet request - * @return A map of header names to their values - */ - static Map> extractHeaders(HttpServletRequest request) { - Map> headers = new HashMap<>(); - Enumeration names = request.getHeaderNames(); - while (names.hasMoreElements()) { - String name = names.nextElement(); - headers.put(name, Collections.list(request.getHeaders(name))); - } - return headers; - } - -} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index 69d73f7ab..0ac16971e 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -8,6 +8,7 @@ import java.io.IOException; import java.io.PrintWriter; import java.time.Duration; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.UUID; @@ -280,8 +281,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) } try { - Map> headers = HttpServletRequestUtils.extractHeaders(request); - this.securityValidator.validateHeaders(headers); + this.securityValidator.validateHeaders(name -> Collections.list(request.getHeaders(name))); } catch (ServerTransportSecurityException e) { response.sendError(e.getStatusCode(), e.getMessage()); @@ -353,8 +353,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) } try { - Map> headers = HttpServletRequestUtils.extractHeaders(request); - this.securityValidator.validateHeaders(headers); + this.securityValidator.validateHeaders(name -> Collections.list(request.getHeaders(name))); } catch (ServerTransportSecurityException e) { response.sendError(e.getStatusCode(), e.getMessage()); diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java index 047aeebe8..245d1585e 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java @@ -7,8 +7,8 @@ import java.io.BufferedReader; import java.io.IOException; import java.io.PrintWriter; +import java.util.Collections; import java.util.List; -import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -134,8 +134,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) } try { - Map> headers = HttpServletRequestUtils.extractHeaders(request); - this.securityValidator.validateHeaders(headers); + this.securityValidator.validateHeaders(name -> Collections.list(request.getHeaders(name))); } catch (ServerTransportSecurityException e) { response.sendError(e.getStatusCode(), e.getMessage()); diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java index 9a785e150..4bfa68643 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java @@ -9,6 +9,7 @@ import java.io.PrintWriter; import java.time.Duration; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -271,8 +272,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) } try { - Map> headers = HttpServletRequestUtils.extractHeaders(request); - this.securityValidator.validateHeaders(headers); + this.securityValidator.validateHeaders(name -> Collections.list(request.getHeaders(name))); } catch (ServerTransportSecurityException e) { response.sendError(e.getStatusCode(), e.getMessage()); @@ -407,8 +407,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) } try { - Map> headers = HttpServletRequestUtils.extractHeaders(request); - this.securityValidator.validateHeaders(headers); + this.securityValidator.validateHeaders(name -> Collections.list(request.getHeaders(name))); } catch (ServerTransportSecurityException e) { response.sendError(e.getStatusCode(), e.getMessage()); @@ -588,8 +587,7 @@ protected void doDelete(HttpServletRequest request, HttpServletResponse response } try { - Map> headers = HttpServletRequestUtils.extractHeaders(request); - this.securityValidator.validateHeaders(headers); + this.securityValidator.validateHeaders(name -> Collections.list(request.getHeaders(name))); } catch (ServerTransportSecurityException e) { response.sendError(e.getStatusCode(), e.getMessage()); diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityValidator.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityValidator.java index ce805931f..b24e1b88f 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityValidator.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityValidator.java @@ -5,7 +5,7 @@ package io.modelcontextprotocol.server.transport; import java.util.List; -import java.util.Map; +import java.util.function.Function; /** * Interface for validating HTTP requests in server transports. Implementations can @@ -22,15 +22,16 @@ public interface ServerTransportSecurityValidator { /** * A no-op validator that accepts all requests without validation. */ - ServerTransportSecurityValidator NOOP = headers -> { + ServerTransportSecurityValidator NOOP = headerAccessor -> { }; /** * Validates the HTTP headers from an incoming request. - * @param headers A map of header names to their values (multi-valued headers - * supported) + * @param headerAccessor A function that returns the list of values for a given header + * name, or an empty list if the header is not present. Header name lookup should be + * case-insensitive. * @throws ServerTransportSecurityException if validation fails */ - void validateHeaders(Map> headers) throws ServerTransportSecurityException; + void validateHeaders(Function> headerAccessor) throws ServerTransportSecurityException; } diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidatorTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidatorTests.java index d4cf8582d..38755febf 100644 --- a/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidatorTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidatorTests.java @@ -7,6 +7,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Function; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -43,50 +44,50 @@ class OriginHeader { @Test void originHeaderMissing() { - assertThatCode(() -> validator.validateHeaders(new HashMap<>())).doesNotThrowAnyException(); + assertThatCode(() -> validator.validateHeaders(emptyAccessor())).doesNotThrowAnyException(); } @Test void originHeaderListEmpty() { - assertThatThrownBy(() -> validator.validateHeaders(Map.of("Origin", List.of()))).isEqualTo(INVALID_ORIGIN); + assertThatCode(() -> validator.validateHeaders(name -> List.of())).doesNotThrowAnyException(); } @Test void caseInsensitive() { - var headers = Map.of("origin", List.of("http://localhost:8080")); + var accessor = headerAccessor("Origin", "http://localhost:8080"); - assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> validator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void exactMatch() { - var headers = originHeader("http://localhost:8080"); + var accessor = originAccessor("http://localhost:8080"); - assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> validator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void differentPort() { - var headers = originHeader("http://localhost:3000"); + var accessor = originAccessor("http://localhost:3000"); - assertThatThrownBy(() -> validator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + assertThatThrownBy(() -> validator.validateHeaders(accessor)).isEqualTo(INVALID_ORIGIN); } @Test void differentHost() { - var headers = originHeader("http://example.com:8080"); + var accessor = originAccessor("http://example.com:8080"); - assertThatThrownBy(() -> validator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + assertThatThrownBy(() -> validator.validateHeaders(accessor)).isEqualTo(INVALID_ORIGIN); } @Test void differentScheme() { - var headers = originHeader("https://localhost:8080"); + var accessor = originAccessor("https://localhost:8080"); - assertThatThrownBy(() -> validator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + assertThatThrownBy(() -> validator.validateHeaders(accessor)).isEqualTo(INVALID_ORIGIN); } @Nested @@ -99,37 +100,37 @@ class WildcardPort { @Test void anyPortWithWildcard() { - var headers = originHeader("http://localhost:3000"); + var accessor = originAccessor("http://localhost:3000"); - assertThatCode(() -> wildcardValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> wildcardValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void noPortWithWildcard() { - var headers = originHeader("http://localhost"); + var accessor = originAccessor("http://localhost"); - assertThatCode(() -> wildcardValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> wildcardValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void differentPortWithWildcard() { - var headers = originHeader("http://localhost:8080"); + var accessor = originAccessor("http://localhost:8080"); - assertThatCode(() -> wildcardValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> wildcardValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void differentHostWithWildcard() { - var headers = originHeader("http://example.com:3000"); + var accessor = originAccessor("http://example.com:3000"); - assertThatThrownBy(() -> wildcardValidator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + assertThatThrownBy(() -> wildcardValidator.validateHeaders(accessor)).isEqualTo(INVALID_ORIGIN); } @Test void differentSchemeWithWildcard() { - var headers = originHeader("https://localhost:3000"); + var accessor = originAccessor("https://localhost:3000"); - assertThatThrownBy(() -> wildcardValidator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + assertThatThrownBy(() -> wildcardValidator.validateHeaders(accessor)).isEqualTo(INVALID_ORIGIN); } } @@ -146,23 +147,23 @@ class MultipleOrigins { @Test void matchingOneOfMultiple() { - var headers = originHeader("http://example.com:3000"); + var accessor = originAccessor("http://example.com:3000"); - assertThatCode(() -> multipleOriginsValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> multipleOriginsValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void matchingWildcardInMultiple() { - var headers = originHeader("http://myapp.example.com:9999"); + var accessor = originAccessor("http://myapp.example.com:9999"); - assertThatCode(() -> multipleOriginsValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> multipleOriginsValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void notMatchingAny() { - var headers = originHeader("http://malicious.example.com:1234"); + var accessor = originAccessor("http://malicious.example.com:1234"); - assertThatThrownBy(() -> multipleOriginsValidator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + assertThatThrownBy(() -> multipleOriginsValidator.validateHeaders(accessor)).isEqualTo(INVALID_ORIGIN); } } @@ -176,9 +177,9 @@ void shouldAddMultipleOriginsWithAllowedOriginsMethod() { .allowedOrigins(List.of("http://localhost:8080", "http://example.com:*")) .build(); - var headers = originHeader("http://example.com:3000"); + var accessor = originAccessor("http://example.com:3000"); - assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> validator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test @@ -188,11 +189,11 @@ void shouldCombineAllowedOriginMethods() { .allowedOrigins(List.of("http://example.com:*", "http://test.com:3000")) .build(); - assertThatCode(() -> validator.validateHeaders(originHeader("http://localhost:8080"))) + assertThatCode(() -> validator.validateHeaders(originAccessor("http://localhost:8080"))) .doesNotThrowAnyException(); - assertThatCode(() -> validator.validateHeaders(originHeader("http://example.com:9999"))) + assertThatCode(() -> validator.validateHeaders(originAccessor("http://example.com:9999"))) .doesNotThrowAnyException(); - assertThatCode(() -> validator.validateHeaders(originHeader("http://test.com:3000"))) + assertThatCode(() -> validator.validateHeaders(originAccessor("http://test.com:3000"))) .doesNotThrowAnyException(); } @@ -210,45 +211,45 @@ class HostHeader { @Test void notConfigured() { - assertThatCode(() -> validator.validateHeaders(new HashMap<>())).doesNotThrowAnyException(); + assertThatCode(() -> validator.validateHeaders(emptyAccessor())).doesNotThrowAnyException(); } @Test void missing() { - assertThatThrownBy(() -> hostValidator.validateHeaders(new HashMap<>())).isEqualTo(INVALID_HOST); + assertThatThrownBy(() -> hostValidator.validateHeaders(emptyAccessor())).isEqualTo(INVALID_HOST); } @Test void listEmpty() { - assertThatThrownBy(() -> hostValidator.validateHeaders(Map.of("Host", List.of()))).isEqualTo(INVALID_HOST); + assertThatThrownBy(() -> hostValidator.validateHeaders(name -> List.of())).isEqualTo(INVALID_HOST); } @Test void caseInsensitive() { - var headers = Map.of("host", List.of("localhost:8080")); + var accessor = headerAccessor("Host", "localhost:8080"); - assertThatCode(() -> hostValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> hostValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void exactMatch() { - var headers = hostHeader("localhost:8080"); + var accessor = hostAccessor("localhost:8080"); - assertThatCode(() -> hostValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> hostValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void differentPort() { - var headers = hostHeader("localhost:3000"); + var accessor = hostAccessor("localhost:3000"); - assertThatThrownBy(() -> hostValidator.validateHeaders(headers)).isEqualTo(INVALID_HOST); + assertThatThrownBy(() -> hostValidator.validateHeaders(accessor)).isEqualTo(INVALID_HOST); } @Test void differentHost() { - var headers = hostHeader("example.com:8080"); + var accessor = hostAccessor("example.com:8080"); - assertThatThrownBy(() -> hostValidator.validateHeaders(headers)).isEqualTo(INVALID_HOST); + assertThatThrownBy(() -> hostValidator.validateHeaders(accessor)).isEqualTo(INVALID_HOST); } @Nested @@ -261,23 +262,23 @@ class HostWildcardPort { @Test void anyPort() { - var headers = hostHeader("localhost:3000"); + var accessor = hostAccessor("localhost:3000"); - assertThatCode(() -> wildcardHostValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> wildcardHostValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void noPort() { - var headers = hostHeader("localhost"); + var accessor = hostAccessor("localhost"); - assertThatCode(() -> wildcardHostValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> wildcardHostValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void differentHost() { - var headers = hostHeader("example.com:3000"); + var accessor = hostAccessor("example.com:3000"); - assertThatThrownBy(() -> wildcardHostValidator.validateHeaders(headers)).isEqualTo(INVALID_HOST); + assertThatThrownBy(() -> wildcardHostValidator.validateHeaders(accessor)).isEqualTo(INVALID_HOST); } } @@ -293,30 +294,30 @@ class MultipleHosts { @Test void exactMatch() { - var headers = hostHeader("example.com:3000"); + var accessor = hostAccessor("example.com:3000"); - assertThatCode(() -> multipleHostsValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> multipleHostsValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void wildcard() { - var headers = hostHeader("myapp.example.com:9999"); + var accessor = hostAccessor("myapp.example.com:9999"); - assertThatCode(() -> multipleHostsValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> multipleHostsValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void differentHost() { - var headers = hostHeader("malicious.example.com:3000"); + var accessor = hostAccessor("malicious.example.com:3000"); - assertThatThrownBy(() -> multipleHostsValidator.validateHeaders(headers)).isEqualTo(INVALID_HOST); + assertThatThrownBy(() -> multipleHostsValidator.validateHeaders(accessor)).isEqualTo(INVALID_HOST); } @Test void differentPort() { - var headers = hostHeader("localhost:8080"); + var accessor = hostAccessor("localhost:8080"); - assertThatThrownBy(() -> multipleHostsValidator.validateHeaders(headers)).isEqualTo(INVALID_HOST); + assertThatThrownBy(() -> multipleHostsValidator.validateHeaders(accessor)).isEqualTo(INVALID_HOST); } } @@ -330,9 +331,9 @@ void multipleHosts() { .allowedHosts(List.of("localhost:8080", "example.com:*")) .build(); - assertThatCode(() -> validator.validateHeaders(hostHeader("example.com:3000"))) + assertThatCode(() -> validator.validateHeaders(hostAccessor("example.com:3000"))) .doesNotThrowAnyException(); - assertThatCode(() -> validator.validateHeaders(hostHeader("localhost:8080"))) + assertThatCode(() -> validator.validateHeaders(hostAccessor("localhost:8080"))) .doesNotThrowAnyException(); } @@ -343,11 +344,12 @@ void combined() { .allowedHosts(List.of("example.com:*", "test.com:3000")) .build(); - assertThatCode(() -> validator.validateHeaders(hostHeader("localhost:8080"))) + assertThatCode(() -> validator.validateHeaders(hostAccessor("localhost:8080"))) .doesNotThrowAnyException(); - assertThatCode(() -> validator.validateHeaders(hostHeader("example.com:9999"))) + assertThatCode(() -> validator.validateHeaders(hostAccessor("example.com:9999"))) + .doesNotThrowAnyException(); + assertThatCode(() -> validator.validateHeaders(hostAccessor("test.com:3000"))) .doesNotThrowAnyException(); - assertThatCode(() -> validator.validateHeaders(hostHeader("test.com:3000"))).doesNotThrowAnyException(); } } @@ -365,60 +367,70 @@ class CombinedOriginAndHostValidation { @Test void bothValid() { - var header = headers("http://localhost:8080", "localhost:8080"); + var accessor = combinedAccessor("http://localhost:8080", "localhost:8080"); - assertThatCode(() -> combinedValidator.validateHeaders(header)).doesNotThrowAnyException(); + assertThatCode(() -> combinedValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void originValidHostInvalid() { - var header = headers("http://localhost:8080", "malicious.example.com:8080"); + var accessor = combinedAccessor("http://localhost:8080", "malicious.example.com:8080"); - assertThatThrownBy(() -> combinedValidator.validateHeaders(header)).isEqualTo(INVALID_HOST); + assertThatThrownBy(() -> combinedValidator.validateHeaders(accessor)).isEqualTo(INVALID_HOST); } @Test void originInvalidHostValid() { - var header = headers("http://malicious.example.com:8080", "localhost:8080"); + var accessor = combinedAccessor("http://malicious.example.com:8080", "localhost:8080"); - assertThatThrownBy(() -> combinedValidator.validateHeaders(header)).isEqualTo(INVALID_ORIGIN); + assertThatThrownBy(() -> combinedValidator.validateHeaders(accessor)).isEqualTo(INVALID_ORIGIN); } @Test void originMissingHostValid() { // Origin missing is OK (same-origin request) - var header = headers(null, "localhost:8080"); + var accessor = combinedAccessor(null, "localhost:8080"); - assertThatCode(() -> combinedValidator.validateHeaders(header)).doesNotThrowAnyException(); + assertThatCode(() -> combinedValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void originValidHostMissing() { // Host missing is NOT OK when allowedHosts is configured - var header = headers("http://localhost:8080", null); + var accessor = combinedAccessor("http://localhost:8080", null); - assertThatThrownBy(() -> combinedValidator.validateHeaders(header)).isEqualTo(INVALID_HOST); + assertThatThrownBy(() -> combinedValidator.validateHeaders(accessor)).isEqualTo(INVALID_HOST); } } - private static Map> originHeader(String origin) { - return Map.of("Origin", List.of(origin)); + private static Function> emptyAccessor() { + return name -> List.of(); + } + + private static Function> headerAccessor(String headerName, String value) { + Map> headers = new HashMap<>(); + headers.put(headerName, List.of(value)); + return name -> headers.getOrDefault(name, List.of()); + } + + private static Function> originAccessor(String origin) { + return headerAccessor("Origin", origin); } - private static Map> hostHeader(String host) { - return Map.of("Host", List.of(host)); + private static Function> hostAccessor(String host) { + return headerAccessor("Host", host); } - private static Map> headers(String origin, String host) { - var map = new HashMap>(); + private static Function> combinedAccessor(String origin, String host) { + Map> headers = new HashMap<>(); if (origin != null) { - map.put("Origin", List.of(origin)); + headers.put("Origin", List.of(origin)); } if (host != null) { - map.put("Host", List.of(host)); + headers.put("Host", List.of(host)); } - return map; + return name -> headers.getOrDefault(name, List.of()); } }