diff --git a/.env.example b/.env.example index d150fde..abb7639 100644 --- a/.env.example +++ b/.env.example @@ -1,6 +1,7 @@ LOG_LEVEL=debug LOG_FORMAT=text DATA_PATH=data/ +FS_MAX_OBJECT_UPLOAD_BYTES=5368709120 PORT=2600 AUDIT_LOG=true ADDRESS=0.0.0.0 diff --git a/api/s3_errors.go b/api/s3_errors.go index 95e7424..cdf882f 100644 --- a/api/s3_errors.go +++ b/api/s3_errors.go @@ -174,6 +174,8 @@ func mapToS3Error(err error) s3APIError { return s3ErrMalformedXML case errors.Is(err, service.ErrEntityTooSmall): return s3ErrEntityTooSmall + case errors.Is(err, service.ErrEntityTooLarge): + return s3ErrEntityTooLarge case errors.Is(err, auth.ErrAccessDenied): return s3ErrAccessDenied case errors.Is(err, auth.ErrInvalidAccessKeyID): diff --git a/api/upload_limit_test.go b/api/upload_limit_test.go new file mode 100644 index 0000000..fd6fda4 --- /dev/null +++ b/api/upload_limit_test.go @@ -0,0 +1,79 @@ +package api + +import ( + "bytes" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "path/filepath" + "strings" + "testing" + "time" + + "fs/logging" + "fs/metadata" + "fs/service" + "fs/storage" +) + +func TestPutObjectReturnsEntityTooLarge(t *testing.T) { + handler, svc := newUploadLimitHandler(t, 4) + if err := svc.CreateBucket("test-bucket"); err != nil { + t.Fatalf("CreateBucket: %v", err) + } + + req := httptest.NewRequest(http.MethodPut, "/test-bucket/too-large.txt", strings.NewReader("12345")) + rec := httptest.NewRecorder() + handler.router.ServeHTTP(rec, req) + + if rec.Code != http.StatusRequestEntityTooLarge { + t.Fatalf("status = %d, want %d body=%s", rec.Code, http.StatusRequestEntityTooLarge, rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), "EntityTooLarge") { + t.Fatalf("expected EntityTooLarge response, body=%s", rec.Body.String()) + } +} + +func TestUploadPartReturnsEntityTooLarge(t *testing.T) { + handler, svc := newUploadLimitHandler(t, 4) + if err := svc.CreateBucket("test-bucket"); err != nil { + t.Fatalf("CreateBucket: %v", err) + } + upload, err := svc.CreateMultipartUpload("test-bucket", "object.txt") + if err != nil { + t.Fatalf("CreateMultipartUpload: %v", err) + } + + req := httptest.NewRequest(http.MethodPut, "/test-bucket/object.txt?partNumber=1&uploadId="+upload.UploadID, bytes.NewReader([]byte("12345"))) + rec := httptest.NewRecorder() + handler.router.ServeHTTP(rec, req) + + if rec.Code != http.StatusRequestEntityTooLarge { + t.Fatalf("status = %d, want %d body=%s", rec.Code, http.StatusRequestEntityTooLarge, rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), "EntityTooLarge") { + t.Fatalf("expected EntityTooLarge response, body=%s", rec.Body.String()) + } +} + +func newUploadLimitHandler(t *testing.T, maxUploadSize int64) (*Handler, *service.ObjectService) { + t.Helper() + root := t.TempDir() + md, err := metadata.NewMetadataHandler(filepath.Join(root, "metadata.db")) + if err != nil { + t.Fatalf("new metadata handler: %v", err) + } + blob, err := storage.NewBlobStore(root, 4) + if err != nil { + t.Fatalf("new blob store: %v", err) + } + svc := service.NewObjectService(md, blob, time.Hour, maxUploadSize) + t.Cleanup(func() { + _ = svc.Close() + }) + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + handler := NewHandler(svc, logger, logging.Config{}, nil, false) + handler.setupRoutes() + return handler, svc +} diff --git a/app/server.go b/app/server.go index 2effadc..19cef92 100644 --- a/app/server.go +++ b/app/server.go @@ -39,6 +39,7 @@ func RunServer(ctx context.Context) error { "audit_log", logConfig.Audit, "data_path", config.DataPath, "multipart_retention_hours", int(config.MultipartCleanupRetention/time.Hour), + "max_object_upload_bytes", config.MaxObjectUploadBytes, "auth_enabled", authConfig.Enabled, "auth_region", authConfig.Region, "admin_api_enabled", config.AdminAPIEnabled, @@ -63,7 +64,7 @@ func RunServer(ctx context.Context) error { return err } - objectService := service.NewObjectService(metadataHandler, blobHandler, config.MultipartCleanupRetention) + objectService := service.NewObjectService(metadataHandler, blobHandler, config.MultipartCleanupRetention, config.MaxObjectUploadBytes) authService, err := auth.NewService(authConfig, metadataHandler) if err != nil { _ = metadataHandler.Close() diff --git a/metadata/metadata.go b/metadata/metadata.go index ed84c55..ec9621a 100644 --- a/metadata/metadata.go +++ b/metadata/metadata.go @@ -902,9 +902,6 @@ func (h *MetadataHandler) CleanupMultipartUploads(retention time.Duration) (int, if err := json.Unmarshal(v, &upload); err != nil { return err } - if upload.State == "pending" { - return nil - } createdAt, err := time.Parse(time.RFC3339, upload.CreatedAt) if err != nil { return nil diff --git a/metadata/metadata_test.go b/metadata/metadata_test.go new file mode 100644 index 0000000..dfecdd9 --- /dev/null +++ b/metadata/metadata_test.go @@ -0,0 +1,99 @@ +package metadata + +import ( + "errors" + "fs/models" + "path/filepath" + "testing" + "time" + + "go.etcd.io/bbolt" +) + +func TestCleanupMultipartUploadsDeletesExpiredPendingUpload(t *testing.T) { + h := newTestMetadataHandler(t) + if err := h.CreateBucket("test-bucket"); err != nil { + t.Fatalf("CreateBucket: %v", err) + } + upload, err := h.CreateMultipartUpload("test-bucket", "object.txt") + if err != nil { + t.Fatalf("CreateMultipartUpload: %v", err) + } + if err := h.PutMultipartPart(upload.UploadID, models.UploadedPart{PartNumber: 1, ETag: "etag", Size: 4, Chunks: []string{"chunk-id"}}); err != nil { + t.Fatalf("PutMultipartPart: %v", err) + } + setMultipartUploadCreatedAt(t, h, upload.UploadID, time.Now().Add(-2*time.Hour)) + + cleaned, err := h.CleanupMultipartUploads(time.Hour) + if err != nil { + t.Fatalf("CleanupMultipartUploads: %v", err) + } + if cleaned != 1 { + t.Fatalf("cleaned = %d, want 1", cleaned) + } + if _, err := h.GetMultipartUpload(upload.UploadID); !errors.Is(err, ErrMultipartNotFound) { + t.Fatalf("GetMultipartUpload error = %v, want ErrMultipartNotFound", err) + } + if _, err := h.ListMultipartParts(upload.UploadID); !errors.Is(err, ErrMultipartNotFound) { + t.Fatalf("ListMultipartParts error = %v, want ErrMultipartNotFound", err) + } +} + +func TestCleanupMultipartUploadsKeepsRecentPendingUpload(t *testing.T) { + h := newTestMetadataHandler(t) + if err := h.CreateBucket("test-bucket"); err != nil { + t.Fatalf("CreateBucket: %v", err) + } + upload, err := h.CreateMultipartUpload("test-bucket", "object.txt") + if err != nil { + t.Fatalf("CreateMultipartUpload: %v", err) + } + + cleaned, err := h.CleanupMultipartUploads(time.Hour) + if err != nil { + t.Fatalf("CleanupMultipartUploads: %v", err) + } + if cleaned != 0 { + t.Fatalf("cleaned = %d, want 0", cleaned) + } + if _, err := h.GetMultipartUpload(upload.UploadID); err != nil { + t.Fatalf("recent upload should remain: %v", err) + } +} + +func TestCleanupMultipartUploadsDisabledForNonPositiveRetention(t *testing.T) { + h := newTestMetadataHandler(t) + cleaned, err := h.CleanupMultipartUploads(0) + if err != nil { + t.Fatalf("CleanupMultipartUploads: %v", err) + } + if cleaned != 0 { + t.Fatalf("cleaned = %d, want 0", cleaned) + } +} + +func newTestMetadataHandler(t *testing.T) *MetadataHandler { + t.Helper() + h, err := NewMetadataHandler(filepath.Join(t.TempDir(), "metadata.db")) + if err != nil { + t.Fatalf("NewMetadataHandler: %v", err) + } + t.Cleanup(func() { + _ = h.Close() + }) + return h +} + +func setMultipartUploadCreatedAt(t *testing.T, h *MetadataHandler, uploadID string, createdAt time.Time) { + t.Helper() + if err := h.update(func(tx *bbolt.Tx) error { + upload, uploadsBucket, err := getMultipartUploadFromTx(tx, uploadID) + if err != nil { + return err + } + upload.CreatedAt = createdAt.UTC().Format(time.RFC3339) + return putMultipartUpload(uploadsBucket, uploadID, upload) + }); err != nil { + t.Fatalf("set multipart created_at: %v", err) + } +} diff --git a/service/service.go b/service/service.go index 65dcead..1f754c8 100644 --- a/service/service.go +++ b/service/service.go @@ -21,6 +21,7 @@ type ObjectService struct { metadata *metadata.MetadataHandler blob *storage.BlobStore multipartRetention time.Duration + maxUploadSize int64 gcMu sync.RWMutex } @@ -29,16 +30,24 @@ var ( ErrInvalidPartOrder = errors.New("invalid multipart part order") ErrInvalidCompleteRequest = errors.New("invalid complete multipart request") ErrEntityTooSmall = errors.New("multipart entity too small") + ErrEntityTooLarge = errors.New("entity too large") ) -func NewObjectService(metadataHandler *metadata.MetadataHandler, blobHandler *storage.BlobStore, multipartRetention time.Duration) *ObjectService { +const DefaultMaxUploadSize int64 = 5 * 1024 * 1024 * 1024 + +func NewObjectService(metadataHandler *metadata.MetadataHandler, blobHandler *storage.BlobStore, multipartRetention time.Duration, maxUploadSize ...int64) *ObjectService { if multipartRetention <= 0 { multipartRetention = 24 * time.Hour } + limit := DefaultMaxUploadSize + if len(maxUploadSize) > 0 { + limit = maxUploadSize[0] + } return &ObjectService{ metadata: metadataHandler, blob: blobHandler, multipartRetention: multipartRetention, + maxUploadSize: limit, } } @@ -74,7 +83,7 @@ func (s *ObjectService) PutObject(bucket, key, contentType string, input io.Read unlock := s.acquireGCRLock() defer unlock() - chunks, size, etag, err := s.blob.IngestStream(input) + chunks, size, etag, err := s.blob.IngestStream(s.limitUpload(input)) if err != nil { return nil, err } @@ -158,7 +167,9 @@ func (s *ObjectService) GetObject(bucket, key string) (io.ReadCloser, *models.Ob defer func() { metrics.Default.ObserveService("get_object", time.Since(start), streamOK) }() - defer metrics.Default.ObserveLockHold("gc_mu_read", time.Since(holdStart)) + defer func() { + metrics.Default.ObserveLockHold("gc_mu_read", time.Since(holdStart)) + }() defer s.gcMu.RUnlock() if err := s.blob.AssembleStream(manifest.Chunks, pw); err != nil { _ = pw.CloseWithError(err) @@ -311,7 +322,7 @@ func (s *ObjectService) UploadPart(bucket, key, uploadId string, partNumber int, } var uploadedPart models.UploadedPart - chunkIds, totalSize, etag, err := s.blob.IngestStream(input) + chunkIds, totalSize, etag, err := s.blob.IngestStream(s.limitUpload(input)) if err != nil { return "", err } @@ -400,6 +411,9 @@ func (s *ObjectService) CompleteMultipartUpload(bucket, key, uploadID string, co orderedParts = append(orderedParts, storedPart) chunks = append(chunks, storedPart.Chunks...) totalSize += storedPart.Size + if s.maxUploadSize > 0 && totalSize > s.maxUploadSize { + return nil, ErrEntityTooLarge + } } finalETag := buildMultipartETag(orderedParts) @@ -435,6 +449,40 @@ func (s *ObjectService) AbortMultipartUpload(bucket, key, uploadID string) error return s.metadata.AbortMultipartUpload(uploadID) } +func (s *ObjectService) limitUpload(input io.Reader) io.Reader { + if s.maxUploadSize <= 0 || input == nil { + return input + } + return &maxBytesReader{inner: input, remaining: s.maxUploadSize} +} + +type maxBytesReader struct { + inner io.Reader + remaining int64 + tooLarge bool +} + +func (r *maxBytesReader) Read(p []byte) (int, error) { + if r.tooLarge { + return 0, ErrEntityTooLarge + } + if r.remaining <= 0 { + var probe [1]byte + n, err := r.inner.Read(probe[:]) + if n > 0 { + r.tooLarge = true + return 0, ErrEntityTooLarge + } + return 0, err + } + if int64(len(p)) > r.remaining { + p = p[:r.remaining] + } + n, err := r.inner.Read(p) + r.remaining -= int64(n) + return n, err +} + func normalizeETag(etag string) string { return strings.Trim(etag, "\"") } @@ -469,6 +517,12 @@ func (s *ObjectService) GarbageCollect() error { unlock := s.acquireGCLock() defer unlock() + var err error + cleanedUploads, err = s.metadata.CleanupMultipartUploads(s.multipartRetention) + if err != nil { + return err + } + referencedChunkSet, err := s.metadata.GetReferencedChunkSet() if err != nil { return err @@ -492,11 +546,6 @@ func (s *ObjectService) GarbageCollect() error { return err } - cleanedUploads, err = s.metadata.CleanupMultipartUploads(s.multipartRetention) - if err != nil { - return err - } - slog.Info("garbage_collect_completed", "referenced_chunks", len(referencedChunkSet), "total_chunks", totalChunks, diff --git a/service/upload_limit_test.go b/service/upload_limit_test.go new file mode 100644 index 0000000..dd81250 --- /dev/null +++ b/service/upload_limit_test.go @@ -0,0 +1,119 @@ +package service + +import ( + "errors" + "fs/metadata" + "fs/storage" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestPutObjectRejectsOversizedUpload(t *testing.T) { + svc := newTestObjectService(t, 4) + if err := svc.CreateBucket("test-bucket"); err != nil { + t.Fatalf("CreateBucket: %v", err) + } + + _, err := svc.PutObject("test-bucket", "too-large.txt", "text/plain", strings.NewReader("12345")) + if !errors.Is(err, ErrEntityTooLarge) { + t.Fatalf("PutObject error = %v, want ErrEntityTooLarge", err) + } + if _, err := svc.HeadObject("test-bucket", "too-large.txt"); !errors.Is(err, metadata.ErrObjectNotFound) { + t.Fatalf("HeadObject error = %v, want ErrObjectNotFound", err) + } +} + +func TestPutObjectAllowsExactUploadLimit(t *testing.T) { + svc := newTestObjectService(t, 4) + if err := svc.CreateBucket("test-bucket"); err != nil { + t.Fatalf("CreateBucket: %v", err) + } + + manifest, err := svc.PutObject("test-bucket", "exact.txt", "text/plain", strings.NewReader("1234")) + if err != nil { + t.Fatalf("PutObject: %v", err) + } + if manifest.Size != 4 { + t.Fatalf("manifest size = %d, want 4", manifest.Size) + } +} + +func TestUploadPartRejectsOversizedUpload(t *testing.T) { + svc := newTestObjectService(t, 4) + if err := svc.CreateBucket("test-bucket"); err != nil { + t.Fatalf("CreateBucket: %v", err) + } + upload, err := svc.CreateMultipartUpload("test-bucket", "object.txt") + if err != nil { + t.Fatalf("CreateMultipartUpload: %v", err) + } + + _, err = svc.UploadPart("test-bucket", "object.txt", upload.UploadID, 1, strings.NewReader("12345")) + if !errors.Is(err, ErrEntityTooLarge) { + t.Fatalf("UploadPart error = %v, want ErrEntityTooLarge", err) + } + parts, err := svc.ListMultipartParts("test-bucket", "object.txt", upload.UploadID) + if err != nil { + t.Fatalf("ListMultipartParts: %v", err) + } + if len(parts) != 0 { + t.Fatalf("stored parts = %d, want 0", len(parts)) + } +} + +func TestGarbageCollectRemovesExpiredPendingMultipartChunks(t *testing.T) { + svc := newTestObjectService(t, 1024) + svc.multipartRetention = time.Nanosecond + if err := svc.CreateBucket("test-bucket"); err != nil { + t.Fatalf("CreateBucket: %v", err) + } + upload, err := svc.CreateMultipartUpload("test-bucket", "object.txt") + if err != nil { + t.Fatalf("CreateMultipartUpload: %v", err) + } + if _, err := svc.UploadPart("test-bucket", "object.txt", upload.UploadID, 1, strings.NewReader("part-data")); err != nil { + t.Fatalf("UploadPart: %v", err) + } + chunks, err := svc.blob.ListChunks() + if err != nil { + t.Fatalf("ListChunks before GC: %v", err) + } + if len(chunks) == 0 { + t.Fatalf("expected uploaded part chunks") + } + time.Sleep(time.Millisecond) + + if err := svc.GarbageCollect(); err != nil { + t.Fatalf("GarbageCollect: %v", err) + } + if _, err := svc.metadata.GetMultipartUpload(upload.UploadID); !errors.Is(err, metadata.ErrMultipartNotFound) { + t.Fatalf("GetMultipartUpload error = %v, want ErrMultipartNotFound", err) + } + chunks, err = svc.blob.ListChunks() + if err != nil { + t.Fatalf("ListChunks after GC: %v", err) + } + if len(chunks) != 0 { + t.Fatalf("chunks after GC = %d, want 0", len(chunks)) + } +} + +func newTestObjectService(t *testing.T, maxUploadSize int64) *ObjectService { + t.Helper() + root := t.TempDir() + md, err := metadata.NewMetadataHandler(filepath.Join(root, "metadata.db")) + if err != nil { + t.Fatalf("NewMetadataHandler: %v", err) + } + blob, err := storage.NewBlobStore(root, 4) + if err != nil { + t.Fatalf("NewBlobStore: %v", err) + } + svc := NewObjectService(md, blob, time.Hour, maxUploadSize) + t.Cleanup(func() { + _ = svc.Close() + }) + return svc +} diff --git a/utils/config.go b/utils/config.go index c758b5e..fc7c2b9 100644 --- a/utils/config.go +++ b/utils/config.go @@ -15,6 +15,7 @@ type Config struct { Address string Port int ChunkSize int + MaxObjectUploadBytes int64 LogLevel string LogFormat string AuditLog bool @@ -36,15 +37,16 @@ func NewConfig() *Config { _ = godotenv.Load() config := &Config{ - DataPath: sanitizeDataPath(os.Getenv("DATA_PATH")), - Address: firstNonEmpty(strings.TrimSpace(os.Getenv("ADDRESS")), "0.0.0.0"), - Port: envIntRange("PORT", 2600, 1, 65535), - ChunkSize: envIntRange("CHUNK_SIZE", 8192000, 1, 64*1024*1024), - LogLevel: strings.ToLower(firstNonEmpty(strings.TrimSpace(os.Getenv("LOG_LEVEL")), "info")), - LogFormat: strings.ToLower(firstNonEmpty(strings.TrimSpace(os.Getenv("LOG_FORMAT")), strings.TrimSpace(os.Getenv("LOG_TYPE")), "text")), - AuditLog: envBool("AUDIT_LOG", true), - GcInterval: time.Duration(envIntRange("GC_INTERVAL", 10, 1, 60)) * time.Minute, - GcEnabled: envBool("GC_ENABLED", true), + DataPath: sanitizeDataPath(os.Getenv("DATA_PATH")), + Address: firstNonEmpty(strings.TrimSpace(os.Getenv("ADDRESS")), "0.0.0.0"), + Port: envIntRange("PORT", 2600, 1, 65535), + ChunkSize: envIntRange("CHUNK_SIZE", 8192000, 1, 64*1024*1024), + MaxObjectUploadBytes: envInt64Range("FS_MAX_OBJECT_UPLOAD_BYTES", 5*1024*1024*1024, 1, 5*1024*1024*1024), + LogLevel: strings.ToLower(firstNonEmpty(strings.TrimSpace(os.Getenv("LOG_LEVEL")), "info")), + LogFormat: strings.ToLower(firstNonEmpty(strings.TrimSpace(os.Getenv("LOG_FORMAT")), strings.TrimSpace(os.Getenv("LOG_TYPE")), "text")), + AuditLog: envBool("AUDIT_LOG", true), + GcInterval: time.Duration(envIntRange("GC_INTERVAL", 10, 1, 60)) * time.Minute, + GcEnabled: envBool("GC_ENABLED", true), MultipartCleanupRetention: time.Duration( envIntRange("MULTIPART_RETENTION_HOURS", 24, 1, 24*30), ) * time.Hour, @@ -82,6 +84,21 @@ func envIntRange(key string, defaultValue, minValue, maxValue int) int { return value } +func envInt64Range(key string, defaultValue, minValue, maxValue int64) int64 { + raw := strings.TrimSpace(os.Getenv(key)) + if raw == "" { + return defaultValue + } + value, err := strconv.ParseInt(raw, 10, 64) + if err != nil { + return defaultValue + } + if value < minValue || value > maxValue { + return defaultValue + } + return value +} + func envBool(key string, defaultValue bool) bool { raw := strings.TrimSpace(os.Getenv(key)) if raw == "" { diff --git a/utils/config_test.go b/utils/config_test.go new file mode 100644 index 0000000..ef30b80 --- /dev/null +++ b/utils/config_test.go @@ -0,0 +1,21 @@ +package utils + +import "testing" + +func TestEnvInt64Range(t *testing.T) { + t.Setenv("TEST_INT64_RANGE", "42") + if got := envInt64Range("TEST_INT64_RANGE", 10, 1, 100); got != 42 { + t.Fatalf("envInt64Range valid = %d, want 42", got) + } +} + +func TestEnvInt64RangeFallsBackForInvalidValues(t *testing.T) { + t.Setenv("TEST_INT64_RANGE", "invalid") + if got := envInt64Range("TEST_INT64_RANGE", 10, 1, 100); got != 10 { + t.Fatalf("envInt64Range invalid = %d, want 10", got) + } + t.Setenv("TEST_INT64_RANGE", "101") + if got := envInt64Range("TEST_INT64_RANGE", 10, 1, 100); got != 10 { + t.Fatalf("envInt64Range too large = %d, want 10", got) + } +}