714 lines
19 KiB
Go
714 lines
19 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"path/filepath"
|
|
"runtime"
|
|
"strings"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/dutchcoders/go-clamd"
|
|
"github.com/go-redis/redis/v8"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
|
"github.com/shirou/gopsutil/cpu"
|
|
"github.com/shirou/gopsutil/mem"
|
|
"gopkg.in/natefinch/lumberjack.v2"
|
|
)
|
|
|
|
// WorkerPool represents a pool of workers
|
|
type WorkerPool struct {
|
|
workers int
|
|
taskQueue chan UploadTask
|
|
scanQueue chan ScanTask
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
}
|
|
|
|
// NewWorkerPool creates a new worker pool
|
|
func NewWorkerPool(workers int, queueSize int) *WorkerPool {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
return &WorkerPool{
|
|
workers: workers,
|
|
taskQueue: make(chan UploadTask, queueSize),
|
|
scanQueue: make(chan ScanTask, queueSize),
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
}
|
|
}
|
|
|
|
// Start starts the worker pool
|
|
func (wp *WorkerPool) Start() {
|
|
for i := 0; i < wp.workers; i++ {
|
|
go wp.worker()
|
|
}
|
|
}
|
|
|
|
// Stop stops the worker pool
|
|
func (wp *WorkerPool) Stop() {
|
|
wp.cancel()
|
|
close(wp.taskQueue)
|
|
close(wp.scanQueue)
|
|
}
|
|
|
|
// worker is the worker function
|
|
func (wp *WorkerPool) worker() {
|
|
for {
|
|
select {
|
|
case <-wp.ctx.Done():
|
|
return
|
|
case task := <-wp.taskQueue:
|
|
if task.Result != nil {
|
|
task.Result <- nil // Simple implementation
|
|
}
|
|
case scanTask := <-wp.scanQueue:
|
|
err := processScan(scanTask)
|
|
if scanTask.Result != nil {
|
|
scanTask.Result <- err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Stub for precacheStoragePath
|
|
func precacheStoragePath(storagePath string) error {
|
|
// TODO: Implement actual pre-caching logic
|
|
// This would typically involve walking the storagePath
|
|
// and loading file information into a cache.
|
|
log.Infof("Pre-caching for storage path '%s' is a stub and not yet implemented.", storagePath)
|
|
return nil
|
|
}
|
|
|
|
func checkFreeSpaceWithRetry(path string, retries int, delay time.Duration) error {
|
|
for i := 0; i < retries; i++ {
|
|
minFreeBytes, err := parseSize(conf.Server.MinFreeBytes)
|
|
if err != nil {
|
|
log.Fatalf("Invalid MinFreeBytes: %v", err)
|
|
}
|
|
if err := checkStorageSpace(path, minFreeBytes); err != nil {
|
|
log.Warnf("Free space check failed (attempt %d/%d): %v", i+1, retries, err)
|
|
time.Sleep(delay)
|
|
continue
|
|
}
|
|
return nil
|
|
}
|
|
return fmt.Errorf("insufficient free space after %d attempts", retries)
|
|
}
|
|
|
|
func handleFileCleanup(conf *Config) {
|
|
if !conf.Server.FileTTLEnabled {
|
|
log.Println("File TTL is disabled.")
|
|
return
|
|
}
|
|
|
|
ttlDuration, err := parseTTL(conf.Server.FileTTL)
|
|
if err != nil {
|
|
log.Fatalf("Invalid TTL configuration: %v", err)
|
|
}
|
|
|
|
log.Printf("TTL cleanup enabled. Files older than %v will be deleted.", ttlDuration)
|
|
ticker := time.NewTicker(24 * time.Hour)
|
|
defer ticker.Stop()
|
|
|
|
for range ticker.C {
|
|
deleteOldFiles(conf, ttlDuration)
|
|
}
|
|
}
|
|
|
|
func computeSHA256(ctx context.Context, filePath string) (string, error) {
|
|
if filePath == "" {
|
|
return "", fmt.Errorf("file path is empty")
|
|
}
|
|
file, err := os.Open(filePath)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to open file %s: %w", filePath, err)
|
|
}
|
|
defer file.Close()
|
|
|
|
hasher := sha256.New()
|
|
if _, err := io.Copy(hasher, file); err != nil {
|
|
return "", fmt.Errorf("failed to hash file: %w", err)
|
|
}
|
|
return hex.EncodeToString(hasher.Sum(nil)), nil
|
|
}
|
|
|
|
func handleDeduplication(ctx context.Context, absFilename string) error {
|
|
checksum, err := computeSHA256(ctx, absFilename)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
dedupDir := conf.Deduplication.Directory
|
|
if dedupDir == "" {
|
|
return fmt.Errorf("deduplication directory not configured")
|
|
}
|
|
|
|
dedupPath := filepath.Join(dedupDir, checksum)
|
|
if err := os.MkdirAll(dedupPath, os.ModePerm); err != nil {
|
|
return err
|
|
}
|
|
|
|
existingPath := filepath.Join(dedupPath, filepath.Base(absFilename))
|
|
if _, err := os.Stat(existingPath); err == nil {
|
|
return os.Link(existingPath, absFilename)
|
|
}
|
|
|
|
if err := os.Rename(absFilename, existingPath); err != nil {
|
|
return err
|
|
}
|
|
|
|
return os.Link(existingPath, absFilename)
|
|
}
|
|
|
|
func handleISOContainer(absFilename string) error {
|
|
isoPath := filepath.Join(conf.ISO.MountPoint, "container.iso")
|
|
if err := CreateISOContainer([]string{absFilename}, isoPath, conf.ISO.Size, conf.ISO.Charset); err != nil {
|
|
return err
|
|
}
|
|
if err := MountISOContainer(isoPath, conf.ISO.MountPoint); err != nil {
|
|
return err
|
|
}
|
|
return UnmountISOContainer(conf.ISO.MountPoint)
|
|
}
|
|
|
|
func sanitizeFilePath(baseDir, filePath string) (string, error) {
|
|
absBaseDir, err := filepath.Abs(baseDir)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
absFilePath, err := filepath.Abs(filepath.Join(absBaseDir, filePath))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if !strings.HasPrefix(absFilePath, absBaseDir) {
|
|
return "", fmt.Errorf("invalid file path: %s", filePath)
|
|
}
|
|
return absFilePath, nil
|
|
}
|
|
|
|
// Stub for formatBytes
|
|
func formatBytes(bytes int64) string {
|
|
const unit = 1024
|
|
if bytes < unit {
|
|
return fmt.Sprintf("%d B", bytes)
|
|
}
|
|
div, exp := int64(unit), 0
|
|
for n := bytes / unit; n >= unit; n /= unit {
|
|
div *= unit
|
|
exp++
|
|
}
|
|
return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp])
|
|
}
|
|
|
|
// Stub for deleteOldFiles
|
|
func deleteOldFiles(conf *Config, ttlDuration time.Duration) {
|
|
// TODO: Implement actual file deletion logic based on TTL
|
|
log.Infof("deleteOldFiles is a stub and not yet implemented. It would check for files older than %v.", ttlDuration)
|
|
}
|
|
|
|
// Stub for CreateISOContainer
|
|
func CreateISOContainer(files []string, isoPath, size, charset string) error {
|
|
// TODO: Implement actual ISO container creation logic
|
|
log.Infof("CreateISOContainer is a stub and not yet implemented. It would create an ISO at %s.", isoPath)
|
|
return nil
|
|
}
|
|
|
|
// Stub for MountISOContainer
|
|
func MountISOContainer(isoPath, mountPoint string) error {
|
|
// TODO: Implement actual ISO container mounting logic
|
|
log.Infof("MountISOContainer is a stub and not yet implemented. It would mount %s to %s.", isoPath, mountPoint)
|
|
return nil
|
|
}
|
|
|
|
// Stub for UnmountISOContainer
|
|
func UnmountISOContainer(mountPoint string) error {
|
|
// TODO: Implement actual ISO container unmounting logic
|
|
log.Infof("UnmountISOContainer is a stub and not yet implemented. It would unmount %s.", mountPoint)
|
|
return nil
|
|
}
|
|
|
|
func checkStorageSpace(storagePath string, minFreeBytes int64) error {
|
|
var stat syscall.Statfs_t
|
|
if err := syscall.Statfs(storagePath, &stat); err != nil {
|
|
return err
|
|
}
|
|
availableBytes := stat.Bavail * uint64(stat.Bsize)
|
|
if int64(availableBytes) < minFreeBytes {
|
|
return fmt.Errorf("not enough space: available %d < required %d", availableBytes, minFreeBytes)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// setupLogging initializes logging configuration
|
|
func setupLogging() {
|
|
log.Infof("DEBUG: Starting setupLogging function")
|
|
if conf.Logging.File != "" {
|
|
log.Infof("DEBUG: Setting up file logging to: %s", conf.Logging.File)
|
|
log.SetOutput(&lumberjack.Logger{
|
|
Filename: conf.Logging.File,
|
|
MaxSize: conf.Logging.MaxSize,
|
|
MaxBackups: conf.Logging.MaxBackups,
|
|
MaxAge: conf.Logging.MaxAge,
|
|
Compress: conf.Logging.Compress,
|
|
})
|
|
log.Infof("Logging configured to file: %s", conf.Logging.File)
|
|
}
|
|
log.Infof("DEBUG: setupLogging function completed")
|
|
}
|
|
|
|
// logSystemInfo logs system information
|
|
func logSystemInfo() {
|
|
memStats, err := mem.VirtualMemory()
|
|
if err != nil {
|
|
log.Warnf("Failed to get memory stats: %v", err)
|
|
} else {
|
|
log.Infof("System Memory: Total=%s, Available=%s, Used=%.1f%%",
|
|
formatBytes(int64(memStats.Total)),
|
|
formatBytes(int64(memStats.Available)),
|
|
memStats.UsedPercent)
|
|
}
|
|
|
|
cpuStats, err := cpu.Info()
|
|
if err != nil {
|
|
log.Warnf("Failed to get CPU stats: %v", err)
|
|
} else if len(cpuStats) > 0 {
|
|
log.Infof("CPU: %s, Cores=%d", cpuStats[0].ModelName, len(cpuStats))
|
|
}
|
|
|
|
log.Infof("Go Runtime: Version=%s, NumCPU=%d, NumGoroutine=%d",
|
|
runtime.Version(), runtime.NumCPU(), runtime.NumGoroutine())
|
|
}
|
|
|
|
// initMetrics initializes Prometheus metrics
|
|
func initMetrics() {
|
|
uploadDuration = prometheus.NewHistogram(prometheus.HistogramOpts{
|
|
Name: "upload_duration_seconds",
|
|
Help: "Duration of upload operations in seconds",
|
|
})
|
|
|
|
uploadErrorsTotal = prometheus.NewCounter(prometheus.CounterOpts{
|
|
Name: "upload_errors_total",
|
|
Help: "Total number of upload errors",
|
|
})
|
|
|
|
uploadsTotal = prometheus.NewCounter(prometheus.CounterOpts{
|
|
Name: "uploads_total",
|
|
Help: "Total number of uploads",
|
|
})
|
|
|
|
downloadDuration = prometheus.NewHistogram(prometheus.HistogramOpts{
|
|
Name: "download_duration_seconds",
|
|
Help: "Duration of download operations in seconds",
|
|
})
|
|
|
|
downloadsTotal = prometheus.NewCounter(prometheus.CounterOpts{
|
|
Name: "downloads_total",
|
|
Help: "Total number of downloads",
|
|
})
|
|
|
|
downloadErrorsTotal = prometheus.NewCounter(prometheus.CounterOpts{
|
|
Name: "download_errors_total",
|
|
Help: "Total number of download errors",
|
|
})
|
|
|
|
memoryUsage = prometheus.NewGauge(prometheus.GaugeOpts{
|
|
Name: "memory_usage_percent",
|
|
Help: "Current memory usage percentage",
|
|
})
|
|
|
|
cpuUsage = prometheus.NewGauge(prometheus.GaugeOpts{
|
|
Name: "cpu_usage_percent",
|
|
Help: "Current CPU usage percentage",
|
|
})
|
|
|
|
activeConnections = prometheus.NewGauge(prometheus.GaugeOpts{
|
|
Name: "active_connections_total",
|
|
Help: "Number of active connections",
|
|
})
|
|
|
|
requestsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{
|
|
Name: "requests_total",
|
|
Help: "Total number of requests",
|
|
}, []string{"method", "status"})
|
|
|
|
goroutines = prometheus.NewGauge(prometheus.GaugeOpts{
|
|
Name: "goroutines_total",
|
|
Help: "Number of goroutines",
|
|
})
|
|
|
|
uploadSizeBytes = prometheus.NewHistogram(prometheus.HistogramOpts{
|
|
Name: "upload_size_bytes",
|
|
Help: "Size of uploaded files in bytes",
|
|
})
|
|
|
|
downloadSizeBytes = prometheus.NewHistogram(prometheus.HistogramOpts{
|
|
Name: "download_size_bytes",
|
|
Help: "Size of downloaded files in bytes",
|
|
})
|
|
|
|
filesDeduplicatedTotal = prometheus.NewCounter(prometheus.CounterOpts{
|
|
Name: "files_deduplicated_total",
|
|
Help: "Total number of deduplicated files",
|
|
})
|
|
|
|
deduplicationErrorsTotal = prometheus.NewCounter(prometheus.CounterOpts{
|
|
Name: "deduplication_errors_total",
|
|
Help: "Total number of deduplication errors",
|
|
})
|
|
|
|
isoContainersCreatedTotal = prometheus.NewCounter(prometheus.CounterOpts{
|
|
Name: "iso_containers_created_total",
|
|
Help: "Total number of ISO containers created",
|
|
})
|
|
|
|
isoCreationErrorsTotal = prometheus.NewCounter(prometheus.CounterOpts{
|
|
Name: "iso_creation_errors_total",
|
|
Help: "Total number of ISO creation errors",
|
|
})
|
|
|
|
isoContainersMountedTotal = prometheus.NewCounter(prometheus.CounterOpts{
|
|
Name: "iso_containers_mounted_total",
|
|
Help: "Total number of ISO containers mounted",
|
|
})
|
|
|
|
isoMountErrorsTotal = prometheus.NewCounter(prometheus.CounterOpts{
|
|
Name: "iso_mount_errors_total",
|
|
Help: "Total number of ISO mount errors",
|
|
})
|
|
|
|
workerAdjustmentsTotal = prometheus.NewCounter(prometheus.CounterOpts{
|
|
Name: "worker_adjustments_total",
|
|
Help: "Total number of worker adjustments",
|
|
})
|
|
|
|
workerReAdjustmentsTotal = prometheus.NewCounter(prometheus.CounterOpts{
|
|
Name: "worker_readjustments_total",
|
|
Help: "Total number of worker readjustments",
|
|
})
|
|
|
|
// Register all metrics
|
|
prometheus.MustRegister(
|
|
uploadDuration, uploadErrorsTotal, uploadsTotal,
|
|
downloadDuration, downloadsTotal, downloadErrorsTotal,
|
|
memoryUsage, cpuUsage, activeConnections, requestsTotal,
|
|
goroutines, uploadSizeBytes, downloadSizeBytes,
|
|
filesDeduplicatedTotal, deduplicationErrorsTotal,
|
|
isoContainersCreatedTotal, isoCreationErrorsTotal,
|
|
isoContainersMountedTotal, isoMountErrorsTotal,
|
|
workerAdjustmentsTotal, workerReAdjustmentsTotal,
|
|
)
|
|
|
|
log.Info("Prometheus metrics initialized successfully")
|
|
}
|
|
|
|
// scanFileWithClamAV scans a file using ClamAV
|
|
func scanFileWithClamAV(filename string) error {
|
|
if clamClient == nil {
|
|
return fmt.Errorf("ClamAV client not initialized")
|
|
}
|
|
|
|
result, err := clamClient.ScanFile(filename)
|
|
if err != nil {
|
|
return fmt.Errorf("ClamAV scan failed: %w", err)
|
|
}
|
|
|
|
// Handle the result channel
|
|
if result != nil {
|
|
select {
|
|
case scanResult := <-result:
|
|
if scanResult != nil && scanResult.Status != "OK" {
|
|
return fmt.Errorf("virus detected in %s: %s", filename, scanResult.Status)
|
|
}
|
|
case <-time.After(30 * time.Second):
|
|
return fmt.Errorf("ClamAV scan timeout for file: %s", filename)
|
|
}
|
|
}
|
|
|
|
log.Debugf("File %s passed ClamAV scan", filename)
|
|
return nil
|
|
}
|
|
|
|
// initClamAV initializes ClamAV client
|
|
func initClamAV(socketPath string) (*clamd.Clamd, error) {
|
|
if socketPath == "" {
|
|
socketPath = "/var/run/clamav/clamd.ctl"
|
|
}
|
|
|
|
client := clamd.NewClamd(socketPath)
|
|
|
|
// Test connection
|
|
err := client.Ping()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to ping ClamAV daemon: %w", err)
|
|
}
|
|
|
|
log.Infof("ClamAV client initialized with socket: %s", socketPath)
|
|
return client, nil
|
|
}
|
|
|
|
// initRedis initializes Redis client
|
|
func initRedis() {
|
|
redisClient = redis.NewClient(&redis.Options{
|
|
Addr: conf.Redis.RedisAddr,
|
|
Password: conf.Redis.RedisPassword,
|
|
DB: conf.Redis.RedisDBIndex,
|
|
})
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
_, err := redisClient.Ping(ctx).Result()
|
|
if err != nil {
|
|
log.Warnf("Failed to connect to Redis: %v", err)
|
|
redisConnected = false
|
|
} else {
|
|
log.Info("Redis client initialized successfully")
|
|
redisConnected = true
|
|
}
|
|
}
|
|
|
|
// monitorNetwork monitors network events
|
|
func monitorNetwork(ctx context.Context) {
|
|
log.Info("Starting network monitoring")
|
|
ticker := time.NewTicker(30 * time.Second)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
log.Info("Network monitoring stopped")
|
|
return
|
|
case <-ticker.C:
|
|
// Simple network monitoring - check interface status
|
|
interfaces, err := net.Interfaces()
|
|
if err != nil {
|
|
log.Warnf("Failed to get network interfaces: %v", err)
|
|
continue
|
|
}
|
|
|
|
for _, iface := range interfaces {
|
|
if iface.Flags&net.FlagUp != 0 && iface.Flags&net.FlagLoopback == 0 {
|
|
select {
|
|
case networkEvents <- NetworkEvent{
|
|
Type: "interface_up",
|
|
Details: fmt.Sprintf("Interface %s is up", iface.Name),
|
|
}:
|
|
default:
|
|
// Channel full, skip
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// handleNetworkEvents handles network events
|
|
func handleNetworkEvents(ctx context.Context) {
|
|
log.Info("Starting network event handler")
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
log.Info("Network event handler stopped")
|
|
return
|
|
case event := <-networkEvents:
|
|
log.Debugf("Network event: %s - %s", event.Type, event.Details)
|
|
}
|
|
}
|
|
}
|
|
|
|
// updateSystemMetrics updates system metrics
|
|
func updateSystemMetrics(ctx context.Context) {
|
|
ticker := time.NewTicker(15 * time.Second)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
// Update memory metrics
|
|
if memStats, err := mem.VirtualMemory(); err == nil {
|
|
memoryUsage.Set(memStats.UsedPercent)
|
|
}
|
|
|
|
// Update CPU metrics
|
|
if cpuPercents, err := cpu.Percent(time.Second, false); err == nil && len(cpuPercents) > 0 {
|
|
cpuUsage.Set(cpuPercents[0])
|
|
}
|
|
|
|
// Update goroutine count
|
|
goroutines.Set(float64(runtime.NumGoroutine()))
|
|
}
|
|
}
|
|
}
|
|
|
|
// setupRouter sets up HTTP routes
|
|
func setupRouter() *http.ServeMux {
|
|
mux := http.NewServeMux()
|
|
|
|
mux.HandleFunc("/upload", handleUpload)
|
|
mux.HandleFunc("/download/", handleDownload)
|
|
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("OK"))
|
|
})
|
|
|
|
if conf.Server.MetricsEnabled {
|
|
mux.Handle("/metrics", promhttp.Handler())
|
|
}
|
|
|
|
// Catch-all handler for all upload protocols (v, v2, token, v3)
|
|
// This must be added last as it matches all paths
|
|
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
|
// Handle PUT requests for all upload protocols
|
|
if r.Method == http.MethodPut {
|
|
query := r.URL.Query()
|
|
|
|
// Check if this is a v3 request (mod_http_upload_external)
|
|
if query.Get("v3") != "" && query.Get("expires") != "" {
|
|
handleV3Upload(w, r)
|
|
return
|
|
}
|
|
|
|
// Check if this is a legacy protocol request (v, v2, token)
|
|
if query.Get("v") != "" || query.Get("v2") != "" || query.Get("token") != "" {
|
|
handleLegacyUpload(w, r)
|
|
return
|
|
}
|
|
}
|
|
|
|
// Handle GET/HEAD requests for downloads
|
|
if r.Method == http.MethodGet || r.Method == http.MethodHead {
|
|
// Only handle download requests if the path looks like a file
|
|
path := strings.TrimPrefix(r.URL.Path, "/")
|
|
if path != "" && !strings.HasSuffix(path, "/") {
|
|
handleLegacyDownload(w, r)
|
|
return
|
|
}
|
|
}
|
|
|
|
// For all other requests, return 404
|
|
http.NotFound(w, r)
|
|
})
|
|
|
|
log.Info("HTTP router configured successfully with full protocol support (v, v2, token, v3)")
|
|
return mux
|
|
}
|
|
|
|
// setupGracefulShutdown sets up graceful shutdown
|
|
func setupGracefulShutdown(server *http.Server, cancel context.CancelFunc) {
|
|
sigChan := make(chan os.Signal, 1)
|
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
|
|
|
go func() {
|
|
<-sigChan
|
|
log.Info("Received shutdown signal, initiating graceful shutdown...")
|
|
|
|
// Cancel context
|
|
cancel()
|
|
|
|
// Shutdown server with timeout
|
|
ctx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer shutdownCancel()
|
|
|
|
if err := server.Shutdown(ctx); err != nil {
|
|
log.Errorf("Server shutdown error: %v", err)
|
|
} else {
|
|
log.Info("Server shutdown completed")
|
|
}
|
|
|
|
// Clean up PID file
|
|
if conf.Server.CleanUponExit {
|
|
removePIDFile(conf.Server.PIDFilePath)
|
|
}
|
|
|
|
// Stop worker pool if it exists
|
|
if workerPool != nil {
|
|
workerPool.Stop()
|
|
log.Info("Worker pool stopped")
|
|
}
|
|
|
|
os.Exit(0)
|
|
}()
|
|
}
|
|
|
|
// ProgressWriter wraps an io.Writer to provide upload progress reporting
|
|
type ProgressWriter struct {
|
|
dst io.Writer
|
|
total int64
|
|
written int64
|
|
filename string
|
|
onProgress func(written, total int64, filename string)
|
|
lastReport time.Time
|
|
}
|
|
|
|
// NewProgressWriter creates a new ProgressWriter
|
|
func NewProgressWriter(dst io.Writer, total int64, filename string) *ProgressWriter {
|
|
return &ProgressWriter{
|
|
dst: dst,
|
|
total: total,
|
|
filename: filename,
|
|
onProgress: func(written, total int64, filename string) {
|
|
if total > 0 {
|
|
percentage := float64(written) / float64(total) * 100
|
|
sizeMiB := float64(written) / (1024 * 1024)
|
|
totalMiB := float64(total) / (1024 * 1024)
|
|
log.Infof("Upload progress for %s: %.1f%% (%.1f/%.1f MiB)",
|
|
filepath.Base(filename), percentage, sizeMiB, totalMiB)
|
|
}
|
|
},
|
|
lastReport: time.Now(),
|
|
}
|
|
}
|
|
|
|
// Write implements io.Writer interface with progress reporting
|
|
func (pw *ProgressWriter) Write(p []byte) (int, error) {
|
|
n, err := pw.dst.Write(p)
|
|
if err != nil {
|
|
return n, err
|
|
}
|
|
|
|
pw.written += int64(n)
|
|
|
|
// Report progress every 30 seconds or every 50MB for large files
|
|
now := time.Now()
|
|
shouldReport := false
|
|
|
|
if pw.total > 100*1024*1024 { // Files larger than 100MB
|
|
shouldReport = now.Sub(pw.lastReport) > 30*time.Second ||
|
|
(pw.written%(50*1024*1024) == 0 && pw.written > 0)
|
|
} else if pw.total > 10*1024*1024 { // Files larger than 10MB
|
|
shouldReport = now.Sub(pw.lastReport) > 10*time.Second ||
|
|
(pw.written%(10*1024*1024) == 0 && pw.written > 0)
|
|
}
|
|
|
|
if shouldReport && pw.onProgress != nil {
|
|
pw.onProgress(pw.written, pw.total, pw.filename)
|
|
pw.lastReport = now
|
|
}
|
|
|
|
return n, err
|
|
}
|
|
|
|
// copyWithProgress copies data from src to dst with progress reporting
|
|
func copyWithProgress(dst io.Writer, src io.Reader, total int64, filename string) (int64, error) {
|
|
progressWriter := NewProgressWriter(dst, total, filename)
|
|
|
|
// Use a pooled buffer for efficient copying
|
|
bufPtr := bufferPool.Get().(*[]byte)
|
|
defer bufferPool.Put(bufPtr)
|
|
buf := *bufPtr
|
|
|
|
return io.CopyBuffer(progressWriter, src, buf)
|
|
}
|