release: hmac-file-server 3.2

This commit is contained in:
2025-06-13 04:24:11 +02:00
parent cc3a4f4dd7
commit 16f50940d0
34 changed files with 10354 additions and 2255 deletions

1050
cmd/monitor/monitor.go Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,67 +0,0 @@
# Server Settings
[server]
ListenPort = "8080"
UnixSocket = false
StoreDir = "./testupload"
LogLevel = "info"
LogFile = "./hmac-file-server.log"
MetricsEnabled = true
MetricsPort = "9090"
FileTTL = "8760h"
# Workers and Connections
[workers]
NumWorkers = 2
UploadQueueSize = 500
# Timeout Settings
[timeouts]
ReadTimeout = "600s"
WriteTimeout = "600s"
IdleTimeout = "600s"
# Security Settings
[security]
Secret = "a-orc-and-a-humans-is-drinking-ale"
# Versioning Settings
[versioning]
EnableVersioning = false
MaxVersions = 1
# Upload/Download Settings
[uploads]
ResumableUploadsEnabled = true
ChunkedUploadsEnabled = true
ChunkSize = 16777216
AllowedExtensions = [
# Document formats
".txt", ".pdf",
# Image formats
".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp",
# Video formats
".wav", ".mp4", ".avi", ".mkv", ".mov", ".wmv", ".flv", ".webm", ".mpeg", ".mpg", ".m4v", ".3gp", ".3g2",
# Audio formats
".mp3", ".ogg"
]
# ClamAV Settings
[clamav]
ClamAVEnabled = false
ClamAVSocket = "/var/run/clamav/clamd.ctl"
NumScanWorkers = 4
# Redis Settings
[redis]
RedisEnabled = false
RedisAddr = "localhost:6379"
RedisPassword = ""
RedisDBIndex = 0
RedisHealthCheckInterval = "120s"
# Deduplication
[deduplication]
enabled = false

View File

