diff --git a/api/api.go b/api/api.go index 6f296e5..6a3f6aa 100644 --- a/api/api.go +++ b/api/api.go @@ -41,6 +41,7 @@ const ( maxXMLBodyBytes int64 = 1 << 20 maxDeleteObjects = 1000 maxObjectKeyBytes = 1024 + maxAWSChunkedLineBytes = 8 << 10 serverReadHeaderTimeout = 5 * time.Second serverReadTimeout = 60 * time.Second serverWriteTimeout = 120 * time.Second @@ -387,6 +388,10 @@ func (h *Handler) handlePutObject(w http.ResponseWriter, r *http.Request) { bodyReader := io.Reader(r.Body) var decodeStream io.ReadCloser + if hasUnsupportedAWSChunkedPayload(r) { + writeS3Error(w, r, s3ErrInvalidArgument, r.URL.Path) + return + } if shouldDecodeAWSChunkedPayload(r) { decodeStream = newAWSChunkedDecodingReader(r.Body) defer decodeStream.Close() @@ -461,6 +466,10 @@ func (h *Handler) handlePutObject(w http.ResponseWriter, r *http.Request) { bodyReader := io.Reader(r.Body) var decodeStream io.ReadCloser + if hasUnsupportedAWSChunkedPayload(r) { + writeS3Error(w, r, s3ErrInvalidArgument, r.URL.Path) + return + } if shouldDecodeAWSChunkedPayload(r) { decodeStream = newAWSChunkedDecodingReader(r.Body) defer decodeStream.Close() @@ -516,17 +525,18 @@ func (h *Handler) handleListMultipartParts(w http.ResponseWriter, r *http.Reques } func shouldDecodeAWSChunkedPayload(r *http.Request) bool { - contentEncoding := strings.ToLower(r.Header.Get("Content-Encoding")) - if strings.Contains(contentEncoding, "aws-chunked") { - return true - } signingMode := strings.ToLower(r.Header.Get("x-amz-content-sha256")) - if strings.HasPrefix(signingMode, "streaming-aws4-hmac-sha256-payload") { - return true - } return strings.HasPrefix(signingMode, "streaming-unsigned-payload") } +func hasUnsupportedAWSChunkedPayload(r *http.Request) bool { + contentEncoding := strings.ToLower(r.Header.Get("Content-Encoding")) + if !strings.Contains(contentEncoding, "aws-chunked") { + return false + } + return !shouldDecodeAWSChunkedPayload(r) +} + func newAWSChunkedDecodingReader(src io.Reader) io.ReadCloser { probedReader, isAWSChunked := probeAWSChunkedPayload(src) if !isAWSChunked { @@ -545,9 +555,12 @@ func newAWSChunkedDecodingReader(src io.Reader) io.ReadCloser { } func probeAWSChunkedPayload(src io.Reader) (io.Reader, bool) { - reader := bufio.NewReaderSize(src, 512) + reader := bufio.NewReaderSize(src, maxAWSChunkedLineBytes) headerLine, err := reader.ReadSlice('\n') replay := io.MultiReader(bytes.NewReader(headerLine), reader) + if errors.Is(err, bufio.ErrBufferFull) { + return replay, true + } if err != nil { return replay, false } @@ -569,9 +582,9 @@ func probeAWSChunkedPayload(src io.Reader) (io.Reader, bool) { } func decodeAWSChunkedPayload(src io.Reader, dst io.Writer) error { - reader := bufio.NewReader(src) + reader := bufio.NewReaderSize(src, maxAWSChunkedLineBytes) for { - headerLine, err := reader.ReadString('\n') + headerLine, err := readAWSChunkedLine(reader) if err != nil { return err } @@ -588,6 +601,17 @@ func decodeAWSChunkedPayload(src io.Reader, dst io.Writer) error { if chunkSize < 0 { return fmt.Errorf("invalid aws-chunked size: %d", chunkSize) } + if chunkSize == 0 { + for { + line, err := readAWSChunkedLine(reader) + if err != nil { + return err + } + if line == "\r\n" || line == "\n" { + return nil + } + } + } if chunkSize > 0 { if _, err := io.CopyN(dst, reader, chunkSize); err != nil { return err @@ -601,21 +625,20 @@ func decodeAWSChunkedPayload(src io.Reader, dst io.Writer) error { if crlf[0] != '\r' || crlf[1] != '\n' { return errors.New("invalid aws-chunked payload terminator") } - - if chunkSize == 0 { - for { - line, err := reader.ReadString('\n') - if err != nil { - return err - } - if line == "\r\n" || line == "\n" { - return nil - } - } - } } } +func readAWSChunkedLine(reader *bufio.Reader) (string, error) { + line, err := reader.ReadSlice('\n') + if errors.Is(err, bufio.ErrBufferFull) { + return "", service.ErrEntityTooLarge + } + if len(line) > maxAWSChunkedLineBytes { + return "", service.ErrEntityTooLarge + } + return string(line), err +} + func ifNoneMatchPreconditionFailed(headerValue, etag string) bool { for _, rawToken := range strings.Split(headerValue, ",") { token := strings.TrimSpace(rawToken) diff --git a/api/aws_chunked_test.go b/api/aws_chunked_test.go index 870b830..7e594b5 100644 --- a/api/aws_chunked_test.go +++ b/api/aws_chunked_test.go @@ -1,10 +1,14 @@ package api import ( + "errors" "io" "net/http" + "net/http/httptest" "strings" "testing" + + "fs/service" ) func TestShouldDecodeAWSChunkedPayloadUnsignedTrailerMode(t *testing.T) { @@ -20,6 +24,45 @@ func TestShouldDecodeAWSChunkedPayloadUnsignedTrailerMode(t *testing.T) { } } +func TestUnsupportedAWSChunkedContentEncodingWithoutStreamingMode(t *testing.T) { + t.Parallel() + + req, err := http.NewRequest(http.MethodPut, "http://example.com/b/k", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Encoding", "aws-chunked") + req.Header.Set("x-amz-content-sha256", "UNSIGNED-PAYLOAD") + + if !hasUnsupportedAWSChunkedPayload(req) { + t.Fatalf("expected aws-chunked content encoding without streaming mode to be unsupported") + } + if shouldDecodeAWSChunkedPayload(req) { + t.Fatalf("non-streaming aws-chunked content encoding must not trigger decoding") + } +} + +func TestPutObjectRejectsUnsignedAWSChunkedContentEncoding(t *testing.T) { + handler, svc := newUploadLimitHandler(t, 1024) + if err := svc.CreateBucket("test-bucket"); err != nil { + t.Fatalf("CreateBucket: %v", err) + } + + req := httptest.NewRequest(http.MethodPut, "/test-bucket/object.txt", strings.NewReader("4\r\nWiki\r\n0\r\n\r\n")) + req.Header.Set("Content-Encoding", "aws-chunked") + req.Header.Set("x-amz-content-sha256", "UNSIGNED-PAYLOAD") + rec := httptest.NewRecorder() + + handler.router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d body=%s", rec.Code, http.StatusBadRequest, rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), "InvalidArgument") { + t.Fatalf("expected InvalidArgument response, body=%s", rec.Body.String()) + } +} + func TestAWSChunkedReaderPassThroughForPlainPayload(t *testing.T) { t.Parallel() @@ -43,7 +86,6 @@ func TestAWSChunkedReaderDecodesChunkedPayload(t *testing.T) { "4\r\nWiki\r\n" + "5\r\npedia\r\n" + "0\r\n" + - "\r\n" + "x-amz-checksum-crc32:xxxx\r\n" + "\r\n" @@ -58,3 +100,16 @@ func TestAWSChunkedReaderDecodesChunkedPayload(t *testing.T) { t.Fatalf("decoded payload mismatch: got %q want %q", string(out), "Wikipedia") } } + +func TestAWSChunkedReaderRejectsOversizedChunkHeader(t *testing.T) { + t.Parallel() + + encoded := strings.Repeat("f", maxAWSChunkedLineBytes+1) + "\n" + reader := newAWSChunkedDecodingReader(strings.NewReader(encoded)) + defer reader.Close() + + _, err := io.ReadAll(reader) + if !errors.Is(err, service.ErrEntityTooLarge) { + t.Fatalf("read error = %v, want ErrEntityTooLarge", err) + } +}