Enhance API with health check endpoints and improve multipart upload management

This commit is contained in:
2026-02-25 00:34:06 +01:00
parent a9fbc06dd0
commit abe1f453fc
10 changed files with 452 additions and 138 deletions

View File

@@ -6,3 +6,4 @@ AUDIT_LOG=true
ADDRESS=0.0.0.0 ADDRESS=0.0.0.0
GC_INTERVAL=10 GC_INTERVAL=10
GC_ENABLED=true GC_ENABLED=true
MULTIPART_RETENTION_HOURS=24

View File

@@ -8,9 +8,11 @@ RUN go mod download
COPY . . COPY . .
RUN CGO_ENABLED=0 GOOS=linux go build -o /app/fs . RUN CGO_ENABLED=0 GOOS=linux go build -o /app/fs .
FROM scratch AS runner FROM alpine:3.23 AS runner
COPY --from=build /app/fs /app/fs COPY --from=build /app/fs /app/fs
WORKDIR /app WORKDIR /app
EXPOSE 2600
HEALTHCHECK --interval=30s --timeout=3s --start-period=10s --retries=3 CMD wget -q -O /dev/null "http://127.0.0.1:${PORT:-2600}/healthz" || exit 1
CMD ["/app/fs"] CMD ["/app/fs"]

View File

@@ -29,6 +29,10 @@ Multi-object delete:
AWS SigV4 streaming payload decoding for uploads (`aws-chunked` request bodies) AWS SigV4 streaming payload decoding for uploads (`aws-chunked` request bodies)
Health:
- `GET /healthz`
- `HEAD /healthz`
## Limitations ## Limitations
- No authentication/authorization yet. - No authentication/authorization yet.

View File

