mirror of
https://github.com/ferdzo/fs.git
synced 2026-06-04 05:26:46 +00:00
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
186 lines
4.6 KiB
Go
186 lines
4.6 KiB
Go
package auth
|
|
|
|
import (
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fs/metrics"
|
|
"hash"
|
|
"io"
|
|
"log/slog"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/go-chi/chi/v5/middleware"
|
|
)
|
|
|
|
func Middleware(
|
|
svc *Service,
|
|
logger *slog.Logger,
|
|
auditEnabled bool,
|
|
onError func(http.ResponseWriter, *http.Request, error),
|
|
) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
authCtx := RequestContext{Authenticated: false, AuthType: "none"}
|
|
if svc == nil || !svc.Config().Enabled {
|
|
metrics.Default.ObserveAuth("bypass", "disabled", "auth_disabled")
|
|
next.ServeHTTP(w, r.WithContext(WithRequestContext(r.Context(), authCtx)))
|
|
return
|
|
}
|
|
|
|
if r.URL.Path == "/healthz" {
|
|
metrics.Default.ObserveAuth("bypass", "none", "public_endpoint")
|
|
next.ServeHTTP(w, r.WithContext(WithRequestContext(r.Context(), authCtx)))
|
|
return
|
|
}
|
|
|
|
resolvedCtx, err := svc.AuthenticateRequest(r)
|
|
if err != nil {
|
|
metrics.Default.ObserveAuth("error", "sigv4", authErrorClass(err))
|
|
if auditEnabled && logger != nil {
|
|
requestID := middleware.GetReqID(r.Context())
|
|
attrs := []any{
|
|
"method", r.Method,
|
|
"path", r.URL.Path,
|
|
"remote_ip", clientIP(r.RemoteAddr),
|
|
"error", err.Error(),
|
|
}
|
|
if requestID != "" {
|
|
attrs = append(attrs, "request_id", requestID)
|
|
}
|
|
logger.Warn("auth_failed", attrs...)
|
|
}
|
|
if onError != nil {
|
|
onError(w, r, err)
|
|
return
|
|
}
|
|
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
if err := wrapPayloadHashVerifier(r); err != nil {
|
|
metrics.Default.ObserveAuth("error", "sigv4", authErrorClass(err))
|
|
if onError != nil {
|
|
onError(w, r, err)
|
|
return
|
|
}
|
|
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
metrics.Default.ObserveAuth("ok", resolvedCtx.AuthType, "none")
|
|
if auditEnabled && logger != nil {
|
|
requestID := middleware.GetReqID(r.Context())
|
|
attrs := []any{
|
|
"method", r.Method,
|
|
"path", r.URL.Path,
|
|
"remote_ip", clientIP(r.RemoteAddr),
|
|
"access_key_id", resolvedCtx.AccessKeyID,
|
|
"auth_type", resolvedCtx.AuthType,
|
|
}
|
|
if requestID != "" {
|
|
attrs = append(attrs, "request_id", requestID)
|
|
}
|
|
logger.Info("auth_success", attrs...)
|
|
}
|
|
next.ServeHTTP(w, r.WithContext(WithRequestContext(r.Context(), resolvedCtx)))
|
|
})
|
|
}
|
|
}
|
|
|
|
func wrapPayloadHashVerifier(r *http.Request) error {
|
|
if r == nil || r.Body == nil || r.Body == http.NoBody {
|
|
return nil
|
|
}
|
|
payloadHash := resolvePayloadHash(r, false)
|
|
if !payloadHashRequiresVerification(payloadHash) {
|
|
return nil
|
|
}
|
|
if !isHexSHA256(payloadHash) {
|
|
return ErrAuthorizationHeaderMalformed
|
|
}
|
|
expected, err := hex.DecodeString(strings.ToLower(payloadHash))
|
|
if err != nil {
|
|
return ErrAuthorizationHeaderMalformed
|
|
}
|
|
r.Body = &payloadHashVerifyingReadCloser{
|
|
inner: r.Body,
|
|
hasher: sha256.New(),
|
|
expected: expected,
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type payloadHashVerifyingReadCloser struct {
|
|
inner io.ReadCloser
|
|
hasher hash.Hash
|
|
expected []byte
|
|
done bool
|
|
}
|
|
|
|
func (r *payloadHashVerifyingReadCloser) Read(p []byte) (int, error) {
|
|
n, err := r.inner.Read(p)
|
|
if n > 0 {
|
|
_, _ = r.hasher.Write(p[:n])
|
|
}
|
|
if err == io.EOF && !r.done {
|
|
r.done = true
|
|
if !equalBytes(r.hasher.Sum(nil), r.expected) {
|
|
return n, ErrSignatureDoesNotMatch
|
|
}
|
|
}
|
|
return n, err
|
|
}
|
|
|
|
func (r *payloadHashVerifyingReadCloser) Close() error {
|
|
return r.inner.Close()
|
|
}
|
|
|
|
func equalBytes(left, right []byte) bool {
|
|
if len(left) != len(right) {
|
|
return false
|
|
}
|
|
var diff byte
|
|
for i := range left {
|
|
diff |= left[i] ^ right[i]
|
|
}
|
|
return diff == 0
|
|
}
|
|
|
|
func authErrorClass(err error) string {
|
|
switch {
|
|
case errors.Is(err, ErrInvalidAccessKeyID):
|
|
return "invalid_access_key"
|
|
case errors.Is(err, ErrSignatureDoesNotMatch):
|
|
return "signature_mismatch"
|
|
case errors.Is(err, ErrAuthorizationHeaderMalformed):
|
|
return "auth_header_malformed"
|
|
case errors.Is(err, ErrRequestTimeTooSkewed):
|
|
return "time_skew"
|
|
case errors.Is(err, ErrExpiredToken):
|
|
return "expired_token"
|
|
case errors.Is(err, ErrNoAuthCredentials):
|
|
return "missing_credentials"
|
|
case errors.Is(err, ErrUnsupportedAuthScheme):
|
|
return "unsupported_auth_scheme"
|
|
case errors.Is(err, ErrInvalidPresign):
|
|
return "invalid_presign"
|
|
case errors.Is(err, ErrCredentialDisabled):
|
|
return "credential_disabled"
|
|
case errors.Is(err, ErrAccessDenied):
|
|
return "access_denied"
|
|
default:
|
|
return "other"
|
|
}
|
|
}
|
|
|
|
func clientIP(remoteAddr string) string {
|
|
host, _, err := net.SplitHostPort(remoteAddr)
|
|
if err == nil && host != "" {
|
|
return host
|
|
}
|
|
return remoteAddr
|
|
}
|