@ -0,0 +1,294 @@
// config_test_scenarios.go
package main
import (
"fmt"
"os"
"path/filepath"
)
// ConfigTestScenario represents a test scenario for configuration validation
type ConfigTestScenario struct {
Name string
Config Config
ShouldPass bool
ExpectedErrors []string
ExpectedWarnings []string
}
// GetConfigTestScenarios returns a set of test scenarios for configuration validation
func GetConfigTestScenarios() []ConfigTestScenario {
baseValidConfig := Config{
Server: ServerConfig{
ListenAddress: "8080",
BindIP: "0.0.0.0",
StoragePath: "/tmp/test-storage",
MetricsEnabled: true,
MetricsPort: "9090",
FileTTLEnabled: true,
FileTTL: "24h",
MinFreeBytes: "1GB",
FileNaming: "HMAC",
ForceProtocol: "auto",
PIDFilePath: "/tmp/test.pid",
},
Security: SecurityConfig{
Secret: "test-secret-key-32-characters",
EnableJWT: false,
},
Logging: LoggingConfig{
Level: "info",
File: "/tmp/test.log",
MaxSize: 100,
MaxBackups: 3,
MaxAge: 30,
},
Timeouts: TimeoutConfig{
Read: "30s",
Write: "30s",
Idle: "60s",
},
Workers: WorkersConfig{
NumWorkers: 4,
UploadQueueSize: 50,
},
Uploads: UploadsConfig{
AllowedExtensions: []string{".txt", ".pdf", ".jpg"},
ChunkSize: "10MB",
},
Downloads: DownloadsConfig{
AllowedExtensions: []string{".txt", ".pdf", ".jpg"},
ChunkSize: "10MB",
},
}
return []ConfigTestScenario{
{
Name: "Valid Basic Configuration",
Config: baseValidConfig,
ShouldPass: true,
},
{
Name: "Missing Listen Address",
Config: func() Config {
c := baseValidConfig
c.Server.ListenAddress = ""
return c
}(),
ShouldPass: false,
ExpectedErrors: []string{"server.listen_address is required"},
},
{
Name: "Invalid Port Number",
Config: func() Config {
c := baseValidConfig
c.Server.ListenAddress = "99999"
return c
}(),
ShouldPass: false,
ExpectedErrors: []string{"invalid port number"},
},
{
Name: "Invalid IP Address",
Config: func() Config {
c := baseValidConfig
c.Server.BindIP = "999.999.999.999"
return c
}(),
ShouldPass: false,
ExpectedErrors: []string{"invalid IP address format"},
},
{
Name: "Same Port for Server and Metrics",
Config: func() Config {
c := baseValidConfig
c.Server.ListenAddress = "8080"
c.Server.MetricsPort = "8080"
return c
}(),
ShouldPass: false,
ExpectedErrors: []string{"metrics port cannot be the same as main listen port"},
},
{
Name: "JWT Enabled Without Secret",
Config: func() Config {
c := baseValidConfig
c.Security.EnableJWT = true
c.Security.JWTSecret = ""
return c
}(),
ShouldPass: false,
ExpectedErrors: []string{"JWT secret is required when JWT is enabled"},
},
{
Name: "Short JWT Secret",
Config: func() Config {
c := baseValidConfig
c.Security.EnableJWT = true
c.Security.JWTSecret = "short"
c.Security.JWTAlgorithm = "HS256"
return c
}(),
ShouldPass: true,
ExpectedWarnings: []string{"JWT secret should be at least 32 characters"},
},
{
Name: "Invalid Log Level",
Config: func() Config {
c := baseValidConfig
c.Logging.Level = "invalid"
return c
}(),
ShouldPass: false,
ExpectedErrors: []string{"invalid log level"},
},
{
Name: "Invalid Timeout Format",
Config: func() Config {
c := baseValidConfig
c.Timeouts.Read = "invalid"
return c
}(),
ShouldPass: false,
ExpectedErrors: []string{"invalid read timeout format"},
},
{
Name: "Negative Worker Count",
Config: func() Config {
c := baseValidConfig
c.Workers.NumWorkers = -1
return c
}(),
ShouldPass: false,
ExpectedErrors: []string{"number of workers must be positive"},
},
{
Name: "Extensions Without Dots",
Config: func() Config {
c := baseValidConfig
c.Uploads.AllowedExtensions = []string{"txt", "pdf"}
return c
}(),
ShouldPass: false,
ExpectedErrors: []string{"file extensions must start with a dot"},
},
{
Name: "High Worker Count Warning",
Config: func() Config {
c := baseValidConfig
c.Workers.NumWorkers = 100
return c
}(),
ShouldPass: true,
ExpectedWarnings: []string{"very high worker count may impact performance"},
},
{
Name: "Deduplication Without Directory",
Config: func() Config {
c := baseValidConfig
c.Deduplication.Enabled = true
c.Deduplication.Directory = ""
return c
}(),
ShouldPass: false,
ExpectedErrors: []string{"deduplication directory is required"},
},
}
}
// RunConfigTests runs all configuration test scenarios
func RunConfigTests() {
scenarios := GetConfigTestScenarios()
passed := 0
failed := 0
fmt.Println("🧪 Running Configuration Test Scenarios")
fmt.Println("=======================================")
fmt.Println()
for i, scenario := range scenarios {
fmt.Printf("Test %d: %s\n", i+1, scenario.Name)
// Create temporary directories for testing
tempDir := filepath.Join(os.TempDir(), fmt.Sprintf("hmac-test-%d", i))
os.MkdirAll(tempDir, 0755)
defer os.RemoveAll(tempDir)
// Update paths in config to use temp directory
scenario.Config.Server.StoragePath = filepath.Join(tempDir, "storage")
scenario.Config.Logging.File = filepath.Join(tempDir, "test.log")
scenario.Config.Server.PIDFilePath = filepath.Join(tempDir, "test.pid")
if scenario.Config.Deduplication.Enabled {
scenario.Config.Deduplication.Directory = filepath.Join(tempDir, "dedup")
}
result := ValidateConfigComprehensive(&scenario.Config)
// Check if test passed as expected
testPassed := true
if scenario.ShouldPass && result.HasErrors() {
fmt.Printf(" ❌ Expected to pass but failed with errors:\n")
for _, err := range result.Errors {
fmt.Printf(" • %s\n", err.Message)
}
testPassed = false
} else if !scenario.ShouldPass && !result.HasErrors() {
fmt.Printf(" ❌ Expected to fail but passed\n")
testPassed = false
} else if !scenario.ShouldPass && result.HasErrors() {
// Check if expected errors are present
expectedFound := true
for _, expectedError := range scenario.ExpectedErrors {
found := false
for _, actualError := range result.Errors {
if contains([]string{actualError.Message}, expectedError) ||
contains([]string{actualError.Error()}, expectedError) {
found = true
break
}
}
if !found {
fmt.Printf(" ❌ Expected error not found: %s\n", expectedError)
expectedFound = false
}
}
if !expectedFound {
testPassed = false
}
}
// Check expected warnings
if len(scenario.ExpectedWarnings) > 0 {
for _, expectedWarning := range scenario.ExpectedWarnings {
found := false
for _, actualWarning := range result.Warnings {
if contains([]string{actualWarning.Message}, expectedWarning) ||
contains([]string{actualWarning.Error()}, expectedWarning) {
found = true
break
}
}
if !found {
fmt.Printf(" ⚠️ Expected warning not found: %s\n", expectedWarning)
}
}
}
if testPassed {
fmt.Printf(" ✅ Passed\n")
passed++
} else {
failed++
}
fmt.Println()
}
// Summary
fmt.Printf("📊 Test Results: %d passed, %d failed\n", passed, failed)
if failed > 0 {
fmt.Printf("❌ Some tests failed. Please review the implementation.\n")
os.Exit(1)
} else {
fmt.Printf("✅ All tests passed!\n")
}
}

