diff --git a/src/csharp/Grpc.Core.Tests/Internal/AsyncCallTest.cs b/src/csharp/Grpc.Core.Tests/Internal/AsyncCallTest.cs index 141af7760c8d376492b1c0c7b9dfe6e8fd46428c..1fa895ba711f4c2bf1a3a54f655ea9cf9b2f0cfe 100644 --- a/src/csharp/Grpc.Core.Tests/Internal/AsyncCallTest.cs +++ b/src/csharp/Grpc.Core.Tests/Internal/AsyncCallTest.cs @@ -66,7 +66,7 @@ namespace Grpc.Core.Internal.Tests public void AsyncUnary_CompletionSuccess() { var resultTask = asyncCall.UnaryCallAsync("abc"); - fakeCall.UnaryResponseClientHandler(true, new ClientSideStatus(Status.DefaultSuccess, new Metadata()), new byte[] { 1, 2, 3 }); + fakeCall.UnaryResponseClientHandler(true, new ClientSideStatus(Status.DefaultSuccess, new Metadata()), new byte[] { 1, 2, 3 }, new Metadata()); Assert.IsTrue(resultTask.IsCompleted); Assert.IsTrue(fakeCall.IsDisposed); Assert.AreEqual(Status.DefaultSuccess, asyncCall.GetStatus()); @@ -76,7 +76,7 @@ namespace Grpc.Core.Internal.Tests public void AsyncUnary_CompletionFailure() { var resultTask = asyncCall.UnaryCallAsync("abc"); - fakeCall.UnaryResponseClientHandler(false, new ClientSideStatus(), null); + fakeCall.UnaryResponseClientHandler(false, new ClientSideStatus(new Status(StatusCode.Internal, ""), null), new byte[] { 1, 2, 3 }, new Metadata()); Assert.IsTrue(resultTask.IsCompleted); Assert.IsTrue(fakeCall.IsDisposed); diff --git a/src/csharp/Grpc.Core.Tests/ResponseHeadersTest.cs b/src/csharp/Grpc.Core.Tests/ResponseHeadersTest.cs index 706006702e5d1c5f2c3e14977fcbe6ec62bf28f3..8ad41af1b81497477d62ebe8eed872c48b92bdfa 100644 --- a/src/csharp/Grpc.Core.Tests/ResponseHeadersTest.cs +++ b/src/csharp/Grpc.Core.Tests/ResponseHeadersTest.cs @@ -73,6 +73,25 @@ namespace Grpc.Core.Tests server.ShutdownAsync().Wait(); } + [Test] + public async Task ResponseHeadersAsync_UnaryCall() + { + helper.UnaryHandler = new UnaryServerMethod<string, string>(async (request, context) => + { + await context.WriteResponseHeadersAsync(headers); + return "PASS"; + }); + + var call = Calls.AsyncUnaryCall(helper.CreateUnaryCall(), ""); + var responseHeaders = await call.ResponseHeadersAsync; + + Assert.AreEqual(headers.Count, responseHeaders.Count); + Assert.AreEqual("ascii-header", responseHeaders[0].Key); + Assert.AreEqual("abcdefg", responseHeaders[0].Value); + + Assert.AreEqual("PASS", await call.ResponseAsync); + } + [Test] public void WriteResponseHeaders_NullNotAllowed() { diff --git a/src/csharp/Grpc.Core/AsyncClientStreamingCall.cs b/src/csharp/Grpc.Core/AsyncClientStreamingCall.cs index fb9b562c77b762e959f84015d9a0c3609bb2c1d0..dbaa3085c54501fb098222298d3c375ce3aa6abe 100644 --- a/src/csharp/Grpc.Core/AsyncClientStreamingCall.cs +++ b/src/csharp/Grpc.Core/AsyncClientStreamingCall.cs @@ -44,14 +44,16 @@ namespace Grpc.Core { readonly IClientStreamWriter<TRequest> requestStream; readonly Task<TResponse> responseAsync; + readonly Task<Metadata> responseHeadersAsync; readonly Func<Status> getStatusFunc; readonly Func<Metadata> getTrailersFunc; readonly Action disposeAction; - public AsyncClientStreamingCall(IClientStreamWriter<TRequest> requestStream, Task<TResponse> responseAsync, Func<Status> getStatusFunc, Func<Metadata> getTrailersFunc, Action disposeAction) + public AsyncClientStreamingCall(IClientStreamWriter<TRequest> requestStream, Task<TResponse> responseAsync, Task<Metadata> responseHeadersAsync, Func<Status> getStatusFunc, Func<Metadata> getTrailersFunc, Action disposeAction) { this.requestStream = requestStream; this.responseAsync = responseAsync; + this.responseHeadersAsync = responseHeadersAsync; this.getStatusFunc = getStatusFunc; this.getTrailersFunc = getTrailersFunc; this.disposeAction = disposeAction; @@ -68,6 +70,17 @@ namespace Grpc.Core } } + /// <summary> + /// Asynchronous access to response headers. + /// </summary> + public Task<Metadata> ResponseHeadersAsync + { + get + { + return this.responseHeadersAsync; + } + } + /// <summary> /// Async stream to send streaming requests. /// </summary> diff --git a/src/csharp/Grpc.Core/AsyncUnaryCall.cs b/src/csharp/Grpc.Core/AsyncUnaryCall.cs index 224e34391603fd68a23b2ef1a24dfd593f2d73bb..154a17a33ef480539a54d5f147f545f6189547a6 100644 --- a/src/csharp/Grpc.Core/AsyncUnaryCall.cs +++ b/src/csharp/Grpc.Core/AsyncUnaryCall.cs @@ -43,13 +43,15 @@ namespace Grpc.Core public sealed class AsyncUnaryCall<TResponse> : IDisposable { readonly Task<TResponse> responseAsync; + readonly Task<Metadata> responseHeadersAsync; readonly Func<Status> getStatusFunc; readonly Func<Metadata> getTrailersFunc; readonly Action disposeAction; - public AsyncUnaryCall(Task<TResponse> responseAsync, Func<Status> getStatusFunc, Func<Metadata> getTrailersFunc, Action disposeAction) + public AsyncUnaryCall(Task<TResponse> responseAsync, Task<Metadata> responseHeadersAsync, Func<Status> getStatusFunc, Func<Metadata> getTrailersFunc, Action disposeAction) { this.responseAsync = responseAsync; + this.responseHeadersAsync = responseHeadersAsync; this.getStatusFunc = getStatusFunc; this.getTrailersFunc = getTrailersFunc; this.disposeAction = disposeAction; @@ -66,6 +68,17 @@ namespace Grpc.Core } } + /// <summary> + /// Asynchronous access to response headers. + /// </summary> + public Task<Metadata> ResponseHeadersAsync + { + get + { + return this.responseHeadersAsync; + } + } + /// <summary> /// Allows awaiting this object directly. /// </summary> diff --git a/src/csharp/Grpc.Core/Calls.cs b/src/csharp/Grpc.Core/Calls.cs index 7067456638a59f46a73f3aac054c4688acf27a54..ada3616aa4aff7b555fc4975120d86189286ff6b 100644 --- a/src/csharp/Grpc.Core/Calls.cs +++ b/src/csharp/Grpc.Core/Calls.cs @@ -74,7 +74,7 @@ namespace Grpc.Core { var asyncCall = new AsyncCall<TRequest, TResponse>(call); var asyncResult = asyncCall.UnaryCallAsync(req); - return new AsyncUnaryCall<TResponse>(asyncResult, asyncCall.GetStatus, asyncCall.GetTrailers, asyncCall.Cancel); + return new AsyncUnaryCall<TResponse>(asyncResult, asyncCall.ResponseHeadersAsync, asyncCall.GetStatus, asyncCall.GetTrailers, asyncCall.Cancel); } /// <summary> @@ -110,7 +110,7 @@ namespace Grpc.Core var asyncCall = new AsyncCall<TRequest, TResponse>(call); var resultTask = asyncCall.ClientStreamingCallAsync(); var requestStream = new ClientRequestStream<TRequest, TResponse>(asyncCall); - return new AsyncClientStreamingCall<TRequest, TResponse>(requestStream, resultTask, asyncCall.GetStatus, asyncCall.GetTrailers, asyncCall.Cancel); + return new AsyncClientStreamingCall<TRequest, TResponse>(requestStream, resultTask, asyncCall.ResponseHeadersAsync, asyncCall.GetStatus, asyncCall.GetTrailers, asyncCall.Cancel); } /// <summary> diff --git a/src/csharp/Grpc.Core/Internal/AsyncCall.cs b/src/csharp/Grpc.Core/Internal/AsyncCall.cs index 30d60077f01e214acc5ee1739377f7639470a741..132b4264243fabb7fc2b6ff8e581621758ceae1e 100644 --- a/src/csharp/Grpc.Core/Internal/AsyncCall.cs +++ b/src/csharp/Grpc.Core/Internal/AsyncCall.cs @@ -56,6 +56,9 @@ namespace Grpc.Core.Internal // Completion of a pending unary response if not null. TaskCompletionSource<TResponse> unaryResponseTcs; + // Response headers set here once received. + TaskCompletionSource<Metadata> responseHeadersTcs = new TaskCompletionSource<Metadata>(); + // Set after status is received. Used for both unary and streaming response calls. ClientSideStatus? finishedStatus; @@ -110,7 +113,7 @@ namespace Grpc.Core.Internal bool success = (ev.success != 0); try { - HandleUnaryResponse(success, ctx.GetReceivedStatusOnClient(), ctx.GetReceivedMessage()); + HandleUnaryResponse(success, ctx.GetReceivedStatusOnClient(), ctx.GetReceivedMessage(), ctx.GetReceivedInitialMetadata()); } catch (Exception e) { @@ -257,6 +260,17 @@ namespace Grpc.Core.Internal } } + /// <summary> + /// Get the task that completes once response headers are received. + /// </summary> + public Task<Metadata> ResponseHeadersAsync + { + get + { + return responseHeadersTcs.Task; + } + } + /// <summary> /// Gets the resulting status if the call has already finished. /// Throws InvalidOperationException otherwise. @@ -371,7 +385,7 @@ namespace Grpc.Core.Internal /// <summary> /// Handler for unary response completion. /// </summary> - private void HandleUnaryResponse(bool success, ClientSideStatus receivedStatus, byte[] receivedMessage) + private void HandleUnaryResponse(bool success, ClientSideStatus receivedStatus, byte[] receivedMessage, Metadata responseHeaders) { lock (myLock) { @@ -383,18 +397,13 @@ namespace Grpc.Core.Internal ReleaseResourcesIfPossible(); } - if (!success) - { - var internalError = new Status(StatusCode.Internal, "Internal error occured."); - finishedStatus = new ClientSideStatus(internalError, null); - unaryResponseTcs.SetException(new RpcException(internalError)); - return; - } + responseHeadersTcs.SetResult(responseHeaders); var status = receivedStatus.Status; - if (status.StatusCode != StatusCode.OK) + if (!success || status.StatusCode != StatusCode.OK) { + unaryResponseTcs.SetException(new RpcException(status)); return; } diff --git a/src/csharp/Grpc.Core/Internal/CallSafeHandle.cs b/src/csharp/Grpc.Core/Internal/CallSafeHandle.cs index e1466da65b3786826fc03da2f0472bce8c38bf2a..ed6747ea93adab1369d01b7384472cf87158d3dd 100644 --- a/src/csharp/Grpc.Core/Internal/CallSafeHandle.cs +++ b/src/csharp/Grpc.Core/Internal/CallSafeHandle.cs @@ -112,7 +112,7 @@ namespace Grpc.Core.Internal public void StartUnary(UnaryResponseClientHandler callback, byte[] payload, MetadataArraySafeHandle metadataArray, WriteFlags writeFlags) { var ctx = BatchContextSafeHandle.Create(); - completionRegistry.RegisterBatchCompletion(ctx, (success, context) => callback(success, context.GetReceivedStatusOnClient(), context.GetReceivedMessage())); + completionRegistry.RegisterBatchCompletion(ctx, (success, context) => callback(success, context.GetReceivedStatusOnClient(), context.GetReceivedMessage(), context.GetReceivedInitialMetadata())); grpcsharp_call_start_unary(this, ctx, payload, new UIntPtr((ulong)payload.Length), metadataArray, writeFlags) .CheckOk(); } @@ -126,7 +126,7 @@ namespace Grpc.Core.Internal public void StartClientStreaming(UnaryResponseClientHandler callback, MetadataArraySafeHandle metadataArray) { var ctx = BatchContextSafeHandle.Create(); - completionRegistry.RegisterBatchCompletion(ctx, (success, context) => callback(success, context.GetReceivedStatusOnClient(), context.GetReceivedMessage())); + completionRegistry.RegisterBatchCompletion(ctx, (success, context) => callback(success, context.GetReceivedStatusOnClient(), context.GetReceivedMessage(), context.GetReceivedInitialMetadata())); grpcsharp_call_start_client_streaming(this, ctx, metadataArray).CheckOk(); } diff --git a/src/csharp/Grpc.Core/Internal/INativeCall.cs b/src/csharp/Grpc.Core/Internal/INativeCall.cs index 42028e458cf9b491735e8bcf1524269b1b4be350..ef2e230ff8de6f9fd2ff48fb007504b6dc2db370 100644 --- a/src/csharp/Grpc.Core/Internal/INativeCall.cs +++ b/src/csharp/Grpc.Core/Internal/INativeCall.cs @@ -33,8 +33,9 @@ using System; namespace Grpc.Core.Internal { - internal delegate void UnaryResponseClientHandler(bool success, ClientSideStatus receivedStatus, byte[] receivedMessage); + internal delegate void UnaryResponseClientHandler(bool success, ClientSideStatus receivedStatus, byte[] receivedMessage, Metadata responseHeaders); + // Received status for streaming response calls. internal delegate void ReceivedStatusOnClientHandler(bool success, ClientSideStatus receivedStatus); internal delegate void ReceivedMessageHandler(bool success, byte[] receivedMessage);