diff --git a/api/api.go b/api/api.go index 3c98517..6f296e5 100644 --- a/api/api.go +++ b/api/api.go @@ -196,6 +196,10 @@ func parseCopySource(raw string) (string, string, error) { } func (h *Handler) authorizeCopySource(r *http.Request, bucket, key string) error { + return h.authorizeObjectAction(r, auth.ActionGetObject, bucket, key) +} + +func (h *Handler) authorizeObjectAction(r *http.Request, action auth.Action, bucket, key string) error { if h.authSvc == nil || !h.authSvc.Config().Enabled { return nil } @@ -206,7 +210,7 @@ func (h *Handler) authorizeCopySource(r *http.Request, bucket, key string) error } return h.authSvc.Authorize(authCtx.AccessKeyID, auth.RequestTarget{ - Action: auth.ActionGetObject, + Action: action, Bucket: bucket, Key: key, }) @@ -307,6 +311,10 @@ func (h *Handler) handlePostObject(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxXMLBodyBytes) var req models.CompleteMultipartUploadRequest if err := xml.NewDecoder(r.Body).Decode(&req); err != nil { + if errors.Is(err, auth.ErrSignatureDoesNotMatch) { + writeMappedS3Error(w, r, err) + return + } var maxErr *http.MaxBytesError if errors.As(err, &maxErr) { writeS3Error(w, r, s3ErrEntityTooLarge, r.URL.Path) @@ -664,6 +672,10 @@ func (h *Handler) handlePostBucket(w http.ResponseWriter, r *http.Request) { var req models.DeleteObjectsRequest if err := xml.NewDecoder(bodyReader).Decode(&req); err != nil { + if errors.Is(err, auth.ErrSignatureDoesNotMatch) { + writeMappedS3Error(w, r, err) + return + } var maxErr *http.MaxBytesError if errors.As(err, &maxErr) { writeS3Error(w, r, s3ErrEntityTooLarge, r.URL.Path) @@ -699,6 +711,15 @@ func (h *Handler) handlePostBucket(w http.ResponseWriter, r *http.Request) { }) continue } + if err := h.authorizeObjectAction(r, auth.ActionDeleteObject, bucket, obj.Key); err != nil { + apiErr := mapToS3Error(err) + response.Errors = append(response.Errors, models.DeleteError{ + Key: obj.Key, + Code: apiErr.Code, + Message: apiErr.Message, + }) + continue + } keys = append(keys, obj.Key) } diff --git a/api/multi_delete_auth_test.go b/api/multi_delete_auth_test.go new file mode 100644 index 0000000..495f504 --- /dev/null +++ b/api/multi_delete_auth_test.go @@ -0,0 +1,165 @@ +package api + +import ( + "bytes" + "context" + "encoding/base64" + "errors" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "path/filepath" + "strings" + "testing" + "time" + + "fs/auth" + "fs/logging" + "fs/metadata" + "fs/models" + "fs/service" + "fs/storage" + + "github.com/go-chi/chi/v5" +) + +func newAuthorizedDeleteHandler(t *testing.T) (*Handler, *service.ObjectService, *auth.Service) { + t.Helper() + + root := t.TempDir() + md, err := metadata.NewMetadataHandler(filepath.Join(root, "metadata.db")) + if err != nil { + t.Fatalf("new metadata handler: %v", err) + } + blob, err := storage.NewBlobStore(root, 1024) + if err != nil { + t.Fatalf("new blob store: %v", err) + } + svc := service.NewObjectService(md, blob, time.Hour) + t.Cleanup(func() { + _ = svc.Close() + }) + + masterKey := base64.StdEncoding.EncodeToString(make([]byte, 32)) + authSvc, err := auth.NewService(auth.ConfigFromValues( + true, + "us-east-1", + 0, + 0, + masterKey, + "", + "", + "", + ), md) + if err != nil { + t.Fatalf("new auth service: %v", err) + } + + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + handler := NewHandler(svc, logger, logging.Config{}, authSvc, false) + return handler, svc, authSvc +} + +func newBucketPostRequest(bucket, body string) *http.Request { + req := httptest.NewRequest(http.MethodPost, "/"+bucket+"?delete", strings.NewReader(body)) + rctx := chi.NewRouteContext() + rctx.URLParams.Add("bucket", bucket) + return req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) +} + +func withAuthContext(req *http.Request, accessKeyID string) *http.Request { + authCtx := auth.RequestContext{ + Authenticated: true, + AccessKeyID: accessKeyID, + AuthType: "test", + } + return req.WithContext(auth.WithRequestContext(req.Context(), authCtx)) +} + +func createDeleteUser(t *testing.T, authSvc *auth.Service, prefix string) { + t.Helper() + _, err := authSvc.CreateUser(auth.CreateUserInput{ + AccessKeyID: "delete-user", + SecretKey: "delete-secret-1", + Policy: models.AuthPolicy{ + Statements: []models.AuthPolicyStatement{ + { + Effect: "allow", + Actions: []string{"s3:DeleteObject"}, + Bucket: "test-bucket", + Prefix: prefix, + }, + }, + }, + }) + if err != nil { + t.Fatalf("create delete user: %v", err) + } +} + +func putTestObject(t *testing.T, svc *service.ObjectService, key string) { + t.Helper() + _, err := svc.PutObject("test-bucket", key, "text/plain", bytes.NewReader([]byte("data"))) + if err != nil { + t.Fatalf("put object %q: %v", key, err) + } +} + +func TestMultiDeleteAuthorizesEveryKey(t *testing.T) { + handler, svc, authSvc := newAuthorizedDeleteHandler(t) + if err := svc.CreateBucket("test-bucket"); err != nil { + t.Fatalf("create bucket: %v", err) + } + createDeleteUser(t, authSvc, "allowed/") + putTestObject(t, svc, "allowed/file.txt") + putTestObject(t, svc, "private/file.txt") + + body := `allowed/file.txtprivate/file.txt` + req := withAuthContext(newBucketPostRequest("test-bucket", body), "delete-user") + rec := httptest.NewRecorder() + + handler.handlePostBucket(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: got %d body=%s", rec.Code, rec.Body.String()) + } + responseBody := rec.Body.String() + if !strings.Contains(responseBody, "") || !strings.Contains(responseBody, "allowed/file.txt") { + t.Fatalf("expected allowed key to be deleted, body=%s", responseBody) + } + if !strings.Contains(responseBody, "") || !strings.Contains(responseBody, "private/file.txt") || !strings.Contains(responseBody, "AccessDenied") { + t.Fatalf("expected denied key error, body=%s", responseBody) + } + if _, err := svc.HeadObject("test-bucket", "allowed/file.txt"); !errors.Is(err, metadata.ErrObjectNotFound) { + t.Fatalf("allowed object should be deleted, got err=%v", err) + } + if _, err := svc.HeadObject("test-bucket", "private/file.txt"); err != nil { + t.Fatalf("private object should remain: %v", err) + } +} + +func TestMultiDeleteAllowsScopedKeys(t *testing.T) { + handler, svc, authSvc := newAuthorizedDeleteHandler(t) + if err := svc.CreateBucket("test-bucket"); err != nil { + t.Fatalf("create bucket: %v", err) + } + createDeleteUser(t, authSvc, "allowed/") + putTestObject(t, svc, "allowed/file.txt") + + body := `allowed/file.txt` + req := withAuthContext(newBucketPostRequest("test-bucket", body), "delete-user") + rec := httptest.NewRecorder() + + handler.handlePostBucket(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: got %d body=%s", rec.Code, rec.Body.String()) + } + if strings.Contains(rec.Body.String(), "") { + t.Fatalf("unexpected delete error body=%s", rec.Body.String()) + } + if _, err := svc.HeadObject("test-bucket", "allowed/file.txt"); !errors.Is(err, metadata.ErrObjectNotFound) { + t.Fatalf("allowed object should be deleted, got err=%v", err) + } +} diff --git a/auth/action.go b/auth/action.go index 576da2e..988029f 100644 --- a/auth/action.go +++ b/auth/action.go @@ -27,6 +27,7 @@ type RequestTarget struct { Action Action Bucket string Key string + Prefix string } func resolveTarget(r *http.Request) RequestTarget { @@ -51,7 +52,7 @@ func resolveTarget(r *http.Request) RequestTarget { case http.MethodDelete: return RequestTarget{Action: ActionDeleteBucket, Bucket: bucket} case http.MethodGet: - return RequestTarget{Action: ActionListBucket, Bucket: bucket} + return RequestTarget{Action: ActionListBucket, Bucket: bucket, Prefix: r.URL.Query().Get("prefix")} case http.MethodPost: if _, ok := r.URL.Query()["delete"]; ok { return RequestTarget{Action: ActionDeleteObject, Bucket: bucket} diff --git a/auth/action_test.go b/auth/action_test.go new file mode 100644 index 0000000..8f0b257 --- /dev/null +++ b/auth/action_test.go @@ -0,0 +1,39 @@ +package auth + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestResolveTargetIncludesListBucketPrefix(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://example.com/test-bucket?list-type=2&prefix=allowed/", nil) + + target := resolveTarget(req) + + if target.Action != ActionListBucket { + t.Fatalf("action = %q, want %q", target.Action, ActionListBucket) + } + if target.Bucket != "test-bucket" { + t.Fatalf("bucket = %q, want test-bucket", target.Bucket) + } + if target.Prefix != "allowed/" { + t.Fatalf("prefix = %q, want allowed/", target.Prefix) + } + if target.Key != "" { + t.Fatalf("key = %q, want empty", target.Key) + } +} + +func TestResolveTargetListBucketWithoutPrefix(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://example.com/test-bucket", nil) + + target := resolveTarget(req) + + if target.Action != ActionListBucket { + t.Fatalf("action = %q, want %q", target.Action, ActionListBucket) + } + if target.Prefix != "" { + t.Fatalf("prefix = %q, want empty", target.Prefix) + } +} diff --git a/auth/middleware.go b/auth/middleware.go index 24f656e..d6e2497 100644 --- a/auth/middleware.go +++ b/auth/middleware.go @@ -1,11 +1,16 @@ package auth import ( + "crypto/sha256" + "encoding/hex" "errors" "fs/metrics" + "hash" + "io" "log/slog" "net" "net/http" + "strings" "github.com/go-chi/chi/v5/middleware" ) @@ -55,6 +60,16 @@ func Middleware( return } + if err := wrapPayloadHashVerifier(r); err != nil { + metrics.Default.ObserveAuth("error", "sigv4", authErrorClass(err)) + if onError != nil { + onError(w, r, err) + return + } + http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) + return + } + metrics.Default.ObserveAuth("ok", resolvedCtx.AuthType, "none") if auditEnabled && logger != nil { requestID := middleware.GetReqID(r.Context()) @@ -75,6 +90,65 @@ func Middleware( } } +func wrapPayloadHashVerifier(r *http.Request) error { + if r == nil || r.Body == nil || r.Body == http.NoBody { + return nil + } + payloadHash := resolvePayloadHash(r, false) + if !payloadHashRequiresVerification(payloadHash) { + return nil + } + if !isHexSHA256(payloadHash) { + return ErrAuthorizationHeaderMalformed + } + expected, err := hex.DecodeString(strings.ToLower(payloadHash)) + if err != nil { + return ErrAuthorizationHeaderMalformed + } + r.Body = &payloadHashVerifyingReadCloser{ + inner: r.Body, + hasher: sha256.New(), + expected: expected, + } + return nil +} + +type payloadHashVerifyingReadCloser struct { + inner io.ReadCloser + hasher hash.Hash + expected []byte + done bool +} + +func (r *payloadHashVerifyingReadCloser) Read(p []byte) (int, error) { + n, err := r.inner.Read(p) + if n > 0 { + _, _ = r.hasher.Write(p[:n]) + } + if err == io.EOF && !r.done { + r.done = true + if !equalBytes(r.hasher.Sum(nil), r.expected) { + return n, ErrSignatureDoesNotMatch + } + } + return n, err +} + +func (r *payloadHashVerifyingReadCloser) Close() error { + return r.inner.Close() +} + +func equalBytes(left, right []byte) bool { + if len(left) != len(right) { + return false + } + var diff byte + for i := range left { + diff |= left[i] ^ right[i] + } + return diff == 0 +} + func authErrorClass(err error) string { switch { case errors.Is(err, ErrInvalidAccessKeyID): diff --git a/auth/payload_hash_test.go b/auth/payload_hash_test.go new file mode 100644 index 0000000..c209aeb --- /dev/null +++ b/auth/payload_hash_test.go @@ -0,0 +1,75 @@ +package auth + +import ( + "crypto/sha256" + "encoding/hex" + "errors" + "io" + "net/http" + "strings" + "testing" +) + +func TestPayloadHashVerifierAllowsMatchingBody(t *testing.T) { + body := "payload" + req := newPayloadHashRequest(t, body, body) + + if err := wrapPayloadHashVerifier(req); err != nil { + t.Fatalf("wrapPayloadHashVerifier returned error: %v", err) + } + got, err := io.ReadAll(req.Body) + if err != nil { + t.Fatalf("ReadAll returned error: %v", err) + } + if string(got) != body { + t.Fatalf("unexpected body: got %q want %q", string(got), body) + } +} + +func TestPayloadHashVerifierRejectsMismatchedBody(t *testing.T) { + req := newPayloadHashRequest(t, "signed-payload", "actual-payload") + + if err := wrapPayloadHashVerifier(req); err != nil { + t.Fatalf("wrapPayloadHashVerifier returned error: %v", err) + } + _, err := io.ReadAll(req.Body) + if !errors.Is(err, ErrSignatureDoesNotMatch) { + t.Fatalf("ReadAll error = %v, want ErrSignatureDoesNotMatch", err) + } +} + +func TestPayloadSigningRejectsSignedStreamingMode(t *testing.T) { + req, err := http.NewRequest(http.MethodPut, "http://example.com/b/k", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("x-amz-content-sha256", "STREAMING-AWS4-HMAC-SHA256-PAYLOAD") + + err = validatePayloadSigningMode(req, &sigV4Input{}) + if !errors.Is(err, ErrAuthorizationHeaderMalformed) { + t.Fatalf("validatePayloadSigningMode error = %v, want ErrAuthorizationHeaderMalformed", err) + } +} + +func TestPayloadSigningAllowsUnsignedStreamingMode(t *testing.T) { + req, err := http.NewRequest(http.MethodPut, "http://example.com/b/k", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("x-amz-content-sha256", "STREAMING-UNSIGNED-PAYLOAD-TRAILER") + + if err := validatePayloadSigningMode(req, &sigV4Input{}); err != nil { + t.Fatalf("validatePayloadSigningMode returned error: %v", err) + } +} + +func newPayloadHashRequest(t *testing.T, signedBody, actualBody string) *http.Request { + t.Helper() + req, err := http.NewRequest(http.MethodPut, "http://example.com/b/k", strings.NewReader(actualBody)) + if err != nil { + t.Fatal(err) + } + sum := sha256.Sum256([]byte(signedBody)) + req.Header.Set("x-amz-content-sha256", hex.EncodeToString(sum[:])) + return req +} diff --git a/auth/policy.go b/auth/policy.go index 2508fc9..80899c1 100644 --- a/auth/policy.go +++ b/auth/policy.go @@ -33,14 +33,16 @@ func statementMatches(stmt models.AuthPolicyStatement, target RequestTarget) boo if !bucketMatches(stmt.Bucket, target.Bucket) { return false } - if target.Key == "" { - return true - } - prefix := strings.TrimSpace(stmt.Prefix) if prefix == "" || prefix == "*" { return true } + if target.Key == "" { + if target.Action == ActionListBucket { + return strings.HasPrefix(target.Prefix, prefix) + } + return true + } return strings.HasPrefix(target.Key, prefix) } diff --git a/auth/policy_test.go b/auth/policy_test.go new file mode 100644 index 0000000..474e7a8 --- /dev/null +++ b/auth/policy_test.go @@ -0,0 +1,52 @@ +package auth + +import ( + "fs/models" + "testing" +) + +func TestListBucketPolicyAppliesPrefix(t *testing.T) { + policy := &models.AuthPolicy{ + Statements: []models.AuthPolicyStatement{ + { + Effect: "allow", + Actions: []string{"s3:ListBucket"}, + Bucket: "test-bucket", + Prefix: "allowed/", + }, + }, + } + + if !isAllowed(policy, RequestTarget{Action: ActionListBucket, Bucket: "test-bucket", Prefix: "allowed/"}) { + t.Fatalf("expected matching list prefix to be allowed") + } + if !isAllowed(policy, RequestTarget{Action: ActionListBucket, Bucket: "test-bucket", Prefix: "allowed/nested/"}) { + t.Fatalf("expected nested list prefix to be allowed") + } + if isAllowed(policy, RequestTarget{Action: ActionListBucket, Bucket: "test-bucket"}) { + t.Fatalf("expected empty list prefix to be denied") + } + if isAllowed(policy, RequestTarget{Action: ActionListBucket, Bucket: "test-bucket", Prefix: "private/"}) { + t.Fatalf("expected non-matching list prefix to be denied") + } +} + +func TestWildcardListBucketPolicyAllowsAnyPrefix(t *testing.T) { + policy := &models.AuthPolicy{ + Statements: []models.AuthPolicyStatement{ + { + Effect: "allow", + Actions: []string{"s3:ListBucket"}, + Bucket: "test-bucket", + Prefix: "*", + }, + }, + } + + if !isAllowed(policy, RequestTarget{Action: ActionListBucket, Bucket: "test-bucket"}) { + t.Fatalf("expected wildcard list policy to allow empty prefix") + } + if !isAllowed(policy, RequestTarget{Action: ActionListBucket, Bucket: "test-bucket", Prefix: "private/"}) { + t.Fatalf("expected wildcard list policy to allow arbitrary prefix") + } +} diff --git a/auth/service.go b/auth/service.go index 9a5b01a..2bb0eee 100644 --- a/auth/service.go +++ b/auth/service.go @@ -152,6 +152,9 @@ func (s *Service) AuthenticateRequest(r *http.Request) (RequestContext, error) { if err := validateSigV4Input(s.now(), s.cfg, input); err != nil { return RequestContext{}, err } + if err := validatePayloadSigningMode(r, input); err != nil { + return RequestContext{}, err + } identity, err := s.store.GetAuthIdentity(input.AccessKeyID) if err != nil { diff --git a/auth/sigv4.go b/auth/sigv4.go index 8988498..8ede60d 100644 --- a/auth/sigv4.go +++ b/auth/sigv4.go @@ -210,6 +210,17 @@ func validateSigV4Input(now time.Time, cfg Config, input *sigV4Input) error { return nil } +func validatePayloadSigningMode(r *http.Request, input *sigV4Input) error { + payloadHash := resolvePayloadHash(r, input.Presigned) + if isSignedStreamingPayloadHash(payloadHash) { + return fmt.Errorf("%w: signed streaming payload verification is not supported", ErrAuthorizationHeaderMalformed) + } + if payloadHashRequiresVerification(payloadHash) && !isHexSHA256(payloadHash) { + return fmt.Errorf("%w: invalid x-amz-content-sha256", ErrAuthorizationHeaderMalformed) + } + return nil +} + func signatureMatches(secret string, r *http.Request, input *sigV4Input) (bool, error) { payloadHash := resolvePayloadHash(r, input.Presigned) canonicalRequest, err := buildCanonicalRequest(r, input.SignedHeaders, payloadHash, input.Presigned) @@ -233,6 +244,34 @@ func resolvePayloadHash(r *http.Request, presigned bool) string { return hash } +func isSignedStreamingPayloadHash(payloadHash string) bool { + payloadHash = strings.ToUpper(strings.TrimSpace(payloadHash)) + return strings.HasPrefix(payloadHash, "STREAMING-AWS4-HMAC-SHA256-PAYLOAD") +} + +func payloadHashRequiresVerification(payloadHash string) bool { + payloadHash = strings.ToUpper(strings.TrimSpace(payloadHash)) + if payloadHash == "" || payloadHash == "UNSIGNED-PAYLOAD" { + return false + } + if strings.HasPrefix(payloadHash, "STREAMING-UNSIGNED-PAYLOAD") { + return false + } + return true +} + +func isHexSHA256(value string) bool { + if len(value) != sha256.Size*2 { + return false + } + for _, ch := range value { + if (ch < '0' || ch > '9') && (ch < 'a' || ch > 'f') && (ch < 'A' || ch > 'F') { + return false + } + } + return true +} + func buildCanonicalRequest(r *http.Request, signedHeaders []string, payloadHash string, presigned bool) (string, error) { canonicalURI := canonicalPath(r.URL) canonicalQuery := canonicalQueryString(r.URL.RawQuery, presigned)