@@ -13,11 +13,12 @@ import (
"fs/service" "fs/service"
"io" "io"
"log/slog" "log/slog"
"net"
"net/http" "net/http"
"net/url" "net/url"
"sort"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
@@ -31,8 +32,21 @@ type Handler struct {
logConfig logging.Config logConfig logging.Config
} }
const (
maxXMLBodyBytes int64 = 1 << 20
maxDeleteObjects = 1000
maxObjectKeyBytes = 1024
serverReadHeaderTimeout = 5 * time.Second
serverReadTimeout = 60 * time.Second
serverWriteTimeout = 120 * time.Second
serverIdleTimeout = 120 * time.Second
serverMaxHeaderBytes = 1 << 20
serverMaxConnections = 1024
)
func NewHandler(svc *service.ObjectService, logger *slog.Logger, logConfig logging.Config) *Handler { func NewHandler(svc *service.ObjectService, logger *slog.Logger, logConfig logging.Config) *Handler {
r := chi.NewRouter() r := chi.NewRouter()
r.Use(middleware.RequestID)
r.Use(middleware.Recoverer) r.Use(middleware.Recoverer)
if logger == nil { if logger == nil {
logger = slog.Default() logger = slog.Default()
@@ -50,6 +64,8 @@ func NewHandler(svc *service.ObjectService, logger *slog.Logger, logConfig loggi
func (h *Handler) setupRoutes() { func (h *Handler) setupRoutes() {
h.router.Use(logging.HTTPMiddleware(h.logger, h.logConfig)) h.router.Use(logging.HTTPMiddleware(h.logger, h.logConfig))
h.router.Get("/healthz", h.handleHealth)
h.router.Head("/healthz", h.handleHealth)
h.router.Get("/", h.handleGetBuckets) h.router.Get("/", h.handleGetBuckets)
h.router.Get("/{bucket}/", h.handleGetBucket) h.router.Get("/{bucket}/", h.handleGetBucket)
@@ -70,20 +86,40 @@ func (h *Handler) setupRoutes() {
h.router.Delete("/{bucket}/*", h.handleDeleteObject) h.router.Delete("/{bucket}/*", h.handleDeleteObject)
} }
func (h *Handler) handleWelcome(w http.ResponseWriter) { func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) if _, err := h.svc.ListBuckets(); err != nil {
_, err := w.Write([]byte("Welcome to the Object Storage API!")) w.Header().Set("Content-Type", "text/plain; charset=utf-8")
if err != nil { w.WriteHeader(http.StatusServiceUnavailable)
if r.Method != http.MethodHead {
_, _ = w.Write([]byte("unhealthy"))
}
return return
} }
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusOK)
if r.Method != http.MethodHead {
_, _ = w.Write([]byte("ok"))
}
}
func validateObjectKey(key string) *s3APIError {
if key == "" {
err := s3ErrInvalidObjectKey
return &err
}
if len(key) > maxObjectKeyBytes {
err := s3ErrKeyTooLong
return &err
}
return nil
} }
func (h *Handler) handleGetObject(w http.ResponseWriter, r *http.Request) { func (h *Handler) handleGetObject(w http.ResponseWriter, r *http.Request) {
bucket := chi.URLParam(r, "bucket") bucket := chi.URLParam(r, "bucket")
key := chi.URLParam(r, "*") key := chi.URLParam(r, "*")
if key == "" { if apiErr := validateObjectKey(key); apiErr != nil {
writeS3Error(w, r, s3ErrInvalidObjectKey, r.URL.Path) writeS3Error(w, r, *apiErr, r.URL.Path)
return return
} }
@@ -140,8 +176,8 @@ func (h *Handler) handleGetObject(w http.ResponseWriter, r *http.Request) {
func (h *Handler) handlePostObject(w http.ResponseWriter, r *http.Request) { func (h *Handler) handlePostObject(w http.ResponseWriter, r *http.Request) {
bucket := chi.URLParam(r, "bucket") bucket := chi.URLParam(r, "bucket")
key := chi.URLParam(r, "*") key := chi.URLParam(r, "*")
if key == "" { if apiErr := validateObjectKey(key); apiErr != nil {
writeS3Error(w, r, s3ErrInvalidObjectKey, r.URL.Path) writeS3Error(w, r, *apiErr, r.URL.Path)
return return
} }
defer r.Body.Close() defer r.Body.Close()
@@ -171,8 +207,14 @@ func (h *Handler) handlePostObject(w http.ResponseWriter, r *http.Request) {
} }
if uploadID := r.URL.Query().Get("uploadId"); uploadID != "" { if uploadID := r.URL.Query().Get("uploadId"); uploadID != "" {
r.Body = http.MaxBytesReader(w, r.Body, maxXMLBodyBytes)
var req models.CompleteMultipartUploadRequest var req models.CompleteMultipartUploadRequest
if err := xml.NewDecoder(r.Body).Decode(&req); err != nil { if err := xml.NewDecoder(r.Body).Decode(&req); err != nil {
var maxErr *http.MaxBytesError
if errors.As(err, &maxErr) {
writeS3Error(w, r, s3ErrEntityTooLarge, r.URL.Path)
return
}
writeS3Error(w, r, s3ErrMalformedXML, r.URL.Path) writeS3Error(w, r, s3ErrMalformedXML, r.URL.Path)
return return
} }
@@ -209,8 +251,8 @@ func (h *Handler) handlePostObject(w http.ResponseWriter, r *http.Request) {
func (h *Handler) handlePutObject(w http.ResponseWriter, r *http.Request) { func (h *Handler) handlePutObject(w http.ResponseWriter, r *http.Request) {
bucket := chi.URLParam(r, "bucket") bucket := chi.URLParam(r, "bucket")
key := chi.URLParam(r, "*") key := chi.URLParam(r, "*")
if key == "" { if apiErr := validateObjectKey(key); apiErr != nil {
writeS3Error(w, r, s3ErrInvalidObjectKey, r.URL.Path) writeS3Error(w, r, *apiErr, r.URL.Path)
return return
} }
defer r.Body.Close() defer r.Body.Close()
@@ -382,28 +424,6 @@ func decodeAWSChunkedPayload(src io.Reader, dst io.Writer) error {
} }
} }
func (h *Handler) handleHeadObject(w http.ResponseWriter, r *http.Request) {
bucket := chi.URLParam(r, "bucket")
key := chi.URLParam(r, "*")
if key == "" {
writeS3Error(w, r, s3ErrInvalidObjectKey, r.URL.Path)
return
}
manifest, err := h.svc.HeadObject(bucket, key)
if err != nil {
writeMappedS3Error(w, r, err)
return
}
etag := manifest.ETag
size := strconv.FormatInt(manifest.Size, 10)
w.Header().Set("ETag", `"`+etag+`"`)
w.Header().Set("Content-Length", size)
w.Header().Set("Last-Modified", time.Unix(manifest.CreatedAt, 0).UTC().Format(http.TimeFormat))
w.WriteHeader(http.StatusOK)
}
func (h *Handler) handlePutBucket(w http.ResponseWriter, r *http.Request) { func (h *Handler) handlePutBucket(w http.ResponseWriter, r *http.Request) {
bucket := chi.URLParam(r, "bucket") bucket := chi.URLParam(r, "bucket")
if err := h.svc.CreateBucket(bucket); err != nil { if err := h.svc.CreateBucket(bucket); err != nil {
@@ -429,6 +449,7 @@ func (h *Handler) handlePostBucket(w http.ResponseWriter, r *http.Request) {
return return
} }
defer r.Body.Close() defer r.Body.Close()
r.Body = http.MaxBytesReader(w, r.Body, maxXMLBodyBytes)
bodyReader := io.Reader(r.Body) bodyReader := io.Reader(r.Body)
var decodeStream io.ReadCloser var decodeStream io.ReadCloser
@@ -440,9 +461,18 @@ func (h *Handler) handlePostBucket(w http.ResponseWriter, r *http.Request) {
var req models.DeleteObjectsRequest var req models.DeleteObjectsRequest
if err := xml.NewDecoder(bodyReader).Decode(&req); err != nil { if err := xml.NewDecoder(bodyReader).Decode(&req); err != nil {
var maxErr *http.MaxBytesError
if errors.As(err, &maxErr) {
writeS3Error(w, r, s3ErrEntityTooLarge, r.URL.Path)
return
}
writeS3Error(w, r, s3ErrMalformedXML, r.URL.Path) writeS3Error(w, r, s3ErrMalformedXML, r.URL.Path)
return return
} }
if len(req.Objects) > maxDeleteObjects {
writeS3Error(w, r, s3ErrTooManyDeleteObjects, r.URL.Path)
return
}
keys := make([]string, 0, len(req.Objects)) keys := make([]string, 0, len(req.Objects))
response := models.DeleteObjectsResult{ response := models.DeleteObjectsResult{
@@ -457,6 +487,14 @@ func (h *Handler) handlePostBucket(w http.ResponseWriter, r *http.Request) {
}) })
continue continue
} }
if len(obj.Key) > maxObjectKeyBytes {
response.Errors = append(response.Errors, models.DeleteError{
Key: obj.Key,
Code: s3ErrKeyTooLong.Code,
Message: s3ErrKeyTooLong.Message,
})
continue
}
keys = append(keys, obj.Key) keys = append(keys, obj.Key)
} }
@@ -488,8 +526,8 @@ func (h *Handler) handlePostBucket(w http.ResponseWriter, r *http.Request) {
func (h *Handler) handleDeleteObject(w http.ResponseWriter, r *http.Request) { func (h *Handler) handleDeleteObject(w http.ResponseWriter, r *http.Request) {
bucket := chi.URLParam(r, "bucket") bucket := chi.URLParam(r, "bucket")
key := chi.URLParam(r, "*") key := chi.URLParam(r, "*")
if key == "" { if apiErr := validateObjectKey(key); apiErr != nil {
writeS3Error(w, r, s3ErrInvalidObjectKey, r.URL.Path) writeS3Error(w, r, *apiErr, r.URL.Path)
return return
} }
if uploadId := r.URL.Query().Get("uploadId"); uploadId != "" { if uploadId := r.URL.Query().Get("uploadId"); uploadId != "" {
@@ -522,6 +560,68 @@ func (h *Handler) handleHeadBucket(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
} }
func (h *Handler) handleHeadObject(w http.ResponseWriter, r *http.Request) {
bucket := chi.URLParam(r, "bucket")
key := chi.URLParam(r, "*")
if apiErr := validateObjectKey(key); apiErr != nil {
writeS3Error(w, r, *apiErr, r.URL.Path)
return
}
manifest, err := h.svc.HeadObject(bucket, key)
if err != nil {
writeMappedS3Error(w, r, err)
return
}
etag := manifest.ETag
size := strconv.FormatInt(manifest.Size, 10)
w.Header().Set("ETag", `"`+etag+`"`)
w.Header().Set("Content-Length", size)
w.Header().Set("Last-Modified", time.Unix(manifest.CreatedAt, 0).UTC().Format(http.TimeFormat))
w.WriteHeader(http.StatusOK)
}
type limitedListener struct {
net.Listener
slots chan struct{}
}
func newLimitedListener(inner net.Listener, maxConns int) net.Listener {
if maxConns <= 0 {
return inner
}
return &limitedListener{
Listener: inner,
slots: make(chan struct{}, maxConns),
}
}
func (l *limitedListener) Accept() (net.Conn, error) {
l.slots <- struct{}{}
conn, err := l.Listener.Accept()
if err != nil {
<-l.slots
return nil, err
}
return &limitedConn{
Conn: conn,
done: func() { <-l.slots },
}, nil
}
type limitedConn struct {
net.Conn
once sync.Once
done func()
}
func (c *limitedConn) Close() error {
err := c.Conn.Close()
c.once.Do(c.done)
return err
}
func (h *Handler) handleGetBuckets(w http.ResponseWriter, r *http.Request) { func (h *Handler) handleGetBuckets(w http.ResponseWriter, r *http.Request) {
buckets, err := h.svc.ListBuckets() buckets, err := h.svc.ListBuckets()
if err != nil { if err != nil {
@@ -613,6 +713,8 @@ func (h *Handler) handleListObjectsV2(w http.ResponseWriter, r *http.Request, bu
continuationToken := strings.TrimSpace(r.URL.Query().Get("continuation-token")) continuationToken := strings.TrimSpace(r.URL.Query().Get("continuation-token"))
continuationMarker := "" continuationMarker := ""
continuationType := ""
continuationValue := ""
if continuationToken != "" { if continuationToken != "" {
decoded, err := base64.StdEncoding.DecodeString(continuationToken) decoded, err := base64.StdEncoding.DecodeString(continuationToken)
if err != nil || len(decoded) == 0 { if err != nil || len(decoded) == 0 {
@@ -620,33 +722,11 @@ func (h *Handler) handleListObjectsV2(w http.ResponseWriter, r *http.Request, bu
return return
} }
continuationMarker = string(decoded) continuationMarker = string(decoded)
} continuationType, continuationValue, _ = strings.Cut(continuationMarker, ":")
if (continuationType != "K" && continuationType != "C") || continuationValue == "" {
objects, err := h.svc.ListObjects(bucket, prefix)
if err != nil {
writeMappedS3Error(w, r, err)
return
}
entries := buildListV2Entries(objects, prefix, delimiter)
startIdx := 0
if continuationMarker != "" {
found := false
for i, entry := range entries {
if entry.Marker == continuationMarker {
startIdx = i + 1
found = true
break
}
}
if !found {
writeS3Error(w, r, s3ErrInvalidArgument, r.URL.Path) writeS3Error(w, r, s3ErrInvalidArgument, r.URL.Path)
return return
} }
} else if startAfter != "" {
for startIdx < len(entries) && entries[startIdx].SortKey <= startAfter {
startIdx++
}
} }
result := models.ListBucketResultV2{ result := models.ListBucketResultV2{
@@ -660,9 +740,91 @@ func (h *Handler) handleListObjectsV2(w http.ResponseWriter, r *http.Request, bu
EncodingType: encodingType, EncodingType: encodingType,
} }
endIdx := startIdx type pageEntry struct {
for endIdx < len(entries) && result.KeyCount < maxKeys { Marker string
entry := entries[endIdx] Object *models.ObjectManifest
CommonPrefix string
}
entries := make([]pageEntry, 0, maxKeys)
seenCommonPrefixes := make(map[string]struct{})
truncated := false
stopErr := errors.New("list_v2_page_complete")
startKey := prefix
if continuationToken != "" {
startKey = continuationValue
} else if startAfter != "" && startAfter > startKey {
startKey = startAfter
}
if maxKeys > 0 {
err := h.svc.ForEachObjectFrom(bucket, startKey, func(object *models.ObjectManifest) error {
if object == nil {
return nil
}
key := object.Key
if prefix != "" {
if key < prefix {
return nil
}
if !strings.HasPrefix(key, prefix) {
return stopErr
}
}
if continuationToken != "" {
if continuationType == "K" && key <= continuationValue {
return nil
}
if continuationType == "C" && strings.HasPrefix(key, continuationValue) {
return nil
}
} else if startAfter != "" && key <= startAfter {
return nil
}
if delimiter != "" {
relative := strings.TrimPrefix(key, prefix)
if idx := strings.Index(relative, delimiter); idx >= 0 {
commonPrefix := prefix + relative[:idx+len(delimiter)]
if continuationToken == "" && startAfter != "" && commonPrefix <= startAfter {
return nil
}
if _, exists := seenCommonPrefixes[commonPrefix]; exists {
return nil
}
seenCommonPrefixes[commonPrefix] = struct{}{}
if len(entries) >= maxKeys {
truncated = true
return stopErr
}
entries = append(entries, pageEntry{
Marker: "C:" + commonPrefix,
CommonPrefix: commonPrefix,
})
return nil
}
}
if len(entries) >= maxKeys {
truncated = true
return stopErr
}
entries = append(entries, pageEntry{
Marker: "K:" + key,
Object: object,
})
return nil
})
if err != nil && !errors.Is(err, stopErr) {
writeMappedS3Error(w, r, err)
return
}
}
for _, entry := range entries {
if entry.Object != nil { if entry.Object != nil {
result.Contents = append(result.Contents, models.Contents{ result.Contents = append(result.Contents, models.Contents{
Key: s3EncodeIfNeeded(entry.Object.Key, encodingType), Key: s3EncodeIfNeeded(entry.Object.Key, encodingType),
@@ -677,12 +839,11 @@ func (h *Handler) handleListObjectsV2(w http.ResponseWriter, r *http.Request, bu
}) })
} }
result.KeyCount++ result.KeyCount++
endIdx++
} }
result.IsTruncated = endIdx < len(entries) result.IsTruncated = truncated
if result.IsTruncated && result.KeyCount > 0 { if result.IsTruncated && result.KeyCount > 0 {
result.NextContinuationToken = base64.StdEncoding.EncodeToString([]byte(entries[endIdx-1].Marker)) result.NextContinuationToken = base64.StdEncoding.EncodeToString([]byte(entries[result.KeyCount-1].Marker))
} }
xmlResponse, err := xml.MarshalIndent(result, "", " ") xmlResponse, err := xml.MarshalIndent(result, "", " ")
@@ -699,52 +860,6 @@ func (h *Handler) handleListObjectsV2(w http.ResponseWriter, r *http.Request, bu
} }
type listV2Entry struct {
Marker string
SortKey string
Object *models.ObjectManifest
CommonPrefix string
}
func buildListV2Entries(objects []*models.ObjectManifest, prefix, delimiter string) []listV2Entry {
sorted := make([]*models.ObjectManifest, 0, len(objects))
sorted = append(sorted, objects...)
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].Key < sorted[j].Key
})
entries := make([]listV2Entry, 0, len(sorted))
seenCommonPrefixes := make(map[string]struct{})
for _, object := range sorted {
if object == nil {
continue
}
if delimiter != "" {
relative := strings.TrimPrefix(object.Key, prefix)
if idx := strings.Index(relative, delimiter); idx >= 0 {
commonPrefix := prefix + relative[:idx+len(delimiter)]
if _, exists := seenCommonPrefixes[commonPrefix]; exists {
continue
}
seenCommonPrefixes[commonPrefix] = struct{}{}
entries = append(entries, listV2Entry{
Marker: "C:" + commonPrefix,
SortKey: commonPrefix,
CommonPrefix: commonPrefix,
})
continue
}
}
entries = append(entries, listV2Entry{
Marker: "K:" + object.Key,
SortKey: object.Key,
Object: object,
})
}
return entries
}
func s3EncodeIfNeeded(value, encodingType string) string { func s3EncodeIfNeeded(value, encodingType string) string {
if encodingType != "url" || value == "" { if encodingType != "url" || value == "" {
return value return value
@@ -815,13 +930,23 @@ func (h *Handler) Start(ctx context.Context, address string) error {
h.setupRoutes() h.setupRoutes()
server := http.Server{ server := http.Server{
Addr: address, Addr: address,
Handler: h.router, Handler: h.router,
ReadHeaderTimeout: serverReadHeaderTimeout,
ReadTimeout: serverReadTimeout,
WriteTimeout: serverWriteTimeout,
IdleTimeout: serverIdleTimeout,
MaxHeaderBytes: serverMaxHeaderBytes,
} }
errCh := make(chan error, 1) errCh := make(chan error, 1)
listener, err := net.Listen("tcp", address)
if err != nil {
return err
}
limitedListener := newLimitedListener(listener, serverMaxConnections)
go func() { go func() {
if err := server.ListenAndServe(); err != nil { if err := server.Serve(limitedListener); err != nil {
if !errors.Is(err, http.ErrServerClosed) { if !errors.Is(err, http.ErrServerClosed) {
errCh <- err errCh <- err
} }

View File

@@ -7,6 +7,8 @@ import (
"fs/models" "fs/models"
"fs/service" "fs/service"
"net/http" "net/http"
"github.com/go-chi/chi/v5/middleware"
) )
type s3APIError struct { type s3APIError struct {
@@ -21,6 +23,11 @@ var (
Code: "InvalidArgument", Code: "InvalidArgument",
Message: "Object key is required.", Message: "Object key is required.",
} }
s3ErrKeyTooLong = s3APIError{
Status: http.StatusBadRequest,
Code: "KeyTooLongError",
Message: "Your key is too long.",
}
s3ErrNotImplemented = s3APIError{ s3ErrNotImplemented = s3APIError{
Status: http.StatusNotImplemented, Status: http.StatusNotImplemented,
Code: "NotImplemented", Code: "NotImplemented",
@@ -56,6 +63,16 @@ var (
Code: "EntityTooSmall", Code: "EntityTooSmall",
Message: "Your proposed upload is smaller than the minimum allowed object size.", Message: "Your proposed upload is smaller than the minimum allowed object size.",
} }
s3ErrEntityTooLarge = s3APIError{
Status: http.StatusRequestEntityTooLarge,
Code: "EntityTooLarge",
Message: "Your proposed upload exceeds the maximum allowed size.",
}
s3ErrTooManyDeleteObjects = s3APIError{
Status: http.StatusBadRequest,
Code: "MalformedXML",
Message: "The request must contain no more than 1000 object identifiers.",
}
s3ErrInternal = s3APIError{ s3ErrInternal = s3APIError{
Status: http.StatusInternalServerError, Status: http.StatusInternalServerError,
Code: "InternalError", Code: "InternalError",
@@ -121,6 +138,13 @@ func mapToS3Error(err error) s3APIError {
} }
func writeS3Error(w http.ResponseWriter, r *http.Request, apiErr s3APIError, resource string) { func writeS3Error(w http.ResponseWriter, r *http.Request, apiErr s3APIError, resource string) {
requestID := ""
if r != nil {
requestID = middleware.GetReqID(r.Context())
if requestID != "" {
w.Header().Set("x-amz-request-id", requestID)
}
}
w.Header().Set("Content-Type", "application/xml; charset=utf-8") w.Header().Set("Content-Type", "application/xml; charset=utf-8")
w.WriteHeader(apiErr.Status) w.WriteHeader(apiErr.Status)
@@ -129,9 +153,10 @@ func writeS3Error(w http.ResponseWriter, r *http.Request, apiErr s3APIError, res
} }
payload := models.S3ErrorResponse{ payload := models.S3ErrorResponse{
Code: apiErr.Code, Code: apiErr.Code,
Message: apiErr.Message, Message: apiErr.Message,
Resource: resource, Resource: resource,
RequestID: requestID,
} }
out, err := xml.MarshalIndent(payload, "", " ") out, err := xml.MarshalIndent(payload, "", " ")

View File

@@ -86,6 +86,10 @@ func HTTPMiddleware(logger *slog.Logger, cfg Config) func(http.Handler) http.Han
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now() start := time.Now()
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
requestID := middleware.GetReqID(r.Context())
if requestID != "" {
ww.Header().Set("x-amz-request-id", requestID)
}
next.ServeHTTP(ww, r) next.ServeHTTP(ww, r)
@@ -106,6 +110,9 @@ func HTTPMiddleware(logger *slog.Logger, cfg Config) func(http.Handler) http.Han
"duration_ms", float64(elapsed.Nanoseconds()) / 1_000_000.0, "duration_ms", float64(elapsed.Nanoseconds()) / 1_000_000.0,
"remote_addr", r.RemoteAddr, "remote_addr", r.RemoteAddr,
} }
if requestID != "" {
attrs = append(attrs, "request_id", requestID)
}
if cfg.DebugMode { if cfg.DebugMode {
attrs = append(attrs, attrs = append(attrs,

View File

@@ -13,6 +13,7 @@ import (
"path/filepath" "path/filepath"
"strconv" "strconv"
"syscall" "syscall"
"time"
) )
func main() { func main() {
@@ -24,6 +25,7 @@ func main() {
"log_format", logConfig.Format, "log_format", logConfig.Format,
"audit_log", logConfig.Audit, "audit_log", logConfig.Audit,
"data_path", config.DataPath, "data_path", config.DataPath,
"multipart_retention_hours", int(config.MultipartCleanupRetention/time.Hour),
) )
if err := os.MkdirAll(config.DataPath, 0o755); err != nil { if err := os.MkdirAll(config.DataPath, 0o755); err != nil {
@@ -44,7 +46,7 @@ func main() {
return return
} }
objectService := service.NewObjectService(metadataHandler, blobHandler) objectService := service.NewObjectService(metadataHandler, blobHandler, config.MultipartCleanupRetention)
handler := api.NewHandler(objectService, logger, logConfig) handler := api.NewHandler(objectService, logger, logConfig)
addr := config.Address + ":" + strconv.Itoa(config.Port) addr := config.Address + ":" + strconv.Itoa(config.Port)

View File

@@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"fs/models" "fs/models"
"net"
"regexp" "regexp"
"sort" "sort"
"strings" "strings"
@@ -23,7 +24,7 @@ var systemIndex = []byte("__SYSTEM_BUCKETS__")
var multipartUploadIndex = []byte("__MULTIPART_UPLOADS__") var multipartUploadIndex = []byte("__MULTIPART_UPLOADS__")
var multipartUploadPartsIndex = []byte("__MULTIPART_UPLOAD_PARTS__") var multipartUploadPartsIndex = []byte("__MULTIPART_UPLOAD_PARTS__")
var validBucketName = regexp.MustCompile(`^[a-z0-9.-]{3,63}$`) var validBucketName = regexp.MustCompile(`^[a-z0-9.-]+$`)
var ( var (
ErrInvalidBucketName = errors.New("invalid bucket name") ErrInvalidBucketName = errors.New("invalid bucket name")
@@ -70,12 +71,36 @@ func NewMetadataHandler(dbPath string) (*MetadataHandler, error) {
return h, nil return h, nil
} }
func isValidBucketName(bucketName string) bool {
if len(bucketName) < 3 || len(bucketName) > 63 {
return false
}
if !validBucketName.MatchString(bucketName) {
return false
}
if strings.Contains(bucketName, "..") {
return false
}
if bucketName[0] == '.' || bucketName[0] == '-' || bucketName[len(bucketName)-1] == '.' || bucketName[len(bucketName)-1] == '-' {
return false
}
for _, label := range strings.Split(bucketName, ".") {
if label == "" || label[0] == '-' || label[len(label)-1] == '-' {
return false
}
}
if ip := net.ParseIP(bucketName); ip != nil && ip.To4() != nil {
return false
}
return true
}
func (h *MetadataHandler) Close() error { func (h *MetadataHandler) Close() error {
return h.db.Close() return h.db.Close()
} }
func (h *MetadataHandler) CreateBucket(bucketName string) error { func (h *MetadataHandler) CreateBucket(bucketName string) error {
if !validBucketName.MatchString(bucketName) { if !isValidBucketName(bucketName) {
return fmt.Errorf("%w: %s", ErrInvalidBucketName, bucketName) return fmt.Errorf("%w: %s", ErrInvalidBucketName, bucketName)
} }
@@ -107,7 +132,7 @@ func (h *MetadataHandler) CreateBucket(bucketName string) error {
} }
func (h *MetadataHandler) DeleteBucket(bucketName string) error { func (h *MetadataHandler) DeleteBucket(bucketName string) error {
if !validBucketName.MatchString(bucketName) { if !isValidBucketName(bucketName) {
return fmt.Errorf("%w: %s", ErrInvalidBucketName, bucketName) return fmt.Errorf("%w: %s", ErrInvalidBucketName, bucketName)
} }
@@ -290,6 +315,46 @@ func (h *MetadataHandler) ListObjects(bucket, prefix string) ([]*models.ObjectMa
return objects, nil return objects, nil
} }
func (h *MetadataHandler) ForEachObjectFrom(bucket, startKey string, fn func(*models.ObjectManifest) error) error {
if fn == nil {
return errors.New("object callback is required")
}
return h.db.View(func(tx *bbolt.Tx) error {
systemIndexBucket := tx.Bucket([]byte(systemIndex))
if systemIndexBucket == nil {
return errors.New("system index not found")
}
if systemIndexBucket.Get([]byte(bucket)) == nil {
return fmt.Errorf("%w: %s", ErrBucketNotFound, bucket)
}
metadataBucket := tx.Bucket([]byte(bucket))
if metadataBucket == nil {
return fmt.Errorf("%w: %s", ErrBucketNotFound, bucket)
}
cursor := metadataBucket.Cursor()
var k, v []byte
if startKey == "" {
k, v = cursor.First()
} else {
k, v = cursor.Seek([]byte(startKey))
}
for ; k != nil; k, v = cursor.Next() {
object := models.ObjectManifest{}
if err := json.Unmarshal(v, &object); err != nil {
return err
}
if err := fn(&object); err != nil {
return err
}
}
return nil
})
}
func (h *MetadataHandler) DeleteManifest(bucket, key string) error { func (h *MetadataHandler) DeleteManifest(bucket, key string) error {
if _, err := h.GetManifest(bucket, key); err != nil { if _, err := h.GetManifest(bucket, key); err != nil {
return err return err
@@ -602,6 +667,58 @@ func (h *MetadataHandler) AbortMultipartUpload(uploadID string) error {
return nil return nil
} }
func (h *MetadataHandler) CleanupMultipartUploads(retention time.Duration) (int, error) {
if retention <= 0 {
return 0, nil
}
cleaned := 0
err := h.db.Update(func(tx *bbolt.Tx) error {
uploadsBucket, err := getMultipartUploadBucket(tx)
if err != nil {
return err
}
now := time.Now().UTC()
keysToDelete := make([]string, 0)
if err := uploadsBucket.ForEach(func(k, v []byte) error {
upload := models.MultipartUpload{}
if err := json.Unmarshal(v, &upload); err != nil {
return err
}
if upload.State == "pending" {
return nil
}
createdAt, err := time.Parse(time.RFC3339, upload.CreatedAt)
if err != nil {
return nil
}
if now.Sub(createdAt) >= retention {
keysToDelete = append(keysToDelete, string(k))
}
return nil
}); err != nil {
return err
}
for _, uploadID := range keysToDelete {
if err := uploadsBucket.Delete([]byte(uploadID)); err != nil {
return err
}
if err := deleteMultipartPartsByUploadID(tx, uploadID); err != nil {
return err
}
cleaned++
}
return nil
})
if err != nil {
return 0, err
}
return cleaned, nil
}
func (h *MetadataHandler) GetReferencedChunkSet() (map[string]struct{}, error) { func (h *MetadataHandler) GetReferencedChunkSet() (map[string]struct{}, error) {
chunkSet := make(map[string]struct{}) chunkSet := make(map[string]struct{})
pendingUploadSet := make(map[string]struct{}) pendingUploadSet := make(map[string]struct{})

View File

@@ -17,9 +17,10 @@ import (
) )
type ObjectService struct { type ObjectService struct {
metadata *metadata.MetadataHandler metadata *metadata.MetadataHandler
blob *storage.BlobStore blob *storage.BlobStore
gcMu sync.RWMutex multipartRetention time.Duration
gcMu sync.RWMutex
} }
var ( var (
@@ -29,8 +30,15 @@ var (
ErrEntityTooSmall = errors.New("multipart entity too small") ErrEntityTooSmall = errors.New("multipart entity too small")
) )
func NewObjectService(metadataHandler *metadata.MetadataHandler, blobHandler *storage.BlobStore) *ObjectService { func NewObjectService(metadataHandler *metadata.MetadataHandler, blobHandler *storage.BlobStore, multipartRetention time.Duration) *ObjectService {
return &ObjectService{metadata: metadataHandler, blob: blobHandler} if multipartRetention <= 0 {
multipartRetention = 24 * time.Hour
}
return &ObjectService{
metadata: metadataHandler,
blob: blobHandler,
multipartRetention: multipartRetention,
}
} }
func (s *ObjectService) PutObject(bucket, key, contentType string, input io.Reader) (*models.ObjectManifest, error) { func (s *ObjectService) PutObject(bucket, key, contentType string, input io.Reader) (*models.ObjectManifest, error) {
@@ -111,6 +119,13 @@ func (s *ObjectService) ListObjects(bucket, prefix string) ([]*models.ObjectMani
return s.metadata.ListObjects(bucket, prefix) return s.metadata.ListObjects(bucket, prefix)
} }
func (s *ObjectService) ForEachObjectFrom(bucket, startKey string, fn func(*models.ObjectManifest) error) error {
s.gcMu.RLock()
defer s.gcMu.RUnlock()
return s.metadata.ForEachObjectFrom(bucket, startKey, fn)
}
func (s *ObjectService) CreateBucket(bucket string) error { func (s *ObjectService) CreateBucket(bucket string) error {
s.gcMu.RLock() s.gcMu.RLock()
defer s.gcMu.RUnlock() defer s.gcMu.RUnlock()
@@ -323,6 +338,7 @@ func (s *ObjectService) GarbageCollect() error {
totalChunks := 0 totalChunks := 0
deletedChunks := 0 deletedChunks := 0
deleteErrors := 0 deleteErrors := 0
cleanedUploads := 0
if err := s.blob.ForEachChunk(func(chunkID string) error { if err := s.blob.ForEachChunk(func(chunkID string) error {
totalChunks++ totalChunks++
@@ -340,16 +356,27 @@ func (s *ObjectService) GarbageCollect() error {
return err return err
} }
cleanedUploads, err = s.metadata.CleanupMultipartUploads(s.multipartRetention)
if err != nil {
return err
}
slog.Info("garbage_collect_completed", slog.Info("garbage_collect_completed",
"referenced_chunks", len(referencedChunkSet), "referenced_chunks", len(referencedChunkSet),
"total_chunks", totalChunks, "total_chunks", totalChunks,
"deleted_chunks", deletedChunks, "deleted_chunks", deletedChunks,
"delete_errors", deleteErrors, "delete_errors", deleteErrors,
"cleaned_uploads", cleanedUploads,
) )
return nil return nil
} }
func (s *ObjectService) RunGC(ctx context.Context, interval time.Duration) { func (s *ObjectService) RunGC(ctx context.Context, interval time.Duration) {
if interval <= 0 {
slog.Warn("garbage_collect_disabled_invalid_interval", "interval", interval.String())
return
}
ticker := time.NewTicker(interval) ticker := time.NewTicker(interval)
defer ticker.Stop() defer ticker.Stop()

View File

@@ -11,15 +11,16 @@ import (
) )
type Config struct { type Config struct {
DataPath string DataPath string
Address string Address string
Port int Port int
ChunkSize int ChunkSize int
LogLevel string LogLevel string
LogFormat string LogFormat string
AuditLog bool AuditLog bool
GcInterval time.Duration GcInterval time.Duration
GcEnabled bool GcEnabled bool
MultipartCleanupRetention time.Duration
} }
func NewConfig() *Config { func NewConfig() *Config {
@@ -33,8 +34,11 @@ func NewConfig() *Config {
LogLevel: strings.ToLower(firstNonEmpty(strings.TrimSpace(os.Getenv("LOG_LEVEL")), "info")), LogLevel: strings.ToLower(firstNonEmpty(strings.TrimSpace(os.Getenv("LOG_LEVEL")), "info")),
LogFormat: strings.ToLower(firstNonEmpty(strings.TrimSpace(os.Getenv("LOG_FORMAT")), strings.TrimSpace(os.Getenv("LOG_TYPE")), "text")), LogFormat: strings.ToLower(firstNonEmpty(strings.TrimSpace(os.Getenv("LOG_FORMAT")), strings.TrimSpace(os.Getenv("LOG_TYPE")), "text")),
AuditLog: envBool("AUDIT_LOG", true), AuditLog: envBool("AUDIT_LOG", true),
GcInterval: time.Duration(envIntRange("GC_INTERVAL", 10, -1, 60)) * time.Minute, GcInterval: time.Duration(envIntRange("GC_INTERVAL", 10, 1, 60)) * time.Minute,
GcEnabled: envBool("GC_ENABLED", true), GcEnabled: envBool("GC_ENABLED", true),
MultipartCleanupRetention: time.Duration(
envIntRange("MULTIPART_RETENTION_HOURS", 24, 1, 24*30),
) * time.Hour,
} }
if config.LogFormat != "json" && config.LogFormat != "text" { if config.LogFormat != "json" && config.LogFormat != "text" {