From 83613492b17fb58243d4fc3a8dcb66a1bc328ba5 Mon Sep 17 00:00:00 2001 From: Ilya Kazakov Date: Fri, 12 Jun 2026 11:45:20 +0200 Subject: [PATCH 1/2] client side implementation --- README.md | 15 +- pom.xml | 7 + .../connectjava/api/ConnectCallExchange.java | 3 +- .../api/ConnectClientResponseStart.java | 10 + .../connectjava/api/ConnectEndOfStream.java | 22 +- .../connectjava/api/ConnectMessage.java | 1 + .../api/ConnectResponseHeadersBuilder.java | 4 +- .../connectjava/api/ConnectResponseMeta.java | 35 + .../api/ConnectResponseTrailersBuilder.java | 4 +- .../ConnectCompressionNegotiation.java | 12 +- .../connectjava/protocol/ConnectEnvelope.java | 29 +- .../protocol/ConnectInterceptorPipeline.java | 93 --- .../protocol/ConnectMediaType.java | 10 +- .../protocol/ConnectProtocolHttpHeaders.java | 13 + .../protocol/ConnectProtocolVersion.java | 8 +- .../protocol/ConnectTransport.java | 10 - .../protocol/client/ClientHandlerSupport.java | 165 +++++ .../ConnectCallTerminatedException.java | 2 +- .../client/ConnectClientCallDispatcher.java | 116 +++ .../client/ConnectClientCallObserver.java | 67 ++ .../client/ConnectClientCallStart.java | 97 +++ .../ConnectClientChannelConfigurer.java | 19 + .../client/ConnectClientInterceptor.java | 70 ++ .../ConnectClientInterceptorPipeline.java | 95 +++ .../client/ConnectClientPipeline.java | 12 + .../client/ConnectClientProtocol.java | 13 + .../client/ConnectClientProtocolConfig.java | 64 ++ .../ConnectClientProtocolParameters.java | 11 + .../protocol/client/ConnectErrorBody.java | 17 + .../client/ConnectJsonDeserializer.java | 70 ++ .../ConnectStringBuilderJsonDeserializer.java | 315 ++++++++ .../client/StreamingClientHandler.java | 415 +++++++++++ .../client/UnaryGetRequestClientHandler.java | 147 ++++ .../client/UnaryPostRequestClientHandler.java | 145 ++++ .../client/UnaryResponseClientHandler.java | 186 +++++ .../protocol/client/package-info.java | 4 + .../{ => server}/ConnectCorsParameters.java | 4 +- .../{ => server}/ConnectEndStreamMeta.java | 2 +- .../ConnectEndStreamResponse.java | 2 +- .../{ => server}/ConnectJsonSerializer.java | 4 +- .../{ => server}/ConnectMetaBuilder.java | 2 +- .../protocol/{ => server}/ConnectRoute.java | 2 +- .../ConnectServerCallObserver.java} | 12 +- .../ConnectServerChannelConfigurer.java} | 16 +- .../ConnectServerInterceptor.java} | 28 +- .../ConnectServerInterceptorPipeline.java | 93 +++ .../ConnectServerPipeline.java} | 14 +- .../ConnectServerProtocol.java} | 34 +- .../ConnectServerProtocolConfig.java} | 33 +- .../ConnectServerProtocolParameters.java} | 10 +- .../ConnectStringBuilderJsonSerializer.java | 2 +- .../protocol/server/ConnectTransport.java | 10 + .../protocol/{ => server}/HttpResponses.java | 2 +- .../{ => server}/ResponseHeadersBuilder.java | 2 +- .../{ => server}/ResponseTrailersBuilder.java | 2 +- .../RoutingServerHandler.java} | 58 +- .../StreamingServerHandler.java} | 23 +- .../UnaryGetRequestServerHandler.java} | 29 +- .../UnaryPostRequestServerHandler.java} | 31 +- ...UnaryResponseProcessingServerHandler.java} | 12 +- .../protocol/server/package-info.java | 4 + .../protocol/ClientTestSupport.java | 169 +++++ .../ConnectClientProtocolIntegrationTest.java | 305 ++++++++ ... => ConnectServerProtocolVersionTest.java} | 2 +- .../ConnectClientCallDispatcherTest.java | 250 +++++++ .../client/ConnectClientCallStartTest.java | 90 +++ .../ConnectClientChannelConfigurerTest.java | 64 ++ .../ConnectClientInterceptorPipelineTest.java | 226 ++++++ .../ConnectClientProtocolConfigTest.java | 62 ++ ...nectStringBuilderJsonDeserializerTest.java | 182 +++++ .../client/StreamingClientHandlerTest.java | 694 ++++++++++++++++++ .../UnaryGetRequestClientHandlerTest.java | 183 +++++ .../UnaryPostRequestClientHandlerTest.java | 208 ++++++ .../UnaryResponseClientHandlerTest.java | 351 +++++++++ .../ConnectEndStreamMetaBuilderTest.java | 2 +- .../ConnectEndStreamMetaTest.java | 2 +- .../{ => server}/ConnectMetaBuilderTest.java | 2 +- .../{ => server}/ConnectRouteTest.java | 2 +- .../ConnectServerChannelConfigurerTest.java} | 24 +- ...ConnectServerInterceptorPipelineTest.java} | 104 +-- ...onnectStringBuilderJsonSerializerTest.java | 2 +- .../protocol/{ => server}/HttpAssertions.java | 2 +- .../ResponseHeadersBuilderTest.java | 2 +- .../ResponseTrailersBuilderTest.java | 2 +- .../RoutingServerHandlerTest.java} | 70 +- .../StreamingServerHandlerTest.java} | 38 +- .../UnaryGetRequestServerHandlerTest.java} | 24 +- .../UnaryPostRequestServerHandlerTest.java} | 26 +- ...yResponseProcessingServerHandlerTest.java} | 31 +- 89 files changed, 5393 insertions(+), 463 deletions(-) create mode 100644 src/main/java/io/suboptimal/connectjava/api/ConnectClientResponseStart.java create mode 100644 src/main/java/io/suboptimal/connectjava/api/ConnectResponseMeta.java delete mode 100644 src/main/java/io/suboptimal/connectjava/protocol/ConnectInterceptorPipeline.java create mode 100644 src/main/java/io/suboptimal/connectjava/protocol/ConnectProtocolHttpHeaders.java delete mode 100644 src/main/java/io/suboptimal/connectjava/protocol/ConnectTransport.java create mode 100644 src/main/java/io/suboptimal/connectjava/protocol/client/ClientHandlerSupport.java rename src/main/java/io/suboptimal/connectjava/protocol/{ => client}/ConnectCallTerminatedException.java (95%) create mode 100644 src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientCallDispatcher.java create mode 100644 src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientCallObserver.java create mode 100644 src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientCallStart.java create mode 100644 src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientChannelConfigurer.java create mode 100644 src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientInterceptor.java create mode 100644 src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientInterceptorPipeline.java create mode 100644 src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientPipeline.java create mode 100644 src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientProtocol.java create mode 100644 src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientProtocolConfig.java create mode 100644 src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientProtocolParameters.java create mode 100644 src/main/java/io/suboptimal/connectjava/protocol/client/ConnectErrorBody.java create mode 100644 src/main/java/io/suboptimal/connectjava/protocol/client/ConnectJsonDeserializer.java create mode 100644 src/main/java/io/suboptimal/connectjava/protocol/client/ConnectStringBuilderJsonDeserializer.java create mode 100644 src/main/java/io/suboptimal/connectjava/protocol/client/StreamingClientHandler.java create mode 100644 src/main/java/io/suboptimal/connectjava/protocol/client/UnaryGetRequestClientHandler.java create mode 100644 src/main/java/io/suboptimal/connectjava/protocol/client/UnaryPostRequestClientHandler.java create mode 100644 src/main/java/io/suboptimal/connectjava/protocol/client/UnaryResponseClientHandler.java create mode 100644 src/main/java/io/suboptimal/connectjava/protocol/client/package-info.java rename src/main/java/io/suboptimal/connectjava/protocol/{ => server}/ConnectCorsParameters.java (97%) rename src/main/java/io/suboptimal/connectjava/protocol/{ => server}/ConnectEndStreamMeta.java (97%) rename src/main/java/io/suboptimal/connectjava/protocol/{ => server}/ConnectEndStreamResponse.java (90%) rename src/main/java/io/suboptimal/connectjava/protocol/{ => server}/ConnectJsonSerializer.java (91%) rename src/main/java/io/suboptimal/connectjava/protocol/{ => server}/ConnectMetaBuilder.java (94%) rename src/main/java/io/suboptimal/connectjava/protocol/{ => server}/ConnectRoute.java (94%) rename src/main/java/io/suboptimal/connectjava/protocol/{ConnectCallObserver.java => server/ConnectServerCallObserver.java} (85%) rename src/main/java/io/suboptimal/connectjava/protocol/{ConnectChannelConfigurer.java => server/ConnectServerChannelConfigurer.java} (78%) rename src/main/java/io/suboptimal/connectjava/protocol/{ConnectInterceptor.java => server/ConnectServerInterceptor.java} (75%) create mode 100644 src/main/java/io/suboptimal/connectjava/protocol/server/ConnectServerInterceptorPipeline.java rename src/main/java/io/suboptimal/connectjava/protocol/{ConnectPipeline.java => server/ConnectServerPipeline.java} (85%) rename src/main/java/io/suboptimal/connectjava/protocol/{ConnectProtocol.java => server/ConnectServerProtocol.java} (63%) rename src/main/java/io/suboptimal/connectjava/protocol/{ConnectProtocolConfig.java => server/ConnectServerProtocolConfig.java} (81%) rename src/main/java/io/suboptimal/connectjava/protocol/{ConnectProtocolParameters.java => server/ConnectServerProtocolParameters.java} (73%) rename src/main/java/io/suboptimal/connectjava/protocol/{ => server}/ConnectStringBuilderJsonSerializer.java (98%) create mode 100644 src/main/java/io/suboptimal/connectjava/protocol/server/ConnectTransport.java rename src/main/java/io/suboptimal/connectjava/protocol/{ => server}/HttpResponses.java (98%) rename src/main/java/io/suboptimal/connectjava/protocol/{ => server}/ResponseHeadersBuilder.java (96%) rename src/main/java/io/suboptimal/connectjava/protocol/{ => server}/ResponseTrailersBuilder.java (97%) rename src/main/java/io/suboptimal/connectjava/protocol/{RoutingHandler.java => server/RoutingServerHandler.java} (78%) rename src/main/java/io/suboptimal/connectjava/protocol/{StreamingHandler.java => server/StreamingServerHandler.java} (93%) rename src/main/java/io/suboptimal/connectjava/protocol/{UnaryGetRequestHandler.java => server/UnaryGetRequestServerHandler.java} (88%) rename src/main/java/io/suboptimal/connectjava/protocol/{UnaryPostRequestHandler.java => server/UnaryPostRequestServerHandler.java} (79%) rename src/main/java/io/suboptimal/connectjava/protocol/{UnaryResponseProcessingHandler.java => server/UnaryResponseProcessingServerHandler.java} (95%) create mode 100644 src/main/java/io/suboptimal/connectjava/protocol/server/package-info.java create mode 100644 src/test/java/io/suboptimal/connectjava/protocol/ClientTestSupport.java create mode 100644 src/test/java/io/suboptimal/connectjava/protocol/ConnectClientProtocolIntegrationTest.java rename src/test/java/io/suboptimal/connectjava/protocol/{ConnectProtocolVersionTest.java => ConnectServerProtocolVersionTest.java} (96%) create mode 100644 src/test/java/io/suboptimal/connectjava/protocol/client/ConnectClientCallDispatcherTest.java create mode 100644 src/test/java/io/suboptimal/connectjava/protocol/client/ConnectClientCallStartTest.java create mode 100644 src/test/java/io/suboptimal/connectjava/protocol/client/ConnectClientChannelConfigurerTest.java create mode 100644 src/test/java/io/suboptimal/connectjava/protocol/client/ConnectClientInterceptorPipelineTest.java create mode 100644 src/test/java/io/suboptimal/connectjava/protocol/client/ConnectClientProtocolConfigTest.java create mode 100644 src/test/java/io/suboptimal/connectjava/protocol/client/ConnectStringBuilderJsonDeserializerTest.java create mode 100644 src/test/java/io/suboptimal/connectjava/protocol/client/StreamingClientHandlerTest.java create mode 100644 src/test/java/io/suboptimal/connectjava/protocol/client/UnaryGetRequestClientHandlerTest.java create mode 100644 src/test/java/io/suboptimal/connectjava/protocol/client/UnaryPostRequestClientHandlerTest.java create mode 100644 src/test/java/io/suboptimal/connectjava/protocol/client/UnaryResponseClientHandlerTest.java rename src/test/java/io/suboptimal/connectjava/protocol/{ => server}/ConnectEndStreamMetaBuilderTest.java (95%) rename src/test/java/io/suboptimal/connectjava/protocol/{ => server}/ConnectEndStreamMetaTest.java (95%) rename src/test/java/io/suboptimal/connectjava/protocol/{ => server}/ConnectMetaBuilderTest.java (97%) rename src/test/java/io/suboptimal/connectjava/protocol/{ => server}/ConnectRouteTest.java (97%) rename src/test/java/io/suboptimal/connectjava/protocol/{ConnectChannelConfigurerTest.java => server/ConnectServerChannelConfigurerTest.java} (85%) rename src/test/java/io/suboptimal/connectjava/protocol/{ConnectInterceptorPipelineTest.java => server/ConnectServerInterceptorPipelineTest.java} (50%) rename src/test/java/io/suboptimal/connectjava/protocol/{ => server}/ConnectStringBuilderJsonSerializerTest.java (99%) rename src/test/java/io/suboptimal/connectjava/protocol/{ => server}/HttpAssertions.java (97%) rename src/test/java/io/suboptimal/connectjava/protocol/{ => server}/ResponseHeadersBuilderTest.java (94%) rename src/test/java/io/suboptimal/connectjava/protocol/{ => server}/ResponseTrailersBuilderTest.java (96%) rename src/test/java/io/suboptimal/connectjava/protocol/{RoutingHandlerTest.java => server/RoutingServerHandlerTest.java} (86%) rename src/test/java/io/suboptimal/connectjava/protocol/{StreamingHandlerTest.java => server/StreamingServerHandlerTest.java} (93%) rename src/test/java/io/suboptimal/connectjava/protocol/{UnaryGetRequestHandlerTest.java => server/UnaryGetRequestServerHandlerTest.java} (97%) rename src/test/java/io/suboptimal/connectjava/protocol/{UnaryPostRequestHandlerTest.java => server/UnaryPostRequestServerHandlerTest.java} (93%) rename src/test/java/io/suboptimal/connectjava/protocol/{UnaryResponseProcessingHandlerTest.java => server/UnaryResponseProcessingServerHandlerTest.java} (95%) diff --git a/README.md b/README.md index 04177c4..c12e48c 100644 --- a/README.md +++ b/README.md @@ -100,7 +100,10 @@ import io.netty.handler.codec.http.HttpServerCodec; import io.suboptimal.connectjava.codec.protobuf.ConnectProtobufCodecs; import io.suboptimal.connectjava.model.*; -import io.suboptimal.connectjava.protocol.*; +import io.suboptimal.connectjava.protocol.server.ConnectCorsParameters; +import io.suboptimal.connectjava.protocol.server.ConnectServerProtocol; +import io.suboptimal.connectjava.protocol.server.ConnectServerProtocolConfig; +import io.suboptimal.connectjava.protocol.server.ConnectServerProtocolParameters; import java.util.Map; @@ -115,18 +118,18 @@ ConnectServiceDefinition greeter = new ConnectServiceDefinition( /* idempotent — also reachable via Unary-GET */ true)), /* optional descriptor for introspection */ null); -ConnectProtocolConfig config = ConnectProtocolConfig +ConnectServerProtocolConfig config = ConnectServerProtocolConfig .builder( Map.of(greeter.serviceName(), greeter), GreeterCallHandler::new, // ConnectCallHandlerFactory - new ConnectProtocolParameters( + new ConnectServerProtocolParameters( /* maxRequestBytes */ 4 * 1024 * 1024, /* maxFrameBytes */ 1 * 1024 * 1024, ConnectCorsParameters.disabled()), ConnectProtobufCodecs.defaults()) // proto + proto-json codecs .build(); -ConnectProtocol protocol = new ConnectProtocol(config); +ConnectServerProtocol protocol = new ConnectServerProtocol(config); ChannelInitializer http1Initializer = new ChannelInitializer<>() { @Override @@ -347,13 +350,13 @@ serves Connect alongside any other HTTP/1.1 or HTTP/2 protocol you implement, with ALPN, H2C prior-knowledge, and H2C upgrade negotiation done for you: ```java -import io.suboptimal.connectjava.protocol.ConnectProtocol; +import io.suboptimal.connectjava.protocol.server.ConnectServerProtocol; import io.suboptimal.nettymultiprotocol.AppChannelConfigurer; import io.suboptimal.nettymultiprotocol.AppProtocol; import io.suboptimal.nettymultiprotocol.AppProtocolRegistry; import io.suboptimal.nettymultiprotocol.NettyMultiprotocol; -ConnectProtocol connect = new ConnectProtocol(connectConfig); +io.suboptimal.connectjava.protocol.server.ConnectServerProtocol connect = new ConnectServerProtocol(connectConfig); AppProtocol connectAsApp = new AppProtocol() { @Override public AppChannelConfigurer http1() { return connect.http1()::configure; } diff --git a/pom.xml b/pom.xml index eac1468..4f9d3ce 100644 --- a/pom.xml +++ b/pom.xml @@ -46,6 +46,7 @@ 3.15.0 3.2.5 3.1.2 + 24.1.0 0.7.0 1.6.0 @@ -88,6 +89,12 @@ ${jspecify.version} + + org.jetbrains + annotations + ${jetbrains-annotations.version} + + com.google.protobuf protobuf-java diff --git a/src/main/java/io/suboptimal/connectjava/api/ConnectCallExchange.java b/src/main/java/io/suboptimal/connectjava/api/ConnectCallExchange.java index 31c1e3d..49e3353 100644 --- a/src/main/java/io/suboptimal/connectjava/api/ConnectCallExchange.java +++ b/src/main/java/io/suboptimal/connectjava/api/ConnectCallExchange.java @@ -2,9 +2,10 @@ import io.suboptimal.connectjava.model.ConnectMethodDefinition; import io.suboptimal.connectjava.model.ConnectServiceDefinition; +import io.suboptimal.connectjava.protocol.server.ConnectServerInterceptor; /** - * Immutable view of a Connect RPC call passed to each {@link io.suboptimal.connectjava.protocol.ConnectInterceptor}. + * Immutable view of a Connect RPC call passed to each {@link ConnectServerInterceptor}. * *

{@link #responseHeadersBuilder()} accepts mutations until the first response payload (or the * terminal response for unary calls) is written; after that point any mutation throws diff --git a/src/main/java/io/suboptimal/connectjava/api/ConnectClientResponseStart.java b/src/main/java/io/suboptimal/connectjava/api/ConnectClientResponseStart.java new file mode 100644 index 0000000..bb448e2 --- /dev/null +++ b/src/main/java/io/suboptimal/connectjava/api/ConnectClientResponseStart.java @@ -0,0 +1,10 @@ +package io.suboptimal.connectjava.api; + +import io.suboptimal.connectjava.model.ConnectMethodDefinition; +import io.suboptimal.connectjava.model.ConnectServiceDefinition; + +public record ConnectClientResponseStart ( + ConnectServiceDefinition serviceDefinition, + ConnectMethodDefinition methodDefinition, + ConnectResponseMeta responseMeta +) implements ConnectMessage {} diff --git a/src/main/java/io/suboptimal/connectjava/api/ConnectEndOfStream.java b/src/main/java/io/suboptimal/connectjava/api/ConnectEndOfStream.java index 5577c87..7180b63 100644 --- a/src/main/java/io/suboptimal/connectjava/api/ConnectEndOfStream.java +++ b/src/main/java/io/suboptimal/connectjava/api/ConnectEndOfStream.java @@ -1,8 +1,24 @@ package io.suboptimal.connectjava.api; +import org.jspecify.annotations.Nullable; + +import java.util.List; +import java.util.Map; + /** - * Terminal signal indicating the successful end of one side of an RPC payload stream. + * Terminal signal indicating the end of one side of an RPC payload stream. + * + *

For streaming calls the {@code trailers} field carries trailing metadata from the + * end-stream envelope, and {@code error} is non-null when the end-stream envelope carried + * an error. A consumer must treat this message as terminal in both cases: a non-null + * {@code error} means the call failed, and the trailers are still available. + * For unary calls use {@link #INSTANCE} (no trailers, no error). */ -public record ConnectEndOfStream() implements ConnectMessage { - public static final ConnectEndOfStream INSTANCE = new ConnectEndOfStream(); +public record ConnectEndOfStream(Map> trailers, @Nullable ConnectError error) + implements ConnectMessage { + public static final ConnectEndOfStream INSTANCE = new ConnectEndOfStream(Map.of(), null); + + public ConnectEndOfStream(Map> trailers) { + this(trailers, null); + } } diff --git a/src/main/java/io/suboptimal/connectjava/api/ConnectMessage.java b/src/main/java/io/suboptimal/connectjava/api/ConnectMessage.java index f278700..3a04885 100644 --- a/src/main/java/io/suboptimal/connectjava/api/ConnectMessage.java +++ b/src/main/java/io/suboptimal/connectjava/api/ConnectMessage.java @@ -2,6 +2,7 @@ public sealed interface ConnectMessage permits ConnectCallExchange, + ConnectClientResponseStart, ConnectPayload, ConnectEndOfStream, ConnectError diff --git a/src/main/java/io/suboptimal/connectjava/api/ConnectResponseHeadersBuilder.java b/src/main/java/io/suboptimal/connectjava/api/ConnectResponseHeadersBuilder.java index 513885c..fcdc436 100644 --- a/src/main/java/io/suboptimal/connectjava/api/ConnectResponseHeadersBuilder.java +++ b/src/main/java/io/suboptimal/connectjava/api/ConnectResponseHeadersBuilder.java @@ -1,9 +1,9 @@ package io.suboptimal.connectjava.api; -import io.suboptimal.connectjava.protocol.ConnectCallObserver; +import io.suboptimal.connectjava.protocol.server.ConnectServerCallObserver; /** - * Mutable Connect response headers collected by {@link ConnectCallObserver}s. + * Mutable Connect response headers collected by {@link ConnectServerCallObserver}s. * *

Mutations are applied to the wire response after all header observers run. Operation * order is preserved, including the difference between {@link #set(CharSequence, CharSequence)} diff --git a/src/main/java/io/suboptimal/connectjava/api/ConnectResponseMeta.java b/src/main/java/io/suboptimal/connectjava/api/ConnectResponseMeta.java new file mode 100644 index 0000000..ff00f6b --- /dev/null +++ b/src/main/java/io/suboptimal/connectjava/api/ConnectResponseMeta.java @@ -0,0 +1,35 @@ +package io.suboptimal.connectjava.api; + +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * Metadata from a Connect RPC response. + * + * @param statusCode HTTP status code + * @param headers leading metadata (HTTP response headers without {@code trailer-} prefix) + * @param trailers trailing metadata; for unary responses these are extracted from HTTP headers + * with the {@code trailer-} prefix stripped; for streaming responses they + * come from the end-stream envelope's {@code metadata} field via + * {@link ConnectEndOfStream#trailers()} + */ +public record ConnectResponseMeta( + int statusCode, + Map> headers, + Map> trailers +) { + public ConnectResponseMeta { + headers = copyLower(headers); + trailers = copyLower(trailers); + } + + private static Map> copyLower(Map> source) { + return source + .entrySet() + .stream() + .collect(Collectors.toUnmodifiableMap(e -> e.getKey().toLowerCase(Locale.ROOT), + Map.Entry::getValue)); + } +} diff --git a/src/main/java/io/suboptimal/connectjava/api/ConnectResponseTrailersBuilder.java b/src/main/java/io/suboptimal/connectjava/api/ConnectResponseTrailersBuilder.java index 1944845..d61d974 100644 --- a/src/main/java/io/suboptimal/connectjava/api/ConnectResponseTrailersBuilder.java +++ b/src/main/java/io/suboptimal/connectjava/api/ConnectResponseTrailersBuilder.java @@ -1,9 +1,9 @@ package io.suboptimal.connectjava.api; -import io.suboptimal.connectjava.protocol.ConnectCallObserver; +import io.suboptimal.connectjava.protocol.server.ConnectServerCallObserver; /** - * Mutable Connect response trailers collected by {@link ConnectCallObserver}s. + * Mutable Connect response trailers collected by {@link ConnectServerCallObserver}s. * *

For unary RPCs, trailers are serialized as {@code Trailer-*} response headers. For * streaming RPCs, trailers are serialized as the {@code metadata} object in the diff --git a/src/main/java/io/suboptimal/connectjava/protocol/ConnectCompressionNegotiation.java b/src/main/java/io/suboptimal/connectjava/protocol/ConnectCompressionNegotiation.java index fec17ff..7e73656 100644 --- a/src/main/java/io/suboptimal/connectjava/protocol/ConnectCompressionNegotiation.java +++ b/src/main/java/io/suboptimal/connectjava/protocol/ConnectCompressionNegotiation.java @@ -5,6 +5,7 @@ import io.suboptimal.connectjava.compression.ConnectCompression; import io.suboptimal.connectjava.compression.ConnectCompressionRegistry; import io.suboptimal.connectjava.compression.ConnectIdentityCompression; +import org.jetbrains.annotations.ApiStatus; import org.jspecify.annotations.Nullable; import java.io.IOException; @@ -12,12 +13,13 @@ import java.util.List; import java.util.Locale; -final class ConnectCompressionNegotiation { +@ApiStatus.Internal +public final class ConnectCompressionNegotiation { private ConnectCompressionNegotiation() { } - static @Nullable String compressionNameFor(@Nullable CharSequence encoding) { + public static @Nullable String compressionNameFor(@Nullable CharSequence encoding) { if (encoding == null) { return null; } @@ -25,11 +27,11 @@ private ConnectCompressionNegotiation() { return name.isEmpty() ? null : name; } - static String formatSupportedEncodings(ConnectCompressionRegistry registry) { + public static String formatSupportedEncodings(ConnectCompressionRegistry registry) { return String.join(",", registry.supportedNames()); } - static ConnectCompression selectResponseEncoding( + public static ConnectCompression selectResponseEncoding( ConnectCompression requestEncoding, @Nullable String responseEncoding, ConnectCompressionRegistry registry) { if (!requestEncoding.isIdentity()) { @@ -100,7 +102,7 @@ private static double qFor(List codings, String name, double fallback) { return null; } - static ByteBuf decompressMessage(ByteBufAllocator alloc, ByteBuf body, ConnectCompression compression) + public static ByteBuf decompressMessage(ByteBufAllocator alloc, ByteBuf body, ConnectCompression compression) throws IOException { if (body.readableBytes() == 0) { diff --git a/src/main/java/io/suboptimal/connectjava/protocol/ConnectEnvelope.java b/src/main/java/io/suboptimal/connectjava/protocol/ConnectEnvelope.java index 20b8862..17b2914 100644 --- a/src/main/java/io/suboptimal/connectjava/protocol/ConnectEnvelope.java +++ b/src/main/java/io/suboptimal/connectjava/protocol/ConnectEnvelope.java @@ -2,6 +2,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import org.jetbrains.annotations.ApiStatus; import org.jspecify.annotations.Nullable; /** @@ -9,22 +10,23 @@ * *

Encoding is stateless; decoding is stateful (maintains a buffer accumulator). */ -final class ConnectEnvelope { - static final byte FLAG_COMPRESSED = 0x01; - static final byte FLAG_END_STREAM = 0x02; +@ApiStatus.Internal +public final class ConnectEnvelope { + public static final byte FLAG_COMPRESSED = 0x01; + public static final byte FLAG_END_STREAM = 0x02; static final int HEADER_SIZE = 5; private ConnectEnvelope() {} - record DecodedFrame(byte flags, ByteBuf payload) {} + public record DecodedFrame(byte flags, ByteBuf payload) {} - static final class FrameTooLargeException extends RuntimeException { + public static final class FrameTooLargeException extends RuntimeException { FrameTooLargeException(String message) { super(message, null, false, false); } } - static ByteBuf encode(ByteBufAllocator alloc, byte flags, byte[] payload) { + public static ByteBuf encode(ByteBufAllocator alloc, byte flags, byte[] payload) { ByteBuf buf = alloc.buffer(HEADER_SIZE + payload.length); buf.writeByte(flags); buf.writeInt(payload.length); @@ -32,7 +34,7 @@ static ByteBuf encode(ByteBufAllocator alloc, byte flags, byte[] payload) { return buf; } - static ByteBuf encode(ByteBufAllocator alloc, byte flags, ByteBuf payload) { + public static ByteBuf encode(ByteBufAllocator alloc, byte flags, ByteBuf payload) { ByteBuf buf = alloc.buffer(HEADER_SIZE + payload.readableBytes()); buf.writeByte(flags); buf.writeInt(payload.readableBytes()); @@ -40,12 +42,12 @@ static ByteBuf encode(ByteBufAllocator alloc, byte flags, ByteBuf payload) { return buf; } - static final class Decoder { + public static final class Decoder { private final ByteBuf accumulator; private final int maxFrameBytes; private boolean closed; - Decoder(ByteBufAllocator alloc, int maxFrameBytes) { + public Decoder(ByteBufAllocator alloc, int maxFrameBytes) { this.accumulator = alloc.buffer(); this.maxFrameBytes = maxFrameBytes; } @@ -54,7 +56,7 @@ static final class Decoder { * Appends incoming bytes to the internal accumulator. No-op if the decoder is closed * or {@code buf} has no readable bytes. */ - void append(ByteBuf buf) { + public void append(ByteBuf buf) { if (closed) { return; } @@ -70,7 +72,8 @@ void append(ByteBuf buf) { * the declared payload length exceeds the configured maximum. Returns {@code null} if * the decoder is closed. */ - @Nullable DecodedFrame pollFrame() { + @Nullable + public DecodedFrame pollFrame() { if (closed) { return null; } @@ -94,11 +97,11 @@ void append(ByteBuf buf) { return new DecodedFrame(flags, payload); } - int readableBytes() { + public int readableBytes() { return closed ? 0 : accumulator.readableBytes(); } - void close() { + public void close() { if (!closed) { closed = true; accumulator.release(); diff --git a/src/main/java/io/suboptimal/connectjava/protocol/ConnectInterceptorPipeline.java b/src/main/java/io/suboptimal/connectjava/protocol/ConnectInterceptorPipeline.java deleted file mode 100644 index a24e984..0000000 --- a/src/main/java/io/suboptimal/connectjava/protocol/ConnectInterceptorPipeline.java +++ /dev/null @@ -1,93 +0,0 @@ -package io.suboptimal.connectjava.protocol; - -import io.suboptimal.connectjava.api.ConnectCallExchange; -import io.suboptimal.connectjava.api.ConnectError; -import io.suboptimal.connectjava.api.ConnectResponseHeadersBuilder; -import io.suboptimal.connectjava.api.ConnectResponseTrailersBuilder; -import org.jspecify.annotations.Nullable; - -import java.util.ArrayList; -import java.util.List; - -/** - * Builds the per-call observer chain for registered Connect interceptors. - */ -final class ConnectInterceptorPipeline { - static final ConnectInterceptorPipeline EMPTY = new ConnectInterceptorPipeline(List.of()); - - private final List interceptors; - - ConnectInterceptorPipeline(List interceptors) { - this.interceptors = List.copyOf(interceptors); - } - - /** - * Runs all registered interceptors in order, building a composite observer from the - * {@link ConnectInterceptor.Decision.Continue} results. - * - *

If any interceptor returns {@link ConnectInterceptor.Decision.Reject}, iteration stops - * immediately and a {@link ConnectInterceptor.Decision.Reject} carrying the composite of all - * prior {@link ConnectInterceptor.Decision.Continue} observers (and the rejection error) is - * returned. {@link ConnectCallObserver#NOOP} observers are filtered out of the composite. - */ - ConnectInterceptor.Decision interceptCall(ConnectCallExchange exchange) { - if (interceptors.isEmpty()) { - return ConnectInterceptor.continueCall(); - } - - List observers = new ArrayList<>(interceptors.size()); - for (ConnectInterceptor interceptor : interceptors) { - switch (interceptor.interceptCall(exchange)) { - case ConnectInterceptor.Decision.Continue(ConnectCallObserver observer) -> observers.add(observer); - case ConnectInterceptor.Decision.Reject(ConnectCallObserver ignore, ConnectError error) -> { - return new ConnectInterceptor.Decision.Reject(composite(observers), error); - } - } - } - return ConnectInterceptor.continueWith(composite(observers)); - } - - private static ConnectCallObserver composite(List observers) { - List filtered = observers.stream() - .filter(o -> o != ConnectCallObserver.NOOP) - .toList(); - - if (filtered.isEmpty()) { - return ConnectCallObserver.NOOP; - } - if (filtered.size() == 1) { - return filtered.getFirst(); - } - return new CompositeConnectCallObserver(filtered); - } - - private record CompositeConnectCallObserver(List observers) - implements ConnectCallObserver - { - @Override - public void onResponseHeaders(ConnectResponseHeadersBuilder headers) { - for (int i = observers.size() - 1; i >= 0; i--) { - observers.get(i).onResponseHeaders(headers); - } - } - - @Override - public void onResponsePayload(Object payload) { - observers.forEach(observer -> observer.onResponsePayload(payload)); - } - - @Override - public void onResponseTrailers(ConnectResponseTrailersBuilder trailers, @Nullable ConnectError error) { - for (int i = observers.size() - 1; i >= 0; i--) { - observers.get(i).onResponseTrailers(trailers, error); - } - } - - @Override - public void onCallComplete(@Nullable ConnectError error) { - for (int i = observers.size() - 1; i >= 0; i--) { - observers.get(i).onCallComplete(error); - } - } - } -} diff --git a/src/main/java/io/suboptimal/connectjava/protocol/ConnectMediaType.java b/src/main/java/io/suboptimal/connectjava/protocol/ConnectMediaType.java index 6de365a..104e74e 100644 --- a/src/main/java/io/suboptimal/connectjava/protocol/ConnectMediaType.java +++ b/src/main/java/io/suboptimal/connectjava/protocol/ConnectMediaType.java @@ -2,15 +2,17 @@ import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpUtil; +import org.jetbrains.annotations.ApiStatus; import org.jspecify.annotations.Nullable; import java.util.Locale; -final class ConnectMediaType { +@ApiStatus.Internal +public final class ConnectMediaType { private ConnectMediaType() {} - static @Nullable String codecNameFor(HttpRequest request) { + public static @Nullable String codecNameFor(HttpRequest request) { CharSequence mimeTypeRaw = HttpUtil.getMimeType(request); String mimeType = mimeTypeRaw == null ? "" : mimeTypeRaw.toString(); @@ -21,7 +23,7 @@ private ConnectMediaType() {} }; } - static String unaryContentTypeFor(String codecName) { + public static String unaryContentTypeFor(String codecName) { return switch (codecName) { case "proto" -> "application/proto"; case "json" -> "application/json"; @@ -29,7 +31,7 @@ static String unaryContentTypeFor(String codecName) { }; } - static String streamingContentTypeFor(String codecName) { + public static String streamingContentTypeFor(String codecName) { return switch (codecName) { case "proto" -> "application/connect+proto"; case "json" -> "application/connect+json"; diff --git a/src/main/java/io/suboptimal/connectjava/protocol/ConnectProtocolHttpHeaders.java b/src/main/java/io/suboptimal/connectjava/protocol/ConnectProtocolHttpHeaders.java new file mode 100644 index 0000000..f31a368 --- /dev/null +++ b/src/main/java/io/suboptimal/connectjava/protocol/ConnectProtocolHttpHeaders.java @@ -0,0 +1,13 @@ +package io.suboptimal.connectjava.protocol; + +import org.jetbrains.annotations.ApiStatus; + +@ApiStatus.Internal +public final class ConnectProtocolHttpHeaders { + public static final CharSequence CONNECT_PROTOCOL_VERSION = "connect-protocol-version"; + public static final CharSequence CONNECT_CONTENT_ENCODING = "connect-content-encoding"; + public static final CharSequence CONNECT_ACCEPT_ENCODING = "connect-accept-encoding"; + public static final CharSequence CONNECT_TIMEOUT_MS = "connect-timeout-ms"; + + private ConnectProtocolHttpHeaders() {} +} diff --git a/src/main/java/io/suboptimal/connectjava/protocol/ConnectProtocolVersion.java b/src/main/java/io/suboptimal/connectjava/protocol/ConnectProtocolVersion.java index 522da3d..f65c02d 100644 --- a/src/main/java/io/suboptimal/connectjava/protocol/ConnectProtocolVersion.java +++ b/src/main/java/io/suboptimal/connectjava/protocol/ConnectProtocolVersion.java @@ -1,17 +1,19 @@ package io.suboptimal.connectjava.protocol; import io.netty.handler.codec.http.HttpHeaders; +import org.jetbrains.annotations.ApiStatus; import org.jspecify.annotations.Nullable; import java.util.List; -final class ConnectProtocolVersion { +@ApiStatus.Internal +public final class ConnectProtocolVersion { public static final String HEADER_VERSION = "1"; public static final String QUERY_VERSION = "v1"; private ConnectProtocolVersion() {} - static @Nullable String validate(HttpHeaders headers) { + public static @Nullable String validate(HttpHeaders headers) { String connectVersion = headers.get("connect-protocol-version"); if (connectVersion == null || HEADER_VERSION.equals(connectVersion)) { return null; @@ -19,7 +21,7 @@ private ConnectProtocolVersion() {} return formatError(connectVersion); } - static @Nullable String validate(@Nullable List queryParams) { + public static @Nullable String validate(@Nullable List queryParams) { if (queryParams == null || queryParams.isEmpty()) { return null; } diff --git a/src/main/java/io/suboptimal/connectjava/protocol/ConnectTransport.java b/src/main/java/io/suboptimal/connectjava/protocol/ConnectTransport.java deleted file mode 100644 index 119080a..0000000 --- a/src/main/java/io/suboptimal/connectjava/protocol/ConnectTransport.java +++ /dev/null @@ -1,10 +0,0 @@ -package io.suboptimal.connectjava.protocol; - -/** - * HTTP transport version a {@link ConnectChannelConfigurer} and its - * {@link RoutingHandler} are wired for. - */ -enum ConnectTransport { - HTTP_1_1, - HTTP_2 -} diff --git a/src/main/java/io/suboptimal/connectjava/protocol/client/ClientHandlerSupport.java b/src/main/java/io/suboptimal/connectjava/protocol/client/ClientHandlerSupport.java new file mode 100644 index 0000000..04d28b7 --- /dev/null +++ b/src/main/java/io/suboptimal/connectjava/protocol/client/ClientHandlerSupport.java @@ -0,0 +1,165 @@ +package io.suboptimal.connectjava.protocol.client; + +import io.netty.buffer.ByteBuf; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaders; +import io.suboptimal.connectjava.api.ConnectErrorCode; +import io.suboptimal.connectjava.codec.ConnectCodec; +import io.suboptimal.connectjava.compression.ConnectCompression; +import io.suboptimal.connectjava.compression.ConnectIdentityCompression; +import io.suboptimal.connectjava.protocol.ConnectCompressionNegotiation; +import io.suboptimal.connectjava.protocol.ConnectProtocolHttpHeaders; +import org.jspecify.annotations.Nullable; + +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * Shared helpers for the client-side Connect handlers. The unary (POST/GET) and streaming + * handlers all derive the request codec and encoding from the caller's headers, copy + * user headers while skipping protocol-managed ones, and convert Netty headers to a map. + */ +class ClientHandlerSupport { + /** Protocol-managed headers for unary requests; user values for these are ignored. */ + private static final Set UNARY_RESERVED_HEADERS = Set.of( + HttpHeaderNames.CONTENT_TYPE.toString(), + HttpHeaderNames.CONTENT_LENGTH.toString(), + ConnectProtocolHttpHeaders.CONNECT_PROTOCOL_VERSION.toString(), + HttpHeaderNames.CONTENT_ENCODING.toString(), + HttpHeaderNames.ACCEPT_ENCODING.toString()); + + /** Protocol-managed headers for streaming requests; user values for these are ignored. */ + private static final Set STREAMING_RESERVED_HEADERS = Set.of( + HttpHeaderNames.CONTENT_TYPE.toString(), + HttpHeaderNames.CONTENT_LENGTH.toString(), + ConnectProtocolHttpHeaders.CONNECT_PROTOCOL_VERSION.toString(), + HttpHeaderNames.CONTENT_ENCODING.toString(), + HttpHeaderNames.TRANSFER_ENCODING.toString(), + ConnectProtocolHttpHeaders.CONNECT_CONTENT_ENCODING.toString(), + ConnectProtocolHttpHeaders.CONNECT_ACCEPT_ENCODING.toString()); + + private ClientHandlerSupport() {} + + /** Maps a {@code Content-Type} value to a codec name, or {@code null} if unrecognized. */ + @Nullable + static String codecNameForContentType(@Nullable String contentType) { + if (contentType == null) { + return null; + } + int semicolonIdx = contentType.indexOf(';'); + String mimeType = semicolonIdx >= 0 + ? contentType.substring(0, semicolonIdx).trim() + : contentType.trim(); + return switch (mimeType.toLowerCase(Locale.ROOT)) { + case "application/proto", "application/connect+proto" -> "proto"; + case "application/json", "application/connect+json" -> "json"; + default -> null; + }; + } + + /** + * Selects the request codec by explicit name, falling back to the registry's preferred codec + * when {@code codecName} is {@code null} or not registered. + */ + static ConnectCodec selectRequestCodec(ConnectClientProtocolConfig config, + @Nullable String codecName) + { + if (codecName != null) { + ConnectCodec codec = config.codecRegistry().byName(codecName); + if (codec != null) { + return codec; + } + } + + return config.codecRegistry().preferred().getFirst(); + } + + /** + * Resolves the request compression from the caller's {@code content-encoding} header, + * defaulting to identity when absent or unsupported. + */ + static ConnectCompression selectRequestEncoding(ConnectClientProtocolConfig config, + Map> requestHeaders) + { + List values = requestHeaders.get(HttpHeaderNames.CONTENT_ENCODING.toString()); + if (values == null || values.isEmpty()) { + return ConnectIdentityCompression.INSTANCE; + } + + String name = ConnectCompressionNegotiation.compressionNameFor(values.getFirst()); + if (name == null) { + return ConnectIdentityCompression.INSTANCE; + } + + ConnectCompression compression = config.compressionRegistry().resolve(name); + return compression != null ? compression : ConnectIdentityCompression.INSTANCE; + } + + static void copyUserHeadersForUnaryCall(Map> source, HttpHeaders target) { + fillHttpHeaders(source, target, UNARY_RESERVED_HEADERS); + } + + static void copyUserHeadersForStreamCall(Map> source, HttpHeaders target) { + fillHttpHeaders(source, target, STREAMING_RESERVED_HEADERS); + } + + private static void fillHttpHeaders(Map> source, + HttpHeaders target, + Set reserved) + { + source + .entrySet() + .stream() + .filter(e -> !reserved.contains(e.getKey().toLowerCase(Locale.ROOT))) + .forEach(e -> target.add(e.getKey(), e.getValue())); + } + + /** Converts Netty headers to a lower-cased name-to-values map. */ + static Map> toHeaderMap(HttpHeaders headers) { + return headers + .entries() + .stream() + .collect(Collectors.toUnmodifiableMap( + e -> e.getKey().toLowerCase(Locale.ROOT), + e -> List.of(e.getValue()), + (l1,l2) -> Stream.concat(l1.stream(), l2.stream()).toList())); + } + + /** Maps a non-200 HTTP status to the closest Connect error code (client-side table). */ + static ConnectErrorCode httpStatusToErrorCode(int status) { + return switch (status) { + case 400 -> ConnectErrorCode.INTERNAL; + case 401 -> ConnectErrorCode.UNAUTHENTICATED; + case 403 -> ConnectErrorCode.PERMISSION_DENIED; + case 404 -> ConnectErrorCode.UNIMPLEMENTED; + case 429, 502, 503, 504 -> ConnectErrorCode.UNAVAILABLE; + default -> ConnectErrorCode.UNKNOWN; + }; + } + + /** Returns the {@link ConnectErrorCode} whose wire name equals {@code wireName}, or {@code null}. */ + static @Nullable ConnectErrorCode findErrorCodeByWireName(String wireName) { + for (ConnectErrorCode code : ConnectErrorCode.values()) { + if (code.wireName().equals(wireName)) { + return code; + } + } + return null; + } + + /** Reads all readable bytes of {@code buf} into a new array without advancing its reader index. */ + static byte[] toByteArray(ByteBuf buf) { + int length = buf.readableBytes(); + if (length == 0) { + return new byte[0]; + } else { + byte[] bytes = new byte[length]; + buf.getBytes(buf.readerIndex(), bytes); + return bytes; + } + } +} diff --git a/src/main/java/io/suboptimal/connectjava/protocol/ConnectCallTerminatedException.java b/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectCallTerminatedException.java similarity index 95% rename from src/main/java/io/suboptimal/connectjava/protocol/ConnectCallTerminatedException.java rename to src/main/java/io/suboptimal/connectjava/protocol/client/ConnectCallTerminatedException.java index 5b8c6e1..e87d3a8 100644 --- a/src/main/java/io/suboptimal/connectjava/protocol/ConnectCallTerminatedException.java +++ b/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectCallTerminatedException.java @@ -1,4 +1,4 @@ -package io.suboptimal.connectjava.protocol; +package io.suboptimal.connectjava.protocol.client; /** * Singleton signal used when a Connect response handler receives outbound RPC messages after diff --git a/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientCallDispatcher.java b/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientCallDispatcher.java new file mode 100644 index 0000000..125ceee --- /dev/null +++ b/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientCallDispatcher.java @@ -0,0 +1,116 @@ +package io.suboptimal.connectjava.protocol.client; + +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http.HttpObjectAggregator; +import io.suboptimal.connectjava.api.ConnectError; +import io.suboptimal.connectjava.codec.ConnectCodec; +import io.suboptimal.connectjava.model.ConnectMethodType; + +@ChannelHandler.Sharable +class ConnectClientCallDispatcher extends ChannelOutboundHandlerAdapter { + private final ConnectClientProtocolConfig config; + private final ConnectClientInterceptorPipeline interceptorPipeline; + + ConnectClientCallDispatcher(ConnectClientProtocolConfig config) { + this.config = config; + this.interceptorPipeline = config.interceptors().isEmpty() + ? ConnectClientInterceptorPipeline.EMPTY + : new ConnectClientInterceptorPipeline(config.interceptors()); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + if (!(msg instanceof ConnectClientCallStart callStart)) { + ctx.write(msg, promise); + return; + } + + ChannelPipeline pipeline = ctx.pipeline(); + + if (pipeline.get(ConnectClientPipeline.UNARY_POST_HANDLER) != null) { + pipeline.remove(ConnectClientPipeline.UNARY_POST_HANDLER); + } + if (pipeline.get(ConnectClientPipeline.UNARY_GET_HANDLER) != null) { + pipeline.remove(ConnectClientPipeline.UNARY_GET_HANDLER); + } + if (pipeline.get(ConnectClientPipeline.UNARY_RESPONSE_HANDLER) != null) { + pipeline.remove(ConnectClientPipeline.UNARY_RESPONSE_HANDLER); + } + if (pipeline.get(ConnectClientPipeline.AGGREGATOR_HANDLER) != null) { + pipeline.remove(ConnectClientPipeline.AGGREGATOR_HANDLER); + } + if (pipeline.get(ConnectClientPipeline.STREAMING_HANDLER) != null) { + pipeline.remove(ConnectClientPipeline.STREAMING_HANDLER); + } + + ConnectClientInterceptor.Decision decision = interceptorPipeline.interceptCall(callStart); + switch (decision) { + case ConnectClientInterceptor.Decision.Reject(ConnectClientCallObserver observer, var error) -> { + observer.onCallComplete(error); + ctx.fireChannelRead(error); + promise.setSuccess(); + } + case ConnectClientInterceptor.Decision.Continue(var observer, var rewrittenCallStart) -> { + // Interceptors may rewrite the outgoing request; use the effective call start + // (falling back to the original when no interceptor rewrote it) everywhere below. + ConnectClientCallStart effectiveCallStart = rewrittenCallStart != null + ? rewrittenCallStart + : callStart; + + // codecName is a plain String (not an enum) because codecs are extensible and are + // identified by name on the wire and in the registry. The downside of a String is + // that a typo would otherwise slip through and silently fall back to the default + // codec, sending the request in an unintended format. So an explicitly requested + // codec that is not registered fails the call here rather than being ignored; + // a null codecName still means "use the registry's preferred codec". + String codecName = effectiveCallStart.codecName(); + if (codecName != null && config.codecRegistry().byName(codecName) == null) { + ConnectError error = ConnectError.internal( + "Unknown codec '" + codecName + "'; registered: " + registeredCodecNames()); + observer.onCallComplete(error); + ctx.fireChannelRead(error); + promise.setSuccess(); + return; + } + + ConnectMethodType type = effectiveCallStart.methodDefinition().type(); + switch (type) { + case UNARY -> { + pipeline.addBefore(ConnectClientPipeline.CALL_DISPATCHER, + ConnectClientPipeline.AGGREGATOR_HANDLER, + new HttpObjectAggregator(config.parameters().maxResponseBytes())); + if (effectiveCallStart.preferGet() && effectiveCallStart.methodDefinition().idempotent()) { + pipeline.addBefore(ConnectClientPipeline.CALL_DISPATCHER, + ConnectClientPipeline.UNARY_GET_HANDLER, + new UnaryGetRequestClientHandler(effectiveCallStart, config, observer)); + } else { + pipeline.addBefore(ConnectClientPipeline.CALL_DISPATCHER, + ConnectClientPipeline.UNARY_POST_HANDLER, + new UnaryPostRequestClientHandler(effectiveCallStart, config, observer)); + } + } + case SERVER_STREAMING, CLIENT_STREAMING, BIDI_STREAMING -> + pipeline.addBefore(ConnectClientPipeline.CALL_DISPATCHER, + ConnectClientPipeline.STREAMING_HANDLER, + new StreamingClientHandler(effectiveCallStart, config, observer)); + } + ctx.write(effectiveCallStart, promise); + } + } + } + + private String registeredCodecNames() { + StringBuilder sb = new StringBuilder(); + for (ConnectCodec codec : config.codecRegistry().preferred()) { + if (!sb.isEmpty()) { + sb.append(", "); + } + sb.append(codec.name()); + } + return sb.toString(); + } +} diff --git a/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientCallObserver.java b/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientCallObserver.java new file mode 100644 index 0000000..bd28810 --- /dev/null +++ b/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientCallObserver.java @@ -0,0 +1,67 @@ +package io.suboptimal.connectjava.protocol.client; + +import io.suboptimal.connectjava.api.ConnectResponseMeta; +import io.suboptimal.connectjava.api.ConnectError; +import org.jspecify.annotations.Nullable; + +/** + * Stateful observer for one client-side Connect RPC call, attached by a + * {@link ConnectClientInterceptor} via {@link ConnectClientInterceptor#continueWith}. + * + *

Direction differs from the server observer. On the client the {@code onRequest*} + * callbacks are outbound — they fire as the client sends the request — whereas the + * server-side {@code io.suboptimal.connectjava.protocol.ConnectCallObserver} treats them as inbound. + * Likewise {@link #onResponseHeaders} here delivers an immutable {@link ConnectResponseMeta} that the + * client received and can only read; the server observer instead receives a mutable builder it + * can use to shape the response it emits. + * + *

Request callbacks are invoked in interceptor registration order (FIFO); response-header and + * completion callbacks are invoked in reverse registration order (LIFO) so outer interceptors observe + * the response last. Observer exceptions are not swallowed: they propagate through the Netty pipeline + * like other user-code failures. + * + *

If the interceptor returned a {@code Continue} decision, the pipeline guarantees exactly one + * {@link #onCallComplete} for the lifetime of the call, regardless of success, failure, or + * cancellation. + */ +public interface ConnectClientCallObserver { + /** + * Called for each outbound request payload as it is encoded and sent. + * + *

Invoked in interceptor registration order (FIFO). + */ + default void onRequestPayload(Object payload) {} + + /** + * Called once after the last request payload has been sent and the request is fully flushed. + * + *

Invoked in interceptor registration order (FIFO). + */ + default void onRequestFinished() {} + + /** + * Called once when the response headers arrive, before any response payload is delivered. + * + *

{@code meta} is an immutable, read-only view of the received leading metadata; it cannot be + * mutated. Invoked in reverse interceptor registration order (LIFO). + */ + default void onResponseHeaders(ConnectResponseMeta meta) {} + + /** + * Called for each successfully decoded response payload. + * + *

Invoked in interceptor registration order (FIFO). + */ + default void onResponsePayload(Object payload) {} + + /** + * Called exactly once when the call terminates. + * + *

{@code error} is {@code null} for a successful RPC and non-null for a failed or cancelled + * one. Invoked in reverse interceptor registration order (LIFO). + */ + default void onCallComplete(@Nullable ConnectError error) {} + + /** No-op implementation of the client call observer. */ + ConnectClientCallObserver NOOP = new ConnectClientCallObserver() {}; +} diff --git a/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientCallStart.java b/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientCallStart.java new file mode 100644 index 0000000..09ee427 --- /dev/null +++ b/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientCallStart.java @@ -0,0 +1,97 @@ +package io.suboptimal.connectjava.protocol.client; + +import io.suboptimal.connectjava.model.ConnectMethodDefinition; +import io.suboptimal.connectjava.model.ConnectServiceDefinition; +import org.jspecify.annotations.Nullable; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +/** + * Outbound trigger written by the terminal handler to initiate a Connect RPC call. + * + * @param serviceDefinition the target service + * @param methodDefinition the target method + * @param requestHeaders user request headers; keys are normalized to lower case. Protocol-managed + * headers (content-type, content-length, encodings, etc.) are ignored. + * @param preferGet when {@code true} and the method is idempotent, the call is sent as a GET + * @param codecName payload codec to use (e.g. {@code "proto"} or {@code "json"}); when + * {@code null} the codec registry's preferred codec is used + * @param timeoutMs call timeout in milliseconds; when non-null the {@code connect-timeout-ms} + * request header is set; {@code null} means no timeout header is sent + */ +public record ConnectClientCallStart( + ConnectServiceDefinition serviceDefinition, + ConnectMethodDefinition methodDefinition, + Map> requestHeaders, + boolean preferGet, + @Nullable String codecName, + @Nullable Long timeoutMs +) { + public ConnectClientCallStart { + Objects.requireNonNull(serviceDefinition); + Objects.requireNonNull(methodDefinition); + Objects.requireNonNull(requestHeaders); + requestHeaders = requestHeaders + .entrySet() + .stream() + .collect(Collectors.toUnmodifiableMap(e -> e.getKey().toLowerCase(Locale.ROOT), + Map.Entry::getValue)); + } + + public ConnectClientCallStart( + ConnectServiceDefinition serviceDefinition, + ConnectMethodDefinition methodDefinition, + Map> requestHeaders, + boolean preferGet, + @Nullable String codecName + ) { + this(serviceDefinition, methodDefinition, requestHeaders, preferGet, codecName, null); + } + + /** + * Returns a copy with {@code value} appended under {@code name} (case-insensitive); other + * fields are unchanged. Intended for interceptors adding an outgoing request header. + */ + public ConnectClientCallStart withHeader(String name, String value) { + Objects.requireNonNull(name); + Objects.requireNonNull(value); + String key = name.toLowerCase(Locale.ROOT); + Map> headers = new LinkedHashMap<>(requestHeaders); + List existing = headers.get(key); + List merged = new ArrayList<>(existing != null ? existing : List.of()); + merged.add(value); + headers.put(key, merged); + return new ConnectClientCallStart( + serviceDefinition, methodDefinition, headers, preferGet, codecName, timeoutMs); + } + + /** Returns a copy with {@code requestHeaders} replaced; other fields are unchanged. */ + public ConnectClientCallStart withRequestHeaders(Map> requestHeaders) { + return new ConnectClientCallStart( + serviceDefinition, methodDefinition, requestHeaders, preferGet, codecName, timeoutMs); + } + + /** Returns a copy with {@code timeoutMs} replaced; other fields are unchanged. */ + public ConnectClientCallStart withTimeoutMs(@Nullable Long timeoutMs) { + return new ConnectClientCallStart( + serviceDefinition, methodDefinition, requestHeaders, preferGet, codecName, timeoutMs); + } + + /** Returns a copy with {@code codecName} replaced; other fields are unchanged. */ + public ConnectClientCallStart withCodecName(@Nullable String codecName) { + return new ConnectClientCallStart( + serviceDefinition, methodDefinition, requestHeaders, preferGet, codecName, timeoutMs); + } + + /** Returns a copy with {@code preferGet} replaced; other fields are unchanged. */ + public ConnectClientCallStart withPreferGet(boolean preferGet) { + return new ConnectClientCallStart( + serviceDefinition, methodDefinition, requestHeaders, preferGet, codecName, timeoutMs); + } +} diff --git a/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientChannelConfigurer.java b/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientChannelConfigurer.java new file mode 100644 index 0000000..09f8372 --- /dev/null +++ b/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientChannelConfigurer.java @@ -0,0 +1,19 @@ +package io.suboptimal.connectjava.protocol.client; + +import io.netty.channel.Channel; + +public class ConnectClientChannelConfigurer { + private final ConnectClientProtocolConfig config; + private final ConnectClientCallDispatcher callDispatcher; + + ConnectClientChannelConfigurer(ConnectClientProtocolConfig config) { + this.config = config; + this.callDispatcher = new ConnectClientCallDispatcher(config); + } + + public void configure(Channel channel) { + channel.pipeline() + .addLast(ConnectClientPipeline.CALL_DISPATCHER, callDispatcher) + .addLast(config.callHandlerFactory().create()); + } +} diff --git a/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientInterceptor.java b/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientInterceptor.java new file mode 100644 index 0000000..c4cef75 --- /dev/null +++ b/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientInterceptor.java @@ -0,0 +1,70 @@ +package io.suboptimal.connectjava.protocol.client; + +import io.suboptimal.connectjava.api.ConnectError; +import org.jspecify.annotations.Nullable; + +/** + * Factory for per-call client-side Connect interceptors. + * + *

Registered interceptors are invoked in registration order when a {@link ConnectClientCallStart} + * is written, before the request is sent. Each interceptor either continues the call by returning + * {@link #continueCall()} / {@link #continueWith(ConnectClientCallObserver)}, optionally rewriting + * the outgoing request via {@link #continueWith(ConnectClientCallStart)} / + * {@link #continueWith(ConnectClientCallStart, ConnectClientCallObserver)}, or rejects it with a + * Connect-native error by returning {@link #reject(ConnectError)}. + * + *

An interceptor that rewrites the {@link ConnectClientCallStart} shapes the request the client + * is about to emit (headers, timeout, codec, GET preference). The rewritten value is threaded to the + * next interceptor in the chain, so later interceptors observe earlier rewrites. The intent is to + * adjust request metadata; rebuilding the record with a different service or method is possible but + * outside the intended contract. + */ +public interface ConnectClientInterceptor { + /** + * Called once per call, before the request is sent. + * + * @param callStart the (possibly already rewritten) outgoing call description + * @return decision to continue (optionally rewriting the request) or reject the call + */ + Decision interceptCall(ConnectClientCallStart callStart); + + /** Continues the call without attaching an observer or rewriting the request. */ + static Decision continueCall() { + return new Decision.Continue(ConnectClientCallObserver.NOOP, null); + } + + /** Continues the call and attaches {@code observer}, without rewriting the request. */ + static Decision continueWith(ConnectClientCallObserver observer) { + return new Decision.Continue(observer, null); + } + + /** Continues the call with the rewritten {@code callStart}, without attaching an observer. */ + static Decision continueWith(ConnectClientCallStart callStart) { + return new Decision.Continue(ConnectClientCallObserver.NOOP, callStart); + } + + /** Continues the call with the rewritten {@code callStart} and attaches {@code observer}. */ + static Decision continueWith(ConnectClientCallStart callStart, ConnectClientCallObserver observer) { + return new Decision.Continue(observer, callStart); + } + + /** Rejects the call with the given Connect error; the request is never sent. */ + static Decision reject(ConnectError error) { + return new Decision.Reject(ConnectClientCallObserver.NOOP, error); + } + + sealed interface Decision permits Decision.Continue, Decision.Reject { + ConnectClientCallObserver observer(); + + /** + * Continue the call, optionally rewriting the outgoing request. + * + * @param observer observer that receives client call lifecycle callbacks + * @param callStart rewritten request, or {@code null} to leave it unchanged + */ + record Continue(ConnectClientCallObserver observer, + @Nullable ConnectClientCallStart callStart) implements Decision {} + + record Reject(ConnectClientCallObserver observer, ConnectError error) implements Decision {} + } +} diff --git a/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientInterceptorPipeline.java b/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientInterceptorPipeline.java new file mode 100644 index 0000000..5d4d531 --- /dev/null +++ b/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientInterceptorPipeline.java @@ -0,0 +1,95 @@ +package io.suboptimal.connectjava.protocol.client; + +import io.suboptimal.connectjava.api.ConnectResponseMeta; +import io.suboptimal.connectjava.api.ConnectError; +import org.jspecify.annotations.Nullable; + +import java.util.ArrayList; +import java.util.List; + +final class ConnectClientInterceptorPipeline { + static final ConnectClientInterceptorPipeline EMPTY = new ConnectClientInterceptorPipeline(List.of()); + + private final List interceptors; + + ConnectClientInterceptorPipeline(List interceptors) { + this.interceptors = List.copyOf(interceptors); + } + + /** + * Runs the interceptor chain. Returns a {@link ConnectClientInterceptor.Decision.Continue} whose + * {@code callStart()} is the effective (possibly rewritten) request, or a + * {@link ConnectClientInterceptor.Decision.Reject} if any interceptor rejected the call. + */ + ConnectClientInterceptor.Decision interceptCall(ConnectClientCallStart callStart) { + if (interceptors.isEmpty()) { + return ConnectClientInterceptor.continueWith(callStart); + } + + ConnectClientCallStart current = callStart; + List observers = new ArrayList<>(interceptors.size()); + for (ConnectClientInterceptor interceptor : interceptors) { + switch (interceptor.interceptCall(current)) { + case ConnectClientInterceptor.Decision.Continue(var observer, var modified) -> { + if (modified != null) { + current = modified; + } + observers.add(observer); + } + case ConnectClientInterceptor.Decision.Reject(ConnectClientCallObserver ignore, ConnectError error) -> { + return new ConnectClientInterceptor.Decision.Reject(composite(observers), error); + } + } + } + return ConnectClientInterceptor.continueWith(current, composite(observers)); + } + + private static ConnectClientCallObserver composite(List observers) { + List filtered = observers.stream() + .filter(o -> o != ConnectClientCallObserver.NOOP) + .toList(); + + if (filtered.isEmpty()) { + return ConnectClientCallObserver.NOOP; + } + + if (filtered.size() == 1) { + return filtered.getFirst(); + } + + return new CompositeConnectClientCallObserver(filtered); + } + + private record CompositeConnectClientCallObserver(List observers) + implements ConnectClientCallObserver + { + @Override + public void onRequestPayload(Object payload) { + observers.forEach(o -> o.onRequestPayload(payload)); + } + + @Override + public void onRequestFinished() { + observers.forEach(ConnectClientCallObserver::onRequestFinished); + } + + @Override + public void onResponseHeaders(ConnectResponseMeta meta) { + for (int i = observers.size() - 1; i >= 0; i--) { + observers.get(i).onResponseHeaders(meta); + } + } + + @Override + public void onResponsePayload(Object payload) { + observers.forEach(o -> o.onResponsePayload(payload)); + } + + @Override + public void onCallComplete(@Nullable ConnectError error) { + for (int i = observers.size() - 1; i >= 0; i--) { + observers.get(i).onCallComplete(error); + } + } + } +} diff --git a/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientPipeline.java b/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientPipeline.java new file mode 100644 index 0000000..1dc3098 --- /dev/null +++ b/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientPipeline.java @@ -0,0 +1,12 @@ +package io.suboptimal.connectjava.protocol.client; + +public final class ConnectClientPipeline { + public static final String CALL_DISPATCHER = "connectClientDispatcher"; + public static final String AGGREGATOR_HANDLER = "connectClientAggregator"; + public static final String UNARY_POST_HANDLER = "connectClientUnaryPost"; + public static final String UNARY_GET_HANDLER = "connectClientUnaryGet"; + public static final String UNARY_RESPONSE_HANDLER = "connectClientUnaryResponse"; + public static final String STREAMING_HANDLER = "connectClientStreaming"; + + private ConnectClientPipeline() {} +} diff --git a/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientProtocol.java b/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientProtocol.java new file mode 100644 index 0000000..8c76c27 --- /dev/null +++ b/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientProtocol.java @@ -0,0 +1,13 @@ +package io.suboptimal.connectjava.protocol.client; + +public class ConnectClientProtocol { + private final ConnectClientChannelConfigurer http1Configurer; + + public ConnectClientProtocol(ConnectClientProtocolConfig config) { + this.http1Configurer = new ConnectClientChannelConfigurer(config); + } + + public ConnectClientChannelConfigurer http1() { + return http1Configurer; + } +} diff --git a/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientProtocolConfig.java b/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientProtocolConfig.java new file mode 100644 index 0000000..50548f7 --- /dev/null +++ b/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientProtocolConfig.java @@ -0,0 +1,64 @@ +package io.suboptimal.connectjava.protocol.client; + +import io.suboptimal.connectjava.codec.ConnectCodecRegistry; +import io.suboptimal.connectjava.compression.ConnectCompressionRegistry; +import io.suboptimal.connectjava.protocol.ConnectCallHandlerFactory; + +import java.util.List; + +public record ConnectClientProtocolConfig( + ConnectCallHandlerFactory callHandlerFactory, + ConnectClientProtocolParameters parameters, + ConnectCodecRegistry codecRegistry, + ConnectCompressionRegistry compressionRegistry, + ConnectJsonDeserializer jsonDeserializer, + List interceptors +) { + public ConnectClientProtocolConfig { + interceptors = List.copyOf(interceptors); + } + + public static Builder builder(ConnectCallHandlerFactory callHandlerFactory, + ConnectClientProtocolParameters parameters, + ConnectCodecRegistry codecRegistry) { + return new Builder(callHandlerFactory, parameters, codecRegistry); + } + + public static final class Builder { + private final ConnectCallHandlerFactory callHandlerFactory; + private final ConnectClientProtocolParameters parameters; + private final ConnectCodecRegistry codecRegistry; + private ConnectCompressionRegistry compressionRegistry = ConnectCompressionRegistry.standard(); + private ConnectJsonDeserializer jsonDeserializer = ConnectStringBuilderJsonDeserializer.INSTANCE; + private List interceptors = List.of(); + + private Builder(ConnectCallHandlerFactory callHandlerFactory, + ConnectClientProtocolParameters parameters, + ConnectCodecRegistry codecRegistry) { + this.callHandlerFactory = callHandlerFactory; + this.parameters = parameters; + this.codecRegistry = codecRegistry; + } + + public Builder compressionRegistry(ConnectCompressionRegistry compressionRegistry) { + this.compressionRegistry = compressionRegistry; + return this; + } + + public Builder jsonDeserializer(ConnectJsonDeserializer jsonDeserializer) { + this.jsonDeserializer = jsonDeserializer; + return this; + } + + public Builder interceptors(List interceptors) { + this.interceptors = interceptors; + return this; + } + + public ConnectClientProtocolConfig build() { + return new ConnectClientProtocolConfig( + callHandlerFactory, parameters, codecRegistry, + compressionRegistry, jsonDeserializer, interceptors); + } + } +} diff --git a/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientProtocolParameters.java b/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientProtocolParameters.java new file mode 100644 index 0000000..aa5591b --- /dev/null +++ b/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectClientProtocolParameters.java @@ -0,0 +1,11 @@ +package io.suboptimal.connectjava.protocol.client; + +public record ConnectClientProtocolParameters( + int maxResponseBytes, + int maxFrameBytes +) { + public ConnectClientProtocolParameters { + if (maxResponseBytes <= 0) throw new IllegalArgumentException("maxResponseBytes must be > 0"); + if (maxFrameBytes <= 0) throw new IllegalArgumentException("maxFrameBytes must be > 0"); + } +} diff --git a/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectErrorBody.java b/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectErrorBody.java new file mode 100644 index 0000000..9fce7a2 --- /dev/null +++ b/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectErrorBody.java @@ -0,0 +1,17 @@ +package io.suboptimal.connectjava.protocol.client; + +import io.suboptimal.connectjava.api.ConnectErrorDetail; +import org.jspecify.annotations.Nullable; + +import java.util.List; + +/** + * Parsed fields of a Connect unary error body, independent of error code resolution. + * + * @param codeName wire name of the error code (e.g. {@code "not_found"}), or {@code null} if absent + * @param message human-readable message, or {@code null} if absent + * @param details rich error details; empty if the {@code details} array is absent + */ +public record ConnectErrorBody(@Nullable String codeName, + @Nullable String message, + List details) {} diff --git a/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectJsonDeserializer.java b/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectJsonDeserializer.java new file mode 100644 index 0000000..24aa0c1 --- /dev/null +++ b/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectJsonDeserializer.java @@ -0,0 +1,70 @@ +package io.suboptimal.connectjava.protocol.client; + +import io.suboptimal.connectjava.api.ConnectError; +import org.jspecify.annotations.Nullable; + +import java.util.List; +import java.util.Map; + +/** + * SPI for deserializing Connect protocol JSON bodies from bytes. + * + *

Two call sites exist in the client Connect implementation: + *

    + *
  • Unary error responses — parsed by {@link #parseError(byte[])} from the HTTP response + * body when the server returns a non-200 status.
  • + *
  • Streaming EndStreamResponse envelopes — parsed by {@link #parseEndStreamError(byte[])} + * and {@link #parseEndStreamMetadata(byte[])} from the final framed envelope with + * flag {@code 0x02}.
  • + *
+ * + *

The default implementation is {@link ConnectStringBuilderJsonDeserializer#INSTANCE}. + * A custom implementation can be supplied via + * {@link io.suboptimal.connectjava.protocol.client.ConnectClientProtocolConfig.Builder#jsonDeserializer(ConnectJsonDeserializer)}. + */ +public interface ConnectJsonDeserializer { + + /** + * Parses a Connect error from a UTF-8 JSON body. + * + *

Expected input: {@code {"code":"not_found","message":"...","details":[...]}} + * + * @param body UTF-8-encoded JSON bytes + * @return parsed error, or {@code null} if the body cannot be parsed as a Connect error + */ + @Nullable ConnectError parseError(byte[] body); + + /** + * Parses the error field from a Connect streaming EndStreamResponse JSON body. + * + *

Expected input: {@code {"error":{"code":"not_found","message":"...","details":[...]}}} + * for an error, or {@code {}} / {@code {"metadata":{...}}} for a successful completion. + * An unrecognized error code is treated as {@code unknown} rather than as a missing error. + * + * @param body UTF-8-encoded JSON bytes + * @return parsed error, or {@code null} if no error is present (successful completion) + */ + @Nullable ConnectError parseEndStreamError(byte[] body); + + /** + * Parses trailing metadata from a Connect streaming EndStreamResponse JSON body. + * + *

Extracts the top-level {@code "metadata"} field. Keys are returned as on the wire + * (no {@code trailer-} prefix is added or stripped). + * + * @param body UTF-8-encoded JSON bytes + * @return trailing metadata map; empty if the {@code metadata} field is absent + */ + Map> parseEndStreamMetadata(byte[] body); + + /** + * Parses the structured fields of a Connect unary error body without resolving the code. + * + *

Use this when the caller needs to apply its own fallback logic for the error code + * (e.g. fall back to the HTTP-to-Connect mapping when the code is absent or unrecognized). + * + * @param body UTF-8-encoded JSON bytes + * @return parsed body fields, or {@code null} if the body does not look like a Connect error + */ + @Nullable ConnectErrorBody parseErrorBody(byte[] body); +} diff --git a/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectStringBuilderJsonDeserializer.java b/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectStringBuilderJsonDeserializer.java new file mode 100644 index 0000000..37011d8 --- /dev/null +++ b/src/main/java/io/suboptimal/connectjava/protocol/client/ConnectStringBuilderJsonDeserializer.java @@ -0,0 +1,315 @@ +package io.suboptimal.connectjava.protocol.client; + +import io.suboptimal.connectjava.api.ConnectError; +import io.suboptimal.connectjava.api.ConnectErrorCode; +import io.suboptimal.connectjava.api.ConnectErrorDetail; +import org.jspecify.annotations.Nullable; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Base64; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +public final class ConnectStringBuilderJsonDeserializer implements ConnectJsonDeserializer { + public static final ConnectStringBuilderJsonDeserializer INSTANCE = + new ConnectStringBuilderJsonDeserializer(); + + private ConnectStringBuilderJsonDeserializer() {} + + @Override + public @Nullable ConnectError parseError(byte[] body) { + String json = new String(body, StandardCharsets.UTF_8); + String code = extractJsonString(json, "code"); + if (code == null) { + return null; + } + ConnectErrorCode errorCode = findErrorCode(code); + if (errorCode == null) { + errorCode = ConnectErrorCode.UNKNOWN; + } + String message = extractJsonString(json, "message"); + List details = parseDetails(json); + return new ConnectError(errorCode, message != null ? message : "", details); + } + + @Override + public @Nullable ConnectError parseEndStreamError(byte[] body) { + String json = new String(body, StandardCharsets.UTF_8); + int errorIdx = json.indexOf("\"error\""); + if (errorIdx < 0) { + return null; + } + int colonIdx = json.indexOf(':', errorIdx + 7); + if (colonIdx < 0) { + return null; + } + int braceOpen = json.indexOf('{', colonIdx + 1); + if (braceOpen < 0) { + return null; + } + int braceEnd = findClose(json, braceOpen, '{', '}'); + if (braceEnd < 0) { + return null; + } + String errorJson = json.substring(braceOpen, braceEnd); + + String code = extractJsonString(errorJson, "code"); + ConnectErrorCode errorCode = (code != null) ? findErrorCode(code) : null; + if (errorCode == null) { + errorCode = ConnectErrorCode.UNKNOWN; + } + String message = extractJsonString(errorJson, "message"); + List details = parseDetails(errorJson); + return new ConnectError(errorCode, message != null ? message : "", details); + } + + @Override + public Map> parseEndStreamMetadata(byte[] body) { + String json = new String(body, StandardCharsets.UTF_8); + return parseMetadata(json); + } + + @Override + public @Nullable ConnectErrorBody parseErrorBody(byte[] body) { + String json = new String(body, StandardCharsets.UTF_8); + if (!json.contains("\"code\"") && !json.contains("\"message\"") + && !json.contains("\"details\"")) { + return null; + } + String codeName = extractJsonString(json, "code"); + String message = extractJsonString(json, "message"); + List details = parseDetails(json); + return new ConnectErrorBody(codeName, message, details); + } + + private static List parseDetails(String errorJson) { + int detailsIdx = errorJson.indexOf("\"details\""); + if (detailsIdx < 0) { + return List.of(); + } + int colonIdx = errorJson.indexOf(':', detailsIdx + 9); + if (colonIdx < 0) { + return List.of(); + } + int arrayStart = errorJson.indexOf('[', colonIdx + 1); + if (arrayStart < 0) { + return List.of(); + } + int arrayEnd = findClose(errorJson, arrayStart, '[', ']'); + if (arrayEnd < 0) { + return List.of(); + } + + List result = new ArrayList<>(); + int pos = arrayStart + 1; + while (pos < arrayEnd - 1) { + int objStart = errorJson.indexOf('{', pos); + if (objStart < 0 || objStart >= arrayEnd - 1) { + break; + } + int objEnd = findClose(errorJson, objStart, '{', '}'); + if (objEnd < 0) { + break; + } + String objJson = errorJson.substring(objStart, objEnd); + String type = extractJsonString(objJson, "type"); + String value = extractJsonString(objJson, "value"); + if (type != null && value != null) { + try { + byte[] decoded = Base64.getDecoder().decode(value); + result.add(new ConnectErrorDetail(type, decoded)); + } catch (IllegalArgumentException ignored) { + // skip malformed base64 detail + } + } + pos = objEnd; + } + return List.copyOf(result); + } + + private static Map> parseMetadata(String json) { + int metaIdx = json.indexOf("\"metadata\""); + if (metaIdx < 0) { + return Map.of(); + } + int colonIdx = json.indexOf(':', metaIdx + 10); + if (colonIdx < 0) { + return Map.of(); + } + int objStart = json.indexOf('{', colonIdx + 1); + if (objStart < 0) { + return Map.of(); + } + int objEnd = findClose(json, objStart, '{', '}'); + if (objEnd < 0) { + return Map.of(); + } + + Map> result = new LinkedHashMap<>(); + int pos = objStart + 1; + while (pos < objEnd - 1) { + // find next key (a quoted string) + while (pos < objEnd - 1 && json.charAt(pos) != '"') { + pos++; + } + if (pos >= objEnd - 1) { + break; + } + int keyEnd = skipString(json, pos); + if (keyEnd < 0) { + break; + } + String key = readString(json, pos); + if (key == null) { + break; + } + pos = keyEnd; + + // find ':' + while (pos < objEnd && json.charAt(pos) != ':') { + pos++; + } + if (pos >= objEnd) { + break; + } + pos++; // skip ':' + + // find '[' for the value array + while (pos < objEnd && json.charAt(pos) != '[') { + pos++; + } + if (pos >= objEnd) { + break; + } + int arrayEnd = findClose(json, pos, '[', ']'); + if (arrayEnd < 0) { + break; + } + + // extract string values from array + List values = new ArrayList<>(); + int aPos = pos + 1; + while (aPos < arrayEnd - 1) { + while (aPos < arrayEnd - 1 && json.charAt(aPos) != '"') { + aPos++; + } + if (aPos >= arrayEnd - 1) { + break; + } + int strEnd = skipString(json, aPos); + if (strEnd < 0) { + break; + } + String val = readString(json, aPos); + if (val != null) { + values.add(val); + } + aPos = strEnd; + } + result.put(key, List.copyOf(values)); + pos = arrayEnd; + } + return Collections.unmodifiableMap(result); + } + + /** Returns index past the closing {@code "}, handling backslash escapes. -1 if malformed. */ + private static int skipString(String json, int openQuote) { + int i = openQuote + 1; + while (i < json.length()) { + char c = json.charAt(i); + if (c == '\\') { + i += 2; + continue; + } + if (c == '"') { + return i + 1; + } + i++; + } + return -1; + } + + /** Returns the unescaped string value for the JSON string starting at {@code openQuote}. */ + private static @Nullable String readString(String json, int openQuote) { + StringBuilder sb = new StringBuilder(); + int i = openQuote + 1; + while (i < json.length()) { + char c = json.charAt(i); + if (c == '"') { + return sb.toString(); + } + if (c == '\\' && i + 1 < json.length()) { + char next = json.charAt(i + 1); + switch (next) { + case '"' -> sb.append('"'); + case '\\' -> sb.append('\\'); + case 'n' -> sb.append('\n'); + case 'r' -> sb.append('\r'); + case 't' -> sb.append('\t'); + default -> sb.append(next); + } + i += 2; + } else { + sb.append(c); + i++; + } + } + return null; + } + + /** Returns index past the matching close delimiter, tracking depth and skipping strings. */ + private static int findClose(String json, int openIdx, char open, char close) { + int depth = 1; + int i = openIdx + 1; + while (i < json.length() && depth > 0) { + char c = json.charAt(i); + if (c == '"') { + int end = skipString(json, i); + if (end < 0) { + return -1; + } + i = end; + continue; + } + if (c == open) { + depth++; + } else if (c == close) { + depth--; + } + i++; + } + return depth == 0 ? i : -1; + } + + private static @Nullable String extractJsonString(String json, String field) { + String key = "\"" + field + "\""; + int keyIdx = json.indexOf(key); + if (keyIdx < 0) { + return null; + } + int colonIdx = json.indexOf(':', keyIdx + key.length()); + if (colonIdx < 0) { + return null; + } + int p = colonIdx + 1; + while (p < json.length() && Character.isWhitespace(json.charAt(p))) { + p++; + } + if (p >= json.length() || json.charAt(p) != '"') { + return null; + } + return readString(json, p); + } + + private static @Nullable ConnectErrorCode findErrorCode(String wireName) { + for (ConnectErrorCode code : ConnectErrorCode.values()) { + if (code.wireName().equals(wireName)) { + return code; + } + } + return null; + } +} diff --git a/src/main/java/io/suboptimal/connectjava/protocol/client/StreamingClientHandler.java b/src/main/java/io/suboptimal/connectjava/protocol/client/StreamingClientHandler.java new file mode 100644 index 0000000..12776a2 --- /dev/null +++ b/src/main/java/io/suboptimal/connectjava/protocol/client/StreamingClientHandler.java @@ -0,0 +1,415 @@ +package io.suboptimal.connectjava.protocol.client; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http.DefaultHttpContent; +import io.netty.handler.codec.http.DefaultHttpRequest; +import io.netty.handler.codec.http.HttpContent; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.handler.codec.http.LastHttpContent; +import io.netty.util.ReferenceCountUtil; +import io.suboptimal.connectjava.api.ConnectClientResponseStart; +import io.suboptimal.connectjava.api.ConnectResponseMeta; +import io.suboptimal.connectjava.api.ConnectEndOfStream; +import io.suboptimal.connectjava.api.ConnectError; +import io.suboptimal.connectjava.api.ConnectPayload; +import io.suboptimal.connectjava.codec.ConnectCodec; +import io.suboptimal.connectjava.compression.ConnectCompression; +import io.suboptimal.connectjava.compression.ConnectIdentityCompression; +import io.suboptimal.connectjava.model.ConnectMethodType; +import io.suboptimal.connectjava.protocol.ConnectCompressionNegotiation; +import io.suboptimal.connectjava.protocol.ConnectEnvelope; +import io.suboptimal.connectjava.protocol.ConnectMediaType; +import io.suboptimal.connectjava.protocol.ConnectProtocolHttpHeaders; +import io.suboptimal.connectjava.protocol.ConnectProtocolVersion; +import org.jspecify.annotations.Nullable; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +class StreamingClientHandler extends ChannelDuplexHandler { + private final ConnectClientCallStart callStart; + private final ConnectClientProtocolConfig config; + private final ConnectClientCallObserver observer; + private final ConnectCodec codec; + private final ConnectCompression requestEncoding; + + private ConnectCompression responseEncoding = ConnectIdentityCompression.INSTANCE; + private ConnectEnvelope. @Nullable Decoder decoder; + private int requestPayloadsSent; + private int responsePayloadsReceived; + private boolean endStreamReceived; + private boolean closed; + + private enum OutboundState { IDLE, HEADERS_SENT, AWAITING_RESPONSE } + private OutboundState outboundState = OutboundState.IDLE; + + StreamingClientHandler(ConnectClientCallStart callStart, + ConnectClientProtocolConfig config, + ConnectClientCallObserver observer) + { + this.callStart = callStart; + this.config = config; + this.observer = observer; + this.codec = ClientHandlerSupport.selectRequestCodec(config, callStart.codecName()); + this.requestEncoding = ClientHandlerSupport.selectRequestEncoding(config, callStart.requestHeaders()); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + if (closed) { + promise.tryFailure(ConnectCallTerminatedException.INSTANCE); + ReferenceCountUtil.release(msg); + return; + } + + switch (msg) { + case ConnectClientCallStart ignored when outboundState == OutboundState.IDLE -> { + ConnectMethodType type = callStart.methodDefinition().type(); + if (type == ConnectMethodType.BIDI_STREAMING) { + closed = true; + promise.setSuccess(); + ConnectError error = ConnectError.unimplemented("Bidi streaming not supported on HTTP/1.1"); + observer.onCallComplete(error); + ctx.fireChannelRead(error); + return; + } + + responseEncoding = ConnectIdentityCompression.INSTANCE; + + String uri = "/" + callStart.serviceDefinition().serviceName() + + "/" + callStart.methodDefinition().methodName(); + + DefaultHttpRequest request = new DefaultHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.POST, uri); + + request.headers() + .set(HttpHeaderNames.CONTENT_TYPE, ConnectMediaType.streamingContentTypeFor(codec.name())) + .set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED) + .set(ConnectProtocolHttpHeaders.CONNECT_PROTOCOL_VERSION, ConnectProtocolVersion.HEADER_VERSION); + + if (callStart.timeoutMs() != null) { + request.headers().set(ConnectProtocolHttpHeaders.CONNECT_TIMEOUT_MS, callStart.timeoutMs()); + } + + if (!requestEncoding.isIdentity()) { + request.headers().set(ConnectProtocolHttpHeaders.CONNECT_CONTENT_ENCODING, requestEncoding.name()); + } + + String acceptEncoding = ConnectCompressionNegotiation.formatSupportedEncodings(config.compressionRegistry()); + if (!acceptEncoding.isEmpty()) { + request.headers().set(ConnectProtocolHttpHeaders.CONNECT_ACCEPT_ENCODING, acceptEncoding); + } + + ClientHandlerSupport.copyUserHeadersForStreamCall(callStart.requestHeaders(), request.headers()); + + outboundState = OutboundState.HEADERS_SENT; + promise.setSuccess(); + ctx.write(request); + ctx.flush(); + } + case ConnectPayload data when outboundState == OutboundState.HEADERS_SENT -> { + ConnectMethodType type = callStart.methodDefinition().type(); + if (type == ConnectMethodType.SERVER_STREAMING && requestPayloadsSent >= 1) { + closed = true; + promise.setSuccess(); + ConnectError error = ConnectError.unimplemented("Server-streaming method requires exactly one request message"); + observer.onCallComplete(error); + ctx.fireChannelRead(error); + return; + } + + ByteBuf encoded; + try { + encoded = codec.encode(data.data(), ctx.alloc()); + } catch (IOException e) { + closed = true; + promise.setSuccess(); + ConnectError error = ConnectError.internal("Serialization failed: " + e.getMessage()); + observer.onCallComplete(error); + ctx.fireChannelRead(error); + return; + } + + byte flags = 0; + ByteBuf payload = encoded; + if (!requestEncoding.isIdentity()) { + try { + payload = requestEncoding.compress(encoded, ctx.alloc()); + flags = ConnectEnvelope.FLAG_COMPRESSED; + } catch (IOException e) { + encoded.release(); + closed = true; + promise.setSuccess(); + ConnectError error = ConnectError.internal("Compression failed: " + e.getMessage()); + observer.onCallComplete(error); + ctx.fireChannelRead(error); + return; + } + encoded.release(); + } + + try { + ByteBuf buf = ConnectEnvelope.encode(ctx.alloc(), flags, payload); + observer.onRequestPayload(data.data()); + requestPayloadsSent++; + ctx.writeAndFlush(new DefaultHttpContent(buf), promise); + } finally { + payload.release(); + } + } + case ConnectEndOfStream ignored when outboundState == OutboundState.HEADERS_SENT -> { + observer.onRequestFinished(); + outboundState = OutboundState.AWAITING_RESPONSE; + ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT, promise); + } + default -> { + promise.tryFailure(ConnectCallTerminatedException.INSTANCE); + ReferenceCountUtil.release(msg); + } + } + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + if (msg instanceof HttpResponse response) { + handleHttpResponse(ctx, response); + } else if (msg instanceof HttpContent content) { + try { + handleHttpContent(ctx, content); + } finally { + content.release(); + } + if (msg instanceof LastHttpContent) { + handleLastHttpContent(ctx); + } + } else { + ctx.fireChannelRead(msg); + } + } + + private void handleHttpResponse(ChannelHandlerContext ctx, HttpResponse response) { + int statusCode = response.status().code(); + if (statusCode != 200) { + closed = true; + ConnectError error = new ConnectError(ClientHandlerSupport.httpStatusToErrorCode(statusCode), + response.status().reasonPhrase()); + observer.onCallComplete(error); + ctx.fireChannelRead(error); + return; + } + + String respCodecName = ClientHandlerSupport.codecNameForContentType(response.headers().get(HttpHeaderNames.CONTENT_TYPE)); + ConnectCodec respCodec = respCodecName != null + ? config.codecRegistry().byName(respCodecName) + : null; + + if (respCodec == null) { + closed = true; + ConnectError error = ConnectError.unknown("Unsupported or missing Content-Type in response"); + observer.onCallComplete(error); + ctx.fireChannelRead(error); + return; + } + + String requestCodecName = callStart.codecName(); + if (requestCodecName != null && !requestCodecName.equals(respCodecName)) { + closed = true; + ConnectError error = ConnectError.internal( + "Response codec '" + respCodecName + "' does not match request codec '" + requestCodecName + "'"); + observer.onCallComplete(error); + ctx.fireChannelRead(error); + return; + } + + String encodingHeader = response.headers().get(ConnectProtocolHttpHeaders.CONNECT_CONTENT_ENCODING); + String encodingName = ConnectCompressionNegotiation.compressionNameFor(encodingHeader); + if (encodingName != null) { + ConnectCompression c = config.compressionRegistry().resolve(encodingName); + if (c != null) { + responseEncoding = c; + } + } + + Map> headersMap = ClientHandlerSupport.toHeaderMap(response.headers()); + ConnectResponseMeta responseMeta = new ConnectResponseMeta(statusCode, headersMap, Map.of()); + decoder = new ConnectEnvelope.Decoder(ctx.alloc(), config.parameters().maxFrameBytes()); + observer.onResponseHeaders(responseMeta); + ctx.fireChannelRead(new ConnectClientResponseStart( + callStart.serviceDefinition(), callStart.methodDefinition(), responseMeta)); + } + + private void handleHttpContent(ChannelHandlerContext ctx, HttpContent content) { + if (closed || decoder == null) { + return; + } + + decoder.append(content.content()); + try { + ConnectEnvelope.DecodedFrame frame; + while ((frame = decoder.pollFrame()) != null) { + ByteBuf payload = frame.payload(); + boolean isEndStream = (frame.flags() & ConnectEnvelope.FLAG_END_STREAM) != 0; + + if (isEndStream) { + handleEndStreamFrame(ctx, frame.flags(), payload); + return; + } else { + handleDataFrame(ctx, frame.flags(), payload); + if (closed) { + return; + } + } + } + } catch (ConnectEnvelope.FrameTooLargeException e) { + closed = true; + ConnectError error = ConnectError.resourceExhausted(e.getMessage()); + observer.onCallComplete(error); + ctx.fireChannelRead(error); + } + } + + private void handleDataFrame(ChannelHandlerContext ctx, byte flags, ByteBuf payload) { + boolean isCompressed = (flags & ConnectEnvelope.FLAG_COMPRESSED) != 0; + if (isCompressed && responseEncoding.isIdentity()) { + payload.release(); + closed = true; + ConnectError error = ConnectError.internal("Received compressed message but no compression was negotiated"); + observer.onCallComplete(error); + ctx.fireChannelRead(error); + return; + } + + ByteBuf decompressed = payload; + if ((flags & ConnectEnvelope.FLAG_COMPRESSED) != 0) { + try { + decompressed = responseEncoding.decompress(payload, ctx.alloc()); + } catch (IOException e) { + closed = true; + ConnectError error = ConnectError.internal("Decompression failed: " + e.getMessage()); + observer.onCallComplete(error); + ctx.fireChannelRead(error); + return; + } finally { + payload.release(); + } + } + + ConnectMethodType type = callStart.methodDefinition().type(); + + if (type == ConnectMethodType.CLIENT_STREAMING && responsePayloadsReceived >= 1) { + decompressed.release(); + closed = true; + ConnectError error = ConnectError.unimplemented( + "Client-streaming method received more than one response message"); + observer.onCallComplete(error); + ctx.fireChannelRead(error); + return; + } + + Object decoded; + try { + decoded = codec.decode(decompressed, callStart.methodDefinition().responseType()); + } catch (IOException e) { + closed = true; + ConnectError error = ConnectError.internal("Deserialization failed: " + e.getMessage()); + observer.onCallComplete(error); + ctx.fireChannelRead(error); + return; + } finally { + decompressed.release(); + } + + observer.onResponsePayload(decoded); + responsePayloadsReceived++; + ctx.fireChannelRead(new ConnectPayload(decoded)); + } + + private void handleEndStreamFrame(ChannelHandlerContext ctx, byte flags, ByteBuf payload) { + endStreamReceived = true; + closed = true; + + ByteBuf decompressed = payload; + if ((flags & ConnectEnvelope.FLAG_COMPRESSED) != 0) { + try { + decompressed = responseEncoding.decompress(payload, ctx.alloc()); + } catch (IOException e) { + ConnectError error = ConnectError.internal("Decompression failed: " + e.getMessage()); + observer.onCallComplete(error); + ctx.fireChannelRead(error); + return; + } finally { + payload.release(); + } + } + + byte[] jsonBytes = ClientHandlerSupport.toByteArray(decompressed); + decompressed.release(); + + ConnectError error = config.jsonDeserializer().parseEndStreamError(jsonBytes); + Map> trailers = config.jsonDeserializer().parseEndStreamMetadata(jsonBytes); + + if (error != null) { + observer.onCallComplete(error); + ctx.fireChannelRead(new ConnectEndOfStream(trailers, error)); + } else { + if (callStart.methodDefinition().type() == ConnectMethodType.CLIENT_STREAMING + && responsePayloadsReceived == 0) { + ConnectError e = ConnectError.unimplemented( + "Client-streaming method received no response message"); + observer.onCallComplete(e); + ctx.fireChannelRead(new ConnectEndOfStream(trailers, e)); + return; + } + + ctx.fireChannelRead(new ConnectEndOfStream(trailers, null)); + observer.onCallComplete(null); + } + } + + private void handleLastHttpContent(ChannelHandlerContext ctx) { + if (!endStreamReceived && !closed) { + closed = true; + ConnectError error = ConnectError.internal("Truncated stream"); + observer.onCallComplete(error); + ctx.fireChannelRead(error); + } + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + closeDecoder(); + if (!closed) { + closed = true; + ConnectError error = ConnectError.canceled("Connection reset"); + observer.onCallComplete(error); + ctx.fireChannelRead(error); + } + ctx.fireChannelInactive(); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) { + closeDecoder(); + if (!closed) { + closed = true; + ConnectError error = ConnectError.canceled("Connection reset"); + observer.onCallComplete(error); + ctx.fireChannelRead(error); + } + } + + private void closeDecoder() { + if (decoder != null) { + decoder.close(); + decoder = null; + } + } +} diff --git a/src/main/java/io/suboptimal/connectjava/protocol/client/UnaryGetRequestClientHandler.java b/src/main/java/io/suboptimal/connectjava/protocol/client/UnaryGetRequestClientHandler.java new file mode 100644 index 0000000..cc3c296 --- /dev/null +++ b/src/main/java/io/suboptimal/connectjava/protocol/client/UnaryGetRequestClientHandler.java @@ -0,0 +1,147 @@ +package io.suboptimal.connectjava.protocol.client; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.util.ReferenceCountUtil; +import io.suboptimal.connectjava.api.ConnectEndOfStream; +import io.suboptimal.connectjava.api.ConnectError; +import io.suboptimal.connectjava.api.ConnectPayload; +import io.suboptimal.connectjava.codec.ConnectCodec; +import io.suboptimal.connectjava.compression.ConnectCompression; +import io.suboptimal.connectjava.protocol.ConnectCompressionNegotiation; +import io.suboptimal.connectjava.protocol.ConnectProtocolHttpHeaders; +import io.suboptimal.connectjava.protocol.ConnectProtocolVersion; +import org.jspecify.annotations.Nullable; + +import java.io.IOException; +import java.util.Base64; + +class UnaryGetRequestClientHandler extends ChannelOutboundHandlerAdapter { + private final ConnectClientCallStart callStart; + private final ConnectClientProtocolConfig config; + private final ConnectClientCallObserver observer; + private final ConnectCodec codec; + private final ConnectCompression requestEncoding; + + private byte @Nullable [] payloadBytes; + private State state = State.IDLE; + + private enum State { IDLE, WAITING_PAYLOAD, WAITING_EOS, TERMINATED } + + UnaryGetRequestClientHandler(ConnectClientCallStart callStart, + ConnectClientProtocolConfig config, + ConnectClientCallObserver observer) { + this.callStart = callStart; + this.config = config; + this.observer = observer; + this.codec = ClientHandlerSupport.selectRequestCodec(config, callStart.codecName()); + this.requestEncoding = ClientHandlerSupport.selectRequestEncoding(config, callStart.requestHeaders()); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + if (state == State.TERMINATED) { + promise.tryFailure(ConnectCallTerminatedException.INSTANCE); + ReferenceCountUtil.release(msg); + return; + } + + switch (msg) { + case ConnectClientCallStart ignored when state == State.IDLE -> { + state = State.WAITING_PAYLOAD; + promise.setSuccess(); + } + case ConnectPayload data when state == State.WAITING_PAYLOAD -> { + try { + ByteBuf encoded = codec.encode(data.data(), ctx.alloc()); + if (!requestEncoding.isIdentity()) { + ByteBuf compressed; + try { + compressed = requestEncoding.compress(encoded, ctx.alloc()); + } finally { + encoded.release(); + } + payloadBytes = ClientHandlerSupport.toByteArray(compressed); + compressed.release(); + } else { + payloadBytes = ClientHandlerSupport.toByteArray(encoded); + encoded.release(); + } + } catch (IOException e) { + promise.tryFailure(e); + state = State.TERMINATED; + return; + } + observer.onRequestPayload(data.data()); + state = State.WAITING_EOS; + promise.setSuccess(); + } + case ConnectEndOfStream ignored when state == State.WAITING_EOS -> { + byte[] bytes = payloadBytes != null ? payloadBytes : new byte[0]; + payloadBytes = null; + + String encoded = Base64.getUrlEncoder().withoutPadding().encodeToString(bytes); + String codecName = codec.name(); + + StringBuilder uriBuilder = new StringBuilder(); + uriBuilder.append('/').append(callStart.serviceDefinition().serviceName()) + .append('/').append(callStart.methodDefinition().methodName()) + .append("?message=").append(encoded) + .append("&encoding=").append(codecName) + .append("&base64=1") + .append("&connect=").append(ConnectProtocolVersion.QUERY_VERSION); + + if (!requestEncoding.isIdentity()) { + uriBuilder.append("&compression=").append(requestEncoding.name()); + } + + FullHttpRequest request = new DefaultFullHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.GET, uriBuilder.toString(), Unpooled.EMPTY_BUFFER); + + if (callStart.timeoutMs() != null) { + request.headers().set(ConnectProtocolHttpHeaders.CONNECT_TIMEOUT_MS, callStart.timeoutMs()); + } + + request.headers().set(HttpHeaderNames.ACCEPT_ENCODING, + ConnectCompressionNegotiation.formatSupportedEncodings(config.compressionRegistry())); + + ClientHandlerSupport.copyUserHeadersForUnaryCall(callStart.requestHeaders(), request.headers()); + + ctx.pipeline().addAfter( + ConnectClientPipeline.UNARY_GET_HANDLER, + ConnectClientPipeline.UNARY_RESPONSE_HANDLER, + new UnaryResponseClientHandler(callStart, config, observer)); + + observer.onRequestFinished(); + state = State.TERMINATED; + ctx.write(request, promise); + } + default -> { + promise.tryFailure(ConnectCallTerminatedException.INSTANCE); + ReferenceCountUtil.release(msg); + } + } + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) { + // Once the request is sent (TERMINATED) the response handler owns the terminal + // callback. If the channel is torn down before that — i.e. this outbound-only + // handler is removed while still pending — it must deliver onCallComplete itself, + // since no response handler was installed yet. + if (state != State.TERMINATED) { + state = State.TERMINATED; + observer.onCallComplete(ConnectError.canceled("Connection reset")); + } + } + +} diff --git a/src/main/java/io/suboptimal/connectjava/protocol/client/UnaryPostRequestClientHandler.java b/src/main/java/io/suboptimal/connectjava/protocol/client/UnaryPostRequestClientHandler.java new file mode 100644 index 0000000..0788ff8 --- /dev/null +++ b/src/main/java/io/suboptimal/connectjava/protocol/client/UnaryPostRequestClientHandler.java @@ -0,0 +1,145 @@ +package io.suboptimal.connectjava.protocol.client; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.util.ReferenceCountUtil; +import io.suboptimal.connectjava.api.ConnectEndOfStream; +import io.suboptimal.connectjava.api.ConnectError; +import io.suboptimal.connectjava.api.ConnectPayload; +import io.suboptimal.connectjava.codec.ConnectCodec; +import io.suboptimal.connectjava.compression.ConnectCompression; +import io.suboptimal.connectjava.protocol.ConnectCompressionNegotiation; +import io.suboptimal.connectjava.protocol.ConnectMediaType; +import io.suboptimal.connectjava.protocol.ConnectProtocolHttpHeaders; +import io.suboptimal.connectjava.protocol.ConnectProtocolVersion; +import org.jspecify.annotations.Nullable; + +import java.io.IOException; + +class UnaryPostRequestClientHandler extends ChannelOutboundHandlerAdapter { + private final ConnectClientCallStart callStart; + private final ConnectClientProtocolConfig config; + private final ConnectClientCallObserver observer; + private final ConnectCodec codec; + private final ConnectCompression requestEncoding; + + @Nullable + private ByteBuf payloadBuf; + private State state = State.IDLE; + + private enum State { IDLE, WAITING_PAYLOAD, WAITING_EOS, TERMINATED } + + UnaryPostRequestClientHandler(ConnectClientCallStart callStart, + ConnectClientProtocolConfig config, + ConnectClientCallObserver observer) { + this.callStart = callStart; + this.config = config; + this.observer = observer; + this.codec = ClientHandlerSupport.selectRequestCodec(config, callStart.codecName()); + this.requestEncoding = ClientHandlerSupport.selectRequestEncoding(config, callStart.requestHeaders()); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + if (state == State.TERMINATED) { + promise.tryFailure(ConnectCallTerminatedException.INSTANCE); + ReferenceCountUtil.release(msg); + return; + } + + switch (msg) { + case ConnectClientCallStart ignored when state == State.IDLE -> { + state = State.WAITING_PAYLOAD; + promise.setSuccess(); + } + case ConnectPayload data when state == State.WAITING_PAYLOAD -> { + try { + ByteBuf encoded = codec.encode(data.data(), ctx.alloc()); + if (!requestEncoding.isIdentity()) { + ByteBuf compressed; + try { + compressed = requestEncoding.compress(encoded, ctx.alloc()); + } finally { + encoded.release(); + } + payloadBuf = compressed; + } else { + payloadBuf = encoded; + } + } catch (IOException e) { + promise.tryFailure(e); + state = State.TERMINATED; + return; + } + observer.onRequestPayload(data.data()); + state = State.WAITING_EOS; + promise.setSuccess(); + } + case ConnectEndOfStream ignored when state == State.WAITING_EOS -> { + ByteBuf body = payloadBuf != null ? payloadBuf : Unpooled.EMPTY_BUFFER; + payloadBuf = null; + + String uri = "/" + callStart.serviceDefinition().serviceName() + + "/" + callStart.methodDefinition().methodName(); + + FullHttpRequest request = new DefaultFullHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.POST, uri, body); + + request.headers() + .set(HttpHeaderNames.CONTENT_TYPE, ConnectMediaType.unaryContentTypeFor(codec.name())) + .set(HttpHeaderNames.CONTENT_LENGTH, body.readableBytes()) + .set(ConnectProtocolHttpHeaders.CONNECT_PROTOCOL_VERSION, ConnectProtocolVersion.HEADER_VERSION); + + if (callStart.timeoutMs() != null) { + request.headers().set(ConnectProtocolHttpHeaders.CONNECT_TIMEOUT_MS, callStart.timeoutMs()); + } + + if (!requestEncoding.isIdentity()) { + request.headers().set(HttpHeaderNames.CONTENT_ENCODING, requestEncoding.name()); + } + + request.headers().set(HttpHeaderNames.ACCEPT_ENCODING, + ConnectCompressionNegotiation.formatSupportedEncodings(config.compressionRegistry())); + + ClientHandlerSupport.copyUserHeadersForUnaryCall(callStart.requestHeaders(), request.headers()); + + ctx.pipeline().addAfter( + ConnectClientPipeline.UNARY_POST_HANDLER, + ConnectClientPipeline.UNARY_RESPONSE_HANDLER, + new UnaryResponseClientHandler(callStart, config, observer)); + + observer.onRequestFinished(); + state = State.TERMINATED; + ctx.write(request, promise); + } + default -> { + promise.tryFailure(ConnectCallTerminatedException.INSTANCE); + ReferenceCountUtil.release(msg); + } + } + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) { + if (payloadBuf != null) { + payloadBuf.release(); + payloadBuf = null; + } + // Once the request is sent (TERMINATED) the response handler owns the terminal + // callback. If the channel is torn down before that — i.e. this outbound-only + // handler is removed while still pending — it must deliver onCallComplete itself, + // since no response handler was installed yet. + if (state != State.TERMINATED) { + state = State.TERMINATED; + observer.onCallComplete(ConnectError.canceled("Connection reset")); + } + } +} diff --git a/src/main/java/io/suboptimal/connectjava/protocol/client/UnaryResponseClientHandler.java b/src/main/java/io/suboptimal/connectjava/protocol/client/UnaryResponseClientHandler.java new file mode 100644 index 0000000..2acab6c --- /dev/null +++ b/src/main/java/io/suboptimal/connectjava/protocol/client/UnaryResponseClientHandler.java @@ -0,0 +1,186 @@ +package io.suboptimal.connectjava.protocol.client; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.suboptimal.connectjava.api.ConnectClientResponseStart; +import io.suboptimal.connectjava.api.ConnectResponseMeta; +import io.suboptimal.connectjava.api.ConnectEndOfStream; +import io.suboptimal.connectjava.api.ConnectError; +import io.suboptimal.connectjava.api.ConnectErrorCode; +import io.suboptimal.connectjava.api.ConnectErrorDetail; +import io.suboptimal.connectjava.api.ConnectPayload; +import io.suboptimal.connectjava.codec.ConnectCodec; +import io.suboptimal.connectjava.compression.ConnectCompression; +import io.suboptimal.connectjava.compression.ConnectIdentityCompression; +import io.suboptimal.connectjava.protocol.ConnectCompressionNegotiation; + +import java.io.IOException; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +class UnaryResponseClientHandler extends SimpleChannelInboundHandler { + private final ConnectClientCallStart callStart; + private final ConnectClientProtocolConfig config; + private final ConnectClientCallObserver observer; + private boolean closed; + + UnaryResponseClientHandler(ConnectClientCallStart callStart, + ConnectClientProtocolConfig config, + ConnectClientCallObserver observer) { + super(true); + this.callStart = callStart; + this.config = config; + this.observer = observer; + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, FullHttpResponse response) { + int statusCode = response.status().code(); + ConnectResponseMeta meta = buildMeta(statusCode, response); + + observer.onResponseHeaders(meta); + + ConnectClientResponseStart responseStart = + new ConnectClientResponseStart(callStart.serviceDefinition(), callStart.methodDefinition(), meta); + + ctx.fireChannelRead(responseStart); + + if (statusCode != 200) { + ConnectError error = parseErrorResponse(ctx, response, statusCode); + closed = true; + observer.onCallComplete(error); + ctx.fireChannelRead(error); + return; + } + + String codecName = ClientHandlerSupport.codecNameForContentType(response.headers().get(HttpHeaderNames.CONTENT_TYPE)); + ConnectCodec codec = codecName != null ? config.codecRegistry().byName(codecName) : null; + if (codec == null) { + ConnectError error = ConnectError.unknown("Unsupported or missing Content-Type in response"); + closed = true; + observer.onCallComplete(error); + ctx.fireChannelRead(error); + return; + } + + String requestCodecName = callStart.codecName(); + if (requestCodecName != null && !requestCodecName.equals(codecName)) { + ConnectError error = ConnectError.internal("Response codec '" + codecName + "' does not match request codec '" + requestCodecName + "'"); + closed = true; + observer.onCallComplete(error); + ctx.fireChannelRead(error); + return; + } + + ConnectCompression decompression = resolveResponseEncoding(response.headers().get(HttpHeaderNames.CONTENT_ENCODING)); + + ByteBuf body = response.content(); + ByteBuf decompressed; + try { + decompressed = ConnectCompressionNegotiation.decompressMessage(ctx.alloc(), body, decompression); + } catch (IOException e) { + ConnectError error = ConnectError.internal("Decompression failed: " + e.getMessage()); + closed = true; + observer.onCallComplete(error); + ctx.fireChannelRead(error); + return; + } + + Object decoded; + try { + decoded = codec.decode(decompressed, callStart.methodDefinition().responseType()); + } catch (IOException e) { + ConnectError error = ConnectError.internal("Deserialization failed: " + e.getMessage()); + closed = true; + observer.onCallComplete(error); + ctx.fireChannelRead(error); + return; + } finally { + decompressed.release(); + } + + closed = true; + observer.onResponsePayload(decoded); + ctx.fireChannelRead(new ConnectPayload(decoded)); + ctx.fireChannelRead(ConnectEndOfStream.INSTANCE); + observer.onCallComplete(null); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + if (!closed) { + closed = true; + ConnectError error = ConnectError.canceled("Connection reset"); + observer.onCallComplete(error); + ctx.fireChannelRead(error); + } + ctx.fireChannelInactive(); + } + + private ConnectError parseErrorResponse(ChannelHandlerContext ctx, FullHttpResponse response, int statusCode) { + ConnectCompression decompression = + resolveResponseEncoding(response.headers().get(HttpHeaderNames.CONTENT_ENCODING)); + + byte[] body; + try { + ByteBuf decompressed = ConnectCompressionNegotiation.decompressMessage( + ctx.alloc(), response.content(), decompression); + try { + body = ClientHandlerSupport.toByteArray(decompressed); + } finally { + decompressed.release(); + } + } catch (IOException e) { + // Body cannot be decoded; fall back to the HTTP-status mapping below. + body = new byte[0]; + } + + ConnectErrorBody parsed = body.length > 0 + ? config.jsonDeserializer().parseErrorBody(body) : null; + + ConnectErrorCode code = null; + if (parsed != null && parsed.codeName() != null) { + code = ClientHandlerSupport.findErrorCodeByWireName(parsed.codeName()); + } + if (code == null) { + code = ClientHandlerSupport.httpStatusToErrorCode(statusCode); + } + + String message = (parsed != null && parsed.message() != null) + ? parsed.message() : response.status().reasonPhrase(); + + java.util.List details = parsed != null ? parsed.details() : java.util.List.of(); + return new ConnectError(code, message, details); + } + + private ConnectCompression resolveResponseEncoding(String encodingHeader) { + String name = ConnectCompressionNegotiation.compressionNameFor(encodingHeader); + if (name == null) { + return ConnectIdentityCompression.INSTANCE; + } + ConnectCompression c = config.compressionRegistry().resolve(name); + return c != null ? c : ConnectIdentityCompression.INSTANCE; + } + + private static ConnectResponseMeta buildMeta(int statusCode, FullHttpResponse response) { + Map> all = new LinkedHashMap<>(); + all.putAll(ClientHandlerSupport.toHeaderMap(response.headers())); + all.putAll(ClientHandlerSupport.toHeaderMap(response.trailingHeaders())); + + Map> headers = new LinkedHashMap<>(); + Map> trailers = new LinkedHashMap<>(); + for (Map.Entry> entry : all.entrySet()) { + String name = entry.getKey(); + if (name.startsWith("trailer-")) { + trailers.put(name.substring("trailer-".length()), entry.getValue()); + } else { + headers.put(name, entry.getValue()); + } + } + return new ConnectResponseMeta(statusCode, headers, trailers); + } +} diff --git a/src/main/java/io/suboptimal/connectjava/protocol/client/package-info.java b/src/main/java/io/suboptimal/connectjava/protocol/client/package-info.java new file mode 100644 index 0000000..faa3d3a --- /dev/null +++ b/src/main/java/io/suboptimal/connectjava/protocol/client/package-info.java @@ -0,0 +1,4 @@ +@NullMarked +package io.suboptimal.connectjava.protocol.client; + +import org.jspecify.annotations.NullMarked; diff --git a/src/main/java/io/suboptimal/connectjava/protocol/ConnectCorsParameters.java b/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectCorsParameters.java similarity index 97% rename from src/main/java/io/suboptimal/connectjava/protocol/ConnectCorsParameters.java rename to src/main/java/io/suboptimal/connectjava/protocol/server/ConnectCorsParameters.java index b469fdb..228f63a 100644 --- a/src/main/java/io/suboptimal/connectjava/protocol/ConnectCorsParameters.java +++ b/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectCorsParameters.java @@ -1,9 +1,9 @@ -package io.suboptimal.connectjava.protocol; +package io.suboptimal.connectjava.protocol.server; import java.util.Set; /** - * CORS configuration for {@link ConnectProtocol}. + * CORS configuration for {@link ConnectServerProtocol}. * *

Use {@link #disabled()} when CORS is not needed, {@link #defaultsForAnyOrigin()} for * wildcard-origin CORS, or {@link #defaultsForOrigins(Set)} for exact-origin allowlists. diff --git a/src/main/java/io/suboptimal/connectjava/protocol/ConnectEndStreamMeta.java b/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectEndStreamMeta.java similarity index 97% rename from src/main/java/io/suboptimal/connectjava/protocol/ConnectEndStreamMeta.java rename to src/main/java/io/suboptimal/connectjava/protocol/server/ConnectEndStreamMeta.java index c28a954..12da294 100644 --- a/src/main/java/io/suboptimal/connectjava/protocol/ConnectEndStreamMeta.java +++ b/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectEndStreamMeta.java @@ -1,4 +1,4 @@ -package io.suboptimal.connectjava.protocol; +package io.suboptimal.connectjava.protocol.server; import java.util.ArrayList; import java.util.LinkedHashMap; diff --git a/src/main/java/io/suboptimal/connectjava/protocol/ConnectEndStreamResponse.java b/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectEndStreamResponse.java similarity index 90% rename from src/main/java/io/suboptimal/connectjava/protocol/ConnectEndStreamResponse.java rename to src/main/java/io/suboptimal/connectjava/protocol/server/ConnectEndStreamResponse.java index 53cbd02..3ebf971 100644 --- a/src/main/java/io/suboptimal/connectjava/protocol/ConnectEndStreamResponse.java +++ b/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectEndStreamResponse.java @@ -1,4 +1,4 @@ -package io.suboptimal.connectjava.protocol; +package io.suboptimal.connectjava.protocol.server; import io.suboptimal.connectjava.api.ConnectError; import org.jspecify.annotations.Nullable; diff --git a/src/main/java/io/suboptimal/connectjava/protocol/ConnectJsonSerializer.java b/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectJsonSerializer.java similarity index 91% rename from src/main/java/io/suboptimal/connectjava/protocol/ConnectJsonSerializer.java rename to src/main/java/io/suboptimal/connectjava/protocol/server/ConnectJsonSerializer.java index 9d38e4b..30bc688 100644 --- a/src/main/java/io/suboptimal/connectjava/protocol/ConnectJsonSerializer.java +++ b/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectJsonSerializer.java @@ -1,4 +1,4 @@ -package io.suboptimal.connectjava.protocol; +package io.suboptimal.connectjava.protocol.server; import io.suboptimal.connectjava.api.ConnectError; @@ -16,7 +16,7 @@ * *

The default implementation is {@link ConnectStringBuilderJsonSerializer#INSTANCE}. * A custom implementation (e.g. Jackson-based) can be supplied via - * {@link ConnectProtocolConfig.Builder#jsonSerializer(ConnectJsonSerializer)}. + * {@link ConnectServerProtocolConfig.Builder#jsonSerializer(ConnectJsonSerializer)}. */ public interface ConnectJsonSerializer { /** diff --git a/src/main/java/io/suboptimal/connectjava/protocol/ConnectMetaBuilder.java b/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectMetaBuilder.java similarity index 94% rename from src/main/java/io/suboptimal/connectjava/protocol/ConnectMetaBuilder.java rename to src/main/java/io/suboptimal/connectjava/protocol/server/ConnectMetaBuilder.java index 6aa596f..18abd02 100644 --- a/src/main/java/io/suboptimal/connectjava/protocol/ConnectMetaBuilder.java +++ b/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectMetaBuilder.java @@ -1,4 +1,4 @@ -package io.suboptimal.connectjava.protocol; +package io.suboptimal.connectjava.protocol.server; import io.netty.handler.codec.http.HttpHeaders; import io.suboptimal.connectjava.api.ConnectRequestMeta; diff --git a/src/main/java/io/suboptimal/connectjava/protocol/ConnectRoute.java b/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectRoute.java similarity index 94% rename from src/main/java/io/suboptimal/connectjava/protocol/ConnectRoute.java rename to src/main/java/io/suboptimal/connectjava/protocol/server/ConnectRoute.java index e9acc19..3b156c9 100644 --- a/src/main/java/io/suboptimal/connectjava/protocol/ConnectRoute.java +++ b/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectRoute.java @@ -1,4 +1,4 @@ -package io.suboptimal.connectjava.protocol; +package io.suboptimal.connectjava.protocol.server; import org.jspecify.annotations.Nullable; diff --git a/src/main/java/io/suboptimal/connectjava/protocol/ConnectCallObserver.java b/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectServerCallObserver.java similarity index 85% rename from src/main/java/io/suboptimal/connectjava/protocol/ConnectCallObserver.java rename to src/main/java/io/suboptimal/connectjava/protocol/server/ConnectServerCallObserver.java index 7ede066..0ede55e 100644 --- a/src/main/java/io/suboptimal/connectjava/protocol/ConnectCallObserver.java +++ b/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectServerCallObserver.java @@ -1,4 +1,4 @@ -package io.suboptimal.connectjava.protocol; +package io.suboptimal.connectjava.protocol.server; import io.suboptimal.connectjava.api.ConnectEndOfStream; import io.suboptimal.connectjava.api.ConnectError; @@ -15,8 +15,8 @@ * interceptors see the final metadata state last. Observer exceptions are not swallowed: * they propagate through the Netty pipeline like other user-code failures. */ -public interface ConnectCallObserver { - ConnectCallObserver NOOP = new NoOpConnectCallObserver(); +public interface ConnectServerCallObserver { + ConnectServerCallObserver NOOP = new NoOpConnectCallObserver(); // Inbound (FIFO) @@ -25,7 +25,7 @@ public interface ConnectCallObserver { * {@link ConnectPayload} is forwarded to the terminal handler. * *

Invoked in interceptor registration order (FIFO). Not called if the call was rejected - * before {@link ConnectInterceptor#interceptCall} returned. + * before {@link ConnectServerInterceptor#interceptCall} returned. */ default void onRequestPayload(Object payload) {} @@ -34,7 +34,7 @@ default void onRequestPayload(Object payload) {} * {@link ConnectEndOfStream} is forwarded to the terminal handler. * *

Invoked in interceptor registration order (FIFO). Not called if the call was rejected - * before {@link ConnectInterceptor#interceptCall} returned. + * before {@link ConnectServerInterceptor#interceptCall} returned. */ default void onRequestFinished() {} @@ -65,5 +65,5 @@ default void onCallComplete(@Nullable ConnectError error) {} /** * No-op implementation of the call observer. */ - record NoOpConnectCallObserver() implements ConnectCallObserver {} + record NoOpConnectCallObserver() implements ConnectServerCallObserver {} } diff --git a/src/main/java/io/suboptimal/connectjava/protocol/ConnectChannelConfigurer.java b/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectServerChannelConfigurer.java similarity index 78% rename from src/main/java/io/suboptimal/connectjava/protocol/ConnectChannelConfigurer.java rename to src/main/java/io/suboptimal/connectjava/protocol/server/ConnectServerChannelConfigurer.java index 7524468..68e67f9 100644 --- a/src/main/java/io/suboptimal/connectjava/protocol/ConnectChannelConfigurer.java +++ b/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectServerChannelConfigurer.java @@ -1,4 +1,4 @@ -package io.suboptimal.connectjava.protocol; +package io.suboptimal.connectjava.protocol.server; import io.netty.channel.Channel; import io.netty.handler.codec.http.HttpMethod; @@ -7,15 +7,15 @@ import io.netty.handler.codec.http.cors.CorsHandler; import io.netty.handler.codec.http2.Http2StreamFrameToHttpObjectCodec; -public class ConnectChannelConfigurer { +public class ConnectServerChannelConfigurer { private final ConnectTransport transport; - private final ConnectProtocolConfig config; - private final RoutingHandler routingHandler; + private final ConnectServerProtocolConfig config; + private final RoutingServerHandler routingHandler; - ConnectChannelConfigurer(ConnectTransport transport, ConnectProtocolConfig config) { + ConnectServerChannelConfigurer(ConnectTransport transport, ConnectServerProtocolConfig config) { this.transport = transport; this.config = config; - this.routingHandler = new RoutingHandler(transport, config); + this.routingHandler = new RoutingServerHandler(transport, config); } public void configure(Channel channel) { @@ -23,10 +23,10 @@ public void configure(Channel channel) { channel.pipeline().addLast(new Http2StreamFrameToHttpObjectCodec(true)); } if (config.parameters().cors().enabled()) { - channel.pipeline().addLast(ConnectPipeline.CORS_HANDLER, buildCorsHandler()); + channel.pipeline().addLast(ConnectServerPipeline.CORS_HANDLER, buildCorsHandler()); } channel.pipeline() - .addLast(ConnectPipeline.ROUTING_HANDLER, routingHandler) + .addLast(ConnectServerPipeline.ROUTING_HANDLER, routingHandler) .addLast(config.callHandlerFactory().create()); } diff --git a/src/main/java/io/suboptimal/connectjava/protocol/ConnectInterceptor.java b/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectServerInterceptor.java similarity index 75% rename from src/main/java/io/suboptimal/connectjava/protocol/ConnectInterceptor.java rename to src/main/java/io/suboptimal/connectjava/protocol/server/ConnectServerInterceptor.java index 1e103a7..3eb244d 100644 --- a/src/main/java/io/suboptimal/connectjava/protocol/ConnectInterceptor.java +++ b/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectServerInterceptor.java @@ -1,4 +1,4 @@ -package io.suboptimal.connectjava.protocol; +package io.suboptimal.connectjava.protocol.server; import io.suboptimal.connectjava.api.ConnectCallExchange; import io.suboptimal.connectjava.api.ConnectError; @@ -9,22 +9,22 @@ *

Registered interceptors are invoked in registration order after Connect request metadata is * decoded and after basic Connect protocol validation (version header, compression negotiation). * Each interceptor either continues the call by returning {@link #continueCall()} / - * {@link #continueWith(ConnectCallObserver)}, or rejects it with a Connect-native error by + * {@link #continueWith(ConnectServerCallObserver)}, or rejects it with a Connect-native error by * returning {@link #reject(ConnectError)}. * *

If the interceptor returns {@link Decision.Continue}, the pipeline guarantees that the - * supplied {@link ConnectCallObserver} receives exactly one {@link ConnectCallObserver#onCallComplete} + * supplied {@link ConnectServerCallObserver} receives exactly one {@link ConnectServerCallObserver#onCallComplete} * callback, regardless of whether the call succeeds, fails, or is cancelled by the client. * *

If the interceptor returns {@link Decision.Reject}, the call is stopped at the Connect layer; * the service is never invoked and no observer callbacks are delivered to that interceptor. */ -public interface ConnectInterceptor { +public interface ConnectServerInterceptor { /** * Called once per Connect RPC call after request metadata has been decoded and basic protocol * validation (version header, compression negotiation) has passed. * - *

Return {@link #continueCall()} or {@link #continueWith(ConnectCallObserver)} to let the + *

Return {@link #continueCall()} or {@link #continueWith(ConnectServerCallObserver)} to let the * call proceed, or {@link #reject(ConnectError)} to abort it with a Connect-native error. * * @param exchange call metadata and mutable response builders for this RPC @@ -36,15 +36,15 @@ public interface ConnectInterceptor { * Returns a {@link Decision} that continues the call without attaching an observer. */ static Decision continueCall() { - return new Decision.Continue(ConnectCallObserver.NOOP); + return new Decision.Continue(ConnectServerCallObserver.NOOP); } /** * Returns a {@link Decision} that continues the call and attaches {@code observer} to receive * lifecycle callbacks. The observer is guaranteed to receive exactly one - * {@link ConnectCallObserver#onCallComplete} for the lifetime of the call. + * {@link ConnectServerCallObserver#onCallComplete} for the lifetime of the call. */ - static Decision continueWith(ConnectCallObserver observer) { + static Decision continueWith(ConnectServerCallObserver observer) { return new Decision.Continue(observer); } @@ -56,26 +56,26 @@ static Decision continueWith(ConnectCallObserver observer) { * {@link Decision.Continue} still receive their terminal callbacks. */ static Decision reject(ConnectError error) { - return new Decision.Reject(ConnectCallObserver.NOOP, error); + return new Decision.Reject(ConnectServerCallObserver.NOOP, error); } /** - * Result of {@link ConnectInterceptor#interceptCall(ConnectCallExchange)}. + * Result of {@link ConnectServerInterceptor#interceptCall(ConnectCallExchange)}. * - *

Use the static factory methods on {@link ConnectInterceptor} to create instances. + *

Use the static factory methods on {@link ConnectServerInterceptor} to create instances. * The {@link #observer()} accessor returns the effective observer associated with this decision; * for {@link Reject} it is the composite of all observers from prior {@link Continue} decisions. */ sealed interface Decision { /** Returns the observer associated with this decision. */ - ConnectCallObserver observer(); + ConnectServerCallObserver observer(); /** * Continue the call and attach a per-call observer. * * @param observer observer that receives Connect call lifecycle callbacks */ - record Continue(ConnectCallObserver observer) implements Decision {} + record Continue(ConnectServerCallObserver observer) implements Decision {} /** * Reject the call with a Connect-native error. @@ -87,7 +87,7 @@ record Continue(ConnectCallObserver observer) implements Decision {} * @param observer composite observer of all prior {@link Continue} decisions * @param error error to serialize to the Connect response */ - record Reject(ConnectCallObserver observer, ConnectError error) implements Decision {} + record Reject(ConnectServerCallObserver observer, ConnectError error) implements Decision {} } } diff --git a/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectServerInterceptorPipeline.java b/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectServerInterceptorPipeline.java new file mode 100644 index 0000000..0cbf1dc --- /dev/null +++ b/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectServerInterceptorPipeline.java @@ -0,0 +1,93 @@ +package io.suboptimal.connectjava.protocol.server; + +import io.suboptimal.connectjava.api.ConnectCallExchange; +import io.suboptimal.connectjava.api.ConnectError; +import io.suboptimal.connectjava.api.ConnectResponseHeadersBuilder; +import io.suboptimal.connectjava.api.ConnectResponseTrailersBuilder; +import org.jspecify.annotations.Nullable; + +import java.util.ArrayList; +import java.util.List; + +/** + * Builds the per-call observer chain for registered Connect interceptors. + */ +final class ConnectServerInterceptorPipeline { + static final ConnectServerInterceptorPipeline EMPTY = new ConnectServerInterceptorPipeline(List.of()); + + private final List interceptors; + + ConnectServerInterceptorPipeline(List interceptors) { + this.interceptors = List.copyOf(interceptors); + } + + /** + * Runs all registered interceptors in order, building a composite observer from the + * {@link ConnectServerInterceptor.Decision.Continue} results. + * + *

If any interceptor returns {@link ConnectServerInterceptor.Decision.Reject}, iteration stops + * immediately and a {@link ConnectServerInterceptor.Decision.Reject} carrying the composite of all + * prior {@link ConnectServerInterceptor.Decision.Continue} observers (and the rejection error) is + * returned. {@link ConnectServerCallObserver#NOOP} observers are filtered out of the composite. + */ + ConnectServerInterceptor.Decision interceptCall(ConnectCallExchange exchange) { + if (interceptors.isEmpty()) { + return ConnectServerInterceptor.continueCall(); + } + + List observers = new ArrayList<>(interceptors.size()); + for (ConnectServerInterceptor interceptor : interceptors) { + switch (interceptor.interceptCall(exchange)) { + case ConnectServerInterceptor.Decision.Continue(ConnectServerCallObserver observer) -> observers.add(observer); + case ConnectServerInterceptor.Decision.Reject(ConnectServerCallObserver ignore, ConnectError error) -> { + return new ConnectServerInterceptor.Decision.Reject(composite(observers), error); + } + } + } + return ConnectServerInterceptor.continueWith(composite(observers)); + } + + private static ConnectServerCallObserver composite(List observers) { + List filtered = observers.stream() + .filter(o -> o != ConnectServerCallObserver.NOOP) + .toList(); + + if (filtered.isEmpty()) { + return ConnectServerCallObserver.NOOP; + } + if (filtered.size() == 1) { + return filtered.getFirst(); + } + return new CompositeConnectCallObserver(filtered); + } + + private record CompositeConnectCallObserver(List observers) + implements ConnectServerCallObserver + { + @Override + public void onResponseHeaders(ConnectResponseHeadersBuilder headers) { + for (int i = observers.size() - 1; i >= 0; i--) { + observers.get(i).onResponseHeaders(headers); + } + } + + @Override + public void onResponsePayload(Object payload) { + observers.forEach(observer -> observer.onResponsePayload(payload)); + } + + @Override + public void onResponseTrailers(ConnectResponseTrailersBuilder trailers, @Nullable ConnectError error) { + for (int i = observers.size() - 1; i >= 0; i--) { + observers.get(i).onResponseTrailers(trailers, error); + } + } + + @Override + public void onCallComplete(@Nullable ConnectError error) { + for (int i = observers.size() - 1; i >= 0; i--) { + observers.get(i).onCallComplete(error); + } + } + } +} diff --git a/src/main/java/io/suboptimal/connectjava/protocol/ConnectPipeline.java b/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectServerPipeline.java similarity index 85% rename from src/main/java/io/suboptimal/connectjava/protocol/ConnectPipeline.java rename to src/main/java/io/suboptimal/connectjava/protocol/server/ConnectServerPipeline.java index ff8fc90..1f44550 100644 --- a/src/main/java/io/suboptimal/connectjava/protocol/ConnectPipeline.java +++ b/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectServerPipeline.java @@ -1,4 +1,4 @@ -package io.suboptimal.connectjava.protocol; +package io.suboptimal.connectjava.protocol.server; import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.cors.CorsHandler; @@ -11,15 +11,15 @@ * Exposed so that {@link io.netty.channel.ChannelPipeline#addAfter} (and friends) * can reference handlers by name without holding direct references. */ -public final class ConnectPipeline { - /** Name of the {@link CorsHandler} optionally installed by {@link ConnectChannelConfigurer}. */ +public final class ConnectServerPipeline { + /** Name of the {@link CorsHandler} optionally installed by {@link ConnectServerChannelConfigurer}. */ public static final String CORS_HANDLER = "connectCors"; - /** Name of the {@link RoutingHandler} installed by {@link ConnectChannelConfigurer}. */ + /** Name of the {@link RoutingServerHandler} installed by {@link ConnectServerChannelConfigurer}. */ public static final String ROUTING_HANDLER = "connectRouting"; /** - * Name of the standard Netty {@link HttpObjectAggregator} that {@link RoutingHandler} + * Name of the standard Netty {@link HttpObjectAggregator} that {@link RoutingServerHandler} * installs ahead of unary handlers. Connect itself does not * provide a custom aggregator — only this name slot in the pipeline. */ @@ -52,8 +52,8 @@ public final class ConnectPipeline { */ public static final String UNARY_RESPONSE_HANDLER = "connectUnaryResponse"; - /** Name of the {@link StreamingHandler} installed on the streaming path. */ + /** Name of the {@link StreamingServerHandler} installed on the streaming path. */ public static final String STREAMING_HANDLER = "connectStreaming"; - private ConnectPipeline() {} + private ConnectServerPipeline() {} } diff --git a/src/main/java/io/suboptimal/connectjava/protocol/ConnectProtocol.java b/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectServerProtocol.java similarity index 63% rename from src/main/java/io/suboptimal/connectjava/protocol/ConnectProtocol.java rename to src/main/java/io/suboptimal/connectjava/protocol/server/ConnectServerProtocol.java index 4471e2e..4c04214 100644 --- a/src/main/java/io/suboptimal/connectjava/protocol/ConnectProtocol.java +++ b/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectServerProtocol.java @@ -1,4 +1,6 @@ -package io.suboptimal.connectjava.protocol; +package io.suboptimal.connectjava.protocol.server; + +import io.suboptimal.connectjava.protocol.ConnectCallHandlerFactory; /** * Buf Connect protocol over HTTP/1.1 and HTTP/2. @@ -28,50 +30,50 @@ * protocol.http2().configure(channel); * } * - * @see ConnectProtocolConfig + * @see ConnectServerProtocolConfig */ -public class ConnectProtocol { - private final ConnectChannelConfigurer http1Configurer; - private final ConnectChannelConfigurer http2Configurer; +public class ConnectServerProtocol { + private final ConnectServerChannelConfigurer http1Configurer; + private final ConnectServerChannelConfigurer http2Configurer; /** - * Creates a new {@link ConnectProtocol} from the given configuration. + * Creates a new {@link ConnectServerProtocol} from the given configuration. * - *

Use {@link ConnectProtocolConfig#builder(java.util.Map, ConnectCallHandlerFactory, - * ConnectProtocolParameters, io.suboptimal.connectjava.codec.ConnectCodecRegistry)} + *

Use {@link ConnectServerProtocolConfig#builder(java.util.Map, ConnectCallHandlerFactory, + * ConnectServerProtocolParameters, io.suboptimal.connectjava.codec.ConnectCodecRegistry)} * to build a configuration with the desired options. ConnectCompression, JSON serializer, * and interceptors carry defaults and only need to be specified when overriding * those defaults. * * @param config protocol configuration; must not be {@code null} - * @see ConnectProtocolConfig + * @see ConnectServerProtocolConfig */ - public ConnectProtocol(ConnectProtocolConfig config) { - this.http1Configurer = new ConnectChannelConfigurer(ConnectTransport.HTTP_1_1, config); - this.http2Configurer = new ConnectChannelConfigurer(ConnectTransport.HTTP_2, config); + public ConnectServerProtocol(ConnectServerProtocolConfig config) { + this.http1Configurer = new ConnectServerChannelConfigurer(ConnectTransport.HTTP_1_1, config); + this.http2Configurer = new ConnectServerChannelConfigurer(ConnectTransport.HTTP_2, config); } /** * Returns the configurer for HTTP/1.1 channels. * - *

The channel passed to {@link ConnectChannelConfigurer#configure(io.netty.channel.Channel)} + *

The channel passed to {@link ConnectServerChannelConfigurer#configure(io.netty.channel.Channel)} * must already have an HTTP/1.1 codec installed (e.g. {@code HttpServerCodec}). * The configurer appends an optional CORS handler, the routing handler, and the terminal * call handler to the pipeline. */ - public ConnectChannelConfigurer http1() { + public ConnectServerChannelConfigurer http1() { return http1Configurer; } /** * Returns the configurer for HTTP/2 stream channels. * - *

The channel passed to {@link ConnectChannelConfigurer#configure(io.netty.channel.Channel)} + *

The channel passed to {@link ConnectServerChannelConfigurer#configure(io.netty.channel.Channel)} * must be a child channel produced by {@code Http2MultiplexHandler}. The configurer prepends * {@code Http2StreamFrameToHttpObjectCodec}, then adds an optional CORS handler, the routing * handler, and the terminal call handler. */ - public ConnectChannelConfigurer http2() { + public ConnectServerChannelConfigurer http2() { return http2Configurer; } } diff --git a/src/main/java/io/suboptimal/connectjava/protocol/ConnectProtocolConfig.java b/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectServerProtocolConfig.java similarity index 81% rename from src/main/java/io/suboptimal/connectjava/protocol/ConnectProtocolConfig.java rename to src/main/java/io/suboptimal/connectjava/protocol/server/ConnectServerProtocolConfig.java index 1b01753..1724208 100644 --- a/src/main/java/io/suboptimal/connectjava/protocol/ConnectProtocolConfig.java +++ b/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectServerProtocolConfig.java @@ -1,16 +1,17 @@ -package io.suboptimal.connectjava.protocol; +package io.suboptimal.connectjava.protocol.server; import io.suboptimal.connectjava.codec.ConnectCodecRegistry; import io.suboptimal.connectjava.compression.ConnectCompressionRegistry; import io.suboptimal.connectjava.model.ConnectServiceDefinition; +import io.suboptimal.connectjava.protocol.ConnectCallHandlerFactory; import java.util.List; import java.util.Map; /** - * Immutable configuration for a {@link ConnectProtocol} instance. + * Immutable configuration for a {@link ConnectServerProtocol} instance. * - *

Create an instance via {@link #builder(Map, ConnectCallHandlerFactory, ConnectProtocolParameters, ConnectCodecRegistry)}: + *

Create an instance via {@link #builder(Map, ConnectCallHandlerFactory, ConnectServerProtocolParameters, ConnectCodecRegistry)}: *

{@code
  * ConnectProtocolConfig config = ConnectProtocolConfig
  *     .builder(services, callHandlerFactory,
@@ -29,17 +30,17 @@
  * @param jsonSerializer      JSON serializer for Connect error bodies and end-stream envelopes
  * @param interceptors        Connect protocol interceptors, invoked in registration order
  */
-public record ConnectProtocolConfig(
+public record ConnectServerProtocolConfig(
     Map services,
     ConnectCallHandlerFactory callHandlerFactory,
-    ConnectProtocolParameters parameters,
+    ConnectServerProtocolParameters parameters,
     ConnectCodecRegistry codecRegistry,
     ConnectCompressionRegistry compressionRegistry,
     ConnectJsonSerializer jsonSerializer,
-    List interceptors
+    List interceptors
 ) {
     /** Validates and makes all collection components immutable. */
-    public ConnectProtocolConfig {
+    public ConnectServerProtocolConfig {
         services = Map.copyOf(services);
         interceptors = List.copyOf(interceptors);
     }
@@ -55,25 +56,25 @@ public record ConnectProtocolConfig(
      */
     public static Builder builder(Map services,
                                   ConnectCallHandlerFactory callHandlerFactory,
-                                  ConnectProtocolParameters parameters,
+                                  ConnectServerProtocolParameters parameters,
                                   ConnectCodecRegistry codecRegistry) {
         return new Builder(services, callHandlerFactory, parameters, codecRegistry);
     }
 
-    /** Builder for {@link ConnectProtocolConfig}. */
+    /** Builder for {@link ConnectServerProtocolConfig}. */
     public static final class Builder {
 
         private final Map services;
         private final ConnectCallHandlerFactory callHandlerFactory;
-        private final ConnectProtocolParameters parameters;
+        private final ConnectServerProtocolParameters parameters;
         private final ConnectCodecRegistry codecRegistry;
         private ConnectCompressionRegistry compressionRegistry = ConnectCompressionRegistry.standard();
         private ConnectJsonSerializer jsonSerializer = ConnectStringBuilderJsonSerializer.INSTANCE;
-        private List interceptors = List.of();
+        private List interceptors = List.of();
 
         private Builder(Map services,
                         ConnectCallHandlerFactory callHandlerFactory,
-                        ConnectProtocolParameters parameters,
+                        ConnectServerProtocolParameters parameters,
                         ConnectCodecRegistry codecRegistry) {
             this.services = services;
             this.callHandlerFactory = callHandlerFactory;
@@ -103,18 +104,18 @@ public Builder jsonSerializer(ConnectJsonSerializer jsonSerializer) {
          * Sets the Connect protocol interceptors, invoked in registration order.
          * Default: none.
          */
-        public Builder interceptors(List interceptors) {
+        public Builder interceptors(List interceptors) {
             this.interceptors = interceptors;
             return this;
         }
 
         /**
-         * Returns an immutable {@link ConnectProtocolConfig}.
+         * Returns an immutable {@link ConnectServerProtocolConfig}.
          *
          * @throws NullPointerException if any required component is {@code null}
          */
-        public ConnectProtocolConfig build() {
-            return new ConnectProtocolConfig(
+        public ConnectServerProtocolConfig build() {
+            return new ConnectServerProtocolConfig(
                 services, callHandlerFactory, parameters,
                 codecRegistry, compressionRegistry, jsonSerializer, interceptors);
         }
diff --git a/src/main/java/io/suboptimal/connectjava/protocol/ConnectProtocolParameters.java b/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectServerProtocolParameters.java
similarity index 73%
rename from src/main/java/io/suboptimal/connectjava/protocol/ConnectProtocolParameters.java
rename to src/main/java/io/suboptimal/connectjava/protocol/server/ConnectServerProtocolParameters.java
index 68492d7..87542a3 100644
--- a/src/main/java/io/suboptimal/connectjava/protocol/ConnectProtocolParameters.java
+++ b/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectServerProtocolParameters.java
@@ -1,7 +1,7 @@
-package io.suboptimal.connectjava.protocol;
+package io.suboptimal.connectjava.protocol.server;
 
 /**
- * Configures request size limits and optional CORS policy for {@link ConnectProtocol}.
+ * Configures request size limits and optional CORS policy for {@link ConnectServerProtocol}.
  *
  * @param maxRequestBytes maximum aggregated request body size for unary calls
  * @param maxFrameBytes   maximum payload size of a single server-streaming Connect envelope,
@@ -9,8 +9,8 @@
  *                        length exceeds this limit are rejected with a streaming error
  * @param cors            CORS policy; use {@link ConnectCorsParameters#disabled()} to opt out
  */
-public record ConnectProtocolParameters(int maxRequestBytes, int maxFrameBytes, ConnectCorsParameters cors) {
-    public ConnectProtocolParameters {
+public record ConnectServerProtocolParameters(int maxRequestBytes, int maxFrameBytes, ConnectCorsParameters cors) {
+    public ConnectServerProtocolParameters {
         if (maxRequestBytes <= 0) {
             throw new IllegalArgumentException("maxRequestBytes must be positive");
         }
@@ -20,7 +20,7 @@ public record ConnectProtocolParameters(int maxRequestBytes, int maxFrameBytes,
     }
 
     /** Convenience constructor with CORS disabled. */
-    public ConnectProtocolParameters(int maxRequestBytes, int maxFrameBytes) {
+    public ConnectServerProtocolParameters(int maxRequestBytes, int maxFrameBytes) {
         this(maxRequestBytes, maxFrameBytes, ConnectCorsParameters.disabled());
     }
 }
diff --git a/src/main/java/io/suboptimal/connectjava/protocol/ConnectStringBuilderJsonSerializer.java b/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectStringBuilderJsonSerializer.java
similarity index 98%
rename from src/main/java/io/suboptimal/connectjava/protocol/ConnectStringBuilderJsonSerializer.java
rename to src/main/java/io/suboptimal/connectjava/protocol/server/ConnectStringBuilderJsonSerializer.java
index 8dec792..1389269 100644
--- a/src/main/java/io/suboptimal/connectjava/protocol/ConnectStringBuilderJsonSerializer.java
+++ b/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectStringBuilderJsonSerializer.java
@@ -1,4 +1,4 @@
-package io.suboptimal.connectjava.protocol;
+package io.suboptimal.connectjava.protocol.server;
 
 import io.suboptimal.connectjava.api.ConnectError;
 import io.suboptimal.connectjava.api.ConnectErrorDetail;
diff --git a/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectTransport.java b/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectTransport.java
new file mode 100644
index 0000000..b59efc0
--- /dev/null
+++ b/src/main/java/io/suboptimal/connectjava/protocol/server/ConnectTransport.java
@@ -0,0 +1,10 @@
+package io.suboptimal.connectjava.protocol.server;
+
+/**
+ * HTTP transport version a {@link ConnectServerChannelConfigurer} and its
+ * {@link RoutingServerHandler} are wired for.
+ */
+enum ConnectTransport {
+    HTTP_1_1,
+    HTTP_2
+}
diff --git a/src/main/java/io/suboptimal/connectjava/protocol/HttpResponses.java b/src/main/java/io/suboptimal/connectjava/protocol/server/HttpResponses.java
similarity index 98%
rename from src/main/java/io/suboptimal/connectjava/protocol/HttpResponses.java
rename to src/main/java/io/suboptimal/connectjava/protocol/server/HttpResponses.java
index d867173..c7f2d1b 100644
--- a/src/main/java/io/suboptimal/connectjava/protocol/HttpResponses.java
+++ b/src/main/java/io/suboptimal/connectjava/protocol/server/HttpResponses.java
@@ -1,4 +1,4 @@
-package io.suboptimal.connectjava.protocol;
+package io.suboptimal.connectjava.protocol.server;
 
 import io.netty.buffer.Unpooled;
 import io.netty.handler.codec.http.DefaultFullHttpResponse;
diff --git a/src/main/java/io/suboptimal/connectjava/protocol/ResponseHeadersBuilder.java b/src/main/java/io/suboptimal/connectjava/protocol/server/ResponseHeadersBuilder.java
similarity index 96%
rename from src/main/java/io/suboptimal/connectjava/protocol/ResponseHeadersBuilder.java
rename to src/main/java/io/suboptimal/connectjava/protocol/server/ResponseHeadersBuilder.java
index d95cc62..7d67687 100644
--- a/src/main/java/io/suboptimal/connectjava/protocol/ResponseHeadersBuilder.java
+++ b/src/main/java/io/suboptimal/connectjava/protocol/server/ResponseHeadersBuilder.java
@@ -1,4 +1,4 @@
-package io.suboptimal.connectjava.protocol;
+package io.suboptimal.connectjava.protocol.server;
 
 import io.netty.handler.codec.http.HttpHeaders;
 import io.suboptimal.connectjava.api.ConnectResponseHeadersBuilder;
diff --git a/src/main/java/io/suboptimal/connectjava/protocol/ResponseTrailersBuilder.java b/src/main/java/io/suboptimal/connectjava/protocol/server/ResponseTrailersBuilder.java
similarity index 97%
rename from src/main/java/io/suboptimal/connectjava/protocol/ResponseTrailersBuilder.java
rename to src/main/java/io/suboptimal/connectjava/protocol/server/ResponseTrailersBuilder.java
index e058359..5601ef7 100644
--- a/src/main/java/io/suboptimal/connectjava/protocol/ResponseTrailersBuilder.java
+++ b/src/main/java/io/suboptimal/connectjava/protocol/server/ResponseTrailersBuilder.java
@@ -1,4 +1,4 @@
-package io.suboptimal.connectjava.protocol;
+package io.suboptimal.connectjava.protocol.server;
 
 import io.netty.handler.codec.http.HttpHeaders;
 import io.netty.util.AsciiString;
diff --git a/src/main/java/io/suboptimal/connectjava/protocol/RoutingHandler.java b/src/main/java/io/suboptimal/connectjava/protocol/server/RoutingServerHandler.java
similarity index 78%
rename from src/main/java/io/suboptimal/connectjava/protocol/RoutingHandler.java
rename to src/main/java/io/suboptimal/connectjava/protocol/server/RoutingServerHandler.java
index 9b41335..868fc78 100644
--- a/src/main/java/io/suboptimal/connectjava/protocol/RoutingHandler.java
+++ b/src/main/java/io/suboptimal/connectjava/protocol/server/RoutingServerHandler.java
@@ -1,4 +1,4 @@
-package io.suboptimal.connectjava.protocol;
+package io.suboptimal.connectjava.protocol.server;
 
 import io.netty.channel.ChannelHandler;
 import io.netty.channel.ChannelHandlerContext;
@@ -27,16 +27,16 @@
  *   
  • Bidirectional streaming over HTTP/1.1 → plain HTTP * {@code 505 HTTP Version Not Supported} with {@code Connection: close}.
  • *
  • {@code GET} on unary calls → installs an {@link HttpObjectAggregator} - * followed by {@link UnaryGetRequestHandler} for side-effect-free calls + * followed by {@link UnaryGetRequestServerHandler} for side-effect-free calls * with query payloads.
  • *
  • {@code GET} on streaming calls → plain HTTP {@code 405 Method Not Allowed} * with {@code Allow: POST}.
  • *
  • {@code POST} with {@code application/connect+proto}, {@code application/connect+json}, or another - * {@code application/connect+*} type → installs {@link StreamingHandler} + * {@code application/connect+*} type → installs {@link StreamingServerHandler} * for streaming procedures, or rejects unary procedures with {@code 415}.
  • *
  • {@code POST} with {@code application/proto}, {@code application/json}, or another * {@code application/*} type → installs an {@link HttpObjectAggregator} - * followed by {@link UnaryPostRequestHandler} for unary procedures, or + * followed by {@link UnaryPostRequestServerHandler} for unary procedures, or * rejects streaming procedures with {@code 415}.
  • *
  • Other methods on known procedures → plain HTTP {@code 405 Method Not Allowed}.
  • *
  • Other {@code POST} content types → plain HTTP {@code 415 Unsupported Media Type}.
  • @@ -44,7 +44,7 @@ * *

    The unary request handlers decode the accepted full request, emit {@link * ConnectCallExchange} and {@link ConnectPayload} into the terminal handler, and install - * {@link UnaryResponseProcessingHandler} for the outbound response leg. + * {@link UnaryResponseProcessingServerHandler} for the outbound response leg. * *

    On a successful match the handler removes itself from the pipeline and re-fires the * original {@link HttpRequest} into the newly configured chain. Plain HTTP error paths release @@ -53,33 +53,33 @@ *

    The {@link ConnectTransport} constructor parameter makes transport-sensitive * gates immutable for each handler instance. Annotated * {@link ChannelHandler.Sharable @Sharable}: one instance is created per transport - * configurer in {@link ConnectChannelConfigurer} and reused across every channel of - * the same {@link ConnectProtocol}. All fields are final and immutable. + * configurer in {@link ConnectServerChannelConfigurer} and reused across every channel of + * the same {@link ConnectServerProtocol}. All fields are final and immutable. */ @ChannelHandler.Sharable -class RoutingHandler extends SimpleChannelInboundHandler { +class RoutingServerHandler extends SimpleChannelInboundHandler { private final ConnectTransport transport; private final Map serviceDefinitions; - private final ConnectProtocolParameters parameters; - private final ConnectProtocolConfig config; - private final ConnectInterceptorPipeline interceptorPipeline; + private final ConnectServerProtocolParameters parameters; + private final ConnectServerProtocolConfig config; + private final ConnectServerInterceptorPipeline interceptorPipeline; - RoutingHandler(ConnectTransport transport, ConnectProtocolConfig config) { + RoutingServerHandler(ConnectTransport transport, ConnectServerProtocolConfig config) { super(false); this.transport = transport; this.serviceDefinitions = config.services(); this.parameters = config.parameters(); this.config = config; this.interceptorPipeline = config.interceptors().isEmpty() - ? ConnectInterceptorPipeline.EMPTY - : new ConnectInterceptorPipeline(config.interceptors()); + ? ConnectServerInterceptorPipeline.EMPTY + : new ConnectServerInterceptorPipeline(config.interceptors()); } @Override protected void channelRead0(ChannelHandlerContext ctx, HttpRequest request) { ChannelPipeline pipeline = ctx.pipeline(); - ConnectCallExchange exchange = buildConnectCallExchange(request); + ConnectCallExchange exchange = buildConnectServerCallExchange(request); if (exchange == null) { ReferenceCountUtil.release(request); ctx.writeAndFlush(HttpResponses.notFound()); @@ -101,13 +101,13 @@ protected void channelRead0(ChannelHandlerContext ctx, HttpRequest request) { return; } pipeline.addAfter( - ConnectPipeline.ROUTING_HANDLER, - ConnectPipeline.AGGREGATOR_HANDLER, + ConnectServerPipeline.ROUTING_HANDLER, + ConnectServerPipeline.AGGREGATOR_HANDLER, new HttpObjectAggregator(parameters.maxRequestBytes())); pipeline.addAfter( - ConnectPipeline.AGGREGATOR_HANDLER, - ConnectPipeline.UNARY_GET_REQUEST_HANDLER, - new UnaryGetRequestHandler( + ConnectServerPipeline.AGGREGATOR_HANDLER, + ConnectServerPipeline.UNARY_GET_REQUEST_HANDLER, + new UnaryGetRequestServerHandler( exchange, config.codecRegistry(), config.compressionRegistry(), parameters.maxRequestBytes(), config.jsonSerializer(), interceptorPipeline)); pipeline.remove(this); @@ -133,9 +133,9 @@ protected void channelRead0(ChannelHandlerContext ctx, HttpRequest request) { return; } pipeline.addAfter( - ConnectPipeline.ROUTING_HANDLER, - ConnectPipeline.STREAMING_HANDLER, - new StreamingHandler( + ConnectServerPipeline.ROUTING_HANDLER, + ConnectServerPipeline.STREAMING_HANDLER, + new StreamingServerHandler( exchange, parameters.maxFrameBytes(), config.codecRegistry(), config.compressionRegistry(), config.jsonSerializer(), interceptorPipeline)); pipeline.remove(this); @@ -147,13 +147,13 @@ protected void channelRead0(ChannelHandlerContext ctx, HttpRequest request) { return; } pipeline.addAfter( - ConnectPipeline.ROUTING_HANDLER, - ConnectPipeline.AGGREGATOR_HANDLER, + ConnectServerPipeline.ROUTING_HANDLER, + ConnectServerPipeline.AGGREGATOR_HANDLER, new HttpObjectAggregator(parameters.maxRequestBytes())); pipeline.addAfter( - ConnectPipeline.AGGREGATOR_HANDLER, - ConnectPipeline.UNARY_POST_REQUEST_HANDLER, - new UnaryPostRequestHandler(exchange, config.codecRegistry(), + ConnectServerPipeline.AGGREGATOR_HANDLER, + ConnectServerPipeline.UNARY_POST_REQUEST_HANDLER, + new UnaryPostRequestServerHandler(exchange, config.codecRegistry(), config.compressionRegistry(), config.jsonSerializer(), interceptorPipeline)); pipeline.remove(this); ctx.fireChannelRead(request); @@ -163,7 +163,7 @@ protected void channelRead0(ChannelHandlerContext ctx, HttpRequest request) { } } - private @Nullable ConnectCallExchange buildConnectCallExchange(HttpRequest request) { + private @Nullable ConnectCallExchange buildConnectServerCallExchange(HttpRequest request) { ConnectRoute route = ConnectRoute.parse(request.uri()); if (route == null) { return null; diff --git a/src/main/java/io/suboptimal/connectjava/protocol/StreamingHandler.java b/src/main/java/io/suboptimal/connectjava/protocol/server/StreamingServerHandler.java similarity index 93% rename from src/main/java/io/suboptimal/connectjava/protocol/StreamingHandler.java rename to src/main/java/io/suboptimal/connectjava/protocol/server/StreamingServerHandler.java index 496de86..637b87e 100644 --- a/src/main/java/io/suboptimal/connectjava/protocol/StreamingHandler.java +++ b/src/main/java/io/suboptimal/connectjava/protocol/server/StreamingServerHandler.java @@ -1,4 +1,4 @@ -package io.suboptimal.connectjava.protocol; +package io.suboptimal.connectjava.protocol.server; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelDuplexHandler; @@ -25,19 +25,24 @@ import io.suboptimal.connectjava.compression.ConnectCompressionRegistry; import io.suboptimal.connectjava.compression.ConnectIdentityCompression; import io.suboptimal.connectjava.model.ConnectMethodType; +import io.suboptimal.connectjava.protocol.ConnectCompressionNegotiation; +import io.suboptimal.connectjava.protocol.ConnectEnvelope; +import io.suboptimal.connectjava.protocol.ConnectMediaType; +import io.suboptimal.connectjava.protocol.ConnectProtocolVersion; +import io.suboptimal.connectjava.protocol.client.ConnectCallTerminatedException; import org.jspecify.annotations.Nullable; import java.io.IOException; -class StreamingHandler extends ChannelDuplexHandler { +class StreamingServerHandler extends ChannelDuplexHandler { private final ConnectCallExchange exchange; private final int maxFrameBytes; private final ConnectCodecRegistry codecRegistry; private final ConnectCompressionRegistry compressionRegistry; private final ConnectJsonSerializer jsonSerializer; - private final ConnectInterceptorPipeline interceptorPipeline; + private final ConnectServerInterceptorPipeline interceptorPipeline; - private ConnectCallObserver observer = ConnectCallObserver.NOOP; + private ConnectServerCallObserver observer = ConnectServerCallObserver.NOOP; private @Nullable ConnectCodec codec; private ConnectEnvelope.@Nullable Decoder decoder; private ConnectCompression requestEncoding = ConnectIdentityCompression.INSTANCE; @@ -48,13 +53,13 @@ class StreamingHandler extends ChannelDuplexHandler { private int requestPayloadsForwarded; private int responsePayloadsWritten; - StreamingHandler( + StreamingServerHandler( ConnectCallExchange exchange, int maxFrameBytes, ConnectCodecRegistry codecRegistry, ConnectCompressionRegistry compressionRegistry, ConnectJsonSerializer jsonSerializer, - ConnectInterceptorPipeline interceptorPipeline) + ConnectServerInterceptorPipeline interceptorPipeline) { this.exchange = exchange; this.maxFrameBytes = maxFrameBytes; @@ -129,10 +134,10 @@ private void handleRequest(ChannelHandlerContext ctx, HttpRequest request) { responseEncoding = ConnectCompressionNegotiation.selectResponseEncoding( requestEncoding, request.headers().get("connect-accept-encoding"), compressionRegistry); - ConnectInterceptor.Decision decision = interceptorPipeline.interceptCall(exchange); + ConnectServerInterceptor.Decision decision = interceptorPipeline.interceptCall(exchange); switch (decision) { - case ConnectInterceptor.Decision.Continue(ConnectCallObserver o) -> observer = o; - case ConnectInterceptor.Decision.Reject(ConnectCallObserver o, ConnectError error) -> { + case ConnectServerInterceptor.Decision.Continue(ConnectServerCallObserver o) -> observer = o; + case ConnectServerInterceptor.Decision.Reject(ConnectServerCallObserver o, ConnectError error) -> { observer = o; writeStreamingError(ctx, error); return; diff --git a/src/main/java/io/suboptimal/connectjava/protocol/UnaryGetRequestHandler.java b/src/main/java/io/suboptimal/connectjava/protocol/server/UnaryGetRequestServerHandler.java similarity index 88% rename from src/main/java/io/suboptimal/connectjava/protocol/UnaryGetRequestHandler.java rename to src/main/java/io/suboptimal/connectjava/protocol/server/UnaryGetRequestServerHandler.java index 042dc3c..1514220 100644 --- a/src/main/java/io/suboptimal/connectjava/protocol/UnaryGetRequestHandler.java +++ b/src/main/java/io/suboptimal/connectjava/protocol/server/UnaryGetRequestServerHandler.java @@ -1,4 +1,4 @@ -package io.suboptimal.connectjava.protocol; +package io.suboptimal.connectjava.protocol.server; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; @@ -17,6 +17,9 @@ import io.suboptimal.connectjava.compression.ConnectCompression; import io.suboptimal.connectjava.compression.ConnectCompressionRegistry; import io.suboptimal.connectjava.compression.ConnectIdentityCompression; +import io.suboptimal.connectjava.protocol.ConnectCompressionNegotiation; +import io.suboptimal.connectjava.protocol.ConnectMediaType; +import io.suboptimal.connectjava.protocol.ConnectProtocolVersion; import org.jspecify.annotations.Nullable; import java.io.IOException; @@ -31,24 +34,24 @@ * target method is a side-effect-free unary method, selects the payload codec from the * {@code encoding} query parameter, decodes the query {@code message}, and fires one * {@link ConnectCallExchange} followed by one {@link ConnectPayload}. After successful decoding it - * installs {@link UnaryResponseProcessingHandler} to own the outbound response + * installs {@link UnaryResponseProcessingServerHandler} to own the outbound response * state machine. */ -class UnaryGetRequestHandler extends SimpleChannelInboundHandler { +class UnaryGetRequestServerHandler extends SimpleChannelInboundHandler { private final ConnectCallExchange exchange; private final ConnectCodecRegistry codecRegistry; private final ConnectCompressionRegistry compressionRegistry; private final int maxRequestBytes; private final ConnectJsonSerializer jsonSerializer; - private final ConnectInterceptorPipeline interceptorPipeline; + private final ConnectServerInterceptorPipeline interceptorPipeline; - UnaryGetRequestHandler( + UnaryGetRequestServerHandler( ConnectCallExchange exchange, ConnectCodecRegistry codecRegistry, ConnectCompressionRegistry compressionRegistry, int maxRequestBytes, ConnectJsonSerializer jsonSerializer, - ConnectInterceptorPipeline interceptorPipeline) + ConnectServerInterceptorPipeline interceptorPipeline) { this.exchange = exchange; this.codecRegistry = codecRegistry; @@ -140,11 +143,11 @@ protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest request) decompressed.release(); } - ConnectInterceptor.Decision decision = interceptorPipeline.interceptCall(exchange); - ConnectCallObserver observer; + ConnectServerInterceptor.Decision decision = interceptorPipeline.interceptCall(exchange); + ConnectServerCallObserver observer; switch (decision) { - case ConnectInterceptor.Decision.Continue(ConnectCallObserver o) -> observer = o; - case ConnectInterceptor.Decision.Reject(ConnectCallObserver o, ConnectError error) -> { + case ConnectServerInterceptor.Decision.Continue(ConnectServerCallObserver o) -> observer = o; + case ConnectServerInterceptor.Decision.Reject(ConnectServerCallObserver o, ConnectError error) -> { observer = o; writeRejectedCall(ctx, observer, error); return; @@ -154,8 +157,8 @@ protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest request) ConnectCompression responseCompression = ConnectCompressionNegotiation.selectResponseEncoding( requestEncoding, request.headers().get(HttpHeaderNames.ACCEPT_ENCODING), compressionRegistry); - ctx.pipeline().addAfter(ConnectPipeline.UNARY_GET_REQUEST_HANDLER, ConnectPipeline.UNARY_RESPONSE_HANDLER, - new UnaryResponseProcessingHandler(exchange, selectedCodec, + ctx.pipeline().addAfter(ConnectServerPipeline.UNARY_GET_REQUEST_HANDLER, ConnectServerPipeline.UNARY_RESPONSE_HANDLER, + new UnaryResponseProcessingServerHandler(exchange, selectedCodec, responseCompression, true, observer, jsonSerializer)); ctx.fireChannelRead(exchange); @@ -167,7 +170,7 @@ protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest request) private void writeRejectedCall( ChannelHandlerContext ctx, - ConnectCallObserver observer, + ConnectServerCallObserver observer, ConnectError error) { var response = HttpResponses.protocolError(error, jsonSerializer); diff --git a/src/main/java/io/suboptimal/connectjava/protocol/UnaryPostRequestHandler.java b/src/main/java/io/suboptimal/connectjava/protocol/server/UnaryPostRequestServerHandler.java similarity index 79% rename from src/main/java/io/suboptimal/connectjava/protocol/UnaryPostRequestHandler.java rename to src/main/java/io/suboptimal/connectjava/protocol/server/UnaryPostRequestServerHandler.java index 60d36bb..3c1d187 100644 --- a/src/main/java/io/suboptimal/connectjava/protocol/UnaryPostRequestHandler.java +++ b/src/main/java/io/suboptimal/connectjava/protocol/server/UnaryPostRequestServerHandler.java @@ -1,4 +1,4 @@ -package io.suboptimal.connectjava.protocol; +package io.suboptimal.connectjava.protocol.server; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; @@ -15,6 +15,9 @@ import io.suboptimal.connectjava.compression.ConnectCompressionRegistry; import io.suboptimal.connectjava.compression.ConnectIdentityCompression; import io.suboptimal.connectjava.model.ConnectMethodDefinition; +import io.suboptimal.connectjava.protocol.ConnectCompressionNegotiation; +import io.suboptimal.connectjava.protocol.ConnectMediaType; +import io.suboptimal.connectjava.protocol.ConnectProtocolVersion; import org.jspecify.annotations.Nullable; import java.io.IOException; @@ -26,19 +29,19 @@ * codec from {@code Content-Type}, validates Connect protocol/version and request * compression headers, decodes one request body, and fires one {@link ConnectCallExchange} * followed by one {@link ConnectPayload}. After successful decoding it installs - * {@link UnaryResponseProcessingHandler} to own the outbound response state + * {@link UnaryResponseProcessingServerHandler} to own the outbound response state * machine. */ -class UnaryPostRequestHandler extends SimpleChannelInboundHandler { +class UnaryPostRequestServerHandler extends SimpleChannelInboundHandler { private final ConnectCallExchange exchange; private final ConnectCodecRegistry codecRegistry; private final ConnectCompressionRegistry compressionRegistry; private final ConnectJsonSerializer jsonSerializer; - private final ConnectInterceptorPipeline interceptorPipeline; + private final ConnectServerInterceptorPipeline interceptorPipeline; - UnaryPostRequestHandler(ConnectCallExchange exchange, ConnectCodecRegistry codecRegistry, - ConnectCompressionRegistry compressionRegistry, ConnectJsonSerializer jsonSerializer, - ConnectInterceptorPipeline interceptorPipeline) + UnaryPostRequestServerHandler(ConnectCallExchange exchange, ConnectCodecRegistry codecRegistry, + ConnectCompressionRegistry compressionRegistry, ConnectJsonSerializer jsonSerializer, + ConnectServerInterceptorPipeline interceptorPipeline) { this.exchange = exchange; this.codecRegistry = codecRegistry; @@ -96,11 +99,11 @@ protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest request) decompressed.release(); } - ConnectInterceptor.Decision decision = interceptorPipeline.interceptCall(exchange); - ConnectCallObserver observer; + ConnectServerInterceptor.Decision decision = interceptorPipeline.interceptCall(exchange); + ConnectServerCallObserver observer; switch (decision) { - case ConnectInterceptor.Decision.Continue(ConnectCallObserver o) -> observer = o; - case ConnectInterceptor.Decision.Reject(ConnectCallObserver o, ConnectError error) -> { + case ConnectServerInterceptor.Decision.Continue(ConnectServerCallObserver o) -> observer = o; + case ConnectServerInterceptor.Decision.Reject(ConnectServerCallObserver o, ConnectError error) -> { observer = o; writeRejectedCall(ctx, observer, error); return; @@ -110,8 +113,8 @@ protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest request) ConnectCompression responseEncoding = ConnectCompressionNegotiation.selectResponseEncoding( requestCompression, request.headers().get(HttpHeaderNames.ACCEPT_ENCODING), compressionRegistry); - ctx.pipeline().addAfter(ConnectPipeline.UNARY_POST_REQUEST_HANDLER, ConnectPipeline.UNARY_RESPONSE_HANDLER, - new UnaryResponseProcessingHandler(exchange, selectedCodec, + ctx.pipeline().addAfter(ConnectServerPipeline.UNARY_POST_REQUEST_HANDLER, ConnectServerPipeline.UNARY_RESPONSE_HANDLER, + new UnaryResponseProcessingServerHandler(exchange, selectedCodec, responseEncoding, false, observer, jsonSerializer)); ctx.fireChannelRead(exchange); @@ -123,7 +126,7 @@ protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest request) private void writeRejectedCall( ChannelHandlerContext ctx, - ConnectCallObserver observer, + ConnectServerCallObserver observer, ConnectError error) { var response = HttpResponses.protocolError(error, jsonSerializer); diff --git a/src/main/java/io/suboptimal/connectjava/protocol/UnaryResponseProcessingHandler.java b/src/main/java/io/suboptimal/connectjava/protocol/server/UnaryResponseProcessingServerHandler.java similarity index 95% rename from src/main/java/io/suboptimal/connectjava/protocol/UnaryResponseProcessingHandler.java rename to src/main/java/io/suboptimal/connectjava/protocol/server/UnaryResponseProcessingServerHandler.java index 20dc2ed..55f554f 100644 --- a/src/main/java/io/suboptimal/connectjava/protocol/UnaryResponseProcessingHandler.java +++ b/src/main/java/io/suboptimal/connectjava/protocol/server/UnaryResponseProcessingServerHandler.java @@ -1,4 +1,4 @@ -package io.suboptimal.connectjava.protocol; +package io.suboptimal.connectjava.protocol.server; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; @@ -17,6 +17,8 @@ import io.suboptimal.connectjava.api.ConnectPayload; import io.suboptimal.connectjava.codec.ConnectCodec; import io.suboptimal.connectjava.compression.ConnectCompression; +import io.suboptimal.connectjava.protocol.ConnectMediaType; +import io.suboptimal.connectjava.protocol.client.ConnectCallTerminatedException; import org.jspecify.annotations.Nullable; import java.io.IOException; @@ -30,23 +32,23 @@ * to terminated after exactly one HTTP response is written. {@link ConnectError} and protocol * state violations write Connect JSON error responses and terminate the call. */ -class UnaryResponseProcessingHandler extends ChannelOutboundHandlerAdapter { +class UnaryResponseProcessingServerHandler extends ChannelOutboundHandlerAdapter { private final ConnectCallExchange exchange; private final ConnectCodec responseCodec; private final ConnectCompression responseCompression; private final boolean varyAcceptEncoding; - private final ConnectCallObserver observer; + private final ConnectServerCallObserver observer; private final ConnectJsonSerializer jsonSerializer; private State state = State.AWAITING_RESPONSE; private @Nullable FullHttpResponse pendingResponse; - UnaryResponseProcessingHandler( + UnaryResponseProcessingServerHandler( ConnectCallExchange exchange, ConnectCodec responseCodec, ConnectCompression responseCompression, boolean varyAcceptEncoding, - ConnectCallObserver observer, + ConnectServerCallObserver observer, ConnectJsonSerializer jsonSerializer) { this.exchange = exchange; diff --git a/src/main/java/io/suboptimal/connectjava/protocol/server/package-info.java b/src/main/java/io/suboptimal/connectjava/protocol/server/package-info.java new file mode 100644 index 0000000..5daa6f2 --- /dev/null +++ b/src/main/java/io/suboptimal/connectjava/protocol/server/package-info.java @@ -0,0 +1,4 @@ +@NullMarked +package io.suboptimal.connectjava.protocol.server; + +import org.jspecify.annotations.NullMarked; diff --git a/src/test/java/io/suboptimal/connectjava/protocol/ClientTestSupport.java b/src/test/java/io/suboptimal/connectjava/protocol/ClientTestSupport.java new file mode 100644 index 0000000..e6efa4d --- /dev/null +++ b/src/test/java/io/suboptimal/connectjava/protocol/ClientTestSupport.java @@ -0,0 +1,169 @@ +package io.suboptimal.connectjava.protocol; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.suboptimal.connectjava.api.ConnectResponseMeta; +import io.suboptimal.connectjava.api.ConnectError; +import io.suboptimal.connectjava.codec.ConnectCodec; +import io.suboptimal.connectjava.codec.protobuf.ConnectProtobufCodecs; +import io.suboptimal.connectjava.compression.ConnectCompression; +import io.suboptimal.connectjava.compression.ConnectCompressionRegistry; +import io.suboptimal.connectjava.protocol.client.ConnectClientCallObserver; +import io.suboptimal.connectjava.protocol.client.ConnectClientInterceptor; +import io.suboptimal.connectjava.protocol.client.ConnectClientProtocolConfig; +import io.suboptimal.connectjava.protocol.client.ConnectClientProtocolParameters; +import org.jspecify.annotations.Nullable; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +/** Shared helpers and recording doubles for client-side handler tests. */ +public final class ClientTestSupport { + static final int MAX_RESPONSE_BYTES = 4 * 1024 * 1024; + static final int MAX_FRAME_BYTES = 1024 * 1024; + + private static final ChannelHandler NOOP_TERMINAL = new ChannelInboundHandlerAdapter(); + + private ClientTestSupport() {} + + public static ConnectClientProtocolConfig config() { + return baseBuilder().build(); + } + + public static ConnectClientProtocolConfig config(List interceptors) { + return baseBuilder().interceptors(interceptors).build(); + } + + public static ConnectClientProtocolConfig configWithMaxFrameBytes(int maxFrameBytes) { + return ConnectClientProtocolConfig.builder( + () -> NOOP_TERMINAL, + new ConnectClientProtocolParameters(MAX_RESPONSE_BYTES, maxFrameBytes), + ConnectProtobufCodecs.defaults()).build(); + } + + private static ConnectClientProtocolConfig.Builder baseBuilder() { + return ConnectClientProtocolConfig.builder( + () -> NOOP_TERMINAL, + new ConnectClientProtocolParameters(MAX_RESPONSE_BYTES, MAX_FRAME_BYTES), + ConnectProtobufCodecs.defaults()); + } + + public static ConnectCodec protoCodec() { + ConnectCodec codec = ConnectProtobufCodecs.defaults().byName("proto"); + if (codec == null) { + throw new IllegalStateException("proto codec not registered"); + } + return codec; + } + + static ConnectCompression gzip() { + ConnectCompression gzip = ConnectCompressionRegistry.standard().resolve("gzip"); + if (gzip == null) { + throw new IllegalStateException("gzip not registered"); + } + return gzip; + } + + /** Encodes a message with the given codec and returns the raw bytes. */ + public static byte[] encode(ConnectCodec codec, Object message) { + ByteBuf buf = null; + try { + buf = codec.encode(message, UnpooledByteBufAllocator.DEFAULT); + return ByteBufUtil.getBytes(buf); + } catch (IOException e) { + throw new RuntimeException(e); + } finally { + if (buf != null) { + buf.release(); + } + } + } + + /** gzip-compresses raw bytes. */ + public static byte[] gzipCompress(byte[] raw) { + ByteBuf input = UnpooledByteBufAllocator.DEFAULT.buffer(raw.length).writeBytes(raw); + ByteBuf compressed = null; + try { + compressed = gzip().compress(input, UnpooledByteBufAllocator.DEFAULT); + return ByteBufUtil.getBytes(compressed); + } catch (IOException e) { + throw new RuntimeException(e); + } finally { + input.release(); + if (compressed != null) { + compressed.release(); + } + } + } + + /** gzip-decompresses raw bytes. */ + public static byte[] gzipDecompress(byte[] compressed) { + ByteBuf input = UnpooledByteBufAllocator.DEFAULT.buffer(compressed.length).writeBytes(compressed); + ByteBuf decompressed = null; + try { + decompressed = gzip().decompress(input, UnpooledByteBufAllocator.DEFAULT); + return ByteBufUtil.getBytes(decompressed); + } catch (IOException e) { + throw new RuntimeException(e); + } finally { + input.release(); + if (decompressed != null) { + decompressed.release(); + } + } + } + + /** Records observer callbacks for ordering and arity assertions. */ + public static final class RecordingObserver implements ConnectClientCallObserver { + public final List events = new ArrayList<>(); + final List requestPayloads = new ArrayList<>(); + final List responsePayloads = new ArrayList<>(); + @Nullable ConnectResponseMeta responseMeta; + public @Nullable ConnectError completeError; + public int completeCount; + + @Override + public void onRequestPayload(Object payload) { + events.add("onRequestPayload"); + requestPayloads.add(payload); + } + + @Override + public void onRequestFinished() { + events.add("onRequestFinished"); + } + + @Override + public void onResponseHeaders(ConnectResponseMeta meta) { + events.add("onResponseHeaders"); + responseMeta = meta; + } + + @Override + public void onResponsePayload(Object payload) { + events.add("onResponsePayload"); + responsePayloads.add(payload); + } + + @Override + public void onCallComplete(@Nullable ConnectError error) { + events.add("onCallComplete"); + completeCount++; + completeError = error; + } + } + + /** Interceptor that always rejects with the given error. */ + public static ConnectClientInterceptor rejectingInterceptor(ConnectError error) { + return callStart -> ConnectClientInterceptor.reject(error); + } + + /** Interceptor that continues with the given observer. */ + public static ConnectClientInterceptor continuingInterceptor(ConnectClientCallObserver observer) { + return callStart -> ConnectClientInterceptor.continueWith(observer); + } +} diff --git a/src/test/java/io/suboptimal/connectjava/protocol/ConnectClientProtocolIntegrationTest.java b/src/test/java/io/suboptimal/connectjava/protocol/ConnectClientProtocolIntegrationTest.java new file mode 100644 index 0000000..5ded74c --- /dev/null +++ b/src/test/java/io/suboptimal/connectjava/protocol/ConnectClientProtocolIntegrationTest.java @@ -0,0 +1,305 @@ +package io.suboptimal.connectjava.protocol; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.MultiThreadIoEventLoopGroup; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalIoHandler; +import io.netty.channel.local.LocalServerChannel; +import io.netty.handler.codec.http.HttpClientCodec; +import io.netty.handler.codec.http.HttpServerCodec; +import io.suboptimal.connectjava.api.ConnectClientResponseStart; +import io.suboptimal.connectjava.api.ConnectResponseMeta; +import io.suboptimal.connectjava.api.ConnectEndOfStream; +import io.suboptimal.connectjava.api.ConnectError; +import io.suboptimal.connectjava.api.ConnectErrorCode; +import io.suboptimal.connectjava.api.ConnectPayload; +import io.suboptimal.connectjava.api.ConnectCallExchange; +import io.suboptimal.connectjava.codec.protobuf.ConnectProtobufCodecs; +import io.suboptimal.connectjava.model.ConnectMethodDefinition; +import io.suboptimal.connectjava.model.ConnectMethodType; +import io.suboptimal.connectjava.model.ConnectServiceDefinition; +import io.suboptimal.connectjava.protocol.client.ConnectClientCallStart; +import io.suboptimal.connectjava.protocol.client.ConnectClientProtocol; +import io.suboptimal.connectjava.protocol.client.ConnectClientProtocolConfig; +import io.suboptimal.connectjava.protocol.client.ConnectClientProtocolParameters; +import io.suboptimal.connectjava.protocol.server.ConnectServerProtocol; +import io.suboptimal.connectjava.protocol.server.ConnectServerProtocolConfig; +import io.suboptimal.connectjava.protocol.server.ConnectServerProtocolParameters; +import io.suboptimal.connectjava.testfixtures.StreamingRequest; +import io.suboptimal.connectjava.testfixtures.StreamingResponse; +import io.suboptimal.connectjava.testfixtures.UnaryGetRequest; +import io.suboptimal.connectjava.testfixtures.UnaryGetResponse; +import io.suboptimal.connectjava.testfixtures.UnaryPostRequest; +import io.suboptimal.connectjava.testfixtures.UnaryPostResponse; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + +import static org.assertj.core.api.Assertions.assertThat; + +class ConnectClientProtocolIntegrationTest { + private static final int MAX = 4 * 1024 * 1024; + private static final LocalAddress ADDRESS = new LocalAddress("connect-client-it"); + + private static final ConnectMethodDefinition UNARY_POST = new ConnectMethodDefinition( + "Unary", ConnectMethodType.UNARY, UnaryPostRequest.class, UnaryPostResponse.class, false); + private static final ConnectServiceDefinition UNARY_POST_SERVICE = new ConnectServiceDefinition( + "connectjava.test.v1.UnaryPostFixtureService", List.of(UNARY_POST), null); + + private static final ConnectMethodDefinition UNARY_GET = new ConnectMethodDefinition( + "SafeUnary", ConnectMethodType.UNARY, UnaryGetRequest.class, UnaryGetResponse.class, true); + private static final ConnectServiceDefinition UNARY_GET_SERVICE = new ConnectServiceDefinition( + "connectjava.test.v1.UnaryGetFixtureService", List.of(UNARY_GET), null); + + private static final ConnectMethodDefinition SERVER_STREAMING = new ConnectMethodDefinition( + "ServerStreaming", ConnectMethodType.SERVER_STREAMING, StreamingRequest.class, StreamingResponse.class, false); + private static final ConnectMethodDefinition CLIENT_STREAMING = new ConnectMethodDefinition( + "ClientStreaming", ConnectMethodType.CLIENT_STREAMING, StreamingRequest.class, StreamingResponse.class, false); + private static final ConnectServiceDefinition STREAMING_SERVICE = new ConnectServiceDefinition( + "connectjava.test.v1.StreamingFixtureService", List.of(SERVER_STREAMING, CLIENT_STREAMING), null); + + private static EventLoopGroup group; + private static Channel serverChannel; + + @BeforeAll + static void startServer() throws InterruptedException { + group = new MultiThreadIoEventLoopGroup(LocalIoHandler.newFactory()); + + Map services = Map.of( + UNARY_POST_SERVICE.serviceName(), UNARY_POST_SERVICE, + UNARY_GET_SERVICE.serviceName(), UNARY_GET_SERVICE, + STREAMING_SERVICE.serviceName(), STREAMING_SERVICE); + + ConnectServerProtocolConfig serverConfig = ConnectServerProtocolConfig.builder( + services, ServerEcho::new, + new ConnectServerProtocolParameters(MAX, MAX), + ConnectProtobufCodecs.defaults()).build(); + ConnectServerProtocol serverProtocol = new ConnectServerProtocol(serverConfig); + + serverChannel = new ServerBootstrap() + .group(group) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(new HttpServerCodec()); + serverProtocol.http1().configure(ch); + } + }) + .bind(ADDRESS).sync().channel(); + } + + @AfterAll + static void stopServer() throws InterruptedException { + if (serverChannel != null) { + serverChannel.close().sync(); + } + if (group != null) { + group.shutdownGracefully(0, 1, TimeUnit.SECONDS).sync(); + } + } + + @Test + void unaryPostRoundTrip() throws Exception { + CallResult result = call(UNARY_POST_SERVICE, UNARY_POST, false, + List.of(UnaryPostRequest.newBuilder().setText("hi").build())); + + assertThat(result.error).isNull(); + assertThat(result.meta.statusCode()).isEqualTo(200); + assertThat(result.payloads).hasSize(1); + assertThat(((UnaryPostResponse) result.payloads.getFirst()).getText()).isEqualTo("echo:hi"); + } + + @Test + void unaryGetRoundTrip() throws Exception { + CallResult result = call(UNARY_GET_SERVICE, UNARY_GET, true, + List.of(UnaryGetRequest.newBuilder().setText("safe").build())); + + assertThat(result.error).isNull(); + assertThat(result.payloads).hasSize(1); + assertThat(((UnaryGetResponse) result.payloads.getFirst()).getText()).isEqualTo("echo:safe"); + } + + @Test + void serverStreamingRoundTrip() throws Exception { + CallResult result = call(STREAMING_SERVICE, SERVER_STREAMING, false, + List.of(StreamingRequest.newBuilder().setText("go").build())); + + assertThat(result.error).isNull(); + assertThat(result.payloads).hasSize(2); + assertThat(((StreamingResponse) result.payloads.get(0)).getText()).isEqualTo("echo:go#0"); + assertThat(((StreamingResponse) result.payloads.get(1)).getText()).isEqualTo("echo:go#1"); + } + + @Test + void clientStreamingRoundTrip() throws Exception { + CallResult result = call(STREAMING_SERVICE, CLIENT_STREAMING, false, + List.of(StreamingRequest.newBuilder().setText("a").build(), + StreamingRequest.newBuilder().setText("b").build())); + + assertThat(result.error).isNull(); + assertThat(result.payloads).hasSize(1); + assertThat(((StreamingResponse) result.payloads.getFirst()).getText()).isEqualTo("echo:a"); + } + + @Test + void unaryErrorPropagates() throws Exception { + CallResult result = call(UNARY_POST_SERVICE, UNARY_POST, false, + List.of(UnaryPostRequest.newBuilder().setText("FAIL").build())); + + assertThat(result.error).isNotNull(); + assertThat(result.error.code()).isEqualTo(ConnectErrorCode.NOT_FOUND); + } + + @Test + void serverStreamingErrorPropagates() throws Exception { + CallResult result = call(STREAMING_SERVICE, SERVER_STREAMING, false, + List.of(StreamingRequest.newBuilder().setText("FAIL").build())); + + assertThat(result.error).isNotNull(); + assertThat(result.error.code()).isEqualTo(ConnectErrorCode.NOT_FOUND); + } + + private CallResult call(ConnectServiceDefinition service, ConnectMethodDefinition method, + boolean preferGet, List requests) throws Exception { + CompletableFuture future = new CompletableFuture<>(); + ConnectClientCallStart callStart = + new ConnectClientCallStart(service, method, Map.of(), preferGet, "proto"); + + ConnectClientProtocolConfig clientConfig = ConnectClientProtocolConfig.builder( + () -> new DriverHandler(callStart, requests, future), + new ConnectClientProtocolParameters(MAX, MAX), + ConnectProtobufCodecs.defaults()).build(); + ConnectClientProtocol clientProtocol = new ConnectClientProtocol(clientConfig); + + Channel channel = new Bootstrap() + .group(group) + .channel(LocalChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(new HttpClientCodec()); + clientProtocol.http1().configure(ch); + } + }) + .connect(ADDRESS).sync().channel(); + + try { + return future.get(5, TimeUnit.SECONDS); + } finally { + channel.close().sync(); + } + } + + private record CallResult(List payloads, ConnectError error, ConnectResponseMeta meta) {} + + /** Client terminal handler: drives one call on activation, captures the result. */ + private static final class DriverHandler extends ChannelInboundHandlerAdapter { + private final ConnectClientCallStart callStart; + private final List requests; + private final CompletableFuture future; + private final List payloads = new ArrayList<>(); + private ConnectResponseMeta meta; + + DriverHandler(ConnectClientCallStart callStart, List requests, + CompletableFuture future) { + this.callStart = callStart; + this.requests = requests; + this.future = future; + } + + @Override + public void channelActive(ChannelHandlerContext ctx) { + ctx.write(callStart); + for (Object request : requests) { + ctx.write(new ConnectPayload(request)); + } + ctx.writeAndFlush(ConnectEndOfStream.INSTANCE); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + switch (msg) { + case ConnectClientResponseStart exchange -> meta = exchange.responseMeta(); + case ConnectPayload payload -> payloads.add(payload.data()); + case ConnectEndOfStream eos -> future.complete(new CallResult(payloads, eos.error(), meta)); + case ConnectError error -> future.complete(new CallResult(payloads, error, meta)); + default -> { } + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + future.completeExceptionally(cause); + } + } + + /** Server terminal handler: echoes requests; fails when a request carries the text "FAIL". */ + private static final class ServerEcho extends ChannelInboundHandlerAdapter { + private ConnectCallExchange exchange; + private final List requestTexts = new ArrayList<>(); + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + switch (msg) { + case ConnectCallExchange ex -> exchange = ex; + case ConnectPayload payload -> requestTexts.add(textOf(payload.data())); + case ConnectEndOfStream ignored -> respond(ctx); + default -> { } + } + } + + private void respond(ChannelHandlerContext ctx) { + if (requestTexts.contains("FAIL")) { + ctx.writeAndFlush(ConnectError.notFound("requested failure")); + return; + } + String base = requestTexts.isEmpty() ? "" : requestTexts.getFirst(); + ConnectMethodType type = exchange.methodDefinition().type(); + int count = type == ConnectMethodType.SERVER_STREAMING ? 2 : 1; + for (int i = 0; i < count; i++) { + String text = type == ConnectMethodType.SERVER_STREAMING + ? "echo:" + base + "#" + i + : "echo:" + base; + ctx.write(new ConnectPayload(response(exchange, text))); + } + ctx.writeAndFlush(ConnectEndOfStream.INSTANCE); + } + + private static String textOf(Object request) { + return switch (request) { + case UnaryPostRequest r -> r.getText(); + case UnaryGetRequest r -> r.getText(); + case StreamingRequest r -> r.getText(); + default -> throw new IllegalStateException("unexpected request: " + request); + }; + } + + private static Object response(ConnectCallExchange exchange, String text) { + Class type = exchange.methodDefinition().responseType(); + if (type == UnaryPostResponse.class) { + return UnaryPostResponse.newBuilder().setText(text).build(); + } + if (type == UnaryGetResponse.class) { + return UnaryGetResponse.newBuilder().setText(text).build(); + } + if (type == StreamingResponse.class) { + return StreamingResponse.newBuilder().setText(text).build(); + } + throw new IllegalStateException("unexpected response type: " + type); + } + } +} diff --git a/src/test/java/io/suboptimal/connectjava/protocol/ConnectProtocolVersionTest.java b/src/test/java/io/suboptimal/connectjava/protocol/ConnectServerProtocolVersionTest.java similarity index 96% rename from src/test/java/io/suboptimal/connectjava/protocol/ConnectProtocolVersionTest.java rename to src/test/java/io/suboptimal/connectjava/protocol/ConnectServerProtocolVersionTest.java index 20c7aa2..bea52b8 100644 --- a/src/test/java/io/suboptimal/connectjava/protocol/ConnectProtocolVersionTest.java +++ b/src/test/java/io/suboptimal/connectjava/protocol/ConnectServerProtocolVersionTest.java @@ -8,7 +8,7 @@ import static org.assertj.core.api.Assertions.assertThat; -class ConnectProtocolVersionTest { +class ConnectServerProtocolVersionTest { @Test void acceptsVersion1() { diff --git a/src/test/java/io/suboptimal/connectjava/protocol/client/ConnectClientCallDispatcherTest.java b/src/test/java/io/suboptimal/connectjava/protocol/client/ConnectClientCallDispatcherTest.java new file mode 100644 index 0000000..15a1033 --- /dev/null +++ b/src/test/java/io/suboptimal/connectjava/protocol/client/ConnectClientCallDispatcherTest.java @@ -0,0 +1,250 @@ +package io.suboptimal.connectjava.protocol.client; + +import io.netty.channel.embedded.EmbeddedChannel; +import io.suboptimal.connectjava.api.ConnectEndOfStream; +import io.suboptimal.connectjava.api.ConnectError; +import io.suboptimal.connectjava.api.ConnectErrorCode; +import io.suboptimal.connectjava.api.ConnectPayload; +import io.suboptimal.connectjava.model.ConnectMethodDefinition; +import io.suboptimal.connectjava.model.ConnectMethodType; +import io.suboptimal.connectjava.model.ConnectServiceDefinition; +import io.netty.handler.codec.http.FullHttpRequest; +import io.suboptimal.connectjava.protocol.ClientTestSupport; +import io.suboptimal.connectjava.testfixtures.StreamingRequest; +import io.suboptimal.connectjava.testfixtures.StreamingResponse; +import io.suboptimal.connectjava.testfixtures.UnaryGetRequest; +import io.suboptimal.connectjava.testfixtures.UnaryGetResponse; +import io.suboptimal.connectjava.testfixtures.UnaryPostRequest; +import io.suboptimal.connectjava.testfixtures.UnaryPostResponse; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; + +import static org.assertj.core.api.Assertions.assertThat; + +class ConnectClientCallDispatcherTest { + private static final ConnectMethodDefinition UNARY_POST = new ConnectMethodDefinition( + "Unary", ConnectMethodType.UNARY, UnaryPostRequest.class, UnaryPostResponse.class, false); + private static final ConnectMethodDefinition UNARY_IDEMPOTENT = new ConnectMethodDefinition( + "SafeUnary", ConnectMethodType.UNARY, UnaryGetRequest.class, UnaryGetResponse.class, true); + private static final ConnectMethodDefinition SERVER_STREAMING = new ConnectMethodDefinition( + "ServerStreaming", ConnectMethodType.SERVER_STREAMING, StreamingRequest.class, StreamingResponse.class, false); + private static final ConnectServiceDefinition SERVICE = new ConnectServiceDefinition( + "svc.Service", List.of(UNARY_POST, UNARY_IDEMPOTENT, SERVER_STREAMING), null); + + private EmbeddedChannel channel; + + @BeforeEach + void setUp() { + channel = new EmbeddedChannel(); + } + + @AfterEach + void tearDown() { + channel.finishAndReleaseAll(); + } + + private void install(ConnectClientProtocolConfig config) { + channel.pipeline().addLast(ConnectClientPipeline.CALL_DISPATCHER, + new ConnectClientCallDispatcher(config)); + } + + private ConnectClientCallStart callStart(ConnectMethodDefinition method, boolean preferGet, String codecName) { + return new ConnectClientCallStart(SERVICE, method, Map.of(), preferGet, codecName); + } + + @Test + void installsPostHandlerForUnary() { + install(ClientTestSupport.config()); + channel.writeOutbound(callStart(UNARY_POST, false, "proto")); + + assertThat(channel.pipeline().get(ConnectClientPipeline.AGGREGATOR_HANDLER)).isNotNull(); + assertThat(channel.pipeline().get(ConnectClientPipeline.UNARY_POST_HANDLER)).isNotNull(); + assertThat(channel.pipeline().get(ConnectClientPipeline.UNARY_GET_HANDLER)).isNull(); + assertThat(channel.pipeline().get(ConnectClientPipeline.STREAMING_HANDLER)).isNull(); + } + + @Test + void installsGetHandlerForIdempotentPreferGet() { + install(ClientTestSupport.config()); + channel.writeOutbound(callStart(UNARY_IDEMPOTENT, true, "proto")); + + assertThat(channel.pipeline().get(ConnectClientPipeline.UNARY_GET_HANDLER)).isNotNull(); + assertThat(channel.pipeline().get(ConnectClientPipeline.UNARY_POST_HANDLER)).isNull(); + } + + @Test + void usesPostWhenPreferGetButNotIdempotent() { + install(ClientTestSupport.config()); + channel.writeOutbound(callStart(UNARY_POST, true, "proto")); + + assertThat(channel.pipeline().get(ConnectClientPipeline.UNARY_POST_HANDLER)).isNotNull(); + assertThat(channel.pipeline().get(ConnectClientPipeline.UNARY_GET_HANDLER)).isNull(); + } + + @Test + void installsStreamingHandlerForStreaming() { + install(ClientTestSupport.config()); + channel.writeOutbound(callStart(SERVER_STREAMING, false, "proto")); + + assertThat(channel.pipeline().get(ConnectClientPipeline.STREAMING_HANDLER)).isNotNull(); + assertThat(channel.pipeline().get(ConnectClientPipeline.UNARY_POST_HANDLER)).isNull(); + } + + @Test + void nullCodecNameProceeds() { + install(ClientTestSupport.config()); + channel.writeOutbound(callStart(UNARY_POST, false, null)); + + assertThat(channel.pipeline().get(ConnectClientPipeline.UNARY_POST_HANDLER)).isNotNull(); + Object inbound = channel.readInbound(); + assertThat(inbound).isNull(); + } + + @Test + void registeredCodecNameProceeds() { + install(ClientTestSupport.config()); + channel.writeOutbound(callStart(UNARY_POST, false, "json")); + + assertThat(channel.pipeline().get(ConnectClientPipeline.UNARY_POST_HANDLER)).isNotNull(); + Object inbound = channel.readInbound(); + assertThat(inbound).isNull(); + } + + @Test + void unknownCodecNameFailsCall() { + install(ClientTestSupport.config()); + channel.writeOutbound(callStart(UNARY_POST, false, "bogus")); + + Object inbound = channel.readInbound(); + assertThat(inbound).isInstanceOf(ConnectError.class); + assertThat(((ConnectError) inbound).code()).isEqualTo(ConnectErrorCode.INTERNAL); + assertThat(((ConnectError) inbound).message()).contains("Unknown codec"); + assertThat(channel.pipeline().get(ConnectClientPipeline.UNARY_POST_HANDLER)).isNull(); + } + + @Test + void interceptorRejectDeliversErrorAndInstallsNoHandler() { + ConnectError rejection = ConnectError.permissionDenied("denied"); + install(ClientTestSupport.config(List.of(ClientTestSupport.rejectingInterceptor(rejection)))); + + channel.writeOutbound(callStart(UNARY_POST, false, "proto")); + + Object inbound = channel.readInbound(); + assertThat(inbound).isInstanceOf(ConnectError.class); + assertThat(((ConnectError) inbound).code()).isEqualTo(ConnectErrorCode.PERMISSION_DENIED); + assertThat(channel.pipeline().get(ConnectClientPipeline.UNARY_POST_HANDLER)).isNull(); + } + + @Test + void interceptorObserverReceivesCallbacks() { + var observer = new ClientTestSupport.RecordingObserver(); + install(ClientTestSupport.config(List.of(ClientTestSupport.continuingInterceptor(observer)))); + + ConnectClientCallStart callStart = callStart(UNARY_POST, false, "proto"); + channel.writeOutbound(callStart); + channel.writeOutbound(new ConnectPayload(UnaryPostRequest.newBuilder().setText("x").build())); + channel.writeOutbound(ConnectEndOfStream.INSTANCE); + + Object request = channel.readOutbound(); + if (request instanceof io.netty.util.ReferenceCounted rc) { + rc.release(); + } + assertThat(observer.events).containsExactly("onRequestPayload", "onRequestFinished"); + } + + @Test + void reuseRemovesPreviousHandlersWithoutSpuriousComplete() { + var observer = new ClientTestSupport.RecordingObserver(); + install(ClientTestSupport.config(List.of(ClientTestSupport.continuingInterceptor(observer)))); + + // First call: drive the request fully so the POST handler reaches TERMINATED and the + // response handler is installed. + channel.writeOutbound(callStart(UNARY_POST, false, "proto")); + channel.writeOutbound(new ConnectPayload(UnaryPostRequest.newBuilder().setText("x").build())); + channel.writeOutbound(ConnectEndOfStream.INSTANCE); + Object firstRequest = channel.readOutbound(); + if (firstRequest instanceof io.netty.util.ReferenceCounted rc) { + rc.release(); + } + + // Second call reuses the channel. + channel.writeOutbound(callStart(SERVER_STREAMING, false, "proto")); + + assertThat(channel.pipeline().get(ConnectClientPipeline.STREAMING_HANDLER)).isNotNull(); + assertThat(channel.pipeline().get(ConnectClientPipeline.UNARY_POST_HANDLER)).isNull(); + assertThat(channel.pipeline().get(ConnectClientPipeline.UNARY_RESPONSE_HANDLER)).isNull(); + assertThat(channel.pipeline().get(ConnectClientPipeline.AGGREGATOR_HANDLER)).isNull(); + assertThat(observer.completeCount).isZero(); + } + + @Test + void interceptorCanRewriteOutgoingHeaders() { + ConnectClientInterceptor adder = cs -> + ConnectClientInterceptor.continueWith(cs.withHeader("x-test", "v1")); + install(ClientTestSupport.config(List.of(adder))); + + channel.writeOutbound(callStart(UNARY_POST, false, "proto")); + channel.writeOutbound(new ConnectPayload(UnaryPostRequest.newBuilder().setText("x").build())); + channel.writeOutbound(ConnectEndOfStream.INSTANCE); + + Object out = channel.readOutbound(); + assertThat(out).isInstanceOf(FullHttpRequest.class); + FullHttpRequest req = (FullHttpRequest) out; + assertThat(req.headers().get("x-test")).isEqualTo("v1"); + req.release(); + } + + @Test + void rewriteIsThreadedToNextInterceptor() { + AtomicReference seenBySecond = new AtomicReference<>(); + ConnectClientInterceptor first = cs -> + ConnectClientInterceptor.continueWith(cs.withHeader("x-a", "1")); + ConnectClientInterceptor second = cs -> { + seenBySecond.set(cs); + return ConnectClientInterceptor.continueWith(cs.withHeader("x-b", "2")); + }; + install(ClientTestSupport.config(List.of(first, second))); + + channel.writeOutbound(callStart(UNARY_POST, false, "proto")); + channel.writeOutbound(new ConnectPayload(UnaryPostRequest.newBuilder().setText("x").build())); + channel.writeOutbound(ConnectEndOfStream.INSTANCE); + + // the second interceptor observed the first interceptor's rewrite + assertThat(seenBySecond.get().requestHeaders()).containsKey("x-a"); + + Object out = channel.readOutbound(); + assertThat(out).isInstanceOf(FullHttpRequest.class); + FullHttpRequest req = (FullHttpRequest) out; + assertThat(req.headers().get("x-a")).isEqualTo("1"); + assertThat(req.headers().get("x-b")).isEqualTo("2"); + req.release(); + } + + @Test + void codecValidationAppliesToRewrittenCodec() { + // original codec is valid; an interceptor rewrites it to an unregistered one + ConnectClientInterceptor breaker = cs -> + ConnectClientInterceptor.continueWith(cs.withCodecName("bogus")); + install(ClientTestSupport.config(List.of(breaker))); + + channel.writeOutbound(callStart(UNARY_POST, false, "proto")); + + Object inbound = channel.readInbound(); + assertThat(inbound).isInstanceOf(ConnectError.class); + assertThat(((ConnectError) inbound).message()).contains("Unknown codec"); + assertThat(channel.pipeline().get(ConnectClientPipeline.UNARY_POST_HANDLER)).isNull(); + } + + @Test + void passesThroughNonCallStartMessages() { + install(ClientTestSupport.config()); + channel.writeOutbound("ping"); + Object out = channel.readOutbound(); + assertThat(out).isEqualTo("ping"); + } +} diff --git a/src/test/java/io/suboptimal/connectjava/protocol/client/ConnectClientCallStartTest.java b/src/test/java/io/suboptimal/connectjava/protocol/client/ConnectClientCallStartTest.java new file mode 100644 index 0000000..6ff5ce9 --- /dev/null +++ b/src/test/java/io/suboptimal/connectjava/protocol/client/ConnectClientCallStartTest.java @@ -0,0 +1,90 @@ +package io.suboptimal.connectjava.protocol.client; + +import io.suboptimal.connectjava.model.ConnectMethodDefinition; +import io.suboptimal.connectjava.model.ConnectMethodType; +import io.suboptimal.connectjava.model.ConnectServiceDefinition; +import io.suboptimal.connectjava.testfixtures.UnaryPostRequest; +import io.suboptimal.connectjava.testfixtures.UnaryPostResponse; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class ConnectClientCallStartTest { + private static final ConnectMethodDefinition METHOD = new ConnectMethodDefinition( + "Unary", ConnectMethodType.UNARY, UnaryPostRequest.class, UnaryPostResponse.class, false); + private static final ConnectServiceDefinition SERVICE = new ConnectServiceDefinition( + "svc.Service", List.of(METHOD), null); + + @Test + void normalizesHeaderNamesToLowerCase() { + ConnectClientCallStart callStart = new ConnectClientCallStart( + SERVICE, METHOD, Map.of("Content-Type", List.of("application/proto")), false, "proto"); + + assertThat(callStart.requestHeaders()).containsKey("content-type"); + assertThat(callStart.requestHeaders()).doesNotContainKey("Content-Type"); + } + + @Test + void copiesHeadersDefensively() { + Map> source = new java.util.HashMap<>(); + source.put("x-test", new ArrayList<>(List.of("a"))); + + ConnectClientCallStart callStart = new ConnectClientCallStart( + SERVICE, METHOD, source, false, null); + + source.put("x-added", List.of("b")); + assertThat(callStart.requestHeaders()).doesNotContainKey("x-added"); + assertThatThrownBy(() -> callStart.requestHeaders().put("x-mutate", List.of("c"))) + .isInstanceOf(UnsupportedOperationException.class); + } + + @Test + void allowsNullCodecName() { + ConnectClientCallStart callStart = new ConnectClientCallStart( + SERVICE, METHOD, Map.of(), true, null); + assertThat(callStart.codecName()).isNull(); + } + + @Test + void withHeaderAppendsValueAndLowercasesName() { + ConnectClientCallStart base = new ConnectClientCallStart( + SERVICE, METHOD, Map.of("x-existing", List.of("a")), false, "proto"); + + ConnectClientCallStart out = base.withHeader("X-Existing", "b").withHeader("X-New", "c"); + + assertThat(out.requestHeaders().get("x-existing")).containsExactly("a", "b"); + assertThat(out.requestHeaders().get("x-new")).containsExactly("c"); + // original is unchanged + assertThat(base.requestHeaders().get("x-existing")).containsExactly("a"); + assertThat(base.requestHeaders()).doesNotContainKey("x-new"); + } + + @Test + void withTimeoutMsAndCodecReplaceSingleField() { + ConnectClientCallStart base = new ConnectClientCallStart( + SERVICE, METHOD, Map.of(), false, "proto", 1000L); + + assertThat(base.withTimeoutMs(5000L).timeoutMs()).isEqualTo(5000L); + assertThat(base.withTimeoutMs(5000L).codecName()).isEqualTo("proto"); + assertThat(base.withCodecName("json").codecName()).isEqualTo("json"); + assertThat(base.withCodecName("json").timeoutMs()).isEqualTo(1000L); + // original unchanged + assertThat(base.timeoutMs()).isEqualTo(1000L); + assertThat(base.codecName()).isEqualTo("proto"); + } + + @Test + void rejectsNullRequiredFields() { + assertThatThrownBy(() -> new ConnectClientCallStart(null, METHOD, Map.of(), false, "proto")) + .isInstanceOf(NullPointerException.class); + assertThatThrownBy(() -> new ConnectClientCallStart(SERVICE, null, Map.of(), false, "proto")) + .isInstanceOf(NullPointerException.class); + assertThatThrownBy(() -> new ConnectClientCallStart(SERVICE, METHOD, null, false, "proto")) + .isInstanceOf(NullPointerException.class); + } +} diff --git a/src/test/java/io/suboptimal/connectjava/protocol/client/ConnectClientChannelConfigurerTest.java b/src/test/java/io/suboptimal/connectjava/protocol/client/ConnectClientChannelConfigurerTest.java new file mode 100644 index 0000000..28bc989 --- /dev/null +++ b/src/test/java/io/suboptimal/connectjava/protocol/client/ConnectClientChannelConfigurerTest.java @@ -0,0 +1,64 @@ +package io.suboptimal.connectjava.protocol.client; + +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.embedded.EmbeddedChannel; +import io.suboptimal.connectjava.codec.protobuf.ConnectProtobufCodecs; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +class ConnectClientChannelConfigurerTest { + private EmbeddedChannel channel; + + @BeforeEach + void setUp() { + channel = new EmbeddedChannel(); + } + + @AfterEach + void tearDown() { + channel.finishAndReleaseAll(); + } + + private static ConnectClientProtocolConfig freshConfig() { + return ConnectClientProtocolConfig.builder( + ChannelInboundHandlerAdapter::new, + new ConnectClientProtocolParameters(4 * 1024 * 1024, 1024 * 1024), + ConnectProtobufCodecs.defaults()).build(); + } + + @Test + void installsCallDispatcher() { + new ConnectClientChannelConfigurer(freshConfig()).configure(channel); + + assertThat(channel.pipeline().get(ConnectClientPipeline.CALL_DISPATCHER)).isNotNull(); + } + + @Test + void installsTerminalHandlerAfterDispatcher() { + new ConnectClientChannelConfigurer(freshConfig()).configure(channel); + + assertThat(channel.pipeline().names()).contains(ConnectClientPipeline.CALL_DISPATCHER); + assertThat(channel.pipeline().last()).isNotInstanceOf(ConnectClientCallDispatcher.class); + } + + @Test + void dispatcherIsReusedAcrossChannels() { + ConnectClientChannelConfigurer cfg = new ConnectClientChannelConfigurer(freshConfig()); + + EmbeddedChannel chA = new EmbeddedChannel(); + EmbeddedChannel chB = new EmbeddedChannel(); + try { + cfg.configure(chA); + cfg.configure(chB); + + assertThat(chA.pipeline().get(ConnectClientPipeline.CALL_DISPATCHER)) + .isSameAs(chB.pipeline().get(ConnectClientPipeline.CALL_DISPATCHER)); + } finally { + chA.finishAndReleaseAll(); + chB.finishAndReleaseAll(); + } + } +} diff --git a/src/test/java/io/suboptimal/connectjava/protocol/client/ConnectClientInterceptorPipelineTest.java b/src/test/java/io/suboptimal/connectjava/protocol/client/ConnectClientInterceptorPipelineTest.java new file mode 100644 index 0000000..655daba --- /dev/null +++ b/src/test/java/io/suboptimal/connectjava/protocol/client/ConnectClientInterceptorPipelineTest.java @@ -0,0 +1,226 @@ +package io.suboptimal.connectjava.protocol.client; + +import io.suboptimal.connectjava.api.ConnectError; +import io.suboptimal.connectjava.api.ConnectResponseMeta; +import io.suboptimal.connectjava.model.ConnectMethodDefinition; +import io.suboptimal.connectjava.model.ConnectMethodType; +import io.suboptimal.connectjava.model.ConnectServiceDefinition; +import io.suboptimal.connectjava.protocol.ClientTestSupport; +import io.suboptimal.connectjava.testfixtures.UnaryPostRequest; +import io.suboptimal.connectjava.testfixtures.UnaryPostResponse; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; + +import static org.assertj.core.api.Assertions.assertThat; + +class ConnectClientInterceptorPipelineTest { + private static final ConnectMethodDefinition METHOD = new ConnectMethodDefinition( + "Unary", ConnectMethodType.UNARY, UnaryPostRequest.class, UnaryPostResponse.class, false); + private static final ConnectServiceDefinition SERVICE = new ConnectServiceDefinition( + "svc.Service", List.of(METHOD), null); + private static final ConnectClientCallStart CALL_START = new ConnectClientCallStart( + SERVICE, METHOD, Map.of(), false, "proto"); + private static final ConnectResponseMeta META = + new ConnectResponseMeta(200, Map.of(), Map.of()); + + @Test + void emptyPipelineContinues() { + ConnectClientInterceptor.Decision d = ConnectClientInterceptorPipeline.EMPTY.interceptCall(CALL_START); + + assertThat(d).isInstanceOf(ConnectClientInterceptor.Decision.Continue.class); + assertThat(d.observer()).isSameAs(ConnectClientCallObserver.NOOP); + assertThat(((ConnectClientInterceptor.Decision.Continue) d).callStart()).isSameAs(CALL_START); + } + + @Test + void allNoOpObserversProduceNoOp() { + ConnectClientInterceptorPipeline pipeline = new ConnectClientInterceptorPipeline(List.of( + cs -> ConnectClientInterceptor.continueCall(), + cs -> ConnectClientInterceptor.continueCall() + )); + + assertThat(pipeline.interceptCall(CALL_START).observer()).isSameAs(ConnectClientCallObserver.NOOP); + } + + @Test + void singleNonNoOpObserverIsReturnedDirectly() { + ClientTestSupport.RecordingObserver observer = new ClientTestSupport.RecordingObserver(); + ConnectClientInterceptorPipeline pipeline = new ConnectClientInterceptorPipeline(List.of( + cs -> ConnectClientInterceptor.continueWith(observer) + )); + + assertThat(pipeline.interceptCall(CALL_START).observer()).isSameAs(observer); + } + + @Test + void noOpObserversAreFilteredFromComposite() { + ClientTestSupport.RecordingObserver real = new ClientTestSupport.RecordingObserver(); + ConnectClientInterceptorPipeline pipeline = new ConnectClientInterceptorPipeline(List.of( + cs -> ConnectClientInterceptor.continueCall(), + cs -> ConnectClientInterceptor.continueWith(real), + cs -> ConnectClientInterceptor.continueCall() + )); + + assertThat(pipeline.interceptCall(CALL_START).observer()).isSameAs(real); + } + + @Test + void requestPayloadCallbacksAreFIFO() { + List log = new ArrayList<>(); + ConnectClientInterceptorPipeline pipeline = new ConnectClientInterceptorPipeline(List.of( + cs -> ConnectClientInterceptor.continueWith(new ConnectClientCallObserver() { + @Override public void onRequestPayload(Object p) { log.add("first"); } + }), + cs -> ConnectClientInterceptor.continueWith(new ConnectClientCallObserver() { + @Override public void onRequestPayload(Object p) { log.add("second"); } + }) + )); + + ConnectClientCallObserver composite = pipeline.interceptCall(CALL_START).observer(); + composite.onRequestPayload("x"); + + assertThat(log).containsExactly("first", "second"); + } + + @Test + void requestFinishedCallbacksAreFIFO() { + List log = new ArrayList<>(); + ConnectClientInterceptorPipeline pipeline = new ConnectClientInterceptorPipeline(List.of( + cs -> ConnectClientInterceptor.continueWith(new ConnectClientCallObserver() { + @Override public void onRequestFinished() { log.add("first"); } + }), + cs -> ConnectClientInterceptor.continueWith(new ConnectClientCallObserver() { + @Override public void onRequestFinished() { log.add("second"); } + }) + )); + + ConnectClientCallObserver composite = pipeline.interceptCall(CALL_START).observer(); + composite.onRequestFinished(); + + assertThat(log).containsExactly("first", "second"); + } + + @Test + void responsePayloadCallbacksAreFIFO() { + List log = new ArrayList<>(); + ConnectClientInterceptorPipeline pipeline = new ConnectClientInterceptorPipeline(List.of( + cs -> ConnectClientInterceptor.continueWith(new ConnectClientCallObserver() { + @Override public void onResponsePayload(Object p) { log.add("first"); } + }), + cs -> ConnectClientInterceptor.continueWith(new ConnectClientCallObserver() { + @Override public void onResponsePayload(Object p) { log.add("second"); } + }) + )); + + ConnectClientCallObserver composite = pipeline.interceptCall(CALL_START).observer(); + composite.onResponsePayload("x"); + + assertThat(log).containsExactly("first", "second"); + } + + @Test + void responseHeaderCallbacksAreLIFO() { + List log = new ArrayList<>(); + ConnectClientInterceptorPipeline pipeline = new ConnectClientInterceptorPipeline(List.of( + cs -> ConnectClientInterceptor.continueWith(new ConnectClientCallObserver() { + @Override public void onResponseHeaders(ConnectResponseMeta m) { log.add("first"); } + }), + cs -> ConnectClientInterceptor.continueWith(new ConnectClientCallObserver() { + @Override public void onResponseHeaders(ConnectResponseMeta m) { log.add("second"); } + }) + )); + + ConnectClientCallObserver composite = pipeline.interceptCall(CALL_START).observer(); + composite.onResponseHeaders(META); + + assertThat(log).containsExactly("second", "first"); + } + + @Test + void callCompleteCallbacksAreLIFO() { + List log = new ArrayList<>(); + ConnectClientInterceptorPipeline pipeline = new ConnectClientInterceptorPipeline(List.of( + cs -> ConnectClientInterceptor.continueWith(new ConnectClientCallObserver() { + @Override public void onCallComplete(ConnectError e) { log.add("first"); } + }), + cs -> ConnectClientInterceptor.continueWith(new ConnectClientCallObserver() { + @Override public void onCallComplete(ConnectError e) { log.add("second"); } + }) + )); + + ConnectClientCallObserver composite = pipeline.interceptCall(CALL_START).observer(); + composite.onCallComplete(null); + + assertThat(log).containsExactly("second", "first"); + } + + @Test + void rejectionStopsIterationAndReturnsCompositeOfPriorObservers() { + List callOrder = new ArrayList<>(); + List completedLog = new ArrayList<>(); + + ConnectClientInterceptor first = cs -> { + callOrder.add("first"); + return ConnectClientInterceptor.continueWith(new ConnectClientCallObserver() { + @Override public void onCallComplete(ConnectError e) { completedLog.add("first"); } + }); + }; + ConnectError rejectError = ConnectError.permissionDenied("no"); + ConnectClientInterceptor rejecting = cs -> { + callOrder.add("rejecting"); + return ConnectClientInterceptor.reject(rejectError); + }; + ConnectClientInterceptor notReached = cs -> { + callOrder.add("notReached"); + return ConnectClientInterceptor.continueCall(); + }; + + ConnectClientInterceptorPipeline pipeline = new ConnectClientInterceptorPipeline( + List.of(first, rejecting, notReached)); + + ConnectClientInterceptor.Decision d = pipeline.interceptCall(CALL_START); + + assertThat(d).isInstanceOf(ConnectClientInterceptor.Decision.Reject.class); + assertThat(((ConnectClientInterceptor.Decision.Reject) d).error()).isSameAs(rejectError); + assertThat(callOrder).containsExactly("first", "rejecting"); + + d.observer().onCallComplete(rejectError); + assertThat(completedLog).containsExactly("first"); + } + + @Test + void rejectionWithNoPriorContinueObserversReturnsNoOp() { + ConnectClientInterceptorPipeline pipeline = new ConnectClientInterceptorPipeline(List.of( + cs -> ConnectClientInterceptor.reject(ConnectError.unauthenticated("go away")) + )); + + ConnectClientInterceptor.Decision d = pipeline.interceptCall(CALL_START); + + assertThat(d).isInstanceOf(ConnectClientInterceptor.Decision.Reject.class); + assertThat(d.observer()).isSameAs(ConnectClientCallObserver.NOOP); + } + + @Test + void rewriteIsThreadedToNextInterceptor() { + AtomicReference seenBySecond = new AtomicReference<>(); + + ConnectClientInterceptor first = cs -> ConnectClientInterceptor.continueWith(cs.withHeader("x-a", "1")); + ConnectClientInterceptor second = cs -> { + seenBySecond.set(cs); + return ConnectClientInterceptor.continueWith(cs.withHeader("x-b", "2")); + }; + + ConnectClientInterceptorPipeline pipeline = new ConnectClientInterceptorPipeline(List.of(first, second)); + ConnectClientInterceptor.Decision d = pipeline.interceptCall(CALL_START); + + assertThat(seenBySecond.get().requestHeaders()).containsKey("x-a"); + + ConnectClientCallStart effective = ((ConnectClientInterceptor.Decision.Continue) d).callStart(); + assertThat(effective.requestHeaders()).containsKey("x-a"); + assertThat(effective.requestHeaders()).containsKey("x-b"); + } +} diff --git a/src/test/java/io/suboptimal/connectjava/protocol/client/ConnectClientProtocolConfigTest.java b/src/test/java/io/suboptimal/connectjava/protocol/client/ConnectClientProtocolConfigTest.java new file mode 100644 index 0000000..c330a3a --- /dev/null +++ b/src/test/java/io/suboptimal/connectjava/protocol/client/ConnectClientProtocolConfigTest.java @@ -0,0 +1,62 @@ +package io.suboptimal.connectjava.protocol.client; + +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.suboptimal.connectjava.codec.protobuf.ConnectProtobufCodecs; +import io.suboptimal.connectjava.compression.ConnectCompressionRegistry; +import io.suboptimal.connectjava.protocol.ConnectCallHandlerFactory; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class ConnectClientProtocolConfigTest { + private static final ConnectCallHandlerFactory FACTORY = ChannelInboundHandlerAdapter::new; + private static final ConnectClientProtocolParameters PARAMS = + new ConnectClientProtocolParameters(1024, 1024); + + @Test + void appliesDefaultsForOptionalFields() { + ConnectClientProtocolConfig config = ConnectClientProtocolConfig + .builder(FACTORY, PARAMS, ConnectProtobufCodecs.defaults()) + .build(); + + assertThat(config.jsonDeserializer()).isSameAs(ConnectStringBuilderJsonDeserializer.INSTANCE); + assertThat(config.compressionRegistry().supportedNames()).contains("gzip", "identity"); + assertThat(config.interceptors()).isEmpty(); + } + + @Test + void overridesOptionalFields() { + ConnectCompressionRegistry customCompression = ConnectCompressionRegistry.standard(); + ConnectClientInterceptor interceptor = callStart -> ConnectClientInterceptor.continueCall(); + + ConnectClientProtocolConfig config = ConnectClientProtocolConfig + .builder(FACTORY, PARAMS, ConnectProtobufCodecs.defaults()) + .compressionRegistry(customCompression) + .jsonDeserializer(ConnectStringBuilderJsonDeserializer.INSTANCE) + .interceptors(List.of(interceptor)) + .build(); + + assertThat(config.compressionRegistry()).isSameAs(customCompression); + assertThat(config.interceptors()).containsExactly(interceptor); + } + + @Test + void interceptorListIsImmutableCopy() { + List source = new ArrayList<>(); + source.add(callStart -> ConnectClientInterceptor.continueCall()); + + ConnectClientProtocolConfig config = ConnectClientProtocolConfig + .builder(FACTORY, PARAMS, ConnectProtobufCodecs.defaults()) + .interceptors(source) + .build(); + + source.add(callStart -> ConnectClientInterceptor.continueCall()); + assertThat(config.interceptors()).hasSize(1); + assertThatThrownBy(() -> config.interceptors().add(callStart -> ConnectClientInterceptor.continueCall())) + .isInstanceOf(UnsupportedOperationException.class); + } +} diff --git a/src/test/java/io/suboptimal/connectjava/protocol/client/ConnectStringBuilderJsonDeserializerTest.java b/src/test/java/io/suboptimal/connectjava/protocol/client/ConnectStringBuilderJsonDeserializerTest.java new file mode 100644 index 0000000..6643a04 --- /dev/null +++ b/src/test/java/io/suboptimal/connectjava/protocol/client/ConnectStringBuilderJsonDeserializerTest.java @@ -0,0 +1,182 @@ +package io.suboptimal.connectjava.protocol.client; + +import io.suboptimal.connectjava.api.ConnectError; +import io.suboptimal.connectjava.api.ConnectErrorCode; +import io.suboptimal.connectjava.api.ConnectErrorDetail; +import org.junit.jupiter.api.Test; + +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +class ConnectStringBuilderJsonDeserializerTest { + private static final ConnectStringBuilderJsonDeserializer D = + ConnectStringBuilderJsonDeserializer.INSTANCE; + + private static byte[] utf8(String s) { + return s.getBytes(StandardCharsets.UTF_8); + } + + // ---- parseError ---- + + @Test + void parseErrorReadsCodeMessageAndDetails() { + String b64 = Base64.getEncoder().encodeToString(new byte[]{1, 2, 3}); + String json = "{\"code\":\"not_found\",\"message\":\"nope\",\"details\":[{\"type\":\"google.rpc.RetryInfo\",\"value\":\"" + b64 + "\"}]}"; + + ConnectError e = D.parseError(utf8(json)); + + assertThat(e).isNotNull(); + assertThat(e.code()).isEqualTo(ConnectErrorCode.NOT_FOUND); + assertThat(e.message()).isEqualTo("nope"); + assertThat(e.details()).hasSize(1); + assertThat(e.details().get(0).type()).isEqualTo("google.rpc.RetryInfo"); + assertThat(e.details().get(0).value()).isEqualTo(new byte[]{1, 2, 3}); + } + + @Test + void parseErrorWithoutCodeReturnsNull() { + ConnectError e = D.parseError(utf8("{\"message\":\"oops\"}")); + assertThat(e).isNull(); + } + + @Test + void parseErrorWithUnknownCodeNameFallsBackToUnknown() { + ConnectError e = D.parseError(utf8("{\"code\":\"totally_made_up\",\"message\":\"x\"}")); + assertThat(e).isNotNull(); + assertThat(e.code()).isEqualTo(ConnectErrorCode.UNKNOWN); + } + + @Test + void parseErrorUnescapesMessage() { + String json = "{\"code\":\"internal\",\"message\":\"a \\\"quote\\\" and\\nnewline\"}"; + + ConnectError e = D.parseError(utf8(json)); + + assertThat(e).isNotNull(); + assertThat(e.message()).isEqualTo("a \"quote\" and\nnewline"); + } + + @Test + void parseErrorMissingMessageGivesEmptyString() { + ConnectError e = D.parseError(utf8("{\"code\":\"internal\"}")); + assertThat(e).isNotNull(); + assertThat(e.message()).isEqualTo(""); + } + + // ---- parseErrorBody ---- + + @Test + void parseErrorBodyReturnsRawFields() { + String json = "{\"code\":\"permission_denied\",\"message\":\"denied\",\"details\":[]}"; + + ConnectErrorBody b = D.parseErrorBody(utf8(json)); + + assertThat(b).isNotNull(); + assertThat(b.codeName()).isEqualTo("permission_denied"); + assertThat(b.message()).isEqualTo("denied"); + assertThat(b.details()).isEmpty(); + } + + @Test + void parseErrorBodyReturnsNullForNonErrorJson() { + ConnectErrorBody b = D.parseErrorBody(utf8("{\"foo\":\"bar\"}")); + assertThat(b).isNull(); + } + + @Test + void parseErrorBodySkipsMalformedBase64Detail() { + String json = "{\"code\":\"internal\",\"details\":[{\"type\":\"t\",\"value\":\"!!!not-base64!!!\"}]}"; + + ConnectErrorBody b = D.parseErrorBody(utf8(json)); + + assertThat(b).isNotNull(); + assertThat(b.details()).isEmpty(); + } + + @Test + void parseErrorBodyDecodesMultipleDetailsInOrder() { + String v1 = Base64.getEncoder().encodeToString(new byte[]{1}); + String v2 = Base64.getEncoder().encodeToString(new byte[]{2}); + String json = "{\"code\":\"internal\",\"details\":[" + + "{\"type\":\"a\",\"value\":\"" + v1 + "\"}," + + "{\"type\":\"b\",\"value\":\"" + v2 + "\"}" + + "]}"; + + ConnectErrorBody b = D.parseErrorBody(utf8(json)); + + assertThat(b).isNotNull(); + assertThat(b.details()).hasSize(2); + assertThat(b.details().get(0).value()).isEqualTo(new byte[]{1}); + assertThat(b.details().get(1).value()).isEqualTo(new byte[]{2}); + } + + // ---- parseEndStreamError ---- + + @Test + void parseEndStreamErrorReadsNestedError() { + ConnectError e = D.parseEndStreamError(utf8("{\"error\":{\"code\":\"not_found\",\"message\":\"gone\"}}")); + + assertThat(e).isNotNull(); + assertThat(e.code()).isEqualTo(ConnectErrorCode.NOT_FOUND); + assertThat(e.message()).isEqualTo("gone"); + } + + @Test + void parseEndStreamErrorReturnsNullWhenNoError() { + assertThat(D.parseEndStreamError(utf8("{}"))).isNull(); + assertThat(D.parseEndStreamError(utf8("{\"metadata\":{\"a\":[\"b\"]}}"))).isNull(); + } + + @Test + void parseEndStreamErrorWithoutCodeIsUnknown() { + ConnectError e = D.parseEndStreamError(utf8("{\"error\":{\"message\":\"oops\"}}")); + + assertThat(e).isNotNull(); + assertThat(e.code()).isEqualTo(ConnectErrorCode.UNKNOWN); + assertThat(e.message()).isEqualTo("oops"); + } + + @Test + void parseEndStreamErrorHandlesNestedBracesInDetails() { + String b64 = Base64.getEncoder().encodeToString(new byte[]{9}); + String json = "{\"error\":{\"code\":\"internal\",\"message\":\"m\",\"details\":[{\"type\":\"t\",\"value\":\"" + b64 + "\"}]}}"; + + ConnectError e = D.parseEndStreamError(utf8(json)); + + assertThat(e).isNotNull(); + assertThat(e.code()).isEqualTo(ConnectErrorCode.INTERNAL); + assertThat(e.details()).hasSize(1); + assertThat(e.details().get(0).value()).isEqualTo(new byte[]{9}); + } + + // ---- parseEndStreamMetadata ---- + + @Test + void parseEndStreamMetadataExtractsMultiValue() { + String json = "{\"metadata\":{\"foo\":[\"a\",\"b\"],\"bar\":[\"c\"]}}"; + + Map> m = D.parseEndStreamMetadata(utf8(json)); + + assertThat(m.get("foo")).containsExactly("a", "b"); + assertThat(m.get("bar")).containsExactly("c"); + } + + @Test + void parseEndStreamMetadataAbsentReturnsEmpty() { + assertThat(D.parseEndStreamMetadata(utf8("{}"))).isEmpty(); + assertThat(D.parseEndStreamMetadata(utf8("{\"error\":{\"code\":\"internal\"}}"))).isEmpty(); + } + + @Test + void parseEndStreamMetadataUnescapesValues() { + String json = "{\"metadata\":{\"k\":[\"line1\\nline2\"]}}"; + + Map> m = D.parseEndStreamMetadata(utf8(json)); + + assertThat(m.get("k").get(0)).isEqualTo("line1\nline2"); + } +} diff --git a/src/test/java/io/suboptimal/connectjava/protocol/client/StreamingClientHandlerTest.java b/src/test/java/io/suboptimal/connectjava/protocol/client/StreamingClientHandlerTest.java new file mode 100644 index 0000000..7f32d16 --- /dev/null +++ b/src/test/java/io/suboptimal/connectjava/protocol/client/StreamingClientHandlerTest.java @@ -0,0 +1,694 @@ +package io.suboptimal.connectjava.protocol.client; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.DefaultHttpContent; +import io.netty.handler.codec.http.DefaultHttpResponse; +import io.netty.handler.codec.http.HttpContent; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.handler.codec.http.LastHttpContent; +import io.suboptimal.connectjava.api.ConnectClientResponseStart; +import io.suboptimal.connectjava.api.ConnectEndOfStream; +import io.suboptimal.connectjava.api.ConnectError; +import io.suboptimal.connectjava.api.ConnectErrorCode; +import io.suboptimal.connectjava.api.ConnectPayload; +import io.suboptimal.connectjava.codec.ConnectCodec; +import io.suboptimal.connectjava.model.ConnectMethodDefinition; +import io.suboptimal.connectjava.model.ConnectMethodType; +import io.suboptimal.connectjava.model.ConnectServiceDefinition; +import io.suboptimal.connectjava.protocol.ClientTestSupport; +import io.suboptimal.connectjava.protocol.ConnectEnvelope; +import io.suboptimal.connectjava.testfixtures.StreamingRequest; +import io.suboptimal.connectjava.testfixtures.StreamingResponse; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +class StreamingClientHandlerTest { + private static final String SERVICE_NAME = "connectjava.test.v1.StreamingFixtureService"; + private static final ConnectMethodDefinition SERVER_STREAMING = new ConnectMethodDefinition( + "ServerStreaming", ConnectMethodType.SERVER_STREAMING, StreamingRequest.class, StreamingResponse.class, false); + private static final ConnectMethodDefinition CLIENT_STREAMING = new ConnectMethodDefinition( + "ClientStreaming", ConnectMethodType.CLIENT_STREAMING, StreamingRequest.class, StreamingResponse.class, false); + private static final ConnectMethodDefinition BIDI_STREAMING = new ConnectMethodDefinition( + "Bidi", ConnectMethodType.BIDI_STREAMING, StreamingRequest.class, StreamingResponse.class, false); + private static final ConnectServiceDefinition SERVICE = new ConnectServiceDefinition( + SERVICE_NAME, List.of(SERVER_STREAMING, CLIENT_STREAMING, BIDI_STREAMING), null); + + private final ConnectCodec proto = ClientTestSupport.protoCodec(); + private EmbeddedChannel channel; + private ClientTestSupport.RecordingObserver observer; + + @BeforeEach + void setUp() { + channel = new EmbeddedChannel(); + observer = new ClientTestSupport.RecordingObserver(); + } + + @AfterEach + void tearDown() { + channel.finishAndReleaseAll(); + } + + private void install(ConnectMethodDefinition method, Map> headers, + ConnectClientProtocolConfig config) { + ConnectClientCallStart callStart = + new ConnectClientCallStart(SERVICE, method, headers, false, "proto"); + channel.pipeline().addLast(new StreamingClientHandler(callStart, config, observer)); + } + + private void install(ConnectMethodDefinition method) { + install(method, Map.of(), ClientTestSupport.config()); + } + + private void start(ConnectMethodDefinition method) { + ConnectClientCallStart callStart = + new ConnectClientCallStart(SERVICE, method, Map.of(), false, "proto"); + channel.writeOutbound(callStart); + } + + private HttpResponse okStreamingResponse() { + HttpResponse response = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + response.headers().set(HttpHeaderNames.CONTENT_TYPE, "application/connect+proto"); + response.headers().set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED); + return response; + } + + private HttpContent dataFrame(StreamingResponse message) { + ByteBuf buf = ConnectEnvelope.encode(channel.alloc(), (byte) 0, + ClientTestSupport.encode(proto, message)); + return new DefaultHttpContent(buf); + } + + private HttpContent endStreamFrame(String json) { + ByteBuf buf = ConnectEnvelope.encode(channel.alloc(), ConnectEnvelope.FLAG_END_STREAM, + json.getBytes(StandardCharsets.UTF_8)); + return new DefaultHttpContent(buf); + } + + // ---- outbound ---- + + @Test + void sendsRequestHeadersOnCallStart() { + install(SERVER_STREAMING); + start(SERVER_STREAMING); + + HttpRequest request = channel.readOutbound(); + assertThat(request.method()).isEqualTo(HttpMethod.POST); + assertThat(request.uri()).isEqualTo("/" + SERVICE_NAME + "/ServerStreaming"); + assertThat(request.headers().get(HttpHeaderNames.CONTENT_TYPE)).isEqualTo("application/connect+proto"); + assertThat(request.headers().get(HttpHeaderNames.TRANSFER_ENCODING)).isEqualTo("chunked"); + assertThat(request.headers().get("connect-protocol-version")).isEqualTo("1"); + assertThat(request.headers().get("connect-accept-encoding")).contains("gzip"); + } + + @Test + void encodesPayloadAsEnvelopeFrame() throws IOException { + install(SERVER_STREAMING); + start(SERVER_STREAMING); + channel.readOutbound(); // request headers + + StreamingRequest req = StreamingRequest.newBuilder().setText("one").build(); + channel.writeOutbound(new ConnectPayload(req)); + + HttpContent content = channel.readOutbound(); + ByteBuf buf = content.content(); + byte flags = buf.readByte(); + int length = buf.readInt(); + assertThat(flags).isEqualTo((byte) 0); + byte[] payload = new byte[length]; + buf.readBytes(payload); + assertThat(proto.decode(Unpooled.wrappedBuffer(payload), StreamingRequest.class)).isEqualTo(req); + content.release(); + + assertThat(observer.events).containsExactly("onRequestPayload"); + } + + @Test + void endOfStreamSendsLastContent() { + install(SERVER_STREAMING); + start(SERVER_STREAMING); + channel.readOutbound(); + channel.writeOutbound(new ConnectPayload(StreamingRequest.newBuilder().setText("x").build())); + channel.readOutbound(); + + channel.writeOutbound(ConnectEndOfStream.INSTANCE); + Object last = channel.readOutbound(); + assertThat(last).isInstanceOf(LastHttpContent.class); + assertThat(observer.events).contains("onRequestFinished"); + } + + @Test + void rejectsBidiStreaming() { + install(BIDI_STREAMING); + start(BIDI_STREAMING); + + Object inbound = channel.readInbound(); + assertThat(inbound).isInstanceOf(ConnectError.class); + assertThat(((ConnectError) inbound).code()).isEqualTo(ConnectErrorCode.UNIMPLEMENTED); + Object outbound = channel.readOutbound(); + assertThat(outbound).isNull(); // no request sent + } + + @Test + void serverStreamingRejectsSecondRequest() { + install(SERVER_STREAMING); + start(SERVER_STREAMING); + channel.readOutbound(); + channel.writeOutbound(new ConnectPayload(StreamingRequest.newBuilder().setText("1").build())); + channel.readOutbound(); + + channel.writeOutbound(new ConnectPayload(StreamingRequest.newBuilder().setText("2").build())); + + Object inbound = channel.readInbound(); + assertThat(inbound).isInstanceOf(ConnectError.class); + assertThat(((ConnectError) inbound).code()).isEqualTo(ConnectErrorCode.UNIMPLEMENTED); + } + + // ---- inbound: server-streaming ---- + + @Test + void deliversExchangeBeforePayloads() { + install(SERVER_STREAMING); + start(SERVER_STREAMING); + channel.readOutbound(); + + channel.writeInbound(okStreamingResponse()); + + // BUG 3 regression: CallExchange must arrive before any payload. + Object first = channel.readInbound(); + assertThat(first).isInstanceOf(ConnectClientResponseStart.class); + + channel.writeInbound(dataFrame(StreamingResponse.newBuilder().setText("a").build())); + channel.writeInbound(dataFrame(StreamingResponse.newBuilder().setText("b").build())); + channel.writeInbound(endStreamFrame("{}")); + + Object p1 = channel.readInbound(); + Object p2 = channel.readInbound(); + Object eos = channel.readInbound(); + assertThat(p1).isInstanceOf(ConnectPayload.class); + assertThat(((ConnectPayload) p1).data()) + .isEqualTo(StreamingResponse.newBuilder().setText("a").build()); + assertThat(p2).isInstanceOf(ConnectPayload.class); + assertThat(eos).isInstanceOf(ConnectEndOfStream.class); + + assertThat(observer.events) + .containsExactly("onResponseHeaders", "onResponsePayload", "onResponsePayload", "onCallComplete"); + assertThat(observer.completeError).isNull(); + } + + @Test + void reassemblesFrameSplitAcrossChunks() { + install(SERVER_STREAMING); + start(SERVER_STREAMING); + channel.readOutbound(); + channel.writeInbound(okStreamingResponse()); + channel.readInbound(); // exchange + + ByteBuf frame = ConnectEnvelope.encode(channel.alloc(), (byte) 0, + ClientTestSupport.encode(proto, StreamingResponse.newBuilder().setText("split").build())); + int mid = frame.readableBytes() / 2; + ByteBuf part1 = frame.readSlice(mid).retain(); + ByteBuf part2 = frame.readSlice(frame.readableBytes()).retain(); + frame.release(); + + channel.writeInbound(new DefaultHttpContent(part1)); + Object partial = channel.readInbound(); + assertThat(partial).isNull(); // not yet complete + channel.writeInbound(new DefaultHttpContent(part2)); + + Object payload = channel.readInbound(); + assertThat(payload).isInstanceOf(ConnectPayload.class); + assertThat(((ConnectPayload) payload).data()) + .isEqualTo(StreamingResponse.newBuilder().setText("split").build()); + } + + // ---- inbound: client-streaming ---- + + @Test + void clientStreamingRejectsSecondResponse() { + install(CLIENT_STREAMING); + start(CLIENT_STREAMING); + channel.readOutbound(); + channel.writeInbound(okStreamingResponse()); + channel.readInbound(); // exchange + + channel.writeInbound(dataFrame(StreamingResponse.newBuilder().setText("1").build())); + channel.readInbound(); // first payload + channel.writeInbound(dataFrame(StreamingResponse.newBuilder().setText("2").build())); + + Object inbound = channel.readInbound(); + assertThat(inbound).isInstanceOf(ConnectError.class); + assertThat(((ConnectError) inbound).code()).isEqualTo(ConnectErrorCode.UNIMPLEMENTED); + assertThat(((ConnectError) inbound).message()).contains("more than one response message"); + } + + // ---- inbound: errors ---- + + @Test + void rejectsResponseWithMismatchedCodec() { + install(SERVER_STREAMING); // callStart has codecName="proto" + start(SERVER_STREAMING); + channel.readOutbound(); + + // server responds with JSON codec while client requested proto + HttpResponse response = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + response.headers().set(HttpHeaderNames.CONTENT_TYPE, "application/connect+json"); + response.headers().set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED); + channel.writeInbound(response); + + Object inbound = channel.readInbound(); + assertThat(inbound).isInstanceOf(ConnectError.class); + assertThat(((ConnectError) inbound).code()).isEqualTo(ConnectErrorCode.INTERNAL); + assertThat(((ConnectError) inbound).message()).contains("json"); + assertThat(observer.completeCount).isEqualTo(1); + // no exchange must be fired — check happened before onResponseHeaders + assertThat(observer.events).doesNotContain("onResponseHeaders"); + } + + @Test + void rejectsCompressedDataFrameWhenCompressionNotNegotiated() { + install(SERVER_STREAMING); + start(SERVER_STREAMING); + channel.readOutbound(); + // response without connect-content-encoding — responseEncoding stays identity + channel.writeInbound(okStreamingResponse()); + channel.readInbound(); // exchange + + // data frame with FLAG_COMPRESSED, but no compression was negotiated + ByteBuf buf = ConnectEnvelope.encode(channel.alloc(), ConnectEnvelope.FLAG_COMPRESSED, + new byte[] {1, 2, 3, 4}); + channel.writeInbound(new DefaultHttpContent(buf)); + + Object inbound = channel.readInbound(); + assertThat(inbound).isInstanceOf(ConnectError.class); + assertThat(((ConnectError) inbound).code()).isEqualTo(ConnectErrorCode.INTERNAL); + assertThat(((ConnectError) inbound).message()).contains("compression"); + assertThat(observer.completeCount).isEqualTo(1); + } + + @Test + void endStreamErrorPropagated() { + install(SERVER_STREAMING); + start(SERVER_STREAMING); + channel.readOutbound(); + channel.writeInbound(okStreamingResponse()); + channel.readInbound(); // exchange + + channel.writeInbound(endStreamFrame("{\"error\":{\"code\":\"not_found\",\"message\":\"gone\"}}")); + + Object inbound = channel.readInbound(); + assertThat(inbound).isInstanceOf(ConnectEndOfStream.class); + ConnectEndOfStream eos = (ConnectEndOfStream) inbound; + assertThat(eos.error()).isNotNull(); + assertThat(eos.error().code()).isEqualTo(ConnectErrorCode.NOT_FOUND); + assertThat(observer.completeError).isNotNull(); + } + + @Test + void endStreamErrorCarriesTrailers() { + install(SERVER_STREAMING); + start(SERVER_STREAMING); + channel.readOutbound(); + channel.writeInbound(okStreamingResponse()); + channel.readInbound(); // exchange + + channel.writeInbound(endStreamFrame( + "{\"error\":{\"code\":\"not_found\",\"message\":\"gone\"}," + + "\"metadata\":{\"x-custom-trailer\":[\"a\",\"b\"]}}")); + + Object inbound = channel.readInbound(); + assertThat(inbound).isInstanceOf(ConnectEndOfStream.class); + ConnectEndOfStream eos = (ConnectEndOfStream) inbound; + assertThat(eos.error()).isNotNull(); + assertThat(eos.error().code()).isEqualTo(ConnectErrorCode.NOT_FOUND); + assertThat(eos.trailers().get("x-custom-trailer")).containsExactly("a", "b"); + assertThat(observer.completeError).isNotNull(); + + // ровно одно терминальное сообщение, отдельного ConnectError быть не должно + assertThat((Object) channel.readInbound()).isNull(); + } + + @Test + void nonOkResponseFailsCall() { + install(SERVER_STREAMING); + start(SERVER_STREAMING); + channel.readOutbound(); + + HttpResponse response = new DefaultHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.SERVICE_UNAVAILABLE); + channel.writeInbound(response); + + Object inbound = channel.readInbound(); + assertThat(inbound).isInstanceOf(ConnectError.class); + assertThat(((ConnectError) inbound).code()).isEqualTo(ConnectErrorCode.UNAVAILABLE); + } + + @Test + void missingContentTypeOnOkFailsCall() { + install(SERVER_STREAMING); + start(SERVER_STREAMING); + channel.readOutbound(); + + HttpResponse response = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + channel.writeInbound(response); + + Object inbound = channel.readInbound(); + assertThat(inbound).isInstanceOf(ConnectError.class); + assertThat(((ConnectError) inbound).code()).isEqualTo(ConnectErrorCode.UNKNOWN); + assertThat(((ConnectError) inbound).message()).contains("Content-Type"); + } + + @Test + void truncatedStreamFails() { + install(SERVER_STREAMING); + start(SERVER_STREAMING); + channel.readOutbound(); + channel.writeInbound(okStreamingResponse()); + channel.readInbound(); // exchange + + channel.writeInbound(LastHttpContent.EMPTY_LAST_CONTENT); + + Object inbound = channel.readInbound(); + assertThat(inbound).isInstanceOf(ConnectError.class); + assertThat(((ConnectError) inbound).message()).contains("Truncated stream"); + } + + @Test + void frameLargerThanLimitFails() { + install(SERVER_STREAMING, Map.of(), ClientTestSupport.configWithMaxFrameBytes(4)); + start(SERVER_STREAMING); + channel.readOutbound(); + channel.writeInbound(okStreamingResponse()); + channel.readInbound(); // exchange + + channel.writeInbound(dataFrame(StreamingResponse.newBuilder().setText("way too long").build())); + + Object inbound = channel.readInbound(); + assertThat(inbound).isInstanceOf(ConnectError.class); + assertThat(((ConnectError) inbound).code()).isEqualTo(ConnectErrorCode.RESOURCE_EXHAUSTED); + } + + private HttpContent compressedEndStreamFrame(String json) { + byte[] compressed = ClientTestSupport.gzipCompress(json.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + ByteBuf buf = ConnectEnvelope.encode(channel.alloc(), + (byte) (ConnectEnvelope.FLAG_END_STREAM | ConnectEnvelope.FLAG_COMPRESSED), compressed); + return new DefaultHttpContent(buf); + } + + private HttpResponse gzipStreamingResponse() { + HttpResponse response = okStreamingResponse(); + response.headers().set("connect-content-encoding", "gzip"); + return response; + } + + @Test + void compressedEndStreamSuccessIsDecompressedAndParsed() { + install(SERVER_STREAMING); + start(SERVER_STREAMING); + channel.readOutbound(); + channel.writeInbound(gzipStreamingResponse()); + channel.readInbound(); // exchange + + channel.writeInbound(compressedEndStreamFrame("{}")); + + Object eos = channel.readInbound(); + assertThat(eos).isInstanceOf(ConnectEndOfStream.class); + assertThat(observer.completeError).isNull(); + assertThat(observer.completeCount).isEqualTo(1); + } + + @Test + void compressedEndStreamErrorIsDecompressedAndParsed() { + install(SERVER_STREAMING); + start(SERVER_STREAMING); + channel.readOutbound(); + channel.writeInbound(gzipStreamingResponse()); + channel.readInbound(); // exchange + + channel.writeInbound(compressedEndStreamFrame( + "{\"error\":{\"code\":\"not_found\",\"message\":\"gone\"}}")); + + Object err = channel.readInbound(); + assertThat(err).isInstanceOf(ConnectEndOfStream.class); + ConnectEndOfStream eos = (ConnectEndOfStream) err; + assertThat(eos.error()).isNotNull(); + assertThat(eos.error().code()).isEqualTo(ConnectErrorCode.NOT_FOUND); + assertThat(observer.completeError).isNotNull(); + } + + @Test + void corruptCompressedEndStreamFrameFails() { + install(SERVER_STREAMING); + start(SERVER_STREAMING); + channel.readOutbound(); + channel.writeInbound(gzipStreamingResponse()); + channel.readInbound(); // exchange + + ByteBuf buf = ConnectEnvelope.encode(channel.alloc(), + (byte) (ConnectEnvelope.FLAG_END_STREAM | ConnectEnvelope.FLAG_COMPRESSED), + new byte[] {1, 2, 3, 4}); + channel.writeInbound(new DefaultHttpContent(buf)); + + Object inbound = channel.readInbound(); + assertThat(inbound).isInstanceOf(ConnectError.class); + assertThat(((ConnectError) inbound).message()).contains("Decompression failed"); + assertThat(observer.completeCount).isEqualTo(1); + } + + @Test + void endStreamWithMissingErrorCodeIsUnknown() { + install(SERVER_STREAMING); + start(SERVER_STREAMING); + channel.readOutbound(); + channel.writeInbound(okStreamingResponse()); + channel.readInbound(); // exchange + + channel.writeInbound(endStreamFrame("{\"error\":{\"message\":\"oops\"}}")); + + Object err = channel.readInbound(); + assertThat(err).isInstanceOf(ConnectEndOfStream.class); + ConnectEndOfStream eos = (ConnectEndOfStream) err; + assertThat(eos.error()).isNotNull(); + assertThat(eos.error().code()).isEqualTo(ConnectErrorCode.UNKNOWN); + assertThat(eos.error().message()).isEqualTo("oops"); + } + + @Test + void corruptCompressedFrameFails() { + install(SERVER_STREAMING); + start(SERVER_STREAMING); + channel.readOutbound(); + + HttpResponse response = okStreamingResponse(); + response.headers().set("connect-content-encoding", "gzip"); + channel.writeInbound(response); + channel.readInbound(); // exchange + + ByteBuf buf = ConnectEnvelope.encode(channel.alloc(), ConnectEnvelope.FLAG_COMPRESSED, + new byte[] {1, 2, 3, 4}); + channel.writeInbound(new DefaultHttpContent(buf)); + + Object inbound = channel.readInbound(); + assertThat(inbound).isInstanceOf(ConnectError.class); + assertThat(((ConnectError) inbound).message()).contains("Decompression failed"); + } + + // ---- lifecycle ---- + + @Test + void channelInactiveMidStreamCancels() { + install(SERVER_STREAMING); + start(SERVER_STREAMING); + channel.readOutbound(); + channel.writeInbound(okStreamingResponse()); + channel.readInbound(); // exchange + + channel.pipeline().fireChannelInactive(); + + Object inbound = channel.readInbound(); + assertThat(inbound).isInstanceOf(ConnectError.class); + assertThat(((ConnectError) inbound).code()).isEqualTo(ConnectErrorCode.CANCELED); + assertThat(observer.completeCount).isEqualTo(1); + } + + @Test + void channelInactiveAfterCompletionDoesNotCompleteTwice() { + install(SERVER_STREAMING); + start(SERVER_STREAMING); + channel.readOutbound(); + channel.writeInbound(okStreamingResponse()); + channel.readInbound(); // exchange + channel.writeInbound(endStreamFrame("{}")); + + channel.pipeline().fireChannelInactive(); + + assertThat(observer.completeCount).isEqualTo(1); + } + + // ---- §5.1 end-stream trailers ---- + + @Test + void endStreamMetadataBecomesEndOfStreamTrailers() { + install(SERVER_STREAMING); + start(SERVER_STREAMING); + channel.readOutbound(); + channel.writeInbound(okStreamingResponse()); + channel.readInbound(); // exchange + + channel.writeInbound(endStreamFrame("{\"metadata\":{\"x-foo\":[\"a\",\"b\"]}}")); + + Object eos = channel.readInbound(); + assertThat(eos).isInstanceOf(ConnectEndOfStream.class); + assertThat(((ConnectEndOfStream) eos).trailers().get("x-foo")).containsExactly("a", "b"); + assertThat(observer.completeError).isNull(); + assertThat(observer.completeCount).isEqualTo(1); + } + + // ---- §5.2 outbound compression ---- + + @Test + void compressesOutboundDataFrameWhenContentEncodingGzip() { + install(SERVER_STREAMING, Map.of("content-encoding", List.of("gzip")), ClientTestSupport.config()); + start(SERVER_STREAMING); + + HttpRequest req = channel.readOutbound(); + assertThat(req.headers().get("connect-content-encoding")).isEqualTo("gzip"); + + StreamingRequest msg = StreamingRequest.newBuilder().setText("zip").build(); + channel.writeOutbound(new ConnectPayload(msg)); + + HttpContent content = channel.readOutbound(); + ByteBuf buf = content.content(); + byte flags = buf.readByte(); + int len = buf.readInt(); + + boolean compressed = (flags & ConnectEnvelope.FLAG_COMPRESSED) != 0; + assertThat(compressed).isTrue(); + + byte[] payload = new byte[len]; + buf.readBytes(payload); + byte[] decompressed = ClientTestSupport.gzipDecompress(payload); + assertThat(decompressed).isEqualTo(ClientTestSupport.encode(proto, msg)); + content.release(); + } + + // ---- §5.3 client-streaming multiple requests ---- + + @Test + void clientStreamingSendsMultipleRequestFrames() throws Exception { + install(CLIENT_STREAMING); + start(CLIENT_STREAMING); + channel.readOutbound(); // headers + + channel.writeOutbound(new ConnectPayload(StreamingRequest.newBuilder().setText("1").build())); + channel.writeOutbound(new ConnectPayload(StreamingRequest.newBuilder().setText("2").build())); + + HttpContent c1 = channel.readOutbound(); + HttpContent c2 = channel.readOutbound(); + + Object noError = channel.readInbound(); + assertThat(noError).isNull(); + + ByteBuf buf1 = c1.content(); + buf1.readByte(); // flags + int len1 = buf1.readInt(); + byte[] payload1 = new byte[len1]; + buf1.readBytes(payload1); + assertThat(proto.decode(Unpooled.wrappedBuffer(payload1), StreamingRequest.class).getText()).isEqualTo("1"); + + ByteBuf buf2 = c2.content(); + buf2.readByte(); // flags + int len2 = buf2.readInt(); + byte[] payload2 = new byte[len2]; + buf2.readBytes(payload2); + assertThat(proto.decode(Unpooled.wrappedBuffer(payload2), StreamingRequest.class).getText()).isEqualTo("2"); + + assertThat(observer.events).contains("onRequestPayload", "onRequestPayload"); + c1.release(); + c2.release(); + } + + // ---- §5.4 serialization failure ---- + + @Test + void serializationFailureOnOutboundPayloadFailsCall() { + install(SERVER_STREAMING); + start(SERVER_STREAMING); + channel.readOutbound(); + + channel.writeOutbound(new ConnectPayload("not a protobuf message")); + + Object inbound = channel.readInbound(); + assertThat(inbound).isInstanceOf(ConnectError.class); + assertThat(((ConnectError) inbound).code()).isEqualTo(ConnectErrorCode.INTERNAL); + assertThat(((ConnectError) inbound).message()).contains("Serialization failed"); + assertThat(observer.completeCount).isEqualTo(1); + } + + // ---- §5.5 undecodable data frame ---- + + @Test + void undecodableDataFrameFailsCleanly() { + install(SERVER_STREAMING); + start(SERVER_STREAMING); + channel.readOutbound(); + channel.writeInbound(okStreamingResponse()); + channel.readInbound(); // exchange + + ByteBuf buf = ConnectEnvelope.encode(channel.alloc(), (byte) 0, + new byte[]{(byte) 0xFF, (byte) 0xFF, (byte) 0xFF}); + channel.writeInbound(new DefaultHttpContent(buf)); + + Object inbound = channel.readInbound(); + assertThat(inbound).isInstanceOf(ConnectError.class); + assertThat(((ConnectError) inbound).code()).isEqualTo(ConnectErrorCode.INTERNAL); + assertThat(((ConnectError) inbound).message()).contains("Deserialization failed"); + assertThat(observer.completeCount).isEqualTo(1); + } + + // ---- §5.6 handlerRemoved mid-stream ---- + + @Test + void handlerRemovedMidStreamCancels() { + install(SERVER_STREAMING); + start(SERVER_STREAMING); + channel.readOutbound(); + channel.writeInbound(okStreamingResponse()); + channel.readInbound(); // exchange + + channel.pipeline().remove(StreamingClientHandler.class); + + Object inbound = channel.readInbound(); + assertThat(inbound).isInstanceOf(ConnectError.class); + assertThat(((ConnectError) inbound).code()).isEqualTo(ConnectErrorCode.CANCELED); + assertThat(observer.completeCount).isEqualTo(1); + } + + // ---- §6.2 streaming timeout ---- + + @Test + void sendsConnectTimeoutHeaderWhenTimeoutSet() { + ConnectClientCallStart cs = new ConnectClientCallStart( + SERVICE, SERVER_STREAMING, Map.of(), false, "proto", 1500L); + channel.pipeline().addLast(new StreamingClientHandler(cs, ClientTestSupport.config(), observer)); + channel.writeOutbound(cs); + + HttpRequest request = channel.readOutbound(); + assertThat(request.headers().get("connect-timeout-ms")).isEqualTo("1500"); + } +} diff --git a/src/test/java/io/suboptimal/connectjava/protocol/client/UnaryGetRequestClientHandlerTest.java b/src/test/java/io/suboptimal/connectjava/protocol/client/UnaryGetRequestClientHandlerTest.java new file mode 100644 index 0000000..a4ad0ac --- /dev/null +++ b/src/test/java/io/suboptimal/connectjava/protocol/client/UnaryGetRequestClientHandlerTest.java @@ -0,0 +1,183 @@ +package io.suboptimal.connectjava.protocol.client; + +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.QueryStringDecoder; +import io.suboptimal.connectjava.api.ConnectEndOfStream; +import io.suboptimal.connectjava.api.ConnectErrorCode; +import io.suboptimal.connectjava.api.ConnectPayload; +import io.suboptimal.connectjava.codec.ConnectCodec; +import io.suboptimal.connectjava.model.ConnectMethodDefinition; +import io.suboptimal.connectjava.model.ConnectMethodType; +import io.suboptimal.connectjava.model.ConnectServiceDefinition; +import io.suboptimal.connectjava.protocol.ClientTestSupport; +import io.suboptimal.connectjava.testfixtures.UnaryGetRequest; +import io.suboptimal.connectjava.testfixtures.UnaryGetResponse; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.Base64; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class UnaryGetRequestClientHandlerTest { + private static final String SERVICE_NAME = "connectjava.test.v1.UnaryGetFixtureService"; + private static final ConnectMethodDefinition METHOD = new ConnectMethodDefinition( + "SafeUnary", ConnectMethodType.UNARY, UnaryGetRequest.class, UnaryGetResponse.class, true); + private static final ConnectServiceDefinition SERVICE = new ConnectServiceDefinition( + SERVICE_NAME, List.of(METHOD), null); + private static final UnaryGetRequest REQUEST = + UnaryGetRequest.newBuilder().setText("ping").build(); + + private final ConnectClientProtocolConfig config = ClientTestSupport.config(); + private final ConnectCodec proto = ClientTestSupport.protoCodec(); + private EmbeddedChannel channel; + + @BeforeEach + void setUp() { + channel = new EmbeddedChannel(); + } + + @AfterEach + void tearDown() { + channel.finishAndReleaseAll(); + } + + private ConnectClientCallStart callStart(Map> headers) { + return new ConnectClientCallStart(SERVICE, METHOD, headers, true, "proto"); + } + + private void install(ClientTestSupport.RecordingObserver observer, ConnectClientCallStart callStart) { + channel.pipeline().addLast(ConnectClientPipeline.UNARY_GET_HANDLER, + new UnaryGetRequestClientHandler(callStart, config, observer)); + } + + @Test + void sendsGetRequestWithBase64QueryAndInstallsResponseHandler() { + var observer = new ClientTestSupport.RecordingObserver(); + ConnectClientCallStart callStart = callStart(Map.of()); + install(observer, callStart); + + channel.writeOutbound(callStart); + channel.writeOutbound(new ConnectPayload(REQUEST)); + channel.writeOutbound(ConnectEndOfStream.INSTANCE); + + FullHttpRequest request = channel.readOutbound(); + assertThat(request.method()).isEqualTo(HttpMethod.GET); + assertThat(request.content().readableBytes()).isZero(); + + QueryStringDecoder query = new QueryStringDecoder(request.uri()); + assertThat(query.path()).isEqualTo("/" + SERVICE_NAME + "/SafeUnary"); + assertThat(first(query, "encoding")).isEqualTo("proto"); + assertThat(first(query, "base64")).isEqualTo("1"); + assertThat(first(query, "connect")).isEqualTo("v1"); + + byte[] decodedMessage = Base64.getUrlDecoder().decode(first(query, "message")); + assertThat(decodedMessage).isEqualTo(ClientTestSupport.encode(proto, REQUEST)); + request.release(); + + assertThat(channel.pipeline().get(ConnectClientPipeline.UNARY_RESPONSE_HANDLER)).isNotNull(); + assertThat(observer.events).containsExactly("onRequestPayload", "onRequestFinished"); + } + + @Test + void compressesGetPayloadWhenContentEncodingGzip() { + var observer = new ClientTestSupport.RecordingObserver(); + ConnectClientCallStart callStart = callStart(Map.of("content-encoding", List.of("gzip"))); + install(observer, callStart); + + channel.writeOutbound(callStart); + channel.writeOutbound(new ConnectPayload(REQUEST)); + channel.writeOutbound(ConnectEndOfStream.INSTANCE); + + FullHttpRequest request = channel.readOutbound(); + QueryStringDecoder query = new QueryStringDecoder(request.uri()); + assertThat(first(query, "compression")).isEqualTo("gzip"); + + byte[] decodedMessage = Base64.getUrlDecoder().decode(first(query, "message")); + byte[] decompressed = ClientTestSupport.gzipDecompress(decodedMessage); + assertThat(decompressed).isEqualTo(ClientTestSupport.encode(proto, REQUEST)); + request.release(); + } + + @Test + void handlerRemovedBeforeRequestSentCompletesCallAsCanceled() { + var observer = new ClientTestSupport.RecordingObserver(); + ConnectClientCallStart callStart = callStart(Map.of()); + install(observer, callStart); + channel.writeOutbound(callStart); + + channel.pipeline().remove(ConnectClientPipeline.UNARY_GET_HANDLER); + + assertThat(observer.completeCount).isEqualTo(1); + assertThat(observer.completeError.code()).isEqualTo(ConnectErrorCode.CANCELED); + } + + // TDD: описывает корректное поведение; может падать на текущей реализации (см. ТЗ §6.3) + @Test + void sendsConnectTimeoutHeaderWhenTimeoutSet() { + var observer = new ClientTestSupport.RecordingObserver(); + ConnectClientCallStart cs = + new ConnectClientCallStart(SERVICE, METHOD, Map.of(), true, "proto", 3000L); + install(observer, cs); + + channel.writeOutbound(cs); + channel.writeOutbound(new ConnectPayload(REQUEST)); + channel.writeOutbound(ConnectEndOfStream.INSTANCE); + + FullHttpRequest request = channel.readOutbound(); + assertThat(request.headers().get("connect-timeout-ms")).isEqualTo("3000"); + request.release(); + } + + @Test + void failsLateWriteBeforeCallStart() { + var observer = new ClientTestSupport.RecordingObserver(); + install(observer, callStart(Map.of())); + + assertThatThrownBy(() -> channel.writeOutbound(new ConnectPayload(REQUEST))) + .isInstanceOf(ConnectCallTerminatedException.class); + } + + @Test + void handlerRemovedAfterRequestSentDoesNotComplete() { + var observer = new ClientTestSupport.RecordingObserver(); + ConnectClientCallStart cs = callStart(Map.of()); + install(observer, cs); + + channel.writeOutbound(cs); + channel.writeOutbound(new ConnectPayload(REQUEST)); + channel.writeOutbound(ConnectEndOfStream.INSTANCE); + FullHttpRequest request = channel.readOutbound(); + request.release(); + + channel.pipeline().remove(ConnectClientPipeline.UNARY_GET_HANDLER); + + assertThat(observer.completeCount).isZero(); + } + + @Test + void handlerRemovedWithBufferedPayloadCompletesCallAsCanceled() { + var observer = new ClientTestSupport.RecordingObserver(); + ConnectClientCallStart cs = callStart(Map.of()); + install(observer, cs); + + channel.writeOutbound(cs); + channel.writeOutbound(new ConnectPayload(REQUEST)); + + channel.pipeline().remove(ConnectClientPipeline.UNARY_GET_HANDLER); + + assertThat(observer.completeCount).isEqualTo(1); + assertThat(observer.completeError.code()).isEqualTo(ConnectErrorCode.CANCELED); + } + + private static String first(QueryStringDecoder query, String name) { + List values = query.parameters().get(name); + return values == null || values.isEmpty() ? null : values.getFirst(); + } +} diff --git a/src/test/java/io/suboptimal/connectjava/protocol/client/UnaryPostRequestClientHandlerTest.java b/src/test/java/io/suboptimal/connectjava/protocol/client/UnaryPostRequestClientHandlerTest.java new file mode 100644 index 0000000..4ea6df2 --- /dev/null +++ b/src/test/java/io/suboptimal/connectjava/protocol/client/UnaryPostRequestClientHandlerTest.java @@ -0,0 +1,208 @@ +package io.suboptimal.connectjava.protocol.client; + +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpMethod; +import io.suboptimal.connectjava.api.ConnectEndOfStream; +import io.suboptimal.connectjava.api.ConnectErrorCode; +import io.suboptimal.connectjava.api.ConnectPayload; +import io.suboptimal.connectjava.codec.ConnectCodec; +import io.suboptimal.connectjava.model.ConnectMethodDefinition; +import io.suboptimal.connectjava.model.ConnectMethodType; +import io.suboptimal.connectjava.model.ConnectServiceDefinition; +import io.suboptimal.connectjava.protocol.ClientTestSupport; +import io.suboptimal.connectjava.testfixtures.UnaryPostRequest; +import io.suboptimal.connectjava.testfixtures.UnaryPostResponse; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class UnaryPostRequestClientHandlerTest { + private static final String SERVICE_NAME = "connectjava.test.v1.UnaryPostFixtureService"; + private static final ConnectMethodDefinition METHOD = new ConnectMethodDefinition( + "Unary", ConnectMethodType.UNARY, UnaryPostRequest.class, UnaryPostResponse.class, false); + private static final ConnectServiceDefinition SERVICE = new ConnectServiceDefinition( + SERVICE_NAME, List.of(METHOD), null); + private static final UnaryPostRequest REQUEST = + UnaryPostRequest.newBuilder().setText("hello").build(); + + private final ConnectClientProtocolConfig config = ClientTestSupport.config(); + private final ConnectCodec proto = ClientTestSupport.protoCodec(); + private EmbeddedChannel channel; + + @BeforeEach + void setUp() { + channel = new EmbeddedChannel(); + } + + @AfterEach + void tearDown() { + channel.finishAndReleaseAll(); + } + + private ConnectClientCallStart callStart(Map> headers) { + return new ConnectClientCallStart(SERVICE, METHOD, headers, false, "proto"); + } + + private void install(ClientTestSupport.RecordingObserver observer, ConnectClientCallStart callStart) { + channel.pipeline().addLast(ConnectClientPipeline.UNARY_POST_HANDLER, + new UnaryPostRequestClientHandler(callStart, config, observer)); + } + + @Test + void sendsPostRequestAndInstallsResponseHandler() throws IOException { + var observer = new ClientTestSupport.RecordingObserver(); + ConnectClientCallStart callStart = callStart(Map.of()); + install(observer, callStart); + + channel.writeOutbound(callStart); + channel.writeOutbound(new ConnectPayload(REQUEST)); + channel.writeOutbound(ConnectEndOfStream.INSTANCE); + + FullHttpRequest request = channel.readOutbound(); + assertThat(request.method()).isEqualTo(HttpMethod.POST); + assertThat(request.uri()).isEqualTo("/" + SERVICE_NAME + "/Unary"); + assertThat(request.headers().get(HttpHeaderNames.CONTENT_TYPE)).isEqualTo("application/proto"); + assertThat(request.headers().get("connect-protocol-version")).isEqualTo("1"); + assertThat(request.headers().get(HttpHeaderNames.CONTENT_LENGTH)) + .isEqualTo(String.valueOf(request.content().readableBytes())); + assertThat(request.headers().get(HttpHeaderNames.ACCEPT_ENCODING)).contains("gzip"); + + UnaryPostRequest decoded = proto.decode(request.content(), UnaryPostRequest.class); + assertThat(decoded).isEqualTo(REQUEST); + request.release(); + + assertThat(channel.pipeline().get(ConnectClientPipeline.UNARY_RESPONSE_HANDLER)).isNotNull(); + assertThat(observer.events).containsExactly("onRequestPayload", "onRequestFinished"); + } + + @Test + void compressesRequestWhenContentEncodingGzip() throws IOException { + var observer = new ClientTestSupport.RecordingObserver(); + ConnectClientCallStart callStart = + callStart(Map.of("content-encoding", List.of("gzip"))); + install(observer, callStart); + + channel.writeOutbound(callStart); + channel.writeOutbound(new ConnectPayload(REQUEST)); + channel.writeOutbound(ConnectEndOfStream.INSTANCE); + + FullHttpRequest request = channel.readOutbound(); + assertThat(request.headers().get(HttpHeaderNames.CONTENT_ENCODING)).isEqualTo("gzip"); + + byte[] body = new byte[request.content().readableBytes()]; + request.content().getBytes(request.content().readerIndex(), body); + byte[] decompressed = ClientTestSupport.gzipDecompress(body); + assertThat(decompressed).isEqualTo(ClientTestSupport.encode(proto, REQUEST)); + request.release(); + } + + @Test + void failsLateWriteBeforeCallStart() { + var observer = new ClientTestSupport.RecordingObserver(); + install(observer, callStart(Map.of())); + + assertThatThrownBy(() -> channel.writeOutbound(new ConnectPayload(REQUEST))) + .isInstanceOf(ConnectCallTerminatedException.class); + } + + @Test + void handlerRemovedBeforeRequestSentCompletesCallAsCanceled() { + var observer = new ClientTestSupport.RecordingObserver(); + ConnectClientCallStart callStart = callStart(Map.of()); + install(observer, callStart); + channel.writeOutbound(callStart); + + channel.pipeline().remove(ConnectClientPipeline.UNARY_POST_HANDLER); + + assertThat(observer.completeCount).isEqualTo(1); + assertThat(observer.completeError).isNotNull(); + assertThat(observer.completeError.code()).isEqualTo(ConnectErrorCode.CANCELED); + } + + @Test + void handlerRemovedAfterRequestSentDoesNotComplete() { + var observer = new ClientTestSupport.RecordingObserver(); + ConnectClientCallStart callStart = callStart(Map.of()); + install(observer, callStart); + + channel.writeOutbound(callStart); + channel.writeOutbound(new ConnectPayload(REQUEST)); + channel.writeOutbound(ConnectEndOfStream.INSTANCE); + FullHttpRequest request = channel.readOutbound(); + request.release(); + + channel.pipeline().remove(ConnectClientPipeline.UNARY_POST_HANDLER); + + assertThat(observer.completeCount).isZero(); + } + + @Test + void handlerRemovedWithBufferedPayloadCompletesCallAsCanceled() { + var observer = new ClientTestSupport.RecordingObserver(); + ConnectClientCallStart callStart = callStart(Map.of()); + install(observer, callStart); + + channel.writeOutbound(callStart); + channel.writeOutbound(new ConnectPayload(REQUEST)); + + channel.pipeline().remove(ConnectClientPipeline.UNARY_POST_HANDLER); + + assertThat(observer.completeCount).isEqualTo(1); + assertThat(observer.completeError.code()).isEqualTo(ConnectErrorCode.CANCELED); + } + + @Test + void sendsConnectTimeoutHeaderWhenTimeoutSet() { + var observer = new ClientTestSupport.RecordingObserver(); + ConnectClientCallStart callStartWithTimeout = + new ConnectClientCallStart(SERVICE, METHOD, Map.of(), false, "proto", 2500L); + install(observer, callStartWithTimeout); + + channel.writeOutbound(callStartWithTimeout); + channel.writeOutbound(new ConnectPayload(REQUEST)); + channel.writeOutbound(ConnectEndOfStream.INSTANCE); + + FullHttpRequest request = channel.readOutbound(); + assertThat(request.headers().get("connect-timeout-ms")).isEqualTo("2500"); + request.release(); + } + + @Test + void copiesUserHeadersToOutboundRequest() { + var observer = new ClientTestSupport.RecordingObserver(); + ConnectClientCallStart cs = callStart(Map.of("x-custom", List.of("v1", "v2"))); + install(observer, cs); + + channel.writeOutbound(cs); + channel.writeOutbound(new ConnectPayload(REQUEST)); + channel.writeOutbound(ConnectEndOfStream.INSTANCE); + + FullHttpRequest request = channel.readOutbound(); + assertThat(request.headers().getAll("x-custom")).containsExactly("v1", "v2"); + request.release(); + } + + @Test + void ignoresUserSuppliedReservedHeaders() { + var observer = new ClientTestSupport.RecordingObserver(); + ConnectClientCallStart cs = callStart(Map.of("content-type", List.of("text/bogus"))); + install(observer, cs); + + channel.writeOutbound(cs); + channel.writeOutbound(new ConnectPayload(REQUEST)); + channel.writeOutbound(ConnectEndOfStream.INSTANCE); + + FullHttpRequest request = channel.readOutbound(); + assertThat(request.headers().get(HttpHeaderNames.CONTENT_TYPE)).isEqualTo("application/proto"); + request.release(); + } +} diff --git a/src/test/java/io/suboptimal/connectjava/protocol/client/UnaryResponseClientHandlerTest.java b/src/test/java/io/suboptimal/connectjava/protocol/client/UnaryResponseClientHandlerTest.java new file mode 100644 index 0000000..a221fdd --- /dev/null +++ b/src/test/java/io/suboptimal/connectjava/protocol/client/UnaryResponseClientHandlerTest.java @@ -0,0 +1,351 @@ +package io.suboptimal.connectjava.protocol.client; + +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpVersion; +import io.suboptimal.connectjava.api.ConnectClientResponseStart; +import io.suboptimal.connectjava.api.ConnectEndOfStream; +import io.suboptimal.connectjava.api.ConnectError; +import io.suboptimal.connectjava.api.ConnectErrorCode; +import io.suboptimal.connectjava.api.ConnectPayload; +import io.suboptimal.connectjava.api.ConnectResponseMeta; +import io.suboptimal.connectjava.codec.ConnectCodec; +import io.suboptimal.connectjava.model.ConnectMethodDefinition; +import io.suboptimal.connectjava.model.ConnectMethodType; +import io.suboptimal.connectjava.model.ConnectServiceDefinition; +import io.suboptimal.connectjava.protocol.ClientTestSupport; +import io.suboptimal.connectjava.testfixtures.UnaryPostRequest; +import io.suboptimal.connectjava.testfixtures.UnaryPostResponse; +import org.junit.jupiter.api.Test; + +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +class UnaryResponseClientHandlerTest { + private static final ConnectMethodDefinition METHOD = new ConnectMethodDefinition( + "Unary", ConnectMethodType.UNARY, UnaryPostRequest.class, UnaryPostResponse.class, false); + private static final ConnectServiceDefinition SERVICE = new ConnectServiceDefinition( + "connectjava.test.v1.UnaryPostFixtureService", List.of(METHOD), null); + private static final UnaryPostResponse RESPONSE = + UnaryPostResponse.newBuilder().setText("pong").build(); + + private final ConnectClientProtocolConfig config = ClientTestSupport.config(); + private final ConnectCodec proto = ClientTestSupport.protoCodec(); + + private ConnectClientCallStart callStart() { + return new ConnectClientCallStart(SERVICE, METHOD, Map.of(), false, "proto"); + } + + private EmbeddedChannel newChannel(ClientTestSupport.RecordingObserver observer) { + EmbeddedChannel channel = new EmbeddedChannel(); + channel.pipeline().addLast(new UnaryResponseClientHandler(callStart(), config, observer)); + return channel; + } + + private FullHttpResponse response(int status, byte[] body, String contentType, String contentEncoding) { + FullHttpResponse response = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.valueOf(status), Unpooled.wrappedBuffer(body)); + if (contentType != null) { + response.headers().set(HttpHeaderNames.CONTENT_TYPE, contentType); + } + if (contentEncoding != null) { + response.headers().set(HttpHeaderNames.CONTENT_ENCODING, contentEncoding); + } + return response; + } + + @Test + void deliversExchangePayloadEndOfStreamOnSuccess() { + var observer = new ClientTestSupport.RecordingObserver(); + EmbeddedChannel channel = newChannel(observer); + + channel.writeInbound(response(200, ClientTestSupport.encode(proto, RESPONSE), + "application/proto", null)); + + Object exchange = channel.readInbound(); + assertThat(exchange).isInstanceOf(ConnectClientResponseStart.class); + assertThat(((ConnectClientResponseStart) exchange).responseMeta().statusCode()).isEqualTo(200); + + Object payload = channel.readInbound(); + assertThat(payload).isInstanceOf(ConnectPayload.class); + assertThat(((ConnectPayload) payload).data()).isEqualTo(RESPONSE); + + Object endOfStream = channel.readInbound(); + assertThat(endOfStream).isSameAs(ConnectEndOfStream.INSTANCE); + + assertThat(observer.events) + .containsExactly("onResponseHeaders", "onResponsePayload", "onCallComplete"); + assertThat(observer.completeError).isNull(); + + channel.finishAndReleaseAll(); + } + + @Test + void decompressesGzipResponse() { + var observer = new ClientTestSupport.RecordingObserver(); + EmbeddedChannel channel = newChannel(observer); + + byte[] gzipped = ClientTestSupport.gzipCompress(ClientTestSupport.encode(proto, RESPONSE)); + channel.writeInbound(response(200, gzipped, "application/proto", "gzip")); + + channel.readInbound(); // exchange + ConnectPayload payload = channel.readInbound(); + assertThat(payload.data()).isEqualTo(RESPONSE); + + channel.finishAndReleaseAll(); + } + + @Test + void mapsEmptyBody404ToUnimplemented() { + var observer = new ClientTestSupport.RecordingObserver(); + EmbeddedChannel channel = newChannel(observer); + + channel.writeInbound(response(404, new byte[0], null, null)); + + Object first = channel.readInbound(); + assertThat(first).isInstanceOf(ConnectClientResponseStart.class); + Object second = channel.readInbound(); + assertThat(second).isInstanceOf(ConnectError.class); + assertThat(((ConnectError) second).code()).isEqualTo(ConnectErrorCode.UNIMPLEMENTED); + assertThat(observer.completeCount).isEqualTo(1); + Object next = channel.readInbound(); + assertThat(next).isNull(); + + channel.finishAndReleaseAll(); + } + + @Test + void mapsHttpStatusesToErrorCodes() { + Map cases = Map.of( + 400, ConnectErrorCode.INTERNAL, + 401, ConnectErrorCode.UNAUTHENTICATED, + 403, ConnectErrorCode.PERMISSION_DENIED, + 429, ConnectErrorCode.UNAVAILABLE, + 500, ConnectErrorCode.UNKNOWN, + 503, ConnectErrorCode.UNAVAILABLE, + 418, ConnectErrorCode.UNKNOWN); + + cases.forEach((status, expected) -> { + var observer = new ClientTestSupport.RecordingObserver(); + EmbeddedChannel channel = newChannel(observer); + channel.writeInbound(response(status, new byte[0], null, null)); + Object first = channel.readInbound(); + assertThat(first).isInstanceOf(ConnectClientResponseStart.class); + Object second = channel.readInbound(); + assertThat(second).isInstanceOf(ConnectError.class); + assertThat(((ConnectError) second).code()) + .as("status %d", status) + .isEqualTo(expected); + channel.finishAndReleaseAll(); + }); + } + + @Test + void prefersConnectErrorBodyOverHttpStatus() { + var observer = new ClientTestSupport.RecordingObserver(); + EmbeddedChannel channel = newChannel(observer); + + byte[] body = "{\"code\":\"not_found\",\"message\":\"nope\"}".getBytes(StandardCharsets.UTF_8); + channel.writeInbound(response(400, body, "application/json", null)); + + assertThat((Object) channel.readInbound()).isInstanceOf(ConnectClientResponseStart.class); + ConnectError error = channel.readInbound(); + assertThat(error.code()).isEqualTo(ConnectErrorCode.NOT_FOUND); + assertThat(error.message()).isEqualTo("nope"); + + channel.finishAndReleaseAll(); + } + + @Test + void rejectsResponseWithMismatchedCodec() { + var observer = new ClientTestSupport.RecordingObserver(); + EmbeddedChannel channel = newChannel(observer); // callStart has codecName="proto" + + // server responds with JSON codec while client requested proto + channel.writeInbound(response(200, new byte[0], "application/json", null)); + + ConnectClientResponseStart responseStart = channel.readInbound(); + assertThat(responseStart).isNotNull(); + + ConnectError error = channel.readInbound(); + assertThat(error.code()).isEqualTo(ConnectErrorCode.INTERNAL); + assertThat(error.message()).contains("json"); + assertThat(error.message()).contains("proto"); + assertThat(observer.completeCount).isEqualTo(1); + + channel.finishAndReleaseAll(); + } + + @Test + void failsOnMissingContentTypeForSuccess() { + var observer = new ClientTestSupport.RecordingObserver(); + EmbeddedChannel channel = newChannel(observer); + + channel.writeInbound(response(200, ClientTestSupport.encode(proto, RESPONSE), null, null)); + + ConnectClientResponseStart responseStart = channel.readInbound(); + assertThat(responseStart).isNotNull(); + + ConnectError error = channel.readInbound(); + assertThat(error.code()).isEqualTo(ConnectErrorCode.UNKNOWN); + assertThat(error.message()).contains("Content-Type"); + + channel.finishAndReleaseAll(); + } + + @Test + void failsCleanlyOnUndecodableBody() { + var observer = new ClientTestSupport.RecordingObserver(); + EmbeddedChannel channel = newChannel(observer); + + // Invalid protobuf payload — decode throws; must surface as a clean ConnectError, + // never an IllegalReferenceCountException from a double release (regression: BUG 1). + channel.writeInbound(response(200, new byte[] {(byte) 0xFF, (byte) 0xFF, (byte) 0xFF}, + "application/proto", null)); + + ConnectClientResponseStart responseStart = channel.readInbound(); + assertThat(responseStart).isNotNull(); + + ConnectError error = channel.readInbound(); + assertThat((error).code()).isEqualTo(ConnectErrorCode.INTERNAL); + assertThat((error).message()).contains("Deserialization failed"); + assertThat(observer.completeCount).isEqualTo(1); + + channel.finishAndReleaseAll(); + } + + @Test + void failsCleanlyOnCorruptGzip() { + var observer = new ClientTestSupport.RecordingObserver(); + EmbeddedChannel channel = newChannel(observer); + + channel.writeInbound(response(200, new byte[] {1, 2, 3, 4, 5}, "application/proto", "gzip")); + + ConnectClientResponseStart responseStart = channel.readInbound(); + assertThat(responseStart).isNotNull(); + + ConnectError error = channel.readInbound(); + assertThat(error.code()).isEqualTo(ConnectErrorCode.INTERNAL); + assertThat(error.message()).contains("Decompression failed"); + + channel.finishAndReleaseAll(); + } + + @Test + void channelInactiveBeforeResponseCancelsCall() { + var observer = new ClientTestSupport.RecordingObserver(); + EmbeddedChannel channel = newChannel(observer); + + channel.pipeline().fireChannelInactive(); + + ConnectError error = channel.readInbound(); + assertThat(error.code()).isEqualTo(ConnectErrorCode.CANCELED); + assertThat(observer.completeCount).isEqualTo(1); + + channel.finishAndReleaseAll(); + } + + @Test + void deliversTrailersFromTrailerPrefixedHeaders() { + var observer = new ClientTestSupport.RecordingObserver(); + EmbeddedChannel channel = newChannel(observer); + + byte[] body = ClientTestSupport.encode(proto, RESPONSE); + FullHttpResponse resp = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.OK, Unpooled.wrappedBuffer(body)); + resp.headers().set(HttpHeaderNames.CONTENT_TYPE, "application/proto"); + resp.headers().set("trailer-x-foo", "bar"); + resp.headers().set("x-plain", "v"); + + channel.writeInbound(resp); + + ConnectClientResponseStart start = channel.readInbound(); + ConnectResponseMeta meta = start.responseMeta(); + + assertThat(meta.trailers().get("x-foo")).containsExactly("bar"); + assertThat(meta.headers()).containsKey("x-plain"); + assertThat(meta.headers()).doesNotContainKey("x-foo"); + + channel.readInbound(); // payload + channel.readInbound(); // EOS + channel.finishAndReleaseAll(); + } + + @Test + void parsesErrorDetailsFromUnaryErrorBody() { + var observer = new ClientTestSupport.RecordingObserver(); + EmbeddedChannel channel = newChannel(observer); + + byte[] detailBytes = new byte[]{1, 2, 3}; + String b64 = Base64.getEncoder().encodeToString(detailBytes); + byte[] body = ("{\"code\":\"not_found\",\"message\":\"nope\",\"details\":" + + "[{\"type\":\"google.rpc.RetryInfo\",\"value\":\"" + b64 + "\"}]}") + .getBytes(StandardCharsets.UTF_8); + + channel.writeInbound(response(400, body, "application/json", null)); + + channel.readInbound(); // ConnectClientResponseStart + ConnectError error = channel.readInbound(); + + assertThat(error.code()).isEqualTo(ConnectErrorCode.NOT_FOUND); + assertThat(error.message()).isEqualTo("nope"); + assertThat(error.details()).hasSize(1); + assertThat(error.details().get(0).type()).isEqualTo("google.rpc.RetryInfo"); + assertThat(error.details().get(0).value()).isEqualTo(detailBytes); + + channel.finishAndReleaseAll(); + } + + @Test + void decompressesGzipErrorBody() { + var observer = new ClientTestSupport.RecordingObserver(); + EmbeddedChannel channel = newChannel(observer); + + byte[] json = "{\"code\":\"out_of_range\",\"message\":\"oops\"}" + .getBytes(StandardCharsets.UTF_8); + channel.writeInbound(response(422, ClientTestSupport.gzipCompress(json), + "application/json", "gzip")); + + channel.readInbound(); // ConnectClientResponseStart + Object inbound = channel.readInbound(); + assertThat(inbound).isInstanceOf(ConnectError.class); + assertThat(((ConnectError) inbound).code()).isEqualTo(ConnectErrorCode.OUT_OF_RANGE); + assertThat(((ConnectError) inbound).message()).isEqualTo("oops"); + } + + @Test + void fallsBackToHttpStatusWhenErrorBodyIsCorruptGzip() { + var observer = new ClientTestSupport.RecordingObserver(); + EmbeddedChannel channel = newChannel(observer); + + channel.writeInbound(response(503, new byte[] {1, 2, 3, 4, 5}, + "application/json", "gzip")); + + channel.readInbound(); // ConnectClientResponseStart + Object inbound = channel.readInbound(); + assertThat(inbound).isInstanceOf(ConnectError.class); + assertThat(((ConnectError) inbound).code()).isEqualTo(ConnectErrorCode.UNAVAILABLE); + } + + @Test + void channelInactiveAfterSuccessDoesNotCompleteTwice() { + var observer = new ClientTestSupport.RecordingObserver(); + EmbeddedChannel channel = newChannel(observer); + + channel.writeInbound(response(200, ClientTestSupport.encode(proto, RESPONSE), + "application/proto", null)); + channel.pipeline().fireChannelInactive(); + + assertThat(observer.completeCount).isEqualTo(1); + + channel.finishAndReleaseAll(); + } +} diff --git a/src/test/java/io/suboptimal/connectjava/protocol/ConnectEndStreamMetaBuilderTest.java b/src/test/java/io/suboptimal/connectjava/protocol/server/ConnectEndStreamMetaBuilderTest.java similarity index 95% rename from src/test/java/io/suboptimal/connectjava/protocol/ConnectEndStreamMetaBuilderTest.java rename to src/test/java/io/suboptimal/connectjava/protocol/server/ConnectEndStreamMetaBuilderTest.java index 2776a69..2d141c8 100644 --- a/src/test/java/io/suboptimal/connectjava/protocol/ConnectEndStreamMetaBuilderTest.java +++ b/src/test/java/io/suboptimal/connectjava/protocol/server/ConnectEndStreamMetaBuilderTest.java @@ -1,4 +1,4 @@ -package io.suboptimal.connectjava.protocol; +package io.suboptimal.connectjava.protocol.server; import org.junit.jupiter.api.Test; diff --git a/src/test/java/io/suboptimal/connectjava/protocol/ConnectEndStreamMetaTest.java b/src/test/java/io/suboptimal/connectjava/protocol/server/ConnectEndStreamMetaTest.java similarity index 95% rename from src/test/java/io/suboptimal/connectjava/protocol/ConnectEndStreamMetaTest.java rename to src/test/java/io/suboptimal/connectjava/protocol/server/ConnectEndStreamMetaTest.java index 6c99bf2..36c5335 100644 --- a/src/test/java/io/suboptimal/connectjava/protocol/ConnectEndStreamMetaTest.java +++ b/src/test/java/io/suboptimal/connectjava/protocol/server/ConnectEndStreamMetaTest.java @@ -1,4 +1,4 @@ -package io.suboptimal.connectjava.protocol; +package io.suboptimal.connectjava.protocol.server; import org.junit.jupiter.api.Test; diff --git a/src/test/java/io/suboptimal/connectjava/protocol/ConnectMetaBuilderTest.java b/src/test/java/io/suboptimal/connectjava/protocol/server/ConnectMetaBuilderTest.java similarity index 97% rename from src/test/java/io/suboptimal/connectjava/protocol/ConnectMetaBuilderTest.java rename to src/test/java/io/suboptimal/connectjava/protocol/server/ConnectMetaBuilderTest.java index 3bf1e11..e4ddc9f 100644 --- a/src/test/java/io/suboptimal/connectjava/protocol/ConnectMetaBuilderTest.java +++ b/src/test/java/io/suboptimal/connectjava/protocol/server/ConnectMetaBuilderTest.java @@ -1,4 +1,4 @@ -package io.suboptimal.connectjava.protocol; +package io.suboptimal.connectjava.protocol.server; import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.HttpHeaders; diff --git a/src/test/java/io/suboptimal/connectjava/protocol/ConnectRouteTest.java b/src/test/java/io/suboptimal/connectjava/protocol/server/ConnectRouteTest.java similarity index 97% rename from src/test/java/io/suboptimal/connectjava/protocol/ConnectRouteTest.java rename to src/test/java/io/suboptimal/connectjava/protocol/server/ConnectRouteTest.java index ff2a4c3..7ffc4a1 100644 --- a/src/test/java/io/suboptimal/connectjava/protocol/ConnectRouteTest.java +++ b/src/test/java/io/suboptimal/connectjava/protocol/server/ConnectRouteTest.java @@ -1,4 +1,4 @@ -package io.suboptimal.connectjava.protocol; +package io.suboptimal.connectjava.protocol.server; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; diff --git a/src/test/java/io/suboptimal/connectjava/protocol/ConnectChannelConfigurerTest.java b/src/test/java/io/suboptimal/connectjava/protocol/server/ConnectServerChannelConfigurerTest.java similarity index 85% rename from src/test/java/io/suboptimal/connectjava/protocol/ConnectChannelConfigurerTest.java rename to src/test/java/io/suboptimal/connectjava/protocol/server/ConnectServerChannelConfigurerTest.java index 4a63b19..d9a415c 100644 --- a/src/test/java/io/suboptimal/connectjava/protocol/ConnectChannelConfigurerTest.java +++ b/src/test/java/io/suboptimal/connectjava/protocol/server/ConnectServerChannelConfigurerTest.java @@ -1,4 +1,4 @@ -package io.suboptimal.connectjava.protocol; +package io.suboptimal.connectjava.protocol.server; import io.suboptimal.connectjava.codec.ConnectCodecRegistry; import io.netty.channel.ChannelInboundHandlerAdapter; @@ -21,10 +21,10 @@ import static org.assertj.core.api.Assertions.assertThat; -class ConnectChannelConfigurerTest { +class ConnectServerChannelConfigurerTest { private static final String EXAMPLE_URL = "https://app.example"; - private static final ConnectProtocolParameters PARAMETERS = - new ConnectProtocolParameters(1024, 64, + private static final ConnectServerProtocolParameters PARAMETERS = + new ConnectServerProtocolParameters(1024, 64, ConnectCorsParameters.defaultsForOrigins(Set.of(EXAMPLE_URL))); private EmbeddedChannel channel; @@ -45,13 +45,13 @@ class Http1 { void installsExpectedPipeline() { configureChannel(ConnectTransport.HTTP_1_1); - assertThat(channel.pipeline().get(ConnectPipeline.CORS_HANDLER)) + assertThat(channel.pipeline().get(ConnectServerPipeline.CORS_HANDLER)) .isInstanceOf(CorsHandler.class); String terminalName = channel.pipeline().context(TerminalHandler.class).name(); assertThat(channel.pipeline().names()) .containsSequence( - ConnectPipeline.CORS_HANDLER, - ConnectPipeline.ROUTING_HANDLER, + ConnectServerPipeline.CORS_HANDLER, + ConnectServerPipeline.ROUTING_HANDLER, terminalName); } } @@ -62,15 +62,15 @@ class Http2 { void installsExpectedPipeline() { configureChannel(ConnectTransport.HTTP_2); - assertThat(channel.pipeline().get(ConnectPipeline.CORS_HANDLER)) + assertThat(channel.pipeline().get(ConnectServerPipeline.CORS_HANDLER)) .isInstanceOf(CorsHandler.class); String adapterName = channel.pipeline().context(Http2StreamFrameToHttpObjectCodec.class).name(); String terminalName = channel.pipeline().context(TerminalHandler.class).name(); assertThat(channel.pipeline().names()) .containsSequence( adapterName, - ConnectPipeline.CORS_HANDLER, - ConnectPipeline.ROUTING_HANDLER, + ConnectServerPipeline.CORS_HANDLER, + ConnectServerPipeline.ROUTING_HANDLER, terminalName); } } @@ -131,8 +131,8 @@ void deniesToOriginNotInAllowlist() { } void configureChannel(ConnectTransport transport) { - new ConnectChannelConfigurer(transport, - ConnectProtocolConfig.builder(Map.of(), TerminalHandler::new, PARAMETERS, ConnectCodecRegistry.builder().build()).build()) + new ConnectServerChannelConfigurer(transport, + ConnectServerProtocolConfig.builder(Map.of(), TerminalHandler::new, PARAMETERS, ConnectCodecRegistry.builder().build()).build()) .configure(channel); } diff --git a/src/test/java/io/suboptimal/connectjava/protocol/ConnectInterceptorPipelineTest.java b/src/test/java/io/suboptimal/connectjava/protocol/server/ConnectServerInterceptorPipelineTest.java similarity index 50% rename from src/test/java/io/suboptimal/connectjava/protocol/ConnectInterceptorPipelineTest.java rename to src/test/java/io/suboptimal/connectjava/protocol/server/ConnectServerInterceptorPipelineTest.java index 38f2937..188a6df 100644 --- a/src/test/java/io/suboptimal/connectjava/protocol/ConnectInterceptorPipelineTest.java +++ b/src/test/java/io/suboptimal/connectjava/protocol/server/ConnectServerInterceptorPipelineTest.java @@ -1,4 +1,4 @@ -package io.suboptimal.connectjava.protocol; +package io.suboptimal.connectjava.protocol.server; import io.suboptimal.connectjava.api.ConnectCallExchange; import io.suboptimal.connectjava.api.ConnectError; @@ -17,7 +17,7 @@ import static org.assertj.core.api.Assertions.assertThat; -class ConnectInterceptorPipelineTest { +class ConnectServerInterceptorPipelineTest { private static final ConnectMethodDefinition METHOD = new ConnectMethodDefinition( "Method", ConnectMethodType.UNARY, String.class, String.class, false); @@ -30,47 +30,47 @@ class ConnectInterceptorPipelineTest { @Test void emptyPipelineContinues() { - ConnectInterceptor.Decision decision = ConnectInterceptorPipeline.EMPTY.interceptCall(EXCHANGE); + ConnectServerInterceptor.Decision decision = ConnectServerInterceptorPipeline.EMPTY.interceptCall(EXCHANGE); - assertThat(decision).isInstanceOf(ConnectInterceptor.Decision.Continue.class); - assertThat(decision.observer()).isSameAs(ConnectCallObserver.NOOP); + assertThat(decision).isInstanceOf(ConnectServerInterceptor.Decision.Continue.class); + assertThat(decision.observer()).isSameAs(ConnectServerCallObserver.NOOP); } @Test void allNoOpObserversProduceNoOp() { - ConnectInterceptorPipeline pipeline = new ConnectInterceptorPipeline(List.of( - exchange -> ConnectInterceptor.continueCall(), - exchange -> ConnectInterceptor.continueCall() + ConnectServerInterceptorPipeline pipeline = new ConnectServerInterceptorPipeline(List.of( + exchange -> ConnectServerInterceptor.continueCall(), + exchange -> ConnectServerInterceptor.continueCall() )); - ConnectInterceptor.Decision decision = pipeline.interceptCall(EXCHANGE); + ConnectServerInterceptor.Decision decision = pipeline.interceptCall(EXCHANGE); - assertThat(decision).isInstanceOf(ConnectInterceptor.Decision.Continue.class); - assertThat(decision.observer()).isSameAs(ConnectCallObserver.NOOP); + assertThat(decision).isInstanceOf(ConnectServerInterceptor.Decision.Continue.class); + assertThat(decision.observer()).isSameAs(ConnectServerCallObserver.NOOP); } @Test void singleNonNoOpObserverIsReturnedDirectly() { - ConnectCallObserver observer = new ConnectCallObserver() {}; - ConnectInterceptorPipeline pipeline = new ConnectInterceptorPipeline(List.of( - exchange -> ConnectInterceptor.continueWith(observer) + ConnectServerCallObserver observer = new ConnectServerCallObserver() {}; + ConnectServerInterceptorPipeline pipeline = new ConnectServerInterceptorPipeline(List.of( + exchange -> ConnectServerInterceptor.continueWith(observer) )); - ConnectInterceptor.Decision decision = pipeline.interceptCall(EXCHANGE); + ConnectServerInterceptor.Decision decision = pipeline.interceptCall(EXCHANGE); assertThat(decision.observer()).isSameAs(observer); } @Test void noOpObserversAreFilteredFromComposite() { - ConnectCallObserver real = new ConnectCallObserver() {}; - ConnectInterceptorPipeline pipeline = new ConnectInterceptorPipeline(List.of( - exchange -> ConnectInterceptor.continueCall(), - exchange -> ConnectInterceptor.continueWith(real), - exchange -> ConnectInterceptor.continueCall() + ConnectServerCallObserver real = new ConnectServerCallObserver() {}; + ConnectServerInterceptorPipeline pipeline = new ConnectServerInterceptorPipeline(List.of( + exchange -> ConnectServerInterceptor.continueCall(), + exchange -> ConnectServerInterceptor.continueWith(real), + exchange -> ConnectServerInterceptor.continueCall() )); - ConnectInterceptor.Decision decision = pipeline.interceptCall(EXCHANGE); + ConnectServerInterceptor.Decision decision = pipeline.interceptCall(EXCHANGE); assertThat(decision.observer()).isSameAs(real); } @@ -78,16 +78,16 @@ void noOpObserversAreFilteredFromComposite() { @Test void responsePayloadCallbacksAreFIFO() { List log = new ArrayList<>(); - ConnectInterceptorPipeline pipeline = new ConnectInterceptorPipeline(List.of( - exchange -> ConnectInterceptor.continueWith(new ConnectCallObserver() { + ConnectServerInterceptorPipeline pipeline = new ConnectServerInterceptorPipeline(List.of( + exchange -> ConnectServerInterceptor.continueWith(new ConnectServerCallObserver() { @Override public void onResponsePayload(Object p) { log.add("first"); } }), - exchange -> ConnectInterceptor.continueWith(new ConnectCallObserver() { + exchange -> ConnectServerInterceptor.continueWith(new ConnectServerCallObserver() { @Override public void onResponsePayload(Object p) { log.add("second"); } }) )); - ConnectCallObserver composite = pipeline.interceptCall(EXCHANGE).observer(); + ConnectServerCallObserver composite = pipeline.interceptCall(EXCHANGE).observer(); composite.onResponsePayload("x"); assertThat(log).containsExactly("first", "second"); @@ -96,16 +96,16 @@ void responsePayloadCallbacksAreFIFO() { @Test void responseHeaderCallbacksAreLIFO() { List log = new ArrayList<>(); - ConnectInterceptorPipeline pipeline = new ConnectInterceptorPipeline(List.of( - exchange -> ConnectInterceptor.continueWith(new ConnectCallObserver() { + ConnectServerInterceptorPipeline pipeline = new ConnectServerInterceptorPipeline(List.of( + exchange -> ConnectServerInterceptor.continueWith(new ConnectServerCallObserver() { @Override public void onResponseHeaders(ConnectResponseHeadersBuilder h) { log.add("first"); } }), - exchange -> ConnectInterceptor.continueWith(new ConnectCallObserver() { + exchange -> ConnectServerInterceptor.continueWith(new ConnectServerCallObserver() { @Override public void onResponseHeaders(ConnectResponseHeadersBuilder h) { log.add("second"); } }) )); - ConnectCallObserver composite = pipeline.interceptCall(EXCHANGE).observer(); + ConnectServerCallObserver composite = pipeline.interceptCall(EXCHANGE).observer(); composite.onResponseHeaders(new ResponseHeadersBuilder()); assertThat(log).containsExactly("second", "first"); @@ -114,16 +114,16 @@ void responseHeaderCallbacksAreLIFO() { @Test void responseTrailerCallbacksAreLIFO() { List log = new ArrayList<>(); - ConnectInterceptorPipeline pipeline = new ConnectInterceptorPipeline(List.of( - exchange -> ConnectInterceptor.continueWith(new ConnectCallObserver() { + ConnectServerInterceptorPipeline pipeline = new ConnectServerInterceptorPipeline(List.of( + exchange -> ConnectServerInterceptor.continueWith(new ConnectServerCallObserver() { @Override public void onResponseTrailers(ConnectResponseTrailersBuilder t, @Nullable ConnectError e) { log.add("first"); } }), - exchange -> ConnectInterceptor.continueWith(new ConnectCallObserver() { + exchange -> ConnectServerInterceptor.continueWith(new ConnectServerCallObserver() { @Override public void onResponseTrailers(ConnectResponseTrailersBuilder t, @Nullable ConnectError e) { log.add("second"); } }) )); - ConnectCallObserver composite = pipeline.interceptCall(EXCHANGE).observer(); + ConnectServerCallObserver composite = pipeline.interceptCall(EXCHANGE).observer(); composite.onResponseTrailers(new ResponseTrailersBuilder(), null); assertThat(log).containsExactly("second", "first"); @@ -132,16 +132,16 @@ void responseTrailerCallbacksAreLIFO() { @Test void callCompleteCallbacksAreLIFO() { List log = new ArrayList<>(); - ConnectInterceptorPipeline pipeline = new ConnectInterceptorPipeline(List.of( - exchange -> ConnectInterceptor.continueWith(new ConnectCallObserver() { + ConnectServerInterceptorPipeline pipeline = new ConnectServerInterceptorPipeline(List.of( + exchange -> ConnectServerInterceptor.continueWith(new ConnectServerCallObserver() { @Override public void onCallComplete(@Nullable ConnectError e) { log.add("first"); } }), - exchange -> ConnectInterceptor.continueWith(new ConnectCallObserver() { + exchange -> ConnectServerInterceptor.continueWith(new ConnectServerCallObserver() { @Override public void onCallComplete(@Nullable ConnectError e) { log.add("second"); } }) )); - ConnectCallObserver composite = pipeline.interceptCall(EXCHANGE).observer(); + ConnectServerCallObserver composite = pipeline.interceptCall(EXCHANGE).observer(); composite.onCallComplete(null); assertThat(log).containsExactly("second", "first"); @@ -152,28 +152,28 @@ void rejectionStopsIterationAndReturnsCompositeOfPriorObservers() { List callOrder = new ArrayList<>(); List completedLog = new ArrayList<>(); - ConnectInterceptor first = exchange -> { + ConnectServerInterceptor first = exchange -> { callOrder.add("first"); - return ConnectInterceptor.continueWith(new ConnectCallObserver() { + return ConnectServerInterceptor.continueWith(new ConnectServerCallObserver() { @Override public void onCallComplete(@Nullable ConnectError e) { completedLog.add("first"); } }); }; ConnectError rejectError = ConnectError.permissionDenied("no"); - ConnectInterceptor rejecting = exchange -> { + ConnectServerInterceptor rejecting = exchange -> { callOrder.add("rejecting"); - return ConnectInterceptor.reject(rejectError); + return ConnectServerInterceptor.reject(rejectError); }; - ConnectInterceptor notReached = exchange -> { + ConnectServerInterceptor notReached = exchange -> { callOrder.add("notReached"); - return ConnectInterceptor.continueCall(); + return ConnectServerInterceptor.continueCall(); }; - ConnectInterceptorPipeline pipeline = new ConnectInterceptorPipeline(List.of(first, rejecting, notReached)); + ConnectServerInterceptorPipeline pipeline = new ConnectServerInterceptorPipeline(List.of(first, rejecting, notReached)); - ConnectInterceptor.Decision decision = pipeline.interceptCall(EXCHANGE); + ConnectServerInterceptor.Decision decision = pipeline.interceptCall(EXCHANGE); - assertThat(decision).isInstanceOf(ConnectInterceptor.Decision.Reject.class); - assertThat(((ConnectInterceptor.Decision.Reject) decision).error()).isSameAs(rejectError); + assertThat(decision).isInstanceOf(ConnectServerInterceptor.Decision.Reject.class); + assertThat(((ConnectServerInterceptor.Decision.Reject) decision).error()).isSameAs(rejectError); assertThat(callOrder).containsExactly("first", "rejecting"); // The composite observer from prior Continue decisions can still be invoked for terminal callbacks @@ -183,13 +183,13 @@ void rejectionStopsIterationAndReturnsCompositeOfPriorObservers() { @Test void rejectionWithNoPriorContinueObserversReturnsNoOp() { - ConnectInterceptorPipeline pipeline = new ConnectInterceptorPipeline(List.of( - exchange -> ConnectInterceptor.reject(ConnectError.unauthenticated("go away")) + ConnectServerInterceptorPipeline pipeline = new ConnectServerInterceptorPipeline(List.of( + exchange -> ConnectServerInterceptor.reject(ConnectError.unauthenticated("go away")) )); - ConnectInterceptor.Decision decision = pipeline.interceptCall(EXCHANGE); + ConnectServerInterceptor.Decision decision = pipeline.interceptCall(EXCHANGE); - assertThat(decision).isInstanceOf(ConnectInterceptor.Decision.Reject.class); - assertThat(decision.observer()).isSameAs(ConnectCallObserver.NOOP); + assertThat(decision).isInstanceOf(ConnectServerInterceptor.Decision.Reject.class); + assertThat(decision.observer()).isSameAs(ConnectServerCallObserver.NOOP); } } diff --git a/src/test/java/io/suboptimal/connectjava/protocol/ConnectStringBuilderJsonSerializerTest.java b/src/test/java/io/suboptimal/connectjava/protocol/server/ConnectStringBuilderJsonSerializerTest.java similarity index 99% rename from src/test/java/io/suboptimal/connectjava/protocol/ConnectStringBuilderJsonSerializerTest.java rename to src/test/java/io/suboptimal/connectjava/protocol/server/ConnectStringBuilderJsonSerializerTest.java index 980f3d9..1fff617 100644 --- a/src/test/java/io/suboptimal/connectjava/protocol/ConnectStringBuilderJsonSerializerTest.java +++ b/src/test/java/io/suboptimal/connectjava/protocol/server/ConnectStringBuilderJsonSerializerTest.java @@ -1,4 +1,4 @@ -package io.suboptimal.connectjava.protocol; +package io.suboptimal.connectjava.protocol.server; import io.suboptimal.connectjava.api.ConnectError; import io.suboptimal.connectjava.api.ConnectErrorCode; diff --git a/src/test/java/io/suboptimal/connectjava/protocol/HttpAssertions.java b/src/test/java/io/suboptimal/connectjava/protocol/server/HttpAssertions.java similarity index 97% rename from src/test/java/io/suboptimal/connectjava/protocol/HttpAssertions.java rename to src/test/java/io/suboptimal/connectjava/protocol/server/HttpAssertions.java index b4b3715..8eb5893 100644 --- a/src/test/java/io/suboptimal/connectjava/protocol/HttpAssertions.java +++ b/src/test/java/io/suboptimal/connectjava/protocol/server/HttpAssertions.java @@ -1,4 +1,4 @@ -package io.suboptimal.connectjava.protocol; +package io.suboptimal.connectjava.protocol.server; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpHeaderNames; diff --git a/src/test/java/io/suboptimal/connectjava/protocol/ResponseHeadersBuilderTest.java b/src/test/java/io/suboptimal/connectjava/protocol/server/ResponseHeadersBuilderTest.java similarity index 94% rename from src/test/java/io/suboptimal/connectjava/protocol/ResponseHeadersBuilderTest.java rename to src/test/java/io/suboptimal/connectjava/protocol/server/ResponseHeadersBuilderTest.java index d1c33ad..80e82d5 100644 --- a/src/test/java/io/suboptimal/connectjava/protocol/ResponseHeadersBuilderTest.java +++ b/src/test/java/io/suboptimal/connectjava/protocol/server/ResponseHeadersBuilderTest.java @@ -1,4 +1,4 @@ -package io.suboptimal.connectjava.protocol; +package io.suboptimal.connectjava.protocol.server; import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.HttpHeaders; diff --git a/src/test/java/io/suboptimal/connectjava/protocol/ResponseTrailersBuilderTest.java b/src/test/java/io/suboptimal/connectjava/protocol/server/ResponseTrailersBuilderTest.java similarity index 96% rename from src/test/java/io/suboptimal/connectjava/protocol/ResponseTrailersBuilderTest.java rename to src/test/java/io/suboptimal/connectjava/protocol/server/ResponseTrailersBuilderTest.java index 4d00360..8706c2f 100644 --- a/src/test/java/io/suboptimal/connectjava/protocol/ResponseTrailersBuilderTest.java +++ b/src/test/java/io/suboptimal/connectjava/protocol/server/ResponseTrailersBuilderTest.java @@ -1,4 +1,4 @@ -package io.suboptimal.connectjava.protocol; +package io.suboptimal.connectjava.protocol.server; import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.HttpHeaders; diff --git a/src/test/java/io/suboptimal/connectjava/protocol/RoutingHandlerTest.java b/src/test/java/io/suboptimal/connectjava/protocol/server/RoutingServerHandlerTest.java similarity index 86% rename from src/test/java/io/suboptimal/connectjava/protocol/RoutingHandlerTest.java rename to src/test/java/io/suboptimal/connectjava/protocol/server/RoutingServerHandlerTest.java index f73b4fe..7ae7b1c 100644 --- a/src/test/java/io/suboptimal/connectjava/protocol/RoutingHandlerTest.java +++ b/src/test/java/io/suboptimal/connectjava/protocol/server/RoutingServerHandlerTest.java @@ -1,4 +1,4 @@ -package io.suboptimal.connectjava.protocol; +package io.suboptimal.connectjava.protocol.server; import io.netty.buffer.Unpooled; import io.netty.channel.ChannelInboundHandlerAdapter; @@ -29,9 +29,9 @@ import static org.assertj.core.api.Assertions.assertThat; -class RoutingHandlerTest { - private static final ConnectProtocolParameters PARAMETERS = - new ConnectProtocolParameters(1024, 64); +class RoutingServerHandlerTest { + private static final ConnectServerProtocolParameters PARAMETERS = + new ConnectServerProtocolParameters(1024, 64); private EmbeddedChannel channel; @@ -50,7 +50,7 @@ void setUpHandler() { } void setUpHandler(ConnectTransport transport) { - channel.pipeline().addLast(ConnectPipeline.ROUTING_HANDLER, newRoutingHandler(transport)); + channel.pipeline().addLast(ConnectServerPipeline.ROUTING_HANDLER, newRoutingHandler(transport)); } @Test @@ -231,8 +231,8 @@ void unaryContentTypeWithParametersStillInstallsUnaryHandler() { // via Netty's parsed request MIME type. channel.writeInbound(fullRequest("application/proto; charset=utf-8")); - assertThat(channel.pipeline().get(ConnectPipeline.UNARY_POST_REQUEST_HANDLER)) - .isInstanceOf(UnaryPostRequestHandler.class); + assertThat(channel.pipeline().get(ConnectServerPipeline.UNARY_POST_REQUEST_HANDLER)) + .isInstanceOf(UnaryPostRequestServerHandler.class); } @Test @@ -242,11 +242,11 @@ void unsupportedApplicationContentTypeIsHandledByUnaryHandler() { channel.writeInbound(request); - assertThat(channel.pipeline().get(ConnectPipeline.ROUTING_HANDLER)).isNull(); - assertThat(channel.pipeline().get(ConnectPipeline.AGGREGATOR_HANDLER)) + assertThat(channel.pipeline().get(ConnectServerPipeline.ROUTING_HANDLER)).isNull(); + assertThat(channel.pipeline().get(ConnectServerPipeline.AGGREGATOR_HANDLER)) .isInstanceOf(HttpObjectAggregator.class); - assertThat(channel.pipeline().get(ConnectPipeline.UNARY_POST_REQUEST_HANDLER)) - .isInstanceOf(UnaryPostRequestHandler.class); + assertThat(channel.pipeline().get(ConnectServerPipeline.UNARY_POST_REQUEST_HANDLER)) + .isInstanceOf(UnaryPostRequestServerHandler.class); channel.writeInbound(new DefaultLastHttpContent(Unpooled.EMPTY_BUFFER)); @@ -265,9 +265,9 @@ void streamingContentTypeInstallsNamedStreamingHandler() { channel.writeInbound(request("application/connect+proto", methodUri("ServerStream"))); - assertThat(channel.pipeline().get(ConnectPipeline.ROUTING_HANDLER)).isNull(); - assertThat(channel.pipeline().get(ConnectPipeline.STREAMING_HANDLER)) - .isInstanceOf(StreamingHandler.class); + assertThat(channel.pipeline().get(ConnectServerPipeline.ROUTING_HANDLER)).isNull(); + assertThat(channel.pipeline().get(ConnectServerPipeline.STREAMING_HANDLER)) + .isInstanceOf(StreamingServerHandler.class); Object inbound = channel.readInbound(); assertThat(inbound) @@ -283,13 +283,13 @@ void unaryContentTypeInstallsNamedAggregatorAndUnaryHandlerInOrder() { channel.writeInbound(fullRequest("application/proto")); - assertThat(channel.pipeline().get(ConnectPipeline.ROUTING_HANDLER)).isNull(); - assertThat(channel.pipeline().get(ConnectPipeline.AGGREGATOR_HANDLER)) + assertThat(channel.pipeline().get(ConnectServerPipeline.ROUTING_HANDLER)).isNull(); + assertThat(channel.pipeline().get(ConnectServerPipeline.AGGREGATOR_HANDLER)) .isInstanceOf(HttpObjectAggregator.class); - assertThat(channel.pipeline().get(ConnectPipeline.UNARY_POST_REQUEST_HANDLER)) - .isInstanceOf(UnaryPostRequestHandler.class); + assertThat(channel.pipeline().get(ConnectServerPipeline.UNARY_POST_REQUEST_HANDLER)) + .isInstanceOf(UnaryPostRequestServerHandler.class); assertThat(channel.pipeline().names()) - .containsSequence(ConnectPipeline.AGGREGATOR_HANDLER, ConnectPipeline.UNARY_POST_REQUEST_HANDLER); + .containsSequence(ConnectServerPipeline.AGGREGATOR_HANDLER, ConnectServerPipeline.UNARY_POST_REQUEST_HANDLER); } @Test @@ -300,11 +300,11 @@ void getInstallsNamedAggregatorAndUnaryGetHandler() { channel.writeInbound(request); - assertThat(channel.pipeline().get(ConnectPipeline.ROUTING_HANDLER)).isNull(); - assertThat(channel.pipeline().get(ConnectPipeline.AGGREGATOR_HANDLER)) + assertThat(channel.pipeline().get(ConnectServerPipeline.ROUTING_HANDLER)).isNull(); + assertThat(channel.pipeline().get(ConnectServerPipeline.AGGREGATOR_HANDLER)) .isInstanceOf(HttpObjectAggregator.class); - assertThat(channel.pipeline().get(ConnectPipeline.UNARY_GET_REQUEST_HANDLER)) - .isInstanceOf(UnaryGetRequestHandler.class); + assertThat(channel.pipeline().get(ConnectServerPipeline.UNARY_GET_REQUEST_HANDLER)) + .isInstanceOf(UnaryGetRequestServerHandler.class); FullHttpResponse response = channel.readOutbound(); try { @@ -338,9 +338,9 @@ void bidiOverHttp2InstallsStreamingHandler() { channel.writeInbound(request); - assertThat(channel.pipeline().get(ConnectPipeline.ROUTING_HANDLER)).isNull(); - assertThat(channel.pipeline().get(ConnectPipeline.STREAMING_HANDLER)) - .isInstanceOf(StreamingHandler.class); + assertThat(channel.pipeline().get(ConnectServerPipeline.ROUTING_HANDLER)).isNull(); + assertThat(channel.pipeline().get(ConnectServerPipeline.STREAMING_HANDLER)) + .isInstanceOf(StreamingServerHandler.class); Object outbound = channel.readOutbound(); assertThat(outbound).isNull(); @@ -380,7 +380,7 @@ void unaryMethodWithStreamingContentTypeReturns415(String contentType) { channel.writeInbound(request); assertThat(request.refCnt()).isZero(); - assertThat(channel.pipeline().get(ConnectPipeline.STREAMING_HANDLER)).isNull(); + assertThat(channel.pipeline().get(ConnectServerPipeline.STREAMING_HANDLER)).isNull(); FullHttpResponse response = channel.readOutbound(); try { HttpAssertions.assertThat(response).unsupportedMediaTypeError(); @@ -445,15 +445,15 @@ void unsupportedMethodOnUnaryProcedureReturns405WithAllowPostGet() { } private void assertNoConnectHandlerInstalled() { - assertThat(channel.pipeline().get(ConnectPipeline.AGGREGATOR_HANDLER)).isNull(); - assertThat(channel.pipeline().get(ConnectPipeline.UNARY_GET_REQUEST_HANDLER)).isNull(); - assertThat(channel.pipeline().get(ConnectPipeline.UNARY_POST_REQUEST_HANDLER)).isNull(); - assertThat(channel.pipeline().get(ConnectPipeline.UNARY_RESPONSE_HANDLER)).isNull(); - assertThat(channel.pipeline().get(ConnectPipeline.STREAMING_HANDLER)).isNull(); + assertThat(channel.pipeline().get(ConnectServerPipeline.AGGREGATOR_HANDLER)).isNull(); + assertThat(channel.pipeline().get(ConnectServerPipeline.UNARY_GET_REQUEST_HANDLER)).isNull(); + assertThat(channel.pipeline().get(ConnectServerPipeline.UNARY_POST_REQUEST_HANDLER)).isNull(); + assertThat(channel.pipeline().get(ConnectServerPipeline.UNARY_RESPONSE_HANDLER)).isNull(); + assertThat(channel.pipeline().get(ConnectServerPipeline.STREAMING_HANDLER)).isNull(); } - private static RoutingHandler newRoutingHandler(ConnectTransport transport) { - return new RoutingHandler(transport, ConnectProtocolConfig.builder( + private static RoutingServerHandler newRoutingHandler(ConnectTransport transport) { + return new RoutingServerHandler(transport, ConnectServerProtocolConfig.builder( Map.of("pkg.Service", SERVICE_DEFINITION), ChannelInboundHandlerAdapter::new, PARAMETERS, @@ -487,7 +487,7 @@ private static void setConnectHeaders(DefaultHttpRequest request, String content request.headers().set(HttpHeaderNames.CONTENT_TYPE, contentType); // Required by ConnectStreamingHandler / ConnectUnaryPostHandler — without it the // downstream handler would short-circuit with a Connect validation error - // before producing the inbound ConnectCallExchange these tests assert on. + // before producing the inbound ConnectServerCallExchange these tests assert on. request.headers().set("connect-protocol-version", "1"); } diff --git a/src/test/java/io/suboptimal/connectjava/protocol/StreamingHandlerTest.java b/src/test/java/io/suboptimal/connectjava/protocol/server/StreamingServerHandlerTest.java similarity index 93% rename from src/test/java/io/suboptimal/connectjava/protocol/StreamingHandlerTest.java rename to src/test/java/io/suboptimal/connectjava/protocol/server/StreamingServerHandlerTest.java index 783a67f..6eac23d 100644 --- a/src/test/java/io/suboptimal/connectjava/protocol/StreamingHandlerTest.java +++ b/src/test/java/io/suboptimal/connectjava/protocol/server/StreamingServerHandlerTest.java @@ -1,4 +1,4 @@ -package io.suboptimal.connectjava.protocol; +package io.suboptimal.connectjava.protocol.server; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; @@ -27,6 +27,8 @@ import io.suboptimal.connectjava.model.ConnectMethodDefinition; import io.suboptimal.connectjava.model.ConnectMethodType; import io.suboptimal.connectjava.model.ConnectServiceDefinition; +import io.suboptimal.connectjava.protocol.ConnectEnvelope; +import io.suboptimal.connectjava.protocol.client.ConnectCallTerminatedException; import io.suboptimal.connectjava.testfixtures.StreamingRequest; import io.suboptimal.connectjava.testfixtures.StreamingResponse; import org.jspecify.annotations.Nullable; @@ -44,7 +46,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -class StreamingHandlerTest { +class StreamingServerHandlerTest { private static final String SERVICE_NAME = "connectjava.test.v1.StreamingFixtureService"; private static final ConnectMethodDefinition SERVER_STREAMING_METHOD = new ConnectMethodDefinition( @@ -173,7 +175,7 @@ void serializesUnsupportedMethodAfterAcceptedRpcAsEndStreamError() { @Test void interceptorsAttachStreamingHeadersAndSuccessMetadata() { - ConnectCallObserver observer = new ConnectCallObserver() { + ConnectServerCallObserver observer = new ConnectServerCallObserver() { @Override public void onResponseHeaders(ConnectResponseHeadersBuilder headers) { headers.set("x-stream", "started"); @@ -201,7 +203,7 @@ public void onResponseTrailers(ConnectResponseTrailersBuilder trailers, @Nullabl @Test void interceptorsAttachStreamingMetadataOnError() { - ConnectCallObserver observer = new ConnectCallObserver() { + ConnectServerCallObserver observer = new ConnectServerCallObserver() { @Override public void onResponseTrailers(ConnectResponseTrailersBuilder trailers, @Nullable ConnectError error) { assertThat(error.code()).isEqualTo(ConnectErrorCode.INVALID_ARGUMENT); @@ -222,9 +224,9 @@ public void onResponseTrailers(ConnectResponseTrailersBuilder trailers, @Nullabl @Test void interceptorCanRejectStreamingCallBeforeRpcRequest() { - ConnectInterceptor rejecting = ctx -> ConnectInterceptor.reject(ConnectError.invalidArgument("nope")); + ConnectServerInterceptor rejecting = ctx -> ConnectServerInterceptor.reject(ConnectError.invalidArgument("nope")); setUpHandler(SERVER_STREAMING_METHOD, - new ConnectInterceptorPipeline(List.of(rejecting))); + new ConnectServerInterceptorPipeline(List.of(rejecting))); channel.writeInbound(request("/ServerStreaming")); @@ -314,7 +316,7 @@ void ignoresLastHttpContentAfterContentProcessingError() throws IOException { void cancelPathNotifiesObserverOnceWhenChannelClosedMidStream() { AtomicInteger callCount = new AtomicInteger(); List errors = new ArrayList<>(); - ConnectCallObserver observer = new ConnectCallObserver() { + ConnectServerCallObserver observer = new ConnectServerCallObserver() { @Override public void onCallComplete(@Nullable ConnectError error) { callCount.incrementAndGet(); @@ -337,7 +339,7 @@ public void onCallComplete(@Nullable ConnectError error) { @Test void cancelPathDoesNotFireAfterNormalCompletion() { AtomicInteger callCount = new AtomicInteger(); - ConnectCallObserver observer = new ConnectCallObserver() { + ConnectServerCallObserver observer = new ConnectServerCallObserver() { @Override public void onCallComplete(@Nullable ConnectError error) { callCount.incrementAndGet(); @@ -363,13 +365,13 @@ public void onCallComplete(@Nullable ConnectError error) { @Test void versionCheckFailureBeforeInterceptorDoesNotInvokeInterceptor() { boolean[] interceptorCalled = {false}; - ConnectInterceptor interceptor = exchange -> { + ConnectServerInterceptor interceptor = exchange -> { interceptorCalled[0] = true; - return ConnectInterceptor.continueCall(); + return ConnectServerInterceptor.continueCall(); }; setUpHandler(SERVER_STREAMING_METHOD, - new ConnectInterceptorPipeline(List.of(interceptor))); + new ConnectServerInterceptorPipeline(List.of(interceptor))); DefaultHttpRequest badRequest = request("/ServerStreaming"); badRequest.headers().set("connect-protocol-version", "999"); @@ -494,22 +496,22 @@ private void assertCallExchangeFired() { } void setUpHandler(ConnectMethodDefinition method) { - setUpHandler(method, ConnectInterceptorPipeline.EMPTY); + setUpHandler(method, ConnectServerInterceptorPipeline.EMPTY); } - void setUpHandler(ConnectMethodDefinition method, List observers) { - List interceptors = observers.stream() - .map(observer -> ctx -> ConnectInterceptor.continueWith(observer)) + void setUpHandler(ConnectMethodDefinition method, List observers) { + List interceptors = observers.stream() + .map(observer -> ctx -> ConnectServerInterceptor.continueWith(observer)) .toList(); - setUpHandler(method, new ConnectInterceptorPipeline(interceptors)); + setUpHandler(method, new ConnectServerInterceptorPipeline(interceptors)); } - void setUpHandler(ConnectMethodDefinition method, ConnectInterceptorPipeline interceptorPipeline) { + void setUpHandler(ConnectMethodDefinition method, ConnectServerInterceptorPipeline interceptorPipeline) { exchange = new ConnectCallExchange(SERVICE_DEF, method, new ConnectRequestMeta(Map.of()), new ResponseHeadersBuilder(), new ResponseTrailersBuilder()); - channel.pipeline().addLast(new StreamingHandler( + channel.pipeline().addLast(new StreamingServerHandler( exchange, 1024, ConnectProtobufCodecs.defaults(), ConnectCompressionRegistry.standard(), ConnectStringBuilderJsonSerializer.INSTANCE, interceptorPipeline)); diff --git a/src/test/java/io/suboptimal/connectjava/protocol/UnaryGetRequestHandlerTest.java b/src/test/java/io/suboptimal/connectjava/protocol/server/UnaryGetRequestServerHandlerTest.java similarity index 97% rename from src/test/java/io/suboptimal/connectjava/protocol/UnaryGetRequestHandlerTest.java rename to src/test/java/io/suboptimal/connectjava/protocol/server/UnaryGetRequestServerHandlerTest.java index 1c44528..4de37e7 100644 --- a/src/test/java/io/suboptimal/connectjava/protocol/UnaryGetRequestHandlerTest.java +++ b/src/test/java/io/suboptimal/connectjava/protocol/server/UnaryGetRequestServerHandlerTest.java @@ -1,4 +1,4 @@ -package io.suboptimal.connectjava.protocol; +package io.suboptimal.connectjava.protocol.server; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; @@ -45,7 +45,7 @@ import static org.assertj.core.api.Assertions.assertThat; -class UnaryGetRequestHandlerTest { +class UnaryGetRequestServerHandlerTest { private static final String SERVICE_NAME = "connectjava.test.v1.UnaryGetFixtureService"; private static final ConnectMethodDefinition UNARY_METHOD = new ConnectMethodDefinition( @@ -83,18 +83,18 @@ void setUpHandler(String methodName) { } void setUpHandler(String methodName, int maxRequestBytes) { - setUpHandler(methodName, maxRequestBytes, ConnectInterceptorPipeline.EMPTY); + setUpHandler(methodName, maxRequestBytes, ConnectServerInterceptorPipeline.EMPTY); } void setUpHandler(String methodName, int maxRequestBytes, - ConnectInterceptorPipeline interceptorPipeline) + ConnectServerInterceptorPipeline interceptorPipeline) { ConnectMethodDefinition method = SERVICE_DEF.methods().get(methodName); exchange = new ConnectCallExchange(SERVICE_DEF, method, new ConnectRequestMeta(Map.of()), new ResponseHeadersBuilder(), new ResponseTrailersBuilder()); - channel.pipeline().addLast(ConnectPipeline.UNARY_GET_REQUEST_HANDLER, - new UnaryGetRequestHandler( + channel.pipeline().addLast(ConnectServerPipeline.UNARY_GET_REQUEST_HANDLER, + new UnaryGetRequestServerHandler( exchange, ConnectProtobufCodecs.defaults(), ConnectCompressionRegistry.standard(), @@ -875,7 +875,7 @@ class InterceptorCallbacks { @Test void observerCallbacksAreInterleavedCorrectlyWithConnectMessages() throws IOException { List events = new ArrayList<>(); - ConnectCallObserver observer = new ConnectCallObserver() { + ConnectServerCallObserver observer = new ConnectServerCallObserver() { @Override public void onRequestPayload(Object payload) { events.add("onRequestPayload"); @@ -888,12 +888,12 @@ public void onRequestFinished() { }; setUpHandler("Unary", DEFAULT_MAX_REQUEST_BYTES, - new ConnectInterceptorPipeline(List.of(ctx -> ConnectInterceptor.continueWith(observer)))); + new ConnectServerInterceptorPipeline(List.of(ctx -> ConnectServerInterceptor.continueWith(observer)))); channel.pipeline().addLast(new ChannelInboundHandlerAdapter() { @Override public void channelRead(ChannelHandlerContext ctx, Object msg) { switch (msg) { - case ConnectCallExchange ignore -> events.add("ConnectCallExchange"); + case ConnectCallExchange ignore -> events.add("ConnectServerCallExchange"); case ConnectPayload ignore -> events.add("ConnectPayload"); case ConnectEndOfStream ignore -> events.add("ConnectEndOfStream"); case null, default -> { @@ -908,7 +908,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { channel.writeInbound(getRequest(methodUri)); assertThat(events).containsExactly( - "ConnectCallExchange", + "ConnectServerCallExchange", "onRequestPayload", "ConnectPayload", "onRequestFinished", @@ -917,8 +917,8 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { } private void assertResponseHandlerInstalled() { - assertThat(channel.pipeline().get(ConnectPipeline.UNARY_RESPONSE_HANDLER)) - .isInstanceOf(UnaryResponseProcessingHandler.class); + assertThat(channel.pipeline().get(ConnectServerPipeline.UNARY_RESPONSE_HANDLER)) + .isInstanceOf(UnaryResponseProcessingServerHandler.class); } private static DefaultFullHttpRequest getRequest(String methodUri) { diff --git a/src/test/java/io/suboptimal/connectjava/protocol/UnaryPostRequestHandlerTest.java b/src/test/java/io/suboptimal/connectjava/protocol/server/UnaryPostRequestServerHandlerTest.java similarity index 93% rename from src/test/java/io/suboptimal/connectjava/protocol/UnaryPostRequestHandlerTest.java rename to src/test/java/io/suboptimal/connectjava/protocol/server/UnaryPostRequestServerHandlerTest.java index 222a0cb..aa60c38 100644 --- a/src/test/java/io/suboptimal/connectjava/protocol/UnaryPostRequestHandlerTest.java +++ b/src/test/java/io/suboptimal/connectjava/protocol/server/UnaryPostRequestServerHandlerTest.java @@ -1,4 +1,4 @@ -package io.suboptimal.connectjava.protocol; +package io.suboptimal.connectjava.protocol.server; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; @@ -44,7 +44,7 @@ import static org.assertj.core.api.Assertions.assertThat; -class UnaryPostRequestHandlerTest { +class UnaryPostRequestServerHandlerTest { private static final String SERVICE_NAME = "connectjava.test.v1.UnaryPostFixtureService"; private static final ConnectMethodDefinition UNARY_METHOD = new ConnectMethodDefinition( @@ -70,16 +70,16 @@ void tearDownChannel() { } void setUpHandler() { - setUpHandler(ConnectInterceptorPipeline.EMPTY); + setUpHandler(ConnectServerInterceptorPipeline.EMPTY); } - void setUpHandler(ConnectInterceptorPipeline interceptorPipeline) { + void setUpHandler(ConnectServerInterceptorPipeline interceptorPipeline) { exchange = new ConnectCallExchange(SERVICE_DEF, UNARY_METHOD, new ConnectRequestMeta(Map.of()), new ResponseHeadersBuilder(), new ResponseTrailersBuilder()); - channel.pipeline().addLast(ConnectPipeline.UNARY_POST_REQUEST_HANDLER, - new UnaryPostRequestHandler( + channel.pipeline().addLast(ConnectServerPipeline.UNARY_POST_REQUEST_HANDLER, + new UnaryPostRequestServerHandler( exchange, ConnectProtobufCodecs.defaults(), ConnectCompressionRegistry.standard(), @@ -193,7 +193,7 @@ void rejectsUnsupportedContentEncodingEvenForZeroLengthRequest() { } @ParameterizedTest - @MethodSource("io.suboptimal.connectjava.protocol.UnaryPostRequestHandlerTest#malformedBodies") + @MethodSource("io.suboptimal.connectjava.protocol.server.UnaryPostRequestServerHandlerTest#malformedBodies") void rejectsMalformedRequestBodyAsInvalidArgument(String contentType, byte[] body) { DefaultFullHttpRequest request = unaryPostRequest(contentType, body); @@ -218,7 +218,7 @@ class InterceptorCallbacks { @Test void observerCallbacksAreInterleavedCorrectlyWithConnectMessages() throws IOException { List events = new ArrayList<>(); - ConnectCallObserver observer = new ConnectCallObserver() { + ConnectServerCallObserver observer = new ConnectServerCallObserver() { @Override public void onRequestPayload(Object payload) { events.add("onRequestPayload"); @@ -231,12 +231,12 @@ public void onRequestFinished() { }; setUpHandler( - new ConnectInterceptorPipeline(List.of(ctx -> ConnectInterceptor.continueWith(observer)))); + new ConnectServerInterceptorPipeline(List.of(ctx -> ConnectServerInterceptor.continueWith(observer)))); channel.pipeline().addLast(new ChannelInboundHandlerAdapter() { @Override public void channelRead(ChannelHandlerContext ctx, Object msg) { switch (msg) { - case ConnectCallExchange ignore -> events.add("ConnectCallExchange"); + case ConnectCallExchange ignore -> events.add("ConnectServerCallExchange"); case ConnectPayload ignore -> events.add("ConnectPayload"); case ConnectEndOfStream ignore -> events.add("ConnectEndOfStream"); case null, default -> { @@ -250,7 +250,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { channel.writeInbound(request); assertThat(events).containsExactly( - "ConnectCallExchange", + "ConnectServerCallExchange", "onRequestPayload", "ConnectPayload", "onRequestFinished", @@ -259,8 +259,8 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { } private void assertResponseHandlerInstalled() { - assertThat(channel.pipeline().get(ConnectPipeline.UNARY_RESPONSE_HANDLER)) - .isInstanceOf(UnaryResponseProcessingHandler.class); + assertThat(channel.pipeline().get(ConnectServerPipeline.UNARY_RESPONSE_HANDLER)) + .isInstanceOf(UnaryResponseProcessingServerHandler.class); } private void assertCallExchangeFired() { diff --git a/src/test/java/io/suboptimal/connectjava/protocol/UnaryResponseProcessingHandlerTest.java b/src/test/java/io/suboptimal/connectjava/protocol/server/UnaryResponseProcessingServerHandlerTest.java similarity index 95% rename from src/test/java/io/suboptimal/connectjava/protocol/UnaryResponseProcessingHandlerTest.java rename to src/test/java/io/suboptimal/connectjava/protocol/server/UnaryResponseProcessingServerHandlerTest.java index f13a637..8c98916 100644 --- a/src/test/java/io/suboptimal/connectjava/protocol/UnaryResponseProcessingHandlerTest.java +++ b/src/test/java/io/suboptimal/connectjava/protocol/server/UnaryResponseProcessingServerHandlerTest.java @@ -1,4 +1,4 @@ -package io.suboptimal.connectjava.protocol; +package io.suboptimal.connectjava.protocol.server; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.http.FullHttpResponse; @@ -18,6 +18,7 @@ import io.suboptimal.connectjava.compression.ConnectCompression; import io.suboptimal.connectjava.compression.ConnectGzipCompression; import io.suboptimal.connectjava.compression.ConnectIdentityCompression; +import io.suboptimal.connectjava.protocol.client.ConnectCallTerminatedException; import io.suboptimal.connectjava.testfixtures.UnaryPostRequest; import io.suboptimal.connectjava.testfixtures.UnaryPostResponse; import io.suboptimal.connectjava.model.ConnectMethodDefinition; @@ -39,7 +40,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -class UnaryResponseProcessingHandlerTest { +class UnaryResponseProcessingServerHandlerTest { private static final String SERVICE_NAME = "connectjava.test.v1.UnaryResponseFixtureService"; private static final ConnectMethodDefinition UNARY_METHOD = new ConnectMethodDefinition( @@ -64,20 +65,20 @@ void tearDownChannel() { } void setUpHandler(ConnectCodec codec, ConnectCompression compression, boolean varyAcceptEncoding) { - setUpHandler(codec, compression, varyAcceptEncoding, ConnectCallObserver.NOOP); + setUpHandler(codec, compression, varyAcceptEncoding, ConnectServerCallObserver.NOOP); } void setUpHandler(ConnectCodec codec, ConnectCompression compression, boolean varyAcceptEncoding, - ConnectCallObserver observer) + ConnectServerCallObserver observer) { - channel.pipeline().addLast(ConnectPipeline.UNARY_RESPONSE_HANDLER, - new UnaryResponseProcessingHandler( + channel.pipeline().addLast(ConnectServerPipeline.UNARY_RESPONSE_HANDLER, + new UnaryResponseProcessingServerHandler( createExchange(), codec, compression, varyAcceptEncoding, observer, ConnectStringBuilderJsonSerializer.INSTANCE)); } void setUpHandler(ConnectCodec codec, ConnectCompression compression, boolean varyAcceptEncoding, - List observers) + List observers) { setUpHandler(codec, compression, varyAcceptEncoding, composite(observers)); } @@ -88,9 +89,9 @@ private static ConnectCallExchange createExchange() { new ResponseHeadersBuilder(), new ResponseTrailersBuilder()); } - private static ConnectCallObserver composite(List observers) { - return new ConnectInterceptorPipeline(observers.stream() - .map(observer -> ctx -> ConnectInterceptor.continueWith(observer)) + private static ConnectServerCallObserver composite(List observers) { + return new ConnectServerInterceptorPipeline(observers.stream() + .map(observer -> ctx -> ConnectServerInterceptor.continueWith(observer)) .toList()) .interceptCall(createExchange()) .observer(); @@ -240,7 +241,7 @@ void errorResponseOmitsVaryAcceptEncoding() { @Test void interceptorsMutateUnaryHeadersAndTrailersInExpectedOrder() { List events = new ArrayList<>(); - ConnectCallObserver first = new ConnectCallObserver() { + ConnectServerCallObserver first = new ConnectServerCallObserver() { @Override public void onResponseHeaders(ConnectResponseHeadersBuilder headers) { events.add("first:headers"); @@ -263,7 +264,7 @@ public void onCallComplete(ConnectError error) { events.add("first:complete"); } }; - ConnectCallObserver second = new ConnectCallObserver() { + ConnectServerCallObserver second = new ConnectServerCallObserver() { @Override public void onResponseHeaders(ConnectResponseHeadersBuilder headers) { events.add("second:headers"); @@ -313,7 +314,7 @@ public void onCallComplete(ConnectError error) { @Test void interceptorsCanAttachUnaryHeadersAndTrailersOnRpcError() { - ConnectCallObserver observer = new ConnectCallObserver() { + ConnectServerCallObserver observer = new ConnectServerCallObserver() { @Override public void onResponseHeaders(ConnectResponseHeadersBuilder headers) { headers.set("error-header", "from-interceptor"); @@ -472,7 +473,7 @@ void endOfStreamAfterSecondPayloadErrorProducesNoResponse() { @Test void cancelPathNotifiesObserverOnceWhenChannelClosedBeforeResponse() { List completions = new ArrayList<>(); - ConnectCallObserver observer = new ConnectCallObserver() { + ConnectServerCallObserver observer = new ConnectServerCallObserver() { @Override public void onCallComplete(ConnectError error) { completions.add(error); @@ -490,7 +491,7 @@ public void onCallComplete(ConnectError error) { @Test void onCallCompleteFiresExactlyOnceAfterSuccessfulResponse() { int[] callCount = {0}; - ConnectCallObserver observer = new ConnectCallObserver() { + ConnectServerCallObserver observer = new ConnectServerCallObserver() { @Override public void onCallComplete(ConnectError error) { callCount[0]++; From 5bcd8b064dfaa36412ece8d1cdda2d1fabf82812 Mon Sep 17 00:00:00 2001 From: Ilya Kazakov Date: Tue, 16 Jun 2026 13:33:33 +0200 Subject: [PATCH 2/2] trailers refactor --- .../connectjava/api/ConnectResponseMeta.java | 8 +------- .../protocol/client/StreamingClientHandler.java | 2 +- .../client/UnaryResponseClientHandler.java | 15 +++++++++------ .../ConnectClientInterceptorPipelineTest.java | 2 +- .../client/UnaryResponseClientHandlerTest.java | 3 +-- 5 files changed, 13 insertions(+), 17 deletions(-) diff --git a/src/main/java/io/suboptimal/connectjava/api/ConnectResponseMeta.java b/src/main/java/io/suboptimal/connectjava/api/ConnectResponseMeta.java index ff00f6b..bc778e9 100644 --- a/src/main/java/io/suboptimal/connectjava/api/ConnectResponseMeta.java +++ b/src/main/java/io/suboptimal/connectjava/api/ConnectResponseMeta.java @@ -10,19 +10,13 @@ * * @param statusCode HTTP status code * @param headers leading metadata (HTTP response headers without {@code trailer-} prefix) - * @param trailers trailing metadata; for unary responses these are extracted from HTTP headers - * with the {@code trailer-} prefix stripped; for streaming responses they - * come from the end-stream envelope's {@code metadata} field via - * {@link ConnectEndOfStream#trailers()} */ public record ConnectResponseMeta( int statusCode, - Map> headers, - Map> trailers + Map> headers ) { public ConnectResponseMeta { headers = copyLower(headers); - trailers = copyLower(trailers); } private static Map> copyLower(Map> source) { diff --git a/src/main/java/io/suboptimal/connectjava/protocol/client/StreamingClientHandler.java b/src/main/java/io/suboptimal/connectjava/protocol/client/StreamingClientHandler.java index 12776a2..361cba4 100644 --- a/src/main/java/io/suboptimal/connectjava/protocol/client/StreamingClientHandler.java +++ b/src/main/java/io/suboptimal/connectjava/protocol/client/StreamingClientHandler.java @@ -239,7 +239,7 @@ private void handleHttpResponse(ChannelHandlerContext ctx, HttpResponse response } Map> headersMap = ClientHandlerSupport.toHeaderMap(response.headers()); - ConnectResponseMeta responseMeta = new ConnectResponseMeta(statusCode, headersMap, Map.of()); + ConnectResponseMeta responseMeta = new ConnectResponseMeta(statusCode, headersMap); decoder = new ConnectEnvelope.Decoder(ctx.alloc(), config.parameters().maxFrameBytes()); observer.onResponseHeaders(responseMeta); ctx.fireChannelRead(new ConnectClientResponseStart( diff --git a/src/main/java/io/suboptimal/connectjava/protocol/client/UnaryResponseClientHandler.java b/src/main/java/io/suboptimal/connectjava/protocol/client/UnaryResponseClientHandler.java index 2acab6c..c7826a0 100644 --- a/src/main/java/io/suboptimal/connectjava/protocol/client/UnaryResponseClientHandler.java +++ b/src/main/java/io/suboptimal/connectjava/protocol/client/UnaryResponseClientHandler.java @@ -40,12 +40,12 @@ class UnaryResponseClientHandler extends SimpleChannelInboundHandler> all = new LinkedHashMap<>(); all.putAll(ClientHandlerSupport.toHeaderMap(response.headers())); all.putAll(ClientHandlerSupport.toHeaderMap(response.trailingHeaders())); @@ -181,6 +181,9 @@ private static ConnectResponseMeta buildMeta(int statusCode, FullHttpResponse re headers.put(name, entry.getValue()); } } - return new ConnectResponseMeta(statusCode, headers, trailers); + + return new MetaContainer(new ConnectResponseMeta(statusCode, headers), trailers); } + + private record MetaContainer(ConnectResponseMeta connectResponseMeta, Map> trailers) { } } diff --git a/src/test/java/io/suboptimal/connectjava/protocol/client/ConnectClientInterceptorPipelineTest.java b/src/test/java/io/suboptimal/connectjava/protocol/client/ConnectClientInterceptorPipelineTest.java index 655daba..fed3067 100644 --- a/src/test/java/io/suboptimal/connectjava/protocol/client/ConnectClientInterceptorPipelineTest.java +++ b/src/test/java/io/suboptimal/connectjava/protocol/client/ConnectClientInterceptorPipelineTest.java @@ -25,7 +25,7 @@ class ConnectClientInterceptorPipelineTest { private static final ConnectClientCallStart CALL_START = new ConnectClientCallStart( SERVICE, METHOD, Map.of(), false, "proto"); private static final ConnectResponseMeta META = - new ConnectResponseMeta(200, Map.of(), Map.of()); + new ConnectResponseMeta(200, Map.of()); @Test void emptyPipelineContinues() { diff --git a/src/test/java/io/suboptimal/connectjava/protocol/client/UnaryResponseClientHandlerTest.java b/src/test/java/io/suboptimal/connectjava/protocol/client/UnaryResponseClientHandlerTest.java index a221fdd..8a40372 100644 --- a/src/test/java/io/suboptimal/connectjava/protocol/client/UnaryResponseClientHandlerTest.java +++ b/src/test/java/io/suboptimal/connectjava/protocol/client/UnaryResponseClientHandlerTest.java @@ -79,7 +79,7 @@ void deliversExchangePayloadEndOfStreamOnSuccess() { assertThat(((ConnectPayload) payload).data()).isEqualTo(RESPONSE); Object endOfStream = channel.readInbound(); - assertThat(endOfStream).isSameAs(ConnectEndOfStream.INSTANCE); + assertThat(endOfStream).isEqualTo(ConnectEndOfStream.INSTANCE); assertThat(observer.events) .containsExactly("onResponseHeaders", "onResponsePayload", "onCallComplete"); @@ -270,7 +270,6 @@ void deliversTrailersFromTrailerPrefixedHeaders() { ConnectClientResponseStart start = channel.readInbound(); ConnectResponseMeta meta = start.responseMeta(); - assertThat(meta.trailers().get("x-foo")).containsExactly("bar"); assertThat(meta.headers()).containsKey("x-plain"); assertThat(meta.headers()).doesNotContainKey("x-foo");