From 5833c706dbff35a55465db4405cf0ac035c63abe Mon Sep 17 00:00:00 2001
From: Yuchen Zeng <zyc@google.com>
Date: Fri, 13 May 2016 17:23:07 -0700
Subject: [PATCH] Add incremental decoding and input validation

---
 .../transport/chttp2/transport/bin_decoder.c  | 208 ++++++++++++++----
 .../transport/chttp2/transport/bin_decoder.h  |  22 +-
 test/core/transport/chttp2/bin_decoder_test.c |  36 ++-
 3 files changed, 212 insertions(+), 54 deletions(-)

diff --git a/src/core/ext/transport/chttp2/transport/bin_decoder.c b/src/core/ext/transport/chttp2/transport/bin_decoder.c
index fe6c84bfb8..640c29f63d 100644
--- a/src/core/ext/transport/chttp2/transport/bin_decoder.c
+++ b/src/core/ext/transport/chttp2/transport/bin_decoder.c
@@ -32,31 +32,130 @@
  */
 
 #include "src/core/ext/transport/chttp2/transport/bin_decoder.h"
+#include <grpc/support/alloc.h>
 #include <grpc/support/log.h>
 #include <stdio.h>
+#include "src/core/lib/support/string.h"
 
 static uint8_t decode_table[] = {
-    0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
-    0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
-    0,  0,  0,  0,  0,  62, 0,  0,  0,  63, 52, 53, 54, 55, 56, 57, 58, 59, 60,
-    61, 0,  0,  0,  0,  0,  0,  0,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10,
-    11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 0,  0,  0,  0,
-    0,  0,  26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42,
-    43, 44, 45, 46, 47, 48, 49, 50, 51, 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
-    0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
-    0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
-    0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
-    0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
-    0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
-    0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
-    0,  0,  0,  0,  0,  0,  0,  0,  0};
+    0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40,
+    0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40,
+    0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40,
+    0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 62,   0x40, 0x40, 0x40, 63,
+    52,   53,   54,   55,   56,   57,   58,   59,   60,   61,   0x40, 0x40,
+    0x40, 0x40, 0x40, 0x40, 0x40, 0,    1,    2,    3,    4,    5,    6,
+    7,    8,    9,    10,   11,   12,   13,   14,   15,   16,   17,   18,
+    19,   20,   21,   22,   23,   24,   25,   0x40, 0x40, 0x40, 0x40, 0x40,
+    0x40, 26,   27,   28,   29,   30,   31,   32,   33,   34,   35,   36,
+    37,   38,   39,   40,   41,   42,   43,   44,   45,   46,   47,   48,
+    49,   50,   51,   0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40,
+    0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40,
+    0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40,
+    0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40,
+    0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40,
+    0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40,
+    0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40,
+    0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40,
+    0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40,
+    0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40,
+    0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40,
+    0x40, 0x40, 0x40, 0x40};
 
 static const uint8_t tail_xtra[4] = {0, 0, 1, 2};
 
