Applied Copilot review suggestions

This commit is contained in:
2026-02-23 22:35:42 +01:00
parent a8204de914
commit d9a1bd9001
5 changed files with 79 additions and 53 deletions

View File

@@ -98,6 +98,7 @@ func (h *Handler) handleGetObject(w http.ResponseWriter, r *http.Request) {
writeMappedS3Error(w, r, err) writeMappedS3Error(w, r, err)
return return
} }
defer stream.Close()
w.Header().Set("Content-Type", manifest.ContentType) w.Header().Set("Content-Type", manifest.ContentType)
w.Header().Set("Content-Length", strconv.FormatInt(manifest.Size, 10)) w.Header().Set("Content-Length", strconv.FormatInt(manifest.Size, 10))
@@ -116,12 +117,7 @@ func (h *Handler) handlePostObject(w http.ResponseWriter, r *http.Request) {
writeS3Error(w, r, s3ErrInvalidObjectKey, r.URL.Path) writeS3Error(w, r, s3ErrInvalidObjectKey, r.URL.Path)
return return
} }
defer func(Body io.ReadCloser) { defer r.Body.Close()
err := Body.Close()
if err != nil {
}
}(r.Body)
if _, ok := r.URL.Query()["uploads"]; ok { if _, ok := r.URL.Query()["uploads"]; ok {
upload, err := h.svc.CreateMultipartUpload(bucket, key) upload, err := h.svc.CreateMultipartUpload(bucket, key)
@@ -190,17 +186,7 @@ func (h *Handler) handlePutObject(w http.ResponseWriter, r *http.Request) {
writeS3Error(w, r, s3ErrInvalidObjectKey, r.URL.Path) writeS3Error(w, r, s3ErrInvalidObjectKey, r.URL.Path)
return return
} }
defer func(Body io.ReadCloser) { defer r.Body.Close()
err := Body.Close()
if err != nil {
}
}(r.Body)
bodyReader := io.Reader(r.Body)
if shouldDecodeAWSChunkedPayload(r) {
bodyReader = newAWSChunkedDecodingReader(r.Body)
}
uploadID := r.URL.Query().Get("uploadId") uploadID := r.URL.Query().Get("uploadId")
partNumberRaw := r.URL.Query().Get("partNumber") partNumberRaw := r.URL.Query().Get("partNumber")
@@ -215,6 +201,18 @@ func (h *Handler) handlePutObject(w http.ResponseWriter, r *http.Request) {
writeS3Error(w, r, s3ErrInvalidPart, r.URL.Path) writeS3Error(w, r, s3ErrInvalidPart, r.URL.Path)
return return
} }
if partNumber < 1 || partNumber > 10000 {
writeS3Error(w, r, s3ErrInvalidPart, r.URL.Path)
return
}
bodyReader := io.Reader(r.Body)
var decodeStream io.ReadCloser
if shouldDecodeAWSChunkedPayload(r) {
decodeStream = newAWSChunkedDecodingReader(r.Body)
defer decodeStream.Close()
bodyReader = decodeStream
}
etag, err := h.svc.UploadPart(bucket, key, uploadID, partNumber, bodyReader) etag, err := h.svc.UploadPart(bucket, key, uploadID, partNumber, bodyReader)
if err != nil { if err != nil {
@@ -232,6 +230,14 @@ func (h *Handler) handlePutObject(w http.ResponseWriter, r *http.Request) {
contentType = "application/octet-stream" contentType = "application/octet-stream"
} }
bodyReader := io.Reader(r.Body)
var decodeStream io.ReadCloser
if shouldDecodeAWSChunkedPayload(r) {
decodeStream = newAWSChunkedDecodingReader(r.Body)
defer decodeStream.Close()
bodyReader = decodeStream
}
manifest, err := h.svc.PutObject(bucket, key, contentType, bodyReader) manifest, err := h.svc.PutObject(bucket, key, contentType, bodyReader)
if err != nil { if err != nil {
@@ -289,7 +295,7 @@ func shouldDecodeAWSChunkedPayload(r *http.Request) bool {
return strings.HasPrefix(signingMode, "streaming-aws4-hmac-sha256-payload") return strings.HasPrefix(signingMode, "streaming-aws4-hmac-sha256-payload")
} }
func newAWSChunkedDecodingReader(src io.Reader) io.Reader { func newAWSChunkedDecodingReader(src io.Reader) io.ReadCloser {
pr, pw := io.Pipe() pr, pw := io.Pipe()
go func() { go func() {
if err := decodeAWSChunkedPayload(src, pw); err != nil { if err := decodeAWSChunkedPayload(src, pw); err != nil {
@@ -363,7 +369,7 @@ func (h *Handler) handleHeadObject(w http.ResponseWriter, r *http.Request) {
return return
} }
etag := manifest.ETag etag := manifest.ETag
size := strconv.Itoa(int(manifest.Size)) size := strconv.FormatInt(manifest.Size, 10)
w.Header().Set("ETag", `"`+etag+`"`) w.Header().Set("ETag", `"`+etag+`"`)
w.Header().Set("Content-Length", size) w.Header().Set("Content-Length", size)
@@ -395,16 +401,14 @@ func (h *Handler) handlePostBucket(w http.ResponseWriter, r *http.Request) {
writeS3Error(w, r, s3ErrNotImplemented, r.URL.Path) writeS3Error(w, r, s3ErrNotImplemented, r.URL.Path)
return return
} }
defer func(Body io.ReadCloser) { defer r.Body.Close()
err := Body.Close()
if err != nil {
}
}(r.Body)
bodyReader := io.Reader(r.Body) bodyReader := io.Reader(r.Body)
var decodeStream io.ReadCloser
if shouldDecodeAWSChunkedPayload(r) { if shouldDecodeAWSChunkedPayload(r) {
bodyReader = newAWSChunkedDecodingReader(r.Body) decodeStream = newAWSChunkedDecodingReader(r.Body)
defer decodeStream.Close()
bodyReader = decodeStream
} }
var req models.DeleteObjectsRequest var req models.DeleteObjectsRequest

View File

@@ -8,6 +8,8 @@ import (
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/go-chi/chi/v5/middleware"
) )
type Config struct { type Config struct {
@@ -83,7 +85,7 @@ func HTTPMiddleware(logger *slog.Logger, cfg Config) func(http.Handler) http.Han
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now() start := time.Now()
ww := &responseWriter{ResponseWriter: w, status: http.StatusOK} ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
next.ServeHTTP(ww, r) next.ServeHTTP(ww, r)
@@ -92,11 +94,15 @@ func HTTPMiddleware(logger *slog.Logger, cfg Config) func(http.Handler) http.Han
} }
elapsed := time.Since(start) elapsed := time.Since(start)
status := ww.Status()
if status == 0 {
status = http.StatusOK
}
attrs := []any{ attrs := []any{
"method", r.Method, "method", r.Method,
"path", r.URL.Path, "path", r.URL.Path,
"status", ww.status, "status", status,
"bytes", ww.bytes, "bytes", ww.BytesWritten(),
"duration_ms", float64(elapsed.Nanoseconds()) / 1_000_000.0, "duration_ms", float64(elapsed.Nanoseconds()) / 1_000_000.0,
"remote_addr", r.RemoteAddr, "remote_addr", r.RemoteAddr,
} }
@@ -118,23 +124,6 @@ func HTTPMiddleware(logger *slog.Logger, cfg Config) func(http.Handler) http.Han
} }
} }
type responseWriter struct {
http.ResponseWriter
status int
bytes int
}
func (w *responseWriter) WriteHeader(statusCode int) {
w.status = statusCode
w.ResponseWriter.WriteHeader(statusCode)
}
func (w *responseWriter) Write(p []byte) (int, error) {
n, err := w.ResponseWriter.Write(p)
w.bytes += n
return n, err
}
func envBool(key string, defaultValue bool) bool { func envBool(key string, defaultValue bool) bool {
raw := os.Getenv(key) raw := os.Getenv(key)
if raw == "" { if raw == "" {

View File

@@ -36,6 +36,7 @@ func main() {
} }
blobHandler, err := storage.NewBlobStore(config.DataPath, config.ChunkSize) blobHandler, err := storage.NewBlobStore(config.DataPath, config.ChunkSize)
if err != nil { if err != nil {
_ = metadataHandler.Close()
logger.Error("failed_to_initialize_blob_store", "error", err) logger.Error("failed_to_initialize_blob_store", "error", err)
return return
} }

View File

@@ -5,13 +5,15 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt"
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
"strings"
) )
const chunkSize = 64 * 1024 const blobRoot = "blobs"
const blobRoot = "blobs/" const maxChunkSize = 64 * 1024 * 1024
type BlobStore struct { type BlobStore struct {
dataRoot string dataRoot string
@@ -19,10 +21,19 @@ type BlobStore struct {
} }
func NewBlobStore(root string, chunkSize int) (*BlobStore, error) { func NewBlobStore(root string, chunkSize int) (*BlobStore, error) {
if err := os.MkdirAll(filepath.Join(root, blobRoot), 0o755); err != nil { root = strings.TrimSpace(root)
if root == "" {
return nil, errors.New("blob root is required")
}
if chunkSize <= 0 || chunkSize > maxChunkSize {
return nil, fmt.Errorf("chunk size must be between 1 and %d bytes", maxChunkSize)
}
cleanRoot := filepath.Clean(root)
if err := os.MkdirAll(filepath.Join(cleanRoot, blobRoot), 0o755); err != nil {
return nil, err return nil, err
} }
return &BlobStore{chunkSize: chunkSize, dataRoot: root}, nil return &BlobStore{chunkSize: chunkSize, dataRoot: cleanRoot}, nil
} }
func (bs *BlobStore) IngestStream(stream io.Reader) ([]string, int64, string, error) { func (bs *BlobStore) IngestStream(stream io.Reader) ([]string, int64, string, error) {
@@ -67,6 +78,9 @@ func (bs *BlobStore) IngestStream(stream io.Reader) ([]string, int64, string, er
} }
func (bs *BlobStore) saveBlob(chunkID string, data []byte) error { func (bs *BlobStore) saveBlob(chunkID string, data []byte) error {
if !isValidChunkID(chunkID) {
return fmt.Errorf("invalid chunk id: %q", chunkID)
}
dir := filepath.Join(bs.dataRoot, blobRoot, chunkID[:2], chunkID[2:4]) dir := filepath.Join(bs.dataRoot, blobRoot, chunkID[:2], chunkID[2:4])
if err := os.MkdirAll(dir, 0755); err != nil { if err := os.MkdirAll(dir, 0755); err != nil {
return err return err
@@ -95,5 +109,20 @@ func (bs *BlobStore) AssembleStream(chunkIDs []string, w *io.PipeWriter) error {
} }
func (bs *BlobStore) GetBlob(chunkID string) ([]byte, error) { func (bs *BlobStore) GetBlob(chunkID string) ([]byte, error) {
if !isValidChunkID(chunkID) {
return nil, fmt.Errorf("invalid chunk id: %q", chunkID)
}
return os.ReadFile(filepath.Join(bs.dataRoot, blobRoot, chunkID[:2], chunkID[2:4], chunkID)) return os.ReadFile(filepath.Join(bs.dataRoot, blobRoot, chunkID[:2], chunkID[2:4], chunkID))
} }
func isValidChunkID(chunkID string) bool {
if len(chunkID) != sha256.Size*2 {
return false
}
for _, ch := range chunkID {
if (ch < '0' || ch > '9') && (ch < 'a' || ch > 'f') {
return false
}
}
return true
}

View File

@@ -25,8 +25,8 @@ func NewConfig() *Config {
config := &Config{ config := &Config{
DataPath: sanitizeDataPath(os.Getenv("DATA_PATH")), DataPath: sanitizeDataPath(os.Getenv("DATA_PATH")),
Address: firstNonEmpty(strings.TrimSpace(os.Getenv("ADDRESS")), "0.0.0.0"), Address: firstNonEmpty(strings.TrimSpace(os.Getenv("ADDRESS")), "0.0.0.0"),
Port: envInt("PORT", 3000), Port: envIntRange("PORT", 3000, 1, 65535),
ChunkSize: envInt("CHUNK_SIZE", 8192000), ChunkSize: envIntRange("CHUNK_SIZE", 8192000, 1, 64*1024*1024),
LogLevel: strings.ToLower(firstNonEmpty(strings.TrimSpace(os.Getenv("LOG_LEVEL")), "info")), LogLevel: strings.ToLower(firstNonEmpty(strings.TrimSpace(os.Getenv("LOG_LEVEL")), "info")),
LogFormat: strings.ToLower(firstNonEmpty(strings.TrimSpace(os.Getenv("LOG_FORMAT")), strings.TrimSpace(os.Getenv("LOG_TYPE")), "text")), LogFormat: strings.ToLower(firstNonEmpty(strings.TrimSpace(os.Getenv("LOG_FORMAT")), strings.TrimSpace(os.Getenv("LOG_TYPE")), "text")),
AuditLog: envBool("AUDIT_LOG", true), AuditLog: envBool("AUDIT_LOG", true),
@@ -40,7 +40,7 @@ func NewConfig() *Config {
} }
func envInt(key string, defaultValue int) int { func envIntRange(key string, defaultValue, minValue, maxValue int) int {
raw := strings.TrimSpace(os.Getenv(key)) raw := strings.TrimSpace(os.Getenv(key))
if raw == "" { if raw == "" {
return defaultValue return defaultValue
@@ -49,6 +49,9 @@ func envInt(key string, defaultValue int) int {
if err != nil { if err != nil {
return defaultValue return defaultValue
} }
if value < minValue || value > maxValue {
return defaultValue
}
return value return value
} }