diff --git a/.env.example b/.env.example index f5276d1..a4dcab2 100644 --- a/.env.example +++ b/.env.example @@ -7,3 +7,14 @@ ADDRESS=0.0.0.0 GC_INTERVAL=10 GC_ENABLED=true MULTIPART_RETENTION_HOURS=24 +AUTH_ENABLED=false +AUTH_REGION=us-east-1 +AUTH_SKEW_SECONDS=300 +AUTH_MAX_PRESIGN_SECONDS=86400 +# When AUTH_ENABLED=true you MUST set AUTH_MASTER_KEY to a strong random value, e.g.: +# openssl rand -base64 32 +AUTH_MASTER_KEY=REPLACE_WITH_SECURE_RANDOM_KEY +AUTH_BOOTSTRAP_ACCESS_KEY= +AUTH_BOOTSTRAP_SECRET_KEY= +AUTH_BOOTSTRAP_POLICY= +ADMIN_API_ENABLED=true diff --git a/README.md b/README.md index a4764a4..1a8c5cc 100644 --- a/README.md +++ b/README.md @@ -29,13 +29,40 @@ Multi-object delete: AWS SigV4 streaming payload decoding for uploads (`aws-chunked` request bodies) +Authentication: +- AWS SigV4 request verification (header and presigned URL forms) +- Local credential/policy store in bbolt +- Bootstrap access key/secret via environment variables + +Admin API (JSON): +- `POST /_admin/v1/users` +- `GET /_admin/v1/users` +- `GET /_admin/v1/users/{accessKeyId}` +- `PUT /_admin/v1/users/{accessKeyId}/policy` +- `PUT /_admin/v1/users/{accessKeyId}/status` +- `DELETE /_admin/v1/users/{accessKeyId}` + +## Auth Setup + +Required when `AUTH_ENABLED=true`: +- `AUTH_MASTER_KEY` must be base64 for 32 decoded bytes (AES-256 key), e.g. `openssl rand -base64 32` +- `AUTH_BOOTSTRAP_ACCESS_KEY` and `AUTH_BOOTSTRAP_SECRET_KEY` define initial credentials +- `ADMIN_API_ENABLED=true` enables `/_admin/v1/*` routes (bootstrap key only) + +Reference: `auth/README.md` + +Additional docs: +- Admin OpenAPI spec: `docs/admin-api-openapi.yaml` +- S3 compatibility matrix: `docs/s3-compatibility.md` + Health: - `GET /healthz` - `HEAD /healthz` +- `GET /metrics` (Prometheus exposition format) +- `HEAD /metrics` ## Limitations -- No authentication/authorization yet. - Not full S3 API coverage. - No versioning or lifecycle policies. - Error and edge-case behavior is still being refined for client compatibility. diff --git a/api/admin_api.go b/api/admin_api.go new file mode 100644 index 0000000..8130132 --- /dev/null +++ b/api/admin_api.go @@ -0,0 +1,300 @@ +package api + +import ( + "encoding/json" + "errors" + "fs/auth" + "fs/models" + "io" + "net/http" + "strconv" + "strings" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" +) + +const ( + maxAdminJSONBodyBytes = 1 << 20 + defaultAdminPageSize = 100 + maxAdminPageSize = 1000 +) + +type adminErrorResponse struct { + Code string `json:"code"` + Message string `json:"message"` + RequestID string `json:"requestId,omitempty"` +} + +type adminCreateUserRequest struct { + AccessKeyID string `json:"accessKeyId"` + SecretKey string `json:"secretKey,omitempty"` + Status string `json:"status,omitempty"` + Policy models.AuthPolicy `json:"policy"` +} + +type adminSetPolicyRequest struct { + Policy models.AuthPolicy `json:"policy"` +} + +type adminSetStatusRequest struct { + Status string `json:"status"` +} + +type adminUserListItem struct { + AccessKeyID string `json:"accessKeyId"` + Status string `json:"status"` + CreatedAt int64 `json:"createdAt"` + UpdatedAt int64 `json:"updatedAt"` +} + +type adminUserListResponse struct { + Items []adminUserListItem `json:"items"` + NextCursor string `json:"nextCursor,omitempty"` +} + +type adminUserResponse struct { + AccessKeyID string `json:"accessKeyId"` + Status string `json:"status"` + CreatedAt int64 `json:"createdAt"` + UpdatedAt int64 `json:"updatedAt"` + Policy *models.AuthPolicy `json:"policy,omitempty"` + SecretKey string `json:"secretKey,omitempty"` +} + +func (h *Handler) registerAdminRoutes() { + h.router.Route("/_admin/v1", func(r chi.Router) { + r.Post("/users", h.handleAdminCreateUser) + r.Get("/users", h.handleAdminListUsers) + r.Get("/users/{accessKeyId}", h.handleAdminGetUser) + r.Put("/users/{accessKeyId}/policy", h.handleAdminSetUserPolicy) + r.Put("/users/{accessKeyId}/status", h.handleAdminSetUserStatus) + r.Delete("/users/{accessKeyId}", h.handleAdminDeleteUser) + }) +} + +func (h *Handler) handleAdminCreateUser(w http.ResponseWriter, r *http.Request) { + if !h.requireBootstrapAdmin(w, r) { + return + } + + var req adminCreateUserRequest + if err := decodeJSONBody(w, r, &req); err != nil { + writeAdminError(w, r, http.StatusBadRequest, "InvalidRequest", err.Error()) + return + } + + created, err := h.authSvc.CreateUser(auth.CreateUserInput{ + AccessKeyID: req.AccessKeyID, + SecretKey: req.SecretKey, + Status: req.Status, + Policy: req.Policy, + }) + if err != nil { + writeMappedAdminError(w, r, err) + return + } + + resp := adminUserResponse{ + AccessKeyID: created.AccessKeyID, + Status: created.Status, + CreatedAt: created.CreatedAt, + UpdatedAt: created.UpdatedAt, + Policy: &created.Policy, + SecretKey: created.SecretKey, + } + writeJSON(w, http.StatusCreated, resp) +} + +func (h *Handler) handleAdminListUsers(w http.ResponseWriter, r *http.Request) { + if !h.requireBootstrapAdmin(w, r) { + return + } + + limit := defaultAdminPageSize + if raw := strings.TrimSpace(r.URL.Query().Get("limit")); raw != "" { + parsed, err := strconv.Atoi(raw) + if err != nil || parsed < 1 || parsed > maxAdminPageSize { + writeAdminError(w, r, http.StatusBadRequest, "InvalidRequest", "limit must be between 1 and 1000") + return + } + limit = parsed + } + cursor := strings.TrimSpace(r.URL.Query().Get("cursor")) + + users, nextCursor, err := h.authSvc.ListUsers(limit, cursor) + if err != nil { + writeMappedAdminError(w, r, err) + return + } + + items := make([]adminUserListItem, 0, len(users)) + for _, user := range users { + items = append(items, adminUserListItem{ + AccessKeyID: user.AccessKeyID, + Status: user.Status, + CreatedAt: user.CreatedAt, + UpdatedAt: user.UpdatedAt, + }) + } + + writeJSON(w, http.StatusOK, adminUserListResponse{ + Items: items, + NextCursor: nextCursor, + }) +} + +func (h *Handler) handleAdminGetUser(w http.ResponseWriter, r *http.Request) { + if !h.requireBootstrapAdmin(w, r) { + return + } + + accessKeyID := chi.URLParam(r, "accessKeyId") + user, err := h.authSvc.GetUser(accessKeyID) + if err != nil { + writeMappedAdminError(w, r, err) + return + } + + resp := adminUserResponse{ + AccessKeyID: user.AccessKeyID, + Status: user.Status, + CreatedAt: user.CreatedAt, + UpdatedAt: user.UpdatedAt, + Policy: &user.Policy, + } + writeJSON(w, http.StatusOK, resp) +} + +func (h *Handler) handleAdminDeleteUser(w http.ResponseWriter, r *http.Request) { + if !h.requireBootstrapAdmin(w, r) { + return + } + accessKeyID := chi.URLParam(r, "accessKeyId") + if err := h.authSvc.DeleteUser(accessKeyID); err != nil { + writeMappedAdminError(w, r, err) + return + } + w.WriteHeader(http.StatusNoContent) +} + +func (h *Handler) handleAdminSetUserPolicy(w http.ResponseWriter, r *http.Request) { + if !h.requireBootstrapAdmin(w, r) { + return + } + + accessKeyID := chi.URLParam(r, "accessKeyId") + var req adminSetPolicyRequest + if err := decodeJSONBody(w, r, &req); err != nil { + writeAdminError(w, r, http.StatusBadRequest, "InvalidRequest", err.Error()) + return + } + + user, err := h.authSvc.SetUserPolicy(accessKeyID, req.Policy) + if err != nil { + writeMappedAdminError(w, r, err) + return + } + resp := adminUserResponse{ + AccessKeyID: user.AccessKeyID, + Status: user.Status, + CreatedAt: user.CreatedAt, + UpdatedAt: user.UpdatedAt, + Policy: &user.Policy, + } + writeJSON(w, http.StatusOK, resp) +} + +func (h *Handler) handleAdminSetUserStatus(w http.ResponseWriter, r *http.Request) { + if !h.requireBootstrapAdmin(w, r) { + return + } + + accessKeyID := chi.URLParam(r, "accessKeyId") + var req adminSetStatusRequest + if err := decodeJSONBody(w, r, &req); err != nil { + writeAdminError(w, r, http.StatusBadRequest, "InvalidRequest", err.Error()) + return + } + + user, err := h.authSvc.SetUserStatus(accessKeyID, req.Status) + if err != nil { + writeMappedAdminError(w, r, err) + return + } + resp := adminUserResponse{ + AccessKeyID: user.AccessKeyID, + Status: user.Status, + CreatedAt: user.CreatedAt, + UpdatedAt: user.UpdatedAt, + Policy: &user.Policy, + } + writeJSON(w, http.StatusOK, resp) +} + +func (h *Handler) requireBootstrapAdmin(w http.ResponseWriter, r *http.Request) bool { + authCtx, ok := auth.GetRequestContext(r.Context()) + if !ok || !authCtx.Authenticated { + writeAdminError(w, r, http.StatusForbidden, "Forbidden", "admin credentials are required") + return false + } + if h.authSvc == nil { + writeAdminError(w, r, http.StatusForbidden, "Forbidden", "admin access is not configured") + return false + } + + bootstrap := strings.TrimSpace(h.authSvc.Config().BootstrapAccessKey) + if bootstrap == "" || authCtx.AccessKeyID != bootstrap { + writeAdminError(w, r, http.StatusForbidden, "Forbidden", "admin access denied") + return false + } + return true +} + +func decodeJSONBody(w http.ResponseWriter, r *http.Request, dst any) error { + r.Body = http.MaxBytesReader(w, r.Body, maxAdminJSONBodyBytes) + decoder := json.NewDecoder(r.Body) + decoder.DisallowUnknownFields() + if err := decoder.Decode(dst); err != nil { + return err + } + if err := decoder.Decode(&struct{}{}); err != io.EOF { + return errors.New("request body must contain a single JSON object") + } + return nil +} + +func writeMappedAdminError(w http.ResponseWriter, r *http.Request, err error) { + switch { + case errors.Is(err, auth.ErrInvalidUserInput): + writeAdminError(w, r, http.StatusBadRequest, "InvalidRequest", err.Error()) + case errors.Is(err, auth.ErrUserAlreadyExists): + writeAdminError(w, r, http.StatusConflict, "UserAlreadyExists", "user already exists") + case errors.Is(err, auth.ErrUserNotFound): + writeAdminError(w, r, http.StatusNotFound, "UserNotFound", "user was not found") + case errors.Is(err, auth.ErrAuthNotEnabled): + writeAdminError(w, r, http.StatusServiceUnavailable, "AuthDisabled", "authentication subsystem is disabled") + default: + writeAdminError(w, r, http.StatusInternalServerError, "InternalError", "internal server error") + } +} + +func writeAdminError(w http.ResponseWriter, r *http.Request, status int, code string, message string) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + requestID := middleware.GetReqID(r.Context()) + if requestID != "" { + w.Header().Set("x-amz-request-id", requestID) + } + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(adminErrorResponse{ + Code: code, + Message: message, + RequestID: requestID, + }) +} + +func writeJSON(w http.ResponseWriter, status int, payload any) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(payload) +} diff --git a/api/api.go b/api/api.go index 5f01252..94fc10c 100644 --- a/api/api.go +++ b/api/api.go @@ -7,8 +7,10 @@ import ( "encoding/xml" "errors" "fmt" + "fs/auth" "fs/logging" "fs/metadata" + "fs/metrics" "fs/models" "fs/service" "io" @@ -30,6 +32,8 @@ type Handler struct { svc *service.ObjectService logger *slog.Logger logConfig logging.Config + authSvc *auth.Service + adminAPI bool } const ( @@ -44,7 +48,7 @@ const ( serverMaxConnections = 1024 ) -func NewHandler(svc *service.ObjectService, logger *slog.Logger, logConfig logging.Config) *Handler { +func NewHandler(svc *service.ObjectService, logger *slog.Logger, logConfig logging.Config, authSvc *auth.Service, adminAPI bool) *Handler { r := chi.NewRouter() r.Use(middleware.RequestID) r.Use(middleware.Recoverer) @@ -57,16 +61,24 @@ func NewHandler(svc *service.ObjectService, logger *slog.Logger, logConfig loggi svc: svc, logger: logger, logConfig: logConfig, + authSvc: authSvc, + adminAPI: adminAPI, } return h } func (h *Handler) setupRoutes() { h.router.Use(logging.HTTPMiddleware(h.logger, h.logConfig)) + h.router.Use(auth.Middleware(h.authSvc, h.logger, h.logConfig.Audit, writeMappedS3Error)) h.router.Get("/healthz", h.handleHealth) h.router.Head("/healthz", h.handleHealth) + h.router.Get("/metrics", h.handleMetrics) + h.router.Head("/metrics", h.handleMetrics) h.router.Get("/", h.handleGetBuckets) + if h.adminAPI { + h.registerAdminRoutes() + } h.router.Get("/{bucket}/", h.handleGetBucket) h.router.Get("/{bucket}", h.handleGetBucket) @@ -102,6 +114,18 @@ func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request) { } } +func (h *Handler) handleMetrics(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain; version=0.0.4; charset=utf-8") + if r.Method == http.MethodHead { + w.WriteHeader(http.StatusOK) + return + } + payload := metrics.Default.RenderPrometheus() + w.Header().Set("Content-Length", strconv.Itoa(len(payload))) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(payload)) +} + func validateObjectKey(key string) *s3APIError { if key == "" { err := s3ErrInvalidObjectKey @@ -218,6 +242,7 @@ func (h *Handler) handlePostObject(w http.ResponseWriter, r *http.Request) { writeS3Error(w, r, s3ErrMalformedXML, r.URL.Path) return } + metrics.Default.ObserveBatchSize(len(req.Parts)) manifest, err := h.svc.CompleteMultipartUpload(bucket, key, uploadID, req.Parts) if err != nil { @@ -293,6 +318,20 @@ func (h *Handler) handlePutObject(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) return } + metrics.Default.ObserveBatchSize(1) + + if ifNoneMatch := strings.TrimSpace(r.Header.Get("If-None-Match")); ifNoneMatch != "" { + manifest, err := h.svc.HeadObject(bucket, key) + if err != nil { + if !errors.Is(err, metadata.ErrObjectNotFound) { + writeMappedS3Error(w, r, err) + return + } + } else if ifNoneMatchPreconditionFailed(ifNoneMatch, manifest.ETag) { + writeS3Error(w, r, s3ErrPreconditionFailed, r.URL.Path) + return + } + } contentType := r.Header.Get("Content-Type") if contentType == "" { @@ -424,6 +463,25 @@ func decodeAWSChunkedPayload(src io.Reader, dst io.Writer) error { } } +func ifNoneMatchPreconditionFailed(headerValue, etag string) bool { + for _, rawToken := range strings.Split(headerValue, ",") { + token := strings.TrimSpace(rawToken) + if token == "" { + continue + } + if token == "*" { + return true + } + + token = strings.TrimPrefix(token, "W/") + token = strings.Trim(token, `"`) + if strings.EqualFold(token, etag) { + return true + } + } + return false +} + func (h *Handler) handlePutBucket(w http.ResponseWriter, r *http.Request) { bucket := chi.URLParam(r, "bucket") if err := h.svc.CreateBucket(bucket); err != nil { @@ -473,6 +531,7 @@ func (h *Handler) handlePostBucket(w http.ResponseWriter, r *http.Request) { writeS3Error(w, r, s3ErrTooManyDeleteObjects, r.URL.Path) return } + metrics.Default.ObserveBatchSize(len(req.Objects)) keys := make([]string, 0, len(req.Objects)) response := models.DeleteObjectsResult{ @@ -591,6 +650,7 @@ func newLimitedListener(inner net.Listener, maxConns int) net.Listener { if maxConns <= 0 { return inner } + metrics.Default.SetConnectionPoolMax(maxConns) return &limitedListener{ Listener: inner, slots: make(chan struct{}, maxConns), @@ -598,15 +658,26 @@ func newLimitedListener(inner net.Listener, maxConns int) net.Listener { } func (l *limitedListener) Accept() (net.Conn, error) { - l.slots <- struct{}{} + select { + case l.slots <- struct{}{}: + default: + metrics.Default.IncConnectionPoolWait() + metrics.Default.IncRequestQueueLength() + l.slots <- struct{}{} + metrics.Default.DecRequestQueueLength() + } conn, err := l.Listener.Accept() if err != nil { <-l.slots return nil, err } + metrics.Default.IncConnectionPoolActive() return &limitedConn{ Conn: conn, - done: func() { <-l.slots }, + done: func() { + <-l.slots + metrics.Default.DecConnectionPoolActive() + }, }, nil } @@ -666,26 +737,203 @@ func (h *Handler) handleGetBuckets(w http.ResponseWriter, r *http.Request) { func (h *Handler) handleGetBucket(w http.ResponseWriter, r *http.Request) { bucket := chi.URLParam(r, "bucket") + query := r.URL.Query() - if r.URL.Query().Get("list-type") == "2" { - h.handleListObjectsV2(w, r, bucket) - return - } - if r.URL.Query().Has("location") { - xmlResponse := ` - us-east-1` + if query.Has("location") { + region := "us-east-1" + if h.authSvc != nil { + candidate := strings.TrimSpace(h.authSvc.Config().Region) + if candidate != "" { + region = candidate + } + } + xmlResponse := fmt.Sprintf(` +%s`, region) w.Header().Set("Content-Type", "application/xml; charset=utf-8") w.Header().Set("Content-Length", strconv.Itoa(len(xmlResponse))) w.WriteHeader(http.StatusOK) - _, err := w.Write([]byte(xmlResponse)) - if err != nil { - return - } + _, _ = w.Write([]byte(xmlResponse)) return } - writeS3Error(w, r, s3ErrNotImplemented, r.URL.Path) + listType := strings.TrimSpace(query.Get("list-type")) + if listType == "2" { + h.handleListObjectsV2(w, r, bucket) + return + } + if listType != "" { + writeS3Error(w, r, s3ErrInvalidArgument, r.URL.Path) + return + } + + if shouldUseListObjectsV1(query) { + h.handleListObjectsV1(w, r, bucket) + return + } + + writeS3Error(w, r, s3ErrNotImplemented, r.URL.Path) +} + +func shouldUseListObjectsV1(query url.Values) bool { + if len(query) == 0 { + return true + } + + listingParams := map[string]struct{}{ + "delimiter": {}, + "encoding-type": {}, + "marker": {}, + "max-keys": {}, + "prefix": {}, + } + for key := range query { + if _, ok := listingParams[key]; !ok { + return false + } + } + return true +} + +func (h *Handler) handleListObjectsV1(w http.ResponseWriter, r *http.Request, bucket string) { + prefix := r.URL.Query().Get("prefix") + delimiter := r.URL.Query().Get("delimiter") + marker := r.URL.Query().Get("marker") + encodingType := strings.ToLower(strings.TrimSpace(r.URL.Query().Get("encoding-type"))) + if encodingType != "" && encodingType != "url" { + writeS3Error(w, r, s3ErrInvalidArgument, r.URL.Path) + return + } + + maxKeys := 1000 + if rawMaxKeys := strings.TrimSpace(r.URL.Query().Get("max-keys")); rawMaxKeys != "" { + parsed, err := strconv.Atoi(rawMaxKeys) + if err != nil || parsed < 0 { + writeS3Error(w, r, s3ErrInvalidArgument, r.URL.Path) + return + } + if parsed > 1000 { + parsed = 1000 + } + maxKeys = parsed + } + + result := models.ListBucketResultV1{ + Xmlns: "http://s3.amazonaws.com/doc/2006-03-01/", + Name: bucket, + Prefix: s3EncodeIfNeeded(prefix, encodingType), + Marker: s3EncodeIfNeeded(marker, encodingType), + Delimiter: s3EncodeIfNeeded(delimiter, encodingType), + MaxKeys: maxKeys, + EncodingType: encodingType, + } + + type pageEntry struct { + Object *models.ObjectManifest + CommonPrefix string + } + + entries := make([]pageEntry, 0, maxKeys) + seenCommonPrefixes := make(map[string]struct{}) + truncated := false + stopErr := errors.New("list_v1_page_complete") + + startKey := prefix + if marker != "" && marker > startKey { + startKey = marker + } + + if maxKeys > 0 { + err := h.svc.ForEachObjectFrom(bucket, startKey, func(object *models.ObjectManifest) error { + if object == nil { + return nil + } + key := object.Key + + if prefix != "" { + if key < prefix { + return nil + } + if !strings.HasPrefix(key, prefix) { + return stopErr + } + } + if marker != "" && key <= marker { + return nil + } + + if delimiter != "" { + relative := strings.TrimPrefix(key, prefix) + if idx := strings.Index(relative, delimiter); idx >= 0 { + commonPrefix := prefix + relative[:idx+len(delimiter)] + if marker != "" && commonPrefix <= marker { + return nil + } + if _, exists := seenCommonPrefixes[commonPrefix]; exists { + return nil + } + seenCommonPrefixes[commonPrefix] = struct{}{} + if len(entries) >= maxKeys { + truncated = true + return stopErr + } + entries = append(entries, pageEntry{ + CommonPrefix: commonPrefix, + }) + return nil + } + } + + if len(entries) >= maxKeys { + truncated = true + return stopErr + } + entries = append(entries, pageEntry{Object: object}) + return nil + }) + if err != nil && !errors.Is(err, stopErr) { + writeMappedS3Error(w, r, err) + return + } + } + + for _, entry := range entries { + if entry.Object != nil { + result.Contents = append(result.Contents, models.Contents{ + Key: s3EncodeIfNeeded(entry.Object.Key, encodingType), + LastModified: time.Unix(entry.Object.CreatedAt, 0).UTC().Format("2006-01-02T15:04:05.000Z"), + ETag: `"` + entry.Object.ETag + `"`, + Size: entry.Object.Size, + StorageClass: "STANDARD", + }) + } else { + result.CommonPrefixes = append(result.CommonPrefixes, models.CommonPrefixes{ + Prefix: s3EncodeIfNeeded(entry.CommonPrefix, encodingType), + }) + } + } + + result.IsTruncated = truncated + if result.IsTruncated && result.NextMarker == "" && len(entries) > 0 { + last := entries[len(entries)-1] + if last.Object != nil { + result.NextMarker = s3EncodeIfNeeded(last.Object.Key, encodingType) + } else { + result.NextMarker = s3EncodeIfNeeded(last.CommonPrefix, encodingType) + } + } + + xmlResponse, err := xml.MarshalIndent(result, "", " ") + if err != nil { + writeMappedS3Error(w, r, err) + return + } + + w.Header().Set("Content-Type", "application/xml; charset=utf-8") + w.Header().Set("Content-Length", strconv.Itoa(len(xml.Header)+len(xmlResponse))) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(xml.Header)) + _, _ = w.Write(xmlResponse) } func (h *Handler) handleListObjectsV2(w http.ResponseWriter, r *http.Request, bucket string) { diff --git a/api/s3_errors.go b/api/s3_errors.go index 7dbea68..95e7424 100644 --- a/api/s3_errors.go +++ b/api/s3_errors.go @@ -3,7 +3,9 @@ package api import ( "encoding/xml" "errors" + "fs/auth" "fs/metadata" + "fs/metrics" "fs/models" "fs/service" "net/http" @@ -58,6 +60,11 @@ var ( Code: "InvalidRange", Message: "The requested range is not satisfiable.", } + s3ErrPreconditionFailed = s3APIError{ + Status: http.StatusPreconditionFailed, + Code: "PreconditionFailed", + Message: "At least one of the pre-conditions you specified did not hold.", + } s3ErrEntityTooSmall = s3APIError{ Status: http.StatusBadRequest, Code: "EntityTooSmall", @@ -73,6 +80,41 @@ var ( Code: "MalformedXML", Message: "The request must contain no more than 1000 object identifiers.", } + s3ErrAccessDenied = s3APIError{ + Status: http.StatusForbidden, + Code: "AccessDenied", + Message: "Access Denied.", + } + s3ErrInvalidAccessKeyID = s3APIError{ + Status: http.StatusForbidden, + Code: "InvalidAccessKeyId", + Message: "The AWS Access Key Id you provided does not exist in our records.", + } + s3ErrSignatureDoesNotMatch = s3APIError{ + Status: http.StatusForbidden, + Code: "SignatureDoesNotMatch", + Message: "The request signature we calculated does not match the signature you provided.", + } + s3ErrAuthorizationHeaderMalformed = s3APIError{ + Status: http.StatusBadRequest, + Code: "AuthorizationHeaderMalformed", + Message: "The authorization header is malformed; the region/service/date is wrong or missing.", + } + s3ErrRequestTimeTooSkewed = s3APIError{ + Status: http.StatusForbidden, + Code: "RequestTimeTooSkewed", + Message: "The difference between the request time and the server's time is too large.", + } + s3ErrExpiredToken = s3APIError{ + Status: http.StatusBadRequest, + Code: "ExpiredToken", + Message: "The provided token has expired.", + } + s3ErrInvalidPresign = s3APIError{ + Status: http.StatusBadRequest, + Code: "AuthorizationQueryParametersError", + Message: "Error parsing the X-Amz-Credential parameter.", + } s3ErrInternal = s3APIError{ Status: http.StatusInternalServerError, Code: "InternalError", @@ -132,6 +174,26 @@ func mapToS3Error(err error) s3APIError { return s3ErrMalformedXML case errors.Is(err, service.ErrEntityTooSmall): return s3ErrEntityTooSmall + case errors.Is(err, auth.ErrAccessDenied): + return s3ErrAccessDenied + case errors.Is(err, auth.ErrInvalidAccessKeyID): + return s3ErrInvalidAccessKeyID + case errors.Is(err, auth.ErrSignatureDoesNotMatch): + return s3ErrSignatureDoesNotMatch + case errors.Is(err, auth.ErrAuthorizationHeaderMalformed): + return s3ErrAuthorizationHeaderMalformed + case errors.Is(err, auth.ErrRequestTimeTooSkewed): + return s3ErrRequestTimeTooSkewed + case errors.Is(err, auth.ErrExpiredToken): + return s3ErrExpiredToken + case errors.Is(err, auth.ErrCredentialDisabled): + return s3ErrAccessDenied + case errors.Is(err, auth.ErrNoAuthCredentials): + return s3ErrAccessDenied + case errors.Is(err, auth.ErrUnsupportedAuthScheme): + return s3ErrAuthorizationHeaderMalformed + case errors.Is(err, auth.ErrInvalidPresign): + return s3ErrInvalidPresign default: return s3ErrInternal } @@ -139,12 +201,19 @@ func mapToS3Error(err error) s3APIError { func writeS3Error(w http.ResponseWriter, r *http.Request, apiErr s3APIError, resource string) { requestID := "" + op := "other" if r != nil { requestID = middleware.GetReqID(r.Context()) + isDeletePost := false + if r.Method == http.MethodPost { + _, isDeletePost = r.URL.Query()["delete"] + } + op = metrics.NormalizeHTTPOperation(r.Method, isDeletePost) if requestID != "" { w.Header().Set("x-amz-request-id", requestID) } } + metrics.Default.ObserveError(op, apiErr.Code) w.Header().Set("Content-Type", "application/xml; charset=utf-8") w.WriteHeader(apiErr.Status) diff --git a/auth/README.md b/auth/README.md new file mode 100644 index 0000000..74a79f0 --- /dev/null +++ b/auth/README.md @@ -0,0 +1,150 @@ +# Authentication Design + +This folder implements S3-compatible request authentication using AWS Signature Version 4 (SigV4), with local identity and policy data stored in bbolt. + +## Goals +- Keep S3 client compatibility for request signing. +- Avoid external auth databases. +- Store secrets encrypted at rest (not plaintext in bbolt). +- Keep authorization simple and explicit. + +## High-Level Architecture +- `auth/middleware.go` + - HTTP middleware that enforces auth before API handlers. + - Exempts `/healthz`. + - Calls auth service and writes mapped S3 XML errors on failure. +- `auth/service.go` + - Main auth orchestration: + - parse SigV4 from request + - validate timestamp/scope/service/region + - load identity from metadata + - decrypt secret + - verify signature + - evaluate policy against requested S3 action +- `auth/sigv4.go` + - Canonical SigV4 parsing and verification helpers. + - Supports header auth and presigned query auth. +- `auth/policy.go` + - Authorization evaluator (deny overrides allow). +- `auth/action.go` + - Maps HTTP method/path/query to logical S3 action + resource target. +- `auth/crypto.go` + - AES-256-GCM encryption/decryption for stored secret keys. +- `auth/context.go` + - Carries authentication result in request context for downstream logic. +- `auth/config.go` + - Normalized auth configuration. +- `auth/errors.go` + - Domain auth errors used by API S3 error mapping. + +## Config Model +Auth is configured through env (read in `utils/config.go`, converted in `auth/config.go`): + +- `AUTH_ENABLED` +- `AUTH_REGION` +- `AUTH_SKEW_SECONDS` +- `AUTH_MAX_PRESIGN_SECONDS` +- `AUTH_MASTER_KEY` +- `AUTH_BOOTSTRAP_ACCESS_KEY` +- `AUTH_BOOTSTRAP_SECRET_KEY` +- `AUTH_BOOTSTRAP_POLICY` (optional JSON) + +Important: +- If `AUTH_ENABLED=true`, `AUTH_MASTER_KEY` is required. +- `AUTH_MASTER_KEY` must be base64 that decodes to exactly 32 bytes (AES-256 key). + +## Persistence Model (bbolt) +Implemented in metadata layer: +- `__AUTH_IDENTITIES__` bucket stores `models.AuthIdentity` + - `access_key_id` + - encrypted secret (`secret_enc`, `secret_nonce`) + - status (`active`/disabled) + - timestamps +- `__AUTH_POLICIES__` bucket stores `models.AuthPolicy` + - `principal` + - statements (`effect`, `actions`, `bucket`, `prefix`) + +## Bootstrap Identity +On startup (`main.go`): +1. Build auth config. +2. Create auth service with metadata store. +3. Call `EnsureBootstrap()`. + +If bootstrap env key/secret are set: +- identity is created/updated +- secret is encrypted with AES-GCM and stored +- policy is created: + - default: full access (`s3:*`, `bucket=*`, `prefix=*`) + - or overridden by `AUTH_BOOTSTRAP_POLICY` + +## Request Authentication Flow +For each non-health request: +1. Parse SigV4 input (header or presigned query). +2. Validate structural fields: + - algorithm + - credential scope + - service must be `s3` + - region must match config +3. Validate time: + - `x-amz-date` format + - skew within `AUTH_SKEW_SECONDS` + - presigned expiry within `AUTH_MAX_PRESIGN_SECONDS` +4. Load identity by access key id. +5. Ensure identity status is active. +6. Decrypt stored secret using master key. +7. Recompute canonical request and expected signature. +8. Compare signatures. +9. Resolve target action from request. +10. Evaluate policy; deny overrides allow. +11. Store auth result in request context and continue. + +## Authorization Semantics +Policy evaluator rules: +- No matching allow => denied. +- Any matching deny => denied (even if allow also matches). +- Wildcards supported: + - action: `*` or `s3:*` + - bucket: `*` + - prefix: `*` + +Action resolution includes: +- bucket APIs (`CreateBucket`, `ListBucket`, `HeadBucket`, `DeleteBucket`) +- object APIs (`GetObject`, `PutObject`, `DeleteObject`) +- multipart APIs (`CreateMultipartUpload`, `UploadPart`, `ListMultipartUploadParts`, `CompleteMultipartUpload`, `AbortMultipartUpload`) + +## Error Behavior +Auth errors are mapped to S3-style XML errors in `api/s3_errors.go`, including: +- `AccessDenied` +- `InvalidAccessKeyId` +- `SignatureDoesNotMatch` +- `AuthorizationHeaderMalformed` +- `RequestTimeTooSkewed` +- `ExpiredToken` +- `AuthorizationQueryParametersError` + +## Audit Logging +When `AUDIT_LOG=true` and auth is enabled: +- successful auth attempts emit `auth_success` +- failed auth attempts emit `auth_failed` + +Each audit entry includes method, path, remote IP, and request ID (if present). Success logs also include access key ID and auth type. + +## Security Notes +- Secret keys are recoverable by server design (required for SigV4 verification). +- They are encrypted at rest, not hashed. +- Master key rotation is not implemented yet. +- Keep `AUTH_MASTER_KEY` protected (secret manager/systemd env file/etc.). + +## Current Scope / Limitations +- No STS/session-token auth yet. +- No admin API for managing multiple users yet. +- Policy language is intentionally minimal, not full IAM. +- No automatic key rotation workflows. + +## Practical Next Step +To support multiple users cleanly, add admin operations in auth service + API: +- create user +- rotate secret +- set policy +- disable/enable +- delete user diff --git a/auth/action.go b/auth/action.go new file mode 100644 index 0000000..576da2e --- /dev/null +++ b/auth/action.go @@ -0,0 +1,93 @@ +package auth + +import ( + "net/http" + "strings" +) + +type Action string + +const ( + ActionListAllMyBuckets Action = "s3:ListAllMyBuckets" + ActionCreateBucket Action = "s3:CreateBucket" + ActionHeadBucket Action = "s3:HeadBucket" + ActionDeleteBucket Action = "s3:DeleteBucket" + ActionListBucket Action = "s3:ListBucket" + ActionGetObject Action = "s3:GetObject" + ActionPutObject Action = "s3:PutObject" + ActionDeleteObject Action = "s3:DeleteObject" + ActionCreateMultipartUpload Action = "s3:CreateMultipartUpload" + ActionUploadPart Action = "s3:UploadPart" + ActionListMultipartParts Action = "s3:ListMultipartUploadParts" + ActionCompleteMultipart Action = "s3:CompleteMultipartUpload" + ActionAbortMultipartUpload Action = "s3:AbortMultipartUpload" +) + +type RequestTarget struct { + Action Action + Bucket string + Key string +} + +func resolveTarget(r *http.Request) RequestTarget { + path := strings.TrimPrefix(r.URL.Path, "/") + if path == "" { + return RequestTarget{Action: ActionListAllMyBuckets} + } + + parts := strings.SplitN(path, "/", 2) + bucket := parts[0] + key := "" + if len(parts) > 1 { + key = parts[1] + } + + if key == "" { + switch r.Method { + case http.MethodPut: + return RequestTarget{Action: ActionCreateBucket, Bucket: bucket} + case http.MethodHead: + return RequestTarget{Action: ActionHeadBucket, Bucket: bucket} + case http.MethodDelete: + return RequestTarget{Action: ActionDeleteBucket, Bucket: bucket} + case http.MethodGet: + return RequestTarget{Action: ActionListBucket, Bucket: bucket} + case http.MethodPost: + if _, ok := r.URL.Query()["delete"]; ok { + return RequestTarget{Action: ActionDeleteObject, Bucket: bucket} + } + } + return RequestTarget{Bucket: bucket} + } + + uploadID := r.URL.Query().Get("uploadId") + partNumber := r.URL.Query().Get("partNumber") + if _, ok := r.URL.Query()["uploads"]; ok && r.Method == http.MethodPost { + return RequestTarget{Action: ActionCreateMultipartUpload, Bucket: bucket, Key: key} + } + if uploadID != "" { + switch r.Method { + case http.MethodPut: + if partNumber != "" { + return RequestTarget{Action: ActionUploadPart, Bucket: bucket, Key: key} + } + case http.MethodGet: + return RequestTarget{Action: ActionListMultipartParts, Bucket: bucket, Key: key} + case http.MethodPost: + return RequestTarget{Action: ActionCompleteMultipart, Bucket: bucket, Key: key} + case http.MethodDelete: + return RequestTarget{Action: ActionAbortMultipartUpload, Bucket: bucket, Key: key} + } + } + + switch r.Method { + case http.MethodGet, http.MethodHead: + return RequestTarget{Action: ActionGetObject, Bucket: bucket, Key: key} + case http.MethodPut: + return RequestTarget{Action: ActionPutObject, Bucket: bucket, Key: key} + case http.MethodDelete: + return RequestTarget{Action: ActionDeleteObject, Bucket: bucket, Key: key} + } + + return RequestTarget{Bucket: bucket, Key: key} +} diff --git a/auth/config.go b/auth/config.go new file mode 100644 index 0000000..3d78144 --- /dev/null +++ b/auth/config.go @@ -0,0 +1,50 @@ +package auth + +import ( + "strings" + "time" +) + +type Config struct { + Enabled bool + Region string + ClockSkew time.Duration + MaxPresignDuration time.Duration + MasterKey string + BootstrapAccessKey string + BootstrapSecretKey string + BootstrapPolicy string +} + +func ConfigFromValues( + enabled bool, + region string, + skew time.Duration, + maxPresign time.Duration, + masterKey string, + bootstrapAccessKey string, + bootstrapSecretKey string, + bootstrapPolicy string, +) Config { + region = strings.TrimSpace(region) + if region == "" { + region = "us-east-1" + } + if skew <= 0 { + skew = 5 * time.Minute + } + if maxPresign <= 0 { + maxPresign = 24 * time.Hour + } + + return Config{ + Enabled: enabled, + Region: region, + ClockSkew: skew, + MaxPresignDuration: maxPresign, + MasterKey: strings.TrimSpace(masterKey), + BootstrapAccessKey: strings.TrimSpace(bootstrapAccessKey), + BootstrapSecretKey: strings.TrimSpace(bootstrapSecretKey), + BootstrapPolicy: strings.TrimSpace(bootstrapPolicy), + } +} diff --git a/auth/context.go b/auth/context.go new file mode 100644 index 0000000..9fab0b0 --- /dev/null +++ b/auth/context.go @@ -0,0 +1,23 @@ +package auth + +import "context" + +type RequestContext struct { + Authenticated bool + AccessKeyID string + AuthType string +} + +type contextKey int + +const requestContextKey contextKey = iota + +func WithRequestContext(ctx context.Context, authCtx RequestContext) context.Context { + return context.WithValue(ctx, requestContextKey, authCtx) +} + +func GetRequestContext(ctx context.Context) (RequestContext, bool) { + value := ctx.Value(requestContextKey) + authCtx, ok := value.(RequestContext) + return authCtx, ok +} diff --git a/auth/crypto.go b/auth/crypto.go new file mode 100644 index 0000000..58f810e --- /dev/null +++ b/auth/crypto.go @@ -0,0 +1,74 @@ +package auth + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "fmt" + "io" +) + +const ( + masterKeyLength = 32 + gcmNonceLength = 12 +) + +func decodeMasterKey(raw string) ([]byte, error) { + decoded, err := base64.StdEncoding.DecodeString(raw) + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrInvalidMasterKey, err) + } + if len(decoded) != masterKeyLength { + return nil, fmt.Errorf("%w: expected %d-byte decoded key", ErrInvalidMasterKey, masterKeyLength) + } + return decoded, nil +} + +func encryptSecret(masterKey []byte, accessKeyID, secret string) (ciphertextB64 string, nonceB64 string, err error) { + block, err := aes.NewCipher(masterKey) + if err != nil { + return "", "", err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", "", err + } + + nonce := make([]byte, gcmNonceLength) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return "", "", err + } + + ciphertext := gcm.Seal(nil, nonce, []byte(secret), []byte(accessKeyID)) + return base64.StdEncoding.EncodeToString(ciphertext), base64.StdEncoding.EncodeToString(nonce), nil +} + +func decryptSecret(masterKey []byte, accessKeyID, ciphertextB64, nonceB64 string) (string, error) { + block, err := aes.NewCipher(masterKey) + if err != nil { + return "", err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + + ciphertext, err := base64.StdEncoding.DecodeString(ciphertextB64) + if err != nil { + return "", err + } + nonce, err := base64.StdEncoding.DecodeString(nonceB64) + if err != nil { + return "", err + } + if len(nonce) != gcmNonceLength { + return "", fmt.Errorf("invalid nonce length: %d", len(nonce)) + } + + plaintext, err := gcm.Open(nil, nonce, ciphertext, []byte(accessKeyID)) + if err != nil { + return "", err + } + return string(plaintext), nil +} diff --git a/auth/errors.go b/auth/errors.go new file mode 100644 index 0000000..1c6b65a --- /dev/null +++ b/auth/errors.go @@ -0,0 +1,22 @@ +package auth + +import "errors" + +var ( + ErrAccessDenied = errors.New("access denied") + ErrInvalidAccessKeyID = errors.New("invalid access key id") + ErrUserAlreadyExists = errors.New("user already exists") + ErrUserNotFound = errors.New("user not found") + ErrInvalidUserInput = errors.New("invalid user input") + ErrSignatureDoesNotMatch = errors.New("signature does not match") + ErrAuthorizationHeaderMalformed = errors.New("authorization header malformed") + ErrRequestTimeTooSkewed = errors.New("request time too skewed") + ErrExpiredToken = errors.New("expired token") + ErrCredentialDisabled = errors.New("credential disabled") + ErrAuthNotEnabled = errors.New("authentication is not enabled") + ErrMasterKeyRequired = errors.New("auth master key is required") + ErrInvalidMasterKey = errors.New("invalid auth master key") + ErrNoAuthCredentials = errors.New("no auth credentials found") + ErrUnsupportedAuthScheme = errors.New("unsupported auth scheme") + ErrInvalidPresign = errors.New("invalid presigned request") +) diff --git a/auth/middleware.go b/auth/middleware.go new file mode 100644 index 0000000..24f656e --- /dev/null +++ b/auth/middleware.go @@ -0,0 +1,111 @@ +package auth + +import ( + "errors" + "fs/metrics" + "log/slog" + "net" + "net/http" + + "github.com/go-chi/chi/v5/middleware" +) + +func Middleware( + svc *Service, + logger *slog.Logger, + auditEnabled bool, + onError func(http.ResponseWriter, *http.Request, error), +) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authCtx := RequestContext{Authenticated: false, AuthType: "none"} + if svc == nil || !svc.Config().Enabled { + metrics.Default.ObserveAuth("bypass", "disabled", "auth_disabled") + next.ServeHTTP(w, r.WithContext(WithRequestContext(r.Context(), authCtx))) + return + } + + if r.URL.Path == "/healthz" { + metrics.Default.ObserveAuth("bypass", "none", "public_endpoint") + next.ServeHTTP(w, r.WithContext(WithRequestContext(r.Context(), authCtx))) + return + } + + resolvedCtx, err := svc.AuthenticateRequest(r) + if err != nil { + metrics.Default.ObserveAuth("error", "sigv4", authErrorClass(err)) + if auditEnabled && logger != nil { + requestID := middleware.GetReqID(r.Context()) + attrs := []any{ + "method", r.Method, + "path", r.URL.Path, + "remote_ip", clientIP(r.RemoteAddr), + "error", err.Error(), + } + if requestID != "" { + attrs = append(attrs, "request_id", requestID) + } + logger.Warn("auth_failed", attrs...) + } + if onError != nil { + onError(w, r, err) + return + } + http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) + return + } + + metrics.Default.ObserveAuth("ok", resolvedCtx.AuthType, "none") + if auditEnabled && logger != nil { + requestID := middleware.GetReqID(r.Context()) + attrs := []any{ + "method", r.Method, + "path", r.URL.Path, + "remote_ip", clientIP(r.RemoteAddr), + "access_key_id", resolvedCtx.AccessKeyID, + "auth_type", resolvedCtx.AuthType, + } + if requestID != "" { + attrs = append(attrs, "request_id", requestID) + } + logger.Info("auth_success", attrs...) + } + next.ServeHTTP(w, r.WithContext(WithRequestContext(r.Context(), resolvedCtx))) + }) + } +} + +func authErrorClass(err error) string { + switch { + case errors.Is(err, ErrInvalidAccessKeyID): + return "invalid_access_key" + case errors.Is(err, ErrSignatureDoesNotMatch): + return "signature_mismatch" + case errors.Is(err, ErrAuthorizationHeaderMalformed): + return "auth_header_malformed" + case errors.Is(err, ErrRequestTimeTooSkewed): + return "time_skew" + case errors.Is(err, ErrExpiredToken): + return "expired_token" + case errors.Is(err, ErrNoAuthCredentials): + return "missing_credentials" + case errors.Is(err, ErrUnsupportedAuthScheme): + return "unsupported_auth_scheme" + case errors.Is(err, ErrInvalidPresign): + return "invalid_presign" + case errors.Is(err, ErrCredentialDisabled): + return "credential_disabled" + case errors.Is(err, ErrAccessDenied): + return "access_denied" + default: + return "other" + } +} + +func clientIP(remoteAddr string) string { + host, _, err := net.SplitHostPort(remoteAddr) + if err == nil && host != "" { + return host + } + return remoteAddr +} diff --git a/auth/policy.go b/auth/policy.go new file mode 100644 index 0000000..2508fc9 --- /dev/null +++ b/auth/policy.go @@ -0,0 +1,66 @@ +package auth + +import ( + "fs/models" + "strings" +) + +func isAllowed(policy *models.AuthPolicy, target RequestTarget) bool { + if policy == nil { + return false + } + + allowed := false + for _, stmt := range policy.Statements { + if !statementMatches(stmt, target) { + continue + } + effect := strings.ToLower(strings.TrimSpace(stmt.Effect)) + if effect == "deny" { + return false + } + if effect == "allow" { + allowed = true + } + } + return allowed +} + +func statementMatches(stmt models.AuthPolicyStatement, target RequestTarget) bool { + if !actionMatches(stmt.Actions, target.Action) { + return false + } + if !bucketMatches(stmt.Bucket, target.Bucket) { + return false + } + if target.Key == "" { + return true + } + + prefix := strings.TrimSpace(stmt.Prefix) + if prefix == "" || prefix == "*" { + return true + } + return strings.HasPrefix(target.Key, prefix) +} + +func actionMatches(actions []string, action Action) bool { + if len(actions) == 0 { + return false + } + for _, current := range actions { + normalized := strings.TrimSpace(current) + if normalized == "*" || normalized == "s3:*" || strings.EqualFold(normalized, string(action)) { + return true + } + } + return false +} + +func bucketMatches(pattern, bucket string) bool { + pattern = strings.TrimSpace(pattern) + if pattern == "" || pattern == "*" { + return true + } + return pattern == bucket +} diff --git a/auth/service.go b/auth/service.go new file mode 100644 index 0000000..8ed0587 --- /dev/null +++ b/auth/service.go @@ -0,0 +1,554 @@ +package auth + +import ( + "crypto/rand" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "fs/metadata" + "fs/models" + "net/http" + "regexp" + "strings" + "time" +) + +type Store interface { + GetAuthIdentity(accessKeyID string) (*models.AuthIdentity, error) + PutAuthIdentity(identity *models.AuthIdentity) error + DeleteAuthIdentity(accessKeyID string) error + ListAuthIdentities(limit int, after string) ([]models.AuthIdentity, string, error) + GetAuthPolicy(accessKeyID string) (*models.AuthPolicy, error) + PutAuthPolicy(policy *models.AuthPolicy) error + DeleteAuthPolicy(accessKeyID string) error +} + +type CreateUserInput struct { + AccessKeyID string + SecretKey string + Status string + Policy models.AuthPolicy +} + +type UserSummary struct { + AccessKeyID string + Status string + CreatedAt int64 + UpdatedAt int64 +} + +type UserDetails struct { + AccessKeyID string + Status string + CreatedAt int64 + UpdatedAt int64 + Policy models.AuthPolicy +} + +type CreateUserResult struct { + UserDetails + SecretKey string +} + +var validAccessKeyID = regexp.MustCompile(`^[A-Za-z0-9][A-Za-z0-9._-]{2,127}$`) + +type Service struct { + cfg Config + store Store + masterKey []byte + now func() time.Time +} + +func NewService(cfg Config, store Store) (*Service, error) { + if store == nil { + return nil, errors.New("auth store is required") + } + + svc := &Service{ + cfg: cfg, + store: store, + now: func() time.Time { return time.Now().UTC() }, + } + if !cfg.Enabled { + return svc, nil + } + + if strings.TrimSpace(cfg.MasterKey) == "" { + return nil, ErrMasterKeyRequired + } + masterKey, err := decodeMasterKey(cfg.MasterKey) + if err != nil { + return nil, err + } + svc.masterKey = masterKey + return svc, nil +} + +func (s *Service) Config() Config { + return s.cfg +} + +func (s *Service) EnsureBootstrap() error { + if !s.cfg.Enabled { + return nil + } + accessKey := strings.TrimSpace(s.cfg.BootstrapAccessKey) + secret := strings.TrimSpace(s.cfg.BootstrapSecretKey) + if accessKey == "" || secret == "" { + return nil + } + + if len(accessKey) < 3 { + return errors.New("bootstrap access key must be at least 3 characters") + } + if len(secret) < 8 { + return errors.New("bootstrap secret key must be at least 8 characters") + } + + now := s.now().Unix() + ciphertext, nonce, err := encryptSecret(s.masterKey, accessKey, secret) + if err != nil { + return err + } + identity := &models.AuthIdentity{ + AccessKeyID: accessKey, + SecretEnc: ciphertext, + SecretNonce: nonce, + EncAlg: "AES-256-GCM", + KeyVersion: "v1", + Status: "active", + CreatedAt: now, + UpdatedAt: now, + } + if existing, err := s.store.GetAuthIdentity(accessKey); err == nil && existing != nil { + identity.CreatedAt = existing.CreatedAt + } + if err := s.store.PutAuthIdentity(identity); err != nil { + return err + } + + policy := defaultBootstrapPolicy(accessKey) + if strings.TrimSpace(s.cfg.BootstrapPolicy) != "" { + parsed, err := parsePolicyJSON(s.cfg.BootstrapPolicy) + if err != nil { + return err + } + policy = parsed + policy.Principal = accessKey + } + return s.store.PutAuthPolicy(policy) +} + +func (s *Service) AuthenticateRequest(r *http.Request) (RequestContext, error) { + if !s.cfg.Enabled { + return RequestContext{Authenticated: false, AuthType: "disabled"}, nil + } + input, err := parseSigV4(r) + if err != nil { + return RequestContext{}, err + } + + if err := validateSigV4Input(s.now(), s.cfg, input); err != nil { + return RequestContext{}, err + } + + identity, err := s.store.GetAuthIdentity(input.AccessKeyID) + if err != nil { + return RequestContext{}, ErrInvalidAccessKeyID + } + if !strings.EqualFold(identity.Status, "active") { + return RequestContext{}, ErrCredentialDisabled + } + + secret, err := decryptSecret(s.masterKey, identity.AccessKeyID, identity.SecretEnc, identity.SecretNonce) + if err != nil { + return RequestContext{}, ErrSignatureDoesNotMatch + } + ok, err := signatureMatches(secret, r, input) + if err != nil { + return RequestContext{}, err + } + if !ok { + return RequestContext{}, ErrSignatureDoesNotMatch + } + + authType := "sigv4-header" + if input.Presigned { + authType = "sigv4-presign" + } + + if strings.HasPrefix(r.URL.Path, "/_admin/") { + return RequestContext{ + Authenticated: true, + AccessKeyID: identity.AccessKeyID, + AuthType: authType, + }, nil + } + + policy, err := s.store.GetAuthPolicy(identity.AccessKeyID) + if err != nil { + return RequestContext{}, ErrAccessDenied + } + target := resolveTarget(r) + if target.Action == "" { + return RequestContext{}, ErrAccessDenied + } + if !isAllowed(policy, target) { + return RequestContext{}, ErrAccessDenied + } + + return RequestContext{ + Authenticated: true, + AccessKeyID: identity.AccessKeyID, + AuthType: authType, + }, nil +} + +func (s *Service) CreateUser(input CreateUserInput) (*CreateUserResult, error) { + if !s.cfg.Enabled { + return nil, ErrAuthNotEnabled + } + + accessKeyID := strings.TrimSpace(input.AccessKeyID) + if !validAccessKeyID.MatchString(accessKeyID) { + return nil, fmt.Errorf("%w: invalid access key id", ErrInvalidUserInput) + } + + secretKey := strings.TrimSpace(input.SecretKey) + if secretKey == "" { + generated, err := generateSecretKey(32) + if err != nil { + return nil, err + } + secretKey = generated + } + if len(secretKey) < 8 { + return nil, fmt.Errorf("%w: secret key must be at least 8 characters", ErrInvalidUserInput) + } + + status := normalizeUserStatus(input.Status) + if status == "" { + return nil, fmt.Errorf("%w: status must be active or disabled", ErrInvalidUserInput) + } + + policy, err := normalizePolicy(input.Policy, accessKeyID) + if err != nil { + return nil, err + } + + existing, err := s.store.GetAuthIdentity(accessKeyID) + if err == nil && existing != nil { + return nil, ErrUserAlreadyExists + } + if err != nil && !errors.Is(err, metadata.ErrAuthIdentityNotFound) { + return nil, err + } + + now := s.now().Unix() + ciphertext, nonce, err := encryptSecret(s.masterKey, accessKeyID, secretKey) + if err != nil { + return nil, err + } + identity := &models.AuthIdentity{ + AccessKeyID: accessKeyID, + SecretEnc: ciphertext, + SecretNonce: nonce, + EncAlg: "AES-256-GCM", + KeyVersion: "v1", + Status: status, + CreatedAt: now, + UpdatedAt: now, + } + if err := s.store.PutAuthIdentity(identity); err != nil { + return nil, err + } + if err := s.store.PutAuthPolicy(&policy); err != nil { + return nil, err + } + + return &CreateUserResult{ + UserDetails: UserDetails{ + AccessKeyID: accessKeyID, + Status: status, + CreatedAt: now, + UpdatedAt: now, + Policy: policy, + }, + SecretKey: secretKey, + }, nil +} + +func (s *Service) ListUsers(limit int, cursor string) ([]UserSummary, string, error) { + if !s.cfg.Enabled { + return nil, "", ErrAuthNotEnabled + } + if limit <= 0 { + limit = 100 + } + if limit > 1000 { + limit = 1000 + } + + identities, nextCursor, err := s.store.ListAuthIdentities(limit, cursor) + if err != nil { + return nil, "", err + } + users := make([]UserSummary, 0, len(identities)) + for _, identity := range identities { + users = append(users, UserSummary{ + AccessKeyID: identity.AccessKeyID, + Status: normalizeUserStatus(identity.Status), + CreatedAt: identity.CreatedAt, + UpdatedAt: identity.UpdatedAt, + }) + } + return users, nextCursor, nil +} + +func (s *Service) GetUser(accessKeyID string) (*UserDetails, error) { + if !s.cfg.Enabled { + return nil, ErrAuthNotEnabled + } + accessKeyID = strings.TrimSpace(accessKeyID) + if accessKeyID == "" { + return nil, fmt.Errorf("%w: access key id is required", ErrInvalidUserInput) + } + + identity, err := s.store.GetAuthIdentity(accessKeyID) + if err != nil { + if errors.Is(err, metadata.ErrAuthIdentityNotFound) { + return nil, ErrUserNotFound + } + return nil, err + } + policy, err := s.store.GetAuthPolicy(accessKeyID) + if err != nil { + if errors.Is(err, metadata.ErrAuthPolicyNotFound) { + return nil, ErrUserNotFound + } + return nil, err + } + + return &UserDetails{ + AccessKeyID: identity.AccessKeyID, + Status: normalizeUserStatus(identity.Status), + CreatedAt: identity.CreatedAt, + UpdatedAt: identity.UpdatedAt, + Policy: *policy, + }, nil +} + +func (s *Service) DeleteUser(accessKeyID string) error { + if !s.cfg.Enabled { + return ErrAuthNotEnabled + } + accessKeyID = strings.TrimSpace(accessKeyID) + if !validAccessKeyID.MatchString(accessKeyID) { + return fmt.Errorf("%w: invalid access key id", ErrInvalidUserInput) + } + + bootstrap := strings.TrimSpace(s.cfg.BootstrapAccessKey) + if bootstrap != "" && accessKeyID == bootstrap { + return fmt.Errorf("%w: bootstrap user cannot be deleted", ErrInvalidUserInput) + } + + if _, err := s.store.GetAuthIdentity(accessKeyID); err != nil { + if errors.Is(err, metadata.ErrAuthIdentityNotFound) { + return ErrUserNotFound + } + return err + } + + if err := s.store.DeleteAuthIdentity(accessKeyID); err != nil { + if errors.Is(err, metadata.ErrAuthIdentityNotFound) { + return ErrUserNotFound + } + return err + } + + if err := s.store.DeleteAuthPolicy(accessKeyID); err != nil && !errors.Is(err, metadata.ErrAuthPolicyNotFound) { + return err + } + return nil +} + +func (s *Service) SetUserPolicy(accessKeyID string, policy models.AuthPolicy) (*UserDetails, error) { + if !s.cfg.Enabled { + return nil, ErrAuthNotEnabled + } + accessKeyID = strings.TrimSpace(accessKeyID) + if !validAccessKeyID.MatchString(accessKeyID) { + return nil, fmt.Errorf("%w: invalid access key id", ErrInvalidUserInput) + } + + identity, err := s.store.GetAuthIdentity(accessKeyID) + if err != nil { + if errors.Is(err, metadata.ErrAuthIdentityNotFound) { + return nil, ErrUserNotFound + } + return nil, err + } + + normalizedPolicy, err := normalizePolicy(policy, accessKeyID) + if err != nil { + return nil, err + } + if err := s.store.PutAuthPolicy(&normalizedPolicy); err != nil { + return nil, err + } + + identity.UpdatedAt = s.now().Unix() + if err := s.store.PutAuthIdentity(identity); err != nil { + return nil, err + } + + return &UserDetails{ + AccessKeyID: identity.AccessKeyID, + Status: normalizeUserStatus(identity.Status), + CreatedAt: identity.CreatedAt, + UpdatedAt: identity.UpdatedAt, + Policy: normalizedPolicy, + }, nil +} + +func (s *Service) SetUserStatus(accessKeyID, status string) (*UserDetails, error) { + if !s.cfg.Enabled { + return nil, ErrAuthNotEnabled + } + accessKeyID = strings.TrimSpace(accessKeyID) + if !validAccessKeyID.MatchString(accessKeyID) { + return nil, fmt.Errorf("%w: invalid access key id", ErrInvalidUserInput) + } + + status = strings.TrimSpace(status) + if status == "" { + return nil, fmt.Errorf("%w: status is required", ErrInvalidUserInput) + } + normalizedStatus := normalizeUserStatus(status) + if normalizedStatus == "" { + return nil, fmt.Errorf("%w: status must be active or disabled", ErrInvalidUserInput) + } + + identity, err := s.store.GetAuthIdentity(accessKeyID) + if err != nil { + if errors.Is(err, metadata.ErrAuthIdentityNotFound) { + return nil, ErrUserNotFound + } + return nil, err + } + identity.Status = normalizedStatus + identity.UpdatedAt = s.now().Unix() + if err := s.store.PutAuthIdentity(identity); err != nil { + return nil, err + } + + policy, err := s.store.GetAuthPolicy(accessKeyID) + if err != nil { + if errors.Is(err, metadata.ErrAuthPolicyNotFound) { + return nil, ErrUserNotFound + } + return nil, err + } + + return &UserDetails{ + AccessKeyID: identity.AccessKeyID, + Status: normalizeUserStatus(identity.Status), + CreatedAt: identity.CreatedAt, + UpdatedAt: identity.UpdatedAt, + Policy: *policy, + }, nil +} + +func parsePolicyJSON(raw string) (*models.AuthPolicy, error) { + policy := models.AuthPolicy{} + if err := json.Unmarshal([]byte(raw), &policy); err != nil { + return nil, fmt.Errorf("invalid bootstrap policy: %w", err) + } + if len(policy.Statements) == 0 { + return nil, errors.New("bootstrap policy must contain at least one statement") + } + return &policy, nil +} + +func defaultBootstrapPolicy(principal string) *models.AuthPolicy { + return &models.AuthPolicy{ + Principal: principal, + Statements: []models.AuthPolicyStatement{ + { + Effect: "allow", + Actions: []string{"s3:*"}, + Bucket: "*", + Prefix: "*", + }, + }, + } +} + +func generateSecretKey(length int) (string, error) { + if length <= 0 { + length = 32 + } + buf := make([]byte, length) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(buf), nil +} + +func normalizeUserStatus(raw string) string { + status := strings.ToLower(strings.TrimSpace(raw)) + if status == "" { + return "active" + } + if status != "active" && status != "disabled" { + return "" + } + return status +} + +func normalizePolicy(policy models.AuthPolicy, principal string) (models.AuthPolicy, error) { + if len(policy.Statements) == 0 { + return models.AuthPolicy{}, fmt.Errorf("%w: at least one policy statement is required", ErrInvalidUserInput) + } + + out := models.AuthPolicy{ + Principal: principal, + Statements: make([]models.AuthPolicyStatement, 0, len(policy.Statements)), + } + for _, stmt := range policy.Statements { + effect := strings.ToLower(strings.TrimSpace(stmt.Effect)) + if effect != "allow" && effect != "deny" { + return models.AuthPolicy{}, fmt.Errorf("%w: invalid policy effect %q", ErrInvalidUserInput, stmt.Effect) + } + + actions := make([]string, 0, len(stmt.Actions)) + for _, action := range stmt.Actions { + action = strings.TrimSpace(action) + if action == "" { + continue + } + actions = append(actions, action) + } + if len(actions) == 0 { + return models.AuthPolicy{}, fmt.Errorf("%w: policy statement must include at least one action", ErrInvalidUserInput) + } + + bucket := strings.TrimSpace(stmt.Bucket) + if bucket == "" { + bucket = "*" + } + prefix := strings.TrimSpace(stmt.Prefix) + if prefix == "" { + prefix = "*" + } + + out.Statements = append(out.Statements, models.AuthPolicyStatement{ + Effect: effect, + Actions: actions, + Bucket: bucket, + Prefix: prefix, + }) + } + return out, nil +} diff --git a/auth/service_admin_test.go b/auth/service_admin_test.go new file mode 100644 index 0000000..bb167e8 --- /dev/null +++ b/auth/service_admin_test.go @@ -0,0 +1,223 @@ +package auth + +import ( + "encoding/base64" + "errors" + "fs/metadata" + "fs/models" + "path/filepath" + "testing" +) + +func TestAdminCreateListGetUser(t *testing.T) { + meta, svc := newTestAuthService(t) + + created, err := svc.CreateUser(CreateUserInput{ + AccessKeyID: "backup-user", + Policy: models.AuthPolicy{ + Statements: []models.AuthPolicyStatement{ + { + Effect: "allow", + Actions: []string{"s3:GetObject"}, + Bucket: "backup-bucket", + Prefix: "restic/", + }, + }, + }, + }) + if err != nil { + t.Fatalf("CreateUser returned error: %v", err) + } + if created.SecretKey == "" { + t.Fatalf("CreateUser should return generated secret") + } + if created.AccessKeyID != "backup-user" { + t.Fatalf("CreateUser access key mismatch: got %q", created.AccessKeyID) + } + if created.Policy.Principal != "backup-user" { + t.Fatalf("policy principal mismatch: got %q", created.Policy.Principal) + } + + users, nextCursor, err := svc.ListUsers(100, "") + if err != nil { + t.Fatalf("ListUsers returned error: %v", err) + } + if nextCursor != "" { + t.Fatalf("unexpected next cursor: %q", nextCursor) + } + if len(users) != 1 { + t.Fatalf("ListUsers returned %d users, want 1", len(users)) + } + if users[0].AccessKeyID != "backup-user" { + t.Fatalf("ListUsers returned wrong user: %q", users[0].AccessKeyID) + } + + got, err := svc.GetUser("backup-user") + if err != nil { + t.Fatalf("GetUser returned error: %v", err) + } + if got.AccessKeyID != "backup-user" { + t.Fatalf("GetUser access key mismatch: got %q", got.AccessKeyID) + } + if got.Policy.Principal != "backup-user" { + t.Fatalf("GetUser policy principal mismatch: got %q", got.Policy.Principal) + } + if len(got.Policy.Statements) != 1 { + t.Fatalf("GetUser policy statement count = %d, want 1", len(got.Policy.Statements)) + } + + _ = meta +} + +func TestCreateUserDuplicateFails(t *testing.T) { + _, svc := newTestAuthService(t) + + input := CreateUserInput{ + AccessKeyID: "duplicate-user", + SecretKey: "super-secret-1", + Policy: models.AuthPolicy{ + Statements: []models.AuthPolicyStatement{ + {Effect: "allow", Actions: []string{"s3:*"}, Bucket: "*", Prefix: "*"}, + }, + }, + } + if _, err := svc.CreateUser(input); err != nil { + t.Fatalf("first CreateUser returned error: %v", err) + } + if _, err := svc.CreateUser(input); !errors.Is(err, ErrUserAlreadyExists) { + t.Fatalf("second CreateUser error = %v, want ErrUserAlreadyExists", err) + } +} + +func TestCreateUserRejectsInvalidAccessKey(t *testing.T) { + _, svc := newTestAuthService(t) + + _, err := svc.CreateUser(CreateUserInput{ + AccessKeyID: "x", + SecretKey: "super-secret-1", + Policy: models.AuthPolicy{ + Statements: []models.AuthPolicyStatement{ + {Effect: "allow", Actions: []string{"s3:*"}, Bucket: "*", Prefix: "*"}, + }, + }, + }) + if !errors.Is(err, ErrInvalidUserInput) { + t.Fatalf("CreateUser error = %v, want ErrInvalidUserInput", err) + } +} + +func TestDeleteUser(t *testing.T) { + _, svc := newTestAuthService(t) + + _, err := svc.CreateUser(CreateUserInput{ + AccessKeyID: "delete-user", + SecretKey: "super-secret-1", + Policy: models.AuthPolicy{ + Statements: []models.AuthPolicyStatement{ + {Effect: "allow", Actions: []string{"s3:*"}, Bucket: "*", Prefix: "*"}, + }, + }, + }) + if err != nil { + t.Fatalf("CreateUser returned error: %v", err) + } + + if err := svc.DeleteUser("delete-user"); err != nil { + t.Fatalf("DeleteUser returned error: %v", err) + } + if _, err := svc.GetUser("delete-user"); !errors.Is(err, ErrUserNotFound) { + t.Fatalf("GetUser after delete error = %v, want ErrUserNotFound", err) + } +} + +func TestDeleteBootstrapUserRejected(t *testing.T) { + _, svc := newTestAuthService(t) + + if err := svc.DeleteUser("root-user"); !errors.Is(err, ErrInvalidUserInput) { + t.Fatalf("DeleteUser bootstrap error = %v, want ErrInvalidUserInput", err) + } +} + +func TestSetUserPolicy(t *testing.T) { + _, svc := newTestAuthService(t) + + _, err := svc.CreateUser(CreateUserInput{ + AccessKeyID: "policy-user", + SecretKey: "super-secret-1", + Policy: models.AuthPolicy{ + Statements: []models.AuthPolicyStatement{ + {Effect: "allow", Actions: []string{"s3:GetObject"}, Bucket: "b1", Prefix: "*"}, + }, + }, + }) + if err != nil { + t.Fatalf("CreateUser returned error: %v", err) + } + + updated, err := svc.SetUserPolicy("policy-user", models.AuthPolicy{ + Statements: []models.AuthPolicyStatement{ + {Effect: "allow", Actions: []string{"s3:PutObject"}, Bucket: "b2", Prefix: "p/"}, + }, + }) + if err != nil { + t.Fatalf("SetUserPolicy returned error: %v", err) + } + if len(updated.Policy.Statements) != 1 || updated.Policy.Statements[0].Actions[0] != "s3:PutObject" { + t.Fatalf("SetUserPolicy did not apply new policy: %+v", updated.Policy) + } +} + +func TestSetUserStatus(t *testing.T) { + _, svc := newTestAuthService(t) + + _, err := svc.CreateUser(CreateUserInput{ + AccessKeyID: "status-user", + SecretKey: "super-secret-1", + Policy: models.AuthPolicy{ + Statements: []models.AuthPolicyStatement{ + {Effect: "allow", Actions: []string{"s3:*"}, Bucket: "*", Prefix: "*"}, + }, + }, + }) + if err != nil { + t.Fatalf("CreateUser returned error: %v", err) + } + + updated, err := svc.SetUserStatus("status-user", "disabled") + if err != nil { + t.Fatalf("SetUserStatus returned error: %v", err) + } + if updated.Status != "disabled" { + t.Fatalf("SetUserStatus status = %q, want disabled", updated.Status) + } +} + +func newTestAuthService(t *testing.T) (*metadata.MetadataHandler, *Service) { + t.Helper() + + dbPath := filepath.Join(t.TempDir(), "metadata.db") + meta, err := metadata.NewMetadataHandler(dbPath) + if err != nil { + t.Fatalf("NewMetadataHandler returned error: %v", err) + } + t.Cleanup(func() { + _ = meta.Close() + }) + + masterKey := base64.StdEncoding.EncodeToString(make([]byte, 32)) + cfg := ConfigFromValues( + true, + "us-east-1", + 0, + 0, + masterKey, + "root-user", + "root-secret-123", + "", + ) + svc, err := NewService(cfg, meta) + if err != nil { + t.Fatalf("NewService returned error: %v", err) + } + return meta, svc +} diff --git a/auth/sigv4.go b/auth/sigv4.go new file mode 100644 index 0000000..d79b628 --- /dev/null +++ b/auth/sigv4.go @@ -0,0 +1,372 @@ +package auth + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "fmt" + "net/http" + "net/url" + "sort" + "strconv" + "strings" + "time" +) + +const ( + sigV4Algorithm = "AWS4-HMAC-SHA256" +) + +type sigV4Input struct { + AccessKeyID string + Date string + Region string + Service string + Scope string + SignedHeaders []string + SignedHeadersRaw string + SignatureHex string + AmzDate string + ExpiresSeconds int + Presigned bool +} + +func parseSigV4(r *http.Request) (*sigV4Input, error) { + if r == nil { + return nil, fmt.Errorf("%w: nil request", ErrAuthorizationHeaderMalformed) + } + if strings.EqualFold(r.URL.Query().Get("X-Amz-Algorithm"), sigV4Algorithm) { + return parsePresignedSigV4(r) + } + return parseHeaderSigV4(r) +} + +func parseHeaderSigV4(r *http.Request) (*sigV4Input, error) { + header := strings.TrimSpace(r.Header.Get("Authorization")) + if header == "" { + return nil, ErrNoAuthCredentials + } + if !strings.HasPrefix(header, sigV4Algorithm+" ") { + return nil, fmt.Errorf("%w: unsupported authorization algorithm", ErrUnsupportedAuthScheme) + } + + params := parseAuthorizationParams(strings.TrimSpace(strings.TrimPrefix(header, sigV4Algorithm))) + credentialRaw := params["Credential"] + signedHeadersRaw := params["SignedHeaders"] + signatureHex := params["Signature"] + if credentialRaw == "" || signedHeadersRaw == "" || signatureHex == "" { + return nil, fmt.Errorf("%w: missing required authorization fields", ErrAuthorizationHeaderMalformed) + } + + accessKeyID, date, region, service, scope, err := parseCredential(credentialRaw) + if err != nil { + return nil, err + } + + amzDate := strings.TrimSpace(r.Header.Get("x-amz-date")) + if amzDate == "" { + return nil, fmt.Errorf("%w: x-amz-date is required", ErrAuthorizationHeaderMalformed) + } + signedHeaders := splitSignedHeaders(signedHeadersRaw) + if len(signedHeaders) == 0 { + return nil, fmt.Errorf("%w: signed headers are required", ErrAuthorizationHeaderMalformed) + } + + return &sigV4Input{ + AccessKeyID: accessKeyID, + Date: date, + Region: region, + Service: service, + Scope: scope, + SignedHeaders: signedHeaders, + SignedHeadersRaw: strings.ToLower(strings.TrimSpace(signedHeadersRaw)), + SignatureHex: strings.ToLower(strings.TrimSpace(signatureHex)), + AmzDate: amzDate, + Presigned: false, + }, nil +} + +func parsePresignedSigV4(r *http.Request) (*sigV4Input, error) { + query := r.URL.Query() + if !strings.EqualFold(query.Get("X-Amz-Algorithm"), sigV4Algorithm) { + return nil, fmt.Errorf("%w: invalid X-Amz-Algorithm", ErrInvalidPresign) + } + + credentialRaw := strings.TrimSpace(query.Get("X-Amz-Credential")) + signedHeadersRaw := strings.TrimSpace(query.Get("X-Amz-SignedHeaders")) + signatureHex := strings.TrimSpace(query.Get("X-Amz-Signature")) + amzDate := strings.TrimSpace(query.Get("X-Amz-Date")) + expiresRaw := strings.TrimSpace(query.Get("X-Amz-Expires")) + if credentialRaw == "" || signedHeadersRaw == "" || signatureHex == "" || amzDate == "" || expiresRaw == "" { + return nil, fmt.Errorf("%w: missing presigned query fields", ErrInvalidPresign) + } + expires, err := strconv.Atoi(expiresRaw) + if err != nil || expires < 0 { + return nil, fmt.Errorf("%w: invalid X-Amz-Expires", ErrInvalidPresign) + } + + accessKeyID, date, region, service, scope, err := parseCredential(credentialRaw) + if err != nil { + return nil, err + } + signedHeaders := splitSignedHeaders(signedHeadersRaw) + if len(signedHeaders) == 0 { + return nil, fmt.Errorf("%w: signed headers are required", ErrInvalidPresign) + } + + return &sigV4Input{ + AccessKeyID: accessKeyID, + Date: date, + Region: region, + Service: service, + Scope: scope, + SignedHeaders: signedHeaders, + SignedHeadersRaw: strings.ToLower(strings.TrimSpace(signedHeadersRaw)), + SignatureHex: strings.ToLower(signatureHex), + AmzDate: amzDate, + ExpiresSeconds: expires, + Presigned: true, + }, nil +} + +func parseCredential(raw string) (accessKeyID string, date string, region string, service string, scope string, err error) { + parts := strings.Split(strings.TrimSpace(raw), "/") + if len(parts) != 5 { + return "", "", "", "", "", fmt.Errorf("%w: invalid credential scope", ErrAuthorizationHeaderMalformed) + } + accessKeyID = strings.TrimSpace(parts[0]) + date = strings.TrimSpace(parts[1]) + region = strings.TrimSpace(parts[2]) + service = strings.TrimSpace(parts[3]) + terminal := strings.TrimSpace(parts[4]) + if accessKeyID == "" || date == "" || region == "" || service == "" || terminal != "aws4_request" { + return "", "", "", "", "", fmt.Errorf("%w: invalid credential scope", ErrAuthorizationHeaderMalformed) + } + scope = strings.Join(parts[1:], "/") + return accessKeyID, date, region, service, scope, nil +} + +func splitSignedHeaders(raw string) []string { + raw = strings.ToLower(strings.TrimSpace(raw)) + if raw == "" { + return nil + } + parts := strings.Split(raw, ";") + headers := make([]string, 0, len(parts)) + for _, current := range parts { + current = strings.TrimSpace(current) + if current == "" { + continue + } + headers = append(headers, current) + } + return headers +} + +func parseAuthorizationParams(raw string) map[string]string { + params := make(map[string]string) + raw = strings.TrimSpace(raw) + raw = strings.TrimPrefix(raw, " ") + for _, token := range strings.Split(raw, ",") { + token = strings.TrimSpace(token) + key, value, found := strings.Cut(token, "=") + if !found { + continue + } + params[strings.TrimSpace(key)] = strings.TrimSpace(value) + } + return params +} + +func validateSigV4Input(now time.Time, cfg Config, input *sigV4Input) error { + if input == nil { + return fmt.Errorf("%w: empty signature input", ErrAuthorizationHeaderMalformed) + } + if !strings.EqualFold(input.Service, "s3") { + return fmt.Errorf("%w: unsupported service", ErrAuthorizationHeaderMalformed) + } + if !strings.EqualFold(input.Region, cfg.Region) { + return fmt.Errorf("%w: region mismatch", ErrAuthorizationHeaderMalformed) + } + + requestTime, err := time.Parse("20060102T150405Z", input.AmzDate) + if err != nil { + return fmt.Errorf("%w: invalid x-amz-date", ErrAuthorizationHeaderMalformed) + } + delta := now.Sub(requestTime) + if delta > cfg.ClockSkew || delta < -cfg.ClockSkew { + return ErrRequestTimeTooSkewed + } + + if input.Presigned { + if input.ExpiresSeconds > int(cfg.MaxPresignDuration.Seconds()) { + return fmt.Errorf("%w: presign expires too large", ErrInvalidPresign) + } + expiresAt := requestTime.Add(time.Duration(input.ExpiresSeconds) * time.Second) + if now.After(expiresAt) { + return ErrExpiredToken + } + } + return nil +} + +func signatureMatches(secret string, r *http.Request, input *sigV4Input) (bool, error) { + payloadHash := resolvePayloadHash(r, input.Presigned) + canonicalRequest, err := buildCanonicalRequest(r, input.SignedHeaders, payloadHash, input.Presigned) + if err != nil { + return false, err + } + stringToSign := buildStringToSign(input.AmzDate, input.Scope, canonicalRequest) + signingKey := deriveSigningKey(secret, input.Date, input.Region, input.Service) + expectedSig := hex.EncodeToString(hmacSHA256(signingKey, stringToSign)) + return hmac.Equal([]byte(expectedSig), []byte(input.SignatureHex)), nil +} + +func resolvePayloadHash(r *http.Request, presigned bool) string { + if presigned { + return "UNSIGNED-PAYLOAD" + } + hash := strings.TrimSpace(r.Header.Get("x-amz-content-sha256")) + if hash == "" { + return "UNSIGNED-PAYLOAD" + } + return hash +} + +func buildCanonicalRequest(r *http.Request, signedHeaders []string, payloadHash string, presigned bool) (string, error) { + canonicalURI := canonicalPath(r.URL) + canonicalQuery := canonicalQueryString(r.URL.RawQuery, presigned) + canonicalHeaders, signedHeadersRaw, err := canonicalHeadersForRequest(r, signedHeaders) + if err != nil { + return "", err + } + + return strings.Join([]string{ + r.Method, + canonicalURI, + canonicalQuery, + canonicalHeaders, + signedHeadersRaw, + payloadHash, + }, "\n"), nil +} + +func canonicalPath(u *url.URL) string { + if u == nil { + return "/" + } + path := u.EscapedPath() + if path == "" { + return "/" + } + return path +} + +type queryPair struct { + Key string + Value string +} + +func canonicalQueryString(rawQuery string, presigned bool) string { + if rawQuery == "" { + return "" + } + values, _ := url.ParseQuery(rawQuery) + pairs := make([]queryPair, 0) + for key, valueList := range values { + if presigned && strings.EqualFold(key, "X-Amz-Signature") { + continue + } + if len(valueList) == 0 { + pairs = append(pairs, queryPair{Key: key, Value: ""}) + continue + } + for _, value := range valueList { + pairs = append(pairs, queryPair{Key: key, Value: value}) + } + } + sort.Slice(pairs, func(i, j int) bool { + if pairs[i].Key == pairs[j].Key { + return pairs[i].Value < pairs[j].Value + } + return pairs[i].Key < pairs[j].Key + }) + + encoded := make([]string, 0, len(pairs)) + for _, pair := range pairs { + encoded = append(encoded, awsEncodeQuery(pair.Key)+"="+awsEncodeQuery(pair.Value)) + } + return strings.Join(encoded, "&") +} + +func awsEncodeQuery(value string) string { + encoded := url.QueryEscape(value) + encoded = strings.ReplaceAll(encoded, "+", "%20") + encoded = strings.ReplaceAll(encoded, "*", "%2A") + encoded = strings.ReplaceAll(encoded, "%7E", "~") + return encoded +} + +func canonicalHeadersForRequest(r *http.Request, signedHeaders []string) (canonical string, signedRaw string, err error) { + if len(signedHeaders) == 0 { + return "", "", fmt.Errorf("%w: empty signed headers", ErrAuthorizationHeaderMalformed) + } + + normalized := make([]string, 0, len(signedHeaders)) + lines := make([]string, 0, len(signedHeaders)) + for _, headerName := range signedHeaders { + headerName = strings.ToLower(strings.TrimSpace(headerName)) + if headerName == "" { + continue + } + var value string + if headerName == "host" { + value = r.Host + } else { + values, ok := r.Header[http.CanonicalHeaderKey(headerName)] + if !ok || len(values) == 0 { + return "", "", fmt.Errorf("%w: missing signed header %q", ErrAuthorizationHeaderMalformed, headerName) + } + value = strings.Join(values, ",") + } + value = normalizeHeaderValue(value) + normalized = append(normalized, headerName) + lines = append(lines, headerName+":"+value) + } + + if len(lines) == 0 { + return "", "", fmt.Errorf("%w: no valid signed headers", ErrAuthorizationHeaderMalformed) + } + signedRaw = strings.Join(normalized, ";") + canonical = strings.Join(lines, "\n") + "\n" + return canonical, signedRaw, nil +} + +func normalizeHeaderValue(value string) string { + value = strings.TrimSpace(value) + parts := strings.Fields(value) + return strings.Join(parts, " ") +} + +func buildStringToSign(amzDate string, scope string, canonicalRequest string) string { + canonicalHash := sha256.Sum256([]byte(canonicalRequest)) + return strings.Join([]string{ + sigV4Algorithm, + amzDate, + scope, + hex.EncodeToString(canonicalHash[:]), + }, "\n") +} + +func deriveSigningKey(secret, date, region, service string) []byte { + kDate := hmacSHA256([]byte("AWS4"+secret), date) + kRegion := hmacSHA256(kDate, region) + kService := hmacSHA256(kRegion, service) + return hmacSHA256(kService, "aws4_request") +} + +func hmacSHA256(key []byte, value string) []byte { + mac := hmac.New(sha256.New, key) + _, _ = mac.Write([]byte(value)) + return mac.Sum(nil) +} diff --git a/docs/admin-api-openapi.yaml b/docs/admin-api-openapi.yaml new file mode 100644 index 0000000..fc5f744 --- /dev/null +++ b/docs/admin-api-openapi.yaml @@ -0,0 +1,336 @@ +openapi: 3.1.0 +info: + title: fs Admin API + version: 1.0.0 + description: | + JSON admin API for managing local users and policies. + + Notes: + - Base path is `/_admin/v1`. + - Requests must be AWS SigV4 signed. + - Only the bootstrap access key is authorized for admin endpoints. +servers: + - url: http://localhost:2600 + description: Local development +security: + - AwsSigV4: [] +paths: + /_admin/v1/users: + post: + summary: Create user + operationId: createUser + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/CreateUserRequest' + responses: + '201': + description: User created + content: + application/json: + schema: + $ref: '#/components/schemas/UserResponse' + '400': + $ref: '#/components/responses/InvalidRequest' + '403': + $ref: '#/components/responses/Forbidden' + '409': + $ref: '#/components/responses/UserAlreadyExists' + '503': + $ref: '#/components/responses/AuthDisabled' + '500': + $ref: '#/components/responses/InternalError' + get: + summary: List users + operationId: listUsers + parameters: + - name: limit + in: query + required: false + schema: + type: integer + minimum: 1 + maximum: 1000 + default: 100 + - name: cursor + in: query + required: false + schema: + type: string + responses: + '200': + description: User summaries + content: + application/json: + schema: + $ref: '#/components/schemas/UserListResponse' + '400': + $ref: '#/components/responses/InvalidRequest' + '403': + $ref: '#/components/responses/Forbidden' + '503': + $ref: '#/components/responses/AuthDisabled' + '500': + $ref: '#/components/responses/InternalError' + + /_admin/v1/users/{accessKeyId}: + get: + summary: Get user with policy + operationId: getUser + parameters: + - $ref: '#/components/parameters/AccessKeyId' + responses: + '200': + description: User details + content: + application/json: + schema: + $ref: '#/components/schemas/UserResponse' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/UserNotFound' + '503': + $ref: '#/components/responses/AuthDisabled' + '500': + $ref: '#/components/responses/InternalError' + delete: + summary: Delete user + operationId: deleteUser + parameters: + - $ref: '#/components/parameters/AccessKeyId' + responses: + '204': + description: User deleted + '400': + $ref: '#/components/responses/InvalidRequest' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/UserNotFound' + '503': + $ref: '#/components/responses/AuthDisabled' + '500': + $ref: '#/components/responses/InternalError' + + /_admin/v1/users/{accessKeyId}/policy: + put: + summary: Replace user policy + operationId: setUserPolicy + parameters: + - $ref: '#/components/parameters/AccessKeyId' + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/SetPolicyRequest' + responses: + '200': + description: User details with updated policy + content: + application/json: + schema: + $ref: '#/components/schemas/UserResponse' + '400': + $ref: '#/components/responses/InvalidRequest' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/UserNotFound' + '503': + $ref: '#/components/responses/AuthDisabled' + '500': + $ref: '#/components/responses/InternalError' + + /_admin/v1/users/{accessKeyId}/status: + put: + summary: Set user status + operationId: setUserStatus + parameters: + - $ref: '#/components/parameters/AccessKeyId' + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/SetStatusRequest' + responses: + '200': + description: User details with updated status + content: + application/json: + schema: + $ref: '#/components/schemas/UserResponse' + '400': + $ref: '#/components/responses/InvalidRequest' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/UserNotFound' + '503': + $ref: '#/components/responses/AuthDisabled' + '500': + $ref: '#/components/responses/InternalError' + +components: + securitySchemes: + AwsSigV4: + type: apiKey + in: header + name: Authorization + description: | + AWS Signature Version 4 headers are required (`Authorization`, `x-amz-date`, + and for payload-signed requests `x-amz-content-sha256`). + Only bootstrap credential is authorized for admin endpoints. + parameters: + AccessKeyId: + name: accessKeyId + in: path + required: true + schema: + type: string + description: User access key ID + responses: + InvalidRequest: + description: Invalid request input + content: + application/json: + schema: + $ref: '#/components/schemas/AdminError' + Forbidden: + description: Authenticated but not allowed + content: + application/json: + schema: + $ref: '#/components/schemas/AdminError' + UserAlreadyExists: + description: User already exists + content: + application/json: + schema: + $ref: '#/components/schemas/AdminError' + UserNotFound: + description: User not found + content: + application/json: + schema: + $ref: '#/components/schemas/AdminError' + AuthDisabled: + description: Authentication subsystem disabled + content: + application/json: + schema: + $ref: '#/components/schemas/AdminError' + InternalError: + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/AdminError' + schemas: + AdminError: + type: object + properties: + code: + type: string + message: + type: string + requestId: + type: string + required: [code, message] + PolicyStatement: + type: object + properties: + effect: + type: string + enum: [allow, deny] + actions: + type: array + items: + type: string + minItems: 1 + bucket: + type: string + default: "*" + prefix: + type: string + default: "*" + required: [effect, actions] + Policy: + type: object + properties: + principal: + type: string + description: Server-managed; overwritten with target access key ID. + statements: + type: array + items: + $ref: '#/components/schemas/PolicyStatement' + minItems: 1 + required: [statements] + CreateUserRequest: + type: object + properties: + accessKeyId: + type: string + secretKey: + type: string + description: If omitted, server generates one. + status: + type: string + enum: [active, disabled] + default: active + policy: + $ref: '#/components/schemas/Policy' + required: [accessKeyId, policy] + SetPolicyRequest: + type: object + properties: + policy: + $ref: '#/components/schemas/Policy' + required: [policy] + SetStatusRequest: + type: object + properties: + status: + type: string + enum: [active, disabled] + required: [status] + UserListItem: + type: object + properties: + accessKeyId: + type: string + status: + type: string + enum: [active, disabled] + createdAt: + type: integer + format: int64 + updatedAt: + type: integer + format: int64 + required: [accessKeyId, status, createdAt, updatedAt] + UserListResponse: + type: object + properties: + items: + type: array + items: + $ref: '#/components/schemas/UserListItem' + nextCursor: + type: string + required: [items] + UserResponse: + allOf: + - $ref: '#/components/schemas/UserListItem' + - type: object + properties: + policy: + $ref: '#/components/schemas/Policy' + secretKey: + type: string + description: Returned only on create. diff --git a/docs/s3-compatibility.md b/docs/s3-compatibility.md new file mode 100644 index 0000000..0850239 --- /dev/null +++ b/docs/s3-compatibility.md @@ -0,0 +1,53 @@ +# S3 Compatibility Matrix + +This project is S3-compatible for a focused subset of operations. + +## Implemented + +### Service and account +- `GET /` list buckets + +### Bucket +- `PUT /{bucket}` create bucket +- `HEAD /{bucket}` head bucket +- `DELETE /{bucket}` delete bucket (must be empty) +- `GET /{bucket}?list-type=2...` list objects v2 +- `GET /{bucket}?location` get bucket location +- `POST /{bucket}?delete` delete multiple objects + +### Object +- `PUT /{bucket}/{key}` put object +- `GET /{bucket}/{key}` get object +- `HEAD /{bucket}/{key}` head object +- `DELETE /{bucket}/{key}` delete object +- `GET /{bucket}/{key}` supports single-range requests + +### Multipart upload +- `POST /{bucket}/{key}?uploads` initiate +- `PUT /{bucket}/{key}?uploadId=...&partNumber=N` upload part +- `GET /{bucket}/{key}?uploadId=...` list parts +- `POST /{bucket}/{key}?uploadId=...` complete +- `DELETE /{bucket}/{key}?uploadId=...` abort + +### Authentication +- AWS SigV4 header auth +- AWS SigV4 presigned query auth +- `aws-chunked` payload decode for streaming uploads + +## Partially Implemented / Differences +- Exact parity with AWS S3 error codes/headers is still evolving. +- Some S3 edge-case behaviors may differ (especially uncommon query/header combinations). +- Admin API is custom JSON (`/_admin/v1/*`). + +## Not Implemented (Current) +- Bucket versioning +- Lifecycle rules +- Replication +- Object lock / legal hold / retention +- SSE-S3 / SSE-KMS / SSE-C +- ACL APIs and IAM-compatible policy APIs +- STS / temporary credentials +- Event notifications +- Tagging APIs +- CORS APIs +- Website hosting APIs diff --git a/logging/logging.go b/logging/logging.go index 28e1985..3600da7 100644 --- a/logging/logging.go +++ b/logging/logging.go @@ -1,6 +1,7 @@ package logging import ( + "fs/metrics" "log/slog" "net/http" "os" @@ -9,6 +10,7 @@ import ( "strings" "time" + "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" ) @@ -86,6 +88,11 @@ func HTTPMiddleware(logger *slog.Logger, cfg Config) func(http.Handler) http.Han return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) + op := metricOperationLabel(r) + metrics.Default.IncHTTPInFlightOp(op) + defer func() { + metrics.Default.DecHTTPInFlightOp(op) + }() requestID := middleware.GetReqID(r.Context()) if requestID != "" { ww.Header().Set("x-amz-request-id", requestID) @@ -93,15 +100,18 @@ func HTTPMiddleware(logger *slog.Logger, cfg Config) func(http.Handler) http.Han next.ServeHTTP(ww, r) - if !cfg.Audit && !cfg.DebugMode { - return - } - elapsed := time.Since(start) status := ww.Status() if status == 0 { status = http.StatusOK } + route := metricRouteLabel(r) + metrics.Default.ObserveHTTPRequestDetailed(r.Method, route, op, status, elapsed, ww.BytesWritten()) + + if !cfg.Audit && !cfg.DebugMode { + return + } + attrs := []any{ "method", r.Method, "path", r.URL.Path, @@ -131,6 +141,46 @@ func HTTPMiddleware(logger *slog.Logger, cfg Config) func(http.Handler) http.Han } } +func metricRouteLabel(r *http.Request) string { + if r == nil || r.URL == nil { + return "/unknown" + } + + if routeCtx := chi.RouteContext(r.Context()); routeCtx != nil { + if pattern := strings.TrimSpace(routeCtx.RoutePattern()); pattern != "" { + return pattern + } + } + + path := strings.TrimSpace(r.URL.Path) + if path == "" || path == "/" { + return "/" + } + if path == "/healthz" || path == "/metrics" { + return path + } + + trimmed := strings.Trim(path, "/") + if trimmed == "" { + return "/" + } + if !strings.Contains(trimmed, "/") { + return "/{bucket}" + } + return "/{bucket}/*" +} + +func metricOperationLabel(r *http.Request) string { + if r == nil { + return "other" + } + isDeletePost := false + if r.Method == http.MethodPost && r.URL != nil { + _, isDeletePost = r.URL.Query()["delete"] + } + return metrics.NormalizeHTTPOperation(r.Method, isDeletePost) +} + func envBool(key string, defaultValue bool) bool { raw := os.Getenv(key) if raw == "" { diff --git a/logging/logging_metrics_test.go b/logging/logging_metrics_test.go new file mode 100644 index 0000000..2a45b4b --- /dev/null +++ b/logging/logging_metrics_test.go @@ -0,0 +1,30 @@ +package logging + +import ( + "net/http/httptest" + "testing" +) + +func TestMetricRouteLabelFallbacks(t *testing.T) { + testCases := []struct { + name string + path string + want string + }{ + {name: "root", path: "/", want: "/"}, + {name: "health", path: "/healthz", want: "/healthz"}, + {name: "metrics", path: "/metrics", want: "/metrics"}, + {name: "bucket", path: "/some-bucket", want: "/{bucket}"}, + {name: "object", path: "/some-bucket/private/path/file.jpg", want: "/{bucket}/*"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest("GET", tc.path, nil) + got := metricRouteLabel(req) + if got != tc.want { + t.Fatalf("metricRouteLabel(%q) = %q, want %q", tc.path, got, tc.want) + } + }) + } +} diff --git a/main.go b/main.go index 70ed5ba..e776c14 100644 --- a/main.go +++ b/main.go @@ -3,6 +3,7 @@ package main import ( "context" "fs/api" + "fs/auth" "fs/logging" "fs/metadata" "fs/service" @@ -19,6 +20,16 @@ import ( func main() { config := utils.NewConfig() logConfig := logging.ConfigFromValues(config.LogLevel, config.LogFormat, config.AuditLog) + authConfig := auth.ConfigFromValues( + config.AuthEnabled, + config.AuthRegion, + config.AuthSkew, + config.AuthMaxPresign, + config.AuthMasterKey, + config.AuthBootstrapAccessKey, + config.AuthBootstrapSecretKey, + config.AuthBootstrapPolicy, + ) logger := logging.NewLogger(logConfig) logger.Info("boot", "log_level", logConfig.LevelName, @@ -26,6 +37,9 @@ func main() { "audit_log", logConfig.Audit, "data_path", config.DataPath, "multipart_retention_hours", int(config.MultipartCleanupRetention/time.Hour), + "auth_enabled", authConfig.Enabled, + "auth_region", authConfig.Region, + "admin_api_enabled", config.AdminAPIEnabled, ) if err := os.MkdirAll(config.DataPath, 0o755); err != nil { @@ -47,7 +61,19 @@ func main() { } objectService := service.NewObjectService(metadataHandler, blobHandler, config.MultipartCleanupRetention) - handler := api.NewHandler(objectService, logger, logConfig) + authService, err := auth.NewService(authConfig, metadataHandler) + if err != nil { + _ = metadataHandler.Close() + logger.Error("failed_to_initialize_auth_service", "error", err) + return + } + if err := authService.EnsureBootstrap(); err != nil { + _ = metadataHandler.Close() + logger.Error("failed_to_ensure_bootstrap_auth_identity", "error", err) + return + } + + handler := api.NewHandler(objectService, logger, logConfig, authService, config.AdminAPIEnabled) addr := config.Address + ":" + strconv.Itoa(config.Port) ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) diff --git a/metadata/metadata.go b/metadata/metadata.go index 916ddfe..ed84c55 100644 --- a/metadata/metadata.go +++ b/metadata/metadata.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "fs/metrics" "fs/models" "net" "regexp" @@ -23,17 +24,21 @@ type MetadataHandler struct { var systemIndex = []byte("__SYSTEM_BUCKETS__") var multipartUploadIndex = []byte("__MULTIPART_UPLOADS__") var multipartUploadPartsIndex = []byte("__MULTIPART_UPLOAD_PARTS__") +var authIdentitiesIndex = []byte("__AUTH_IDENTITIES__") +var authPoliciesIndex = []byte("__AUTH_POLICIES__") var validBucketName = regexp.MustCompile(`^[a-z0-9.-]+$`) var ( - ErrInvalidBucketName = errors.New("invalid bucket name") - ErrBucketAlreadyExists = errors.New("bucket already exists") - ErrBucketNotFound = errors.New("bucket not found") - ErrBucketNotEmpty = errors.New("bucket not empty") - ErrObjectNotFound = errors.New("object not found") - ErrMultipartNotFound = errors.New("multipart upload not found") - ErrMultipartNotPending = errors.New("multipart upload is not pending") + ErrInvalidBucketName = errors.New("invalid bucket name") + ErrBucketAlreadyExists = errors.New("bucket already exists") + ErrBucketNotFound = errors.New("bucket not found") + ErrBucketNotEmpty = errors.New("bucket not empty") + ErrObjectNotFound = errors.New("object not found") + ErrMultipartNotFound = errors.New("multipart upload not found") + ErrMultipartNotPending = errors.New("multipart upload is not pending") + ErrAuthIdentityNotFound = errors.New("auth identity not found") + ErrAuthPolicyNotFound = errors.New("auth policy not found") ) func NewMetadataHandler(dbPath string) (*MetadataHandler, error) { @@ -43,7 +48,7 @@ func NewMetadataHandler(dbPath string) (*MetadataHandler, error) { } h := &MetadataHandler{db: db} - err = h.db.Update(func(tx *bbolt.Tx) error { + err = h.update(func(tx *bbolt.Tx) error { _, err := tx.CreateBucketIfNotExists(systemIndex) return err }) @@ -51,7 +56,7 @@ func NewMetadataHandler(dbPath string) (*MetadataHandler, error) { _ = db.Close() return nil, err } - err = h.db.Update(func(tx *bbolt.Tx) error { + err = h.update(func(tx *bbolt.Tx) error { _, err := tx.CreateBucketIfNotExists(multipartUploadIndex) return err }) @@ -59,7 +64,7 @@ func NewMetadataHandler(dbPath string) (*MetadataHandler, error) { _ = db.Close() return nil, err } - err = h.db.Update(func(tx *bbolt.Tx) error { + err = h.update(func(tx *bbolt.Tx) error { _, err := tx.CreateBucketIfNotExists(multipartUploadPartsIndex) return err }) @@ -67,6 +72,22 @@ func NewMetadataHandler(dbPath string) (*MetadataHandler, error) { _ = db.Close() return nil, err } + err = h.update(func(tx *bbolt.Tx) error { + _, err := tx.CreateBucketIfNotExists(authIdentitiesIndex) + return err + }) + if err != nil { + _ = db.Close() + return nil, err + } + err = h.update(func(tx *bbolt.Tx) error { + _, err := tx.CreateBucketIfNotExists(authPoliciesIndex) + return err + }) + if err != nil { + _ = db.Close() + return nil, err + } return h, nil } @@ -99,12 +120,207 @@ func (h *MetadataHandler) Close() error { return h.db.Close() } +func (h *MetadataHandler) view(fn func(tx *bbolt.Tx) error) error { + start := time.Now() + err := h.db.View(fn) + metrics.Default.ObserveMetadataTx("view", time.Since(start), err == nil) + return err +} + +func (h *MetadataHandler) update(fn func(tx *bbolt.Tx) error) error { + start := time.Now() + err := h.db.Update(fn) + metrics.Default.ObserveMetadataTx("update", time.Since(start), err == nil) + return err +} + +func (h *MetadataHandler) PutAuthIdentity(identity *models.AuthIdentity) error { + if identity == nil { + return errors.New("auth identity is required") + } + if strings.TrimSpace(identity.AccessKeyID) == "" { + return errors.New("access key id is required") + } + return h.update(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(authIdentitiesIndex) + if bucket == nil { + return errors.New("auth identities index not found") + } + payload, err := json.Marshal(identity) + if err != nil { + return err + } + return bucket.Put([]byte(identity.AccessKeyID), payload) + }) +} + +func (h *MetadataHandler) DeleteAuthIdentity(accessKeyID string) error { + accessKeyID = strings.TrimSpace(accessKeyID) + if accessKeyID == "" { + return errors.New("access key id is required") + } + return h.update(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(authIdentitiesIndex) + if bucket == nil { + return errors.New("auth identities index not found") + } + if bucket.Get([]byte(accessKeyID)) == nil { + return fmt.Errorf("%w: %s", ErrAuthIdentityNotFound, accessKeyID) + } + return bucket.Delete([]byte(accessKeyID)) + }) +} + +func (h *MetadataHandler) GetAuthIdentity(accessKeyID string) (*models.AuthIdentity, error) { + accessKeyID = strings.TrimSpace(accessKeyID) + if accessKeyID == "" { + return nil, errors.New("access key id is required") + } + + var identity *models.AuthIdentity + err := h.view(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(authIdentitiesIndex) + if bucket == nil { + return errors.New("auth identities index not found") + } + payload := bucket.Get([]byte(accessKeyID)) + if payload == nil { + return fmt.Errorf("%w: %s", ErrAuthIdentityNotFound, accessKeyID) + } + record := models.AuthIdentity{} + if err := json.Unmarshal(payload, &record); err != nil { + return err + } + identity = &record + return nil + }) + if err != nil { + return nil, err + } + return identity, nil +} + +func (h *MetadataHandler) PutAuthPolicy(policy *models.AuthPolicy) error { + if policy == nil { + return errors.New("auth policy is required") + } + principal := strings.TrimSpace(policy.Principal) + if principal == "" { + return errors.New("auth policy principal is required") + } + policy.Principal = principal + return h.update(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(authPoliciesIndex) + if bucket == nil { + return errors.New("auth policies index not found") + } + payload, err := json.Marshal(policy) + if err != nil { + return err + } + return bucket.Put([]byte(principal), payload) + }) +} + +func (h *MetadataHandler) DeleteAuthPolicy(accessKeyID string) error { + accessKeyID = strings.TrimSpace(accessKeyID) + if accessKeyID == "" { + return errors.New("access key id is required") + } + return h.update(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(authPoliciesIndex) + if bucket == nil { + return errors.New("auth policies index not found") + } + if bucket.Get([]byte(accessKeyID)) == nil { + return fmt.Errorf("%w: %s", ErrAuthPolicyNotFound, accessKeyID) + } + return bucket.Delete([]byte(accessKeyID)) + }) +} + +func (h *MetadataHandler) GetAuthPolicy(accessKeyID string) (*models.AuthPolicy, error) { + accessKeyID = strings.TrimSpace(accessKeyID) + if accessKeyID == "" { + return nil, errors.New("access key id is required") + } + + var policy *models.AuthPolicy + err := h.view(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(authPoliciesIndex) + if bucket == nil { + return errors.New("auth policies index not found") + } + payload := bucket.Get([]byte(accessKeyID)) + if payload == nil { + return fmt.Errorf("%w: %s", ErrAuthPolicyNotFound, accessKeyID) + } + record := models.AuthPolicy{} + if err := json.Unmarshal(payload, &record); err != nil { + return err + } + policy = &record + return nil + }) + if err != nil { + return nil, err + } + return policy, nil +} + +func (h *MetadataHandler) ListAuthIdentities(limit int, after string) ([]models.AuthIdentity, string, error) { + if limit <= 0 { + limit = 100 + } + after = strings.TrimSpace(after) + + identities := make([]models.AuthIdentity, 0, limit) + nextCursor := "" + + err := h.view(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(authIdentitiesIndex) + if bucket == nil { + return errors.New("auth identities index not found") + } + + cursor := bucket.Cursor() + var k, v []byte + if after == "" { + k, v = cursor.First() + } else { + k, v = cursor.Seek([]byte(after)) + if k != nil && string(k) == after { + k, v = cursor.Next() + } + } + + count := 0 + for ; k != nil; k, v = cursor.Next() { + if count >= limit { + nextCursor = string(k) + break + } + record := models.AuthIdentity{} + if err := json.Unmarshal(v, &record); err != nil { + return err + } + identities = append(identities, record) + count++ + } + return nil + }) + if err != nil { + return nil, "", err + } + return identities, nextCursor, nil +} + func (h *MetadataHandler) CreateBucket(bucketName string) error { if !isValidBucketName(bucketName) { return fmt.Errorf("%w: %s", ErrInvalidBucketName, bucketName) } - err := h.db.Update(func(tx *bbolt.Tx) error { + err := h.update(func(tx *bbolt.Tx) error { indexBucket, err := tx.CreateBucketIfNotExists([]byte(systemIndex)) if err != nil { return err @@ -136,7 +352,7 @@ func (h *MetadataHandler) DeleteBucket(bucketName string) error { return fmt.Errorf("%w: %s", ErrInvalidBucketName, bucketName) } - err := h.db.Update(func(tx *bbolt.Tx) error { + err := h.update(func(tx *bbolt.Tx) error { indexBucket, err := tx.CreateBucketIfNotExists([]byte(systemIndex)) if err != nil { return err @@ -183,7 +399,7 @@ func (h *MetadataHandler) DeleteBucket(bucketName string) error { func (h *MetadataHandler) ListBuckets() ([]string, error) { buckets := []string{} - err := h.db.View(func(tx *bbolt.Tx) error { + err := h.view(func(tx *bbolt.Tx) error { systemIndexBucket := tx.Bucket([]byte(systemIndex)) if systemIndexBucket == nil { return errors.New("system index not found") @@ -203,7 +419,7 @@ func (h *MetadataHandler) ListBuckets() ([]string, error) { func (h *MetadataHandler) GetBucketManifest(bucketName string) (*models.BucketManifest, error) { var manifest *models.BucketManifest - err := h.db.View(func(tx *bbolt.Tx) error { + err := h.view(func(tx *bbolt.Tx) error { systemIndexBucket := tx.Bucket([]byte(systemIndex)) if systemIndexBucket == nil { return errors.New("system index not found") @@ -233,7 +449,7 @@ func (h *MetadataHandler) PutManifest(manifest *models.ObjectManifest) error { return err } - err := h.db.Update(func(tx *bbolt.Tx) error { + err := h.update(func(tx *bbolt.Tx) error { data, err := json.Marshal(manifest) if err != nil { return err @@ -253,7 +469,7 @@ func (h *MetadataHandler) PutManifest(manifest *models.ObjectManifest) error { func (h *MetadataHandler) GetManifest(bucket, key string) (*models.ObjectManifest, error) { var manifest *models.ObjectManifest - err := h.db.View(func(tx *bbolt.Tx) error { + err := h.view(func(tx *bbolt.Tx) error { metadataBucket := tx.Bucket([]byte(bucket)) if metadataBucket == nil { return fmt.Errorf("%w: %s", ErrBucketNotFound, bucket) @@ -280,7 +496,7 @@ func (h *MetadataHandler) ListObjects(bucket, prefix string) ([]*models.ObjectMa var objects []*models.ObjectManifest - err := h.db.View(func(tx *bbolt.Tx) error { + err := h.view(func(tx *bbolt.Tx) error { systemIndexBucket := tx.Bucket([]byte(systemIndex)) if systemIndexBucket == nil { return errors.New("system index not found") @@ -320,7 +536,7 @@ func (h *MetadataHandler) ForEachObjectFrom(bucket, startKey string, fn func(*mo return errors.New("object callback is required") } - return h.db.View(func(tx *bbolt.Tx) error { + return h.view(func(tx *bbolt.Tx) error { systemIndexBucket := tx.Bucket([]byte(systemIndex)) if systemIndexBucket == nil { return errors.New("system index not found") @@ -360,7 +576,7 @@ func (h *MetadataHandler) DeleteManifest(bucket, key string) error { return err } - err := h.db.Update(func(tx *bbolt.Tx) error { + err := h.update(func(tx *bbolt.Tx) error { metadataBucket := tx.Bucket([]byte(bucket)) if metadataBucket == nil { return fmt.Errorf("%w: %s", ErrBucketNotFound, bucket) @@ -377,7 +593,7 @@ func (h *MetadataHandler) DeleteManifest(bucket, key string) error { func (h *MetadataHandler) DeleteManifests(bucket string, keys []string) ([]string, error) { deleted := make([]string, 0, len(keys)) - err := h.db.Update(func(tx *bbolt.Tx) error { + err := h.update(func(tx *bbolt.Tx) error { metadataBucket := tx.Bucket([]byte(bucket)) if metadataBucket == nil { return fmt.Errorf("%w: %s", ErrBucketNotFound, bucket) @@ -405,7 +621,7 @@ func (h *MetadataHandler) DeleteManifests(bucket string, keys []string) ([]strin func (h *MetadataHandler) CreateMultipartUpload(bucket, key string) (*models.MultipartUpload, error) { var upload *models.MultipartUpload - err := h.db.View(func(tx *bbolt.Tx) error { + err := h.view(func(tx *bbolt.Tx) error { systemIndexBucket := tx.Bucket([]byte(systemIndex)) if systemIndexBucket == nil { return errors.New("system index not found") @@ -428,7 +644,7 @@ func (h *MetadataHandler) CreateMultipartUpload(bucket, key string) (*models.Mul State: "pending", } - err = h.db.Update(func(tx *bbolt.Tx) error { + err = h.update(func(tx *bbolt.Tx) error { multipartUploadBucket := tx.Bucket([]byte(multipartUploadIndex)) if multipartUploadBucket == nil { return errors.New("multipart upload index not found") @@ -523,7 +739,7 @@ func deleteMultipartPartsByUploadID(tx *bbolt.Tx, uploadID string) error { func (h *MetadataHandler) GetMultipartUpload(uploadID string) (*models.MultipartUpload, error) { var upload *models.MultipartUpload - err := h.db.View(func(tx *bbolt.Tx) error { + err := h.view(func(tx *bbolt.Tx) error { var err error upload, _, err = getMultipartUploadFromTx(tx, uploadID) if err != nil { @@ -541,7 +757,7 @@ func (h *MetadataHandler) PutMultipartPart(uploadID string, part models.Uploaded return fmt.Errorf("invalid part number: %d", part.PartNumber) } - err := h.db.Update(func(tx *bbolt.Tx) error { + err := h.update(func(tx *bbolt.Tx) error { upload, _, err := getMultipartUploadFromTx(tx, uploadID) if err != nil { return err @@ -570,7 +786,7 @@ func (h *MetadataHandler) PutMultipartPart(uploadID string, part models.Uploaded func (h *MetadataHandler) ListMultipartParts(uploadID string) ([]models.UploadedPart, error) { parts := make([]models.UploadedPart, 0) - err := h.db.View(func(tx *bbolt.Tx) error { + err := h.view(func(tx *bbolt.Tx) error { if _, _, err := getMultipartUploadFromTx(tx, uploadID); err != nil { return err } @@ -604,7 +820,7 @@ func (h *MetadataHandler) CompleteMultipartUpload(uploadID string, final *models return errors.New("final object manifest is required") } - err := h.db.Update(func(tx *bbolt.Tx) error { + err := h.update(func(tx *bbolt.Tx) error { upload, multipartUploadBucket, err := getMultipartUploadFromTx(tx, uploadID) if err != nil { return err @@ -643,7 +859,7 @@ func (h *MetadataHandler) CompleteMultipartUpload(uploadID string, final *models return nil } func (h *MetadataHandler) AbortMultipartUpload(uploadID string) error { - err := h.db.Update(func(tx *bbolt.Tx) error { + err := h.update(func(tx *bbolt.Tx) error { upload, multipartUploadBucket, err := getMultipartUploadFromTx(tx, uploadID) if err != nil { return err @@ -673,7 +889,7 @@ func (h *MetadataHandler) CleanupMultipartUploads(retention time.Duration) (int, } cleaned := 0 - err := h.db.Update(func(tx *bbolt.Tx) error { + err := h.update(func(tx *bbolt.Tx) error { uploadsBucket, err := getMultipartUploadBucket(tx) if err != nil { return err @@ -723,7 +939,7 @@ func (h *MetadataHandler) GetReferencedChunkSet() (map[string]struct{}, error) { chunkSet := make(map[string]struct{}) pendingUploadSet := make(map[string]struct{}) - err := h.db.View(func(tx *bbolt.Tx) error { + err := h.view(func(tx *bbolt.Tx) error { systemIndexBucket := tx.Bucket([]byte(systemIndex)) if systemIndexBucket == nil { return errors.New("system index not found") diff --git a/metrics/metrics.go b/metrics/metrics.go new file mode 100644 index 0000000..7df70b8 --- /dev/null +++ b/metrics/metrics.go @@ -0,0 +1,795 @@ +package metrics + +import ( + "fmt" + "os" + "runtime" + "sort" + "strconv" + "strings" + "sync" + "sync/atomic" + "syscall" + "time" +) + +var defaultBuckets = []float64{ + 0.0005, 0.001, 0.0025, 0.005, 0.01, + 0.025, 0.05, 0.1, 0.25, 0.5, + 1, 2.5, 5, 10, +} + +var lockBuckets = []float64{ + 0.000001, 0.000005, 0.00001, 0.00005, + 0.0001, 0.0005, 0.001, 0.005, 0.01, + 0.025, 0.05, 0.1, 0.25, 0.5, 1, +} + +var batchBuckets = []float64{1, 2, 4, 8, 16, 32, 64, 100, 128, 256, 512, 1000, 5000} + +var Default = NewRegistry() + +type histogram struct { + bounds []float64 + counts []uint64 + sum float64 + count uint64 +} + +func newHistogram(bounds []float64) *histogram { + cloned := make([]float64, len(bounds)) + copy(cloned, bounds) + return &histogram{ + bounds: cloned, + counts: make([]uint64, len(bounds)+1), + } +} + +func (h *histogram) observe(v float64) { + h.count++ + h.sum += v + for i, bound := range h.bounds { + if v <= bound { + h.counts[i]++ + return + } + } + h.counts[len(h.counts)-1]++ +} + +func (h *histogram) snapshot() (bounds []float64, counts []uint64, sum float64, count uint64) { + bounds = make([]float64, len(h.bounds)) + copy(bounds, h.bounds) + counts = make([]uint64, len(h.counts)) + copy(counts, h.counts) + return bounds, counts, h.sum, h.count +} + +type Registry struct { + startedAt time.Time + + httpInFlight atomic.Int64 + + connectionPoolActive atomic.Int64 + connectionPoolMax atomic.Int64 + connectionPoolWaits atomic.Uint64 + + requestQueueLength atomic.Int64 + + mu sync.Mutex + + httpRequestsRoute map[string]uint64 + httpResponseBytesRoute map[string]uint64 + httpDurationRoute map[string]*histogram + + httpRequestsOp map[string]uint64 + httpDurationOp map[string]*histogram + httpInFlightOp map[string]int64 + + authRequests map[string]uint64 + + serviceOps map[string]uint64 + serviceDuration map[string]*histogram + + dbTxTotal map[string]uint64 + dbTxDuration map[string]*histogram + + blobOps map[string]uint64 + blobBytes map[string]uint64 + blobDuration map[string]*histogram + + lockWait map[string]*histogram + lockHold map[string]*histogram + + cacheHits map[string]uint64 + cacheMisses map[string]uint64 + + batchSize *histogram + + retries map[string]uint64 + errors map[string]uint64 + + gcRuns map[string]uint64 + gcDuration *histogram + gcDeletedChunks uint64 + gcDeleteErrors uint64 + gcCleanedUpload uint64 +} + +func NewRegistry() *Registry { + return &Registry{ + startedAt: time.Now(), + httpRequestsRoute: make(map[string]uint64), + httpResponseBytesRoute: make(map[string]uint64), + httpDurationRoute: make(map[string]*histogram), + httpRequestsOp: make(map[string]uint64), + httpDurationOp: make(map[string]*histogram), + httpInFlightOp: make(map[string]int64), + authRequests: make(map[string]uint64), + serviceOps: make(map[string]uint64), + serviceDuration: make(map[string]*histogram), + dbTxTotal: make(map[string]uint64), + dbTxDuration: make(map[string]*histogram), + blobOps: make(map[string]uint64), + blobBytes: make(map[string]uint64), + blobDuration: make(map[string]*histogram), + lockWait: make(map[string]*histogram), + lockHold: make(map[string]*histogram), + cacheHits: make(map[string]uint64), + cacheMisses: make(map[string]uint64), + batchSize: newHistogram(batchBuckets), + retries: make(map[string]uint64), + errors: make(map[string]uint64), + gcRuns: make(map[string]uint64), + gcDuration: newHistogram(defaultBuckets), + } +} + +func NormalizeHTTPOperation(method string, isDeletePost bool) string { + switch strings.ToUpper(strings.TrimSpace(method)) { + case "GET": + return "get" + case "PUT": + return "put" + case "DELETE": + return "delete" + case "HEAD": + return "head" + case "POST": + if isDeletePost { + return "delete" + } + return "put" + default: + return "other" + } +} + +func statusResult(status int) string { + if status >= 200 && status < 400 { + return "ok" + } + return "error" +} + +func normalizeRoute(route string) string { + route = strings.TrimSpace(route) + if route == "" { + return "/unknown" + } + return route +} + +func normalizeOp(op string) string { + op = strings.ToLower(strings.TrimSpace(op)) + if op == "" { + return "other" + } + return op +} + +func (r *Registry) IncHTTPInFlight() { + r.httpInFlight.Add(1) +} + +func (r *Registry) DecHTTPInFlight() { + r.httpInFlight.Add(-1) +} + +func (r *Registry) IncHTTPInFlightOp(op string) { + r.httpInFlight.Add(1) + op = normalizeOp(op) + r.mu.Lock() + r.httpInFlightOp[op]++ + r.mu.Unlock() +} + +func (r *Registry) DecHTTPInFlightOp(op string) { + r.httpInFlight.Add(-1) + op = normalizeOp(op) + r.mu.Lock() + r.httpInFlightOp[op]-- + if r.httpInFlightOp[op] < 0 { + r.httpInFlightOp[op] = 0 + } + r.mu.Unlock() +} + +func (r *Registry) ObserveHTTPRequest(method, route string, status int, d time.Duration, responseBytes int) { + op := NormalizeHTTPOperation(method, false) + r.ObserveHTTPRequestDetailed(method, route, op, status, d, responseBytes) +} + +func (r *Registry) ObserveHTTPRequestDetailed(method, route, op string, status int, d time.Duration, responseBytes int) { + route = normalizeRoute(route) + op = normalizeOp(op) + result := statusResult(status) + + routeKey := method + "|" + route + "|" + strconv.Itoa(status) + routeDurKey := method + "|" + route + opKey := op + "|" + result + + r.mu.Lock() + r.httpRequestsRoute[routeKey]++ + if responseBytes > 0 { + r.httpResponseBytesRoute[routeKey] += uint64(responseBytes) + } + hRoute := r.httpDurationRoute[routeDurKey] + if hRoute == nil { + hRoute = newHistogram(defaultBuckets) + r.httpDurationRoute[routeDurKey] = hRoute + } + hRoute.observe(d.Seconds()) + + r.httpRequestsOp[opKey]++ + hOp := r.httpDurationOp[opKey] + if hOp == nil { + hOp = newHistogram(defaultBuckets) + r.httpDurationOp[opKey] = hOp + } + hOp.observe(d.Seconds()) + r.mu.Unlock() +} + +func (r *Registry) ObserveAuth(result, authType, reason string) { + authType = strings.TrimSpace(authType) + if authType == "" { + authType = "unknown" + } + reason = strings.TrimSpace(reason) + if reason == "" { + reason = "none" + } + key := result + "|" + authType + "|" + reason + r.mu.Lock() + r.authRequests[key]++ + r.mu.Unlock() +} + +func (r *Registry) ObserveService(operation string, d time.Duration, ok bool) { + result := "error" + if ok { + result = "ok" + } + key := operation + "|" + result + r.mu.Lock() + r.serviceOps[key]++ + h := r.serviceDuration[operation] + if h == nil { + h = newHistogram(defaultBuckets) + r.serviceDuration[operation] = h + } + h.observe(d.Seconds()) + r.mu.Unlock() +} + +func (r *Registry) ObserveMetadataTx(txType string, d time.Duration, ok bool) { + result := "error" + if ok { + result = "ok" + } + key := txType + "|" + result + r.mu.Lock() + r.dbTxTotal[key]++ + h := r.dbTxDuration[txType] + if h == nil { + h = newHistogram(defaultBuckets) + r.dbTxDuration[txType] = h + } + h.observe(d.Seconds()) + r.mu.Unlock() +} + +func (r *Registry) ObserveBlob(operation string, d time.Duration, bytes int64, ok bool, backend ...string) { + be := "disk" + if len(backend) > 0 { + candidate := strings.TrimSpace(backend[0]) + if candidate != "" { + be = strings.ToLower(candidate) + } + } + result := "error" + if ok { + result = "ok" + } + op := strings.ToLower(strings.TrimSpace(operation)) + if op == "" { + op = "unknown" + } + + histKey := op + "|" + be + "|" + result + opsKey := histKey + + r.mu.Lock() + r.blobOps[opsKey]++ + h := r.blobDuration[histKey] + if h == nil { + h = newHistogram(defaultBuckets) + r.blobDuration[histKey] = h + } + h.observe(d.Seconds()) + + if bytes > 0 { + r.blobBytes[op] += uint64(bytes) + } + r.mu.Unlock() +} + +func (r *Registry) SetConnectionPoolMax(max int) { + if max < 0 { + max = 0 + } + r.connectionPoolMax.Store(int64(max)) +} + +func (r *Registry) IncConnectionPoolActive() { + r.connectionPoolActive.Add(1) +} + +func (r *Registry) DecConnectionPoolActive() { + r.connectionPoolActive.Add(-1) +} + +func (r *Registry) IncConnectionPoolWait() { + r.connectionPoolWaits.Add(1) +} + +func (r *Registry) IncRequestQueueLength() { + r.requestQueueLength.Add(1) +} + +func (r *Registry) DecRequestQueueLength() { + r.requestQueueLength.Add(-1) +} + +func (r *Registry) ObserveLockWait(lockName string, d time.Duration) { + lockName = strings.TrimSpace(lockName) + if lockName == "" { + lockName = "unknown" + } + r.mu.Lock() + h := r.lockWait[lockName] + if h == nil { + h = newHistogram(lockBuckets) + r.lockWait[lockName] = h + } + h.observe(d.Seconds()) + r.mu.Unlock() +} + +func (r *Registry) ObserveLockHold(lockName string, d time.Duration) { + lockName = strings.TrimSpace(lockName) + if lockName == "" { + lockName = "unknown" + } + r.mu.Lock() + h := r.lockHold[lockName] + if h == nil { + h = newHistogram(lockBuckets) + r.lockHold[lockName] = h + } + h.observe(d.Seconds()) + r.mu.Unlock() +} + +func (r *Registry) ObserveCacheHit(cache string) { + cache = strings.TrimSpace(cache) + if cache == "" { + cache = "unknown" + } + r.mu.Lock() + r.cacheHits[cache]++ + r.mu.Unlock() +} + +func (r *Registry) ObserveCacheMiss(cache string) { + cache = strings.TrimSpace(cache) + if cache == "" { + cache = "unknown" + } + r.mu.Lock() + r.cacheMisses[cache]++ + r.mu.Unlock() +} + +func (r *Registry) ObserveBatchSize(size int) { + if size < 0 { + size = 0 + } + r.mu.Lock() + r.batchSize.observe(float64(size)) + r.mu.Unlock() +} + +func (r *Registry) ObserveRetry(op, reason string) { + op = normalizeOp(op) + reason = strings.TrimSpace(reason) + if reason == "" { + reason = "unknown" + } + key := op + "|" + reason + r.mu.Lock() + r.retries[key]++ + r.mu.Unlock() +} + +func (r *Registry) ObserveError(op, reason string) { + op = normalizeOp(op) + reason = strings.TrimSpace(reason) + if reason == "" { + reason = "unknown" + } + key := op + "|" + reason + r.mu.Lock() + r.errors[key]++ + r.mu.Unlock() +} + +func (r *Registry) ObserveGC(d time.Duration, deletedChunks, deleteErrors, cleanedUploads int, ok bool) { + result := "error" + if ok { + result = "ok" + } + r.mu.Lock() + r.gcRuns[result]++ + r.gcDuration.observe(d.Seconds()) + if deletedChunks > 0 { + r.gcDeletedChunks += uint64(deletedChunks) + } + if deleteErrors > 0 { + r.gcDeleteErrors += uint64(deleteErrors) + } + if cleanedUploads > 0 { + r.gcCleanedUpload += uint64(cleanedUploads) + } + r.mu.Unlock() +} + +func (r *Registry) RenderPrometheus() string { + now := time.Now() + var mem runtime.MemStats + runtime.ReadMemStats(&mem) + + r.mu.Lock() + httpReqRoute := copyCounterMap(r.httpRequestsRoute) + httpRespRoute := copyCounterMap(r.httpResponseBytesRoute) + httpDurRoute := copyHistogramMap(r.httpDurationRoute) + httpReqOp := copyCounterMap(r.httpRequestsOp) + httpDurOp := copyHistogramMap(r.httpDurationOp) + httpInFlightOp := copyIntGaugeMap(r.httpInFlightOp) + authReq := copyCounterMap(r.authRequests) + serviceOps := copyCounterMap(r.serviceOps) + serviceDur := copyHistogramMap(r.serviceDuration) + dbTx := copyCounterMap(r.dbTxTotal) + dbTxDur := copyHistogramMap(r.dbTxDuration) + blobOps := copyCounterMap(r.blobOps) + blobBytes := copyCounterMap(r.blobBytes) + blobDur := copyHistogramMap(r.blobDuration) + lockWait := copyHistogramMap(r.lockWait) + lockHold := copyHistogramMap(r.lockHold) + cacheHits := copyCounterMap(r.cacheHits) + cacheMisses := copyCounterMap(r.cacheMisses) + batchBounds, batchCounts, batchSum, batchCount := r.batchSize.snapshot() + retries := copyCounterMap(r.retries) + errorsTotal := copyCounterMap(r.errors) + gcRuns := copyCounterMap(r.gcRuns) + gcDurBounds, gcDurCounts, gcDurSum, gcDurCount := r.gcDuration.snapshot() + gcDeletedChunks := r.gcDeletedChunks + gcDeleteErrors := r.gcDeleteErrors + gcCleanedUploads := r.gcCleanedUpload + r.mu.Unlock() + + connectionActive := float64(r.connectionPoolActive.Load()) + connectionMax := float64(r.connectionPoolMax.Load()) + connectionWaits := r.connectionPoolWaits.Load() + queueLength := float64(r.requestQueueLength.Load()) + + resident, hasResident := readResidentMemoryBytes() + cpuSeconds, hasCPU := readProcessCPUSeconds() + + var b strings.Builder + + httpInFlightOp["all"] = r.httpInFlight.Load() + writeGaugeVecFromInt64(&b, "fs_http_inflight_requests", "Current in-flight HTTP requests by operation.", httpInFlightOp, "op") + writeCounterVecKV(&b, "fs_http_requests_total", "Total HTTP requests by operation and result.", httpReqOp, []string{"op", "result"}) + writeHistogramVecKV(&b, "fs_http_request_duration_seconds", "HTTP request latency by operation and result.", httpDurOp, []string{"op", "result"}) + + writeCounterVecKV(&b, "fs_http_requests_by_route_total", "Total HTTP requests by method/route/status.", httpReqRoute, []string{"method", "route", "status"}) + writeCounterVecKV(&b, "fs_http_response_bytes_total", "Total HTTP response bytes written.", httpRespRoute, []string{"method", "route", "status"}) + writeHistogramVecKV(&b, "fs_http_request_duration_by_route_seconds", "HTTP request latency by method/route.", httpDurRoute, []string{"method", "route"}) + + writeCounterVecKV(&b, "fs_auth_requests_total", "Authentication attempts by result.", authReq, []string{"result", "auth_type", "reason"}) + + writeCounterVecKV(&b, "fs_service_operations_total", "Service-level operation calls.", serviceOps, []string{"operation", "result"}) + writeHistogramVecKV(&b, "fs_service_operation_duration_seconds", "Service-level operation latency.", serviceDur, []string{"operation"}) + + writeCounterVecKV(&b, "fs_metadata_tx_total", "Metadata transaction calls.", dbTx, []string{"type", "result"}) + writeHistogramVecKV(&b, "fs_metadata_tx_duration_seconds", "Metadata transaction latency.", dbTxDur, []string{"type"}) + + writeHistogramVecKV(&b, "fs_blob_operation_duration_seconds", "Blob backend operation latency.", blobDur, []string{"op", "backend", "result"}) + writeCounterVecKV(&b, "fs_blob_operations_total", "Blob store operations.", blobOps, []string{"op", "backend", "result"}) + writeCounterVecKV(&b, "fs_blob_bytes_total", "Blob bytes processed by operation.", blobBytes, []string{"op"}) + + writeGauge(&b, "fs_connection_pool_active", "Active pooled connections.", connectionActive) + writeGauge(&b, "fs_connection_pool_max", "Maximum pooled connections.", connectionMax) + writeCounter(&b, "fs_connection_pool_waits_total", "Number of waits due to pool saturation.", connectionWaits) + + writeGauge(&b, "fs_request_queue_length", "Requests waiting for an execution slot.", queueLength) + + writeHistogramVecKV(&b, "fs_lock_wait_seconds", "Time spent waiting for locks.", lockWait, []string{"lock_name"}) + writeHistogramVecKV(&b, "fs_lock_hold_seconds", "Time locks were held.", lockHold, []string{"lock_name"}) + + writeCounterVecKV(&b, "fs_cache_hits_total", "Cache hits by cache name.", cacheHits, []string{"cache"}) + writeCounterVecKV(&b, "fs_cache_misses_total", "Cache misses by cache name.", cacheMisses, []string{"cache"}) + + writeHistogram(&b, "fs_batch_size_histogram", "Observed batch sizes.", nil, batchBounds, batchCounts, batchSum, batchCount) + + writeCounterVecKV(&b, "fs_retries_total", "Retries by operation and reason.", retries, []string{"op", "reason"}) + writeCounterVecKV(&b, "fs_errors_total", "Errors by operation and reason.", errorsTotal, []string{"op", "reason"}) + + writeCounterVecKV(&b, "fs_gc_runs_total", "Garbage collection runs.", gcRuns, []string{"result"}) + writeHistogram(&b, "fs_gc_duration_seconds", "Garbage collection runtime.", nil, gcDurBounds, gcDurCounts, gcDurSum, gcDurCount) + writeCounter(&b, "fs_gc_deleted_chunks_total", "Deleted chunks during GC.", gcDeletedChunks) + writeCounter(&b, "fs_gc_delete_errors_total", "Chunk delete errors during GC.", gcDeleteErrors) + writeCounter(&b, "fs_gc_cleaned_uploads_total", "Cleaned multipart uploads during GC.", gcCleanedUploads) + + writeGauge(&b, "fs_uptime_seconds", "Process uptime in seconds.", now.Sub(r.startedAt).Seconds()) + writeGauge(&b, "fs_runtime_goroutines", "Number of goroutines.", float64(runtime.NumGoroutine())) + writeGaugeVec(&b, "fs_runtime_memory_bytes", "Runtime memory in bytes.", map[string]float64{ + "alloc": float64(mem.Alloc), + "total": float64(mem.TotalAlloc), + "sys": float64(mem.Sys), + "heap_alloc": float64(mem.HeapAlloc), + "heap_sys": float64(mem.HeapSys), + "stack_sys": float64(mem.StackSys), + }, "type") + writeCounter(&b, "fs_runtime_gc_cycles_total", "Completed GC cycles.", uint64(mem.NumGC)) + writeCounterFloat(&b, "fs_runtime_gc_pause_seconds_total", "Total GC pause time in seconds.", float64(mem.PauseTotalNs)/1e9) + + if hasCPU { + writeCounterFloat(&b, "process_cpu_seconds_total", "Total user and system CPU time spent in seconds.", cpuSeconds) + } + if hasResident { + writeGauge(&b, "process_resident_memory_bytes", "Resident memory size in bytes.", resident) + } + + return b.String() +} + +type histogramSnapshot struct { + bounds []float64 + counts []uint64 + sum float64 + count uint64 +} + +func copyCounterMap(src map[string]uint64) map[string]uint64 { + out := make(map[string]uint64, len(src)) + for k, v := range src { + out[k] = v + } + return out +} + +func copyIntGaugeMap(src map[string]int64) map[string]int64 { + out := make(map[string]int64, len(src)) + for k, v := range src { + out[k] = v + } + return out +} + +func copyHistogramMap(src map[string]*histogram) map[string]histogramSnapshot { + out := make(map[string]histogramSnapshot, len(src)) + for k, h := range src { + bounds, counts, sum, count := h.snapshot() + out[k] = histogramSnapshot{bounds: bounds, counts: counts, sum: sum, count: count} + } + return out +} + +func writeCounter(b *strings.Builder, name, help string, value uint64) { + fmt.Fprintf(b, "# HELP %s %s\n", name, help) + fmt.Fprintf(b, "# TYPE %s counter\n", name) + fmt.Fprintf(b, "%s %d\n", name, value) +} + +func writeCounterFloat(b *strings.Builder, name, help string, value float64) { + fmt.Fprintf(b, "# HELP %s %s\n", name, help) + fmt.Fprintf(b, "# TYPE %s counter\n", name) + fmt.Fprintf(b, "%s %.9f\n", name, value) +} + +func writeGauge(b *strings.Builder, name, help string, value float64) { + fmt.Fprintf(b, "# HELP %s %s\n", name, help) + fmt.Fprintf(b, "# TYPE %s gauge\n", name) + fmt.Fprintf(b, "%s %.9f\n", name, value) +} + +func writeGaugeVec(b *strings.Builder, name, help string, values map[string]float64, labelName string) { + fmt.Fprintf(b, "# HELP %s %s\n", name, help) + fmt.Fprintf(b, "# TYPE %s gauge\n", name) + keys := make([]string, 0, len(values)) + for k := range values { + keys = append(keys, k) + } + sort.Strings(keys) + for _, key := range keys { + fmt.Fprintf(b, "%s{%s=\"%s\"} %.9f\n", name, labelName, escapeLabelValue(key), values[key]) + } +} + +func writeGaugeVecFromInt64(b *strings.Builder, name, help string, values map[string]int64, labelName string) { + fmt.Fprintf(b, "# HELP %s %s\n", name, help) + fmt.Fprintf(b, "# TYPE %s gauge\n", name) + keys := make([]string, 0, len(values)) + for k := range values { + keys = append(keys, k) + } + sort.Strings(keys) + for _, key := range keys { + fmt.Fprintf(b, "%s{%s=\"%s\"} %.9f\n", name, labelName, escapeLabelValue(key), float64(values[key])) + } +} + +func writeCounterVecKV(b *strings.Builder, name, help string, values map[string]uint64, labels []string) { + fmt.Fprintf(b, "# HELP %s %s\n", name, help) + fmt.Fprintf(b, "# TYPE %s counter\n", name) + keys := make([]string, 0, len(values)) + for k := range values { + keys = append(keys, k) + } + sort.Strings(keys) + for _, key := range keys { + parts := strings.Split(key, "|") + fmt.Fprintf(b, "%s{%s} %d\n", name, formatLabels(labels, parts), values[key]) + } +} + +func writeHistogramVecKV(b *strings.Builder, name, help string, values map[string]histogramSnapshot, labels []string) { + fmt.Fprintf(b, "# HELP %s %s\n", name, help) + fmt.Fprintf(b, "# TYPE %s histogram\n", name) + keys := make([]string, 0, len(values)) + for k := range values { + keys = append(keys, k) + } + sort.Strings(keys) + for _, key := range keys { + parts := strings.Split(key, "|") + labelsMap := make(map[string]string, len(labels)) + for i, label := range labels { + if i < len(parts) { + labelsMap[label] = parts[i] + } else { + labelsMap[label] = "" + } + } + writeHistogramWithLabelsMap(b, name, labelsMap, values[key]) + } +} + +func writeHistogram(b *strings.Builder, name, help string, labels map[string]string, bounds []float64, counts []uint64, sum float64, count uint64) { + fmt.Fprintf(b, "# HELP %s %s\n", name, help) + fmt.Fprintf(b, "# TYPE %s histogram\n", name) + writeHistogramWithLabelsMap(b, name, labels, histogramSnapshot{bounds: bounds, counts: counts, sum: sum, count: count}) +} + +func writeHistogramWithLabelsMap(b *strings.Builder, name string, labels map[string]string, s histogramSnapshot) { + var cumulative uint64 + for i, bucketCount := range s.counts { + cumulative += bucketCount + bucketLabels := cloneLabels(labels) + if i < len(s.bounds) { + bucketLabels["le"] = trimFloat(s.bounds[i]) + } else { + bucketLabels["le"] = "+Inf" + } + fmt.Fprintf(b, "%s_bucket{%s} %d\n", name, labelsToString(bucketLabels), cumulative) + } + labelsSuffix := formatLabelsSuffix(labels) + fmt.Fprintf(b, "%s_sum%s %.9f\n", name, labelsSuffix, s.sum) + fmt.Fprintf(b, "%s_count%s %d\n", name, labelsSuffix, s.count) +} + +func formatLabelsSuffix(labels map[string]string) string { + if len(labels) == 0 { + return "" + } + return "{" + labelsToString(labels) + "}" +} + +func formatLabels(keys, values []string) string { + parts := make([]string, 0, len(keys)) + for i, key := range keys { + value := "" + if i < len(values) { + value = values[i] + } + parts = append(parts, fmt.Sprintf("%s=\"%s\"", key, escapeLabelValue(value))) + } + return strings.Join(parts, ",") +} + +func labelsToString(labels map[string]string) string { + if len(labels) == 0 { + return "" + } + keys := make([]string, 0, len(labels)) + for k := range labels { + keys = append(keys, k) + } + sort.Strings(keys) + parts := make([]string, 0, len(keys)) + for _, key := range keys { + parts = append(parts, fmt.Sprintf("%s=\"%s\"", key, escapeLabelValue(labels[key]))) + } + return strings.Join(parts, ",") +} + +func cloneLabels(in map[string]string) map[string]string { + if len(in) == 0 { + return map[string]string{} + } + out := make(map[string]string, len(in)+1) + for k, v := range in { + out[k] = v + } + return out +} + +func trimFloat(v float64) string { + return strconv.FormatFloat(v, 'f', -1, 64) +} + +func escapeLabelValue(value string) string { + value = strings.ReplaceAll(value, `\`, `\\`) + value = strings.ReplaceAll(value, "\n", `\n`) + value = strings.ReplaceAll(value, `"`, `\"`) + return value +} + +func readResidentMemoryBytes() (float64, bool) { + data, err := os.ReadFile("/proc/self/statm") + if err != nil { + return 0, false + } + fields := strings.Fields(string(data)) + if len(fields) < 2 { + return 0, false + } + rssPages, err := strconv.ParseInt(fields[1], 10, 64) + if err != nil || rssPages < 0 { + return 0, false + } + return float64(rssPages * int64(os.Getpagesize())), true +} + +func readProcessCPUSeconds() (float64, bool) { + var usage syscall.Rusage + if err := syscall.Getrusage(syscall.RUSAGE_SELF, &usage); err != nil { + return 0, false + } + user := float64(usage.Utime.Sec) + float64(usage.Utime.Usec)/1e6 + sys := float64(usage.Stime.Sec) + float64(usage.Stime.Usec)/1e6 + return user + sys, true +} diff --git a/metrics/metrics_test.go b/metrics/metrics_test.go new file mode 100644 index 0000000..2f25c23 --- /dev/null +++ b/metrics/metrics_test.go @@ -0,0 +1,34 @@ +package metrics + +import ( + "strings" + "testing" +) + +func TestRenderPrometheusHistogramNoEmptyLabelSet(t *testing.T) { + reg := NewRegistry() + reg.ObserveBatchSize(3) + reg.ObserveGC(0, 0, 0, 0, true) + + out := reg.RenderPrometheus() + if strings.Contains(out, "fs_batch_size_histogram_sum{}") { + t.Fatalf("unexpected empty label set for batch sum metric") + } + if strings.Contains(out, "fs_batch_size_histogram_count{}") { + t.Fatalf("unexpected empty label set for batch count metric") + } + if strings.Contains(out, "fs_gc_duration_seconds_sum{}") { + t.Fatalf("unexpected empty label set for gc sum metric") + } + if strings.Contains(out, "fs_gc_duration_seconds_count{}") { + t.Fatalf("unexpected empty label set for gc count metric") + } +} + +func TestEscapeLabelValueEscapesSingleBackslash(t *testing.T) { + got := escapeLabelValue(`a\b`) + want := `a\\b` + if got != want { + t.Fatalf("escapeLabelValue returned %q, want %q", got, want) + } +} diff --git a/models/models.go b/models/models.go index df363e1..4b0d38a 100644 --- a/models/models.go +++ b/models/models.go @@ -68,6 +68,23 @@ type ListBucketResult struct { CommonPrefixes []CommonPrefixes `xml:"CommonPrefixes,omitempty"` } +type ListBucketResultV1 struct { + XMLName xml.Name `xml:"ListBucketResult"` + Xmlns string `xml:"xmlns,attr"` + + Name string `xml:"Name"` + Prefix string `xml:"Prefix"` + Marker string `xml:"Marker,omitempty"` + NextMarker string `xml:"NextMarker,omitempty"` + Delimiter string `xml:"Delimiter,omitempty"` + MaxKeys int `xml:"MaxKeys"` + IsTruncated bool `xml:"IsTruncated"` + EncodingType string `xml:"EncodingType,omitempty"` + + Contents []Contents `xml:"Contents,omitempty"` + CommonPrefixes []CommonPrefixes `xml:"CommonPrefixes,omitempty"` +} + type ListBucketResultV2 struct { XMLName xml.Name `xml:"ListBucketResult"` Xmlns string `xml:"xmlns,attr"` @@ -183,3 +200,26 @@ type DeleteError struct { Code string `xml:"Code"` Message string `xml:"Message"` } + +type AuthIdentity struct { + AccessKeyID string `json:"access_key_id"` + SecretEnc string `json:"secret_enc"` + SecretNonce string `json:"secret_nonce"` + EncAlg string `json:"enc_alg"` + KeyVersion string `json:"key_version"` + Status string `json:"status"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` +} + +type AuthPolicy struct { + Principal string `json:"principal"` + Statements []AuthPolicyStatement `json:"statements"` +} + +type AuthPolicyStatement struct { + Effect string `json:"effect"` + Actions []string `json:"actions"` + Bucket string `json:"bucket"` + Prefix string `json:"prefix"` +} diff --git a/service/service.go b/service/service.go index 09b04d6..9d7db43 100644 --- a/service/service.go +++ b/service/service.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "fs/metadata" + "fs/metrics" "fs/models" "fs/storage" "io" @@ -41,9 +42,37 @@ func NewObjectService(metadataHandler *metadata.MetadataHandler, blobHandler *st } } -func (s *ObjectService) PutObject(bucket, key, contentType string, input io.Reader) (*models.ObjectManifest, error) { +func (s *ObjectService) acquireGCRLock() func() { + waitStart := time.Now() s.gcMu.RLock() - defer s.gcMu.RUnlock() + metrics.Default.ObserveLockWait("gc_mu_read", time.Since(waitStart)) + holdStart := time.Now() + return func() { + metrics.Default.ObserveLockHold("gc_mu_read", time.Since(holdStart)) + s.gcMu.RUnlock() + } +} + +func (s *ObjectService) acquireGCLock() func() { + waitStart := time.Now() + s.gcMu.Lock() + metrics.Default.ObserveLockWait("gc_mu_write", time.Since(waitStart)) + holdStart := time.Now() + return func() { + metrics.Default.ObserveLockHold("gc_mu_write", time.Since(holdStart)) + s.gcMu.Unlock() + } +} + +func (s *ObjectService) PutObject(bucket, key, contentType string, input io.Reader) (*models.ObjectManifest, error) { + start := time.Now() + success := false + defer func() { + metrics.Default.ObserveService("put_object", time.Since(start), success) + }() + + unlock := s.acquireGCRLock() + defer unlock() chunks, size, etag, err := s.blob.IngestStream(input) if err != nil { @@ -71,110 +100,171 @@ func (s *ObjectService) PutObject(bucket, key, contentType string, input io.Read return nil, err } + success = true return manifest, nil } func (s *ObjectService) GetObject(bucket, key string) (io.ReadCloser, *models.ObjectManifest, error) { + start := time.Now() + + waitStart := time.Now() s.gcMu.RLock() + metrics.Default.ObserveLockWait("gc_mu_read", time.Since(waitStart)) + holdStart := time.Now() manifest, err := s.metadata.GetManifest(bucket, key) if err != nil { + metrics.Default.ObserveLockHold("gc_mu_read", time.Since(holdStart)) s.gcMu.RUnlock() + metrics.Default.ObserveService("get_object", time.Since(start), false) return nil, nil, err } pr, pw := io.Pipe() go func() { + streamOK := false + defer func() { + metrics.Default.ObserveService("get_object", time.Since(start), streamOK) + }() + defer metrics.Default.ObserveLockHold("gc_mu_read", time.Since(holdStart)) defer s.gcMu.RUnlock() if err := s.blob.AssembleStream(manifest.Chunks, pw); err != nil { _ = pw.CloseWithError(err) return } - _ = pw.Close() + if err := pw.Close(); err != nil { + return + } + streamOK = true }() return pr, manifest, nil } func (s *ObjectService) HeadObject(bucket, key string) (models.ObjectManifest, error) { - s.gcMu.RLock() - defer s.gcMu.RUnlock() + start := time.Now() + success := false + defer func() { + metrics.Default.ObserveService("head_object", time.Since(start), success) + }() + + unlock := s.acquireGCRLock() + defer unlock() manifest, err := s.metadata.GetManifest(bucket, key) if err != nil { return models.ObjectManifest{}, err } + success = true return *manifest, nil } func (s *ObjectService) DeleteObject(bucket, key string) error { - s.gcMu.RLock() - defer s.gcMu.RUnlock() - return s.metadata.DeleteManifest(bucket, key) + start := time.Now() + success := false + defer func() { + metrics.Default.ObserveService("delete_object", time.Since(start), success) + }() + + unlock := s.acquireGCRLock() + defer unlock() + err := s.metadata.DeleteManifest(bucket, key) + success = err == nil + return err } func (s *ObjectService) ListObjects(bucket, prefix string) ([]*models.ObjectManifest, error) { - s.gcMu.RLock() - defer s.gcMu.RUnlock() + unlock := s.acquireGCRLock() + defer unlock() return s.metadata.ListObjects(bucket, prefix) } func (s *ObjectService) ForEachObjectFrom(bucket, startKey string, fn func(*models.ObjectManifest) error) error { - s.gcMu.RLock() - defer s.gcMu.RUnlock() + start := time.Now() + success := false + defer func() { + metrics.Default.ObserveService("for_each_object_from", time.Since(start), success) + }() - return s.metadata.ForEachObjectFrom(bucket, startKey, fn) + unlock := s.acquireGCRLock() + defer unlock() + + err := s.metadata.ForEachObjectFrom(bucket, startKey, fn) + success = err == nil + return err } func (s *ObjectService) CreateBucket(bucket string) error { - s.gcMu.RLock() - defer s.gcMu.RUnlock() - return s.metadata.CreateBucket(bucket) + start := time.Now() + success := false + defer func() { + metrics.Default.ObserveService("create_bucket", time.Since(start), success) + }() + + unlock := s.acquireGCRLock() + defer unlock() + err := s.metadata.CreateBucket(bucket) + success = err == nil + return err } func (s *ObjectService) HeadBucket(bucket string) error { - s.gcMu.RLock() - defer s.gcMu.RUnlock() + unlock := s.acquireGCRLock() + defer unlock() _, err := s.metadata.GetBucketManifest(bucket) return err } func (s *ObjectService) GetBucketManifest(bucket string) (*models.BucketManifest, error) { - s.gcMu.RLock() - defer s.gcMu.RUnlock() + unlock := s.acquireGCRLock() + defer unlock() return s.metadata.GetBucketManifest(bucket) } func (s *ObjectService) DeleteBucket(bucket string) error { - s.gcMu.RLock() - defer s.gcMu.RUnlock() + unlock := s.acquireGCRLock() + defer unlock() return s.metadata.DeleteBucket(bucket) } func (s *ObjectService) ListBuckets() ([]string, error) { - s.gcMu.RLock() - defer s.gcMu.RUnlock() + start := time.Now() + success := false + defer func() { + metrics.Default.ObserveService("list_buckets", time.Since(start), success) + }() - return s.metadata.ListBuckets() + unlock := s.acquireGCRLock() + defer unlock() + + buckets, err := s.metadata.ListBuckets() + success = err == nil + return buckets, err } func (s *ObjectService) DeleteObjects(bucket string, keys []string) ([]string, error) { - s.gcMu.RLock() - defer s.gcMu.RUnlock() + unlock := s.acquireGCRLock() + defer unlock() return s.metadata.DeleteManifests(bucket, keys) } func (s *ObjectService) CreateMultipartUpload(bucket, key string) (*models.MultipartUpload, error) { - s.gcMu.RLock() - defer s.gcMu.RUnlock() + unlock := s.acquireGCRLock() + defer unlock() return s.metadata.CreateMultipartUpload(bucket, key) } func (s *ObjectService) UploadPart(bucket, key, uploadId string, partNumber int, input io.Reader) (string, error) { - s.gcMu.RLock() - defer s.gcMu.RUnlock() + start := time.Now() + success := false + defer func() { + metrics.Default.ObserveService("upload_part", time.Since(start), success) + }() + + unlock := s.acquireGCRLock() + defer unlock() if partNumber < 1 || partNumber > 10000 { return "", ErrInvalidPart @@ -204,12 +294,13 @@ func (s *ObjectService) UploadPart(bucket, key, uploadId string, partNumber int, if err != nil { return "", err } + success = true return etag, nil } func (s *ObjectService) ListMultipartParts(bucket, key, uploadID string) ([]models.UploadedPart, error) { - s.gcMu.RLock() - defer s.gcMu.RUnlock() + unlock := s.acquireGCRLock() + defer unlock() upload, err := s.metadata.GetMultipartUpload(uploadID) if err != nil { @@ -222,8 +313,14 @@ func (s *ObjectService) ListMultipartParts(bucket, key, uploadID string) ([]mode } func (s *ObjectService) CompleteMultipartUpload(bucket, key, uploadID string, completed []models.CompletedPart) (*models.ObjectManifest, error) { - s.gcMu.RLock() - defer s.gcMu.RUnlock() + start := time.Now() + success := false + defer func() { + metrics.Default.ObserveService("complete_multipart_upload", time.Since(start), success) + }() + + unlock := s.acquireGCRLock() + defer unlock() if len(completed) == 0 { return nil, ErrInvalidCompleteRequest @@ -288,12 +385,13 @@ func (s *ObjectService) CompleteMultipartUpload(bucket, key, uploadID string, co return nil, err } + success = true return manifest, nil } func (s *ObjectService) AbortMultipartUpload(bucket, key, uploadID string) error { - s.gcMu.RLock() - defer s.gcMu.RUnlock() + unlock := s.acquireGCRLock() + defer unlock() upload, err := s.metadata.GetMultipartUpload(uploadID) if err != nil { @@ -327,8 +425,17 @@ func (s *ObjectService) Close() error { } func (s *ObjectService) GarbageCollect() error { - s.gcMu.Lock() - defer s.gcMu.Unlock() + start := time.Now() + success := false + deletedChunks := 0 + deleteErrors := 0 + cleanedUploads := 0 + defer func() { + metrics.Default.ObserveGC(time.Since(start), deletedChunks, deleteErrors, cleanedUploads, success) + }() + + unlock := s.acquireGCLock() + defer unlock() referencedChunkSet, err := s.metadata.GetReferencedChunkSet() if err != nil { @@ -336,9 +443,6 @@ func (s *ObjectService) GarbageCollect() error { } totalChunks := 0 - deletedChunks := 0 - deleteErrors := 0 - cleanedUploads := 0 if err := s.blob.ForEachChunk(func(chunkID string) error { totalChunks++ @@ -368,6 +472,7 @@ func (s *ObjectService) GarbageCollect() error { "delete_errors", deleteErrors, "cleaned_uploads", cleanedUploads, ) + success = true return nil } diff --git a/storage/blob.go b/storage/blob.go index 6215a6f..667958f 100644 --- a/storage/blob.go +++ b/storage/blob.go @@ -6,10 +6,12 @@ import ( "encoding/hex" "errors" "fmt" + "fs/metrics" "io" "os" "path/filepath" "strings" + "time" ) const blobRoot = "blobs" @@ -37,11 +39,16 @@ func NewBlobStore(root string, chunkSize int) (*BlobStore, error) { } func (bs *BlobStore) IngestStream(stream io.Reader) ([]string, int64, string, error) { + start := time.Now() fullFileHasher := md5.New() buffer := make([]byte, bs.chunkSize) var totalSize int64 var chunkIDs []string + success := false + defer func() { + metrics.Default.ObserveBlob("ingest_stream", time.Since(start), 0, success) + }() for { bytesRead, err := io.ReadFull(stream, buffer) @@ -74,10 +81,18 @@ func (bs *BlobStore) IngestStream(stream io.Reader) ([]string, int64, string, er } etag := hex.EncodeToString(fullFileHasher.Sum(nil)) + success = true return chunkIDs, totalSize, etag, nil } func (bs *BlobStore) saveBlob(chunkID string, data []byte) error { + start := time.Now() + success := false + writtenBytes := int64(0) + defer func() { + metrics.Default.ObserveBlob("write_chunk", time.Since(start), writtenBytes, success) + }() + if !isValidChunkID(chunkID) { return fmt.Errorf("invalid chunk id: %q", chunkID) } @@ -88,6 +103,7 @@ func (bs *BlobStore) saveBlob(chunkID string, data []byte) error { fullPath := filepath.Join(dir, chunkID) if _, err := os.Stat(fullPath); err == nil { + success = true return nil } else if !os.IsNotExist(err) { return err @@ -119,6 +135,7 @@ func (bs *BlobStore) saveBlob(chunkID string, data []byte) error { if err := os.Rename(tmpPath, fullPath); err != nil { if _, statErr := os.Stat(fullPath); statErr == nil { + success = true return nil } return err @@ -128,10 +145,18 @@ func (bs *BlobStore) saveBlob(chunkID string, data []byte) error { if err := syncDir(dir); err != nil { return err } + writtenBytes = int64(len(data)) + success = true return nil } func (bs *BlobStore) AssembleStream(chunkIDs []string, w *io.PipeWriter) error { + start := time.Now() + success := false + defer func() { + metrics.Default.ObserveBlob("assemble_stream", time.Since(start), 0, success) + }() + for _, chunkID := range chunkIDs { chunkData, err := bs.GetBlob(chunkID) if err != nil { @@ -141,14 +166,28 @@ func (bs *BlobStore) AssembleStream(chunkIDs []string, w *io.PipeWriter) error { return err } } + success = true return nil } func (bs *BlobStore) GetBlob(chunkID string) ([]byte, error) { + start := time.Now() + success := false + var size int64 + defer func() { + metrics.Default.ObserveBlob("read_chunk", time.Since(start), size, success) + }() + if !isValidChunkID(chunkID) { return nil, fmt.Errorf("invalid chunk id: %q", chunkID) } - return os.ReadFile(filepath.Join(bs.dataRoot, blobRoot, chunkID[:2], chunkID[2:4], chunkID)) + data, err := os.ReadFile(filepath.Join(bs.dataRoot, blobRoot, chunkID[:2], chunkID[2:4], chunkID)) + if err != nil { + return nil, err + } + size = int64(len(data)) + success = true + return data, nil } func (bs *BlobStore) DeleteBlob(chunkID string) error { diff --git a/utils/config.go b/utils/config.go index b42bf11..88c38ee 100644 --- a/utils/config.go +++ b/utils/config.go @@ -21,6 +21,15 @@ type Config struct { GcInterval time.Duration GcEnabled bool MultipartCleanupRetention time.Duration + AuthEnabled bool + AuthRegion string + AuthSkew time.Duration + AuthMaxPresign time.Duration + AuthMasterKey string + AuthBootstrapAccessKey string + AuthBootstrapSecretKey string + AuthBootstrapPolicy string + AdminAPIEnabled bool } func NewConfig() *Config { @@ -39,6 +48,15 @@ func NewConfig() *Config { MultipartCleanupRetention: time.Duration( envIntRange("MULTIPART_RETENTION_HOURS", 24, 1, 24*30), ) * time.Hour, + AuthEnabled: envBool("AUTH_ENABLED", false), + AuthRegion: firstNonEmpty(strings.TrimSpace(os.Getenv("AUTH_REGION")), "us-east-1"), + AuthSkew: time.Duration(envIntRange("AUTH_SKEW_SECONDS", 300, 30, 3600)) * time.Second, + AuthMaxPresign: time.Duration(envIntRange("AUTH_MAX_PRESIGN_SECONDS", 86400, 60, 86400)) * time.Second, + AuthMasterKey: strings.TrimSpace(os.Getenv("AUTH_MASTER_KEY")), + AuthBootstrapAccessKey: strings.TrimSpace(os.Getenv("AUTH_BOOTSTRAP_ACCESS_KEY")), + AuthBootstrapSecretKey: strings.TrimSpace(os.Getenv("AUTH_BOOTSTRAP_SECRET_KEY")), + AuthBootstrapPolicy: strings.TrimSpace(os.Getenv("AUTH_BOOTSTRAP_POLICY")), + AdminAPIEnabled: envBool("ADMIN_API_ENABLED", true), } if config.LogFormat != "json" && config.LogFormat != "text" {