+static inline bool input_is_valid(uint8_t *input_ptr, size_t length) {
+  size_t i;
+
+  for (i = 0; i < length; ++i) {
+    if ((decode_table[input_ptr[i]] & 0xC0) != 0) {
+      gpr_log(GPR_ERROR,
+              "Base64 decoding failed, invalid charactor '%c' in base64 "
+              "input.\n",
+              (char)(*input_ptr));
+      return false;
+    }
+  }
+  return true;
+}
+
+#define COMPOSE_OUTPUT_BYTE_0(input_ptr)        \
+  (uint8_t)((decode_table[input_ptr[0]] << 2) | \
+            (decode_table[input_ptr[1]] >> 4))
+
+#define COMPOSE_OUTPUT_BYTE_1(input_ptr)             \
+  (uint8_t)((decode_table[ctx->input_cur[1]] << 4) | \
+            (decode_table[ctx->input_cur[2]] >> 2))
+
+#define COMPOSE_OUTPUT_BYTE_2(input_ptr)             \
+  (uint8_t)((decode_table[ctx->input_cur[2]] << 6) | \
+            decode_table[ctx->input_cur[3]])
+
+bool grpc_base64_decode_partial(struct grpc_base64_decode_context *ctx) {
+  size_t input_tail;
+
+  if (ctx->input_cur > ctx->input_end || ctx->output_cur > ctx->output_end) {
+    return false;
+  }
+
+  while (ctx->input_end >= ctx->input_cur + 4 &&
+         ctx->output_end >= ctx->output_cur + 3) {
+    if (!input_is_valid(ctx->input_cur, 4)) return false;
+    ctx->output_cur[0] = COMPOSE_OUTPUT_BYTE_0(ctx->input_cur);
+    ctx->output_cur[1] = COMPOSE_OUTPUT_BYTE_1(ctx->input_cur);
+    ctx->output_cur[2] = COMPOSE_OUTPUT_BYTE_2(ctx->input_cur);
+    ctx->output_cur += 3;
+    ctx->input_cur += 4;
+  }
+
+  input_tail = (size_t)(ctx->input_end - ctx->input_cur);
+  if (input_tail == 4) {
+    // Process the input data with pad chars
+    if (ctx->input_cur[3] == '=') {
+      if (ctx->input_cur[2] == '=' && ctx->output_end >= ctx->output_cur + 1) {
+        if (!input_is_valid(ctx->input_cur, 2)) return false;
+        *(ctx->output_cur++) = COMPOSE_OUTPUT_BYTE_0(ctx->input_cur);
+        ctx->input_cur += 4;
+      } else if (ctx->output_end >= ctx->output_cur + 2) {
+        if (!input_is_valid(ctx->input_cur, 3)) return false;
+        *(ctx->output_cur++) = COMPOSE_OUTPUT_BYTE_0(ctx->input_cur);
+        *(ctx->output_cur++) = COMPOSE_OUTPUT_BYTE_1(ctx->input_cur);
+        ;
+        ctx->input_cur += 4;
+      }
+    }
+
+  } else if (ctx->contains_tail && input_tail > 1) {
+    // Process the input data without pad chars, but constains_tail is set
+    if (ctx->output_end >= ctx->output_cur + tail_xtra[input_tail]) {
+      if (!input_is_valid(ctx->input_cur, input_tail)) return false;
+      switch (input_tail) {
+        case 3:
+          ctx->output_cur[1] = COMPOSE_OUTPUT_BYTE_1(ctx->input_cur);
+        case 2:
+          ctx->output_cur[0] = COMPOSE_OUTPUT_BYTE_0(ctx->input_cur);
+      }
+      ctx->output_cur += tail_xtra[input_tail];
+      ctx->input_cur += input_tail;
+    }
+  }
+
+  return true;
+}
+
 gpr_slice grpc_chttp2_base64_decode(gpr_slice input) {
   size_t input_length = GPR_SLICE_LENGTH(input);
-  GPR_ASSERT(input_length % 4 == 0);
   size_t output_length = input_length / 4 * 3;
+  struct grpc_base64_decode_context ctx;
+  gpr_slice output;
+
+  if (input_length % 4 != 0) {
+    gpr_log(GPR_ERROR,
+            "Base64 decoding failed, input of "
+            "grpc_chttp2_base64_decode has a length of %zu, which is not a "
+            "multiple of 4.\n",
+            input_length);
+    return gpr_empty_slice();
+  }
 
   if (input_length > 0) {
     uint8_t *input_end = GPR_SLICE_END_PTR(input);
@@ -67,49 +166,66 @@ gpr_slice grpc_chttp2_base64_decode(gpr_slice input) {
       }
     }
   }
+  output = gpr_slice_malloc(output_length);
 
-  gpr_log(GPR_ERROR, "input_length: %d, output_length: %d\n", input_length,
-          output_length);
+  ctx.input_cur = GPR_SLICE_START_PTR(input);
+  ctx.input_end = GPR_SLICE_END_PTR(input);
+  ctx.output_cur = GPR_SLICE_START_PTR(output);
+  ctx.output_end = GPR_SLICE_END_PTR(output);
+  ctx.contains_tail = false;
 
