20 Commits

Author SHA1 Message Date
f61cc3168b Reject unsupported aws-chunked uploads
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-05-16 10:24:32 +02:00
e928ebca15 Defer multi-delete authorization to handler
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-05-16 10:24:32 +02:00
654a505c0d Document S3 auth hardening
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-05-16 10:15:26 +02:00
0f9b461e8e Verify chunk integrity on read
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-05-16 10:11:25 +02:00
c3c9e3262f Add upload limits and multipart cleanup
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-05-16 10:11:15 +02:00
2425cd524e Harden S3 auth boundaries
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-05-16 10:11:04 +02:00
eac20f7fda Merge pull request #10 from ferdzo/hotfix/copy-object-fix-url-encoding-fix
HOTFIX: Copy object and key extraction improvement
2026-03-13 01:32:10 +01:00
9bfdceca08 HOTFIX: Copy object and encoding fixed 2026-03-13 01:29:29 +01:00
6473726a45 Merge pull request #9 from ferdzo/hotfix/streaming-error-and-equalsign-in-sigv4
HOTFIX: Fixed chunked stream problems with size and equal sign encoding
2026-03-13 00:27:57 +01:00
115d825234 HOTFIX: Fixed streaming problems with size and equal sign encoding 2026-03-13 00:25:48 +01:00
237063d9fc Merge pull request #8 from ferdzo/fix/bucket-creation-status201-to-200ok
Changed 201 Created to 200 OK when creating bucket
2026-03-12 00:29:05 +01:00
c2215d8589 Changed 201 Created to 200 OK when creating bucket 2026-03-12 00:28:01 +01:00
82cb58dff1 README.md edit 2026-03-11 23:39:25 +01:00
b592d6a2f0 Merge pull request #7 from ferdzo/feat/cli
Fs CLI
2026-03-11 20:27:22 +01:00
ef12326975 Copilot suggestions fixed 2026-03-11 20:26:17 +01:00
a23577d531 Update cmd/admin_snapshot.go
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-03-11 20:19:50 +01:00
e8256d66e3 Added Github Action for build and release. 2026-03-11 20:18:24 +01:00
ad53a6d8ac Backup/restore option for cli 2026-03-11 20:12:00 +01:00
cfb9b591ac Policy example documentation 2026-03-11 00:50:09 +01:00
b27f1186cf Remove Role command. 2026-03-11 00:47:24 +01:00
36 changed files with 2921 additions and 79 deletions

View File

@@ -1,6 +1,7 @@
LOG_LEVEL=debug LOG_LEVEL=debug
LOG_FORMAT=text LOG_FORMAT=text
DATA_PATH=data/ DATA_PATH=data/
FS_MAX_OBJECT_UPLOAD_BYTES=5368709120
PORT=2600 PORT=2600
AUDIT_LOG=true AUDIT_LOG=true
ADDRESS=0.0.0.0 ADDRESS=0.0.0.0

24
.github/workflows/ci.yml vendored Normal file
View File

@@ -0,0 +1,24 @@
name: CI
on:
push:
branches: ["main"]
pull_request:
jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version-file: go.mod
- name: Run tests
run: |
export GOCACHE=/tmp/go-build-cache
go test ./...

66
.github/workflows/release-image.yml vendored Normal file
View File

@@ -0,0 +1,66 @@
name: Release Image
on:
push:
tags:
- "v*.*.*"
workflow_dispatch:
permissions:
contents: read
packages: write
jobs:
docker:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set build date
id: date
run: echo "value=$(date -u +%Y-%m-%dT%H:%M:%SZ)" >> "$GITHUB_OUTPUT"
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to GHCR
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Docker metadata
id: meta
uses: docker/metadata-action@v5
with:
images: ghcr.io/${{ github.repository }}
tags: |
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=semver,pattern={{major}}
type=sha
labels: |
org.opencontainers.image.title=fs
org.opencontainers.image.description=Lightweight S3-compatible object storage
org.opencontainers.image.source=https://github.com/${{ github.repository }}
org.opencontainers.image.revision=${{ github.sha }}
org.opencontainers.image.created=${{ steps.date.outputs.value }}
- name: Build and push image
uses: docker/build-push-action@v6
with:
context: .
platforms: linux/amd64,linux/arm64
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
build-args: |
VERSION=${{ github.ref_name }}
COMMIT=${{ github.sha }}
DATE=${{ steps.date.outputs.value }}

View File

@@ -6,7 +6,13 @@ COPY go.mod go.sum ./
RUN go mod download RUN go mod download
COPY . . COPY . .
RUN CGO_ENABLED=0 GOOS=linux go build -o /app/fs . ARG VERSION=dev
ARG COMMIT=none
ARG DATE=unknown
RUN CGO_ENABLED=0 GOOS=linux go build \
-trimpath \
-ldflags "-s -w -X main.version=${VERSION} -X main.commit=${COMMIT} -X main.date=${DATE}" \
-o /app/fs .
FROM alpine:3.23 AS runner FROM alpine:3.23 AS runner

View File

@@ -49,6 +49,61 @@ Admin API (JSON):
- `PUT /_admin/v1/users/{accessKeyId}/status` - `PUT /_admin/v1/users/{accessKeyId}/status`
- `DELETE /_admin/v1/users/{accessKeyId}` - `DELETE /_admin/v1/users/{accessKeyId}`
Admin API policy examples (SigV4):
```bash
ENDPOINT="http://localhost:2600"
REGION="us-east-1"
ADMIN_ACCESS_KEY="${FS_ROOT_USER}"
ADMIN_SECRET_KEY="${FS_ROOT_PASSWORD}"
SIGV4="aws:amz:${REGION}:s3"
```
Replace user policy with one scoped statement:
```bash
curl --aws-sigv4 "$SIGV4" \
--user "${ADMIN_ACCESS_KEY}:${ADMIN_SECRET_KEY}" \
-H "Content-Type: application/json" \
-X PUT "${ENDPOINT}/_admin/v1/users/test-user/policy" \
-d '{
"policy": {
"statements": [
{
"effect": "allow",
"actions": ["s3:ListBucket", "s3:GetObject", "s3:PutObject", "s3:DeleteObject"],
"bucket": "backup",
"prefix": "restic/*"
}
]
}
}'
```
Set multiple statements (for multiple buckets):
```bash
curl --aws-sigv4 "$SIGV4" \
--user "${ADMIN_ACCESS_KEY}:${ADMIN_SECRET_KEY}" \
-H "Content-Type: application/json" \
-X PUT "${ENDPOINT}/_admin/v1/users/test-user/policy" \
-d '{
"policy": {
"statements": [
{
"effect": "allow",
"actions": ["s3:ListBucket", "s3:GetObject"],
"bucket": "test-bucket",
"prefix": "*"
},
{
"effect": "allow",
"actions": ["s3:ListBucket", "s3:GetObject", "s3:PutObject", "s3:DeleteObject"],
"bucket": "test-bucket-2",
"prefix": "*"
}
]
}
}'
```
Admin CLI: Admin CLI:
- `fs admin user create --access-key backup-user --role readwrite` - `fs admin user create --access-key backup-user --role readwrite`
- `fs admin user list` - `fs admin user list`
@@ -56,8 +111,12 @@ Admin CLI:
- `fs admin user set-status backup-user --status disabled` - `fs admin user set-status backup-user --status disabled`
- `fs admin user set-role backup-user --role readonly --bucket backup-bucket --prefix restic/` - `fs admin user set-role backup-user --role readonly --bucket backup-bucket --prefix restic/`
- `fs admin user set-role backup-user --role readwrite --bucket backups-2` (appends another statement) - `fs admin user set-role backup-user --role readwrite --bucket backups-2` (appends another statement)
- `fs admin user remove-role backup-user --role readonly --bucket backup-bucket --prefix restic/`
- `fs admin user set-role backup-user --role admin --replace` (replaces all statements) - `fs admin user set-role backup-user --role admin --replace` (replaces all statements)
- `fs admin user delete backup-user` - `fs admin user delete backup-user`
- `fs admin snapshot create --data-path /var/lib/fs --out /backup/fs-20260311.tar.gz`
- `fs admin snapshot inspect --file /backup/fs-20260311.tar.gz`
- `fs admin snapshot restore --file /backup/fs-20260311.tar.gz --data-path /var/lib/fs --force`
- `fs admin diag health` - `fs admin diag health`
- `fs admin diag version` - `fs admin diag version`
@@ -68,6 +127,9 @@ Required when `FS_AUTH_ENABLED=true`:
- `FS_ROOT_USER` and `FS_ROOT_PASSWORD` define initial credentials - `FS_ROOT_USER` and `FS_ROOT_PASSWORD` define initial credentials
- `ADMIN_API_ENABLED=true` enables `/_admin/v1/*` routes (bootstrap key only) - `ADMIN_API_ENABLED=true` enables `/_admin/v1/*` routes (bootstrap key only)
Upload limits:
- `FS_MAX_OBJECT_UPLOAD_BYTES` limits object PUT payloads, multipart upload parts, and completed multipart object size (default 5 GiB).
Reference: `auth/README.md` Reference: `auth/README.md`
Additional docs: Additional docs:
@@ -80,9 +142,12 @@ CLI credential/env resolution for `fs admin`:
- `FS_ROOT_USER` / `FS_ROOT_PASSWORD` (same defaults as server bootstrap) - `FS_ROOT_USER` / `FS_ROOT_PASSWORD` (same defaults as server bootstrap)
- `FSCLI_ACCESS_KEY` / `FSCLI_SECRET_KEY` - `FSCLI_ACCESS_KEY` / `FSCLI_SECRET_KEY`
- `AWS_ACCESS_KEY_ID` / `AWS_SECRET_ACCESS_KEY` - `AWS_ACCESS_KEY_ID` / `AWS_SECRET_ACCESS_KEY`
- `FSCLI_ENDPOINT` (fallback to `ADDRESS` + `PORT`, then `http://localhost:3000`) - `FSCLI_ENDPOINT` (fallback to `ADDRESS` + `PORT`, then `http://localhost:2600`)
- `FSCLI_REGION` (fallback `FS_AUTH_REGION`, default `us-east-1`) - `FSCLI_REGION` (fallback `FS_AUTH_REGION`, default `us-east-1`)
Note:
- `fs admin snapshot ...` commands operate locally on filesystem paths and do not require endpoint or auth credentials.
Health: Health:
- `GET /healthz` - `GET /healthz`
- `HEAD /healthz` - `HEAD /healthz`

View File

