mirror of
https://github.com/ferdzo/fs.git
synced 2026-04-04 20:36:25 +00:00
Initial working authentication with SigV4
This commit is contained in:
@@ -7,3 +7,11 @@ ADDRESS=0.0.0.0
|
||||
GC_INTERVAL=10
|
||||
GC_ENABLED=true
|
||||
MULTIPART_RETENTION_HOURS=24
|
||||
AUTH_ENABLED=true
|
||||
AUTH_REGION=us-east-1
|
||||
AUTH_SKEW_SECONDS=300
|
||||
AUTH_MAX_PRESIGN_SECONDS=86400
|
||||
AUTH_MASTER_KEY=
|
||||
AUTH_BOOTSTRAP_ACCESS_KEY=
|
||||
AUTH_BOOTSTRAP_SECRET_KEY=
|
||||
AUTH_BOOTSTRAP_POLICY=
|
||||
|
||||
14
README.md
14
README.md
@@ -29,13 +29,25 @@ Multi-object delete:
|
||||
|
||||
AWS SigV4 streaming payload decoding for uploads (`aws-chunked` request bodies)
|
||||
|
||||
Authentication:
|
||||
- AWS SigV4 request verification (header and presigned URL forms)
|
||||
- Local credential/policy store in bbolt
|
||||
- Bootstrap access key/secret via environment variables
|
||||
|
||||
## Auth Setup
|
||||
|
||||
Required when `AUTH_ENABLED=true`:
|
||||
- `AUTH_MASTER_KEY` must be base64 for 32 decoded bytes (AES-256 key), e.g. `openssl rand -base64 32`
|
||||
- `AUTH_BOOTSTRAP_ACCESS_KEY` and `AUTH_BOOTSTRAP_SECRET_KEY` define initial credentials
|
||||
|
||||
Reference: `docs/auth-spec.md`
|
||||
|
||||
Health:
|
||||
- `GET /healthz`
|
||||
- `HEAD /healthz`
|
||||
|
||||
## Limitations
|
||||
|
||||
- No authentication/authorization yet.
|
||||
- Not full S3 API coverage.
|
||||
- No versioning or lifecycle policies.
|
||||
- Error and edge-case behavior is still being refined for client compatibility.
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
"fs/auth"
|
||||
"fs/logging"
|
||||
"fs/metadata"
|
||||
"fs/models"
|
||||
@@ -30,6 +31,7 @@ type Handler struct {
|
||||
svc *service.ObjectService
|
||||
logger *slog.Logger
|
||||
logConfig logging.Config
|
||||
authSvc *auth.Service
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -44,7 +46,7 @@ const (
|
||||
serverMaxConnections = 1024
|
||||
)
|
||||
|
||||
func NewHandler(svc *service.ObjectService, logger *slog.Logger, logConfig logging.Config) *Handler {
|
||||
func NewHandler(svc *service.ObjectService, logger *slog.Logger, logConfig logging.Config, authSvc *auth.Service) *Handler {
|
||||
r := chi.NewRouter()
|
||||
r.Use(middleware.RequestID)
|
||||
r.Use(middleware.Recoverer)
|
||||
@@ -57,12 +59,14 @@ func NewHandler(svc *service.ObjectService, logger *slog.Logger, logConfig loggi
|
||||
svc: svc,
|
||||
logger: logger,
|
||||
logConfig: logConfig,
|
||||
authSvc: authSvc,
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *Handler) setupRoutes() {
|
||||
h.router.Use(logging.HTTPMiddleware(h.logger, h.logConfig))
|
||||
h.router.Use(auth.Middleware(h.authSvc, h.logger, h.logConfig.Audit, writeMappedS3Error))
|
||||
|
||||
h.router.Get("/healthz", h.handleHealth)
|
||||
h.router.Head("/healthz", h.handleHealth)
|
||||
|
||||
@@ -3,6 +3,7 @@ package api
|
||||
import (
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"fs/auth"
|
||||
"fs/metadata"
|
||||
"fs/models"
|
||||
"fs/service"
|
||||
@@ -73,6 +74,41 @@ var (
|
||||
Code: "MalformedXML",
|
||||
Message: "The request must contain no more than 1000 object identifiers.",
|
||||
}
|
||||
s3ErrAccessDenied = s3APIError{
|
||||
Status: http.StatusForbidden,
|
||||
Code: "AccessDenied",
|
||||
Message: "Access Denied.",
|
||||
}
|
||||
s3ErrInvalidAccessKeyID = s3APIError{
|
||||
Status: http.StatusForbidden,
|
||||
Code: "InvalidAccessKeyId",
|
||||
Message: "The AWS Access Key Id you provided does not exist in our records.",
|
||||
}
|
||||
s3ErrSignatureDoesNotMatch = s3APIError{
|
||||
Status: http.StatusForbidden,
|
||||
Code: "SignatureDoesNotMatch",
|
||||
Message: "The request signature we calculated does not match the signature you provided.",
|
||||
}
|
||||
s3ErrAuthorizationHeaderMalformed = s3APIError{
|
||||
Status: http.StatusBadRequest,
|
||||
Code: "AuthorizationHeaderMalformed",
|
||||
Message: "The authorization header is malformed; the region/service/date is wrong or missing.",
|
||||
}
|
||||
s3ErrRequestTimeTooSkewed = s3APIError{
|
||||
Status: http.StatusForbidden,
|
||||
Code: "RequestTimeTooSkewed",
|
||||
Message: "The difference between the request time and the server's time is too large.",
|
||||
}
|
||||
s3ErrExpiredToken = s3APIError{
|
||||
Status: http.StatusBadRequest,
|
||||
Code: "ExpiredToken",
|
||||
Message: "The provided token has expired.",
|
||||
}
|
||||
s3ErrInvalidPresign = s3APIError{
|
||||
Status: http.StatusBadRequest,
|
||||
Code: "AuthorizationQueryParametersError",
|
||||
Message: "Error parsing the X-Amz-Credential parameter.",
|
||||
}
|
||||
s3ErrInternal = s3APIError{
|
||||
Status: http.StatusInternalServerError,
|
||||
Code: "InternalError",
|
||||
@@ -132,6 +168,26 @@ func mapToS3Error(err error) s3APIError {
|
||||
return s3ErrMalformedXML
|
||||
case errors.Is(err, service.ErrEntityTooSmall):
|
||||
return s3ErrEntityTooSmall
|
||||
case errors.Is(err, auth.ErrAccessDenied):
|
||||
return s3ErrAccessDenied
|
||||
case errors.Is(err, auth.ErrInvalidAccessKeyID):
|
||||
return s3ErrInvalidAccessKeyID
|
||||
case errors.Is(err, auth.ErrSignatureDoesNotMatch):
|
||||
return s3ErrSignatureDoesNotMatch
|
||||
case errors.Is(err, auth.ErrAuthorizationHeaderMalformed):
|
||||
return s3ErrAuthorizationHeaderMalformed
|
||||
case errors.Is(err, auth.ErrRequestTimeTooSkewed):
|
||||
return s3ErrRequestTimeTooSkewed
|
||||
case errors.Is(err, auth.ErrExpiredToken):
|
||||
return s3ErrExpiredToken
|
||||
case errors.Is(err, auth.ErrCredentialDisabled):
|
||||
return s3ErrAccessDenied
|
||||
case errors.Is(err, auth.ErrNoAuthCredentials):
|
||||
return s3ErrAccessDenied
|
||||
case errors.Is(err, auth.ErrUnsupportedAuthScheme):
|
||||
return s3ErrAuthorizationHeaderMalformed
|
||||
case errors.Is(err, auth.ErrInvalidPresign):
|
||||
return s3ErrInvalidPresign
|
||||
default:
|
||||
return s3ErrInternal
|
||||
}
|
||||
|
||||
150
auth/README.md
Normal file
150
auth/README.md
Normal file
@@ -0,0 +1,150 @@
|
||||
# Authentication Design
|
||||
|
||||
This folder implements S3-compatible request authentication using AWS Signature Version 4 (SigV4), with local identity and policy data stored in bbolt.
|
||||
|
||||
## Goals
|
||||
- Keep S3 client compatibility for request signing.
|
||||
- Avoid external auth databases.
|
||||
- Store secrets encrypted at rest (not plaintext in bbolt).
|
||||
- Keep authorization simple and explicit.
|
||||
|
||||
## High-Level Architecture
|
||||
- `auth/middleware.go`
|
||||
- HTTP middleware that enforces auth before API handlers.
|
||||
- Exempts `/healthz`.
|
||||
- Calls auth service and writes mapped S3 XML errors on failure.
|
||||
- `auth/service.go`
|
||||
- Main auth orchestration:
|
||||
- parse SigV4 from request
|
||||
- validate timestamp/scope/service/region
|
||||
- load identity from metadata
|
||||
- decrypt secret
|
||||
- verify signature
|
||||
- evaluate policy against requested S3 action
|
||||
- `auth/sigv4.go`
|
||||
- Canonical SigV4 parsing and verification helpers.
|
||||
- Supports header auth and presigned query auth.
|
||||
- `auth/policy.go`
|
||||
- Authorization evaluator (deny overrides allow).
|
||||
- `auth/action.go`
|
||||
- Maps HTTP method/path/query to logical S3 action + resource target.
|
||||
- `auth/crypto.go`
|
||||
- AES-256-GCM encryption/decryption for stored secret keys.
|
||||
- `auth/context.go`
|
||||
- Carries authentication result in request context for downstream logic.
|
||||
- `auth/config.go`
|
||||
- Normalized auth configuration.
|
||||
- `auth/errors.go`
|
||||
- Domain auth errors used by API S3 error mapping.
|
||||
|
||||
## Config Model
|
||||
Auth is configured through env (read in `utils/config.go`, converted in `auth/config.go`):
|
||||
|
||||
- `AUTH_ENABLED`
|
||||
- `AUTH_REGION`
|
||||
- `AUTH_SKEW_SECONDS`
|
||||
- `AUTH_MAX_PRESIGN_SECONDS`
|
||||
- `AUTH_MASTER_KEY`
|
||||
- `AUTH_BOOTSTRAP_ACCESS_KEY`
|
||||
- `AUTH_BOOTSTRAP_SECRET_KEY`
|
||||
- `AUTH_BOOTSTRAP_POLICY` (optional JSON)
|
||||
|
||||
Important:
|
||||
- If `AUTH_ENABLED=true`, `AUTH_MASTER_KEY` is required.
|
||||
- `AUTH_MASTER_KEY` must be base64 that decodes to exactly 32 bytes (AES-256 key).
|
||||
|
||||
## Persistence Model (bbolt)
|
||||
Implemented in metadata layer:
|
||||
- `__AUTH_IDENTITIES__` bucket stores `models.AuthIdentity`
|
||||
- `access_key_id`
|
||||
- encrypted secret (`secret_enc`, `secret_nonce`)
|
||||
- status (`active`/disabled)
|
||||
- timestamps
|
||||
- `__AUTH_POLICIES__` bucket stores `models.AuthPolicy`
|
||||
- `principal`
|
||||
- statements (`effect`, `actions`, `bucket`, `prefix`)
|
||||
|
||||
## Bootstrap Identity
|
||||
On startup (`main.go`):
|
||||
1. Build auth config.
|
||||
2. Create auth service with metadata store.
|
||||
3. Call `EnsureBootstrap()`.
|
||||
|
||||
If bootstrap env key/secret are set:
|
||||
- identity is created/updated
|
||||
- secret is encrypted with AES-GCM and stored
|
||||
- policy is created:
|
||||
- default: full access (`s3:*`, `bucket=*`, `prefix=*`)
|
||||
- or overridden by `AUTH_BOOTSTRAP_POLICY`
|
||||
|
||||
## Request Authentication Flow
|
||||
For each non-health request:
|
||||
1. Parse SigV4 input (header or presigned query).
|
||||
2. Validate structural fields:
|
||||
- algorithm
|
||||
- credential scope
|
||||
- service must be `s3`
|
||||
- region must match config
|
||||
3. Validate time:
|
||||
- `x-amz-date` format
|
||||
- skew within `AUTH_SKEW_SECONDS`
|
||||
- presigned expiry within `AUTH_MAX_PRESIGN_SECONDS`
|
||||
4. Load identity by access key id.
|
||||
5. Ensure identity status is active.
|
||||
6. Decrypt stored secret using master key.
|
||||
7. Recompute canonical request and expected signature.
|
||||
8. Compare signatures.
|
||||
9. Resolve target action from request.
|
||||
10. Evaluate policy; deny overrides allow.
|
||||
11. Store auth result in request context and continue.
|
||||
|
||||
## Authorization Semantics
|
||||
Policy evaluator rules:
|
||||
- No matching allow => denied.
|
||||
- Any matching deny => denied (even if allow also matches).
|
||||
- Wildcards supported:
|
||||
- action: `*` or `s3:*`
|
||||
- bucket: `*`
|
||||
- prefix: `*`
|
||||
|
||||
Action resolution includes:
|
||||
- bucket APIs (`CreateBucket`, `ListBucket`, `HeadBucket`, `DeleteBucket`)
|
||||
- object APIs (`GetObject`, `PutObject`, `DeleteObject`)
|
||||
- multipart APIs (`CreateMultipartUpload`, `UploadPart`, `ListMultipartUploadParts`, `CompleteMultipartUpload`, `AbortMultipartUpload`)
|
||||
|
||||
## Error Behavior
|
||||
Auth errors are mapped to S3-style XML errors in `api/s3_errors.go`, including:
|
||||
- `AccessDenied`
|
||||
- `InvalidAccessKeyId`
|
||||
- `SignatureDoesNotMatch`
|
||||
- `AuthorizationHeaderMalformed`
|
||||
- `RequestTimeTooSkewed`
|
||||
- `ExpiredToken`
|
||||
- `AuthorizationQueryParametersError`
|
||||
|
||||
## Audit Logging
|
||||
When `AUDIT_LOG=true` and auth is enabled:
|
||||
- successful auth attempts emit `auth_success`
|
||||
- failed auth attempts emit `auth_failed`
|
||||
|
||||
Each audit entry includes method, path, remote IP, and request ID (if present). Success logs also include access key ID and auth type.
|
||||
|
||||
## Security Notes
|
||||
- Secret keys are recoverable by server design (required for SigV4 verification).
|
||||
- They are encrypted at rest, not hashed.
|
||||
- Master key rotation is not implemented yet.
|
||||
- Keep `AUTH_MASTER_KEY` protected (secret manager/systemd env file/etc.).
|
||||
|
||||
## Current Scope / Limitations
|
||||
- No STS/session-token auth yet.
|
||||
- No admin API for managing multiple users yet.
|
||||
- Policy language is intentionally minimal, not full IAM.
|
||||
- No automatic key rotation workflows.
|
||||
|
||||
## Practical Next Step
|
||||
To support multiple users cleanly, add admin operations in auth service + API:
|
||||
- create user
|
||||
- rotate secret
|
||||
- set policy
|
||||
- disable/enable
|
||||
- delete user
|
||||
93
auth/action.go
Normal file
93
auth/action.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Action string
|
||||
|
||||
const (
|
||||
ActionListAllMyBuckets Action = "s3:ListAllMyBuckets"
|
||||
ActionCreateBucket Action = "s3:CreateBucket"
|
||||
ActionHeadBucket Action = "s3:HeadBucket"
|
||||
ActionDeleteBucket Action = "s3:DeleteBucket"
|
||||
ActionListBucket Action = "s3:ListBucket"
|
||||
ActionGetObject Action = "s3:GetObject"
|
||||
ActionPutObject Action = "s3:PutObject"
|
||||
ActionDeleteObject Action = "s3:DeleteObject"
|
||||
ActionCreateMultipartUpload Action = "s3:CreateMultipartUpload"
|
||||
ActionUploadPart Action = "s3:UploadPart"
|
||||
ActionListMultipartParts Action = "s3:ListMultipartUploadParts"
|
||||
ActionCompleteMultipart Action = "s3:CompleteMultipartUpload"
|
||||
ActionAbortMultipartUpload Action = "s3:AbortMultipartUpload"
|
||||
)
|
||||
|
||||
type RequestTarget struct {
|
||||
Action Action
|
||||
Bucket string
|
||||
Key string
|
||||
}
|
||||
|
||||
func resolveTarget(r *http.Request) RequestTarget {
|
||||
path := strings.TrimPrefix(r.URL.Path, "/")
|
||||
if path == "" {
|
||||
return RequestTarget{Action: ActionListAllMyBuckets}
|
||||
}
|
||||
|
||||
parts := strings.SplitN(path, "/", 2)
|
||||
bucket := parts[0]
|
||||
key := ""
|
||||
if len(parts) > 1 {
|
||||
key = parts[1]
|
||||
}
|
||||
|
||||
if key == "" {
|
||||
switch r.Method {
|
||||
case http.MethodPut:
|
||||
return RequestTarget{Action: ActionCreateBucket, Bucket: bucket}
|
||||
case http.MethodHead:
|
||||
return RequestTarget{Action: ActionHeadBucket, Bucket: bucket}
|
||||
case http.MethodDelete:
|
||||
return RequestTarget{Action: ActionDeleteBucket, Bucket: bucket}
|
||||
case http.MethodGet:
|
||||
return RequestTarget{Action: ActionListBucket, Bucket: bucket}
|
||||
case http.MethodPost:
|
||||
if _, ok := r.URL.Query()["delete"]; ok {
|
||||
return RequestTarget{Action: ActionDeleteObject, Bucket: bucket}
|
||||
}
|
||||
}
|
||||
return RequestTarget{Bucket: bucket}
|
||||
}
|
||||
|
||||
uploadID := r.URL.Query().Get("uploadId")
|
||||
partNumber := r.URL.Query().Get("partNumber")
|
||||
if _, ok := r.URL.Query()["uploads"]; ok && r.Method == http.MethodPost {
|
||||
return RequestTarget{Action: ActionCreateMultipartUpload, Bucket: bucket, Key: key}
|
||||
}
|
||||
if uploadID != "" {
|
||||
switch r.Method {
|
||||
case http.MethodPut:
|
||||
if partNumber != "" {
|
||||
return RequestTarget{Action: ActionUploadPart, Bucket: bucket, Key: key}
|
||||
}
|
||||
case http.MethodGet:
|
||||
return RequestTarget{Action: ActionListMultipartParts, Bucket: bucket, Key: key}
|
||||
case http.MethodPost:
|
||||
return RequestTarget{Action: ActionCompleteMultipart, Bucket: bucket, Key: key}
|
||||
case http.MethodDelete:
|
||||
return RequestTarget{Action: ActionAbortMultipartUpload, Bucket: bucket, Key: key}
|
||||
}
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodGet, http.MethodHead:
|
||||
return RequestTarget{Action: ActionGetObject, Bucket: bucket, Key: key}
|
||||
case http.MethodPut:
|
||||
return RequestTarget{Action: ActionPutObject, Bucket: bucket, Key: key}
|
||||
case http.MethodDelete:
|
||||
return RequestTarget{Action: ActionDeleteObject, Bucket: bucket, Key: key}
|
||||
}
|
||||
|
||||
return RequestTarget{Bucket: bucket, Key: key}
|
||||
}
|
||||
50
auth/config.go
Normal file
50
auth/config.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Enabled bool
|
||||
Region string
|
||||
ClockSkew time.Duration
|
||||
MaxPresignDuration time.Duration
|
||||
MasterKey string
|
||||
BootstrapAccessKey string
|
||||
BootstrapSecretKey string
|
||||
BootstrapPolicy string
|
||||
}
|
||||
|
||||
func ConfigFromValues(
|
||||
enabled bool,
|
||||
region string,
|
||||
skew time.Duration,
|
||||
maxPresign time.Duration,
|
||||
masterKey string,
|
||||
bootstrapAccessKey string,
|
||||
bootstrapSecretKey string,
|
||||
bootstrapPolicy string,
|
||||
) Config {
|
||||
region = strings.TrimSpace(region)
|
||||
if region == "" {
|
||||
region = "us-east-1"
|
||||
}
|
||||
if skew <= 0 {
|
||||
skew = 5 * time.Minute
|
||||
}
|
||||
if maxPresign <= 0 {
|
||||
maxPresign = 24 * time.Hour
|
||||
}
|
||||
|
||||
return Config{
|
||||
Enabled: enabled,
|
||||
Region: region,
|
||||
ClockSkew: skew,
|
||||
MaxPresignDuration: maxPresign,
|
||||
MasterKey: strings.TrimSpace(masterKey),
|
||||
BootstrapAccessKey: strings.TrimSpace(bootstrapAccessKey),
|
||||
BootstrapSecretKey: strings.TrimSpace(bootstrapSecretKey),
|
||||
BootstrapPolicy: strings.TrimSpace(bootstrapPolicy),
|
||||
}
|
||||
}
|
||||
23
auth/context.go
Normal file
23
auth/context.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package auth
|
||||
|
||||
import "context"
|
||||
|
||||
type RequestContext struct {
|
||||
Authenticated bool
|
||||
AccessKeyID string
|
||||
AuthType string
|
||||
}
|
||||
|
||||
type contextKey int
|
||||
|
||||
const requestContextKey contextKey = iota
|
||||
|
||||
func WithRequestContext(ctx context.Context, authCtx RequestContext) context.Context {
|
||||
return context.WithValue(ctx, requestContextKey, authCtx)
|
||||
}
|
||||
|
||||
func GetRequestContext(ctx context.Context) (RequestContext, bool) {
|
||||
value := ctx.Value(requestContextKey)
|
||||
authCtx, ok := value.(RequestContext)
|
||||
return authCtx, ok
|
||||
}
|
||||
74
auth/crypto.go
Normal file
74
auth/crypto.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
const (
|
||||
masterKeyLength = 32
|
||||
gcmNonceLength = 12
|
||||
)
|
||||
|
||||
func decodeMasterKey(raw string) ([]byte, error) {
|
||||
decoded, err := base64.StdEncoding.DecodeString(raw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %v", ErrInvalidMasterKey, err)
|
||||
}
|
||||
if len(decoded) != masterKeyLength {
|
||||
return nil, fmt.Errorf("%w: expected %d-byte decoded key", ErrInvalidMasterKey, masterKeyLength)
|
||||
}
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
func encryptSecret(masterKey []byte, accessKeyID, secret string) (ciphertextB64 string, nonceB64 string, err error) {
|
||||
block, err := aes.NewCipher(masterKey)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
nonce := make([]byte, gcmNonceLength)
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
ciphertext := gcm.Seal(nil, nonce, []byte(secret), []byte(accessKeyID))
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), base64.StdEncoding.EncodeToString(nonce), nil
|
||||
}
|
||||
|
||||
func decryptSecret(masterKey []byte, accessKeyID, ciphertextB64, nonceB64 string) (string, error) {
|
||||
block, err := aes.NewCipher(masterKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ciphertext, err := base64.StdEncoding.DecodeString(ciphertextB64)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
nonce, err := base64.StdEncoding.DecodeString(nonceB64)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(nonce) != gcmNonceLength {
|
||||
return "", fmt.Errorf("invalid nonce length: %d", len(nonce))
|
||||
}
|
||||
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertext, []byte(accessKeyID))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(plaintext), nil
|
||||
}
|
||||
19
auth/errors.go
Normal file
19
auth/errors.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package auth
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrAccessDenied = errors.New("access denied")
|
||||
ErrInvalidAccessKeyID = errors.New("invalid access key id")
|
||||
ErrSignatureDoesNotMatch = errors.New("signature does not match")
|
||||
ErrAuthorizationHeaderMalformed = errors.New("authorization header malformed")
|
||||
ErrRequestTimeTooSkewed = errors.New("request time too skewed")
|
||||
ErrExpiredToken = errors.New("expired token")
|
||||
ErrCredentialDisabled = errors.New("credential disabled")
|
||||
ErrAuthNotEnabled = errors.New("authentication is not enabled")
|
||||
ErrMasterKeyRequired = errors.New("auth master key is required")
|
||||
ErrInvalidMasterKey = errors.New("invalid auth master key")
|
||||
ErrNoAuthCredentials = errors.New("no auth credentials found")
|
||||
ErrUnsupportedAuthScheme = errors.New("unsupported auth scheme")
|
||||
ErrInvalidPresign = errors.New("invalid presigned request")
|
||||
)
|
||||
78
auth/middleware.go
Normal file
78
auth/middleware.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
)
|
||||
|
||||
func Middleware(
|
||||
svc *Service,
|
||||
logger *slog.Logger,
|
||||
auditEnabled bool,
|
||||
onError func(http.ResponseWriter, *http.Request, error),
|
||||
) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
authCtx := RequestContext{Authenticated: false, AuthType: "none"}
|
||||
if svc == nil || !svc.Config().Enabled {
|
||||
next.ServeHTTP(w, r.WithContext(WithRequestContext(r.Context(), authCtx)))
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Path == "/healthz" {
|
||||
next.ServeHTTP(w, r.WithContext(WithRequestContext(r.Context(), authCtx)))
|
||||
return
|
||||
}
|
||||
|
||||
resolvedCtx, err := svc.AuthenticateRequest(r)
|
||||
if err != nil {
|
||||
if auditEnabled && logger != nil {
|
||||
requestID := middleware.GetReqID(r.Context())
|
||||
attrs := []any{
|
||||
"method", r.Method,
|
||||
"path", r.URL.Path,
|
||||
"remote_ip", clientIP(r.RemoteAddr),
|
||||
"error", err.Error(),
|
||||
}
|
||||
if requestID != "" {
|
||||
attrs = append(attrs, "request_id", requestID)
|
||||
}
|
||||
logger.Warn("auth_failed", attrs...)
|
||||
}
|
||||
if onError != nil {
|
||||
onError(w, r, err)
|
||||
return
|
||||
}
|
||||
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
if auditEnabled && logger != nil {
|
||||
requestID := middleware.GetReqID(r.Context())
|
||||
attrs := []any{
|
||||
"method", r.Method,
|
||||
"path", r.URL.Path,
|
||||
"remote_ip", clientIP(r.RemoteAddr),
|
||||
"access_key_id", resolvedCtx.AccessKeyID,
|
||||
"auth_type", resolvedCtx.AuthType,
|
||||
}
|
||||
if requestID != "" {
|
||||
attrs = append(attrs, "request_id", requestID)
|
||||
}
|
||||
logger.Info("auth_success", attrs...)
|
||||
}
|
||||
next.ServeHTTP(w, r.WithContext(WithRequestContext(r.Context(), resolvedCtx)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func clientIP(remoteAddr string) string {
|
||||
host, _, err := net.SplitHostPort(remoteAddr)
|
||||
if err == nil && host != "" {
|
||||
return host
|
||||
}
|
||||
return remoteAddr
|
||||
}
|
||||
66
auth/policy.go
Normal file
66
auth/policy.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fs/models"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func isAllowed(policy *models.AuthPolicy, target RequestTarget) bool {
|
||||
if policy == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
allowed := false
|
||||
for _, stmt := range policy.Statements {
|
||||
if !statementMatches(stmt, target) {
|
||||
continue
|
||||
}
|
||||
effect := strings.ToLower(strings.TrimSpace(stmt.Effect))
|
||||
if effect == "deny" {
|
||||
return false
|
||||
}
|
||||
if effect == "allow" {
|
||||
allowed = true
|
||||
}
|
||||
}
|
||||
return allowed
|
||||
}
|
||||
|
||||
func statementMatches(stmt models.AuthPolicyStatement, target RequestTarget) bool {
|
||||
if !actionMatches(stmt.Actions, target.Action) {
|
||||
return false
|
||||
}
|
||||
if !bucketMatches(stmt.Bucket, target.Bucket) {
|
||||
return false
|
||||
}
|
||||
if target.Key == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
prefix := strings.TrimSpace(stmt.Prefix)
|
||||
if prefix == "" || prefix == "*" {
|
||||
return true
|
||||
}
|
||||
return strings.HasPrefix(target.Key, prefix)
|
||||
}
|
||||
|
||||
func actionMatches(actions []string, action Action) bool {
|
||||
if len(actions) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, current := range actions {
|
||||
normalized := strings.TrimSpace(current)
|
||||
if normalized == "*" || normalized == "s3:*" || strings.EqualFold(normalized, string(action)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func bucketMatches(pattern, bucket string) bool {
|
||||
pattern = strings.TrimSpace(pattern)
|
||||
if pattern == "" || pattern == "*" {
|
||||
return true
|
||||
}
|
||||
return pattern == bucket
|
||||
}
|
||||
186
auth/service.go
Normal file
186
auth/service.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"fs/models"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Store interface {
|
||||
GetAuthIdentity(accessKeyID string) (*models.AuthIdentity, error)
|
||||
PutAuthIdentity(identity *models.AuthIdentity) error
|
||||
GetAuthPolicy(accessKeyID string) (*models.AuthPolicy, error)
|
||||
PutAuthPolicy(policy *models.AuthPolicy) error
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
cfg Config
|
||||
store Store
|
||||
masterKey []byte
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
func NewService(cfg Config, store Store) (*Service, error) {
|
||||
if store == nil {
|
||||
return nil, errors.New("auth store is required")
|
||||
}
|
||||
|
||||
svc := &Service{
|
||||
cfg: cfg,
|
||||
store: store,
|
||||
now: func() time.Time { return time.Now().UTC() },
|
||||
}
|
||||
if !cfg.Enabled {
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
if strings.TrimSpace(cfg.MasterKey) == "" {
|
||||
return nil, ErrMasterKeyRequired
|
||||
}
|
||||
masterKey, err := decodeMasterKey(cfg.MasterKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
svc.masterKey = masterKey
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
func (s *Service) Config() Config {
|
||||
return s.cfg
|
||||
}
|
||||
|
||||
func (s *Service) EnsureBootstrap() error {
|
||||
if !s.cfg.Enabled {
|
||||
return nil
|
||||
}
|
||||
accessKey := strings.TrimSpace(s.cfg.BootstrapAccessKey)
|
||||
secret := strings.TrimSpace(s.cfg.BootstrapSecretKey)
|
||||
if accessKey == "" || secret == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(accessKey) < 3 {
|
||||
return errors.New("bootstrap access key must be at least 3 characters")
|
||||
}
|
||||
if len(secret) < 8 {
|
||||
return errors.New("bootstrap secret key must be at least 8 characters")
|
||||
}
|
||||
|
||||
now := s.now().Unix()
|
||||
ciphertext, nonce, err := encryptSecret(s.masterKey, accessKey, secret)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
identity := &models.AuthIdentity{
|
||||
AccessKeyID: accessKey,
|
||||
SecretEnc: ciphertext,
|
||||
SecretNonce: nonce,
|
||||
EncAlg: "AES-256-GCM",
|
||||
KeyVersion: "v1",
|
||||
Status: "active",
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
if existing, err := s.store.GetAuthIdentity(accessKey); err == nil && existing != nil {
|
||||
identity.CreatedAt = existing.CreatedAt
|
||||
}
|
||||
if err := s.store.PutAuthIdentity(identity); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
policy := defaultBootstrapPolicy(accessKey)
|
||||
if strings.TrimSpace(s.cfg.BootstrapPolicy) != "" {
|
||||
parsed, err := parsePolicyJSON(s.cfg.BootstrapPolicy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
policy = parsed
|
||||
policy.Principal = accessKey
|
||||
}
|
||||
return s.store.PutAuthPolicy(policy)
|
||||
}
|
||||
|
||||
func (s *Service) AuthenticateRequest(r *http.Request) (RequestContext, error) {
|
||||
if !s.cfg.Enabled {
|
||||
return RequestContext{Authenticated: false, AuthType: "disabled"}, nil
|
||||
}
|
||||
input, err := parseSigV4(r)
|
||||
if err != nil {
|
||||
return RequestContext{}, err
|
||||
}
|
||||
|
||||
if err := validateSigV4Input(s.now(), s.cfg, input); err != nil {
|
||||
return RequestContext{}, err
|
||||
}
|
||||
|
||||
identity, err := s.store.GetAuthIdentity(input.AccessKeyID)
|
||||
if err != nil {
|
||||
return RequestContext{}, ErrInvalidAccessKeyID
|
||||
}
|
||||
if !strings.EqualFold(identity.Status, "active") {
|
||||
return RequestContext{}, ErrCredentialDisabled
|
||||
}
|
||||
|
||||
secret, err := decryptSecret(s.masterKey, identity.AccessKeyID, identity.SecretEnc, identity.SecretNonce)
|
||||
if err != nil {
|
||||
return RequestContext{}, ErrSignatureDoesNotMatch
|
||||
}
|
||||
ok, err := signatureMatches(secret, r, input)
|
||||
if err != nil {
|
||||
return RequestContext{}, err
|
||||
}
|
||||
if !ok {
|
||||
return RequestContext{}, ErrSignatureDoesNotMatch
|
||||
}
|
||||
|
||||
policy, err := s.store.GetAuthPolicy(identity.AccessKeyID)
|
||||
if err != nil {
|
||||
return RequestContext{}, ErrAccessDenied
|
||||
}
|
||||
target := resolveTarget(r)
|
||||
if target.Action == "" {
|
||||
return RequestContext{}, ErrAccessDenied
|
||||
}
|
||||
if !isAllowed(policy, target) {
|
||||
return RequestContext{}, ErrAccessDenied
|
||||
}
|
||||
|
||||
authType := "sigv4-header"
|
||||
if input.Presigned {
|
||||
authType = "sigv4-presign"
|
||||
}
|
||||
return RequestContext{
|
||||
Authenticated: true,
|
||||
AccessKeyID: identity.AccessKeyID,
|
||||
AuthType: authType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parsePolicyJSON(raw string) (*models.AuthPolicy, error) {
|
||||
policy := models.AuthPolicy{}
|
||||
if err := json.Unmarshal([]byte(raw), &policy); err != nil {
|
||||
return nil, fmt.Errorf("invalid bootstrap policy: %w", err)
|
||||
}
|
||||
if len(policy.Statements) == 0 {
|
||||
return nil, errors.New("bootstrap policy must contain at least one statement")
|
||||
}
|
||||
return &policy, nil
|
||||
}
|
||||
|
||||
func defaultBootstrapPolicy(principal string) *models.AuthPolicy {
|
||||
return &models.AuthPolicy{
|
||||
Principal: principal,
|
||||
Statements: []models.AuthPolicyStatement{
|
||||
{
|
||||
Effect: "allow",
|
||||
Actions: []string{"s3:*"},
|
||||
Bucket: "*",
|
||||
Prefix: "*",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
372
auth/sigv4.go
Normal file
372
auth/sigv4.go
Normal file
@@ -0,0 +1,372 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
sigV4Algorithm = "AWS4-HMAC-SHA256"
|
||||
)
|
||||
|
||||
type sigV4Input struct {
|
||||
AccessKeyID string
|
||||
Date string
|
||||
Region string
|
||||
Service string
|
||||
Scope string
|
||||
SignedHeaders []string
|
||||
SignedHeadersRaw string
|
||||
SignatureHex string
|
||||
AmzDate string
|
||||
ExpiresSeconds int
|
||||
Presigned bool
|
||||
}
|
||||
|
||||
func parseSigV4(r *http.Request) (*sigV4Input, error) {
|
||||
if r == nil {
|
||||
return nil, fmt.Errorf("%w: nil request", ErrAuthorizationHeaderMalformed)
|
||||
}
|
||||
if strings.EqualFold(r.URL.Query().Get("X-Amz-Algorithm"), sigV4Algorithm) {
|
||||
return parsePresignedSigV4(r)
|
||||
}
|
||||
return parseHeaderSigV4(r)
|
||||
}
|
||||
|
||||
func parseHeaderSigV4(r *http.Request) (*sigV4Input, error) {
|
||||
header := strings.TrimSpace(r.Header.Get("Authorization"))
|
||||
if header == "" {
|
||||
return nil, ErrNoAuthCredentials
|
||||
}
|
||||
if !strings.HasPrefix(header, sigV4Algorithm+" ") {
|
||||
return nil, fmt.Errorf("%w: unsupported authorization algorithm", ErrUnsupportedAuthScheme)
|
||||
}
|
||||
|
||||
params := parseAuthorizationParams(strings.TrimSpace(strings.TrimPrefix(header, sigV4Algorithm)))
|
||||
credentialRaw := params["Credential"]
|
||||
signedHeadersRaw := params["SignedHeaders"]
|
||||
signatureHex := params["Signature"]
|
||||
if credentialRaw == "" || signedHeadersRaw == "" || signatureHex == "" {
|
||||
return nil, fmt.Errorf("%w: missing required authorization fields", ErrAuthorizationHeaderMalformed)
|
||||
}
|
||||
|
||||
accessKeyID, date, region, service, scope, err := parseCredential(credentialRaw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
amzDate := strings.TrimSpace(r.Header.Get("x-amz-date"))
|
||||
if amzDate == "" {
|
||||
return nil, fmt.Errorf("%w: x-amz-date is required", ErrAuthorizationHeaderMalformed)
|
||||
}
|
||||
signedHeaders := splitSignedHeaders(signedHeadersRaw)
|
||||
if len(signedHeaders) == 0 {
|
||||
return nil, fmt.Errorf("%w: signed headers are required", ErrAuthorizationHeaderMalformed)
|
||||
}
|
||||
|
||||
return &sigV4Input{
|
||||
AccessKeyID: accessKeyID,
|
||||
Date: date,
|
||||
Region: region,
|
||||
Service: service,
|
||||
Scope: scope,
|
||||
SignedHeaders: signedHeaders,
|
||||
SignedHeadersRaw: strings.ToLower(strings.TrimSpace(signedHeadersRaw)),
|
||||
SignatureHex: strings.ToLower(strings.TrimSpace(signatureHex)),
|
||||
AmzDate: amzDate,
|
||||
Presigned: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parsePresignedSigV4(r *http.Request) (*sigV4Input, error) {
|
||||
query := r.URL.Query()
|
||||
if !strings.EqualFold(query.Get("X-Amz-Algorithm"), sigV4Algorithm) {
|
||||
return nil, fmt.Errorf("%w: invalid X-Amz-Algorithm", ErrInvalidPresign)
|
||||
}
|
||||
|
||||
credentialRaw := strings.TrimSpace(query.Get("X-Amz-Credential"))
|
||||
signedHeadersRaw := strings.TrimSpace(query.Get("X-Amz-SignedHeaders"))
|
||||
signatureHex := strings.TrimSpace(query.Get("X-Amz-Signature"))
|
||||
amzDate := strings.TrimSpace(query.Get("X-Amz-Date"))
|
||||
expiresRaw := strings.TrimSpace(query.Get("X-Amz-Expires"))
|
||||
if credentialRaw == "" || signedHeadersRaw == "" || signatureHex == "" || amzDate == "" || expiresRaw == "" {
|
||||
return nil, fmt.Errorf("%w: missing presigned query fields", ErrInvalidPresign)
|
||||
}
|
||||
expires, err := strconv.Atoi(expiresRaw)
|
||||
if err != nil || expires < 0 {
|
||||
return nil, fmt.Errorf("%w: invalid X-Amz-Expires", ErrInvalidPresign)
|
||||
}
|
||||
|
||||
accessKeyID, date, region, service, scope, err := parseCredential(credentialRaw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
signedHeaders := splitSignedHeaders(signedHeadersRaw)
|
||||
if len(signedHeaders) == 0 {
|
||||
return nil, fmt.Errorf("%w: signed headers are required", ErrInvalidPresign)
|
||||
}
|
||||
|
||||
return &sigV4Input{
|
||||
AccessKeyID: accessKeyID,
|
||||
Date: date,
|
||||
Region: region,
|
||||
Service: service,
|
||||
Scope: scope,
|
||||
SignedHeaders: signedHeaders,
|
||||
SignedHeadersRaw: strings.ToLower(strings.TrimSpace(signedHeadersRaw)),
|
||||
SignatureHex: strings.ToLower(signatureHex),
|
||||
AmzDate: amzDate,
|
||||
ExpiresSeconds: expires,
|
||||
Presigned: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseCredential(raw string) (accessKeyID string, date string, region string, service string, scope string, err error) {
|
||||
parts := strings.Split(strings.TrimSpace(raw), "/")
|
||||
if len(parts) != 5 {
|
||||
return "", "", "", "", "", fmt.Errorf("%w: invalid credential scope", ErrAuthorizationHeaderMalformed)
|
||||
}
|
||||
accessKeyID = strings.TrimSpace(parts[0])
|
||||
date = strings.TrimSpace(parts[1])
|
||||
region = strings.TrimSpace(parts[2])
|
||||
service = strings.TrimSpace(parts[3])
|
||||
terminal := strings.TrimSpace(parts[4])
|
||||
if accessKeyID == "" || date == "" || region == "" || service == "" || terminal != "aws4_request" {
|
||||
return "", "", "", "", "", fmt.Errorf("%w: invalid credential scope", ErrAuthorizationHeaderMalformed)
|
||||
}
|
||||
scope = strings.Join(parts[1:], "/")
|
||||
return accessKeyID, date, region, service, scope, nil
|
||||
}
|
||||
|
||||
func splitSignedHeaders(raw string) []string {
|
||||
raw = strings.ToLower(strings.TrimSpace(raw))
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
parts := strings.Split(raw, ";")
|
||||
headers := make([]string, 0, len(parts))
|
||||
for _, current := range parts {
|
||||
current = strings.TrimSpace(current)
|
||||
if current == "" {
|
||||
continue
|
||||
}
|
||||
headers = append(headers, current)
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
func parseAuthorizationParams(raw string) map[string]string {
|
||||
params := make(map[string]string)
|
||||
raw = strings.TrimSpace(raw)
|
||||
raw = strings.TrimPrefix(raw, " ")
|
||||
for _, token := range strings.Split(raw, ",") {
|
||||
token = strings.TrimSpace(token)
|
||||
key, value, found := strings.Cut(token, "=")
|
||||
if !found {
|
||||
continue
|
||||
}
|
||||
params[strings.TrimSpace(key)] = strings.TrimSpace(value)
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
func validateSigV4Input(now time.Time, cfg Config, input *sigV4Input) error {
|
||||
if input == nil {
|
||||
return fmt.Errorf("%w: empty signature input", ErrAuthorizationHeaderMalformed)
|
||||
}
|
||||
if !strings.EqualFold(input.Service, "s3") {
|
||||
return fmt.Errorf("%w: unsupported service", ErrAuthorizationHeaderMalformed)
|
||||
}
|
||||
if !strings.EqualFold(input.Region, cfg.Region) {
|
||||
return fmt.Errorf("%w: region mismatch", ErrAuthorizationHeaderMalformed)
|
||||
}
|
||||
|
||||
requestTime, err := time.Parse("20060102T150405Z", input.AmzDate)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: invalid x-amz-date", ErrAuthorizationHeaderMalformed)
|
||||
}
|
||||
delta := now.Sub(requestTime)
|
||||
if delta > cfg.ClockSkew || delta < -cfg.ClockSkew {
|
||||
return ErrRequestTimeTooSkewed
|
||||
}
|
||||
|
||||
if input.Presigned {
|
||||
if input.ExpiresSeconds > int(cfg.MaxPresignDuration.Seconds()) {
|
||||
return fmt.Errorf("%w: presign expires too large", ErrInvalidPresign)
|
||||
}
|
||||
expiresAt := requestTime.Add(time.Duration(input.ExpiresSeconds) * time.Second)
|
||||
if now.After(expiresAt) {
|
||||
return ErrExpiredToken
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func signatureMatches(secret string, r *http.Request, input *sigV4Input) (bool, error) {
|
||||
payloadHash := resolvePayloadHash(r, input.Presigned)
|
||||
canonicalRequest, err := buildCanonicalRequest(r, input.SignedHeaders, payloadHash, input.Presigned)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
stringToSign := buildStringToSign(input.AmzDate, input.Scope, canonicalRequest)
|
||||
signingKey := deriveSigningKey(secret, input.Date, input.Region, input.Service)
|
||||
expectedSig := hex.EncodeToString(hmacSHA256(signingKey, stringToSign))
|
||||
return hmac.Equal([]byte(expectedSig), []byte(input.SignatureHex)), nil
|
||||
}
|
||||
|
||||
func resolvePayloadHash(r *http.Request, presigned bool) string {
|
||||
if presigned {
|
||||
return "UNSIGNED-PAYLOAD"
|
||||
}
|
||||
hash := strings.TrimSpace(r.Header.Get("x-amz-content-sha256"))
|
||||
if hash == "" {
|
||||
return "UNSIGNED-PAYLOAD"
|
||||
}
|
||||
return hash
|
||||
}
|
||||
|
||||
func buildCanonicalRequest(r *http.Request, signedHeaders []string, payloadHash string, presigned bool) (string, error) {
|
||||
canonicalURI := canonicalPath(r.URL)
|
||||
canonicalQuery := canonicalQueryString(r.URL.RawQuery, presigned)
|
||||
canonicalHeaders, signedHeadersRaw, err := canonicalHeadersForRequest(r, signedHeaders)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return strings.Join([]string{
|
||||
r.Method,
|
||||
canonicalURI,
|
||||
canonicalQuery,
|
||||
canonicalHeaders,
|
||||
signedHeadersRaw,
|
||||
payloadHash,
|
||||
}, "\n"), nil
|
||||
}
|
||||
|
||||
func canonicalPath(u *url.URL) string {
|
||||
if u == nil {
|
||||
return "/"
|
||||
}
|
||||
path := u.EscapedPath()
|
||||
if path == "" {
|
||||
return "/"
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
type queryPair struct {
|
||||
Key string
|
||||
Value string
|
||||
}
|
||||
|
||||
func canonicalQueryString(rawQuery string, presigned bool) string {
|
||||
if rawQuery == "" {
|
||||
return ""
|
||||
}
|
||||
values, _ := url.ParseQuery(rawQuery)
|
||||
pairs := make([]queryPair, 0)
|
||||
for key, valueList := range values {
|
||||
if presigned && strings.EqualFold(key, "X-Amz-Signature") {
|
||||
continue
|
||||
}
|
||||
if len(valueList) == 0 {
|
||||
pairs = append(pairs, queryPair{Key: key, Value: ""})
|
||||
continue
|
||||
}
|
||||
for _, value := range valueList {
|
||||
pairs = append(pairs, queryPair{Key: key, Value: value})
|
||||
}
|
||||
}
|
||||
sort.Slice(pairs, func(i, j int) bool {
|
||||
if pairs[i].Key == pairs[j].Key {
|
||||
return pairs[i].Value < pairs[j].Value
|
||||
}
|
||||
return pairs[i].Key < pairs[j].Key
|
||||
})
|
||||
|
||||
encoded := make([]string, 0, len(pairs))
|
||||
for _, pair := range pairs {
|
||||
encoded = append(encoded, awsEncodeQuery(pair.Key)+"="+awsEncodeQuery(pair.Value))
|
||||
}
|
||||
return strings.Join(encoded, "&")
|
||||
}
|
||||
|
||||
func awsEncodeQuery(value string) string {
|
||||
encoded := url.QueryEscape(value)
|
||||
encoded = strings.ReplaceAll(encoded, "+", "%20")
|
||||
encoded = strings.ReplaceAll(encoded, "*", "%2A")
|
||||
encoded = strings.ReplaceAll(encoded, "%7E", "~")
|
||||
return encoded
|
||||
}
|
||||
|
||||
func canonicalHeadersForRequest(r *http.Request, signedHeaders []string) (canonical string, signedRaw string, err error) {
|
||||
if len(signedHeaders) == 0 {
|
||||
return "", "", fmt.Errorf("%w: empty signed headers", ErrAuthorizationHeaderMalformed)
|
||||
}
|
||||
|
||||
normalized := make([]string, 0, len(signedHeaders))
|
||||
lines := make([]string, 0, len(signedHeaders))
|
||||
for _, headerName := range signedHeaders {
|
||||
headerName = strings.ToLower(strings.TrimSpace(headerName))
|
||||
if headerName == "" {
|
||||
continue
|
||||
}
|
||||
var value string
|
||||
if headerName == "host" {
|
||||
value = r.Host
|
||||
} else {
|
||||
values, ok := r.Header[http.CanonicalHeaderKey(headerName)]
|
||||
if !ok || len(values) == 0 {
|
||||
return "", "", fmt.Errorf("%w: missing signed header %q", ErrAuthorizationHeaderMalformed, headerName)
|
||||
}
|
||||
value = strings.Join(values, ",")
|
||||
}
|
||||
value = normalizeHeaderValue(value)
|
||||
normalized = append(normalized, headerName)
|
||||
lines = append(lines, headerName+":"+value)
|
||||
}
|
||||
|
||||
if len(lines) == 0 {
|
||||
return "", "", fmt.Errorf("%w: no valid signed headers", ErrAuthorizationHeaderMalformed)
|
||||
}
|
||||
signedRaw = strings.Join(normalized, ";")
|
||||
canonical = strings.Join(lines, "\n") + "\n"
|
||||
return canonical, signedRaw, nil
|
||||
}
|
||||
|
||||
func normalizeHeaderValue(value string) string {
|
||||
value = strings.TrimSpace(value)
|
||||
parts := strings.Fields(value)
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
func buildStringToSign(amzDate string, scope string, canonicalRequest string) string {
|
||||
canonicalHash := sha256.Sum256([]byte(canonicalRequest))
|
||||
return strings.Join([]string{
|
||||
sigV4Algorithm,
|
||||
amzDate,
|
||||
scope,
|
||||
hex.EncodeToString(canonicalHash[:]),
|
||||
}, "\n")
|
||||
}
|
||||
|
||||
func deriveSigningKey(secret, date, region, service string) []byte {
|
||||
kDate := hmacSHA256([]byte("AWS4"+secret), date)
|
||||
kRegion := hmacSHA256(kDate, region)
|
||||
kService := hmacSHA256(kRegion, service)
|
||||
return hmacSHA256(kService, "aws4_request")
|
||||
}
|
||||
|
||||
func hmacSHA256(key []byte, value string) []byte {
|
||||
mac := hmac.New(sha256.New, key)
|
||||
_, _ = mac.Write([]byte(value))
|
||||
return mac.Sum(nil)
|
||||
}
|
||||
27
main.go
27
main.go
@@ -3,6 +3,7 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"fs/api"
|
||||
"fs/auth"
|
||||
"fs/logging"
|
||||
"fs/metadata"
|
||||
"fs/service"
|
||||
@@ -19,6 +20,16 @@ import (
|
||||
func main() {
|
||||
config := utils.NewConfig()
|
||||
logConfig := logging.ConfigFromValues(config.LogLevel, config.LogFormat, config.AuditLog)
|
||||
authConfig := auth.ConfigFromValues(
|
||||
config.AuthEnabled,
|
||||
config.AuthRegion,
|
||||
config.AuthSkew,
|
||||
config.AuthMaxPresign,
|
||||
config.AuthMasterKey,
|
||||
config.AuthBootstrapAccessKey,
|
||||
config.AuthBootstrapSecretKey,
|
||||
config.AuthBootstrapPolicy,
|
||||
)
|
||||
logger := logging.NewLogger(logConfig)
|
||||
logger.Info("boot",
|
||||
"log_level", logConfig.LevelName,
|
||||
@@ -26,6 +37,8 @@ func main() {
|
||||
"audit_log", logConfig.Audit,
|
||||
"data_path", config.DataPath,
|
||||
"multipart_retention_hours", int(config.MultipartCleanupRetention/time.Hour),
|
||||
"auth_enabled", authConfig.Enabled,
|
||||
"auth_region", authConfig.Region,
|
||||
)
|
||||
|
||||
if err := os.MkdirAll(config.DataPath, 0o755); err != nil {
|
||||
@@ -47,7 +60,19 @@ func main() {
|
||||
}
|
||||
|
||||
objectService := service.NewObjectService(metadataHandler, blobHandler, config.MultipartCleanupRetention)
|
||||
handler := api.NewHandler(objectService, logger, logConfig)
|
||||
authService, err := auth.NewService(authConfig, metadataHandler)
|
||||
if err != nil {
|
||||
_ = metadataHandler.Close()
|
||||
logger.Error("failed_to_initialize_auth_service", "error", err)
|
||||
return
|
||||
}
|
||||
if err := authService.EnsureBootstrap(); err != nil {
|
||||
_ = metadataHandler.Close()
|
||||
logger.Error("failed_to_ensure_bootstrap_auth_identity", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
handler := api.NewHandler(objectService, logger, logConfig, authService)
|
||||
addr := config.Address + ":" + strconv.Itoa(config.Port)
|
||||
|
||||
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
@@ -23,17 +23,21 @@ type MetadataHandler struct {
|
||||
var systemIndex = []byte("__SYSTEM_BUCKETS__")
|
||||
var multipartUploadIndex = []byte("__MULTIPART_UPLOADS__")
|
||||
var multipartUploadPartsIndex = []byte("__MULTIPART_UPLOAD_PARTS__")
|
||||
var authIdentitiesIndex = []byte("__AUTH_IDENTITIES__")
|
||||
var authPoliciesIndex = []byte("__AUTH_POLICIES__")
|
||||
|
||||
var validBucketName = regexp.MustCompile(`^[a-z0-9.-]+$`)
|
||||
|
||||
var (
|
||||
ErrInvalidBucketName = errors.New("invalid bucket name")
|
||||
ErrBucketAlreadyExists = errors.New("bucket already exists")
|
||||
ErrBucketNotFound = errors.New("bucket not found")
|
||||
ErrBucketNotEmpty = errors.New("bucket not empty")
|
||||
ErrObjectNotFound = errors.New("object not found")
|
||||
ErrMultipartNotFound = errors.New("multipart upload not found")
|
||||
ErrMultipartNotPending = errors.New("multipart upload is not pending")
|
||||
ErrInvalidBucketName = errors.New("invalid bucket name")
|
||||
ErrBucketAlreadyExists = errors.New("bucket already exists")
|
||||
ErrBucketNotFound = errors.New("bucket not found")
|
||||
ErrBucketNotEmpty = errors.New("bucket not empty")
|
||||
ErrObjectNotFound = errors.New("object not found")
|
||||
ErrMultipartNotFound = errors.New("multipart upload not found")
|
||||
ErrMultipartNotPending = errors.New("multipart upload is not pending")
|
||||
ErrAuthIdentityNotFound = errors.New("auth identity not found")
|
||||
ErrAuthPolicyNotFound = errors.New("auth policy not found")
|
||||
)
|
||||
|
||||
func NewMetadataHandler(dbPath string) (*MetadataHandler, error) {
|
||||
@@ -67,6 +71,22 @@ func NewMetadataHandler(dbPath string) (*MetadataHandler, error) {
|
||||
_ = db.Close()
|
||||
return nil, err
|
||||
}
|
||||
err = h.db.Update(func(tx *bbolt.Tx) error {
|
||||
_, err := tx.CreateBucketIfNotExists(authIdentitiesIndex)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
_ = db.Close()
|
||||
return nil, err
|
||||
}
|
||||
err = h.db.Update(func(tx *bbolt.Tx) error {
|
||||
_, err := tx.CreateBucketIfNotExists(authPoliciesIndex)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
_ = db.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return h, nil
|
||||
}
|
||||
@@ -99,6 +119,106 @@ func (h *MetadataHandler) Close() error {
|
||||
return h.db.Close()
|
||||
}
|
||||
|
||||
func (h *MetadataHandler) PutAuthIdentity(identity *models.AuthIdentity) error {
|
||||
if identity == nil {
|
||||
return errors.New("auth identity is required")
|
||||
}
|
||||
if strings.TrimSpace(identity.AccessKeyID) == "" {
|
||||
return errors.New("access key id is required")
|
||||
}
|
||||
return h.db.Update(func(tx *bbolt.Tx) error {
|
||||
bucket := tx.Bucket(authIdentitiesIndex)
|
||||
if bucket == nil {
|
||||
return errors.New("auth identities index not found")
|
||||
}
|
||||
payload, err := json.Marshal(identity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return bucket.Put([]byte(identity.AccessKeyID), payload)
|
||||
})
|
||||
}
|
||||
|
||||
func (h *MetadataHandler) GetAuthIdentity(accessKeyID string) (*models.AuthIdentity, error) {
|
||||
accessKeyID = strings.TrimSpace(accessKeyID)
|
||||
if accessKeyID == "" {
|
||||
return nil, errors.New("access key id is required")
|
||||
}
|
||||
|
||||
var identity *models.AuthIdentity
|
||||
err := h.db.View(func(tx *bbolt.Tx) error {
|
||||
bucket := tx.Bucket(authIdentitiesIndex)
|
||||
if bucket == nil {
|
||||
return errors.New("auth identities index not found")
|
||||
}
|
||||
payload := bucket.Get([]byte(accessKeyID))
|
||||
if payload == nil {
|
||||
return fmt.Errorf("%w: %s", ErrAuthIdentityNotFound, accessKeyID)
|
||||
}
|
||||
record := models.AuthIdentity{}
|
||||
if err := json.Unmarshal(payload, &record); err != nil {
|
||||
return err
|
||||
}
|
||||
identity = &record
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return identity, nil
|
||||
}
|
||||
|
||||
func (h *MetadataHandler) PutAuthPolicy(policy *models.AuthPolicy) error {
|
||||
if policy == nil {
|
||||
return errors.New("auth policy is required")
|
||||
}
|
||||
principal := strings.TrimSpace(policy.Principal)
|
||||
if principal == "" {
|
||||
return errors.New("auth policy principal is required")
|
||||
}
|
||||
policy.Principal = principal
|
||||
return h.db.Update(func(tx *bbolt.Tx) error {
|
||||
bucket := tx.Bucket(authPoliciesIndex)
|
||||
if bucket == nil {
|
||||
return errors.New("auth policies index not found")
|
||||
}
|
||||
payload, err := json.Marshal(policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return bucket.Put([]byte(principal), payload)
|
||||
})
|
||||
}
|
||||
|
||||
func (h *MetadataHandler) GetAuthPolicy(accessKeyID string) (*models.AuthPolicy, error) {
|
||||
accessKeyID = strings.TrimSpace(accessKeyID)
|
||||
if accessKeyID == "" {
|
||||
return nil, errors.New("access key id is required")
|
||||
}
|
||||
|
||||
var policy *models.AuthPolicy
|
||||
err := h.db.View(func(tx *bbolt.Tx) error {
|
||||
bucket := tx.Bucket(authPoliciesIndex)
|
||||
if bucket == nil {
|
||||
return errors.New("auth policies index not found")
|
||||
}
|
||||
payload := bucket.Get([]byte(accessKeyID))
|
||||
if payload == nil {
|
||||
return fmt.Errorf("%w: %s", ErrAuthPolicyNotFound, accessKeyID)
|
||||
}
|
||||
record := models.AuthPolicy{}
|
||||
if err := json.Unmarshal(payload, &record); err != nil {
|
||||
return err
|
||||
}
|
||||
policy = &record
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
func (h *MetadataHandler) CreateBucket(bucketName string) error {
|
||||
if !isValidBucketName(bucketName) {
|
||||
return fmt.Errorf("%w: %s", ErrInvalidBucketName, bucketName)
|
||||
|
||||
@@ -183,3 +183,26 @@ type DeleteError struct {
|
||||
Code string `xml:"Code"`
|
||||
Message string `xml:"Message"`
|
||||
}
|
||||
|
||||
type AuthIdentity struct {
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretEnc string `json:"secret_enc"`
|
||||
SecretNonce string `json:"secret_nonce"`
|
||||
EncAlg string `json:"enc_alg"`
|
||||
KeyVersion string `json:"key_version"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
}
|
||||
|
||||
type AuthPolicy struct {
|
||||
Principal string `json:"principal"`
|
||||
Statements []AuthPolicyStatement `json:"statements"`
|
||||
}
|
||||
|
||||
type AuthPolicyStatement struct {
|
||||
Effect string `json:"effect"`
|
||||
Actions []string `json:"actions"`
|
||||
Bucket string `json:"bucket"`
|
||||
Prefix string `json:"prefix"`
|
||||
}
|
||||
|
||||
@@ -21,6 +21,14 @@ type Config struct {
|
||||
GcInterval time.Duration
|
||||
GcEnabled bool
|
||||
MultipartCleanupRetention time.Duration
|
||||
AuthEnabled bool
|
||||
AuthRegion string
|
||||
AuthSkew time.Duration
|
||||
AuthMaxPresign time.Duration
|
||||
AuthMasterKey string
|
||||
AuthBootstrapAccessKey string
|
||||
AuthBootstrapSecretKey string
|
||||
AuthBootstrapPolicy string
|
||||
}
|
||||
|
||||
func NewConfig() *Config {
|
||||
@@ -39,6 +47,14 @@ func NewConfig() *Config {
|
||||
MultipartCleanupRetention: time.Duration(
|
||||
envIntRange("MULTIPART_RETENTION_HOURS", 24, 1, 24*30),
|
||||
) * time.Hour,
|
||||
AuthEnabled: envBool("AUTH_ENABLED", true),
|
||||
AuthRegion: firstNonEmpty(strings.TrimSpace(os.Getenv("AUTH_REGION")), "us-east-1"),
|
||||
AuthSkew: time.Duration(envIntRange("AUTH_SKEW_SECONDS", 300, 30, 3600)) * time.Second,
|
||||
AuthMaxPresign: time.Duration(envIntRange("AUTH_MAX_PRESIGN_SECONDS", 86400, 60, 86400)) * time.Second,
|
||||
AuthMasterKey: strings.TrimSpace(os.Getenv("AUTH_MASTER_KEY")),
|
||||
AuthBootstrapAccessKey: strings.TrimSpace(os.Getenv("AUTH_BOOTSTRAP_ACCESS_KEY")),
|
||||
AuthBootstrapSecretKey: strings.TrimSpace(os.Getenv("AUTH_BOOTSTRAP_SECRET_KEY")),
|
||||
AuthBootstrapPolicy: strings.TrimSpace(os.Getenv("AUTH_BOOTSTRAP_POLICY")),
|
||||
}
|
||||
|
||||
if config.LogFormat != "json" && config.LogFormat != "text" {
|
||||
|
||||
Reference in New Issue
Block a user