-  return grpc_chttp2_base64_decode_with_length(input, output_length);
+  if (!grpc_base64_decode_partial(&ctx)) {
+    char *s = gpr_dump_slice(input, GPR_DUMP_ASCII);
+    gpr_log(GPR_ERROR, "Base64 decoding failed, input string:\n%s\n", s);
+    gpr_free(s);
+    gpr_slice_unref(output);
+    return gpr_empty_slice();
+  }
+  GPR_ASSERT(ctx.output_cur == GPR_SLICE_END_PTR(output));
+  GPR_ASSERT(ctx.input_cur == GPR_SLICE_END_PTR(input));
+  return output;
 }
 
 gpr_slice grpc_chttp2_base64_decode_with_length(gpr_slice input,
                                                 size_t output_length) {
   size_t input_length = GPR_SLICE_LENGTH(input);
-  // The length of a base64 string cannot be 4 * n + 1
-  GPR_ASSERT(input_length % 4 != 1);
-  GPR_ASSERT(output_length <=
-             input_length / 4 * 3 + tail_xtra[input_length % 4]);
-  size_t output_triplets = output_length / 3;
-  size_t tail_case = output_length % 3;
   gpr_slice output = gpr_slice_malloc(output_length);
-  uint8_t *in = GPR_SLICE_START_PTR(input);
-  uint8_t *out = GPR_SLICE_START_PTR(output);
-  size_t i;
+  struct grpc_base64_decode_context ctx;
 
-  for (i = 0; i < output_triplets; i++) {
-    out[0] = (uint8_t)((decode_table[in[0]] << 2) | (decode_table[in[1]] >> 4));
-    out[1] = (uint8_t)((decode_table[in[1]] << 4) | (decode_table[in[2]] >> 2));
-    out[2] = (uint8_t)((decode_table[in[2]] << 6) | decode_table[in[3]]);
-    out += 3;
-    in += 4;
+  // The length of a base64 string cannot be 4 * n + 1
+  if (input_length % 4 == 1) {
+    gpr_log(GPR_ERROR,
+            "Base64 decoding failed, input of "
+            "grpc_chttp2_base64_decode_with_length has a length of %zu, which "
+            "has a tail of 1 byte.\n",
+            input_length);
+    gpr_slice_unref(output);
+    return gpr_empty_slice();
   }
 
-  if (tail_case > 0) {
-    switch (tail_case) {
-      case 2:
-        out[1] =
-            (uint8_t)((decode_table[in[1]] << 4) | (decode_table[in[2]] >> 2));
-      case 1:
-        out[0] =
-            (uint8_t)((decode_table[in[0]] << 2) | (decode_table[in[1]] >> 4));
-    }
-    out += tail_case;
-    in += tail_case + 1;
+  if (output_length > input_length / 4 * 3 + tail_xtra[input_length % 4]) {
+    gpr_log(GPR_ERROR,
+            "Base64 decoding failed, output_length %zu is longer "
+            "than the max possible output length %zu./\n",
+            output_length, input_length / 4 * 3 + tail_xtra[input_length % 4]);
+    gpr_slice_unref(output);
+    return gpr_empty_slice();
   }
 
-  GPR_ASSERT(out == GPR_SLICE_END_PTR(output));
-  GPR_ASSERT(in <= GPR_SLICE_END_PTR(input));
+  ctx.input_cur = GPR_SLICE_START_PTR(input);
+  ctx.input_end = GPR_SLICE_END_PTR(input);
+  ctx.output_cur = GPR_SLICE_START_PTR(output);
+  ctx.output_end = GPR_SLICE_END_PTR(output);
+  ctx.contains_tail = true;
+
+  if (!grpc_base64_decode_partial(&ctx)) {
+    char *s = gpr_dump_slice(input, GPR_DUMP_ASCII);
+    gpr_log(GPR_ERROR, "Base64 decoding failed, input string:\n%s\n", s);
+    gpr_free(s);
+    gpr_slice_unref(output);
+    return gpr_empty_slice();
+  }
+  GPR_ASSERT(ctx.output_cur == GPR_SLICE_END_PTR(output));
+  GPR_ASSERT(ctx.input_cur <= GPR_SLICE_END_PTR(input));
   return output;
 }
diff --git a/src/core/ext/transport/chttp2/transport/bin_decoder.h b/src/core/ext/transport/chttp2/transport/bin_decoder.h
index 5516f86d53..b9d40c9b74 100644
--- a/src/core/ext/transport/chttp2/transport/bin_decoder.h
+++ b/src/core/ext/transport/chttp2/transport/bin_decoder.h
@@ -35,13 +35,31 @@
 #define GRPC_CORE_EXT_TRANSPORT_CHTTP2_TRANSPORT_BIN_DECODER_H
 
 #include <grpc/support/slice.h>
+#include <stdbool.h>
+
+struct grpc_base64_decode_context {
+  /* input/output: */
+  uint8_t *input_cur;
+  uint8_t *input_end;
+  uint8_t *output_cur;
+  uint8_t *output_end;
+  /* Indicate if the decoder should handle the tail of input data*/
+  bool contains_tail;
+};
+
+/* base64 decode a grpc_base64_decode_context util either input_end is reached
+   or output_end is reached. When input_end is reached, (input_end - input_cur)
+   is less than 4. When output_end is reached, (output_end - output_cur) is less
+   than 3. Returns false if decoding is failed. */
+bool grpc_base64_decode_partial(struct grpc_base64_decode_context *ctx);
 
 /* base64 decode a slice with pad chars. Returns a new slice, does not take
-   ownership of the input */
+   ownership of the input. Returns an empty slice if decoding is failed. */
 gpr_slice grpc_chttp2_base64_decode(gpr_slice input);
 
 /* base64 decode a slice without pad chars, data length is needed. Returns a new
-   slice, does not take ownership of the input */
+   slice, does not take ownership of the input. Returns an empty slice if
+   decoding is failed. */
 gpr_slice grpc_chttp2_base64_decode_with_length(gpr_slice input,
                                                 size_t output_length);
 
diff --git a/test/core/transport/chttp2/bin_decoder_test.c b/test/core/transport/chttp2/bin_decoder_test.c
index 980da02dc3..c4e6cd332f 100644
--- a/test/core/transport/chttp2/bin_decoder_test.c
+++ b/test/core/transport/chttp2/bin_decoder_test.c
@@ -37,7 +37,6 @@
 
 #include <grpc/support/alloc.h>
 #include <grpc/support/log.h>
-#include <grpc/support/log.h>
 #include "src/core/ext/transport/chttp2/transport/bin_encoder.h"
 #include "src/core/lib/support/string.h"
 
@@ -72,6 +71,14 @@ static gpr_slice base64_decode(const char *s) {
   return out;
 }
 
+static gpr_slice base64_decode_with_length(const char *s,
+                                           size_t output_length) {
+  gpr_slice ss = gpr_slice_from_copied_string(s);
+  gpr_slice out = grpc_chttp2_base64_decode_with_length(ss, output_length);
+  gpr_slice_unref(ss);
+  return out;
+}
+
 #define EXPECT_SLICE_EQ(expected, slice)                                   \
   expect_slice_eq(                                                         \
       gpr_slice_from_copied_buffer(expected, sizeof(expected) - 1), slice, \
@@ -82,11 +89,9 @@ static gpr_slice base64_decode(const char *s) {
       s, grpc_chttp2_base64_decode_with_length(base64_encode(s), strlen(s)));
 
 int main(int argc, char **argv) {
-  /*
-   * ENCODE_AND_DECODE tests grpc_chttp2_base64_decode_with_length(), which
-   * takes encoded base64 strings without pad chars, but output length is
-   * required
-   */
+  /* ENCODE_AND_DECODE tests grpc_chttp2_base64_decode_with_length(), which
+     takes encoded base64 strings without pad chars, but output length is
+     required. */
   /* Base64 test vectors from RFC 4648 */
   ENCODE_AND_DECODE("");
   ENCODE_AND_DECODE("f");
@@ -116,5 +121,24 @@ int main(int argc, char **argv) {
 
   EXPECT_SLICE_EQ("\xc0\xc1\xc2\xc3\xc4\xc5", base64_decode("wMHCw8TF"));
 
+  // Test illegal input length in grpc_chttp2_base64_decode
+  EXPECT_SLICE_EQ("", base64_decode("a"));
+  EXPECT_SLICE_EQ("", base64_decode("ab"));
+  EXPECT_SLICE_EQ("", base64_decode("abc"));
+
+  // Test illegal charactors in grpc_chttp2_base64_decode
+  EXPECT_SLICE_EQ("", base64_decode("Zm:v"));
+  EXPECT_SLICE_EQ("", base64_decode("Zm=v"));
+
+  // Test output_length longer than max possible output length in
+  // grpc_chttp2_base64_decode_with_length
+  EXPECT_SLICE_EQ("", base64_decode_with_length("Zg", 2));
+  EXPECT_SLICE_EQ("", base64_decode_with_length("Zm8", 3));
+  EXPECT_SLICE_EQ("", base64_decode_with_length("Zm9v", 4));
+
+  // Test illegal charactors in grpc_chttp2_base64_decode_with_length
+  EXPECT_SLICE_EQ("", base64_decode_with_length("Zm:v", 3));
+  EXPECT_SLICE_EQ("", base64_decode_with_length("Zm=v", 3));
+
   return all_ok ? 0 : 1;
 }
-- 
GitLab