Harden S3 auth boundaries

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
2026-05-16 10:11:04 +02:00
parent eac20f7fda
commit 2425cd524e
10 changed files with 477 additions and 6 deletions

View File

@@ -27,6 +27,7 @@ type RequestTarget struct {
Action Action
Bucket string
Key string
Prefix string
}
func resolveTarget(r *http.Request) RequestTarget {
@@ -51,7 +52,7 @@ func resolveTarget(r *http.Request) RequestTarget {
case http.MethodDelete:
return RequestTarget{Action: ActionDeleteBucket, Bucket: bucket}
case http.MethodGet:
return RequestTarget{Action: ActionListBucket, Bucket: bucket}
return RequestTarget{Action: ActionListBucket, Bucket: bucket, Prefix: r.URL.Query().Get("prefix")}
case http.MethodPost:
if _, ok := r.URL.Query()["delete"]; ok {
return RequestTarget{Action: ActionDeleteObject, Bucket: bucket}

39
auth/action_test.go Normal file
View File

@@ -0,0 +1,39 @@
package auth
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestResolveTargetIncludesListBucketPrefix(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://example.com/test-bucket?list-type=2&prefix=allowed/", nil)
target := resolveTarget(req)
if target.Action != ActionListBucket {
t.Fatalf("action = %q, want %q", target.Action, ActionListBucket)
}
if target.Bucket != "test-bucket" {
t.Fatalf("bucket = %q, want test-bucket", target.Bucket)
}
if target.Prefix != "allowed/" {
t.Fatalf("prefix = %q, want allowed/", target.Prefix)
}
if target.Key != "" {
t.Fatalf("key = %q, want empty", target.Key)
}
}
func TestResolveTargetListBucketWithoutPrefix(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://example.com/test-bucket", nil)
target := resolveTarget(req)
if target.Action != ActionListBucket {
t.Fatalf("action = %q, want %q", target.Action, ActionListBucket)
}
if target.Prefix != "" {
t.Fatalf("prefix = %q, want empty", target.Prefix)
}
}

View File

@@ -1,11 +1,16 @@
package auth
import (
"crypto/sha256"
"encoding/hex"
"errors"
"fs/metrics"
"hash"
"io"
"log/slog"
"net"
"net/http"
"strings"
"github.com/go-chi/chi/v5/middleware"
)
@@ -55,6 +60,16 @@ func Middleware(
return
}
if err := wrapPayloadHashVerifier(r); err != nil {
metrics.Default.ObserveAuth("error", "sigv4", authErrorClass(err))
if onError != nil {
onError(w, r, err)
return
}
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
metrics.Default.ObserveAuth("ok", resolvedCtx.AuthType, "none")
if auditEnabled && logger != nil {
requestID := middleware.GetReqID(r.Context())
@@ -75,6 +90,65 @@ func Middleware(
}
}
func wrapPayloadHashVerifier(r *http.Request) error {
if r == nil || r.Body == nil || r.Body == http.NoBody {
return nil
}
payloadHash := resolvePayloadHash(r, false)
if !payloadHashRequiresVerification(payloadHash) {
return nil
}
if !isHexSHA256(payloadHash) {
return ErrAuthorizationHeaderMalformed
}
expected, err := hex.DecodeString(strings.ToLower(payloadHash))
if err != nil {
return ErrAuthorizationHeaderMalformed
}
r.Body = &payloadHashVerifyingReadCloser{
inner: r.Body,
hasher: sha256.New(),
expected: expected,
}
return nil
}
type payloadHashVerifyingReadCloser struct {
inner io.ReadCloser
hasher hash.Hash
expected []byte
done bool
}
func (r *payloadHashVerifyingReadCloser) Read(p []byte) (int, error) {
n, err := r.inner.Read(p)
if n > 0 {
_, _ = r.hasher.Write(p[:n])
}
if err == io.EOF && !r.done {
r.done = true
if !equalBytes(r.hasher.Sum(nil), r.expected) {
return n, ErrSignatureDoesNotMatch
}
}
return n, err
}
func (r *payloadHashVerifyingReadCloser) Close() error {
return r.inner.Close()
}
func equalBytes(left, right []byte) bool {
if len(left) != len(right) {
return false
}
var diff byte
for i := range left {
diff |= left[i] ^ right[i]
}
return diff == 0
}
func authErrorClass(err error) string {
switch {
case errors.Is(err, ErrInvalidAccessKeyID):

75
auth/payload_hash_test.go Normal file
View File

@@ -0,0 +1,75 @@
package auth
import (
"crypto/sha256"
"encoding/hex"
"errors"
"io"
"net/http"
"strings"
"testing"
)
func TestPayloadHashVerifierAllowsMatchingBody(t *testing.T) {
body := "payload"
req := newPayloadHashRequest(t, body, body)
if err := wrapPayloadHashVerifier(req); err != nil {
t.Fatalf("wrapPayloadHashVerifier returned error: %v", err)
}
got, err := io.ReadAll(req.Body)
if err != nil {
t.Fatalf("ReadAll returned error: %v", err)
}
if string(got) != body {
t.Fatalf("unexpected body: got %q want %q", string(got), body)
}
}
func TestPayloadHashVerifierRejectsMismatchedBody(t *testing.T) {
req := newPayloadHashRequest(t, "signed-payload", "actual-payload")
if err := wrapPayloadHashVerifier(req); err != nil {
t.Fatalf("wrapPayloadHashVerifier returned error: %v", err)
}
_, err := io.ReadAll(req.Body)
if !errors.Is(err, ErrSignatureDoesNotMatch) {
t.Fatalf("ReadAll error = %v, want ErrSignatureDoesNotMatch", err)
}
}
func TestPayloadSigningRejectsSignedStreamingMode(t *testing.T) {
req, err := http.NewRequest(http.MethodPut, "http://example.com/b/k", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("x-amz-content-sha256", "STREAMING-AWS4-HMAC-SHA256-PAYLOAD")
err = validatePayloadSigningMode(req, &sigV4Input{})
if !errors.Is(err, ErrAuthorizationHeaderMalformed) {
t.Fatalf("validatePayloadSigningMode error = %v, want ErrAuthorizationHeaderMalformed", err)
}
}
func TestPayloadSigningAllowsUnsignedStreamingMode(t *testing.T) {
req, err := http.NewRequest(http.MethodPut, "http://example.com/b/k", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("x-amz-content-sha256", "STREAMING-UNSIGNED-PAYLOAD-TRAILER")
if err := validatePayloadSigningMode(req, &sigV4Input{}); err != nil {
t.Fatalf("validatePayloadSigningMode returned error: %v", err)
}
}
func newPayloadHashRequest(t *testing.T, signedBody, actualBody string) *http.Request {
t.Helper()
req, err := http.NewRequest(http.MethodPut, "http://example.com/b/k", strings.NewReader(actualBody))
if err != nil {
t.Fatal(err)
}
sum := sha256.Sum256([]byte(signedBody))
req.Header.Set("x-amz-content-sha256", hex.EncodeToString(sum[:]))
return req
}

View File

@@ -33,14 +33,16 @@ func statementMatches(stmt models.AuthPolicyStatement, target RequestTarget) boo
if !bucketMatches(stmt.Bucket, target.Bucket) {
return false
}
if target.Key == "" {
return true
}
prefix := strings.TrimSpace(stmt.Prefix)
if prefix == "" || prefix == "*" {
return true
}
if target.Key == "" {
if target.Action == ActionListBucket {
return strings.HasPrefix(target.Prefix, prefix)
}
return true
}
return strings.HasPrefix(target.Key, prefix)
}

52
auth/policy_test.go Normal file
View File

@@ -0,0 +1,52 @@
package auth
import (
"fs/models"
"testing"
)
func TestListBucketPolicyAppliesPrefix(t *testing.T) {
policy := &models.AuthPolicy{
Statements: []models.AuthPolicyStatement{
{
Effect: "allow",
Actions: []string{"s3:ListBucket"},
Bucket: "test-bucket",
Prefix: "allowed/",
},
},
}
if !isAllowed(policy, RequestTarget{Action: ActionListBucket, Bucket: "test-bucket", Prefix: "allowed/"}) {
t.Fatalf("expected matching list prefix to be allowed")
}
if !isAllowed(policy, RequestTarget{Action: ActionListBucket, Bucket: "test-bucket", Prefix: "allowed/nested/"}) {
t.Fatalf("expected nested list prefix to be allowed")
}
if isAllowed(policy, RequestTarget{Action: ActionListBucket, Bucket: "test-bucket"}) {
t.Fatalf("expected empty list prefix to be denied")
}
if isAllowed(policy, RequestTarget{Action: ActionListBucket, Bucket: "test-bucket", Prefix: "private/"}) {
t.Fatalf("expected non-matching list prefix to be denied")
}
}
func TestWildcardListBucketPolicyAllowsAnyPrefix(t *testing.T) {
policy := &models.AuthPolicy{
Statements: []models.AuthPolicyStatement{
{
Effect: "allow",
Actions: []string{"s3:ListBucket"},
Bucket: "test-bucket",
Prefix: "*",
},
},
}
if !isAllowed(policy, RequestTarget{Action: ActionListBucket, Bucket: "test-bucket"}) {
t.Fatalf("expected wildcard list policy to allow empty prefix")
}
if !isAllowed(policy, RequestTarget{Action: ActionListBucket, Bucket: "test-bucket", Prefix: "private/"}) {
t.Fatalf("expected wildcard list policy to allow arbitrary prefix")
}
}

View File

@@ -152,6 +152,9 @@ func (s *Service) AuthenticateRequest(r *http.Request) (RequestContext, error) {
if err := validateSigV4Input(s.now(), s.cfg, input); err != nil {
return RequestContext{}, err
}
if err := validatePayloadSigningMode(r, input); err != nil {
return RequestContext{}, err
}
identity, err := s.store.GetAuthIdentity(input.AccessKeyID)
if err != nil {

View File

@@ -210,6 +210,17 @@ func validateSigV4Input(now time.Time, cfg Config, input *sigV4Input) error {
return nil
}
func validatePayloadSigningMode(r *http.Request, input *sigV4Input) error {
payloadHash := resolvePayloadHash(r, input.Presigned)
if isSignedStreamingPayloadHash(payloadHash) {
return fmt.Errorf("%w: signed streaming payload verification is not supported", ErrAuthorizationHeaderMalformed)
}
if payloadHashRequiresVerification(payloadHash) && !isHexSHA256(payloadHash) {
return fmt.Errorf("%w: invalid x-amz-content-sha256", ErrAuthorizationHeaderMalformed)
}
return nil
}
func signatureMatches(secret string, r *http.Request, input *sigV4Input) (bool, error) {
payloadHash := resolvePayloadHash(r, input.Presigned)
canonicalRequest, err := buildCanonicalRequest(r, input.SignedHeaders, payloadHash, input.Presigned)
@@ -233,6 +244,34 @@ func resolvePayloadHash(r *http.Request, presigned bool) string {
return hash
}
func isSignedStreamingPayloadHash(payloadHash string) bool {
payloadHash = strings.ToUpper(strings.TrimSpace(payloadHash))
return strings.HasPrefix(payloadHash, "STREAMING-AWS4-HMAC-SHA256-PAYLOAD")
}
func payloadHashRequiresVerification(payloadHash string) bool {
payloadHash = strings.ToUpper(strings.TrimSpace(payloadHash))
if payloadHash == "" || payloadHash == "UNSIGNED-PAYLOAD" {
return false
}
if strings.HasPrefix(payloadHash, "STREAMING-UNSIGNED-PAYLOAD") {
return false
}
return true
}
func isHexSHA256(value string) bool {
if len(value) != sha256.Size*2 {
return false
}
for _, ch := range value {
if (ch < '0' || ch > '9') && (ch < 'a' || ch > 'f') && (ch < 'A' || ch > 'F') {
return false
}
}
return true
}
func buildCanonicalRequest(r *http.Request, signedHeaders []string, payloadHash string, presigned bool) (string, error) {
canonicalURI := canonicalPath(r.URL)
canonicalQuery := canonicalQueryString(r.URL.RawQuery, presigned)