@@ -2,6 +2,7 @@ package api
import ( import (
"bufio" "bufio"
"bytes"
"context" "context"
"encoding/base64" "encoding/base64"
"encoding/xml" "encoding/xml"
@@ -40,6 +41,7 @@ const (
maxXMLBodyBytes int64 = 1 << 20 maxXMLBodyBytes int64 = 1 << 20
maxDeleteObjects = 1000 maxDeleteObjects = 1000
maxObjectKeyBytes = 1024 maxObjectKeyBytes = 1024
maxAWSChunkedLineBytes = 8 << 10
serverReadHeaderTimeout = 5 * time.Second serverReadHeaderTimeout = 5 * time.Second
serverReadTimeout = 60 * time.Second serverReadTimeout = 60 * time.Second
serverWriteTimeout = 120 * time.Second serverWriteTimeout = 120 * time.Second
@@ -138,11 +140,87 @@ func validateObjectKey(key string) *s3APIError {
return nil return nil
} }
func objectKeyFromRequest(r *http.Request) (string, *s3APIError) {
rawKey := rawObjectKeyFromRequest(r)
key, err := normalizeObjectKey(rawKey)
if err != nil {
apiErr := s3ErrInvalidArgument
return "", &apiErr
}
if apiErr := validateObjectKey(key); apiErr != nil {
return "", apiErr
}
return key, nil
}
func rawObjectKeyFromRequest(r *http.Request) string {
if r == nil || r.URL == nil {
return ""
}
bucket := chi.URLParam(r, "bucket")
if bucket == "" {
return chi.URLParam(r, "*")
}
escapedPath := r.URL.EscapedPath()
prefix := "/" + bucket + "/"
if strings.HasPrefix(escapedPath, prefix) {
return strings.TrimPrefix(escapedPath, prefix)
}
return chi.URLParam(r, "*")
}
func normalizeObjectKey(raw string) (string, error) {
if raw == "" {
return "", nil
}
return url.PathUnescape(raw)
}
func parseCopySource(raw string) (string, string, error) {
raw = strings.TrimSpace(raw)
raw = strings.TrimPrefix(raw, "/")
if idx := strings.IndexByte(raw, '?'); idx >= 0 {
raw = raw[:idx]
}
bucket, rawKey, found := strings.Cut(raw, "/")
if !found || strings.TrimSpace(bucket) == "" || rawKey == "" {
return "", "", errors.New("invalid copy source")
}
key, err := normalizeObjectKey(rawKey)
if err != nil {
return "", "", err
}
if apiErr := validateObjectKey(key); apiErr != nil {
return "", "", errors.New(apiErr.Code)
}
return bucket, key, nil
}
func (h *Handler) authorizeCopySource(r *http.Request, bucket, key string) error {
return h.authorizeObjectAction(r, auth.ActionGetObject, bucket, key)
}
func (h *Handler) authorizeObjectAction(r *http.Request, action auth.Action, bucket, key string) error {
if h.authSvc == nil || !h.authSvc.Config().Enabled {
return nil
}
authCtx, ok := auth.GetRequestContext(r.Context())
if !ok || !authCtx.Authenticated {
return auth.ErrAccessDenied
}
return h.authSvc.Authorize(authCtx.AccessKeyID, auth.RequestTarget{
Action: action,
Bucket: bucket,
Key: key,
})
}
func (h *Handler) handleGetObject(w http.ResponseWriter, r *http.Request) { func (h *Handler) handleGetObject(w http.ResponseWriter, r *http.Request) {
bucket := chi.URLParam(r, "bucket") bucket := chi.URLParam(r, "bucket")
key := chi.URLParam(r, "*") key, apiErr := objectKeyFromRequest(r)
if apiErr != nil {
if apiErr := validateObjectKey(key); apiErr != nil {
writeS3Error(w, r, *apiErr, r.URL.Path) writeS3Error(w, r, *apiErr, r.URL.Path)
return return
} }
@@ -199,8 +277,8 @@ func (h *Handler) handleGetObject(w http.ResponseWriter, r *http.Request) {
func (h *Handler) handlePostObject(w http.ResponseWriter, r *http.Request) { func (h *Handler) handlePostObject(w http.ResponseWriter, r *http.Request) {
bucket := chi.URLParam(r, "bucket") bucket := chi.URLParam(r, "bucket")
key := chi.URLParam(r, "*") key, apiErr := objectKeyFromRequest(r)
if apiErr := validateObjectKey(key); apiErr != nil { if apiErr != nil {
writeS3Error(w, r, *apiErr, r.URL.Path) writeS3Error(w, r, *apiErr, r.URL.Path)
return return
} }
@@ -234,6 +312,10 @@ func (h *Handler) handlePostObject(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxXMLBodyBytes) r.Body = http.MaxBytesReader(w, r.Body, maxXMLBodyBytes)
var req models.CompleteMultipartUploadRequest var req models.CompleteMultipartUploadRequest
if err := xml.NewDecoder(r.Body).Decode(&req); err != nil { if err := xml.NewDecoder(r.Body).Decode(&req); err != nil {
if errors.Is(err, auth.ErrSignatureDoesNotMatch) {
writeMappedS3Error(w, r, err)
return
}
var maxErr *http.MaxBytesError var maxErr *http.MaxBytesError
if errors.As(err, &maxErr) { if errors.As(err, &maxErr) {
writeS3Error(w, r, s3ErrEntityTooLarge, r.URL.Path) writeS3Error(w, r, s3ErrEntityTooLarge, r.URL.Path)
@@ -275,8 +357,8 @@ func (h *Handler) handlePostObject(w http.ResponseWriter, r *http.Request) {
func (h *Handler) handlePutObject(w http.ResponseWriter, r *http.Request) { func (h *Handler) handlePutObject(w http.ResponseWriter, r *http.Request) {
bucket := chi.URLParam(r, "bucket") bucket := chi.URLParam(r, "bucket")
key := chi.URLParam(r, "*") key, apiErr := objectKeyFromRequest(r)
if apiErr := validateObjectKey(key); apiErr != nil { if apiErr != nil {
writeS3Error(w, r, *apiErr, r.URL.Path) writeS3Error(w, r, *apiErr, r.URL.Path)
return return
} }
@@ -289,6 +371,10 @@ func (h *Handler) handlePutObject(w http.ResponseWriter, r *http.Request) {
writeS3Error(w, r, s3ErrInvalidPart, r.URL.Path) writeS3Error(w, r, s3ErrInvalidPart, r.URL.Path)
return return
} }
if strings.TrimSpace(r.Header.Get("x-amz-copy-source")) != "" {
writeS3Error(w, r, s3ErrNotImplemented, r.URL.Path)
return
}
partNumber, err := strconv.Atoi(partNumberRaw) partNumber, err := strconv.Atoi(partNumberRaw)
if err != nil { if err != nil {
@@ -302,6 +388,10 @@ func (h *Handler) handlePutObject(w http.ResponseWriter, r *http.Request) {
bodyReader := io.Reader(r.Body) bodyReader := io.Reader(r.Body)
var decodeStream io.ReadCloser var decodeStream io.ReadCloser
if hasUnsupportedAWSChunkedPayload(r) {
writeS3Error(w, r, s3ErrInvalidArgument, r.URL.Path)
return
}
if shouldDecodeAWSChunkedPayload(r) { if shouldDecodeAWSChunkedPayload(r) {
decodeStream = newAWSChunkedDecodingReader(r.Body) decodeStream = newAWSChunkedDecodingReader(r.Body)
defer decodeStream.Close() defer decodeStream.Close()
@@ -333,6 +423,42 @@ func (h *Handler) handlePutObject(w http.ResponseWriter, r *http.Request) {
} }
} }
if copySourceRaw := strings.TrimSpace(r.Header.Get("x-amz-copy-source")); copySourceRaw != "" {
srcBucket, srcKey, err := parseCopySource(copySourceRaw)
if err != nil {
writeS3Error(w, r, s3ErrInvalidArgument, r.URL.Path)
return
}
if err := h.authorizeCopySource(r, srcBucket, srcKey); err != nil {
writeMappedS3Error(w, r, err)
return
}
manifest, err := h.svc.CopyObject(srcBucket, srcKey, bucket, key)
if err != nil {
writeMappedS3Error(w, r, err)
return
}
response := models.CopyObjectResult{
Xmlns: "http://s3.amazonaws.com/doc/2006-03-01/",
LastModified: time.Unix(manifest.CreatedAt, 0).UTC().Format("2006-01-02T15:04:05.000Z"),
ETag: `"` + manifest.ETag + `"`,
}
payload, err := xml.MarshalIndent(response, "", " ")
if err != nil {
writeMappedS3Error(w, r, err)
return
}
w.Header().Set("Content-Type", "application/xml; charset=utf-8")
w.Header().Set("ETag", `"`+manifest.ETag+`"`)
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(xml.Header))
_, _ = w.Write(payload)
return
}
contentType := r.Header.Get("Content-Type") contentType := r.Header.Get("Content-Type")
if contentType == "" { if contentType == "" {
contentType = "application/octet-stream" contentType = "application/octet-stream"
@@ -340,6 +466,10 @@ func (h *Handler) handlePutObject(w http.ResponseWriter, r *http.Request) {
bodyReader := io.Reader(r.Body) bodyReader := io.Reader(r.Body)
var decodeStream io.ReadCloser var decodeStream io.ReadCloser
if hasUnsupportedAWSChunkedPayload(r) {
writeS3Error(w, r, s3ErrInvalidArgument, r.URL.Path)
return
}
if shouldDecodeAWSChunkedPayload(r) { if shouldDecodeAWSChunkedPayload(r) {
decodeStream = newAWSChunkedDecodingReader(r.Body) decodeStream = newAWSChunkedDecodingReader(r.Body)
defer decodeStream.Close() defer decodeStream.Close()
@@ -395,18 +525,27 @@ func (h *Handler) handleListMultipartParts(w http.ResponseWriter, r *http.Reques
} }
func shouldDecodeAWSChunkedPayload(r *http.Request) bool { func shouldDecodeAWSChunkedPayload(r *http.Request) bool {
contentEncoding := strings.ToLower(r.Header.Get("Content-Encoding"))
if strings.Contains(contentEncoding, "aws-chunked") {
return true
}
signingMode := strings.ToLower(r.Header.Get("x-amz-content-sha256")) signingMode := strings.ToLower(r.Header.Get("x-amz-content-sha256"))
return strings.HasPrefix(signingMode, "streaming-aws4-hmac-sha256-payload") return strings.HasPrefix(signingMode, "streaming-unsigned-payload")
}
func hasUnsupportedAWSChunkedPayload(r *http.Request) bool {
contentEncoding := strings.ToLower(r.Header.Get("Content-Encoding"))
if !strings.Contains(contentEncoding, "aws-chunked") {
return false
}
return !shouldDecodeAWSChunkedPayload(r)
} }
func newAWSChunkedDecodingReader(src io.Reader) io.ReadCloser { func newAWSChunkedDecodingReader(src io.Reader) io.ReadCloser {
probedReader, isAWSChunked := probeAWSChunkedPayload(src)
if !isAWSChunked {
return io.NopCloser(probedReader)
}
pr, pw := io.Pipe() pr, pw := io.Pipe()
go func() { go func() {
if err := decodeAWSChunkedPayload(src, pw); err != nil { if err := decodeAWSChunkedPayload(probedReader, pw); err != nil {
_ = pw.CloseWithError(err) _ = pw.CloseWithError(err)
return return
} }
@@ -415,10 +554,37 @@ func newAWSChunkedDecodingReader(src io.Reader) io.ReadCloser {
return pr return pr
} }
func probeAWSChunkedPayload(src io.Reader) (io.Reader, bool) {
reader := bufio.NewReaderSize(src, maxAWSChunkedLineBytes)
headerLine, err := reader.ReadSlice('\n')
replay := io.MultiReader(bytes.NewReader(headerLine), reader)
if errors.Is(err, bufio.ErrBufferFull) {
return replay, true
}
if err != nil {
return replay, false
}
line := strings.TrimRight(string(headerLine), "\r\n")
chunkSizeToken := line
if idx := strings.IndexByte(chunkSizeToken, ';'); idx >= 0 {
chunkSizeToken = chunkSizeToken[:idx]
}
chunkSizeToken = strings.TrimSpace(chunkSizeToken)
if chunkSizeToken == "" {
return replay, false
}
size, parseErr := strconv.ParseInt(chunkSizeToken, 16, 64)
if parseErr != nil || size < 0 {
return replay, false
}
return replay, true
}
func decodeAWSChunkedPayload(src io.Reader, dst io.Writer) error { func decodeAWSChunkedPayload(src io.Reader, dst io.Writer) error {
reader := bufio.NewReader(src) reader := bufio.NewReaderSize(src, maxAWSChunkedLineBytes)
for { for {
headerLine, err := reader.ReadString('\n') headerLine, err := readAWSChunkedLine(reader)
if err != nil { if err != nil {
return err return err
} }
@@ -435,6 +601,17 @@ func decodeAWSChunkedPayload(src io.Reader, dst io.Writer) error {
if chunkSize < 0 { if chunkSize < 0 {
return fmt.Errorf("invalid aws-chunked size: %d", chunkSize) return fmt.Errorf("invalid aws-chunked size: %d", chunkSize)
} }
if chunkSize == 0 {
for {
line, err := readAWSChunkedLine(reader)
if err != nil {
return err
}
if line == "\r\n" || line == "\n" {
return nil
}
}
}
if chunkSize > 0 { if chunkSize > 0 {
if _, err := io.CopyN(dst, reader, chunkSize); err != nil { if _, err := io.CopyN(dst, reader, chunkSize); err != nil {
return err return err
@@ -448,21 +625,20 @@ func decodeAWSChunkedPayload(src io.Reader, dst io.Writer) error {
if crlf[0] != '\r' || crlf[1] != '\n' { if crlf[0] != '\r' || crlf[1] != '\n' {
return errors.New("invalid aws-chunked payload terminator") return errors.New("invalid aws-chunked payload terminator")
} }
if chunkSize == 0 {
for {
line, err := reader.ReadString('\n')
if err != nil {
return err
}
if line == "\r\n" || line == "\n" {
return nil
}
}
}
} }
} }
func readAWSChunkedLine(reader *bufio.Reader) (string, error) {
line, err := reader.ReadSlice('\n')
if errors.Is(err, bufio.ErrBufferFull) {
return "", service.ErrEntityTooLarge
}
if len(line) > maxAWSChunkedLineBytes {
return "", service.ErrEntityTooLarge
}
return string(line), err
}
func ifNoneMatchPreconditionFailed(headerValue, etag string) bool { func ifNoneMatchPreconditionFailed(headerValue, etag string) bool {
for _, rawToken := range strings.Split(headerValue, ",") { for _, rawToken := range strings.Split(headerValue, ",") {
token := strings.TrimSpace(rawToken) token := strings.TrimSpace(rawToken)
@@ -488,7 +664,7 @@ func (h *Handler) handlePutBucket(w http.ResponseWriter, r *http.Request) {
writeMappedS3Error(w, r, err) writeMappedS3Error(w, r, err)
return return
} }
w.WriteHeader(http.StatusCreated) w.WriteHeader(http.StatusOK)
} }
func (h *Handler) handleDeleteBucket(w http.ResponseWriter, r *http.Request) { func (h *Handler) handleDeleteBucket(w http.ResponseWriter, r *http.Request) {
@@ -519,6 +695,10 @@ func (h *Handler) handlePostBucket(w http.ResponseWriter, r *http.Request) {
var req models.DeleteObjectsRequest var req models.DeleteObjectsRequest
if err := xml.NewDecoder(bodyReader).Decode(&req); err != nil { if err := xml.NewDecoder(bodyReader).Decode(&req); err != nil {
if errors.Is(err, auth.ErrSignatureDoesNotMatch) {
writeMappedS3Error(w, r, err)
return
}
var maxErr *http.MaxBytesError var maxErr *http.MaxBytesError
if errors.As(err, &maxErr) { if errors.As(err, &maxErr) {
writeS3Error(w, r, s3ErrEntityTooLarge, r.URL.Path) writeS3Error(w, r, s3ErrEntityTooLarge, r.URL.Path)
@@ -554,6 +734,15 @@ func (h *Handler) handlePostBucket(w http.ResponseWriter, r *http.Request) {
}) })
continue continue
} }
if err := h.authorizeObjectAction(r, auth.ActionDeleteObject, bucket, obj.Key); err != nil {
apiErr := mapToS3Error(err)
response.Errors = append(response.Errors, models.DeleteError{
Key: obj.Key,
Code: apiErr.Code,
Message: apiErr.Message,
})
continue
}
keys = append(keys, obj.Key) keys = append(keys, obj.Key)
} }
@@ -584,8 +773,8 @@ func (h *Handler) handlePostBucket(w http.ResponseWriter, r *http.Request) {
func (h *Handler) handleDeleteObject(w http.ResponseWriter, r *http.Request) { func (h *Handler) handleDeleteObject(w http.ResponseWriter, r *http.Request) {
bucket := chi.URLParam(r, "bucket") bucket := chi.URLParam(r, "bucket")
key := chi.URLParam(r, "*") key, apiErr := objectKeyFromRequest(r)
if apiErr := validateObjectKey(key); apiErr != nil { if apiErr != nil {
writeS3Error(w, r, *apiErr, r.URL.Path) writeS3Error(w, r, *apiErr, r.URL.Path)
return return
} }
@@ -621,8 +810,8 @@ func (h *Handler) handleHeadBucket(w http.ResponseWriter, r *http.Request) {
func (h *Handler) handleHeadObject(w http.ResponseWriter, r *http.Request) { func (h *Handler) handleHeadObject(w http.ResponseWriter, r *http.Request) {
bucket := chi.URLParam(r, "bucket") bucket := chi.URLParam(r, "bucket")
key := chi.URLParam(r, "*") key, apiErr := objectKeyFromRequest(r)
if apiErr := validateObjectKey(key); apiErr != nil { if apiErr != nil {
writeS3Error(w, r, *apiErr, r.URL.Path) writeS3Error(w, r, *apiErr, r.URL.Path)
return return
} }

115
api/aws_chunked_test.go Normal file
View File

@@ -0,0 +1,115 @@
package api
import (
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"fs/service"
)
func TestShouldDecodeAWSChunkedPayloadUnsignedTrailerMode(t *testing.T) {
t.Parallel()
req, err := http.NewRequest(http.MethodPut, "http://example.com/b/k", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("x-amz-content-sha256", "STREAMING-UNSIGNED-PAYLOAD-TRAILER")
if !shouldDecodeAWSChunkedPayload(req) {
t.Fatalf("expected shouldDecodeAWSChunkedPayload to return true for STREAMING-UNSIGNED-PAYLOAD-TRAILER")
}
}
func TestUnsupportedAWSChunkedContentEncodingWithoutStreamingMode(t *testing.T) {
t.Parallel()
req, err := http.NewRequest(http.MethodPut, "http://example.com/b/k", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Encoding", "aws-chunked")
req.Header.Set("x-amz-content-sha256", "UNSIGNED-PAYLOAD")
if !hasUnsupportedAWSChunkedPayload(req) {
t.Fatalf("expected aws-chunked content encoding without streaming mode to be unsupported")
}
if shouldDecodeAWSChunkedPayload(req) {
t.Fatalf("non-streaming aws-chunked content encoding must not trigger decoding")
}
}
func TestPutObjectRejectsUnsignedAWSChunkedContentEncoding(t *testing.T) {
handler, svc := newUploadLimitHandler(t, 1024)
if err := svc.CreateBucket("test-bucket"); err != nil {
t.Fatalf("CreateBucket: %v", err)
}
req := httptest.NewRequest(http.MethodPut, "/test-bucket/object.txt", strings.NewReader("4\r\nWiki\r\n0\r\n\r\n"))
req.Header.Set("Content-Encoding", "aws-chunked")
req.Header.Set("x-amz-content-sha256", "UNSIGNED-PAYLOAD")
rec := httptest.NewRecorder()
handler.router.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d body=%s", rec.Code, http.StatusBadRequest, rec.Body.String())
}
if !strings.Contains(rec.Body.String(), "InvalidArgument") {
t.Fatalf("expected InvalidArgument response, body=%s", rec.Body.String())
}
}
func TestAWSChunkedReaderPassThroughForPlainPayload(t *testing.T) {
t.Parallel()
plain := "PAR1\x00\x01\x02\x03binary-without-aws-chunked-header"
reader := newAWSChunkedDecodingReader(strings.NewReader(plain))
defer reader.Close()
out, err := io.ReadAll(reader)
if err != nil {
t.Fatalf("read failed: %v", err)
}
if string(out) != plain {
t.Fatalf("unexpected passthrough result: got %q want %q", string(out), plain)
}
}
func TestAWSChunkedReaderDecodesChunkedPayload(t *testing.T) {
t.Parallel()
encoded := "" +
"4\r\nWiki\r\n" +
"5\r\npedia\r\n" +
"0\r\n" +
"x-amz-checksum-crc32:xxxx\r\n" +
"\r\n"
reader := newAWSChunkedDecodingReader(strings.NewReader(encoded))
defer reader.Close()
out, err := io.ReadAll(reader)
if err != nil {
t.Fatalf("read failed: %v", err)
}
if string(out) != "Wikipedia" {
t.Fatalf("decoded payload mismatch: got %q want %q", string(out), "Wikipedia")
}
}
func TestAWSChunkedReaderRejectsOversizedChunkHeader(t *testing.T) {
t.Parallel()
encoded := strings.Repeat("f", maxAWSChunkedLineBytes+1) + "\n"
reader := newAWSChunkedDecodingReader(strings.NewReader(encoded))
defer reader.Close()
_, err := io.ReadAll(reader)
if !errors.Is(err, service.ErrEntityTooLarge) {
t.Fatalf("read error = %v, want ErrEntityTooLarge", err)
}
}

View File

@@ -0,0 +1,283 @@
package api
import (
"bytes"
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"errors"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"net/url"
"path/filepath"
"sort"
"strings"
"testing"
"time"
"fs/auth"
"fs/logging"
"fs/metadata"
"fs/models"
"fs/service"
"fs/storage"
"github.com/go-chi/chi/v5"
)
func newAuthorizedDeleteHandler(t *testing.T) (*Handler, *service.ObjectService, *auth.Service) {
t.Helper()
root := t.TempDir()
md, err := metadata.NewMetadataHandler(filepath.Join(root, "metadata.db"))
if err != nil {
t.Fatalf("new metadata handler: %v", err)
}
blob, err := storage.NewBlobStore(root, 1024)
if err != nil {
t.Fatalf("new blob store: %v", err)
}
svc := service.NewObjectService(md, blob, time.Hour)
t.Cleanup(func() {
_ = svc.Close()
})
masterKey := base64.StdEncoding.EncodeToString(make([]byte, 32))
authSvc, err := auth.NewService(auth.ConfigFromValues(
true,
"us-east-1",
0,
0,
masterKey,
"",
"",
"",
), md)
if err != nil {
t.Fatalf("new auth service: %v", err)
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
handler := NewHandler(svc, logger, logging.Config{}, authSvc, false)
return handler, svc, authSvc
}
func newBucketPostRequest(bucket, body string) *http.Request {
req := httptest.NewRequest(http.MethodPost, "/"+bucket+"?delete", strings.NewReader(body))
rctx := chi.NewRouteContext()
rctx.URLParams.Add("bucket", bucket)
return req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
}
func withAuthContext(req *http.Request, accessKeyID string) *http.Request {
authCtx := auth.RequestContext{
Authenticated: true,
AccessKeyID: accessKeyID,
AuthType: "test",
}
return req.WithContext(auth.WithRequestContext(req.Context(), authCtx))
}
func createDeleteUser(t *testing.T, authSvc *auth.Service, prefix string) {
t.Helper()
createDeleteUserWithStatements(t, authSvc, []models.AuthPolicyStatement{
{
Effect: "allow",
Actions: []string{"s3:DeleteObject"},
Bucket: "test-bucket",
Prefix: prefix,
},
})
}
func createDeleteUserWithStatements(t *testing.T, authSvc *auth.Service, statements []models.AuthPolicyStatement) {
t.Helper()
_, err := authSvc.CreateUser(auth.CreateUserInput{
AccessKeyID: "delete-user",
SecretKey: "delete-secret-1",
Policy: models.AuthPolicy{
Statements: statements,
},
})
if err != nil {
t.Fatalf("create delete user: %v", err)
}
}
func putTestObject(t *testing.T, svc *service.ObjectService, key string) {
t.Helper()
_, err := svc.PutObject("test-bucket", key, "text/plain", bytes.NewReader([]byte("data")))
if err != nil {
t.Fatalf("put object %q: %v", key, err)
}
}
func TestMultiDeleteAuthorizesEveryKey(t *testing.T) {
handler, svc, authSvc := newAuthorizedDeleteHandler(t)
if err := svc.CreateBucket("test-bucket"); err != nil {
t.Fatalf("create bucket: %v", err)
}
createDeleteUser(t, authSvc, "allowed/")
putTestObject(t, svc, "allowed/file.txt")
putTestObject(t, svc, "private/file.txt")
body := `<Delete><Object><Key>allowed/file.txt</Key></Object><Object><Key>private/file.txt</Key></Object></Delete>`
req := withAuthContext(newBucketPostRequest("test-bucket", body), "delete-user")
rec := httptest.NewRecorder()
handler.handlePostBucket(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("unexpected status: got %d body=%s", rec.Code, rec.Body.String())
}
responseBody := rec.Body.String()
if !strings.Contains(responseBody, "<Deleted>") || !strings.Contains(responseBody, "allowed/file.txt") {
t.Fatalf("expected allowed key to be deleted, body=%s", responseBody)
}
if !strings.Contains(responseBody, "<Error>") || !strings.Contains(responseBody, "private/file.txt") || !strings.Contains(responseBody, "AccessDenied") {
t.Fatalf("expected denied key error, body=%s", responseBody)
}
if _, err := svc.HeadObject("test-bucket", "allowed/file.txt"); !errors.Is(err, metadata.ErrObjectNotFound) {
t.Fatalf("allowed object should be deleted, got err=%v", err)
}
if _, err := svc.HeadObject("test-bucket", "private/file.txt"); err != nil {
t.Fatalf("private object should remain: %v", err)
}
}
func TestMultiDeleteAllowsScopedKeys(t *testing.T) {
handler, svc, authSvc := newAuthorizedDeleteHandler(t)
if err := svc.CreateBucket("test-bucket"); err != nil {
t.Fatalf("create bucket: %v", err)
}
createDeleteUser(t, authSvc, "allowed/")
putTestObject(t, svc, "allowed/file.txt")
body := `<Delete><Object><Key>allowed/file.txt</Key></Object></Delete>`
req := withAuthContext(newBucketPostRequest("test-bucket", body), "delete-user")
rec := httptest.NewRecorder()
handler.handlePostBucket(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("unexpected status: got %d body=%s", rec.Code, rec.Body.String())
}
if strings.Contains(rec.Body.String(), "<Error>") {
t.Fatalf("unexpected delete error body=%s", rec.Body.String())
}
if _, err := svc.HeadObject("test-bucket", "allowed/file.txt"); !errors.Is(err, metadata.ErrObjectNotFound) {
t.Fatalf("allowed object should be deleted, got err=%v", err)
}
}
func TestMultiDeleteRouteAuthorizesKeysAfterMiddleware(t *testing.T) {
handler, svc, authSvc := newAuthorizedDeleteHandler(t)
handler.setupRoutes()
if err := svc.CreateBucket("test-bucket"); err != nil {
t.Fatalf("create bucket: %v", err)
}
createDeleteUserWithStatements(t, authSvc, []models.AuthPolicyStatement{
{Effect: "allow", Actions: []string{"s3:DeleteObject"}, Bucket: "test-bucket", Prefix: "allowed/"},
{Effect: "deny", Actions: []string{"s3:DeleteObject"}, Bucket: "test-bucket", Prefix: "private/"},
})
putTestObject(t, svc, "allowed/file.txt")
putTestObject(t, svc, "private/file.txt")
body := `<Delete><Object><Key>allowed/file.txt</Key></Object><Object><Key>private/file.txt</Key></Object></Delete>`
req := httptest.NewRequest(http.MethodPost, "/test-bucket?delete", strings.NewReader(body))
signTestSigV4Request(t, req, "delete-user", "delete-secret-1")
rec := httptest.NewRecorder()
handler.router.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("unexpected status: got %d body=%s", rec.Code, rec.Body.String())
}
responseBody := rec.Body.String()
if !strings.Contains(responseBody, "allowed/file.txt") || !strings.Contains(responseBody, "<Deleted>") {
t.Fatalf("expected allowed key deletion, body=%s", responseBody)
}
if !strings.Contains(responseBody, "private/file.txt") || !strings.Contains(responseBody, "AccessDenied") {
t.Fatalf("expected per-key AccessDenied, body=%s", responseBody)
}
if _, err := svc.HeadObject("test-bucket", "allowed/file.txt"); !errors.Is(err, metadata.ErrObjectNotFound) {
t.Fatalf("allowed object should be deleted, got err=%v", err)
}
if _, err := svc.HeadObject("test-bucket", "private/file.txt"); err != nil {
t.Fatalf("private object should remain: %v", err)
}
}
func signTestSigV4Request(t *testing.T, req *http.Request, accessKeyID, secretKey string) {
t.Helper()
amzDate := time.Now().UTC().Format("20060102T150405Z")
date := amzDate[:8]
region := "us-east-1"
serviceName := "s3"
scope := strings.Join([]string{date, region, serviceName, "aws4_request"}, "/")
signedHeaders := []string{"host", "x-amz-content-sha256", "x-amz-date"}
signedHeadersRaw := strings.Join(signedHeaders, ";")
payloadHash := "UNSIGNED-PAYLOAD"
req.Header.Set("x-amz-date", amzDate)
req.Header.Set("x-amz-content-sha256", payloadHash)
canonicalRequest := strings.Join([]string{
req.Method,
req.URL.EscapedPath(),
canonicalTestQuery(req.URL.RawQuery),
"host:" + strings.TrimSpace(req.Host) + "\n" +
"x-amz-content-sha256:" + payloadHash + "\n" +
"x-amz-date:" + amzDate + "\n",
signedHeadersRaw,
payloadHash,
}, "\n")
canonicalHash := sha256.Sum256([]byte(canonicalRequest))
stringToSign := strings.Join([]string{
"AWS4-HMAC-SHA256",
amzDate,
scope,
hex.EncodeToString(canonicalHash[:]),
}, "\n")
signingKey := testHMAC(testHMAC(testHMAC(testHMAC([]byte("AWS4"+secretKey), date), region), serviceName), "aws4_request")
signature := hex.EncodeToString(testHMAC(signingKey, stringToSign))
req.Header.Set("Authorization", "AWS4-HMAC-SHA256 "+
"Credential="+accessKeyID+"/"+scope+", "+
"SignedHeaders="+signedHeadersRaw+", "+
"Signature="+signature)
}
func canonicalTestQuery(rawQuery string) string {
values, _ := url.ParseQuery(rawQuery)
pairs := make([]string, 0)
for key, valueList := range values {
if len(valueList) == 0 {
pairs = append(pairs, awsTestQueryEscape(key)+"=")
continue
}
for _, value := range valueList {
pairs = append(pairs, awsTestQueryEscape(key)+"="+awsTestQueryEscape(value))
}
}
sort.Strings(pairs)
return strings.Join(pairs, "&")
}
func awsTestQueryEscape(value string) string {
encoded := url.QueryEscape(value)
encoded = strings.ReplaceAll(encoded, "+", "%20")
encoded = strings.ReplaceAll(encoded, "*", "%2A")
encoded = strings.ReplaceAll(encoded, "%7E", "~")
return encoded
}
func testHMAC(key []byte, value string) []byte {
mac := hmac.New(sha256.New, key)
_, _ = mac.Write([]byte(value))
return mac.Sum(nil)
}

107
api/object_copy_test.go Normal file
View File

@@ -0,0 +1,107 @@
package api
import (
"bytes"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"path/filepath"
"strings"
"testing"
"time"
"fs/logging"
"fs/metadata"
"fs/service"
"fs/storage"
)
func newTestObjectHandler(t *testing.T) (*Handler, *service.ObjectService) {
t.Helper()
root := t.TempDir()
md, err := metadata.NewMetadataHandler(filepath.Join(root, "metadata.db"))
if err != nil {
t.Fatalf("new metadata handler: %v", err)
}
blob, err := storage.NewBlobStore(root, 1024)
if err != nil {
t.Fatalf("new blob store: %v", err)
}
svc := service.NewObjectService(md, blob, time.Hour)
t.Cleanup(func() {
_ = svc.Close()
})
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
handler := NewHandler(svc, logger, logging.Config{}, nil, false)
handler.setupRoutes()
return handler, svc
}
func TestPutObjectStoresDecodedKey(t *testing.T) {
handler, svc := newTestObjectHandler(t)
if err := svc.CreateBucket("test-bucket"); err != nil {
t.Fatalf("create bucket: %v", err)
}
req := httptest.NewRequest(http.MethodPut, "/test-bucket/jsp-data-raw/vehicle_positions/year%3D2026/month%3D03/day%3D12/file.parquet", bytes.NewReader([]byte("PAR1data")))
rec := httptest.NewRecorder()
handler.router.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("unexpected status: got %d body=%s", rec.Code, rec.Body.String())
}
_, err := svc.HeadObject("test-bucket", "jsp-data-raw/vehicle_positions/year=2026/month=03/day=12/file.parquet")
if err != nil {
t.Fatalf("head decoded key: %v", err)
}
getReq := httptest.NewRequest(http.MethodGet, "/test-bucket/jsp-data-raw/vehicle_positions/year=2026/month=03/day=12/file.parquet", nil)
getRec := httptest.NewRecorder()
handler.router.ServeHTTP(getRec, getReq)
if getRec.Code != http.StatusOK {
t.Fatalf("unexpected get status: got %d body=%s", getRec.Code, getRec.Body.String())
}
if got := getRec.Body.String(); got != "PAR1data" {
t.Fatalf("unexpected get body: got %q", got)
}
}
func TestCopyObjectCopiesCanonicalObject(t *testing.T) {
handler, svc := newTestObjectHandler(t)
if err := svc.CreateBucket("test-bucket"); err != nil {
t.Fatalf("create bucket: %v", err)
}
putReq := httptest.NewRequest(http.MethodPut, "/test-bucket/source/year%3D2026/file.parquet", bytes.NewReader([]byte("PAR1copy")))
putRec := httptest.NewRecorder()
handler.router.ServeHTTP(putRec, putReq)
if putRec.Code != http.StatusOK {
t.Fatalf("unexpected put status: got %d body=%s", putRec.Code, putRec.Body.String())
}
copyReq := httptest.NewRequest(http.MethodPut, "/test-bucket/copied/year=2026/file.parquet", http.NoBody)
copyReq.Header.Set("x-amz-copy-source", "/test-bucket/source/year%3D2026/file.parquet")
copyRec := httptest.NewRecorder()
handler.router.ServeHTTP(copyRec, copyReq)
if copyRec.Code != http.StatusOK {
t.Fatalf("unexpected copy status: got %d body=%s", copyRec.Code, copyRec.Body.String())
}
if !strings.Contains(copyRec.Body.String(), "<CopyObjectResult") {
t.Fatalf("unexpected copy response body: %s", copyRec.Body.String())
}
getReq := httptest.NewRequest(http.MethodGet, "/test-bucket/copied/year=2026/file.parquet", nil)
getRec := httptest.NewRecorder()
handler.router.ServeHTTP(getRec, getReq)
if getRec.Code != http.StatusOK {
t.Fatalf("unexpected get status after copy: got %d body=%s", getRec.Code, getRec.Body.String())
}
if got := getRec.Body.String(); got != "PAR1copy" {
t.Fatalf("unexpected copied body: got %q", got)
}
}

View File

@@ -174,6 +174,8 @@ func mapToS3Error(err error) s3APIError {
return s3ErrMalformedXML return s3ErrMalformedXML
case errors.Is(err, service.ErrEntityTooSmall): case errors.Is(err, service.ErrEntityTooSmall):
return s3ErrEntityTooSmall return s3ErrEntityTooSmall
case errors.Is(err, service.ErrEntityTooLarge):
return s3ErrEntityTooLarge
case errors.Is(err, auth.ErrAccessDenied): case errors.Is(err, auth.ErrAccessDenied):
return s3ErrAccessDenied return s3ErrAccessDenied
case errors.Is(err, auth.ErrInvalidAccessKeyID): case errors.Is(err, auth.ErrInvalidAccessKeyID):

79
api/upload_limit_test.go Normal file
View File

@@ -0,0 +1,79 @@
package api
import (
"bytes"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"path/filepath"
"strings"
"testing"
"time"
"fs/logging"
"fs/metadata"
"fs/service"
"fs/storage"
)
func TestPutObjectReturnsEntityTooLarge(t *testing.T) {
handler, svc := newUploadLimitHandler(t, 4)
if err := svc.CreateBucket("test-bucket"); err != nil {
t.Fatalf("CreateBucket: %v", err)
}
req := httptest.NewRequest(http.MethodPut, "/test-bucket/too-large.txt", strings.NewReader("12345"))
rec := httptest.NewRecorder()
handler.router.ServeHTTP(rec, req)
if rec.Code != http.StatusRequestEntityTooLarge {
t.Fatalf("status = %d, want %d body=%s", rec.Code, http.StatusRequestEntityTooLarge, rec.Body.String())
}
if !strings.Contains(rec.Body.String(), "EntityTooLarge") {
t.Fatalf("expected EntityTooLarge response, body=%s", rec.Body.String())
}
}
func TestUploadPartReturnsEntityTooLarge(t *testing.T) {
handler, svc := newUploadLimitHandler(t, 4)
if err := svc.CreateBucket("test-bucket"); err != nil {
t.Fatalf("CreateBucket: %v", err)
}
upload, err := svc.CreateMultipartUpload("test-bucket", "object.txt")
if err != nil {
t.Fatalf("CreateMultipartUpload: %v", err)
}
req := httptest.NewRequest(http.MethodPut, "/test-bucket/object.txt?partNumber=1&uploadId="+upload.UploadID, bytes.NewReader([]byte("12345")))
rec := httptest.NewRecorder()
handler.router.ServeHTTP(rec, req)
if rec.Code != http.StatusRequestEntityTooLarge {
t.Fatalf("status = %d, want %d body=%s", rec.Code, http.StatusRequestEntityTooLarge, rec.Body.String())
}
if !strings.Contains(rec.Body.String(), "EntityTooLarge") {
t.Fatalf("expected EntityTooLarge response, body=%s", rec.Body.String())
}
}
func newUploadLimitHandler(t *testing.T, maxUploadSize int64) (*Handler, *service.ObjectService) {
t.Helper()
root := t.TempDir()
md, err := metadata.NewMetadataHandler(filepath.Join(root, "metadata.db"))
if err != nil {
t.Fatalf("new metadata handler: %v", err)
}
blob, err := storage.NewBlobStore(root, 4)
if err != nil {
t.Fatalf("new blob store: %v", err)
}
svc := service.NewObjectService(md, blob, time.Hour, maxUploadSize)
t.Cleanup(func() {
_ = svc.Close()
})
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
handler := NewHandler(svc, logger, logging.Config{}, nil, false)
handler.setupRoutes()
return handler, svc
}

View File

@@ -2,7 +2,6 @@ package app
import ( import (
"context" "context"
"fmt"
"fs/api" "fs/api"
"fs/auth" "fs/auth"
"fs/logging" "fs/logging"
@@ -40,6 +39,7 @@ func RunServer(ctx context.Context) error {
"audit_log", logConfig.Audit, "audit_log", logConfig.Audit,
"data_path", config.DataPath, "data_path", config.DataPath,
"multipart_retention_hours", int(config.MultipartCleanupRetention/time.Hour), "multipart_retention_hours", int(config.MultipartCleanupRetention/time.Hour),
"max_object_upload_bytes", config.MaxObjectUploadBytes,
"auth_enabled", authConfig.Enabled, "auth_enabled", authConfig.Enabled,
"auth_region", authConfig.Region, "auth_region", authConfig.Region,
"admin_api_enabled", config.AdminAPIEnabled, "admin_api_enabled", config.AdminAPIEnabled,
@@ -64,7 +64,7 @@ func RunServer(ctx context.Context) error {
return err return err
} }
objectService := service.NewObjectService(metadataHandler, blobHandler, config.MultipartCleanupRetention) objectService := service.NewObjectService(metadataHandler, blobHandler, config.MultipartCleanupRetention, config.MaxObjectUploadBytes)
authService, err := auth.NewService(authConfig, metadataHandler) authService, err := auth.NewService(authConfig, metadataHandler)
if err != nil { if err != nil {
_ = metadataHandler.Close() _ = metadataHandler.Close()
@@ -85,7 +85,7 @@ func RunServer(ctx context.Context) error {
if err := handler.Start(ctx, addr); err != nil { if err := handler.Start(ctx, addr); err != nil {
logger.Error("server_stopped_with_error", "error", err) logger.Error("server_stopped_with_error", "error", err)
return fmt.Errorf("server start failed: %w", err) return err
} }
return nil return nil
} }

View File

@@ -94,9 +94,11 @@ For each non-health request:
6. Decrypt stored secret using master key. 6. Decrypt stored secret using master key.
7. Recompute canonical request and expected signature. 7. Recompute canonical request and expected signature.
8. Compare signatures. 8. Compare signatures.
9. Resolve target action from request. 9. Reject signed streaming payload modes that require per-chunk signature verification.
10. Evaluate policy; deny overrides allow. 10. Wrap fixed-size signed payloads so the actual body must match `x-amz-content-sha256`.
11. Store auth result in request context and continue. 11. Resolve target action from request.
12. Evaluate policy; deny overrides allow.
13. Store auth result in request context and continue.
## Authorization Semantics ## Authorization Semantics
Policy evaluator rules: Policy evaluator rules:
@@ -106,6 +108,9 @@ Policy evaluator rules:
- action: `*` or `s3:*` - action: `*` or `s3:*`
- bucket: `*` - bucket: `*`
- prefix: `*` - prefix: `*`
- Object actions apply `prefix` to the object key.
- `ListBucket` applies `prefix` to the requested list `prefix` query value; a scoped list policy such as `prefix=backups/` does not allow an empty-prefix or sibling-prefix bucket listing.
- Multi-object delete is authorized per object key after the XML body is parsed; denied keys are returned as per-key `AccessDenied` errors and are not deleted.
Action resolution includes: Action resolution includes:
- bucket APIs (`CreateBucket`, `ListBucket`, `HeadBucket`, `DeleteBucket`) - bucket APIs (`CreateBucket`, `ListBucket`, `HeadBucket`, `DeleteBucket`)
@@ -137,14 +142,7 @@ Each audit entry includes method, path, remote IP, and request ID (if present).
## Current Scope / Limitations ## Current Scope / Limitations
- No STS/session-token auth yet. - No STS/session-token auth yet.
- No admin API for managing multiple users yet. - Signed aws-chunked streaming payloads are not accepted until per-chunk signature verification is implemented. Unsigned streaming payload modes can still be decoded by the API layer.
- Policy language is intentionally minimal, not full IAM. - Policy language is intentionally minimal, not full IAM.
- No automatic key rotation workflows. - No automatic key rotation workflows.
- No key rotation endpoint for existing users yet.
## Practical Next Step
To support multiple users cleanly, add admin operations in auth service + API:
- create user
- rotate secret
- set policy
- disable/enable
- delete user

View File

@@ -27,6 +27,18 @@ type RequestTarget struct {
Action Action Action Action
Bucket string Bucket string
Key string Key string
Prefix string
}
func RequiresHandlerAuthorization(r *http.Request) bool {
if r == nil || r.URL == nil {
return false
}
if r.Method == http.MethodPost {
_, isDelete := r.URL.Query()["delete"]
return isDelete
}
return false
} }
func resolveTarget(r *http.Request) RequestTarget { func resolveTarget(r *http.Request) RequestTarget {
@@ -51,7 +63,7 @@ func resolveTarget(r *http.Request) RequestTarget {
case http.MethodDelete: case http.MethodDelete:
return RequestTarget{Action: ActionDeleteBucket, Bucket: bucket} return RequestTarget{Action: ActionDeleteBucket, Bucket: bucket}
case http.MethodGet: case http.MethodGet:
return RequestTarget{Action: ActionListBucket, Bucket: bucket} return RequestTarget{Action: ActionListBucket, Bucket: bucket, Prefix: r.URL.Query().Get("prefix")}
case http.MethodPost: case http.MethodPost:
if _, ok := r.URL.Query()["delete"]; ok { if _, ok := r.URL.Query()["delete"]; ok {
return RequestTarget{Action: ActionDeleteObject, Bucket: bucket} return RequestTarget{Action: ActionDeleteObject, Bucket: bucket}

39
auth/action_test.go Normal file
View File

@@ -0,0 +1,39 @@
package auth
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestResolveTargetIncludesListBucketPrefix(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://example.com/test-bucket?list-type=2&prefix=allowed/", nil)
target := resolveTarget(req)
if target.Action != ActionListBucket {
t.Fatalf("action = %q, want %q", target.Action, ActionListBucket)
}
if target.Bucket != "test-bucket" {
t.Fatalf("bucket = %q, want test-bucket", target.Bucket)
}
if target.Prefix != "allowed/" {
t.Fatalf("prefix = %q, want allowed/", target.Prefix)
}
if target.Key != "" {
t.Fatalf("key = %q, want empty", target.Key)
}
}
func TestResolveTargetListBucketWithoutPrefix(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://example.com/test-bucket", nil)
target := resolveTarget(req)
if target.Action != ActionListBucket {
t.Fatalf("action = %q, want %q", target.Action, ActionListBucket)
}
if target.Prefix != "" {
t.Fatalf("prefix = %q, want empty", target.Prefix)
}
}

View File

@@ -1,11 +1,16 @@
package auth package auth
import ( import (
"crypto/sha256"
"encoding/hex"
"errors" "errors"
"fs/metrics" "fs/metrics"
"hash"
"io"
"log/slog" "log/slog"
"net" "net"
"net/http" "net/http"
"strings"
"github.com/go-chi/chi/v5/middleware" "github.com/go-chi/chi/v5/middleware"
) )
@@ -55,6 +60,16 @@ func Middleware(
return return
} }
if err := wrapPayloadHashVerifier(r); err != nil {
metrics.Default.ObserveAuth("error", "sigv4", authErrorClass(err))
if onError != nil {
onError(w, r, err)
return
}
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
metrics.Default.ObserveAuth("ok", resolvedCtx.AuthType, "none") metrics.Default.ObserveAuth("ok", resolvedCtx.AuthType, "none")
if auditEnabled && logger != nil { if auditEnabled && logger != nil {
requestID := middleware.GetReqID(r.Context()) requestID := middleware.GetReqID(r.Context())
@@ -75,6 +90,65 @@ func Middleware(
} }
} }
func wrapPayloadHashVerifier(r *http.Request) error {
if r == nil || r.Body == nil || r.Body == http.NoBody {
return nil
}
payloadHash := resolvePayloadHash(r, false)
if !payloadHashRequiresVerification(payloadHash) {
return nil
}
if !isHexSHA256(payloadHash) {
return ErrAuthorizationHeaderMalformed
}
expected, err := hex.DecodeString(strings.ToLower(payloadHash))
if err != nil {
return ErrAuthorizationHeaderMalformed
}
r.Body = &payloadHashVerifyingReadCloser{
inner: r.Body,
hasher: sha256.New(),
expected: expected,
}
return nil
}
type payloadHashVerifyingReadCloser struct {
inner io.ReadCloser
hasher hash.Hash
expected []byte
done bool
}
func (r *payloadHashVerifyingReadCloser) Read(p []byte) (int, error) {
n, err := r.inner.Read(p)
if n > 0 {
_, _ = r.hasher.Write(p[:n])
}
if err == io.EOF && !r.done {
r.done = true
if !equalBytes(r.hasher.Sum(nil), r.expected) {
return n, ErrSignatureDoesNotMatch
}
}
return n, err
}
func (r *payloadHashVerifyingReadCloser) Close() error {
return r.inner.Close()
}
func equalBytes(left, right []byte) bool {
if len(left) != len(right) {
return false
}
var diff byte
for i := range left {
diff |= left[i] ^ right[i]
}
return diff == 0
}
func authErrorClass(err error) string { func authErrorClass(err error) string {
switch { switch {
case errors.Is(err, ErrInvalidAccessKeyID): case errors.Is(err, ErrInvalidAccessKeyID):

75
auth/payload_hash_test.go Normal file
View File

@@ -0,0 +1,75 @@
package auth
import (
"crypto/sha256"
"encoding/hex"
"errors"
"io"
"net/http"
"strings"
"testing"
)
func TestPayloadHashVerifierAllowsMatchingBody(t *testing.T) {
body := "payload"
req := newPayloadHashRequest(t, body, body)
if err := wrapPayloadHashVerifier(req); err != nil {
t.Fatalf("wrapPayloadHashVerifier returned error: %v", err)
}
got, err := io.ReadAll(req.Body)
if err != nil {
t.Fatalf("ReadAll returned error: %v", err)
}
if string(got) != body {
t.Fatalf("unexpected body: got %q want %q", string(got), body)
}
}
func TestPayloadHashVerifierRejectsMismatchedBody(t *testing.T) {
req := newPayloadHashRequest(t, "signed-payload", "actual-payload")
if err := wrapPayloadHashVerifier(req); err != nil {
t.Fatalf("wrapPayloadHashVerifier returned error: %v", err)
}
_, err := io.ReadAll(req.Body)
if !errors.Is(err, ErrSignatureDoesNotMatch) {
t.Fatalf("ReadAll error = %v, want ErrSignatureDoesNotMatch", err)
}
}
func TestPayloadSigningRejectsSignedStreamingMode(t *testing.T) {
req, err := http.NewRequest(http.MethodPut, "http://example.com/b/k", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("x-amz-content-sha256", "STREAMING-AWS4-HMAC-SHA256-PAYLOAD")
err = validatePayloadSigningMode(req, &sigV4Input{})
if !errors.Is(err, ErrAuthorizationHeaderMalformed) {
t.Fatalf("validatePayloadSigningMode error = %v, want ErrAuthorizationHeaderMalformed", err)
}
}
func TestPayloadSigningAllowsUnsignedStreamingMode(t *testing.T) {
req, err := http.NewRequest(http.MethodPut, "http://example.com/b/k", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("x-amz-content-sha256", "STREAMING-UNSIGNED-PAYLOAD-TRAILER")
if err := validatePayloadSigningMode(req, &sigV4Input{}); err != nil {
t.Fatalf("validatePayloadSigningMode returned error: %v", err)
}
}
func newPayloadHashRequest(t *testing.T, signedBody, actualBody string) *http.Request {
t.Helper()
req, err := http.NewRequest(http.MethodPut, "http://example.com/b/k", strings.NewReader(actualBody))
if err != nil {
t.Fatal(err)
}
sum := sha256.Sum256([]byte(signedBody))
req.Header.Set("x-amz-content-sha256", hex.EncodeToString(sum[:]))
return req
}

View File

@@ -33,14 +33,16 @@ func statementMatches(stmt models.AuthPolicyStatement, target RequestTarget) boo
if !bucketMatches(stmt.Bucket, target.Bucket) { if !bucketMatches(stmt.Bucket, target.Bucket) {
return false return false
} }
if target.Key == "" {
return true
}
prefix := strings.TrimSpace(stmt.Prefix) prefix := strings.TrimSpace(stmt.Prefix)
if prefix == "" || prefix == "*" { if prefix == "" || prefix == "*" {
return true return true
} }
if target.Key == "" {
if target.Action == ActionListBucket {
return strings.HasPrefix(target.Prefix, prefix)
}
return true
}
return strings.HasPrefix(target.Key, prefix) return strings.HasPrefix(target.Key, prefix)
} }

52
auth/policy_test.go Normal file
View File

@@ -0,0 +1,52 @@
package auth
import (
"fs/models"
"testing"
)
func TestListBucketPolicyAppliesPrefix(t *testing.T) {
policy := &models.AuthPolicy{
Statements: []models.AuthPolicyStatement{
{
Effect: "allow",
Actions: []string{"s3:ListBucket"},
Bucket: "test-bucket",
Prefix: "allowed/",
},
},
}
if !isAllowed(policy, RequestTarget{Action: ActionListBucket, Bucket: "test-bucket", Prefix: "allowed/"}) {
t.Fatalf("expected matching list prefix to be allowed")
}
if !isAllowed(policy, RequestTarget{Action: ActionListBucket, Bucket: "test-bucket", Prefix: "allowed/nested/"}) {
t.Fatalf("expected nested list prefix to be allowed")
}
if isAllowed(policy, RequestTarget{Action: ActionListBucket, Bucket: "test-bucket"}) {
t.Fatalf("expected empty list prefix to be denied")
}
if isAllowed(policy, RequestTarget{Action: ActionListBucket, Bucket: "test-bucket", Prefix: "private/"}) {
t.Fatalf("expected non-matching list prefix to be denied")
}
}
func TestWildcardListBucketPolicyAllowsAnyPrefix(t *testing.T) {
policy := &models.AuthPolicy{
Statements: []models.AuthPolicyStatement{
{
Effect: "allow",
Actions: []string{"s3:ListBucket"},
Bucket: "test-bucket",
Prefix: "*",
},
},
}
if !isAllowed(policy, RequestTarget{Action: ActionListBucket, Bucket: "test-bucket"}) {
t.Fatalf("expected wildcard list policy to allow empty prefix")
}
if !isAllowed(policy, RequestTarget{Action: ActionListBucket, Bucket: "test-bucket", Prefix: "private/"}) {
t.Fatalf("expected wildcard list policy to allow arbitrary prefix")
}
}

View File

@@ -152,6 +152,9 @@ func (s *Service) AuthenticateRequest(r *http.Request) (RequestContext, error) {
if err := validateSigV4Input(s.now(), s.cfg, input); err != nil { if err := validateSigV4Input(s.now(), s.cfg, input); err != nil {
return RequestContext{}, err return RequestContext{}, err
} }
if err := validatePayloadSigningMode(r, input); err != nil {
return RequestContext{}, err
}
identity, err := s.store.GetAuthIdentity(input.AccessKeyID) identity, err := s.store.GetAuthIdentity(input.AccessKeyID)
if err != nil { if err != nil {
@@ -185,6 +188,13 @@ func (s *Service) AuthenticateRequest(r *http.Request) (RequestContext, error) {
AuthType: authType, AuthType: authType,
}, nil }, nil
} }
if RequiresHandlerAuthorization(r) {
return RequestContext{
Authenticated: true,
AccessKeyID: identity.AccessKeyID,
AuthType: authType,
}, nil
}
policy, err := s.store.GetAuthPolicy(identity.AccessKeyID) policy, err := s.store.GetAuthPolicy(identity.AccessKeyID)
if err != nil { if err != nil {
@@ -205,6 +215,29 @@ func (s *Service) AuthenticateRequest(r *http.Request) (RequestContext, error) {
}, nil }, nil
} }
func (s *Service) Authorize(accessKeyID string, target RequestTarget) error {
if !s.cfg.Enabled {
return nil
}
accessKeyID = strings.TrimSpace(accessKeyID)
if accessKeyID == "" {
return ErrAccessDenied
}
if target.Action == "" {
return ErrAccessDenied
}
policy, err := s.store.GetAuthPolicy(accessKeyID)
if err != nil {
return ErrAccessDenied
}
if !isAllowed(policy, target) {
return ErrAccessDenied
}
return nil
}
func (s *Service) CreateUser(input CreateUserInput) (*CreateUserResult, error) { func (s *Service) CreateUser(input CreateUserInput) (*CreateUserResult, error) {
if !s.cfg.Enabled { if !s.cfg.Enabled {
return nil, ErrAuthNotEnabled return nil, ErrAuthNotEnabled

View File

@@ -210,6 +210,17 @@ func validateSigV4Input(now time.Time, cfg Config, input *sigV4Input) error {
return nil return nil
} }
func validatePayloadSigningMode(r *http.Request, input *sigV4Input) error {
payloadHash := resolvePayloadHash(r, input.Presigned)
if isSignedStreamingPayloadHash(payloadHash) {
return fmt.Errorf("%w: signed streaming payload verification is not supported", ErrAuthorizationHeaderMalformed)
}
if payloadHashRequiresVerification(payloadHash) && !isHexSHA256(payloadHash) {
return fmt.Errorf("%w: invalid x-amz-content-sha256", ErrAuthorizationHeaderMalformed)
}
return nil
}
func signatureMatches(secret string, r *http.Request, input *sigV4Input) (bool, error) { func signatureMatches(secret string, r *http.Request, input *sigV4Input) (bool, error) {
payloadHash := resolvePayloadHash(r, input.Presigned) payloadHash := resolvePayloadHash(r, input.Presigned)
canonicalRequest, err := buildCanonicalRequest(r, input.SignedHeaders, payloadHash, input.Presigned) canonicalRequest, err := buildCanonicalRequest(r, input.SignedHeaders, payloadHash, input.Presigned)
@@ -233,6 +244,34 @@ func resolvePayloadHash(r *http.Request, presigned bool) string {
return hash return hash
} }
func isSignedStreamingPayloadHash(payloadHash string) bool {
payloadHash = strings.ToUpper(strings.TrimSpace(payloadHash))
return strings.HasPrefix(payloadHash, "STREAMING-AWS4-HMAC-SHA256-PAYLOAD")
}
func payloadHashRequiresVerification(payloadHash string) bool {
payloadHash = strings.ToUpper(strings.TrimSpace(payloadHash))
if payloadHash == "" || payloadHash == "UNSIGNED-PAYLOAD" {
return false
}
if strings.HasPrefix(payloadHash, "STREAMING-UNSIGNED-PAYLOAD") {
return false
}
return true
}
func isHexSHA256(value string) bool {
if len(value) != sha256.Size*2 {
return false
}
for _, ch := range value {
if (ch < '0' || ch > '9') && (ch < 'a' || ch > 'f') && (ch < 'A' || ch > 'F') {
return false
}
}
return true
}
func buildCanonicalRequest(r *http.Request, signedHeaders []string, payloadHash string, presigned bool) (string, error) { func buildCanonicalRequest(r *http.Request, signedHeaders []string, payloadHash string, presigned bool) (string, error) {
canonicalURI := canonicalPath(r.URL) canonicalURI := canonicalPath(r.URL)
canonicalQuery := canonicalQueryString(r.URL.RawQuery, presigned) canonicalQuery := canonicalQueryString(r.URL.RawQuery, presigned)
@@ -259,7 +298,57 @@ func canonicalPath(u *url.URL) string {
if path == "" { if path == "" {
return "/" return "/"
} }
return path return awsEncodePath(path)
}
func awsEncodePath(path string) string {
var b strings.Builder
b.Grow(len(path))
for i := 0; i < len(path); i++ {
ch := path[i]
if ch == '/' || isUnreserved(ch) {
b.WriteByte(ch)
continue
}
if ch == '%' && i+2 < len(path) && isHex(path[i+1]) && isHex(path[i+2]) {
b.WriteByte('%')
b.WriteByte(toUpperHex(path[i+1]))
b.WriteByte(toUpperHex(path[i+2]))
i += 2
continue
}
b.WriteByte('%')
b.WriteByte(hexUpper(ch >> 4))
b.WriteByte(hexUpper(ch & 0x0F))
}
return b.String()
}
func isUnreserved(ch byte) bool {
return (ch >= 'A' && ch <= 'Z') ||
(ch >= 'a' && ch <= 'z') ||
(ch >= '0' && ch <= '9') ||
ch == '-' || ch == '_' || ch == '.' || ch == '~'
}
func isHex(ch byte) bool {
return (ch >= '0' && ch <= '9') ||
(ch >= 'a' && ch <= 'f') ||
(ch >= 'A' && ch <= 'F')
}
func toUpperHex(ch byte) byte {
if ch >= 'a' && ch <= 'f' {
return ch - ('a' - 'A')
}
return ch
}
func hexUpper(nibble byte) byte {
if nibble < 10 {
return '0' + nibble
}
return 'A' + (nibble - 10)
} }
type queryPair struct { type queryPair struct {

50
auth/sigv4_test.go Normal file
View File

@@ -0,0 +1,50 @@
package auth
import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
func TestCanonicalPathEncodesEquals(t *testing.T) {
u := &url.URL{Path: "/test-bucket/jsp-data-raw/year=2026/month=03/day=12/vehicle_positions.parquet"}
got := canonicalPath(u)
want := "/test-bucket/jsp-data-raw/year%3D2026/month%3D03/day%3D12/vehicle_positions.parquet"
if got != want {
t.Fatalf("unexpected canonical path: got %q want %q", got, want)
}
}
func TestCanonicalPathPreservesExistingEscapes(t *testing.T) {
u, err := url.Parse("http://localhost:2600/test-bucket/jsp-data-raw/year%3d2026/file%2Eparquet")
if err != nil {
t.Fatalf("url.Parse failed: %v", err)
}
got := canonicalPath(u)
want := "/test-bucket/jsp-data-raw/year%3D2026/file%2Eparquet"
if got != want {
t.Fatalf("unexpected canonical path: got %q want %q", got, want)
}
}
func TestBuildCanonicalRequestUsesAwsEncodedPath(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://localhost:2600/test-bucket/jsp-data-raw/year=2026/month=03/day=12/vehicle_positions.parquet", nil)
req.Header.Set("x-amz-date", "20260313T120000Z")
req.Header.Set("x-amz-content-sha256", "UNSIGNED-PAYLOAD")
canonical, err := buildCanonicalRequest(req, []string{"host", "x-amz-content-sha256", "x-amz-date"}, "UNSIGNED-PAYLOAD", false)
if err != nil {
t.Fatalf("buildCanonicalRequest failed: %v", err)
}
lines := strings.Split(canonical, "\n")
if len(lines) < 2 {
t.Fatalf("canonical request has unexpected format: %q", canonical)
}
wantPath := "/test-bucket/jsp-data-raw/year%3D2026/month%3D03/day%3D12/vehicle_positions.parquet"
if lines[1] != wantPath {
t.Fatalf("unexpected canonical path line: got %q want %q", lines[1], wantPath)
}
}

View File

@@ -15,7 +15,7 @@ import (
) )
const ( const (
defaultAdminEndpoint = "http://localhost:3000" defaultAdminEndpoint = "http://localhost:2600"
defaultAdminRegion = "us-east-1" defaultAdminRegion = "us-east-1"
) )
@@ -48,6 +48,7 @@ func newAdminCommand(build BuildInfo) *cobra.Command {
cmd.AddCommand(newAdminUserCommand(opts)) cmd.AddCommand(newAdminUserCommand(opts))
cmd.AddCommand(newAdminDiagCommand(opts, build)) cmd.AddCommand(newAdminDiagCommand(opts, build))
cmd.AddCommand(newAdminSnapshotCommand(opts))
return cmd return cmd
} }
@@ -107,7 +108,7 @@ func endpointFromServerConfig(address string, port int) string {
host = "localhost" host = "localhost"
} }
if port <= 0 || port > 65535 { if port <= 0 || port > 65535 {
port = 3000 port = 2600
} }
return "http://" + net.JoinHostPort(host, strconv.Itoa(port)) return "http://" + net.JoinHostPort(host, strconv.Itoa(port))
} }

728
cmd/admin_snapshot.go Normal file
View File

@@ -0,0 +1,728 @@
package cmd
import (
"archive/tar"
"compress/gzip"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"os"
"path/filepath"
"strings"
"time"
bolt "go.etcd.io/bbolt"
"github.com/spf13/cobra"
)
const (
snapshotManifestPath = ".fs-snapshot/manifest.json"
snapshotFormat = 1
)
type snapshotFileEntry struct {
Path string `json:"path"`
Size int64 `json:"size"`
SHA256 string `json:"sha256"`
}
type snapshotManifest struct {
FormatVersion int `json:"formatVersion"`
CreatedAt string `json:"createdAt"`
SourcePath string `json:"sourcePath"`
Files []snapshotFileEntry `json:"files"`
}
type snapshotSummary struct {
SnapshotFile string `json:"snapshotFile"`
CreatedAt string `json:"createdAt"`
SourcePath string `json:"sourcePath"`
FileCount int `json:"fileCount"`
TotalBytes int64 `json:"totalBytes"`
}
func newAdminSnapshotCommand(opts *adminOptions) *cobra.Command {
cmd := &cobra.Command{
Use: "snapshot",
Short: "Offline snapshot and restore utilities",
RunE: func(cmd *cobra.Command, args []string) error {
return cmd.Help()
},
}
cmd.AddCommand(newAdminSnapshotCreateCommand(opts))
cmd.AddCommand(newAdminSnapshotInspectCommand(opts))
cmd.AddCommand(newAdminSnapshotRestoreCommand(opts))
return cmd
}
func newAdminSnapshotCreateCommand(opts *adminOptions) *cobra.Command {
var dataPath string
var outFile string
cmd := &cobra.Command{
Use: "create",
Short: "Create offline snapshot tarball (.tar.gz)",
RunE: func(cmd *cobra.Command, args []string) error {
dataPath = strings.TrimSpace(dataPath)
outFile = strings.TrimSpace(outFile)
if dataPath == "" {
return usageError("fs admin snapshot create --data-path <path> --out <snapshot.tar.gz>", "--data-path is required")
}
if outFile == "" {
return usageError("fs admin snapshot create --data-path <path> --out <snapshot.tar.gz>", "--out is required")
}
result, err := createSnapshotArchive(context.Background(), dataPath, outFile)
if err != nil {
return err
}
if opts.JSON {
return writeJSON(cmd.OutOrStdout(), result)
}
_, err = fmt.Fprintf(cmd.OutOrStdout(), "snapshot created: %s (files=%d bytes=%d)\n", result.SnapshotFile, result.FileCount, result.TotalBytes)
return err
},
}
cmd.Flags().StringVar(&dataPath, "data-path", "", "Source data path (must contain metadata.db)")
cmd.Flags().StringVar(&outFile, "out", "", "Output snapshot file path (.tar.gz)")
return cmd
}
func newAdminSnapshotInspectCommand(opts *adminOptions) *cobra.Command {
var filePath string
cmd := &cobra.Command{
Use: "inspect",
Short: "Inspect and verify snapshot archive integrity",
RunE: func(cmd *cobra.Command, args []string) error {
filePath = strings.TrimSpace(filePath)
if filePath == "" {
return usageError("fs admin snapshot inspect --file <snapshot.tar.gz>", "--file is required")
}
manifest, summary, err := inspectSnapshotArchive(filePath)
if err != nil {
return err
}
if opts.JSON {
return writeJSON(cmd.OutOrStdout(), map[string]any{
"summary": summary,
"manifest": manifest,
})
}
_, err = fmt.Fprintf(
cmd.OutOrStdout(),
"snapshot ok: %s\ncreated_at=%s source=%s files=%d bytes=%d\n",
summary.SnapshotFile,
summary.CreatedAt,
summary.SourcePath,
summary.FileCount,
summary.TotalBytes,
)
return err
},
}
cmd.Flags().StringVar(&filePath, "file", "", "Snapshot file path (.tar.gz)")
return cmd
}
func newAdminSnapshotRestoreCommand(opts *adminOptions) *cobra.Command {
var filePath string
var dataPath string
var force bool
cmd := &cobra.Command{
Use: "restore",
Short: "Restore snapshot into a data path (offline only)",
RunE: func(cmd *cobra.Command, args []string) error {
filePath = strings.TrimSpace(filePath)
dataPath = strings.TrimSpace(dataPath)
if filePath == "" {
return usageError("fs admin snapshot restore --file <snapshot.tar.gz> --data-path <path> [--force]", "--file is required")
}
if dataPath == "" {
return usageError("fs admin snapshot restore --file <snapshot.tar.gz> --data-path <path> [--force]", "--data-path is required")
}
result, err := restoreSnapshotArchive(context.Background(), filePath, dataPath, force)
if err != nil {
return err
}
if opts.JSON {
return writeJSON(cmd.OutOrStdout(), result)
}
_, err = fmt.Fprintf(
cmd.OutOrStdout(),
"snapshot restored to %s (files=%d bytes=%d)\n",
result.SourcePath,
result.FileCount,
result.TotalBytes,
)
return err
},
}
cmd.Flags().StringVar(&filePath, "file", "", "Snapshot file path (.tar.gz)")
cmd.Flags().StringVar(&dataPath, "data-path", "", "Destination data path")
cmd.Flags().BoolVar(&force, "force", false, "Overwrite destination data path if it exists")
return cmd
}
func createSnapshotArchive(ctx context.Context, dataPath, outFile string) (*snapshotSummary, error) {
_ = ctx
sourceAbs, err := filepath.Abs(filepath.Clean(dataPath))
if err != nil {
return nil, err
}
outAbs, err := filepath.Abs(filepath.Clean(outFile))
if err != nil {
return nil, err
}
if isPathWithin(sourceAbs, outAbs) {
return nil, errors.New("output file cannot be inside --data-path")
}
info, err := os.Stat(sourceAbs)
if err != nil {
return nil, err
}
if !info.IsDir() {
return nil, fmt.Errorf("data path %q is not a directory", sourceAbs)
}
if err := ensureMetadataExists(sourceAbs); err != nil {
return nil, err
}
if err := ensureDataPathOffline(sourceAbs); err != nil {
return nil, err
}
manifest, totalBytes, err := buildSnapshotManifest(sourceAbs)
if err != nil {
return nil, err
}
if err := os.MkdirAll(filepath.Dir(outAbs), 0o755); err != nil {
return nil, err
}
tmpPath := outAbs + ".tmp-" + strconvNowNano()
file, err := os.OpenFile(tmpPath, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0o600)
if err != nil {
return nil, err
}
defer func() {
_ = file.Close()
}()
gzw := gzip.NewWriter(file)
tw := tar.NewWriter(gzw)
if err := writeManifestToTar(tw, manifest); err != nil {
_ = tw.Close()
_ = gzw.Close()
_ = os.Remove(tmpPath)
return nil, err
}
for _, entry := range manifest.Files {
absPath := filepath.Join(sourceAbs, filepath.FromSlash(entry.Path))
if err := writeFileToTar(tw, absPath, entry.Path); err != nil {
_ = tw.Close()
_ = gzw.Close()
_ = os.Remove(tmpPath)
return nil, err
}
}
if err := tw.Close(); err != nil {
_ = gzw.Close()
_ = os.Remove(tmpPath)
return nil, err
}
if err := gzw.Close(); err != nil {
_ = os.Remove(tmpPath)
return nil, err
}
if err := file.Sync(); err != nil {
_ = os.Remove(tmpPath)
return nil, err
}
if err := file.Close(); err != nil {
_ = os.Remove(tmpPath)
return nil, err
}
if err := os.Rename(tmpPath, outAbs); err != nil {
_ = os.Remove(tmpPath)
return nil, err
}
if err := syncDir(filepath.Dir(outAbs)); err != nil {
return nil, err
}
return &snapshotSummary{
SnapshotFile: outAbs,
CreatedAt: manifest.CreatedAt,
SourcePath: sourceAbs,
FileCount: len(manifest.Files),
TotalBytes: totalBytes,
}, nil
}
func inspectSnapshotArchive(filePath string) (*snapshotManifest, *snapshotSummary, error) {
fileAbs, err := filepath.Abs(filepath.Clean(filePath))
if err != nil {
return nil, nil, err
}
manifest, actual, err := readSnapshotArchive(fileAbs)
if err != nil {
return nil, nil, err
}
expected := map[string]snapshotFileEntry{}
var totalBytes int64
for _, entry := range manifest.Files {
expected[entry.Path] = entry
totalBytes += entry.Size
}
if len(expected) != len(actual) {
return nil, nil, fmt.Errorf("snapshot validation failed: expected %d files, got %d", len(expected), len(actual))
}
for path, exp := range expected {
got, ok := actual[path]
if !ok {
return nil, nil, fmt.Errorf("snapshot validation failed: missing file %s", path)
}
if got.Size != exp.Size || got.SHA256 != exp.SHA256 {
return nil, nil, fmt.Errorf("snapshot validation failed: checksum mismatch for %s", path)
}
}
return manifest, &snapshotSummary{
SnapshotFile: fileAbs,
CreatedAt: manifest.CreatedAt,
SourcePath: manifest.SourcePath,
FileCount: len(manifest.Files),
TotalBytes: totalBytes,
}, nil
}
func restoreSnapshotArchive(ctx context.Context, filePath, destinationPath string, force bool) (*snapshotSummary, error) {
_ = ctx
manifest, summary, err := inspectSnapshotArchive(filePath)
if err != nil {
return nil, err
}
destAbs, err := filepath.Abs(filepath.Clean(destinationPath))
if err != nil {
return nil, err
}
if fi, statErr := os.Stat(destAbs); statErr == nil && fi.IsDir() {
if err := ensureDataPathOffline(destAbs); err != nil {
return nil, err
}
entries, err := os.ReadDir(destAbs)
if err == nil && len(entries) > 0 && !force {
return nil, errors.New("destination data path is not empty; use --force to overwrite")
}
}
parent := filepath.Dir(destAbs)
if err := os.MkdirAll(parent, 0o755); err != nil {
return nil, err
}
stage := filepath.Join(parent, "."+filepath.Base(destAbs)+".restore-"+strconvNowNano())
if err := os.MkdirAll(stage, 0o755); err != nil {
return nil, err
}
cleanupStage := true
defer func() {
if cleanupStage {
_ = os.RemoveAll(stage)
}
}()
if err := extractSnapshotArchive(filePath, stage, manifest); err != nil {
return nil, err
}
if err := syncDir(stage); err != nil {
return nil, err
}
if _, err := os.Stat(destAbs); err == nil {
if !force {
return nil, errors.New("destination data path exists; use --force to overwrite")
}
if err := os.RemoveAll(destAbs); err != nil {
return nil, err
}
}
if err := os.Rename(stage, destAbs); err != nil {
return nil, err
}
if err := syncDir(parent); err != nil {
return nil, err
}
cleanupStage = false
summary.SourcePath = destAbs
return summary, nil
}
func ensureMetadataExists(dataPath string) error {
dbPath := filepath.Join(dataPath, "metadata.db")
info, err := os.Stat(dbPath)
if err != nil {
return fmt.Errorf("metadata.db not found in %s", dataPath)
}
if !info.Mode().IsRegular() {
return fmt.Errorf("metadata.db in %s is not a regular file", dataPath)
}
return nil
}
func ensureDataPathOffline(dataPath string) error {
dbPath := filepath.Join(dataPath, "metadata.db")
if _, err := os.Stat(dbPath); err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil
}
return err
}
db, err := bolt.Open(dbPath, 0o600, &bolt.Options{
Timeout: 100 * time.Millisecond,
ReadOnly: true,
})
if err != nil {
return fmt.Errorf("data path appears in use (metadata.db locked): %w", err)
}
return db.Close()
}
func buildSnapshotManifest(dataPath string) (*snapshotManifest, int64, error) {
entries := make([]snapshotFileEntry, 0, 128)
var totalBytes int64
err := filepath.WalkDir(dataPath, func(path string, d fs.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
if d.IsDir() {
return nil
}
info, err := d.Info()
if err != nil {
return err
}
if !info.Mode().IsRegular() {
return nil
}
rel, err := filepath.Rel(dataPath, path)
if err != nil {
return err
}
rel = filepath.ToSlash(filepath.Clean(rel))
if rel == "." || rel == "" {
return nil
}
sum, err := sha256File(path)
if err != nil {
return err
}
totalBytes += info.Size()
entries = append(entries, snapshotFileEntry{
Path: rel,
Size: info.Size(),
SHA256: sum,
})
return nil
})
if err != nil {
return nil, 0, err
}
if len(entries) == 0 {
return nil, 0, errors.New("data path contains no regular files to snapshot")
}
return &snapshotManifest{
FormatVersion: snapshotFormat,
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
SourcePath: dataPath,
Files: entries,
}, totalBytes, nil
}
func writeManifestToTar(tw *tar.Writer, manifest *snapshotManifest) error {
if tw == nil || manifest == nil {
return errors.New("invalid manifest writer input")
}
payload, err := json.Marshal(manifest)
if err != nil {
return err
}
header := &tar.Header{
Name: snapshotManifestPath,
Mode: 0o600,
Size: int64(len(payload)),
ModTime: time.Now(),
}
if err := tw.WriteHeader(header); err != nil {
return err
}
_, err = tw.Write(payload)
return err
}
func writeFileToTar(tw *tar.Writer, absPath, relPath string) error {
file, err := os.Open(absPath)
if err != nil {
return err
}
defer file.Close()
info, err := file.Stat()
if err != nil {
return err
}
header := &tar.Header{
Name: relPath,
Mode: int64(info.Mode().Perm()),
Size: info.Size(),
ModTime: info.ModTime(),
}
if err := tw.WriteHeader(header); err != nil {
return err
}
_, err = io.Copy(tw, file)
return err
}
func readSnapshotArchive(filePath string) (*snapshotManifest, map[string]snapshotFileEntry, error) {
file, err := os.Open(filePath)
if err != nil {
return nil, nil, err
}
defer file.Close()
gzr, err := gzip.NewReader(file)
if err != nil {
return nil, nil, err
}
defer gzr.Close()
tr := tar.NewReader(gzr)
actual := make(map[string]snapshotFileEntry)
var manifest *snapshotManifest
for {
header, err := tr.Next()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return nil, nil, err
}
name, err := cleanArchivePath(header.Name)
if err != nil {
return nil, nil, err
}
if header.Typeflag == tar.TypeDir {
continue
}
if header.Typeflag != tar.TypeReg && header.Typeflag != tar.TypeRegA {
return nil, nil, fmt.Errorf("unsupported tar entry type for %s", name)
}
if name == snapshotManifestPath {
raw, err := io.ReadAll(tr)
if err != nil {
return nil, nil, err
}
current := &snapshotManifest{}
if err := json.Unmarshal(raw, current); err != nil {
return nil, nil, err
}
manifest = current
continue
}
size, hashHex, err := digestReader(tr)
if err != nil {
return nil, nil, err
}
actual[name] = snapshotFileEntry{
Path: name,
Size: size,
SHA256: hashHex,
}
}
if manifest == nil {
return nil, nil, errors.New("snapshot manifest.json not found")
}
if manifest.FormatVersion != snapshotFormat {
return nil, nil, fmt.Errorf("unsupported snapshot format version %d", manifest.FormatVersion)
}
return manifest, actual, nil
}
func extractSnapshotArchive(filePath, destination string, manifest *snapshotManifest) error {
expected := make(map[string]snapshotFileEntry, len(manifest.Files))
for _, entry := range manifest.Files {
expected[entry.Path] = entry
}
seen := make(map[string]struct{}, len(expected))
file, err := os.Open(filePath)
if err != nil {
return err
}
defer file.Close()
gzr, err := gzip.NewReader(file)
if err != nil {
return err
}
defer gzr.Close()
tr := tar.NewReader(gzr)
for {
header, err := tr.Next()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return err
}
name, err := cleanArchivePath(header.Name)
if err != nil {
return err
}
if name == snapshotManifestPath {
if _, err := io.Copy(io.Discard, tr); err != nil {
return err
}
continue
}
if header.Typeflag == tar.TypeDir {
continue
}
if header.Typeflag != tar.TypeReg && header.Typeflag != tar.TypeRegA {
return fmt.Errorf("unsupported tar entry type for %s", name)
}
exp, ok := expected[name]
if !ok {
return fmt.Errorf("snapshot contains unexpected file %s", name)
}
targetPath := filepath.Join(destination, filepath.FromSlash(name))
if !isPathWithin(destination, targetPath) {
return fmt.Errorf("invalid archive path %s", name)
}
if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil {
return err
}
out, err := os.OpenFile(targetPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, os.FileMode(header.Mode)&0o777)
if err != nil {
return err
}
hasher := sha256.New()
written, copyErr := io.Copy(io.MultiWriter(out, hasher), tr)
syncErr := out.Sync()
closeErr := out.Close()
if copyErr != nil {
return copyErr
}
if syncErr != nil {
return syncErr
}
if closeErr != nil {
return closeErr
}
if err := syncDir(filepath.Dir(targetPath)); err != nil {
return err
}
sum := hex.EncodeToString(hasher.Sum(nil))
if written != exp.Size || sum != exp.SHA256 {
return fmt.Errorf("checksum mismatch while extracting %s", name)
}
seen[name] = struct{}{}
}
if len(seen) != len(expected) {
return fmt.Errorf("restore validation failed: extracted %d files, expected %d", len(seen), len(expected))
}
for path := range expected {
if _, ok := seen[path]; !ok {
return fmt.Errorf("restore validation failed: missing file %s", path)
}
}
return nil
}
func cleanArchivePath(name string) (string, error) {
name = strings.TrimSpace(name)
if name == "" {
return "", errors.New("empty archive path")
}
name = filepath.ToSlash(filepath.Clean(name))
if strings.HasPrefix(name, "/") || strings.HasPrefix(name, "../") || strings.Contains(name, "/../") || name == ".." {
return "", fmt.Errorf("unsafe archive path %q", name)
}
return name, nil
}
func digestReader(r io.Reader) (int64, string, error) {
hasher := sha256.New()
n, err := io.Copy(hasher, r)
if err != nil {
return 0, "", err
}
return n, hex.EncodeToString(hasher.Sum(nil)), nil
}
func sha256File(path string) (string, error) {
file, err := os.Open(path)
if err != nil {
return "", err
}
defer file.Close()
return sha256FromReader(file)
}
func sha256FromReader(reader io.Reader) (string, error) {
hasher := sha256.New()
if _, err := io.Copy(hasher, reader); err != nil {
return "", err
}
return hex.EncodeToString(hasher.Sum(nil)), nil
}
func isPathWithin(base, candidate string) bool {
base = filepath.Clean(base)
candidate = filepath.Clean(candidate)
rel, err := filepath.Rel(base, candidate)
if err != nil {
return false
}
return rel == "." || (rel != ".." && !strings.HasPrefix(rel, ".."+string(filepath.Separator)))
}
func syncDir(path string) error {
dir, err := os.Open(path)
if err != nil {
return err
}
defer dir.Close()
return dir.Sync()
}
func strconvNowNano() string {
return fmt.Sprintf("%d", time.Now().UnixNano())
}

240
cmd/admin_snapshot_test.go Normal file
View File

@@ -0,0 +1,240 @@
package cmd
import (
"archive/tar"
"bytes"
"compress/gzip"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"os"
"path/filepath"
"strings"
"testing"
bolt "go.etcd.io/bbolt"
)
type snapshotArchiveEntry struct {
Path string
Data []byte
}
func TestInspectSnapshotArchiveRejectsUnsafePath(t *testing.T) {
t.Parallel()
archive := filepath.Join(t.TempDir(), "bad.tar.gz")
manifest := manifestForEntries([]snapshotArchiveEntry{
{Path: "metadata.db", Data: []byte("db")},
})
err := writeSnapshotArchiveForTest(archive, manifest, []snapshotArchiveEntry{
{Path: "../escape", Data: []byte("oops")},
}, true)
if err != nil {
t.Fatalf("write test archive: %v", err)
}
_, _, err = inspectSnapshotArchive(archive)
if err == nil || !strings.Contains(err.Error(), "unsafe archive path") {
t.Fatalf("expected unsafe archive path error, got %v", err)
}
}
func TestInspectSnapshotArchiveChecksumMismatch(t *testing.T) {
t.Parallel()
archive := filepath.Join(t.TempDir(), "mismatch.tar.gz")
manifest := manifestForEntries([]snapshotArchiveEntry{
{Path: "chunks/c1", Data: []byte("good")},
})
err := writeSnapshotArchiveForTest(archive, manifest, []snapshotArchiveEntry{
{Path: "chunks/c1", Data: []byte("bad")},
}, true)
if err != nil {
t.Fatalf("write test archive: %v", err)
}
_, _, err = inspectSnapshotArchive(archive)
if err == nil || !strings.Contains(err.Error(), "checksum mismatch") {
t.Fatalf("expected checksum mismatch error, got %v", err)
}
}
func TestInspectSnapshotArchiveMissingManifest(t *testing.T) {
t.Parallel()
archive := filepath.Join(t.TempDir(), "no-manifest.tar.gz")
err := writeSnapshotArchiveForTest(archive, nil, []snapshotArchiveEntry{
{Path: "chunks/c1", Data: []byte("x")},
}, false)
if err != nil {
t.Fatalf("write test archive: %v", err)
}
_, _, err = inspectSnapshotArchive(archive)
if err == nil || !strings.Contains(err.Error(), "manifest.json not found") {
t.Fatalf("expected missing manifest error, got %v", err)
}
}
func TestInspectSnapshotArchiveUnsupportedFormat(t *testing.T) {
t.Parallel()
archive := filepath.Join(t.TempDir(), "unsupported-format.tar.gz")
manifest := manifestForEntries([]snapshotArchiveEntry{
{Path: "chunks/c1", Data: []byte("x")},
})
manifest.FormatVersion = 99
err := writeSnapshotArchiveForTest(archive, manifest, []snapshotArchiveEntry{
{Path: "chunks/c1", Data: []byte("x")},
}, true)
if err != nil {
t.Fatalf("write test archive: %v", err)
}
_, _, err = inspectSnapshotArchive(archive)
if err == nil || !strings.Contains(err.Error(), "unsupported snapshot format version") {
t.Fatalf("expected unsupported format error, got %v", err)
}
}
func TestRestoreSnapshotArchiveDestinationBehavior(t *testing.T) {
t.Parallel()
root := t.TempDir()
archive := filepath.Join(root, "ok.tar.gz")
destination := filepath.Join(root, "dst")
entries := []snapshotArchiveEntry{
{Path: "metadata.db", Data: []byte("db-bytes")},
{Path: "chunks/c1", Data: []byte("chunk-1")},
}
manifest := manifestForEntries(entries)
if err := writeSnapshotArchiveForTest(archive, manifest, entries, true); err != nil {
t.Fatalf("write test archive: %v", err)
}
if err := os.MkdirAll(destination, 0o755); err != nil {
t.Fatalf("mkdir destination: %v", err)
}
if err := os.WriteFile(filepath.Join(destination, "old.txt"), []byte("old"), 0o600); err != nil {
t.Fatalf("seed destination: %v", err)
}
if _, err := restoreSnapshotArchive(context.Background(), archive, destination, false); err == nil || !strings.Contains(err.Error(), "not empty") {
t.Fatalf("expected non-empty destination error, got %v", err)
}
if _, err := restoreSnapshotArchive(context.Background(), archive, destination, true); err != nil {
t.Fatalf("restore with force: %v", err)
}
if _, err := os.Stat(filepath.Join(destination, "old.txt")); !os.IsNotExist(err) {
t.Fatalf("expected old file to be removed, stat err=%v", err)
}
got, err := os.ReadFile(filepath.Join(destination, "chunks/c1"))
if err != nil {
t.Fatalf("read restored chunk: %v", err)
}
if string(got) != "chunk-1" {
t.Fatalf("restored chunk mismatch: got %q", string(got))
}
}
func TestCreateSnapshotArchiveRejectsOutputInsideDataPath(t *testing.T) {
t.Parallel()
root := t.TempDir()
if err := os.MkdirAll(filepath.Join(root, "chunks"), 0o755); err != nil {
t.Fatalf("mkdir chunks: %v", err)
}
if err := createBoltDBForTest(filepath.Join(root, "metadata.db")); err != nil {
t.Fatalf("create metadata db: %v", err)
}
if err := os.WriteFile(filepath.Join(root, "chunks/c1"), []byte("x"), 0o600); err != nil {
t.Fatalf("write chunk: %v", err)
}
out := filepath.Join(root, "inside.tar.gz")
if _, err := createSnapshotArchive(context.Background(), root, out); err == nil || !strings.Contains(err.Error(), "cannot be inside") {
t.Fatalf("expected output-inside-data-path error, got %v", err)
}
}
func writeSnapshotArchiveForTest(path string, manifest *snapshotManifest, entries []snapshotArchiveEntry, includeManifest bool) error {
file, err := os.Create(path)
if err != nil {
return err
}
defer file.Close()
gzw := gzip.NewWriter(file)
defer gzw.Close()
tw := tar.NewWriter(gzw)
defer tw.Close()
if includeManifest {
raw, err := json.Marshal(manifest)
if err != nil {
return err
}
if err := writeTarEntry(tw, snapshotManifestPath, raw); err != nil {
return err
}
}
for _, entry := range entries {
if err := writeTarEntry(tw, entry.Path, entry.Data); err != nil {
return err
}
}
return nil
}
func writeTarEntry(tw *tar.Writer, name string, data []byte) error {
header := &tar.Header{
Name: name,
Mode: 0o600,
Size: int64(len(data)),
}
if err := tw.WriteHeader(header); err != nil {
return err
}
_, err := ioCopyBytes(tw, data)
return err
}
func manifestForEntries(entries []snapshotArchiveEntry) *snapshotManifest {
files := make([]snapshotFileEntry, 0, len(entries))
for _, entry := range entries {
sum := sha256.Sum256(entry.Data)
files = append(files, snapshotFileEntry{
Path: filepath.ToSlash(filepath.Clean(entry.Path)),
Size: int64(len(entry.Data)),
SHA256: hex.EncodeToString(sum[:]),
})
}
return &snapshotManifest{
FormatVersion: snapshotFormat,
CreatedAt: "2026-03-11T00:00:00Z",
SourcePath: "/tmp/source",
Files: files,
}
}
func createBoltDBForTest(path string) error {
db, err := bolt.Open(path, 0o600, nil)
if err != nil {
return err
}
defer db.Close()
return db.Update(func(tx *bolt.Tx) error {
_, err := tx.CreateBucketIfNotExists([]byte("x"))
return err
})
}
func ioCopyBytes(w *tar.Writer, data []byte) (int64, error) {
n, err := bytes.NewReader(data).WriteTo(w)
return n, err
}

View File

@@ -22,6 +22,7 @@ func newAdminUserCommand(opts *adminOptions) *cobra.Command {
cmd.AddCommand(newAdminUserDeleteCommand(opts)) cmd.AddCommand(newAdminUserDeleteCommand(opts))
cmd.AddCommand(newAdminUserSetStatusCommand(opts)) cmd.AddCommand(newAdminUserSetStatusCommand(opts))
cmd.AddCommand(newAdminUserSetRoleCommand(opts)) cmd.AddCommand(newAdminUserSetRoleCommand(opts))
cmd.AddCommand(newAdminUserRemoveRoleCommand(opts))
return cmd return cmd
} }
@@ -253,6 +254,68 @@ func newAdminUserSetRoleCommand(opts *adminOptions) *cobra.Command {
return cmd return cmd
} }
func newAdminUserRemoveRoleCommand(opts *adminOptions) *cobra.Command {
var (
role string
bucket string
prefix string
)
cmd := &cobra.Command{
Use: "remove-role <access-key-id>",
Short: "Remove one role policy statement from user",
Args: requireAccessKeyArg("fs admin user remove-role <access-key-id> --role admin|readwrite|readonly [--bucket <name>] [--prefix <path>]"),
RunE: func(cmd *cobra.Command, args []string) error {
policy, err := buildPolicyFromRole(rolePolicyOptions{
Role: role,
Bucket: bucket,
Prefix: prefix,
})
if err != nil {
return usageError("fs admin user remove-role <access-key-id> --role admin|readwrite|readonly [--bucket <name>] [--prefix <path>]", err.Error())
}
if len(policy.Statements) == 0 {
return usageError("fs admin user remove-role <access-key-id> --role admin|readwrite|readonly [--bucket <name>] [--prefix <path>]", "no statement to remove")
}
client, err := newAdminAPIClient(opts, true)
if err != nil {
return err
}
existing, err := client.GetUser(context.Background(), args[0])
if err != nil {
return err
}
if existing.Policy == nil || len(existing.Policy.Statements) == 0 {
return fmt.Errorf("user %q has no policy statements", args[0])
}
target := policy.Statements[0]
nextPolicy, removed := removePolicyStatements(existing.Policy, target)
if removed == 0 {
return fmt.Errorf("no matching statement found for role=%s bucket=%s prefix=%s", role, bucket, prefix)
}
if len(nextPolicy.Statements) == 0 {
return fmt.Errorf("cannot remove the last policy statement; add another role first or use set-role --replace")
}
out, err := client.SetUserPolicy(context.Background(), args[0], nextPolicy)
if err != nil {
return err
}
if opts.JSON {
return writeJSON(cmd.OutOrStdout(), out)
}
return writeUserTable(cmd.OutOrStdout(), out, false)
},
}
cmd.Flags().StringVar(&role, "role", "readwrite", "Role: admin|readwrite|readonly")
cmd.Flags().StringVar(&bucket, "bucket", "*", "Bucket scope, defaults to *")
cmd.Flags().StringVar(&prefix, "prefix", "*", "Prefix scope, defaults to *")
return cmd
}
func mergePolicyStatements(existing *adminPolicy, addition adminPolicy) adminPolicy { func mergePolicyStatements(existing *adminPolicy, addition adminPolicy) adminPolicy {
merged := adminPolicy{} merged := adminPolicy{}
if existing != nil { if existing != nil {
@@ -289,6 +352,25 @@ func policyStatementsEqual(a, b adminPolicyStatement) bool {
return true return true
} }
func removePolicyStatements(existing *adminPolicy, target adminPolicyStatement) (adminPolicy, int) {
out := adminPolicy{}
if existing == nil {
return out, 0
}
out.Principal = existing.Principal
out.Statements = make([]adminPolicyStatement, 0, len(existing.Statements))
removed := 0
for _, stmt := range existing.Statements {
if policyStatementsEqual(stmt, target) {
removed++
continue
}
out.Statements = append(out.Statements, stmt)
}
return out, removed
}
func requireAccessKeyArg(usage string) cobra.PositionalArgs { func requireAccessKeyArg(usage string) cobra.PositionalArgs {
return func(cmd *cobra.Command, args []string) error { return func(cmd *cobra.Command, args []string) error {
if len(args) != 1 { if len(args) != 1 {

View File

@@ -32,12 +32,15 @@ This project is S3-compatible for a focused subset of operations.
### Authentication ### Authentication
- AWS SigV4 header auth - AWS SigV4 header auth
- AWS SigV4 presigned query auth - AWS SigV4 presigned query auth
- `aws-chunked` payload decode for streaming uploads - `aws-chunked` payload decode for unsigned streaming upload modes
- SigV4 payload hash verification for fixed-size signed payloads
## Partially Implemented / Differences ## Partially Implemented / Differences
- Exact parity with AWS S3 error codes/headers is still evolving. - Exact parity with AWS S3 error codes/headers is still evolving.
- Some S3 edge-case behaviors may differ (especially uncommon query/header combinations). - Some S3 edge-case behaviors may differ (especially uncommon query/header combinations).
- Admin API is custom JSON (`/_admin/v1/*`). - Admin API is custom JSON (`/_admin/v1/*`).
- Object and upload-part payloads are limited by `FS_MAX_OBJECT_UPLOAD_BYTES` (default 5 GiB).
- Signed `aws-chunked` payload modes that require per-chunk signature verification are rejected until chunk-signature validation is implemented.
## Not Implemented (Current) ## Not Implemented (Current)
- Bucket versioning - Bucket versioning

View File

@@ -902,9 +902,6 @@ func (h *MetadataHandler) CleanupMultipartUploads(retention time.Duration) (int,
if err := json.Unmarshal(v, &upload); err != nil { if err := json.Unmarshal(v, &upload); err != nil {
return err return err
} }
if upload.State == "pending" {
return nil
}
createdAt, err := time.Parse(time.RFC3339, upload.CreatedAt) createdAt, err := time.Parse(time.RFC3339, upload.CreatedAt)
if err != nil { if err != nil {
return nil return nil

99
metadata/metadata_test.go Normal file
View File

@@ -0,0 +1,99 @@
package metadata
import (
"errors"
"fs/models"
"path/filepath"
"testing"
"time"
"go.etcd.io/bbolt"
)
func TestCleanupMultipartUploadsDeletesExpiredPendingUpload(t *testing.T) {
h := newTestMetadataHandler(t)
if err := h.CreateBucket("test-bucket"); err != nil {
t.Fatalf("CreateBucket: %v", err)
}
upload, err := h.CreateMultipartUpload("test-bucket", "object.txt")
if err != nil {
t.Fatalf("CreateMultipartUpload: %v", err)
}
if err := h.PutMultipartPart(upload.UploadID, models.UploadedPart{PartNumber: 1, ETag: "etag", Size: 4, Chunks: []string{"chunk-id"}}); err != nil {
t.Fatalf("PutMultipartPart: %v", err)
}
setMultipartUploadCreatedAt(t, h, upload.UploadID, time.Now().Add(-2*time.Hour))
cleaned, err := h.CleanupMultipartUploads(time.Hour)
if err != nil {
t.Fatalf("CleanupMultipartUploads: %v", err)
}
if cleaned != 1 {
t.Fatalf("cleaned = %d, want 1", cleaned)
}
if _, err := h.GetMultipartUpload(upload.UploadID); !errors.Is(err, ErrMultipartNotFound) {
t.Fatalf("GetMultipartUpload error = %v, want ErrMultipartNotFound", err)
}
if _, err := h.ListMultipartParts(upload.UploadID); !errors.Is(err, ErrMultipartNotFound) {
t.Fatalf("ListMultipartParts error = %v, want ErrMultipartNotFound", err)
}
}
func TestCleanupMultipartUploadsKeepsRecentPendingUpload(t *testing.T) {
h := newTestMetadataHandler(t)
if err := h.CreateBucket("test-bucket"); err != nil {
t.Fatalf("CreateBucket: %v", err)
}
upload, err := h.CreateMultipartUpload("test-bucket", "object.txt")
if err != nil {
t.Fatalf("CreateMultipartUpload: %v", err)
}
cleaned, err := h.CleanupMultipartUploads(time.Hour)
if err != nil {
t.Fatalf("CleanupMultipartUploads: %v", err)
}
if cleaned != 0 {
t.Fatalf("cleaned = %d, want 0", cleaned)
}
if _, err := h.GetMultipartUpload(upload.UploadID); err != nil {
t.Fatalf("recent upload should remain: %v", err)
}
}
func TestCleanupMultipartUploadsDisabledForNonPositiveRetention(t *testing.T) {
h := newTestMetadataHandler(t)
cleaned, err := h.CleanupMultipartUploads(0)
if err != nil {
t.Fatalf("CleanupMultipartUploads: %v", err)
}
if cleaned != 0 {
t.Fatalf("cleaned = %d, want 0", cleaned)
}
}
func newTestMetadataHandler(t *testing.T) *MetadataHandler {
t.Helper()
h, err := NewMetadataHandler(filepath.Join(t.TempDir(), "metadata.db"))
if err != nil {
t.Fatalf("NewMetadataHandler: %v", err)
}
t.Cleanup(func() {
_ = h.Close()
})
return h
}
func setMultipartUploadCreatedAt(t *testing.T, h *MetadataHandler, uploadID string, createdAt time.Time) {
t.Helper()
if err := h.update(func(tx *bbolt.Tx) error {
upload, uploadsBucket, err := getMultipartUploadFromTx(tx, uploadID)
if err != nil {
return err
}
upload.CreatedAt = createdAt.UTC().Format(time.RFC3339)
return putMultipartUpload(uploadsBucket, uploadID, upload)
}); err != nil {
t.Fatalf("set multipart created_at: %v", err)
}
}

View File

@@ -158,6 +158,13 @@ type CompleteMultipartUploadResult struct {
Location string `xml:"Location,omitempty"` Location string `xml:"Location,omitempty"`
} }
type CopyObjectResult struct {
XMLName xml.Name `xml:"CopyObjectResult"`
Xmlns string `xml:"xmlns,attr,omitempty"`
LastModified string `xml:"LastModified"`
ETag string `xml:"ETag"`
}
type ListPartsResult struct { type ListPartsResult struct {
XMLName xml.Name `xml:"ListPartsResult"` XMLName xml.Name `xml:"ListPartsResult"`
Xmlns string `xml:"xmlns,attr"` Xmlns string `xml:"xmlns,attr"`

View File

@@ -21,6 +21,7 @@ type ObjectService struct {
metadata *metadata.MetadataHandler metadata *metadata.MetadataHandler
blob *storage.BlobStore blob *storage.BlobStore
multipartRetention time.Duration multipartRetention time.Duration
maxUploadSize int64
gcMu sync.RWMutex gcMu sync.RWMutex
} }
@@ -29,16 +30,24 @@ var (
ErrInvalidPartOrder = errors.New("invalid multipart part order") ErrInvalidPartOrder = errors.New("invalid multipart part order")
ErrInvalidCompleteRequest = errors.New("invalid complete multipart request") ErrInvalidCompleteRequest = errors.New("invalid complete multipart request")
ErrEntityTooSmall = errors.New("multipart entity too small") ErrEntityTooSmall = errors.New("multipart entity too small")
ErrEntityTooLarge = errors.New("entity too large")
) )
func NewObjectService(metadataHandler *metadata.MetadataHandler, blobHandler *storage.BlobStore, multipartRetention time.Duration) *ObjectService { const DefaultMaxUploadSize int64 = 5 * 1024 * 1024 * 1024
func NewObjectService(metadataHandler *metadata.MetadataHandler, blobHandler *storage.BlobStore, multipartRetention time.Duration, maxUploadSize ...int64) *ObjectService {
if multipartRetention <= 0 { if multipartRetention <= 0 {
multipartRetention = 24 * time.Hour multipartRetention = 24 * time.Hour
} }
limit := DefaultMaxUploadSize
if len(maxUploadSize) > 0 {
limit = maxUploadSize[0]
}
return &ObjectService{ return &ObjectService{
metadata: metadataHandler, metadata: metadataHandler,
blob: blobHandler, blob: blobHandler,
multipartRetention: multipartRetention, multipartRetention: multipartRetention,
maxUploadSize: limit,
} }
} }
@@ -74,7 +83,7 @@ func (s *ObjectService) PutObject(bucket, key, contentType string, input io.Read
unlock := s.acquireGCRLock() unlock := s.acquireGCRLock()
defer unlock() defer unlock()
chunks, size, etag, err := s.blob.IngestStream(input) chunks, size, etag, err := s.blob.IngestStream(s.limitUpload(input))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -104,6 +113,38 @@ func (s *ObjectService) PutObject(bucket, key, contentType string, input io.Read
return manifest, nil return manifest, nil
} }
func (s *ObjectService) CopyObject(srcBucket, srcKey, dstBucket, dstKey string) (*models.ObjectManifest, error) {
start := time.Now()
success := false
defer func() {
metrics.Default.ObserveService("copy_object", time.Since(start), success)
}()
unlock := s.acquireGCRLock()
defer unlock()
source, err := s.metadata.GetManifest(srcBucket, srcKey)
if err != nil {
return nil, err
}
manifest := &models.ObjectManifest{
Bucket: dstBucket,
Key: dstKey,
Size: source.Size,
ContentType: source.ContentType,
ETag: source.ETag,
Chunks: append([]string(nil), source.Chunks...),
CreatedAt: time.Now().Unix(),
}
if err := s.metadata.PutManifest(manifest); err != nil {
return nil, err
}
success = true
return manifest, nil
}
func (s *ObjectService) GetObject(bucket, key string) (io.ReadCloser, *models.ObjectManifest, error) { func (s *ObjectService) GetObject(bucket, key string) (io.ReadCloser, *models.ObjectManifest, error) {
start := time.Now() start := time.Now()
@@ -126,7 +167,9 @@ func (s *ObjectService) GetObject(bucket, key string) (io.ReadCloser, *models.Ob
defer func() { defer func() {
metrics.Default.ObserveService("get_object", time.Since(start), streamOK) metrics.Default.ObserveService("get_object", time.Since(start), streamOK)
}() }()
defer metrics.Default.ObserveLockHold("gc_mu_read", time.Since(holdStart)) defer func() {
metrics.Default.ObserveLockHold("gc_mu_read", time.Since(holdStart))
}()
defer s.gcMu.RUnlock() defer s.gcMu.RUnlock()
if err := s.blob.AssembleStream(manifest.Chunks, pw); err != nil { if err := s.blob.AssembleStream(manifest.Chunks, pw); err != nil {
_ = pw.CloseWithError(err) _ = pw.CloseWithError(err)
@@ -279,7 +322,7 @@ func (s *ObjectService) UploadPart(bucket, key, uploadId string, partNumber int,
} }
var uploadedPart models.UploadedPart var uploadedPart models.UploadedPart
chunkIds, totalSize, etag, err := s.blob.IngestStream(input) chunkIds, totalSize, etag, err := s.blob.IngestStream(s.limitUpload(input))
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -368,6 +411,9 @@ func (s *ObjectService) CompleteMultipartUpload(bucket, key, uploadID string, co
orderedParts = append(orderedParts, storedPart) orderedParts = append(orderedParts, storedPart)
chunks = append(chunks, storedPart.Chunks...) chunks = append(chunks, storedPart.Chunks...)
totalSize += storedPart.Size totalSize += storedPart.Size
if s.maxUploadSize > 0 && totalSize > s.maxUploadSize {
return nil, ErrEntityTooLarge
}
} }
finalETag := buildMultipartETag(orderedParts) finalETag := buildMultipartETag(orderedParts)
@@ -403,6 +449,40 @@ func (s *ObjectService) AbortMultipartUpload(bucket, key, uploadID string) error
return s.metadata.AbortMultipartUpload(uploadID) return s.metadata.AbortMultipartUpload(uploadID)
} }
func (s *ObjectService) limitUpload(input io.Reader) io.Reader {
if s.maxUploadSize <= 0 || input == nil {
return input
}
return &maxBytesReader{inner: input, remaining: s.maxUploadSize}
}
type maxBytesReader struct {
inner io.Reader
remaining int64
tooLarge bool
}
func (r *maxBytesReader) Read(p []byte) (int, error) {
if r.tooLarge {
return 0, ErrEntityTooLarge
}
if r.remaining <= 0 {
var probe [1]byte
n, err := r.inner.Read(probe[:])
if n > 0 {
r.tooLarge = true
return 0, ErrEntityTooLarge
}
return 0, err
}
if int64(len(p)) > r.remaining {
p = p[:r.remaining]
}
n, err := r.inner.Read(p)
r.remaining -= int64(n)
return n, err
}
func normalizeETag(etag string) string { func normalizeETag(etag string) string {
return strings.Trim(etag, "\"") return strings.Trim(etag, "\"")
} }
@@ -437,6 +517,12 @@ func (s *ObjectService) GarbageCollect() error {
unlock := s.acquireGCLock() unlock := s.acquireGCLock()
defer unlock() defer unlock()
var err error
cleanedUploads, err = s.metadata.CleanupMultipartUploads(s.multipartRetention)
if err != nil {
return err
}
referencedChunkSet, err := s.metadata.GetReferencedChunkSet() referencedChunkSet, err := s.metadata.GetReferencedChunkSet()
if err != nil { if err != nil {
return err return err
@@ -460,11 +546,6 @@ func (s *ObjectService) GarbageCollect() error {
return err return err
} }
cleanedUploads, err = s.metadata.CleanupMultipartUploads(s.multipartRetention)
if err != nil {
return err
}
slog.Info("garbage_collect_completed", slog.Info("garbage_collect_completed",
"referenced_chunks", len(referencedChunkSet), "referenced_chunks", len(referencedChunkSet),
"total_chunks", totalChunks, "total_chunks", totalChunks,

View File

@@ -0,0 +1,119 @@
package service
import (
"errors"
"fs/metadata"
"fs/storage"
"path/filepath"
"strings"
"testing"
"time"
)
func TestPutObjectRejectsOversizedUpload(t *testing.T) {
svc := newTestObjectService(t, 4)
if err := svc.CreateBucket("test-bucket"); err != nil {
t.Fatalf("CreateBucket: %v", err)
}
_, err := svc.PutObject("test-bucket", "too-large.txt", "text/plain", strings.NewReader("12345"))
if !errors.Is(err, ErrEntityTooLarge) {
t.Fatalf("PutObject error = %v, want ErrEntityTooLarge", err)
}
if _, err := svc.HeadObject("test-bucket", "too-large.txt"); !errors.Is(err, metadata.ErrObjectNotFound) {
t.Fatalf("HeadObject error = %v, want ErrObjectNotFound", err)
}
}
func TestPutObjectAllowsExactUploadLimit(t *testing.T) {
svc := newTestObjectService(t, 4)
if err := svc.CreateBucket("test-bucket"); err != nil {
t.Fatalf("CreateBucket: %v", err)
}
manifest, err := svc.PutObject("test-bucket", "exact.txt", "text/plain", strings.NewReader("1234"))
if err != nil {
t.Fatalf("PutObject: %v", err)
}
if manifest.Size != 4 {
t.Fatalf("manifest size = %d, want 4", manifest.Size)
}
}
func TestUploadPartRejectsOversizedUpload(t *testing.T) {
svc := newTestObjectService(t, 4)
if err := svc.CreateBucket("test-bucket"); err != nil {
t.Fatalf("CreateBucket: %v", err)
}
upload, err := svc.CreateMultipartUpload("test-bucket", "object.txt")
if err != nil {
t.Fatalf("CreateMultipartUpload: %v", err)
}
_, err = svc.UploadPart("test-bucket", "object.txt", upload.UploadID, 1, strings.NewReader("12345"))
if !errors.Is(err, ErrEntityTooLarge) {
t.Fatalf("UploadPart error = %v, want ErrEntityTooLarge", err)
}
parts, err := svc.ListMultipartParts("test-bucket", "object.txt", upload.UploadID)
if err != nil {
t.Fatalf("ListMultipartParts: %v", err)
}
if len(parts) != 0 {
t.Fatalf("stored parts = %d, want 0", len(parts))
}
}
func TestGarbageCollectRemovesExpiredPendingMultipartChunks(t *testing.T) {
svc := newTestObjectService(t, 1024)
svc.multipartRetention = time.Nanosecond
if err := svc.CreateBucket("test-bucket"); err != nil {
t.Fatalf("CreateBucket: %v", err)
}
upload, err := svc.CreateMultipartUpload("test-bucket", "object.txt")
if err != nil {
t.Fatalf("CreateMultipartUpload: %v", err)
}
if _, err := svc.UploadPart("test-bucket", "object.txt", upload.UploadID, 1, strings.NewReader("part-data")); err != nil {
t.Fatalf("UploadPart: %v", err)
}
chunks, err := svc.blob.ListChunks()
if err != nil {
t.Fatalf("ListChunks before GC: %v", err)
}
if len(chunks) == 0 {
t.Fatalf("expected uploaded part chunks")
}
time.Sleep(time.Millisecond)
if err := svc.GarbageCollect(); err != nil {
t.Fatalf("GarbageCollect: %v", err)
}
if _, err := svc.metadata.GetMultipartUpload(upload.UploadID); !errors.Is(err, metadata.ErrMultipartNotFound) {
t.Fatalf("GetMultipartUpload error = %v, want ErrMultipartNotFound", err)
}
chunks, err = svc.blob.ListChunks()
if err != nil {
t.Fatalf("ListChunks after GC: %v", err)
}
if len(chunks) != 0 {
t.Fatalf("chunks after GC = %d, want 0", len(chunks))
}
}
func newTestObjectService(t *testing.T, maxUploadSize int64) *ObjectService {
t.Helper()
root := t.TempDir()
md, err := metadata.NewMetadataHandler(filepath.Join(root, "metadata.db"))
if err != nil {
t.Fatalf("NewMetadataHandler: %v", err)
}
blob, err := storage.NewBlobStore(root, 4)
if err != nil {
t.Fatalf("NewBlobStore: %v", err)
}
svc := NewObjectService(md, blob, time.Hour, maxUploadSize)
t.Cleanup(func() {
_ = svc.Close()
})
return svc
}

View File

@@ -17,6 +17,8 @@ import (
const blobRoot = "blobs" const blobRoot = "blobs"
const maxChunkSize = 64 * 1024 * 1024 const maxChunkSize = 64 * 1024 * 1024
var ErrChunkIntegrity = errors.New("chunk integrity check failed")
type BlobStore struct { type BlobStore struct {
dataRoot string dataRoot string
chunkSize int chunkSize int
@@ -185,6 +187,11 @@ func (bs *BlobStore) GetBlob(chunkID string) ([]byte, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
chunkHash := sha256.Sum256(data)
actualChunkID := hex.EncodeToString(chunkHash[:])
if actualChunkID != chunkID {
return nil, fmt.Errorf("%w: expected %s, got %s", ErrChunkIntegrity, chunkID, actualChunkID)
}
size = int64(len(data)) size = int64(len(data))
success = true success = true
return data, nil return data, nil

79
storage/blob_test.go Normal file
View File

@@ -0,0 +1,79 @@
package storage
import (
"errors"
"io"
"os"
"path/filepath"
"strings"
"testing"
)
func TestGetBlobDetectsCorruptedChunk(t *testing.T) {
root := t.TempDir()
bs, err := NewBlobStore(root, 4)
if err != nil {
t.Fatalf("new blob store: %v", err)
}
chunks, _, _, err := bs.IngestStream(strings.NewReader("good"))
if err != nil {
t.Fatalf("ingest: %v", err)
}
chunkID := chunks[0]
corruptChunk(t, root, chunkID, []byte("bad"))
got, err := bs.GetBlob(chunkID)
if !errors.Is(err, ErrChunkIntegrity) {
t.Fatalf("GetBlob error = %v, want ErrChunkIntegrity", err)
}
if got != nil {
t.Fatalf("GetBlob returned data for corrupted chunk: %q", got)
}
}
func TestAssembleStreamDetectsCorruptedChunk(t *testing.T) {
root := t.TempDir()
bs, err := NewBlobStore(root, 4)
if err != nil {
t.Fatalf("new blob store: %v", err)
}
chunks, _, _, err := bs.IngestStream(strings.NewReader("abcdefgh"))
if err != nil {
t.Fatalf("ingest: %v", err)
}
if len(chunks) != 2 {
t.Fatalf("chunk count = %d, want 2", len(chunks))
}
corruptChunk(t, root, chunks[1], []byte("corrupt"))
pr, pw := io.Pipe()
errCh := make(chan error, 1)
go func() {
err := bs.AssembleStream(chunks, pw)
if err != nil {
_ = pw.CloseWithError(err)
} else {
_ = pw.Close()
}
errCh <- err
}()
_, readErr := io.ReadAll(pr)
assembleErr := <-errCh
if !errors.Is(assembleErr, ErrChunkIntegrity) {
t.Fatalf("AssembleStream error = %v, want ErrChunkIntegrity", assembleErr)
}
if !errors.Is(readErr, ErrChunkIntegrity) {
t.Fatalf("pipe read error = %v, want ErrChunkIntegrity", readErr)
}
}
func corruptChunk(t *testing.T, root, chunkID string, data []byte) {
t.Helper()
path := filepath.Join(root, blobRoot, chunkID[:2], chunkID[2:4], chunkID)
if err := os.WriteFile(path, data, 0o600); err != nil {
t.Fatalf("corrupt chunk: %v", err)
}
}

View File

@@ -15,6 +15,7 @@ type Config struct {
Address string Address string
Port int Port int
ChunkSize int ChunkSize int
MaxObjectUploadBytes int64
LogLevel string LogLevel string
LogFormat string LogFormat string
AuditLog bool AuditLog bool
@@ -36,15 +37,16 @@ func NewConfig() *Config {
_ = godotenv.Load() _ = godotenv.Load()
config := &Config{ config := &Config{
DataPath: sanitizeDataPath(os.Getenv("DATA_PATH")), DataPath: sanitizeDataPath(os.Getenv("DATA_PATH")),
Address: firstNonEmpty(strings.TrimSpace(os.Getenv("ADDRESS")), "0.0.0.0"), Address: firstNonEmpty(strings.TrimSpace(os.Getenv("ADDRESS")), "0.0.0.0"),
Port: envIntRange("PORT", 3000, 1, 65535), Port: envIntRange("PORT", 2600, 1, 65535),
ChunkSize: envIntRange("CHUNK_SIZE", 8192000, 1, 64*1024*1024), ChunkSize: envIntRange("CHUNK_SIZE", 8192000, 1, 64*1024*1024),
LogLevel: strings.ToLower(firstNonEmpty(strings.TrimSpace(os.Getenv("LOG_LEVEL")), "info")), MaxObjectUploadBytes: envInt64Range("FS_MAX_OBJECT_UPLOAD_BYTES", 5*1024*1024*1024, 1, 5*1024*1024*1024),
LogFormat: strings.ToLower(firstNonEmpty(strings.TrimSpace(os.Getenv("LOG_FORMAT")), strings.TrimSpace(os.Getenv("LOG_TYPE")), "text")), LogLevel: strings.ToLower(firstNonEmpty(strings.TrimSpace(os.Getenv("LOG_LEVEL")), "info")),
AuditLog: envBool("AUDIT_LOG", true), LogFormat: strings.ToLower(firstNonEmpty(strings.TrimSpace(os.Getenv("LOG_FORMAT")), strings.TrimSpace(os.Getenv("LOG_TYPE")), "text")),
GcInterval: time.Duration(envIntRange("GC_INTERVAL", 10, 1, 60)) * time.Minute, AuditLog: envBool("AUDIT_LOG", true),
GcEnabled: envBool("GC_ENABLED", true), GcInterval: time.Duration(envIntRange("GC_INTERVAL", 10, 1, 60)) * time.Minute,
GcEnabled: envBool("GC_ENABLED", true),
MultipartCleanupRetention: time.Duration( MultipartCleanupRetention: time.Duration(
envIntRange("MULTIPART_RETENTION_HOURS", 24, 1, 24*30), envIntRange("MULTIPART_RETENTION_HOURS", 24, 1, 24*30),
) * time.Hour, ) * time.Hour,
@@ -82,6 +84,21 @@ func envIntRange(key string, defaultValue, minValue, maxValue int) int {
return value return value
} }
func envInt64Range(key string, defaultValue, minValue, maxValue int64) int64 {
raw := strings.TrimSpace(os.Getenv(key))
if raw == "" {
return defaultValue
}
value, err := strconv.ParseInt(raw, 10, 64)
if err != nil {
return defaultValue
}
if value < minValue || value > maxValue {
return defaultValue
}
return value
}
func envBool(key string, defaultValue bool) bool { func envBool(key string, defaultValue bool) bool {
raw := strings.TrimSpace(os.Getenv(key)) raw := strings.TrimSpace(os.Getenv(key))
if raw == "" { if raw == "" {

21
utils/config_test.go Normal file
View File

@@ -0,0 +1,21 @@
package utils
import "testing"
func TestEnvInt64Range(t *testing.T) {
t.Setenv("TEST_INT64_RANGE", "42")
if got := envInt64Range("TEST_INT64_RANGE", 10, 1, 100); got != 42 {
t.Fatalf("envInt64Range valid = %d, want 42", got)
}
}
func TestEnvInt64RangeFallsBackForInvalidValues(t *testing.T) {
t.Setenv("TEST_INT64_RANGE", "invalid")
if got := envInt64Range("TEST_INT64_RANGE", 10, 1, 100); got != 10 {
t.Fatalf("envInt64Range invalid = %d, want 10", got)
}
t.Setenv("TEST_INT64_RANGE", "101")
if got := envInt64Range("TEST_INT64_RANGE", 10, 1, 100); got != 10 {
t.Fatalf("envInt64Range too large = %d, want 10", got)
}
}