diff --git a/include/grpc++/client_context.h b/include/grpc++/client_context.h
index 0a81f6a36664828e707c45600d834736bc17cabc..f74de8fad4ff908989b099e482bc924a2cbac857 100644
--- a/include/grpc++/client_context.h
+++ b/include/grpc++/client_context.h
@@ -95,6 +95,7 @@ class ClientContext {
 
   gpr_timespec RawDeadline() { return absolute_deadline_; }
 
+  bool initial_metadata_received_ = false;
   grpc_call *call_;
   grpc_completion_queue *cq_;
   gpr_timespec absolute_deadline_;
diff --git a/include/grpc++/impl/call.h b/include/grpc++/impl/call.h
index 5fafd0e89046de647b337d7baecbb4436954ca52..a1ef9268f0c63b83e4f666009a52d2d62f8905f0 100644
--- a/include/grpc++/impl/call.h
+++ b/include/grpc++/impl/call.h
@@ -134,16 +134,7 @@ class Call final {
   grpc_call *call() { return call_; }
   CompletionQueue *cq() { return cq_; }
 
-  // TODO(yangg) change it to a general state query function.
-  bool initial_metadata_received() {
-    return initial_metadata_received_;
-  }
-  void set_initial_metadata_received() {
-    initial_metadata_received_ = true;
-  }
-
  private:
-  bool initial_metadata_received_ = false;
   CallHook *call_hook_;
   CompletionQueue *cq_;
   grpc_call* call_;
diff --git a/include/grpc++/stream.h b/include/grpc++/stream.h
index 6265310c5a823c227d112d1fdc417edc77da5006..74e7539aa4756301555b1c025dbb4a204dc07201 100644
--- a/include/grpc++/stream.h
+++ b/include/grpc++/stream.h
@@ -99,21 +99,25 @@ class ClientReader final : public ClientStreamingInterface,
   }
 
   // Blocking wait for initial metadata from server. The received metadata
-  // can only be accessed after this call returns. Calling this method is
-  // optional as it will be called internally before the first Read.
+  // can only be accessed after this call returns. Should only be called before
+  // the first read. Calling this method is optional, and if it is not called
+  // the metadata will be available in ClientContext after the first read.
   void WaitForInitialMetadata() {
-    if (!call_.initial_metadata_received()) {
-      CallOpBuffer buf;
-      buf.AddRecvInitialMetadata(&context_->recv_initial_metadata_);
-      call_.PerformOps(&buf);
-      GPR_ASSERT(cq_.Pluck(&buf));
-      call_.set_initial_metadata_received();
-    }
+    GPR_ASSERT(!context_->initial_metadata_received_);
+
+    CallOpBuffer buf;
+    buf.AddRecvInitialMetadata(&context_->recv_initial_metadata_);
+    call_.PerformOps(&buf);
+    GPR_ASSERT(cq_.Pluck(&buf));
+    context_->initial_metadata_received_ = true;
   }
 
   virtual bool Read(R *msg) override {
-    WaitForInitialMetadata();
     CallOpBuffer buf;
+    if (!context_->initial_metadata_received_) {
+      buf.AddRecvInitialMetadata(&context_->recv_initial_metadata_);
+      context_->initial_metadata_received_ = true;
+    }
     bool got_message;
     buf.AddRecvMessage(msg, &got_message);
     call_.PerformOps(&buf);
@@ -201,21 +205,25 @@ class ClientReaderWriter final : public ClientStreamingInterface,
   }
 
   // Blocking wait for initial metadata from server. The received metadata
-  // can only be accessed after this call returns. Calling this method is
-  // optional as it will be called internally before the first Read.
+  // can only be accessed after this call returns. Should only be called before
+  // the first read. Calling this method is optional, and if it is not called
+  // the metadata will be available in ClientContext after the first read.
   void WaitForInitialMetadata() {
-    if (!call_.initial_metadata_received()) {
-      CallOpBuffer buf;
-      buf.AddRecvInitialMetadata(&context_->recv_initial_metadata_);
-      call_.PerformOps(&buf);
-      GPR_ASSERT(cq_.Pluck(&buf));
-      call_.set_initial_metadata_received();
-    }
+    GPR_ASSERT(!context_->initial_metadata_received_);
+
+    CallOpBuffer buf;
+    buf.AddRecvInitialMetadata(&context_->recv_initial_metadata_);
+    call_.PerformOps(&buf);
+    GPR_ASSERT(cq_.Pluck(&buf));
+    context_->initial_metadata_received_ = true;
   }
 
   virtual bool Read(R *msg) override {
-    WaitForInitialMetadata();
     CallOpBuffer buf;
+    if (!context_->initial_metadata_received_) {
+      buf.AddRecvInitialMetadata(&context_->recv_initial_metadata_);
+      context_->initial_metadata_received_ = true;
+    }
     bool got_message;
     buf.AddRecvMessage(msg, &got_message);
     call_.PerformOps(&buf);
@@ -257,13 +265,13 @@ class ServerReader final : public ReaderInterface<R> {
   ServerReader(Call* call, ServerContext* ctx) : call_(call), ctx_(ctx) {}
 
   void SendInitialMetadata() {
-    if (!ctx_->sent_initial_metadata_) {
-      CallOpBuffer buf;
-      buf.AddSendInitialMetadata(&ctx_->initial_metadata_);
-      ctx_->sent_initial_metadata_ = true;
-      call_->PerformOps(&buf);
-      call_->cq()->Pluck(&buf);
-    }
+    GPR_ASSERT(!ctx_->sent_initial_metadata_);
+
+    CallOpBuffer buf;
+    buf.AddSendInitialMetadata(&ctx_->initial_metadata_);
+    ctx_->sent_initial_metadata_ = true;
+    call_->PerformOps(&buf);
+    call_->cq()->Pluck(&buf);
   }
 
   virtual bool Read(R* msg) override {
@@ -285,18 +293,21 @@ class ServerWriter final : public WriterInterface<W> {
   ServerWriter(Call* call, ServerContext* ctx) : call_(call), ctx_(ctx) {}
 
   void SendInitialMetadata() {
-    if (!ctx_->sent_initial_metadata_) {
-      CallOpBuffer buf;
-      buf.AddSendInitialMetadata(&ctx_->initial_metadata_);
-      ctx_->sent_initial_metadata_ = true;
-      call_->PerformOps(&buf);
-      call_->cq()->Pluck(&buf);
-    }
+    GPR_ASSERT(!ctx_->sent_initial_metadata_);
+
+    CallOpBuffer buf;
+    buf.AddSendInitialMetadata(&ctx_->initial_metadata_);
+    ctx_->sent_initial_metadata_ = true;
+    call_->PerformOps(&buf);
+    call_->cq()->Pluck(&buf);
   }
 
   virtual bool Write(const W& msg) override {
-    SendInitialMetadata();
     CallOpBuffer buf;
+    if (!ctx_->sent_initial_metadata_) {
+      buf.AddSendInitialMetadata(&ctx_->initial_metadata_);
+      ctx_->sent_initial_metadata_ = true;
+    }
     buf.AddSendMessage(msg);
     call_->PerformOps(&buf);
     return call_->cq()->Pluck(&buf);
@@ -315,13 +326,13 @@ class ServerReaderWriter final : public WriterInterface<W>,
   ServerReaderWriter(Call* call, ServerContext* ctx) : call_(call), ctx_(ctx) {}
 
   void SendInitialMetadata() {
-    if (!ctx_->sent_initial_metadata_) {
-      CallOpBuffer buf;
-      buf.AddSendInitialMetadata(&ctx_->initial_metadata_);
-      ctx_->sent_initial_metadata_ = true;
-      call_->PerformOps(&buf);
-      call_->cq()->Pluck(&buf);
-    }
+    GPR_ASSERT(!ctx_->sent_initial_metadata_);
+
+    CallOpBuffer buf;
+    buf.AddSendInitialMetadata(&ctx_->initial_metadata_);
+    ctx_->sent_initial_metadata_ = true;
+    call_->PerformOps(&buf);
+    call_->cq()->Pluck(&buf);
   }
 
   virtual bool Read(R* msg) override {
@@ -333,8 +344,11 @@ class ServerReaderWriter final : public WriterInterface<W>,
   }
 
   virtual bool Write(const W& msg) override {
-    SendInitialMetadata();
     CallOpBuffer buf;
+    if (!ctx_->sent_initial_metadata_) {
+      buf.AddSendInitialMetadata(&ctx_->initial_metadata_);
+      ctx_->sent_initial_metadata_ = true;
+    }
     buf.AddSendMessage(msg);
     call_->PerformOps(&buf);
     return call_->cq()->Pluck(&buf);