Files
fs/auth/middleware.go
2026-05-16 10:11:04 +02:00

186 lines
4.6 KiB
Go

package auth
import (
"crypto/sha256"
"encoding/hex"
"errors"
"fs/metrics"
"hash"
"io"
"log/slog"
"net"
"net/http"
"strings"
"github.com/go-chi/chi/v5/middleware"
)
func Middleware(
svc *Service,
logger *slog.Logger,
auditEnabled bool,
onError func(http.ResponseWriter, *http.Request, error),
) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authCtx := RequestContext{Authenticated: false, AuthType: "none"}
if svc == nil || !svc.Config().Enabled {
metrics.Default.ObserveAuth("bypass", "disabled", "auth_disabled")
next.ServeHTTP(w, r.WithContext(WithRequestContext(r.Context(), authCtx)))
return
}
if r.URL.Path == "/healthz" {
metrics.Default.ObserveAuth("bypass", "none", "public_endpoint")
next.ServeHTTP(w, r.WithContext(WithRequestContext(r.Context(), authCtx)))
return
}
resolvedCtx, err := svc.AuthenticateRequest(r)
if err != nil {
metrics.Default.ObserveAuth("error", "sigv4", authErrorClass(err))
if auditEnabled && logger != nil {
requestID := middleware.GetReqID(r.Context())
attrs := []any{
"method", r.Method,
"path", r.URL.Path,
"remote_ip", clientIP(r.RemoteAddr),
"error", err.Error(),
}
if requestID != "" {
attrs = append(attrs, "request_id", requestID)
}
logger.Warn("auth_failed", attrs...)
}
if onError != nil {
onError(w, r, err)
return
}
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
if err := wrapPayloadHashVerifier(r); err != nil {
metrics.Default.ObserveAuth("error", "sigv4", authErrorClass(err))
if onError != nil {
onError(w, r, err)
return
}
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
metrics.Default.ObserveAuth("ok", resolvedCtx.AuthType, "none")
if auditEnabled && logger != nil {
requestID := middleware.GetReqID(r.Context())
attrs := []any{
"method", r.Method,
"path", r.URL.Path,
"remote_ip", clientIP(r.RemoteAddr),
"access_key_id", resolvedCtx.AccessKeyID,
"auth_type", resolvedCtx.AuthType,
}
if requestID != "" {
attrs = append(attrs, "request_id", requestID)
}
logger.Info("auth_success", attrs...)
}
next.ServeHTTP(w, r.WithContext(WithRequestContext(r.Context(), resolvedCtx)))
})
}
}
func wrapPayloadHashVerifier(r *http.Request) error {
if r == nil || r.Body == nil || r.Body == http.NoBody {
return nil
}
payloadHash := resolvePayloadHash(r, false)
if !payloadHashRequiresVerification(payloadHash) {
return nil
}
if !isHexSHA256(payloadHash) {
return ErrAuthorizationHeaderMalformed
}
expected, err := hex.DecodeString(strings.ToLower(payloadHash))
if err != nil {
return ErrAuthorizationHeaderMalformed
}
r.Body = &payloadHashVerifyingReadCloser{
inner: r.Body,
hasher: sha256.New(),
expected: expected,
}
return nil
}
type payloadHashVerifyingReadCloser struct {
inner io.ReadCloser
hasher hash.Hash
expected []byte
done bool
}
func (r *payloadHashVerifyingReadCloser) Read(p []byte) (int, error) {
n, err := r.inner.Read(p)
if n > 0 {
_, _ = r.hasher.Write(p[:n])
}
if err == io.EOF && !r.done {
r.done = true
if !equalBytes(r.hasher.Sum(nil), r.expected) {
return n, ErrSignatureDoesNotMatch
}
}
return n, err
}
func (r *payloadHashVerifyingReadCloser) Close() error {
return r.inner.Close()
}
func equalBytes(left, right []byte) bool {
if len(left) != len(right) {
return false
}
var diff byte
for i := range left {
diff |= left[i] ^ right[i]
}
return diff == 0
}
func authErrorClass(err error) string {
switch {
case errors.Is(err, ErrInvalidAccessKeyID):
return "invalid_access_key"
case errors.Is(err, ErrSignatureDoesNotMatch):
return "signature_mismatch"
case errors.Is(err, ErrAuthorizationHeaderMalformed):
return "auth_header_malformed"
case errors.Is(err, ErrRequestTimeTooSkewed):
return "time_skew"
case errors.Is(err, ErrExpiredToken):
return "expired_token"
case errors.Is(err, ErrNoAuthCredentials):
return "missing_credentials"
case errors.Is(err, ErrUnsupportedAuthScheme):
return "unsupported_auth_scheme"
case errors.Is(err, ErrInvalidPresign):
return "invalid_presign"
case errors.Is(err, ErrCredentialDisabled):
return "credential_disabled"
case errors.Is(err, ErrAccessDenied):
return "access_denied"
default:
return "other"
}
}
func clientIP(remoteAddr string) string {
host, _, err := net.SplitHostPort(remoteAddr)
if err == nil && host != "" {
return host
}
return remoteAddr
}