package server import ( "context" "fmt" "io" "log" "net/http" "os/exec" "strings" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/mistifyio/go-zfs" ) // StorageBackend defines the interface for different storage types type StorageBackend interface { Upload(ctx context.Context, key string, data io.Reader, size int64) error Download(ctx context.Context, key string) (io.ReadCloser, error) Delete(ctx context.Context, key string) error List(ctx context.Context, prefix string) ([]string, error) GetSize(ctx context.Context, key string) (int64, error) } // S3Backend implements StorageBackend for S3-compatible storage using AWS SDK v2 type S3Backend struct { client *s3.Client bucketName string } // NewS3Backend creates a new S3 storage backend func NewS3Backend(endpoint, accessKey, secretKey, bucketName string, useSSL bool, region string) (*S3Backend, error) { // Ensure endpoint has valid URI scheme if endpoint != "" && !strings.HasPrefix(endpoint, "http://") && !strings.HasPrefix(endpoint, "https://") { if useSSL { endpoint = "https://" + endpoint } else { endpoint = "http://" + endpoint } } // Determine if using custom endpoint (non-AWS) customEndpoint := endpoint != "" && endpoint != "https://s3.amazonaws.com" && endpoint != "http://s3.amazonaws.com" // Load AWS config awsCfg, err := config.LoadDefaultConfig(context.Background(), config.WithRegion(region), config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(accessKey, secretKey, "")), ) if err != nil { return nil, fmt.Errorf("failed to load AWS config: %v", err) } // Create S3 client s3Client := s3.NewFromConfig(awsCfg, func(o *s3.Options) { if customEndpoint { o.BaseEndpoint = aws.String(endpoint) o.UsePathStyle = true // Required for MinIO and other S3-compatible storage } // Set HTTP client with extended timeout for large uploads o.HTTPClient = &http.Client{ Timeout: 0, // No timeout for large file uploads } }) // Check if bucket exists (or create it for AWS S3) ctx := context.Background() _, err = s3Client.HeadBucket(ctx, &s3.HeadBucketInput{ Bucket: aws.String(bucketName), }) if err != nil { // Try to create bucket _, err = s3Client.CreateBucket(ctx, &s3.CreateBucketInput{ Bucket: aws.String(bucketName), }) if err != nil { log.Printf("Warning: failed to create bucket: %v", err) } else { log.Printf("Created S3 bucket: %s", bucketName) } } return &S3Backend{ client: s3Client, bucketName: bucketName, }, nil } // Upload uploads data to S3 func (s *S3Backend) Upload(ctx context.Context, key string, data io.Reader, size int64) error { _, err := s.client.PutObject(ctx, &s3.PutObjectInput{ Bucket: aws.String(s.bucketName), Key: aws.String(key), Body: data, ContentType: aws.String("application/octet-stream"), }) return err } // Download retrieves data from S3 func (s *S3Backend) Download(ctx context.Context, key string) (io.ReadCloser, error) { resp, err := s.client.GetObject(ctx, &s3.GetObjectInput{ Bucket: aws.String(s.bucketName), Key: aws.String(key), }) if err != nil { return nil, err } return resp.Body, nil } // Delete removes an object from S3 func (s *S3Backend) Delete(ctx context.Context, key string) error { _, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{ Bucket: aws.String(s.bucketName), Key: aws.String(key), }) return err } // List returns all objects with the given prefix func (s *S3Backend) List(ctx context.Context, prefix string) ([]string, error) { var keys []string paginator := s3.NewListObjectsV2Paginator(s.client, &s3.ListObjectsV2Input{ Bucket: aws.String(s.bucketName), Prefix: aws.String(prefix), }) for paginator.HasMorePages() { page, err := paginator.NextPage(ctx) if err != nil { return nil, err } for _, obj := range page.Contents { keys = append(keys, *obj.Key) } } return keys, nil } // GetSize returns the size of an object in S3 func (s *S3Backend) GetSize(ctx context.Context, key string) (int64, error) { info, err := s.client.HeadObject(ctx, &s3.HeadObjectInput{ Bucket: aws.String(s.bucketName), Key: aws.String(key), }) if err != nil { return 0, err } return *info.ContentLength, nil } // LocalBackend implements StorageBackend for local ZFS storage type LocalBackend struct { baseDataset string } // NewLocalBackend creates a new local ZFS storage backend func NewLocalBackend(baseDataset string) *LocalBackend { return &LocalBackend{baseDataset: baseDataset} } // Upload is not supported for local backend func (l *LocalBackend) Upload(ctx context.Context, key string, data io.Reader, size int64) error { return fmt.Errorf("local backend upload not supported via storage interface, use zfs receive endpoint") } // Download creates a zfs send stream func (l *LocalBackend) Download(ctx context.Context, key string) (io.ReadCloser, error) { cmd := exec.CommandContext(ctx, "zfs", "send", key) stdout, err := cmd.StdoutPipe() if err != nil { return nil, err } if err := cmd.Start(); err != nil { return nil, err } return &cmdReadCloser{stdout: stdout, cmd: cmd}, nil } // Delete destroys a ZFS dataset func (l *LocalBackend) Delete(ctx context.Context, key string) error { ds, err := zfs.GetDataset(key) if err != nil { return err } return ds.Destroy(zfs.DestroyDefault) } // List returns all snapshots with the given prefix func (l *LocalBackend) List(ctx context.Context, prefix string) ([]string, error) { snapshots, err := zfs.Snapshots(prefix) if err != nil { return nil, err } var names []string for _, snap := range snapshots { names = append(names, snap.Name) } return names, nil } // GetSize returns the used size of a ZFS dataset func (l *LocalBackend) GetSize(ctx context.Context, key string) (int64, error) { ds, err := zfs.GetDataset(key) if err != nil { return 0, err } return int64(ds.Used), nil } // cmdReadCloser wraps stdout pipe to properly wait for command completion type cmdReadCloser struct { stdout io.ReadCloser cmd *exec.Cmd closed bool } func (c *cmdReadCloser) Read(p []byte) (int, error) { return c.stdout.Read(p) } func (c *cmdReadCloser) Close() error { if c.closed { return nil } c.closed = true err := c.stdout.Close() waitErr := c.cmd.Wait() if err != nil { return err } return waitErr }