Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

import io.modelcontextprotocol.util.Assert;

Expand Down Expand Up @@ -47,27 +47,18 @@ private DefaultServerTransportSecurityValidator(List<String> allowedOrigins, Lis
}

@Override
public void validateHeaders(Map<String, List<String>> headers) throws ServerTransportSecurityException {
boolean missingHost = true;
for (Map.Entry<String, List<String>> entry : headers.entrySet()) {
if (ORIGIN_HEADER.equalsIgnoreCase(entry.getKey())) {
List<String> values = entry.getValue();
if (values == null || values.isEmpty()) {
throw new ServerTransportSecurityException(403, "Invalid Origin header");
}
validateOrigin(values.get(0));
}
else if (HOST_HEADER.equalsIgnoreCase(entry.getKey())) {
missingHost = false;
List<String> values = entry.getValue();
if (values == null || values.isEmpty()) {
throw new ServerTransportSecurityException(421, "Invalid Host header");
}
validateHost(values.get(0));
}
public void validateHeaders(Function<String, List<String>> headerAccessor) throws ServerTransportSecurityException {
List<String> originValues = headerAccessor.apply(ORIGIN_HEADER);
if (originValues != null && !originValues.isEmpty()) {
validateOrigin(originValues.get(0));
}
if (!allowedHosts.isEmpty() && missingHost) {
throw new ServerTransportSecurityException(421, "Invalid Host header");

if (!allowedHosts.isEmpty()) {
List<String> hostValues = headerAccessor.apply(HOST_HEADER);
if (hostValues == null || hostValues.isEmpty()) {
throw new ServerTransportSecurityException(421, "Invalid Host header");
}
validateHost(hostValues.get(0));
}
}

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.io.IOException;
import java.io.PrintWriter;
import java.time.Duration;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.UUID;
Expand Down Expand Up @@ -280,8 +281,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
}

try {
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
this.securityValidator.validateHeaders(headers);
this.securityValidator.validateHeaders(name -> Collections.list(request.getHeaders(name)));
}
catch (ServerTransportSecurityException e) {
response.sendError(e.getStatusCode(), e.getMessage());
Expand Down Expand Up @@ -353,8 +353,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
}

try {
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
this.securityValidator.validateHeaders(headers);
this.securityValidator.validateHeaders(name -> Collections.list(request.getHeaders(name)));
}
catch (ServerTransportSecurityException e) {
response.sendError(e.getStatusCode(), e.getMessage());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Collections;
import java.util.List;
import java.util.Map;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -134,8 +134,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
}

try {
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
this.securityValidator.validateHeaders(headers);
this.securityValidator.validateHeaders(name -> Collections.list(request.getHeaders(name)));
}
catch (ServerTransportSecurityException e) {
response.sendError(e.getStatusCode(), e.getMessage());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.io.PrintWriter;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
Expand Down Expand Up @@ -271,8 +272,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
}

try {
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
this.securityValidator.validateHeaders(headers);
this.securityValidator.validateHeaders(name -> Collections.list(request.getHeaders(name)));
}
catch (ServerTransportSecurityException e) {
response.sendError(e.getStatusCode(), e.getMessage());
Expand Down Expand Up @@ -407,8 +407,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
}

try {
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
this.securityValidator.validateHeaders(headers);
this.securityValidator.validateHeaders(name -> Collections.list(request.getHeaders(name)));
}
catch (ServerTransportSecurityException e) {
response.sendError(e.getStatusCode(), e.getMessage());
Expand Down Expand Up @@ -588,8 +587,7 @@ protected void doDelete(HttpServletRequest request, HttpServletResponse response
}

try {
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
this.securityValidator.validateHeaders(headers);
this.securityValidator.validateHeaders(name -> Collections.list(request.getHeaders(name)));
}
catch (ServerTransportSecurityException e) {
response.sendError(e.getStatusCode(), e.getMessage());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
package io.modelcontextprotocol.server.transport;

import java.util.List;
import java.util.Map;
import java.util.function.Function;

/**
* Interface for validating HTTP requests in server transports. Implementations can
Expand All @@ -22,15 +22,16 @@ public interface ServerTransportSecurityValidator {
/**
* A no-op validator that accepts all requests without validation.
*/
ServerTransportSecurityValidator NOOP = headers -> {
ServerTransportSecurityValidator NOOP = headerAccessor -> {
};

/**
* Validates the HTTP headers from an incoming request.
* @param headers A map of header names to their values (multi-valued headers
* supported)
* @param headerAccessor A function that returns the list of values for a given header
* name, or an empty list if the header is not present. Header name lookup should be
* case-insensitive.
* @throws ServerTransportSecurityException if validation fails
*/
void validateHeaders(Map<String, List<String>> headers) throws ServerTransportSecurityException;
void validateHeaders(Function<String, List<String>> headerAccessor) throws ServerTransportSecurityException;

}
Loading