From d9a1bd9001e93c1dbb39f303a45a55ceb881f301 Mon Sep 17 00:00:00 2001 From: Andrej Mickov Date: Mon, 23 Feb 2026 22:35:42 +0100 Subject: [PATCH] Applied Copilot review suggestions --- api/api.go | 56 +++++++++++++++++++++++++--------------------- logging/logging.go | 29 ++++++++---------------- main.go | 1 + storage/blob.go | 37 ++++++++++++++++++++++++++---- utils/config.go | 9 +++++--- 5 files changed, 79 insertions(+), 53 deletions(-) diff --git a/api/api.go b/api/api.go index 3de2450..ee70fc7 100644 --- a/api/api.go +++ b/api/api.go @@ -98,6 +98,7 @@ func (h *Handler) handleGetObject(w http.ResponseWriter, r *http.Request) { writeMappedS3Error(w, r, err) return } + defer stream.Close() w.Header().Set("Content-Type", manifest.ContentType) w.Header().Set("Content-Length", strconv.FormatInt(manifest.Size, 10)) @@ -116,12 +117,7 @@ func (h *Handler) handlePostObject(w http.ResponseWriter, r *http.Request) { writeS3Error(w, r, s3ErrInvalidObjectKey, r.URL.Path) return } - defer func(Body io.ReadCloser) { - err := Body.Close() - if err != nil { - - } - }(r.Body) + defer r.Body.Close() if _, ok := r.URL.Query()["uploads"]; ok { upload, err := h.svc.CreateMultipartUpload(bucket, key) @@ -190,17 +186,7 @@ func (h *Handler) handlePutObject(w http.ResponseWriter, r *http.Request) { writeS3Error(w, r, s3ErrInvalidObjectKey, r.URL.Path) return } - defer func(Body io.ReadCloser) { - err := Body.Close() - if err != nil { - - } - }(r.Body) - - bodyReader := io.Reader(r.Body) - if shouldDecodeAWSChunkedPayload(r) { - bodyReader = newAWSChunkedDecodingReader(r.Body) - } + defer r.Body.Close() uploadID := r.URL.Query().Get("uploadId") partNumberRaw := r.URL.Query().Get("partNumber") @@ -215,6 +201,18 @@ func (h *Handler) handlePutObject(w http.ResponseWriter, r *http.Request) { writeS3Error(w, r, s3ErrInvalidPart, r.URL.Path) return } + if partNumber < 1 || partNumber > 10000 { + writeS3Error(w, r, s3ErrInvalidPart, r.URL.Path) + return + } + + bodyReader := io.Reader(r.Body) + var decodeStream io.ReadCloser + if shouldDecodeAWSChunkedPayload(r) { + decodeStream = newAWSChunkedDecodingReader(r.Body) + defer decodeStream.Close() + bodyReader = decodeStream + } etag, err := h.svc.UploadPart(bucket, key, uploadID, partNumber, bodyReader) if err != nil { @@ -232,6 +230,14 @@ func (h *Handler) handlePutObject(w http.ResponseWriter, r *http.Request) { contentType = "application/octet-stream" } + bodyReader := io.Reader(r.Body) + var decodeStream io.ReadCloser + if shouldDecodeAWSChunkedPayload(r) { + decodeStream = newAWSChunkedDecodingReader(r.Body) + defer decodeStream.Close() + bodyReader = decodeStream + } + manifest, err := h.svc.PutObject(bucket, key, contentType, bodyReader) if err != nil { @@ -289,7 +295,7 @@ func shouldDecodeAWSChunkedPayload(r *http.Request) bool { return strings.HasPrefix(signingMode, "streaming-aws4-hmac-sha256-payload") } -func newAWSChunkedDecodingReader(src io.Reader) io.Reader { +func newAWSChunkedDecodingReader(src io.Reader) io.ReadCloser { pr, pw := io.Pipe() go func() { if err := decodeAWSChunkedPayload(src, pw); err != nil { @@ -363,7 +369,7 @@ func (h *Handler) handleHeadObject(w http.ResponseWriter, r *http.Request) { return } etag := manifest.ETag - size := strconv.Itoa(int(manifest.Size)) + size := strconv.FormatInt(manifest.Size, 10) w.Header().Set("ETag", `"`+etag+`"`) w.Header().Set("Content-Length", size) @@ -395,16 +401,14 @@ func (h *Handler) handlePostBucket(w http.ResponseWriter, r *http.Request) { writeS3Error(w, r, s3ErrNotImplemented, r.URL.Path) return } - defer func(Body io.ReadCloser) { - err := Body.Close() - if err != nil { - - } - }(r.Body) + defer r.Body.Close() bodyReader := io.Reader(r.Body) + var decodeStream io.ReadCloser if shouldDecodeAWSChunkedPayload(r) { - bodyReader = newAWSChunkedDecodingReader(r.Body) + decodeStream = newAWSChunkedDecodingReader(r.Body) + defer decodeStream.Close() + bodyReader = decodeStream } var req models.DeleteObjectsRequest diff --git a/logging/logging.go b/logging/logging.go index 1edadf2..7290bf2 100644 --- a/logging/logging.go +++ b/logging/logging.go @@ -8,6 +8,8 @@ import ( "strconv" "strings" "time" + + "github.com/go-chi/chi/v5/middleware" ) type Config struct { @@ -83,7 +85,7 @@ func HTTPMiddleware(logger *slog.Logger, cfg Config) func(http.Handler) http.Han return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() - ww := &responseWriter{ResponseWriter: w, status: http.StatusOK} + ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) next.ServeHTTP(ww, r) @@ -92,11 +94,15 @@ func HTTPMiddleware(logger *slog.Logger, cfg Config) func(http.Handler) http.Han } elapsed := time.Since(start) + status := ww.Status() + if status == 0 { + status = http.StatusOK + } attrs := []any{ "method", r.Method, "path", r.URL.Path, - "status", ww.status, - "bytes", ww.bytes, + "status", status, + "bytes", ww.BytesWritten(), "duration_ms", float64(elapsed.Nanoseconds()) / 1_000_000.0, "remote_addr", r.RemoteAddr, } @@ -118,23 +124,6 @@ func HTTPMiddleware(logger *slog.Logger, cfg Config) func(http.Handler) http.Han } } -type responseWriter struct { - http.ResponseWriter - status int - bytes int -} - -func (w *responseWriter) WriteHeader(statusCode int) { - w.status = statusCode - w.ResponseWriter.WriteHeader(statusCode) -} - -func (w *responseWriter) Write(p []byte) (int, error) { - n, err := w.ResponseWriter.Write(p) - w.bytes += n - return n, err -} - func envBool(key string, defaultValue bool) bool { raw := os.Getenv(key) if raw == "" { diff --git a/main.go b/main.go index 251fd00..53a28ed 100644 --- a/main.go +++ b/main.go @@ -36,6 +36,7 @@ func main() { } blobHandler, err := storage.NewBlobStore(config.DataPath, config.ChunkSize) if err != nil { + _ = metadataHandler.Close() logger.Error("failed_to_initialize_blob_store", "error", err) return } diff --git a/storage/blob.go b/storage/blob.go index 4fae764..23d21f6 100644 --- a/storage/blob.go +++ b/storage/blob.go @@ -5,13 +5,15 @@ import ( "crypto/sha256" "encoding/hex" "errors" + "fmt" "io" "os" "path/filepath" + "strings" ) -const chunkSize = 64 * 1024 -const blobRoot = "blobs/" +const blobRoot = "blobs" +const maxChunkSize = 64 * 1024 * 1024 type BlobStore struct { dataRoot string @@ -19,10 +21,19 @@ type BlobStore struct { } func NewBlobStore(root string, chunkSize int) (*BlobStore, error) { - if err := os.MkdirAll(filepath.Join(root, blobRoot), 0o755); err != nil { + root = strings.TrimSpace(root) + if root == "" { + return nil, errors.New("blob root is required") + } + if chunkSize <= 0 || chunkSize > maxChunkSize { + return nil, fmt.Errorf("chunk size must be between 1 and %d bytes", maxChunkSize) + } + + cleanRoot := filepath.Clean(root) + if err := os.MkdirAll(filepath.Join(cleanRoot, blobRoot), 0o755); err != nil { return nil, err } - return &BlobStore{chunkSize: chunkSize, dataRoot: root}, nil + return &BlobStore{chunkSize: chunkSize, dataRoot: cleanRoot}, nil } func (bs *BlobStore) IngestStream(stream io.Reader) ([]string, int64, string, error) { @@ -67,6 +78,9 @@ func (bs *BlobStore) IngestStream(stream io.Reader) ([]string, int64, string, er } func (bs *BlobStore) saveBlob(chunkID string, data []byte) error { + if !isValidChunkID(chunkID) { + return fmt.Errorf("invalid chunk id: %q", chunkID) + } dir := filepath.Join(bs.dataRoot, blobRoot, chunkID[:2], chunkID[2:4]) if err := os.MkdirAll(dir, 0755); err != nil { return err @@ -95,5 +109,20 @@ func (bs *BlobStore) AssembleStream(chunkIDs []string, w *io.PipeWriter) error { } func (bs *BlobStore) GetBlob(chunkID string) ([]byte, error) { + if !isValidChunkID(chunkID) { + return nil, fmt.Errorf("invalid chunk id: %q", chunkID) + } return os.ReadFile(filepath.Join(bs.dataRoot, blobRoot, chunkID[:2], chunkID[2:4], chunkID)) } + +func isValidChunkID(chunkID string) bool { + if len(chunkID) != sha256.Size*2 { + return false + } + for _, ch := range chunkID { + if (ch < '0' || ch > '9') && (ch < 'a' || ch > 'f') { + return false + } + } + return true +} diff --git a/utils/config.go b/utils/config.go index f9f773c..a73e195 100644 --- a/utils/config.go +++ b/utils/config.go @@ -25,8 +25,8 @@ func NewConfig() *Config { config := &Config{ DataPath: sanitizeDataPath(os.Getenv("DATA_PATH")), Address: firstNonEmpty(strings.TrimSpace(os.Getenv("ADDRESS")), "0.0.0.0"), - Port: envInt("PORT", 3000), - ChunkSize: envInt("CHUNK_SIZE", 8192000), + Port: envIntRange("PORT", 3000, 1, 65535), + ChunkSize: envIntRange("CHUNK_SIZE", 8192000, 1, 64*1024*1024), LogLevel: strings.ToLower(firstNonEmpty(strings.TrimSpace(os.Getenv("LOG_LEVEL")), "info")), LogFormat: strings.ToLower(firstNonEmpty(strings.TrimSpace(os.Getenv("LOG_FORMAT")), strings.TrimSpace(os.Getenv("LOG_TYPE")), "text")), AuditLog: envBool("AUDIT_LOG", true), @@ -40,7 +40,7 @@ func NewConfig() *Config { } -func envInt(key string, defaultValue int) int { +func envIntRange(key string, defaultValue, minValue, maxValue int) int { raw := strings.TrimSpace(os.Getenv(key)) if raw == "" { return defaultValue @@ -49,6 +49,9 @@ func envInt(key string, defaultValue int) int { if err != nil { return defaultValue } + if value < minValue || value > maxValue { + return defaultValue + } return value }