From abe1f453fcaf6c5cc8b48f0f191a3334baa81819 Mon Sep 17 00:00:00 2001 From: Andrej Mickov Date: Wed, 25 Feb 2026 00:34:06 +0100 Subject: [PATCH] Enhance API with health check endpoints and improve multipart upload management --- .env.example | 3 +- Dockerfile | 4 +- README.md | 4 + api/api.go | 353 +++++++++++++++++++++++++++++-------------- api/s3_errors.go | 31 +++- logging/logging.go | 7 + main.go | 4 +- metadata/metadata.go | 123 ++++++++++++++- service/service.go | 37 ++++- utils/config.go | 24 +-- 10 files changed, 452 insertions(+), 138 deletions(-) diff --git a/.env.example b/.env.example index 3b4db17..f5276d1 100644 --- a/.env.example +++ b/.env.example @@ -5,4 +5,5 @@ PORT=2600 AUDIT_LOG=true ADDRESS=0.0.0.0 GC_INTERVAL=10 -GC_ENABLED=true \ No newline at end of file +GC_ENABLED=true +MULTIPART_RETENTION_HOURS=24 diff --git a/Dockerfile b/Dockerfile index 843bf24..33c185d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,9 +8,11 @@ RUN go mod download COPY . . RUN CGO_ENABLED=0 GOOS=linux go build -o /app/fs . -FROM scratch AS runner +FROM alpine:3.23 AS runner COPY --from=build /app/fs /app/fs WORKDIR /app +EXPOSE 2600 +HEALTHCHECK --interval=30s --timeout=3s --start-period=10s --retries=3 CMD wget -q -O /dev/null "http://127.0.0.1:${PORT:-2600}/healthz" || exit 1 CMD ["/app/fs"] diff --git a/README.md b/README.md index f708d26..a4764a4 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,10 @@ Multi-object delete: AWS SigV4 streaming payload decoding for uploads (`aws-chunked` request bodies) +Health: +- `GET /healthz` +- `HEAD /healthz` + ## Limitations - No authentication/authorization yet. diff --git a/api/api.go b/api/api.go index 4dabc74..5f01252 100644 --- a/api/api.go +++ b/api/api.go @@ -13,11 +13,12 @@ import ( "fs/service" "io" "log/slog" + "net" "net/http" "net/url" - "sort" "strconv" "strings" + "sync" "time" "github.com/go-chi/chi/v5" @@ -31,8 +32,21 @@ type Handler struct { logConfig logging.Config } +const ( + maxXMLBodyBytes int64 = 1 << 20 + maxDeleteObjects = 1000 + maxObjectKeyBytes = 1024 + serverReadHeaderTimeout = 5 * time.Second + serverReadTimeout = 60 * time.Second + serverWriteTimeout = 120 * time.Second + serverIdleTimeout = 120 * time.Second + serverMaxHeaderBytes = 1 << 20 + serverMaxConnections = 1024 +) + func NewHandler(svc *service.ObjectService, logger *slog.Logger, logConfig logging.Config) *Handler { r := chi.NewRouter() + r.Use(middleware.RequestID) r.Use(middleware.Recoverer) if logger == nil { logger = slog.Default() @@ -50,6 +64,8 @@ func NewHandler(svc *service.ObjectService, logger *slog.Logger, logConfig loggi func (h *Handler) setupRoutes() { h.router.Use(logging.HTTPMiddleware(h.logger, h.logConfig)) + h.router.Get("/healthz", h.handleHealth) + h.router.Head("/healthz", h.handleHealth) h.router.Get("/", h.handleGetBuckets) h.router.Get("/{bucket}/", h.handleGetBucket) @@ -70,20 +86,40 @@ func (h *Handler) setupRoutes() { h.router.Delete("/{bucket}/*", h.handleDeleteObject) } -func (h *Handler) handleWelcome(w http.ResponseWriter) { - w.WriteHeader(http.StatusOK) - _, err := w.Write([]byte("Welcome to the Object Storage API!")) - if err != nil { +func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request) { + if _, err := h.svc.ListBuckets(); err != nil { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusServiceUnavailable) + if r.Method != http.MethodHead { + _, _ = w.Write([]byte("unhealthy")) + } return } + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusOK) + if r.Method != http.MethodHead { + _, _ = w.Write([]byte("ok")) + } +} + +func validateObjectKey(key string) *s3APIError { + if key == "" { + err := s3ErrInvalidObjectKey + return &err + } + if len(key) > maxObjectKeyBytes { + err := s3ErrKeyTooLong + return &err + } + return nil } func (h *Handler) handleGetObject(w http.ResponseWriter, r *http.Request) { bucket := chi.URLParam(r, "bucket") key := chi.URLParam(r, "*") - if key == "" { - writeS3Error(w, r, s3ErrInvalidObjectKey, r.URL.Path) + if apiErr := validateObjectKey(key); apiErr != nil { + writeS3Error(w, r, *apiErr, r.URL.Path) return } @@ -140,8 +176,8 @@ func (h *Handler) handleGetObject(w http.ResponseWriter, r *http.Request) { func (h *Handler) handlePostObject(w http.ResponseWriter, r *http.Request) { bucket := chi.URLParam(r, "bucket") key := chi.URLParam(r, "*") - if key == "" { - writeS3Error(w, r, s3ErrInvalidObjectKey, r.URL.Path) + if apiErr := validateObjectKey(key); apiErr != nil { + writeS3Error(w, r, *apiErr, r.URL.Path) return } defer r.Body.Close() @@ -171,8 +207,14 @@ func (h *Handler) handlePostObject(w http.ResponseWriter, r *http.Request) { } if uploadID := r.URL.Query().Get("uploadId"); uploadID != "" { + r.Body = http.MaxBytesReader(w, r.Body, maxXMLBodyBytes) var req models.CompleteMultipartUploadRequest if err := xml.NewDecoder(r.Body).Decode(&req); err != nil { + var maxErr *http.MaxBytesError + if errors.As(err, &maxErr) { + writeS3Error(w, r, s3ErrEntityTooLarge, r.URL.Path) + return + } writeS3Error(w, r, s3ErrMalformedXML, r.URL.Path) return } @@ -209,8 +251,8 @@ func (h *Handler) handlePostObject(w http.ResponseWriter, r *http.Request) { func (h *Handler) handlePutObject(w http.ResponseWriter, r *http.Request) { bucket := chi.URLParam(r, "bucket") key := chi.URLParam(r, "*") - if key == "" { - writeS3Error(w, r, s3ErrInvalidObjectKey, r.URL.Path) + if apiErr := validateObjectKey(key); apiErr != nil { + writeS3Error(w, r, *apiErr, r.URL.Path) return } defer r.Body.Close() @@ -382,28 +424,6 @@ func decodeAWSChunkedPayload(src io.Reader, dst io.Writer) error { } } -func (h *Handler) handleHeadObject(w http.ResponseWriter, r *http.Request) { - bucket := chi.URLParam(r, "bucket") - key := chi.URLParam(r, "*") - if key == "" { - writeS3Error(w, r, s3ErrInvalidObjectKey, r.URL.Path) - return - } - - manifest, err := h.svc.HeadObject(bucket, key) - if err != nil { - writeMappedS3Error(w, r, err) - return - } - etag := manifest.ETag - size := strconv.FormatInt(manifest.Size, 10) - - w.Header().Set("ETag", `"`+etag+`"`) - w.Header().Set("Content-Length", size) - w.Header().Set("Last-Modified", time.Unix(manifest.CreatedAt, 0).UTC().Format(http.TimeFormat)) - w.WriteHeader(http.StatusOK) -} - func (h *Handler) handlePutBucket(w http.ResponseWriter, r *http.Request) { bucket := chi.URLParam(r, "bucket") if err := h.svc.CreateBucket(bucket); err != nil { @@ -429,6 +449,7 @@ func (h *Handler) handlePostBucket(w http.ResponseWriter, r *http.Request) { return } defer r.Body.Close() + r.Body = http.MaxBytesReader(w, r.Body, maxXMLBodyBytes) bodyReader := io.Reader(r.Body) var decodeStream io.ReadCloser @@ -440,9 +461,18 @@ func (h *Handler) handlePostBucket(w http.ResponseWriter, r *http.Request) { var req models.DeleteObjectsRequest if err := xml.NewDecoder(bodyReader).Decode(&req); err != nil { + var maxErr *http.MaxBytesError + if errors.As(err, &maxErr) { + writeS3Error(w, r, s3ErrEntityTooLarge, r.URL.Path) + return + } writeS3Error(w, r, s3ErrMalformedXML, r.URL.Path) return } + if len(req.Objects) > maxDeleteObjects { + writeS3Error(w, r, s3ErrTooManyDeleteObjects, r.URL.Path) + return + } keys := make([]string, 0, len(req.Objects)) response := models.DeleteObjectsResult{ @@ -457,6 +487,14 @@ func (h *Handler) handlePostBucket(w http.ResponseWriter, r *http.Request) { }) continue } + if len(obj.Key) > maxObjectKeyBytes { + response.Errors = append(response.Errors, models.DeleteError{ + Key: obj.Key, + Code: s3ErrKeyTooLong.Code, + Message: s3ErrKeyTooLong.Message, + }) + continue + } keys = append(keys, obj.Key) } @@ -488,8 +526,8 @@ func (h *Handler) handlePostBucket(w http.ResponseWriter, r *http.Request) { func (h *Handler) handleDeleteObject(w http.ResponseWriter, r *http.Request) { bucket := chi.URLParam(r, "bucket") key := chi.URLParam(r, "*") - if key == "" { - writeS3Error(w, r, s3ErrInvalidObjectKey, r.URL.Path) + if apiErr := validateObjectKey(key); apiErr != nil { + writeS3Error(w, r, *apiErr, r.URL.Path) return } if uploadId := r.URL.Query().Get("uploadId"); uploadId != "" { @@ -522,6 +560,68 @@ func (h *Handler) handleHeadBucket(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } +func (h *Handler) handleHeadObject(w http.ResponseWriter, r *http.Request) { + bucket := chi.URLParam(r, "bucket") + key := chi.URLParam(r, "*") + if apiErr := validateObjectKey(key); apiErr != nil { + writeS3Error(w, r, *apiErr, r.URL.Path) + return + } + + manifest, err := h.svc.HeadObject(bucket, key) + if err != nil { + writeMappedS3Error(w, r, err) + return + } + etag := manifest.ETag + size := strconv.FormatInt(manifest.Size, 10) + + w.Header().Set("ETag", `"`+etag+`"`) + w.Header().Set("Content-Length", size) + w.Header().Set("Last-Modified", time.Unix(manifest.CreatedAt, 0).UTC().Format(http.TimeFormat)) + w.WriteHeader(http.StatusOK) +} + +type limitedListener struct { + net.Listener + slots chan struct{} +} + +func newLimitedListener(inner net.Listener, maxConns int) net.Listener { + if maxConns <= 0 { + return inner + } + return &limitedListener{ + Listener: inner, + slots: make(chan struct{}, maxConns), + } +} + +func (l *limitedListener) Accept() (net.Conn, error) { + l.slots <- struct{}{} + conn, err := l.Listener.Accept() + if err != nil { + <-l.slots + return nil, err + } + return &limitedConn{ + Conn: conn, + done: func() { <-l.slots }, + }, nil +} + +type limitedConn struct { + net.Conn + once sync.Once + done func() +} + +func (c *limitedConn) Close() error { + err := c.Conn.Close() + c.once.Do(c.done) + return err +} + func (h *Handler) handleGetBuckets(w http.ResponseWriter, r *http.Request) { buckets, err := h.svc.ListBuckets() if err != nil { @@ -613,6 +713,8 @@ func (h *Handler) handleListObjectsV2(w http.ResponseWriter, r *http.Request, bu continuationToken := strings.TrimSpace(r.URL.Query().Get("continuation-token")) continuationMarker := "" + continuationType := "" + continuationValue := "" if continuationToken != "" { decoded, err := base64.StdEncoding.DecodeString(continuationToken) if err != nil || len(decoded) == 0 { @@ -620,33 +722,11 @@ func (h *Handler) handleListObjectsV2(w http.ResponseWriter, r *http.Request, bu return } continuationMarker = string(decoded) - } - - objects, err := h.svc.ListObjects(bucket, prefix) - if err != nil { - writeMappedS3Error(w, r, err) - return - } - - entries := buildListV2Entries(objects, prefix, delimiter) - startIdx := 0 - if continuationMarker != "" { - found := false - for i, entry := range entries { - if entry.Marker == continuationMarker { - startIdx = i + 1 - found = true - break - } - } - if !found { + continuationType, continuationValue, _ = strings.Cut(continuationMarker, ":") + if (continuationType != "K" && continuationType != "C") || continuationValue == "" { writeS3Error(w, r, s3ErrInvalidArgument, r.URL.Path) return } - } else if startAfter != "" { - for startIdx < len(entries) && entries[startIdx].SortKey <= startAfter { - startIdx++ - } } result := models.ListBucketResultV2{ @@ -660,9 +740,91 @@ func (h *Handler) handleListObjectsV2(w http.ResponseWriter, r *http.Request, bu EncodingType: encodingType, } - endIdx := startIdx - for endIdx < len(entries) && result.KeyCount < maxKeys { - entry := entries[endIdx] + type pageEntry struct { + Marker string + Object *models.ObjectManifest + CommonPrefix string + } + + entries := make([]pageEntry, 0, maxKeys) + seenCommonPrefixes := make(map[string]struct{}) + truncated := false + stopErr := errors.New("list_v2_page_complete") + + startKey := prefix + if continuationToken != "" { + startKey = continuationValue + } else if startAfter != "" && startAfter > startKey { + startKey = startAfter + } + + if maxKeys > 0 { + err := h.svc.ForEachObjectFrom(bucket, startKey, func(object *models.ObjectManifest) error { + if object == nil { + return nil + } + key := object.Key + + if prefix != "" { + if key < prefix { + return nil + } + if !strings.HasPrefix(key, prefix) { + return stopErr + } + } + + if continuationToken != "" { + if continuationType == "K" && key <= continuationValue { + return nil + } + if continuationType == "C" && strings.HasPrefix(key, continuationValue) { + return nil + } + } else if startAfter != "" && key <= startAfter { + return nil + } + + if delimiter != "" { + relative := strings.TrimPrefix(key, prefix) + if idx := strings.Index(relative, delimiter); idx >= 0 { + commonPrefix := prefix + relative[:idx+len(delimiter)] + if continuationToken == "" && startAfter != "" && commonPrefix <= startAfter { + return nil + } + if _, exists := seenCommonPrefixes[commonPrefix]; exists { + return nil + } + seenCommonPrefixes[commonPrefix] = struct{}{} + if len(entries) >= maxKeys { + truncated = true + return stopErr + } + entries = append(entries, pageEntry{ + Marker: "C:" + commonPrefix, + CommonPrefix: commonPrefix, + }) + return nil + } + } + + if len(entries) >= maxKeys { + truncated = true + return stopErr + } + entries = append(entries, pageEntry{ + Marker: "K:" + key, + Object: object, + }) + return nil + }) + if err != nil && !errors.Is(err, stopErr) { + writeMappedS3Error(w, r, err) + return + } + } + + for _, entry := range entries { if entry.Object != nil { result.Contents = append(result.Contents, models.Contents{ Key: s3EncodeIfNeeded(entry.Object.Key, encodingType), @@ -677,12 +839,11 @@ func (h *Handler) handleListObjectsV2(w http.ResponseWriter, r *http.Request, bu }) } result.KeyCount++ - endIdx++ } - result.IsTruncated = endIdx < len(entries) + result.IsTruncated = truncated if result.IsTruncated && result.KeyCount > 0 { - result.NextContinuationToken = base64.StdEncoding.EncodeToString([]byte(entries[endIdx-1].Marker)) + result.NextContinuationToken = base64.StdEncoding.EncodeToString([]byte(entries[result.KeyCount-1].Marker)) } xmlResponse, err := xml.MarshalIndent(result, "", " ") @@ -699,52 +860,6 @@ func (h *Handler) handleListObjectsV2(w http.ResponseWriter, r *http.Request, bu } -type listV2Entry struct { - Marker string - SortKey string - Object *models.ObjectManifest - CommonPrefix string -} - -func buildListV2Entries(objects []*models.ObjectManifest, prefix, delimiter string) []listV2Entry { - sorted := make([]*models.ObjectManifest, 0, len(objects)) - sorted = append(sorted, objects...) - sort.Slice(sorted, func(i, j int) bool { - return sorted[i].Key < sorted[j].Key - }) - - entries := make([]listV2Entry, 0, len(sorted)) - seenCommonPrefixes := make(map[string]struct{}) - for _, object := range sorted { - if object == nil { - continue - } - if delimiter != "" { - relative := strings.TrimPrefix(object.Key, prefix) - if idx := strings.Index(relative, delimiter); idx >= 0 { - commonPrefix := prefix + relative[:idx+len(delimiter)] - if _, exists := seenCommonPrefixes[commonPrefix]; exists { - continue - } - seenCommonPrefixes[commonPrefix] = struct{}{} - entries = append(entries, listV2Entry{ - Marker: "C:" + commonPrefix, - SortKey: commonPrefix, - CommonPrefix: commonPrefix, - }) - continue - } - } - - entries = append(entries, listV2Entry{ - Marker: "K:" + object.Key, - SortKey: object.Key, - Object: object, - }) - } - return entries -} - func s3EncodeIfNeeded(value, encodingType string) string { if encodingType != "url" || value == "" { return value @@ -815,13 +930,23 @@ func (h *Handler) Start(ctx context.Context, address string) error { h.setupRoutes() server := http.Server{ - Addr: address, - Handler: h.router, + Addr: address, + Handler: h.router, + ReadHeaderTimeout: serverReadHeaderTimeout, + ReadTimeout: serverReadTimeout, + WriteTimeout: serverWriteTimeout, + IdleTimeout: serverIdleTimeout, + MaxHeaderBytes: serverMaxHeaderBytes, } errCh := make(chan error, 1) + listener, err := net.Listen("tcp", address) + if err != nil { + return err + } + limitedListener := newLimitedListener(listener, serverMaxConnections) go func() { - if err := server.ListenAndServe(); err != nil { + if err := server.Serve(limitedListener); err != nil { if !errors.Is(err, http.ErrServerClosed) { errCh <- err } diff --git a/api/s3_errors.go b/api/s3_errors.go index f61afe1..7dbea68 100644 --- a/api/s3_errors.go +++ b/api/s3_errors.go @@ -7,6 +7,8 @@ import ( "fs/models" "fs/service" "net/http" + + "github.com/go-chi/chi/v5/middleware" ) type s3APIError struct { @@ -21,6 +23,11 @@ var ( Code: "InvalidArgument", Message: "Object key is required.", } + s3ErrKeyTooLong = s3APIError{ + Status: http.StatusBadRequest, + Code: "KeyTooLongError", + Message: "Your key is too long.", + } s3ErrNotImplemented = s3APIError{ Status: http.StatusNotImplemented, Code: "NotImplemented", @@ -56,6 +63,16 @@ var ( Code: "EntityTooSmall", Message: "Your proposed upload is smaller than the minimum allowed object size.", } + s3ErrEntityTooLarge = s3APIError{ + Status: http.StatusRequestEntityTooLarge, + Code: "EntityTooLarge", + Message: "Your proposed upload exceeds the maximum allowed size.", + } + s3ErrTooManyDeleteObjects = s3APIError{ + Status: http.StatusBadRequest, + Code: "MalformedXML", + Message: "The request must contain no more than 1000 object identifiers.", + } s3ErrInternal = s3APIError{ Status: http.StatusInternalServerError, Code: "InternalError", @@ -121,6 +138,13 @@ func mapToS3Error(err error) s3APIError { } func writeS3Error(w http.ResponseWriter, r *http.Request, apiErr s3APIError, resource string) { + requestID := "" + if r != nil { + requestID = middleware.GetReqID(r.Context()) + if requestID != "" { + w.Header().Set("x-amz-request-id", requestID) + } + } w.Header().Set("Content-Type", "application/xml; charset=utf-8") w.WriteHeader(apiErr.Status) @@ -129,9 +153,10 @@ func writeS3Error(w http.ResponseWriter, r *http.Request, apiErr s3APIError, res } payload := models.S3ErrorResponse{ - Code: apiErr.Code, - Message: apiErr.Message, - Resource: resource, + Code: apiErr.Code, + Message: apiErr.Message, + Resource: resource, + RequestID: requestID, } out, err := xml.MarshalIndent(payload, "", " ") diff --git a/logging/logging.go b/logging/logging.go index 7290bf2..28e1985 100644 --- a/logging/logging.go +++ b/logging/logging.go @@ -86,6 +86,10 @@ func HTTPMiddleware(logger *slog.Logger, cfg Config) func(http.Handler) http.Han return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) + requestID := middleware.GetReqID(r.Context()) + if requestID != "" { + ww.Header().Set("x-amz-request-id", requestID) + } next.ServeHTTP(ww, r) @@ -106,6 +110,9 @@ func HTTPMiddleware(logger *slog.Logger, cfg Config) func(http.Handler) http.Han "duration_ms", float64(elapsed.Nanoseconds()) / 1_000_000.0, "remote_addr", r.RemoteAddr, } + if requestID != "" { + attrs = append(attrs, "request_id", requestID) + } if cfg.DebugMode { attrs = append(attrs, diff --git a/main.go b/main.go index e10a9bd..70ed5ba 100644 --- a/main.go +++ b/main.go @@ -13,6 +13,7 @@ import ( "path/filepath" "strconv" "syscall" + "time" ) func main() { @@ -24,6 +25,7 @@ func main() { "log_format", logConfig.Format, "audit_log", logConfig.Audit, "data_path", config.DataPath, + "multipart_retention_hours", int(config.MultipartCleanupRetention/time.Hour), ) if err := os.MkdirAll(config.DataPath, 0o755); err != nil { @@ -44,7 +46,7 @@ func main() { return } - objectService := service.NewObjectService(metadataHandler, blobHandler) + objectService := service.NewObjectService(metadataHandler, blobHandler, config.MultipartCleanupRetention) handler := api.NewHandler(objectService, logger, logConfig) addr := config.Address + ":" + strconv.Itoa(config.Port) diff --git a/metadata/metadata.go b/metadata/metadata.go index 9e50a3a..916ddfe 100644 --- a/metadata/metadata.go +++ b/metadata/metadata.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "fs/models" + "net" "regexp" "sort" "strings" @@ -23,7 +24,7 @@ var systemIndex = []byte("__SYSTEM_BUCKETS__") var multipartUploadIndex = []byte("__MULTIPART_UPLOADS__") var multipartUploadPartsIndex = []byte("__MULTIPART_UPLOAD_PARTS__") -var validBucketName = regexp.MustCompile(`^[a-z0-9.-]{3,63}$`) +var validBucketName = regexp.MustCompile(`^[a-z0-9.-]+$`) var ( ErrInvalidBucketName = errors.New("invalid bucket name") @@ -70,12 +71,36 @@ func NewMetadataHandler(dbPath string) (*MetadataHandler, error) { return h, nil } +func isValidBucketName(bucketName string) bool { + if len(bucketName) < 3 || len(bucketName) > 63 { + return false + } + if !validBucketName.MatchString(bucketName) { + return false + } + if strings.Contains(bucketName, "..") { + return false + } + if bucketName[0] == '.' || bucketName[0] == '-' || bucketName[len(bucketName)-1] == '.' || bucketName[len(bucketName)-1] == '-' { + return false + } + for _, label := range strings.Split(bucketName, ".") { + if label == "" || label[0] == '-' || label[len(label)-1] == '-' { + return false + } + } + if ip := net.ParseIP(bucketName); ip != nil && ip.To4() != nil { + return false + } + return true +} + func (h *MetadataHandler) Close() error { return h.db.Close() } func (h *MetadataHandler) CreateBucket(bucketName string) error { - if !validBucketName.MatchString(bucketName) { + if !isValidBucketName(bucketName) { return fmt.Errorf("%w: %s", ErrInvalidBucketName, bucketName) } @@ -107,7 +132,7 @@ func (h *MetadataHandler) CreateBucket(bucketName string) error { } func (h *MetadataHandler) DeleteBucket(bucketName string) error { - if !validBucketName.MatchString(bucketName) { + if !isValidBucketName(bucketName) { return fmt.Errorf("%w: %s", ErrInvalidBucketName, bucketName) } @@ -290,6 +315,46 @@ func (h *MetadataHandler) ListObjects(bucket, prefix string) ([]*models.ObjectMa return objects, nil } +func (h *MetadataHandler) ForEachObjectFrom(bucket, startKey string, fn func(*models.ObjectManifest) error) error { + if fn == nil { + return errors.New("object callback is required") + } + + return h.db.View(func(tx *bbolt.Tx) error { + systemIndexBucket := tx.Bucket([]byte(systemIndex)) + if systemIndexBucket == nil { + return errors.New("system index not found") + } + if systemIndexBucket.Get([]byte(bucket)) == nil { + return fmt.Errorf("%w: %s", ErrBucketNotFound, bucket) + } + + metadataBucket := tx.Bucket([]byte(bucket)) + if metadataBucket == nil { + return fmt.Errorf("%w: %s", ErrBucketNotFound, bucket) + } + + cursor := metadataBucket.Cursor() + var k, v []byte + if startKey == "" { + k, v = cursor.First() + } else { + k, v = cursor.Seek([]byte(startKey)) + } + + for ; k != nil; k, v = cursor.Next() { + object := models.ObjectManifest{} + if err := json.Unmarshal(v, &object); err != nil { + return err + } + if err := fn(&object); err != nil { + return err + } + } + return nil + }) +} + func (h *MetadataHandler) DeleteManifest(bucket, key string) error { if _, err := h.GetManifest(bucket, key); err != nil { return err @@ -602,6 +667,58 @@ func (h *MetadataHandler) AbortMultipartUpload(uploadID string) error { return nil } +func (h *MetadataHandler) CleanupMultipartUploads(retention time.Duration) (int, error) { + if retention <= 0 { + return 0, nil + } + + cleaned := 0 + err := h.db.Update(func(tx *bbolt.Tx) error { + uploadsBucket, err := getMultipartUploadBucket(tx) + if err != nil { + return err + } + + now := time.Now().UTC() + keysToDelete := make([]string, 0) + if err := uploadsBucket.ForEach(func(k, v []byte) error { + upload := models.MultipartUpload{} + if err := json.Unmarshal(v, &upload); err != nil { + return err + } + if upload.State == "pending" { + return nil + } + createdAt, err := time.Parse(time.RFC3339, upload.CreatedAt) + if err != nil { + return nil + } + if now.Sub(createdAt) >= retention { + keysToDelete = append(keysToDelete, string(k)) + } + return nil + }); err != nil { + return err + } + + for _, uploadID := range keysToDelete { + if err := uploadsBucket.Delete([]byte(uploadID)); err != nil { + return err + } + if err := deleteMultipartPartsByUploadID(tx, uploadID); err != nil { + return err + } + cleaned++ + } + return nil + }) + if err != nil { + return 0, err + } + + return cleaned, nil +} + func (h *MetadataHandler) GetReferencedChunkSet() (map[string]struct{}, error) { chunkSet := make(map[string]struct{}) pendingUploadSet := make(map[string]struct{}) diff --git a/service/service.go b/service/service.go index c2553ad..09b04d6 100644 --- a/service/service.go +++ b/service/service.go @@ -17,9 +17,10 @@ import ( ) type ObjectService struct { - metadata *metadata.MetadataHandler - blob *storage.BlobStore - gcMu sync.RWMutex + metadata *metadata.MetadataHandler + blob *storage.BlobStore + multipartRetention time.Duration + gcMu sync.RWMutex } var ( @@ -29,8 +30,15 @@ var ( ErrEntityTooSmall = errors.New("multipart entity too small") ) -func NewObjectService(metadataHandler *metadata.MetadataHandler, blobHandler *storage.BlobStore) *ObjectService { - return &ObjectService{metadata: metadataHandler, blob: blobHandler} +func NewObjectService(metadataHandler *metadata.MetadataHandler, blobHandler *storage.BlobStore, multipartRetention time.Duration) *ObjectService { + if multipartRetention <= 0 { + multipartRetention = 24 * time.Hour + } + return &ObjectService{ + metadata: metadataHandler, + blob: blobHandler, + multipartRetention: multipartRetention, + } } func (s *ObjectService) PutObject(bucket, key, contentType string, input io.Reader) (*models.ObjectManifest, error) { @@ -111,6 +119,13 @@ func (s *ObjectService) ListObjects(bucket, prefix string) ([]*models.ObjectMani return s.metadata.ListObjects(bucket, prefix) } +func (s *ObjectService) ForEachObjectFrom(bucket, startKey string, fn func(*models.ObjectManifest) error) error { + s.gcMu.RLock() + defer s.gcMu.RUnlock() + + return s.metadata.ForEachObjectFrom(bucket, startKey, fn) +} + func (s *ObjectService) CreateBucket(bucket string) error { s.gcMu.RLock() defer s.gcMu.RUnlock() @@ -323,6 +338,7 @@ func (s *ObjectService) GarbageCollect() error { totalChunks := 0 deletedChunks := 0 deleteErrors := 0 + cleanedUploads := 0 if err := s.blob.ForEachChunk(func(chunkID string) error { totalChunks++ @@ -340,16 +356,27 @@ func (s *ObjectService) GarbageCollect() error { return err } + cleanedUploads, err = s.metadata.CleanupMultipartUploads(s.multipartRetention) + if err != nil { + return err + } + slog.Info("garbage_collect_completed", "referenced_chunks", len(referencedChunkSet), "total_chunks", totalChunks, "deleted_chunks", deletedChunks, "delete_errors", deleteErrors, + "cleaned_uploads", cleanedUploads, ) return nil } func (s *ObjectService) RunGC(ctx context.Context, interval time.Duration) { + if interval <= 0 { + slog.Warn("garbage_collect_disabled_invalid_interval", "interval", interval.String()) + return + } + ticker := time.NewTicker(interval) defer ticker.Stop() diff --git a/utils/config.go b/utils/config.go index 76d9247..b42bf11 100644 --- a/utils/config.go +++ b/utils/config.go @@ -11,15 +11,16 @@ import ( ) type Config struct { - DataPath string - Address string - Port int - ChunkSize int - LogLevel string - LogFormat string - AuditLog bool - GcInterval time.Duration - GcEnabled bool + DataPath string + Address string + Port int + ChunkSize int + LogLevel string + LogFormat string + AuditLog bool + GcInterval time.Duration + GcEnabled bool + MultipartCleanupRetention time.Duration } func NewConfig() *Config { @@ -33,8 +34,11 @@ func NewConfig() *Config { LogLevel: strings.ToLower(firstNonEmpty(strings.TrimSpace(os.Getenv("LOG_LEVEL")), "info")), LogFormat: strings.ToLower(firstNonEmpty(strings.TrimSpace(os.Getenv("LOG_FORMAT")), strings.TrimSpace(os.Getenv("LOG_TYPE")), "text")), AuditLog: envBool("AUDIT_LOG", true), - GcInterval: time.Duration(envIntRange("GC_INTERVAL", 10, -1, 60)) * time.Minute, + GcInterval: time.Duration(envIntRange("GC_INTERVAL", 10, 1, 60)) * time.Minute, GcEnabled: envBool("GC_ENABLED", true), + MultipartCleanupRetention: time.Duration( + envIntRange("MULTIPART_RETENTION_HOURS", 24, 1, 24*30), + ) * time.Hour, } if config.LogFormat != "json" && config.LogFormat != "text" {