Enhance API with health check endpoints and improve multipart upload management

This commit is contained in:
2026-02-25 00:34:06 +01:00
parent a9fbc06dd0
commit abe1f453fc
10 changed files with 452 additions and 138 deletions

View File

@@ -13,11 +13,12 @@ import (
"fs/service"
"io"
"log/slog"
"net"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/go-chi/chi/v5"
@@ -31,8 +32,21 @@ type Handler struct {
logConfig logging.Config
}
const (
maxXMLBodyBytes int64 = 1 << 20
maxDeleteObjects = 1000
maxObjectKeyBytes = 1024
serverReadHeaderTimeout = 5 * time.Second
serverReadTimeout = 60 * time.Second
serverWriteTimeout = 120 * time.Second
serverIdleTimeout = 120 * time.Second
serverMaxHeaderBytes = 1 << 20
serverMaxConnections = 1024
)
func NewHandler(svc *service.ObjectService, logger *slog.Logger, logConfig logging.Config) *Handler {
r := chi.NewRouter()
r.Use(middleware.RequestID)
r.Use(middleware.Recoverer)
if logger == nil {
logger = slog.Default()
@@ -50,6 +64,8 @@ func NewHandler(svc *service.ObjectService, logger *slog.Logger, logConfig loggi
func (h *Handler) setupRoutes() {
h.router.Use(logging.HTTPMiddleware(h.logger, h.logConfig))
h.router.Get("/healthz", h.handleHealth)
h.router.Head("/healthz", h.handleHealth)
h.router.Get("/", h.handleGetBuckets)
h.router.Get("/{bucket}/", h.handleGetBucket)
@@ -70,20 +86,40 @@ func (h *Handler) setupRoutes() {
h.router.Delete("/{bucket}/*", h.handleDeleteObject)
}
func (h *Handler) handleWelcome(w http.ResponseWriter) {
w.WriteHeader(http.StatusOK)
_, err := w.Write([]byte("Welcome to the Object Storage API!"))
if err != nil {
func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request) {
if _, err := h.svc.ListBuckets(); err != nil {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusServiceUnavailable)
if r.Method != http.MethodHead {
_, _ = w.Write([]byte("unhealthy"))
}
return
}
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusOK)
if r.Method != http.MethodHead {
_, _ = w.Write([]byte("ok"))
}
}
func validateObjectKey(key string) *s3APIError {
if key == "" {
err := s3ErrInvalidObjectKey
return &err
}
if len(key) > maxObjectKeyBytes {
err := s3ErrKeyTooLong
return &err
}
return nil
}
func (h *Handler) handleGetObject(w http.ResponseWriter, r *http.Request) {
bucket := chi.URLParam(r, "bucket")
key := chi.URLParam(r, "*")
if key == "" {
writeS3Error(w, r, s3ErrInvalidObjectKey, r.URL.Path)
if apiErr := validateObjectKey(key); apiErr != nil {
writeS3Error(w, r, *apiErr, r.URL.Path)
return
}
@@ -140,8 +176,8 @@ func (h *Handler) handleGetObject(w http.ResponseWriter, r *http.Request) {
func (h *Handler) handlePostObject(w http.ResponseWriter, r *http.Request) {
bucket := chi.URLParam(r, "bucket")
key := chi.URLParam(r, "*")
if key == "" {
writeS3Error(w, r, s3ErrInvalidObjectKey, r.URL.Path)
if apiErr := validateObjectKey(key); apiErr != nil {
writeS3Error(w, r, *apiErr, r.URL.Path)
return
}
defer r.Body.Close()
@@ -171,8 +207,14 @@ func (h *Handler) handlePostObject(w http.ResponseWriter, r *http.Request) {
}
if uploadID := r.URL.Query().Get("uploadId"); uploadID != "" {
r.Body = http.MaxBytesReader(w, r.Body, maxXMLBodyBytes)
var req models.CompleteMultipartUploadRequest
if err := xml.NewDecoder(r.Body).Decode(&req); err != nil {
var maxErr *http.MaxBytesError
if errors.As(err, &maxErr) {
writeS3Error(w, r, s3ErrEntityTooLarge, r.URL.Path)
return
}
writeS3Error(w, r, s3ErrMalformedXML, r.URL.Path)
return
}
@@ -209,8 +251,8 @@ func (h *Handler) handlePostObject(w http.ResponseWriter, r *http.Request) {
func (h *Handler) handlePutObject(w http.ResponseWriter, r *http.Request) {
bucket := chi.URLParam(r, "bucket")
key := chi.URLParam(r, "*")
if key == "" {
writeS3Error(w, r, s3ErrInvalidObjectKey, r.URL.Path)
if apiErr := validateObjectKey(key); apiErr != nil {
writeS3Error(w, r, *apiErr, r.URL.Path)
return
}
defer r.Body.Close()
@@ -382,28 +424,6 @@ func decodeAWSChunkedPayload(src io.Reader, dst io.Writer) error {
}
}
func (h *Handler) handleHeadObject(w http.ResponseWriter, r *http.Request) {
bucket := chi.URLParam(r, "bucket")
key := chi.URLParam(r, "*")
if key == "" {
writeS3Error(w, r, s3ErrInvalidObjectKey, r.URL.Path)
return
}
manifest, err := h.svc.HeadObject(bucket, key)
if err != nil {
writeMappedS3Error(w, r, err)
return
}
etag := manifest.ETag
size := strconv.FormatInt(manifest.Size, 10)
w.Header().Set("ETag", `"`+etag+`"`)
w.Header().Set("Content-Length", size)
w.Header().Set("Last-Modified", time.Unix(manifest.CreatedAt, 0).UTC().Format(http.TimeFormat))
w.WriteHeader(http.StatusOK)
}
func (h *Handler) handlePutBucket(w http.ResponseWriter, r *http.Request) {
bucket := chi.URLParam(r, "bucket")
if err := h.svc.CreateBucket(bucket); err != nil {
@@ -429,6 +449,7 @@ func (h *Handler) handlePostBucket(w http.ResponseWriter, r *http.Request) {
return
}
defer r.Body.Close()
r.Body = http.MaxBytesReader(w, r.Body, maxXMLBodyBytes)
bodyReader := io.Reader(r.Body)
var decodeStream io.ReadCloser
@@ -440,9 +461,18 @@ func (h *Handler) handlePostBucket(w http.ResponseWriter, r *http.Request) {
var req models.DeleteObjectsRequest
if err := xml.NewDecoder(bodyReader).Decode(&req); err != nil {
var maxErr *http.MaxBytesError
if errors.As(err, &maxErr) {
writeS3Error(w, r, s3ErrEntityTooLarge, r.URL.Path)
return
}
writeS3Error(w, r, s3ErrMalformedXML, r.URL.Path)
return
}
if len(req.Objects) > maxDeleteObjects {
writeS3Error(w, r, s3ErrTooManyDeleteObjects, r.URL.Path)
return
}
keys := make([]string, 0, len(req.Objects))
response := models.DeleteObjectsResult{
@@ -457,6 +487,14 @@ func (h *Handler) handlePostBucket(w http.ResponseWriter, r *http.Request) {
})
continue
}
if len(obj.Key) > maxObjectKeyBytes {
response.Errors = append(response.Errors, models.DeleteError{
Key: obj.Key,
Code: s3ErrKeyTooLong.Code,
Message: s3ErrKeyTooLong.Message,
})
continue
}
keys = append(keys, obj.Key)
}
@@ -488,8 +526,8 @@ func (h *Handler) handlePostBucket(w http.ResponseWriter, r *http.Request) {
func (h *Handler) handleDeleteObject(w http.ResponseWriter, r *http.Request) {
bucket := chi.URLParam(r, "bucket")
key := chi.URLParam(r, "*")
if key == "" {
writeS3Error(w, r, s3ErrInvalidObjectKey, r.URL.Path)
if apiErr := validateObjectKey(key); apiErr != nil {
writeS3Error(w, r, *apiErr, r.URL.Path)
return
}
if uploadId := r.URL.Query().Get("uploadId"); uploadId != "" {
@@ -522,6 +560,68 @@ func (h *Handler) handleHeadBucket(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
func (h *Handler) handleHeadObject(w http.ResponseWriter, r *http.Request) {
bucket := chi.URLParam(r, "bucket")
key := chi.URLParam(r, "*")
if apiErr := validateObjectKey(key); apiErr != nil {
writeS3Error(w, r, *apiErr, r.URL.Path)
return
}
manifest, err := h.svc.HeadObject(bucket, key)
if err != nil {
writeMappedS3Error(w, r, err)
return
}
etag := manifest.ETag
size := strconv.FormatInt(manifest.Size, 10)
w.Header().Set("ETag", `"`+etag+`"`)
w.Header().Set("Content-Length", size)
w.Header().Set("Last-Modified", time.Unix(manifest.CreatedAt, 0).UTC().Format(http.TimeFormat))
w.WriteHeader(http.StatusOK)
}
type limitedListener struct {
net.Listener
slots chan struct{}
}
func newLimitedListener(inner net.Listener, maxConns int) net.Listener {
if maxConns <= 0 {
return inner
}
return &limitedListener{
Listener: inner,
slots: make(chan struct{}, maxConns),
}
}
func (l *limitedListener) Accept() (net.Conn, error) {
l.slots <- struct{}{}
conn, err := l.Listener.Accept()
if err != nil {
<-l.slots
return nil, err
}
return &limitedConn{
Conn: conn,
done: func() { <-l.slots },
}, nil
}
type limitedConn struct {
net.Conn
once sync.Once
done func()
}
func (c *limitedConn) Close() error {
err := c.Conn.Close()
c.once.Do(c.done)
return err
}
func (h *Handler) handleGetBuckets(w http.ResponseWriter, r *http.Request) {
buckets, err := h.svc.ListBuckets()
if err != nil {
@@ -613,6 +713,8 @@ func (h *Handler) handleListObjectsV2(w http.ResponseWriter, r *http.Request, bu
continuationToken := strings.TrimSpace(r.URL.Query().Get("continuation-token"))
continuationMarker := ""
continuationType := ""
continuationValue := ""
if continuationToken != "" {
decoded, err := base64.StdEncoding.DecodeString(continuationToken)
if err != nil || len(decoded) == 0 {
@@ -620,33 +722,11 @@ func (h *Handler) handleListObjectsV2(w http.ResponseWriter, r *http.Request, bu
return
}
continuationMarker = string(decoded)
}
objects, err := h.svc.ListObjects(bucket, prefix)
if err != nil {
writeMappedS3Error(w, r, err)
return
}
entries := buildListV2Entries(objects, prefix, delimiter)
startIdx := 0
if continuationMarker != "" {
found := false
for i, entry := range entries {
if entry.Marker == continuationMarker {
startIdx = i + 1
found = true
break
}
}
if !found {
continuationType, continuationValue, _ = strings.Cut(continuationMarker, ":")
if (continuationType != "K" && continuationType != "C") || continuationValue == "" {
writeS3Error(w, r, s3ErrInvalidArgument, r.URL.Path)
return
}
} else if startAfter != "" {
for startIdx < len(entries) && entries[startIdx].SortKey <= startAfter {
startIdx++
}
}
result := models.ListBucketResultV2{
@@ -660,9 +740,91 @@ func (h *Handler) handleListObjectsV2(w http.ResponseWriter, r *http.Request, bu
EncodingType: encodingType,
}
endIdx := startIdx
for endIdx < len(entries) && result.KeyCount < maxKeys {
entry := entries[endIdx]
type pageEntry struct {
Marker string
Object *models.ObjectManifest
CommonPrefix string
}
entries := make([]pageEntry, 0, maxKeys)
seenCommonPrefixes := make(map[string]struct{})
truncated := false
stopErr := errors.New("list_v2_page_complete")
startKey := prefix
if continuationToken != "" {
startKey = continuationValue
} else if startAfter != "" && startAfter > startKey {
startKey = startAfter
}
if maxKeys > 0 {
err := h.svc.ForEachObjectFrom(bucket, startKey, func(object *models.ObjectManifest) error {
if object == nil {
return nil
}
key := object.Key
if prefix != "" {
if key < prefix {
return nil
}
if !strings.HasPrefix(key, prefix) {
return stopErr
}
}
if continuationToken != "" {
if continuationType == "K" && key <= continuationValue {
return nil
}
if continuationType == "C" && strings.HasPrefix(key, continuationValue) {
return nil
}
} else if startAfter != "" && key <= startAfter {
return nil
}
if delimiter != "" {
relative := strings.TrimPrefix(key, prefix)
if idx := strings.Index(relative, delimiter); idx >= 0 {
commonPrefix := prefix + relative[:idx+len(delimiter)]
if continuationToken == "" && startAfter != "" && commonPrefix <= startAfter {
return nil
}
if _, exists := seenCommonPrefixes[commonPrefix]; exists {
return nil
}
seenCommonPrefixes[commonPrefix] = struct{}{}
if len(entries) >= maxKeys {
truncated = true
return stopErr
}
entries = append(entries, pageEntry{
Marker: "C:" + commonPrefix,
CommonPrefix: commonPrefix,
})
return nil
}
}
if len(entries) >= maxKeys {
truncated = true
return stopErr
}
entries = append(entries, pageEntry{
Marker: "K:" + key,
Object: object,
})
return nil
})
if err != nil && !errors.Is(err, stopErr) {
writeMappedS3Error(w, r, err)
return
}
}
for _, entry := range entries {
if entry.Object != nil {
result.Contents = append(result.Contents, models.Contents{
Key: s3EncodeIfNeeded(entry.Object.Key, encodingType),
@@ -677,12 +839,11 @@ func (h *Handler) handleListObjectsV2(w http.ResponseWriter, r *http.Request, bu
})
}
result.KeyCount++
endIdx++
}
result.IsTruncated = endIdx < len(entries)
result.IsTruncated = truncated
if result.IsTruncated && result.KeyCount > 0 {
result.NextContinuationToken = base64.StdEncoding.EncodeToString([]byte(entries[endIdx-1].Marker))
result.NextContinuationToken = base64.StdEncoding.EncodeToString([]byte(entries[result.KeyCount-1].Marker))
}
xmlResponse, err := xml.MarshalIndent(result, "", " ")
@@ -699,52 +860,6 @@ func (h *Handler) handleListObjectsV2(w http.ResponseWriter, r *http.Request, bu
}
type listV2Entry struct {
Marker string
SortKey string
Object *models.ObjectManifest
CommonPrefix string
}
func buildListV2Entries(objects []*models.ObjectManifest, prefix, delimiter string) []listV2Entry {
sorted := make([]*models.ObjectManifest, 0, len(objects))
sorted = append(sorted, objects...)
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].Key < sorted[j].Key
})
entries := make([]listV2Entry, 0, len(sorted))
seenCommonPrefixes := make(map[string]struct{})
for _, object := range sorted {
if object == nil {
continue
}
if delimiter != "" {
relative := strings.TrimPrefix(object.Key, prefix)
if idx := strings.Index(relative, delimiter); idx >= 0 {
commonPrefix := prefix + relative[:idx+len(delimiter)]
if _, exists := seenCommonPrefixes[commonPrefix]; exists {
continue
}
seenCommonPrefixes[commonPrefix] = struct{}{}
entries = append(entries, listV2Entry{
Marker: "C:" + commonPrefix,
SortKey: commonPrefix,
CommonPrefix: commonPrefix,
})
continue
}
}
entries = append(entries, listV2Entry{
Marker: "K:" + object.Key,
SortKey: object.Key,
Object: object,
})
}
return entries
}
func s3EncodeIfNeeded(value, encodingType string) string {
if encodingType != "url" || value == "" {
return value
@@ -815,13 +930,23 @@ func (h *Handler) Start(ctx context.Context, address string) error {
h.setupRoutes()
server := http.Server{
Addr: address,
Handler: h.router,
Addr: address,
Handler: h.router,
ReadHeaderTimeout: serverReadHeaderTimeout,
ReadTimeout: serverReadTimeout,
WriteTimeout: serverWriteTimeout,
IdleTimeout: serverIdleTimeout,
MaxHeaderBytes: serverMaxHeaderBytes,
}
errCh := make(chan error, 1)
listener, err := net.Listen("tcp", address)
if err != nil {
return err
}
limitedListener := newLimitedListener(listener, serverMaxConnections)
go func() {
if err := server.ListenAndServe(); err != nil {
if err := server.Serve(limitedListener); err != nil {
if !errors.Is(err, http.ErrServerClosed) {
errCh <- err
}

View File

@@ -7,6 +7,8 @@ import (
"fs/models"
"fs/service"
"net/http"
"github.com/go-chi/chi/v5/middleware"
)
type s3APIError struct {
@@ -21,6 +23,11 @@ var (
Code: "InvalidArgument",
Message: "Object key is required.",
}
s3ErrKeyTooLong = s3APIError{
Status: http.StatusBadRequest,
Code: "KeyTooLongError",
Message: "Your key is too long.",
}
s3ErrNotImplemented = s3APIError{
Status: http.StatusNotImplemented,
Code: "NotImplemented",
@@ -56,6 +63,16 @@ var (
Code: "EntityTooSmall",
Message: "Your proposed upload is smaller than the minimum allowed object size.",
}
s3ErrEntityTooLarge = s3APIError{
Status: http.StatusRequestEntityTooLarge,
Code: "EntityTooLarge",
Message: "Your proposed upload exceeds the maximum allowed size.",
}
s3ErrTooManyDeleteObjects = s3APIError{
Status: http.StatusBadRequest,
Code: "MalformedXML",
Message: "The request must contain no more than 1000 object identifiers.",
}
s3ErrInternal = s3APIError{
Status: http.StatusInternalServerError,
Code: "InternalError",
@@ -121,6 +138,13 @@ func mapToS3Error(err error) s3APIError {
}
func writeS3Error(w http.ResponseWriter, r *http.Request, apiErr s3APIError, resource string) {
requestID := ""
if r != nil {
requestID = middleware.GetReqID(r.Context())
if requestID != "" {
w.Header().Set("x-amz-request-id", requestID)
}
}
w.Header().Set("Content-Type", "application/xml; charset=utf-8")
w.WriteHeader(apiErr.Status)
@@ -129,9 +153,10 @@ func writeS3Error(w http.ResponseWriter, r *http.Request, apiErr s3APIError, res
}
payload := models.S3ErrorResponse{
Code: apiErr.Code,
Message: apiErr.Message,
Resource: resource,
Code: apiErr.Code,
Message: apiErr.Message,
Resource: resource,
RequestID: requestID,
}
out, err := xml.MarshalIndent(payload, "", " ")