File diff suppressed because it is too large Load Diff

713
cmd/server/helpers.go Normal file
View File

@ -0,0 +1,713 @@
package main
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net"
"net/http"
"os"
"os/signal"
"path/filepath"
"runtime"
"strings"
"syscall"
"time"
"github.com/dutchcoders/go-clamd"
"github.com/go-redis/redis/v8"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/shirou/gopsutil/cpu"
"github.com/shirou/gopsutil/mem"
"gopkg.in/natefinch/lumberjack.v2"
)
// WorkerPool represents a pool of workers
type WorkerPool struct {
workers int
taskQueue chan UploadTask
scanQueue chan ScanTask
ctx context.Context
cancel context.CancelFunc
}
// NewWorkerPool creates a new worker pool
func NewWorkerPool(workers int, queueSize int) *WorkerPool {
ctx, cancel := context.WithCancel(context.Background())
return &WorkerPool{
workers: workers,
taskQueue: make(chan UploadTask, queueSize),
scanQueue: make(chan ScanTask, queueSize),
ctx: ctx,
cancel: cancel,
}
}
// Start starts the worker pool
func (wp *WorkerPool) Start() {
for i := 0; i < wp.workers; i++ {
go wp.worker()
}
}
// Stop stops the worker pool
func (wp *WorkerPool) Stop() {
wp.cancel()
close(wp.taskQueue)
close(wp.scanQueue)
}
// worker is the worker function
func (wp *WorkerPool) worker() {
for {
select {
case <-wp.ctx.Done():
return
case task := <-wp.taskQueue:
if task.Result != nil {
task.Result <- nil // Simple implementation
}
case scanTask := <-wp.scanQueue:
err := processScan(scanTask)
if scanTask.Result != nil {
scanTask.Result <- err
}
}
}
}
// Stub for precacheStoragePath
func precacheStoragePath(storagePath string) error {
// TODO: Implement actual pre-caching logic
// This would typically involve walking the storagePath
// and loading file information into a cache.
log.Infof("Pre-caching for storage path '%s' is a stub and not yet implemented.", storagePath)
return nil
}
func checkFreeSpaceWithRetry(path string, retries int, delay time.Duration) error {
for i := 0; i < retries; i++ {
minFreeBytes, err := parseSize(conf.Server.MinFreeBytes)
if err != nil {
log.Fatalf("Invalid MinFreeBytes: %v", err)
}
if err := checkStorageSpace(path, minFreeBytes); err != nil {
log.Warnf("Free space check failed (attempt %d/%d): %v", i+1, retries, err)
time.Sleep(delay)
continue
}
return nil
}
return fmt.Errorf("insufficient free space after %d attempts", retries)
}
func handleFileCleanup(conf *Config) {
if !conf.Server.FileTTLEnabled {
log.Println("File TTL is disabled.")
return
}
ttlDuration, err := parseTTL(conf.Server.FileTTL)
if err != nil {
log.Fatalf("Invalid TTL configuration: %v", err)
}
log.Printf("TTL cleanup enabled. Files older than %v will be deleted.", ttlDuration)
ticker := time.NewTicker(24 * time.Hour)
defer ticker.Stop()
for range ticker.C {
deleteOldFiles(conf, ttlDuration)
}
}
func computeSHA256(ctx context.Context, filePath string) (string, error) {
if filePath == "" {
return "", fmt.Errorf("file path is empty")
}
file, err := os.Open(filePath)
if err != nil {
return "", fmt.Errorf("failed to open file %s: %w", filePath, err)
}
defer file.Close()
hasher := sha256.New()
if _, err := io.Copy(hasher, file); err != nil {
return "", fmt.Errorf("failed to hash file: %w", err)
}
return hex.EncodeToString(hasher.Sum(nil)), nil
}
func handleDeduplication(ctx context.Context, absFilename string) error {
checksum, err := computeSHA256(ctx, absFilename)
if err != nil {
return err
}
dedupDir := conf.Deduplication.Directory
if dedupDir == "" {
return fmt.Errorf("deduplication directory not configured")
}
dedupPath := filepath.Join(dedupDir, checksum)
if err := os.MkdirAll(dedupPath, os.ModePerm); err != nil {
return err
}
existingPath := filepath.Join(dedupPath, filepath.Base(absFilename))
if _, err := os.Stat(existingPath); err == nil {
return os.Link(existingPath, absFilename)
}
if err := os.Rename(absFilename, existingPath); err != nil {
return err
}
return os.Link(existingPath, absFilename)
}
func handleISOContainer(absFilename string) error {
isoPath := filepath.Join(conf.ISO.MountPoint, "container.iso")
if err := CreateISOContainer([]string{absFilename}, isoPath, conf.ISO.Size, conf.ISO.Charset); err != nil {
return err
}
if err := MountISOContainer(isoPath, conf.ISO.MountPoint); err != nil {
return err
}
return UnmountISOContainer(conf.ISO.MountPoint)
}
func sanitizeFilePath(baseDir, filePath string) (string, error) {
absBaseDir, err := filepath.Abs(baseDir)
if err != nil {
return "", err
}
absFilePath, err := filepath.Abs(filepath.Join(absBaseDir, filePath))
if err != nil {
return "", err
}
if !strings.HasPrefix(absFilePath, absBaseDir) {
return "", fmt.Errorf("invalid file path: %s", filePath)
}
return absFilePath, nil
}
// Stub for formatBytes
func formatBytes(bytes int64) string {
const unit = 1024
if bytes < unit {
return fmt.Sprintf("%d B", bytes)
}
div, exp := int64(unit), 0
for n := bytes / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp])
}
// Stub for deleteOldFiles
func deleteOldFiles(conf *Config, ttlDuration time.Duration) {
// TODO: Implement actual file deletion logic based on TTL
log.Infof("deleteOldFiles is a stub and not yet implemented. It would check for files older than %v.", ttlDuration)
}
// Stub for CreateISOContainer
func CreateISOContainer(files []string, isoPath, size, charset string) error {
// TODO: Implement actual ISO container creation logic
log.Infof("CreateISOContainer is a stub and not yet implemented. It would create an ISO at %s.", isoPath)
return nil
}
// Stub for MountISOContainer
func MountISOContainer(isoPath, mountPoint string) error {
// TODO: Implement actual ISO container mounting logic
log.Infof("MountISOContainer is a stub and not yet implemented. It would mount %s to %s.", isoPath, mountPoint)
return nil
}
// Stub for UnmountISOContainer
func UnmountISOContainer(mountPoint string) error {
// TODO: Implement actual ISO container unmounting logic
log.Infof("UnmountISOContainer is a stub and not yet implemented. It would unmount %s.", mountPoint)
return nil
}
func checkStorageSpace(storagePath string, minFreeBytes int64) error {
var stat syscall.Statfs_t
if err := syscall.Statfs(storagePath, &stat); err != nil {
return err
}
availableBytes := stat.Bavail * uint64(stat.Bsize)
if int64(availableBytes) < minFreeBytes {
return fmt.Errorf("not enough space: available %d < required %d", availableBytes, minFreeBytes)
}
return nil
}
// setupLogging initializes logging configuration
func setupLogging() {
log.Infof("DEBUG: Starting setupLogging function")
if conf.Logging.File != "" {
log.Infof("DEBUG: Setting up file logging to: %s", conf.Logging.File)
log.SetOutput(&lumberjack.Logger{
Filename: conf.Logging.File,
MaxSize: conf.Logging.MaxSize,
MaxBackups: conf.Logging.MaxBackups,
MaxAge: conf.Logging.MaxAge,
Compress: conf.Logging.Compress,
})
log.Infof("Logging configured to file: %s", conf.Logging.File)
}
log.Infof("DEBUG: setupLogging function completed")
}
// logSystemInfo logs system information
func logSystemInfo() {
memStats, err := mem.VirtualMemory()
if err != nil {
log.Warnf("Failed to get memory stats: %v", err)
} else {
log.Infof("System Memory: Total=%s, Available=%s, Used=%.1f%%",
formatBytes(int64(memStats.Total)),
formatBytes(int64(memStats.Available)),
memStats.UsedPercent)
}
cpuStats, err := cpu.Info()
if err != nil {
log.Warnf("Failed to get CPU stats: %v", err)
} else if len(cpuStats) > 0 {
log.Infof("CPU: %s, Cores=%d", cpuStats[0].ModelName, len(cpuStats))
}
log.Infof("Go Runtime: Version=%s, NumCPU=%d, NumGoroutine=%d",
runtime.Version(), runtime.NumCPU(), runtime.NumGoroutine())
}
// initMetrics initializes Prometheus metrics
func initMetrics() {
uploadDuration = prometheus.NewHistogram(prometheus.HistogramOpts{
Name: "upload_duration_seconds",
Help: "Duration of upload operations in seconds",
})
uploadErrorsTotal = prometheus.NewCounter(prometheus.CounterOpts{
Name: "upload_errors_total",
Help: "Total number of upload errors",
})
uploadsTotal = prometheus.NewCounter(prometheus.CounterOpts{
Name: "uploads_total",
Help: "Total number of uploads",
})
downloadDuration = prometheus.NewHistogram(prometheus.HistogramOpts{
Name: "download_duration_seconds",
Help: "Duration of download operations in seconds",
})
downloadsTotal = prometheus.NewCounter(prometheus.CounterOpts{
Name: "downloads_total",
Help: "Total number of downloads",
})
downloadErrorsTotal = prometheus.NewCounter(prometheus.CounterOpts{
Name: "download_errors_total",
Help: "Total number of download errors",
})
memoryUsage = prometheus.NewGauge(prometheus.GaugeOpts{
Name: "memory_usage_percent",
Help: "Current memory usage percentage",
})
cpuUsage = prometheus.NewGauge(prometheus.GaugeOpts{
Name: "cpu_usage_percent",
Help: "Current CPU usage percentage",
})
activeConnections = prometheus.NewGauge(prometheus.GaugeOpts{
Name: "active_connections_total",
Help: "Number of active connections",
})
requestsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "requests_total",
Help: "Total number of requests",
}, []string{"method", "status"})
goroutines = prometheus.NewGauge(prometheus.GaugeOpts{
Name: "goroutines_total",
Help: "Number of goroutines",
})
uploadSizeBytes = prometheus.NewHistogram(prometheus.HistogramOpts{
Name: "upload_size_bytes",
Help: "Size of uploaded files in bytes",
})
downloadSizeBytes = prometheus.NewHistogram(prometheus.HistogramOpts{
Name: "download_size_bytes",
Help: "Size of downloaded files in bytes",
})
filesDeduplicatedTotal = prometheus.NewCounter(prometheus.CounterOpts{
Name: "files_deduplicated_total",
Help: "Total number of deduplicated files",
})
deduplicationErrorsTotal = prometheus.NewCounter(prometheus.CounterOpts{
Name: "deduplication_errors_total",
Help: "Total number of deduplication errors",
})
isoContainersCreatedTotal = prometheus.NewCounter(prometheus.CounterOpts{
Name: "iso_containers_created_total",
Help: "Total number of ISO containers created",
})
isoCreationErrorsTotal = prometheus.NewCounter(prometheus.CounterOpts{
Name: "iso_creation_errors_total",
Help: "Total number of ISO creation errors",
})
isoContainersMountedTotal = prometheus.NewCounter(prometheus.CounterOpts{
Name: "iso_containers_mounted_total",
Help: "Total number of ISO containers mounted",
})
isoMountErrorsTotal = prometheus.NewCounter(prometheus.CounterOpts{
Name: "iso_mount_errors_total",
Help: "Total number of ISO mount errors",
})
workerAdjustmentsTotal = prometheus.NewCounter(prometheus.CounterOpts{
Name: "worker_adjustments_total",
Help: "Total number of worker adjustments",
})
workerReAdjustmentsTotal = prometheus.NewCounter(prometheus.CounterOpts{
Name: "worker_readjustments_total",
Help: "Total number of worker readjustments",
})
// Register all metrics
prometheus.MustRegister(
uploadDuration, uploadErrorsTotal, uploadsTotal,
downloadDuration, downloadsTotal, downloadErrorsTotal,
memoryUsage, cpuUsage, activeConnections, requestsTotal,
goroutines, uploadSizeBytes, downloadSizeBytes,
filesDeduplicatedTotal, deduplicationErrorsTotal,
isoContainersCreatedTotal, isoCreationErrorsTotal,
isoContainersMountedTotal, isoMountErrorsTotal,
workerAdjustmentsTotal, workerReAdjustmentsTotal,
)
log.Info("Prometheus metrics initialized successfully")
}
// scanFileWithClamAV scans a file using ClamAV
func scanFileWithClamAV(filename string) error {
if clamClient == nil {
return fmt.Errorf("ClamAV client not initialized")
}
result, err := clamClient.ScanFile(filename)
if err != nil {
return fmt.Errorf("ClamAV scan failed: %w", err)
}
// Handle the result channel
if result != nil {
select {
case scanResult := <-result:
if scanResult != nil && scanResult.Status != "OK" {
return fmt.Errorf("virus detected in %s: %s", filename, scanResult.Status)
}
case <-time.After(30 * time.Second):
return fmt.Errorf("ClamAV scan timeout for file: %s", filename)
}
}
log.Debugf("File %s passed ClamAV scan", filename)
return nil
}
// initClamAV initializes ClamAV client
func initClamAV(socketPath string) (*clamd.Clamd, error) {
if socketPath == "" {
socketPath = "/var/run/clamav/clamd.ctl"
}
client := clamd.NewClamd(socketPath)
// Test connection
err := client.Ping()
if err != nil {
return nil, fmt.Errorf("failed to ping ClamAV daemon: %w", err)
}
log.Infof("ClamAV client initialized with socket: %s", socketPath)
return client, nil
}
// initRedis initializes Redis client
func initRedis() {
redisClient = redis.NewClient(&redis.Options{
Addr: conf.Redis.RedisAddr,
Password: conf.Redis.RedisPassword,
DB: conf.Redis.RedisDBIndex,
})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, err := redisClient.Ping(ctx).Result()
if err != nil {
log.Warnf("Failed to connect to Redis: %v", err)
redisConnected = false
} else {
log.Info("Redis client initialized successfully")
redisConnected = true
}
}
// monitorNetwork monitors network events
func monitorNetwork(ctx context.Context) {
log.Info("Starting network monitoring")
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
log.Info("Network monitoring stopped")
return
case <-ticker.C:
// Simple network monitoring - check interface status
interfaces, err := net.Interfaces()
if err != nil {
log.Warnf("Failed to get network interfaces: %v", err)
continue
}
for _, iface := range interfaces {
if iface.Flags&net.FlagUp != 0 && iface.Flags&net.FlagLoopback == 0 {
select {
case networkEvents <- NetworkEvent{
Type: "interface_up",
Details: fmt.Sprintf("Interface %s is up", iface.Name),
}:
default:
// Channel full, skip
}
}
}
}
}
}
// handleNetworkEvents handles network events
func handleNetworkEvents(ctx context.Context) {
log.Info("Starting network event handler")
for {
select {
case <-ctx.Done():
log.Info("Network event handler stopped")
return
case event := <-networkEvents:
log.Debugf("Network event: %s - %s", event.Type, event.Details)
}
}
}
// updateSystemMetrics updates system metrics
func updateSystemMetrics(ctx context.Context) {
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
// Update memory metrics
if memStats, err := mem.VirtualMemory(); err == nil {
memoryUsage.Set(memStats.UsedPercent)
}
// Update CPU metrics
if cpuPercents, err := cpu.Percent(time.Second, false); err == nil && len(cpuPercents) > 0 {
cpuUsage.Set(cpuPercents[0])
}
// Update goroutine count
goroutines.Set(float64(runtime.NumGoroutine()))
}
}
}
// setupRouter sets up HTTP routes
func setupRouter() *http.ServeMux {
mux := http.NewServeMux()
mux.HandleFunc("/upload", handleUpload)
mux.HandleFunc("/download/", handleDownload)
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
})
if conf.Server.MetricsEnabled {
mux.Handle("/metrics", promhttp.Handler())
}
// Catch-all handler for all upload protocols (v, v2, token, v3)
// This must be added last as it matches all paths
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
// Handle PUT requests for all upload protocols
if r.Method == http.MethodPut {
query := r.URL.Query()
// Check if this is a v3 request (mod_http_upload_external)
if query.Get("v3") != "" && query.Get("expires") != "" {
handleV3Upload(w, r)
return
}
// Check if this is a legacy protocol request (v, v2, token)
if query.Get("v") != "" || query.Get("v2") != "" || query.Get("token") != "" {
handleLegacyUpload(w, r)
return
}
}
// Handle GET/HEAD requests for downloads
if r.Method == http.MethodGet || r.Method == http.MethodHead {
// Only handle download requests if the path looks like a file
path := strings.TrimPrefix(r.URL.Path, "/")
if path != "" && !strings.HasSuffix(path, "/") {
handleLegacyDownload(w, r)
return
}
}
// For all other requests, return 404
http.NotFound(w, r)
})
log.Info("HTTP router configured successfully with full protocol support (v, v2, token, v3)")
return mux
}
// setupGracefulShutdown sets up graceful shutdown
func setupGracefulShutdown(server *http.Server, cancel context.CancelFunc) {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigChan
log.Info("Received shutdown signal, initiating graceful shutdown...")
// Cancel context
cancel()
// Shutdown server with timeout
ctx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer shutdownCancel()
if err := server.Shutdown(ctx); err != nil {
log.Errorf("Server shutdown error: %v", err)
} else {
log.Info("Server shutdown completed")
}
// Clean up PID file
if conf.Server.CleanUponExit {
removePIDFile(conf.Server.PIDFilePath)
}
// Stop worker pool if it exists
if workerPool != nil {
workerPool.Stop()
log.Info("Worker pool stopped")
}
os.Exit(0)
}()
}
// ProgressWriter wraps an io.Writer to provide upload progress reporting
type ProgressWriter struct {
dst io.Writer
total int64
written int64
filename string
onProgress func(written, total int64, filename string)
lastReport time.Time
}
// NewProgressWriter creates a new ProgressWriter
func NewProgressWriter(dst io.Writer, total int64, filename string) *ProgressWriter {
return &ProgressWriter{
dst: dst,
total: total,
filename: filename,
onProgress: func(written, total int64, filename string) {
if total > 0 {
percentage := float64(written) / float64(total) * 100
sizeMiB := float64(written) / (1024 * 1024)
totalMiB := float64(total) / (1024 * 1024)
log.Infof("Upload progress for %s: %.1f%% (%.1f/%.1f MiB)",
filepath.Base(filename), percentage, sizeMiB, totalMiB)
}
},
lastReport: time.Now(),
}
}
// Write implements io.Writer interface with progress reporting
func (pw *ProgressWriter) Write(p []byte) (int, error) {
n, err := pw.dst.Write(p)
if err != nil {
return n, err
}
pw.written += int64(n)
// Report progress every 30 seconds or every 50MB for large files
now := time.Now()
shouldReport := false
if pw.total > 100*1024*1024 { // Files larger than 100MB
shouldReport = now.Sub(pw.lastReport) > 30*time.Second ||
(pw.written%(50*1024*1024) == 0 && pw.written > 0)
} else if pw.total > 10*1024*1024 { // Files larger than 10MB
shouldReport = now.Sub(pw.lastReport) > 10*time.Second ||
(pw.written%(10*1024*1024) == 0 && pw.written > 0)
}
if shouldReport && pw.onProgress != nil {
pw.onProgress(pw.written, pw.total, pw.filename)
pw.lastReport = now
}
return n, err
}
// copyWithProgress copies data from src to dst with progress reporting
func copyWithProgress(dst io.Writer, src io.Reader, total int64, filename string) (int64, error) {
progressWriter := NewProgressWriter(dst, total, filename)
// Use a pooled buffer for efficient copying
bufPtr := bufferPool.Get().(*[]byte)
defer bufferPool.Put(bufPtr)
buf := *bufPtr
return io.CopyBuffer(progressWriter, src, buf)
}

File diff suppressed because it is too large Load Diff