// main.go package main import ( "bufio" "context" "crypto/hmac" "crypto/sha256" "encoding/hex" "flag" "fmt" "io" "mime" "net" "net/http" "net/url" "os" "os/signal" "path/filepath" "runtime" "strconv" "strings" "syscall" "time" "sync" "github.com/dutchcoders/go-clamd" // ClamAV integration "github.com/go-redis/redis/v8" // Redis integration "github.com/patrickmn/go-cache" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/shirou/gopsutil/cpu" "github.com/shirou/gopsutil/disk" "github.com/shirou/gopsutil/host" "github.com/shirou/gopsutil/mem" "github.com/sirupsen/logrus" "github.com/spf13/viper" ) // Configuration structure type ServerConfig struct { ListenPort string `mapstructure:"ListenPort"` UnixSocket bool `mapstructure:"UnixSocket"` StoragePath string `mapstructure:"StoragePath"` LogLevel string `mapstructure:"LogLevel"` LogFile string `mapstructure:"LogFile"` MetricsEnabled bool `mapstructure:"MetricsEnabled"` MetricsPort string `mapstructure:"MetricsPort"` FileTTL string `mapstructure:"FileTTL"` MinFreeBytes int64 `mapstructure:"MinFreeBytes"` // Minimum free bytes required DeduplicationEnabled bool `mapstructure:"DeduplicationEnabled"` } type TimeoutConfig struct { ReadTimeout string `mapstructure:"ReadTimeout"` WriteTimeout string `mapstructure:"WriteTimeout"` IdleTimeout string `mapstructure:"IdleTimeout"` } type SecurityConfig struct { Secret string `mapstructure:"Secret"` } type VersioningConfig struct { EnableVersioning bool `mapstructure:"EnableVersioning"` MaxVersions int `mapstructure:"MaxVersions"` } type UploadsConfig struct { ResumableUploadsEnabled bool `mapstructure:"ResumableUploadsEnabled"` ChunkedUploadsEnabled bool `mapstructure:"ChunkedUploadsEnabled"` ChunkSize int64 `mapstructure:"ChunkSize"` AllowedExtensions []string `mapstructure:"AllowedExtensions"` } type ClamAVConfig struct { ClamAVEnabled bool `mapstructure:"ClamAVEnabled"` ClamAVSocket string `mapstructure:"ClamAVSocket"` NumScanWorkers int `mapstructure:"NumScanWorkers"` } type RedisConfig struct { RedisEnabled bool `mapstructure:"RedisEnabled"` RedisDBIndex int `mapstructure:"RedisDBIndex"` RedisAddr string `mapstructure:"RedisAddr"` RedisPassword string `mapstructure:"RedisPassword"` RedisHealthCheckInterval string `mapstructure:"RedisHealthCheckInterval"` } type WorkersConfig struct { NumWorkers int `mapstructure:"NumWorkers"` UploadQueueSize int `mapstructure:"UploadQueueSize"` } type FileConfig struct { FileRevision int `mapstructure:"FileRevision"` } type Config struct { Server ServerConfig `mapstructure:"server"` Timeouts TimeoutConfig `mapstructure:"timeouts"` Security SecurityConfig `mapstructure:"security"` Versioning VersioningConfig `mapstructure:"versioning"` Uploads UploadsConfig `mapstructure:"uploads"` ClamAV ClamAVConfig `mapstructure:"clamav"` Redis RedisConfig `mapstructure:"redis"` Workers WorkersConfig `mapstructure:"workers"` File FileConfig `mapstructure:"file"` } // UploadTask represents a file upload task type UploadTask struct { AbsFilename string Request *http.Request Result chan error } // ScanTask represents a file scan task type ScanTask struct { AbsFilename string Result chan error } // NetworkEvent represents a network-related event type NetworkEvent struct { Type string Details string } var ( conf Config versionString string = "v2.0-dev" log = logrus.New() uploadQueue chan UploadTask networkEvents chan NetworkEvent fileInfoCache *cache.Cache clamClient *clamd.Clamd // Added for ClamAV integration redisClient *redis.Client // Redis client redisConnected bool // Redis connection status mu sync.RWMutex // Prometheus metrics uploadDuration prometheus.Histogram uploadErrorsTotal prometheus.Counter uploadsTotal prometheus.Counter downloadDuration prometheus.Histogram downloadsTotal prometheus.Counter downloadErrorsTotal prometheus.Counter memoryUsage prometheus.Gauge cpuUsage prometheus.Gauge activeConnections prometheus.Gauge requestsTotal *prometheus.CounterVec goroutines prometheus.Gauge uploadSizeBytes prometheus.Histogram downloadSizeBytes prometheus.Histogram // Constants for worker pool MinWorkers = 5 // Increased from 10 to 20 for better concurrency UploadQueueSize = 10000 // Increased from 5000 to 10000 // Channels scanQueue chan ScanTask ScanWorkers = 5 // Number of ClamAV scan workers ) func main() { // Set default configuration values setDefaults() // Flags for configuration file var configFile string flag.StringVar(&configFile, "config", "./config.toml", "Path to configuration file \"config.toml\".") flag.Parse() // Load configuration err := readConfig(configFile, &conf) if err != nil { log.Fatalf("Error reading config: %v", err) // Fatal: application cannot proceed } log.Info("Configuration loaded successfully.") // Initialize file info cache fileInfoCache = cache.New(5*time.Minute, 10*time.Minute) // Create store directory err = os.MkdirAll(conf.Server.StoragePath, os.ModePerm) if err != nil { log.Fatalf("Error creating store directory: %v", err) } log.WithField("directory", conf.Server.StoragePath).Info("Store directory is ready") // Setup logging setupLogging() // Log system information logSystemInfo() // Initialize Prometheus metrics initMetrics() log.Info("Prometheus metrics initialized.") // Initialize upload and scan queues uploadQueue = make(chan UploadTask, conf.Workers.UploadQueueSize) scanQueue = make(chan ScanTask, conf.Workers.UploadQueueSize) networkEvents = make(chan NetworkEvent, 100) log.Info("Upload, scan, and network event channels initialized.") // Context for goroutines ctx, cancel := context.WithCancel(context.Background()) defer cancel() // Start network monitoring go monitorNetwork(ctx) go handleNetworkEvents(ctx) // Update system metrics go updateSystemMetrics(ctx) // Initialize ClamAV client if enabled if conf.ClamAV.ClamAVEnabled { clamClient, err = initClamAV(conf.ClamAV.ClamAVSocket) if err != nil { log.WithFields(logrus.Fields{ "error": err.Error(), }).Warn("ClamAV client initialization failed. Continuing without ClamAV.") } else { log.Info("ClamAV client initialized successfully.") } } // Initialize Redis client if enabled if conf.Redis.RedisEnabled { initRedis() } // Redis Initialization initRedis() log.Info("Redis client initialized and connected successfully.") // ClamAV Initialization if conf.ClamAV.ClamAVEnabled { clamClient, err = initClamAV(conf.ClamAV.ClamAVSocket) if err != nil { log.WithFields(logrus.Fields{ "error": err.Error(), }).Warn("ClamAV client initialization failed. Continuing without ClamAV.") } else { log.Info("ClamAV client initialized successfully.") } } // Initialize worker pools initializeUploadWorkerPool(ctx) if conf.ClamAV.ClamAVEnabled && clamClient != nil { initializeScanWorkerPool(ctx) } // Start Redis health monitor if Redis is enabled if conf.Redis.RedisEnabled && redisClient != nil { go MonitorRedisHealth(ctx, redisClient, parseDuration(conf.Redis.RedisHealthCheckInterval)) } // Setup router router := setupRouter() // Start file cleaner fileTTL, err := time.ParseDuration(conf.Server.FileTTL) if err != nil { log.Fatalf("Invalid FileTTL: %v", err) } go runFileCleaner(ctx, conf.Server.StoragePath, fileTTL) // Parse timeout durations readTimeout, err := time.ParseDuration(conf.Timeouts.ReadTimeout) if err != nil { log.Fatalf("Invalid ReadTimeout: %v", err) } writeTimeout, err := time.ParseDuration(conf.Timeouts.WriteTimeout) if err != nil { log.Fatalf("Invalid WriteTimeout: %v", err) } idleTimeout, err := time.ParseDuration(conf.Timeouts.IdleTimeout) if err != nil { log.Fatalf("Invalid IdleTimeout: %v", err) } // Configure HTTP server server := &http.Server{ Addr: ":" + conf.Server.ListenPort, // Prepend colon to ListenPort Handler: router, ReadTimeout: readTimeout, WriteTimeout: writeTimeout, IdleTimeout: idleTimeout, } // Start metrics server if enabled if conf.Server.MetricsEnabled { go func() { http.Handle("/metrics", promhttp.Handler()) log.Infof("Metrics server started on port %s", conf.Server.MetricsPort) if err := http.ListenAndServe(":"+conf.Server.MetricsPort, nil); err != nil { log.Fatalf("Metrics server failed: %v", err) } }() } // Setup graceful shutdown setupGracefulShutdown(server, cancel) // Start server log.Infof("Starting HMAC file server %s...", versionString) if conf.Server.UnixSocket { // Listen on Unix socket if err := os.RemoveAll(conf.Server.ListenPort); err != nil { log.Fatalf("Failed to remove existing Unix socket: %v", err) } listener, err := net.Listen("unix", conf.Server.ListenPort) if err != nil { log.Fatalf("Failed to listen on Unix socket %s: %v", conf.Server.ListenPort, err) } defer listener.Close() if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { log.Fatalf("Server failed: %v", err) } } else { // Listen on TCP port if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Fatalf("Server failed: %v", err) } } } // Function to load configuration using Viper func readConfig(configFilename string, conf *Config) error { viper.SetConfigFile(configFilename) viper.SetConfigType("toml") // Read in environment variables that match viper.AutomaticEnv() viper.SetEnvPrefix("HMAC") // Prefix for environment variables // Read the config file if err := viper.ReadInConfig(); err != nil { return fmt.Errorf("error reading config file: %w", err) } // Unmarshal the config into the Config struct if err := viper.Unmarshal(conf); err != nil { return fmt.Errorf("unable to decode into struct: %w", err) } // Debug log the loaded configuration log.Debugf("Loaded Configuration: %+v", conf.Server) // Validate the configuration if err := validateConfig(conf); err != nil { return fmt.Errorf("configuration validation failed: %w", err) } // Set Deduplication Enabled conf.Server.DeduplicationEnabled = viper.GetBool("deduplication.Enabled") return nil } // Set default configuration values func setDefaults() { // Server defaults viper.SetDefault("server.ListenPort", "8080") viper.SetDefault("server.UnixSocket", false) viper.SetDefault("server.StoragePath", "./uploads") viper.SetDefault("server.LogLevel", "info") viper.SetDefault("server.LogFile", "") viper.SetDefault("server.MetricsEnabled", true) viper.SetDefault("server.MetricsPort", "9090") viper.SetDefault("server.FileTTL", "8760h") // 365d -> 8760h viper.SetDefault("server.MinFreeBytes", 100<<20) // 100 MB // Timeout defaults viper.SetDefault("timeouts.ReadTimeout", "4800s") // supports 's' viper.SetDefault("timeouts.WriteTimeout", "4800s") viper.SetDefault("timeouts.IdleTimeout", "4800s") // Security defaults viper.SetDefault("security.Secret", "changeme") // Versioning defaults viper.SetDefault("versioning.EnableVersioning", false) viper.SetDefault("versioning.MaxVersions", 1) // Uploads defaults viper.SetDefault("uploads.ResumableUploadsEnabled", true) viper.SetDefault("uploads.ChunkedUploadsEnabled", true) viper.SetDefault("uploads.ChunkSize", 8192) viper.SetDefault("uploads.AllowedExtensions", []string{ ".txt", ".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp", ".wav", ".mp4", ".avi", ".mkv", ".mov", ".wmv", ".flv", ".webm", ".mpeg", ".mpg", ".m4v", ".3gp", ".3g2", ".mp3", ".ogg", }) // ClamAV defaults viper.SetDefault("clamav.ClamAVEnabled", true) viper.SetDefault("clamav.ClamAVSocket", "/var/run/clamav/clamd.ctl") viper.SetDefault("clamav.NumScanWorkers", 2) // Redis defaults viper.SetDefault("redis.RedisEnabled", true) viper.SetDefault("redis.RedisAddr", "localhost:6379") viper.SetDefault("redis.RedisPassword", "") viper.SetDefault("redis.RedisDBIndex", 0) viper.SetDefault("redis.RedisHealthCheckInterval", "120s") // Workers defaults viper.SetDefault("workers.NumWorkers", 2) viper.SetDefault("workers.UploadQueueSize", 50) // Deduplication defaults viper.SetDefault("deduplication.Enabled", true) } // Validate configuration fields func validateConfig(conf *Config) error { if conf.Server.ListenPort == "" { return fmt.Errorf("ListenPort must be set") } if conf.Security.Secret == "" { return fmt.Errorf("secret must be set") } if conf.Server.StoragePath == "" { return fmt.Errorf("StoragePath must be set") } if conf.Server.FileTTL == "" { return fmt.Errorf("FileTTL must be set") } // Validate timeouts if _, err := time.ParseDuration(conf.Timeouts.ReadTimeout); err != nil { return fmt.Errorf("invalid ReadTimeout: %v", err) } if _, err := time.ParseDuration(conf.Timeouts.WriteTimeout); err != nil { return fmt.Errorf("invalid WriteTimeout: %v", err) } if _, err := time.ParseDuration(conf.Timeouts.IdleTimeout); err != nil { return fmt.Errorf("invalid IdleTimeout: %v", err) } // Validate Redis configuration if enabled if conf.Redis.RedisEnabled { if conf.Redis.RedisAddr == "" { return fmt.Errorf("RedisAddr must be set when Redis is enabled") } } // Add more validations as needed return nil } // Setup logging func setupLogging() { level, err := logrus.ParseLevel(conf.Server.LogLevel) if err != nil { log.Fatalf("Invalid log level: %s", conf.Server.LogLevel) } log.SetLevel(level) if conf.Server.LogFile != "" { logFile, err := os.OpenFile(conf.Server.LogFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) if err != nil { log.Fatalf("Failed to open log file: %v", err) } log.SetOutput(io.MultiWriter(os.Stdout, logFile)) } else { log.SetOutput(os.Stdout) } // Use Text formatter for human-readable logs log.SetFormatter(&logrus.TextFormatter{ FullTimestamp: true, // You can customize the format further if needed }) } // Log system information func logSystemInfo() { log.Info("========================================") log.Infof(" HMAC File Server - %s ", versionString) log.Info(" Secure File Handling with HMAC Auth ") log.Info("========================================") log.Info("Features: Prometheus Metrics, Chunked Uploads, ClamAV Scanning") log.Info("Build Date: 2024-10-28") log.Infof("Operating System: %s", runtime.GOOS) log.Infof("Architecture: %s", runtime.GOARCH) log.Infof("Number of CPUs: %d", runtime.NumCPU()) log.Infof("Go Version: %s", runtime.Version()) v, _ := mem.VirtualMemory() log.Infof("Total Memory: %v MB", v.Total/1024/1024) log.Infof("Free Memory: %v MB", v.Free/1024/1024) log.Infof("Used Memory: %v MB", v.Used/1024/1024) cpuInfo, _ := cpu.Info() for _, info := range cpuInfo { log.Infof("CPU Model: %s, Cores: %d, Mhz: %f", info.ModelName, info.Cores, info.Mhz) } partitions, _ := disk.Partitions(false) for _, partition := range partitions { usage, _ := disk.Usage(partition.Mountpoint) log.Infof("Disk Mountpoint: %s, Total: %v GB, Free: %v GB, Used: %v GB", partition.Mountpoint, usage.Total/1024/1024/1024, usage.Free/1024/1024/1024, usage.Used/1024/1024/1024) } hInfo, _ := host.Info() log.Infof("Hostname: %s", hInfo.Hostname) log.Infof("Uptime: %v seconds", hInfo.Uptime) log.Infof("Boot Time: %v", time.Unix(int64(hInfo.BootTime), 0)) log.Infof("Platform: %s", hInfo.Platform) log.Infof("Platform Family: %s", hInfo.PlatformFamily) log.Infof("Platform Version: %s", hInfo.PlatformVersion) log.Infof("Kernel Version: %s", hInfo.KernelVersion) } // Initialize Prometheus metrics // Duplicate initMetrics function removed func initMetrics() { uploadDuration = prometheus.NewHistogram(prometheus.HistogramOpts{Namespace: "hmac", Name: "file_server_upload_duration_seconds", Help: "Histogram of file upload duration in seconds.", Buckets: prometheus.DefBuckets}) uploadErrorsTotal = prometheus.NewCounter(prometheus.CounterOpts{Namespace: "hmac", Name: "file_server_upload_errors_total", Help: "Total number of file upload errors."}) uploadsTotal = prometheus.NewCounter(prometheus.CounterOpts{Namespace: "hmac", Name: "file_server_uploads_total", Help: "Total number of successful file uploads."}) downloadDuration = prometheus.NewHistogram(prometheus.HistogramOpts{Namespace: "hmac", Name: "file_server_download_duration_seconds", Help: "Histogram of file download duration in seconds.", Buckets: prometheus.DefBuckets}) downloadsTotal = prometheus.NewCounter(prometheus.CounterOpts{Namespace: "hmac", Name: "file_server_downloads_total", Help: "Total number of successful file downloads."}) downloadErrorsTotal = prometheus.NewCounter(prometheus.CounterOpts{Namespace: "hmac", Name: "file_server_download_errors_total", Help: "Total number of file download errors."}) memoryUsage = prometheus.NewGauge(prometheus.GaugeOpts{Namespace: "hmac", Name: "memory_usage_bytes", Help: "Current memory usage in bytes."}) cpuUsage = prometheus.NewGauge(prometheus.GaugeOpts{Namespace: "hmac", Name: "cpu_usage_percent", Help: "Current CPU usage as a percentage."}) activeConnections = prometheus.NewGauge(prometheus.GaugeOpts{Namespace: "hmac", Name: "active_connections_total", Help: "Total number of active connections."}) requestsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{Namespace: "hmac", Name: "http_requests_total", Help: "Total number of HTTP requests received, labeled by method and path."}, []string{"method", "path"}) goroutines = prometheus.NewGauge(prometheus.GaugeOpts{Namespace: "hmac", Name: "goroutines_count", Help: "Current number of goroutines."}) uploadSizeBytes = prometheus.NewHistogram(prometheus.HistogramOpts{ Namespace: "hmac", Name: "file_server_upload_size_bytes", Help: "Histogram of uploaded file sizes in bytes.", Buckets: prometheus.ExponentialBuckets(100, 10, 8), }) downloadSizeBytes = prometheus.NewHistogram(prometheus.HistogramOpts{ Namespace: "hmac", Name: "file_server_download_size_bytes", Help: "Histogram of downloaded file sizes in bytes.", Buckets: prometheus.ExponentialBuckets(100, 10, 8), }) if conf.Server.MetricsEnabled { prometheus.MustRegister(uploadDuration, uploadErrorsTotal, uploadsTotal) prometheus.MustRegister(downloadDuration, downloadsTotal, downloadErrorsTotal) prometheus.MustRegister(memoryUsage, cpuUsage, activeConnections, requestsTotal, goroutines) prometheus.MustRegister(uploadSizeBytes, downloadSizeBytes) } } // Update system metrics func updateSystemMetrics(ctx context.Context) { ticker := time.NewTicker(10 * time.Second) defer ticker.Stop() for { select { case <-ctx.Done(): log.Info("Stopping system metrics updater.") return case <-ticker.C: v, _ := mem.VirtualMemory() memoryUsage.Set(float64(v.Used)) cpuPercent, _ := cpu.Percent(0, false) if len(cpuPercent) > 0 { cpuUsage.Set(cpuPercent[0]) } goroutines.Set(float64(runtime.NumGoroutine())) } } } // Function to check if a file exists and return its size func fileExists(filePath string) (bool, int64) { if cachedInfo, found := fileInfoCache.Get(filePath); found { if info, ok := cachedInfo.(os.FileInfo); ok { return !info.IsDir(), info.Size() } } fileInfo, err := os.Stat(filePath) if os.IsNotExist(err) { return false, 0 } else if err != nil { log.Error("Error checking file existence:", err) return false, 0 } fileInfoCache.Set(filePath, fileInfo, cache.DefaultExpiration) return !fileInfo.IsDir(), fileInfo.Size() } // Function to check file extension func isExtensionAllowed(filename string) bool { if len(conf.Uploads.AllowedExtensions) == 0 { return true // No restrictions if the list is empty } ext := strings.ToLower(filepath.Ext(filename)) for _, allowedExt := range conf.Uploads.AllowedExtensions { if strings.ToLower(allowedExt) == ext { return true } } return false } // Version the file by moving the existing file to a versioned directory func versionFile(absFilename string) error { versionDir := absFilename + "_versions" err := os.MkdirAll(versionDir, os.ModePerm) if err != nil { return fmt.Errorf("failed to create version directory: %v", err) } timestamp := time.Now().Format("20060102-150405") versionedFilename := filepath.Join(versionDir, filepath.Base(absFilename)+"."+timestamp) err = os.Rename(absFilename, versionedFilename) if err != nil { return fmt.Errorf("failed to version the file: %v", err) } log.WithFields(logrus.Fields{ "original": absFilename, "versioned_as": versionedFilename, }).Info("Versioned old file") return cleanupOldVersions(versionDir) } // Clean up older versions if they exceed the maximum allowed func cleanupOldVersions(versionDir string) error { files, err := os.ReadDir(versionDir) if err != nil { return fmt.Errorf("failed to list version files: %v", err) } if conf.Versioning.MaxVersions > 0 && len(files) > conf.Versioning.MaxVersions { excessFiles := len(files) - conf.Versioning.MaxVersions for i := 0; i < excessFiles; i++ { err := os.Remove(filepath.Join(versionDir, files[i].Name())) if err != nil { return fmt.Errorf("failed to remove old version: %v", err) } log.WithField("file", files[i].Name()).Info("Removed old version") } } return nil } // Process the upload task func processUpload(task UploadTask) error { absFilename := task.AbsFilename tempFilename := absFilename + ".tmp" r := task.Request log.Infof("Processing upload for file: %s", absFilename) startTime := time.Now() // Handle uploads and write to a temporary file if conf.Uploads.ChunkedUploadsEnabled { log.Debugf("Chunked uploads enabled. Handling chunked upload for %s", tempFilename) err := handleChunkedUpload(tempFilename, r) if err != nil { uploadDuration.Observe(time.Since(startTime).Seconds()) log.WithFields(logrus.Fields{ "file": tempFilename, "error": err, }).Error("Failed to handle chunked upload") return err } } else { log.Debugf("Handling standard upload for %s", tempFilename) err := createFile(tempFilename, r) if err != nil { log.WithFields(logrus.Fields{ "file": tempFilename, "error": err, }).Error("Error creating file") uploadDuration.Observe(time.Since(startTime).Seconds()) return err } } // Perform ClamAV scan on the temporary file if clamClient != nil { log.Debugf("Scanning %s with ClamAV", tempFilename) err := scanFileWithClamAV(tempFilename) if err != nil { log.WithFields(logrus.Fields{ "file": tempFilename, "error": err, }).Warn("ClamAV detected a virus or scan failed") os.Remove(tempFilename) uploadErrorsTotal.Inc() return err } log.Infof("ClamAV scan passed for file: %s", tempFilename) } // Handle file versioning if enabled if conf.Versioning.EnableVersioning { existing, _ := fileExists(absFilename) if existing { log.Infof("File %s exists. Initiating versioning.", absFilename) err := versionFile(absFilename) if err != nil { log.WithFields(logrus.Fields{ "file": absFilename, "error": err, }).Error("Error versioning file") os.Remove(tempFilename) return err } log.Infof("File versioned successfully: %s", absFilename) } } // Rename temporary file to final destination err := os.Rename(tempFilename, absFilename) if err != nil { log.WithFields(logrus.Fields{ "temp_file": tempFilename, "final_file": absFilename, "error": err, }).Error("Failed to move file to final destination") os.Remove(tempFilename) return err } log.Infof("File moved to final destination: %s", absFilename) // Handle deduplication if enabled if conf.Server.DeduplicationEnabled { log.Debugf("Deduplication enabled. Checking duplicates for %s", absFilename) err = handleDeduplication(context.Background(), absFilename) if err != nil { log.WithError(err).Error("Deduplication failed") uploadErrorsTotal.Inc() return err } log.Infof("Deduplication handled successfully for file: %s", absFilename) } log.WithFields(logrus.Fields{ "file": absFilename, }).Info("File uploaded and processed successfully") uploadDuration.Observe(time.Since(startTime).Seconds()) uploadsTotal.Inc() return nil } // uploadWorker processes upload tasks from the uploadQueue func uploadWorker(ctx context.Context, workerID int) { log.Infof("Upload worker %d started.", workerID) defer log.Infof("Upload worker %d stopped.", workerID) for { select { case <-ctx.Done(): return case task, ok := <-uploadQueue: if !ok { log.Warnf("Upload queue closed. Worker %d exiting.", workerID) return } log.Infof("Worker %d processing upload for file: %s", workerID, task.AbsFilename) err := processUpload(task) if err != nil { log.Errorf("Worker %d failed to process upload for %s: %v", workerID, task.AbsFilename, err) uploadErrorsTotal.Inc() } else { log.Infof("Worker %d successfully processed upload for %s", workerID, task.AbsFilename) } task.Result <- err close(task.Result) } } } // Initialize upload worker pool func initializeUploadWorkerPool(ctx context.Context) { for i := 0; i < MinWorkers; i++ { go uploadWorker(ctx, i) } log.Infof("Initialized %d upload workers", MinWorkers) } // Worker function to process scan tasks func scanWorker(ctx context.Context, workerID int) { log.WithField("worker_id", workerID).Info("Scan worker started") for { select { case <-ctx.Done(): log.WithField("worker_id", workerID).Info("Scan worker stopping") return case task, ok := <-scanQueue: if !ok { log.WithField("worker_id", workerID).Info("Scan queue closed") return } log.WithFields(logrus.Fields{ "worker_id": workerID, "file": task.AbsFilename, }).Info("Processing scan task") err := scanFileWithClamAV(task.AbsFilename) if err != nil { log.WithFields(logrus.Fields{ "worker_id": workerID, "file": task.AbsFilename, "error": err, }).Error("Failed to scan file") } else { log.WithFields(logrus.Fields{ "worker_id": workerID, "file": task.AbsFilename, }).Info("Successfully scanned file") } task.Result <- err close(task.Result) } } } // Initialize scan worker pool func initializeScanWorkerPool(ctx context.Context) { for i := 0; i < ScanWorkers; i++ { go scanWorker(ctx, i) } log.Infof("Initialized %d scan workers", ScanWorkers) } // Setup router with middleware func setupRouter() http.Handler { mux := http.NewServeMux() mux.HandleFunc("/", handleRequest) if conf.Server.MetricsEnabled { mux.Handle("/metrics", promhttp.Handler()) } // Apply middleware handler := loggingMiddleware(mux) handler = recoveryMiddleware(handler) handler = corsMiddleware(handler) return handler } // Middleware for logging func loggingMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { requestsTotal.WithLabelValues(r.Method, r.URL.Path).Inc() next.ServeHTTP(w, r) }) } // Middleware for panic recovery func recoveryMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer func() { if rec := recover(); rec != nil { log.WithFields(logrus.Fields{ "method": r.Method, "url": r.URL.String(), "error": rec, }).Error("Panic recovered in HTTP handler") http.Error(w, "Internal Server Error", http.StatusInternalServerError) } }() next.ServeHTTP(w, r) }) } // corsMiddleware handles CORS by setting appropriate headers func corsMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Set CORS headers w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-File-MAC") w.Header().Set("Access-Control-Max-Age", "86400") // Cache preflight response for 1 day // Handle preflight OPTIONS request if r.Method == http.MethodOptions { w.WriteHeader(http.StatusOK) return } // Proceed to the next handler next.ServeHTTP(w, r) }) } // Handle file uploads and downloads func handleRequest(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodPost && strings.Contains(r.Header.Get("Content-Type"), "multipart/form-data") { absFilename, err := sanitizeFilePath(conf.Server.StoragePath, strings.TrimPrefix(r.URL.Path, "/")) if err != nil { log.WithError(err).Error("Invalid file path") http.Error(w, "Invalid file path", http.StatusBadRequest) return } err = handleMultipartUpload(w, r, absFilename) if err != nil { log.WithError(err).Error("Failed to handle multipart upload") http.Error(w, "Failed to handle multipart upload", http.StatusInternalServerError) return } w.WriteHeader(http.StatusCreated) return } // Get client IP address clientIP := r.Header.Get("X-Real-IP") if clientIP == "" { clientIP = r.Header.Get("X-Forwarded-For") } if clientIP == "" { // Fallback to RemoteAddr host, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { log.WithError(err).Warn("Failed to parse RemoteAddr") clientIP = r.RemoteAddr } else { clientIP = host } } // Log the request with the client IP log.WithFields(logrus.Fields{ "method": r.Method, "url": r.URL.String(), "remote": clientIP, }).Info("Incoming request") // Parse URL and query parameters p := r.URL.Path a, err := url.ParseQuery(r.URL.RawQuery) if err != nil { log.Warn("Failed to parse query parameters") http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } fileStorePath := strings.TrimPrefix(p, "/") if fileStorePath == "" || fileStorePath == "/" { log.Warn("Access to root directory is forbidden") http.Error(w, "Forbidden", http.StatusForbidden) return } else if fileStorePath[0] == '/' { fileStorePath = fileStorePath[1:] } absFilename, err := sanitizeFilePath(conf.Server.StoragePath, fileStorePath) if err != nil { log.WithFields(logrus.Fields{ "file": fileStorePath, "error": err, }).Warn("Invalid file path") http.Error(w, "Invalid file path", http.StatusBadRequest) return } switch r.Method { case http.MethodPut: handleUpload(w, r, absFilename, fileStorePath, a) case http.MethodHead, http.MethodGet: handleDownload(w, r, absFilename, fileStorePath) case http.MethodOptions: // Handled by NGINX; no action needed w.Header().Set("Allow", "OPTIONS, GET, PUT, HEAD") return default: log.WithField("method", r.Method).Warn("Invalid HTTP method for upload directory") http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) return } } // Handle file uploads with extension restrictions and HMAC validation func handleUpload(w http.ResponseWriter, r *http.Request, absFilename, fileStorePath string, a url.Values) { // Log the storage path being used log.Infof("Using storage path: %s", conf.Server.StoragePath) // Determine protocol version based on query parameters var protocolVersion string if a.Get("v2") != "" { protocolVersion = "v2" } else if a.Get("token") != "" { protocolVersion = "token" } else if a.Get("v") != "" { protocolVersion = "v" } else { log.Warn("No HMAC attached to URL. Expecting 'v', 'v2', or 'token' parameter as MAC") http.Error(w, "No HMAC attached to URL. Expecting 'v', 'v2', or 'token' parameter as MAC", http.StatusForbidden) return } log.Debugf("Protocol version determined: %s", protocolVersion) // Initialize HMAC mac := hmac.New(sha256.New, []byte(conf.Security.Secret)) // Calculate MAC based on protocolVersion if protocolVersion == "v" { mac.Write([]byte(fileStorePath + "\x20" + strconv.FormatInt(r.ContentLength, 10))) } else if protocolVersion == "v2" || protocolVersion == "token" { contentType := mime.TypeByExtension(filepath.Ext(fileStorePath)) if contentType == "" { contentType = "application/octet-stream" } mac.Write([]byte(fileStorePath + "\x00" + strconv.FormatInt(r.ContentLength, 10) + "\x00" + contentType)) } calculatedMAC := mac.Sum(nil) log.Debugf("Calculated MAC: %x", calculatedMAC) // Decode provided MAC from hex providedMACHex := a.Get(protocolVersion) providedMAC, err := hex.DecodeString(providedMACHex) if err != nil { log.Warn("Invalid MAC encoding") http.Error(w, "Invalid MAC encoding", http.StatusForbidden) return } log.Debugf("Provided MAC: %x", providedMAC) // Validate the HMAC if !hmac.Equal(calculatedMAC, providedMAC) { log.Warn("Invalid MAC") http.Error(w, "Invalid MAC", http.StatusForbidden) return } log.Debug("HMAC validation successful") // Validate file extension if !isExtensionAllowed(fileStorePath) { log.WithFields(logrus.Fields{ // No need to sanitize and validate the file path here since absFilename is already sanitized in handleRequest "file": fileStorePath, "error": err, }).Warn("Invalid file path") http.Error(w, "Invalid file path", http.StatusBadRequest) uploadErrorsTotal.Inc() return } // absFilename = sanitizedFilename // Check if there is enough free space err = checkStorageSpace(conf.Server.StoragePath, conf.Server.MinFreeBytes) if err != nil { log.WithFields(logrus.Fields{ "storage_path": conf.Server.StoragePath, "error": err, }).Warn("Not enough free space") http.Error(w, "Not enough free space", http.StatusInsufficientStorage) uploadErrorsTotal.Inc() return } // Create an UploadTask with a result channel result := make(chan error) task := UploadTask{ AbsFilename: absFilename, Request: r, Result: result, } // Submit task to the upload queue select { case uploadQueue <- task: // Successfully added to the queue log.Debug("Upload task enqueued successfully") default: // Queue is full log.Warn("Upload queue is full. Rejecting upload") http.Error(w, "Server busy. Try again later.", http.StatusServiceUnavailable) uploadErrorsTotal.Inc() return } // Wait for the worker to process the upload err = <-result if err != nil { // The worker has already logged the error; send an appropriate HTTP response http.Error(w, fmt.Sprintf("Upload failed: %v", err), http.StatusInternalServerError) return } // Upload was successful w.WriteHeader(http.StatusCreated) } // Handle file downloads func handleDownload(w http.ResponseWriter, r *http.Request, absFilename, fileStorePath string) { fileInfo, err := getFileInfo(absFilename) if err != nil { log.WithError(err).Error("Failed to get file information") http.Error(w, "Not Found", http.StatusNotFound) downloadErrorsTotal.Inc() return } else if fileInfo.IsDir() { log.Warn("Directory listing forbidden") http.Error(w, "Forbidden", http.StatusForbidden) downloadErrorsTotal.Inc() return } contentType := mime.TypeByExtension(filepath.Ext(fileStorePath)) if contentType == "" { contentType = "application/octet-stream" } w.Header().Set("Content-Type", contentType) // Handle resumable downloads if conf.Uploads.ResumableUploadsEnabled { handleResumableDownload(absFilename, w, r, fileInfo.Size()) return } if r.Method == http.MethodHead { w.Header().Set("Content-Length", strconv.FormatInt(fileInfo.Size(), 10)) downloadsTotal.Inc() return } else { // Measure download duration startTime := time.Now() log.Infof("Initiating download for file: %s", absFilename) http.ServeFile(w, r, absFilename) downloadDuration.Observe(time.Since(startTime).Seconds()) downloadSizeBytes.Observe(float64(fileInfo.Size())) downloadsTotal.Inc() log.Infof("File downloaded successfully: %s", absFilename) return } } // Create the file for upload with buffered Writer func createFile(tempFilename string, r *http.Request) error { absDirectory := filepath.Dir(tempFilename) err := os.MkdirAll(absDirectory, os.ModePerm) if err != nil { log.WithError(err).Errorf("Failed to create directory %s", absDirectory) return fmt.Errorf("failed to create directory %s: %w", absDirectory, err) } // Open the file for writing targetFile, err := os.OpenFile(tempFilename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) if err != nil { log.WithError(err).Errorf("Failed to create file %s", tempFilename) return fmt.Errorf("failed to create file %s: %w", tempFilename, err) } defer targetFile.Close() // Use a large buffer for efficient file writing bufferSize := 4 * 1024 * 1024 // 4 MB buffer writer := bufio.NewWriterSize(targetFile, bufferSize) buffer := make([]byte, bufferSize) totalBytes := int64(0) for { n, readErr := r.Body.Read(buffer) if n > 0 { totalBytes += int64(n) _, writeErr := writer.Write(buffer[:n]) if writeErr != nil { log.WithError(writeErr).Errorf("Failed to write to file %s", tempFilename) return fmt.Errorf("failed to write to file %s: %w", tempFilename, writeErr) } } if readErr != nil { if readErr == io.EOF { break } log.WithError(readErr).Error("Failed to read request body") return fmt.Errorf("failed to read request body: %w", readErr) } } err = writer.Flush() if err != nil { log.WithError(err).Errorf("Failed to flush buffer to file %s", tempFilename) return fmt.Errorf("failed to flush buffer to file %s: %w", tempFilename, err) } log.WithFields(logrus.Fields{ "temp_file": tempFilename, "total_bytes": totalBytes, }).Info("File uploaded successfully") uploadSizeBytes.Observe(float64(totalBytes)) return nil } // Scan the uploaded file with ClamAV (Optional) func scanFileWithClamAV(filePath string) error { log.WithField("file", filePath).Info("Scanning file with ClamAV") scanResultChan, err := clamClient.ScanFile(filePath) if err != nil { log.WithError(err).Error("Failed to initiate ClamAV scan") return fmt.Errorf("failed to initiate ClamAV scan: %w", err) } // Receive scan result scanResult := <-scanResultChan if scanResult == nil { log.Error("Failed to receive scan result from ClamAV") return fmt.Errorf("failed to receive scan result from ClamAV") } // Handle scan result switch scanResult.Status { case clamd.RES_OK: log.WithField("file", filePath).Info("ClamAV scan passed") return nil case clamd.RES_FOUND: log.WithFields(logrus.Fields{ "file": filePath, "description": scanResult.Description, }).Warn("ClamAV detected a virus") return fmt.Errorf("virus detected: %s", scanResult.Description) default: log.WithFields(logrus.Fields{ "file": filePath, "status": scanResult.Status, "description": scanResult.Description, }).Warn("ClamAV scan returned unexpected status") return fmt.Errorf("ClamAV scan returned unexpected status: %s", scanResult.Description) } } // initClamAV initializes the ClamAV client and logs the status func initClamAV(socket string) (*clamd.Clamd, error) { if socket == "" { log.Error("ClamAV socket path is not configured.") return nil, fmt.Errorf("ClamAV socket path is not configured") } clamClient := clamd.NewClamd("unix:" + socket) err := clamClient.Ping() if err != nil { log.Errorf("Failed to connect to ClamAV at %s: %v", socket, err) return nil, fmt.Errorf("failed to connect to ClamAV: %w", err) } log.Info("Connected to ClamAV successfully.") return clamClient, nil } // Handle resumable downloads func handleResumableDownload(absFilename string, w http.ResponseWriter, r *http.Request, fileSize int64) { rangeHeader := r.Header.Get("Range") if rangeHeader == "" { // If no Range header, serve the full file startTime := time.Now() http.ServeFile(w, r, absFilename) downloadDuration.Observe(time.Since(startTime).Seconds()) downloadSizeBytes.Observe(float64(fileSize)) downloadsTotal.Inc() return } // Parse Range header ranges := strings.Split(strings.TrimPrefix(rangeHeader, "bytes="), "-") if len(ranges) != 2 { http.Error(w, "Invalid Range", http.StatusRequestedRangeNotSatisfiable) downloadErrorsTotal.Inc() return } start, err := strconv.ParseInt(ranges[0], 10, 64) if err != nil { http.Error(w, "Invalid Range", http.StatusRequestedRangeNotSatisfiable) downloadErrorsTotal.Inc() return } // Calculate end byte end := fileSize - 1 if ranges[1] != "" { end, err = strconv.ParseInt(ranges[1], 10, 64) if err != nil || end >= fileSize { http.Error(w, "Invalid Range", http.StatusRequestedRangeNotSatisfiable) downloadErrorsTotal.Inc() return } } // Set response headers for partial content w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, fileSize)) w.Header().Set("Content-Length", strconv.FormatInt(end-start+1, 10)) w.Header().Set("Accept-Ranges", "bytes") w.WriteHeader(http.StatusPartialContent) // Serve the requested byte range file, err := os.Open(absFilename) if err != nil { http.Error(w, "Internal Server Error", http.StatusInternalServerError) downloadErrorsTotal.Inc() return } defer file.Close() // Seek to the start byte _, err = file.Seek(start, 0) if err != nil { http.Error(w, "Internal Server Error", http.StatusInternalServerError) downloadErrorsTotal.Inc() return } // Create a buffer and copy the specified range to the response writer buffer := make([]byte, 32*1024) // 32KB buffer remaining := end - start + 1 startTime := time.Now() for remaining > 0 { if int64(len(buffer)) > remaining { buffer = buffer[:remaining] } n, err := file.Read(buffer) if n > 0 { if _, writeErr := w.Write(buffer[:n]); writeErr != nil { log.WithError(writeErr).Error("Failed to write to response") downloadErrorsTotal.Inc() return } remaining -= int64(n) } if err != nil { if err != io.EOF { log.WithError(err).Error("Error reading file during resumable download") http.Error(w, "Internal Server Error", http.StatusInternalServerError) downloadErrorsTotal.Inc() } break } } downloadDuration.Observe(time.Since(startTime).Seconds()) downloadSizeBytes.Observe(float64(end - start + 1)) downloadsTotal.Inc() } // Handle chunked uploads with bufio.Writer func handleChunkedUpload(tempFilename string, r *http.Request) error { log.WithField("file", tempFilename).Info("Handling chunked upload to temporary file") // Ensure the directory exists absDirectory := filepath.Dir(tempFilename) err := os.MkdirAll(absDirectory, os.ModePerm) if err != nil { log.WithError(err).Errorf("Failed to create directory %s for chunked upload", absDirectory) return fmt.Errorf("failed to create directory %s: %w", absDirectory, err) } targetFile, err := os.OpenFile(tempFilename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) if err != nil { log.WithError(err).Error("Failed to open temporary file for chunked upload") return err } defer targetFile.Close() writer := bufio.NewWriterSize(targetFile, int(conf.Uploads.ChunkSize)) buffer := make([]byte, conf.Uploads.ChunkSize) totalBytes := int64(0) for { n, err := r.Body.Read(buffer) if n > 0 { totalBytes += int64(n) _, writeErr := writer.Write(buffer[:n]) if writeErr != nil { log.WithError(writeErr).Error("Failed to write chunk to temporary file") return writeErr } } if err != nil { if err == io.EOF { break // Finished reading the body } log.WithError(err).Error("Error reading from request body") return err } } err = writer.Flush() if err != nil { log.WithError(err).Error("Failed to flush buffer to temporary file") return err } log.WithFields(logrus.Fields{ "temp_file": tempFilename, "total_bytes": totalBytes, }).Info("Chunked upload completed successfully") uploadSizeBytes.Observe(float64(totalBytes)) return nil } // Get file information with caching func getFileInfo(absFilename string) (os.FileInfo, error) { if cachedInfo, found := fileInfoCache.Get(absFilename); found { if info, ok := cachedInfo.(os.FileInfo); ok { return info, nil } } fileInfo, err := os.Stat(absFilename) if err != nil { return nil, err } fileInfoCache.Set(absFilename, fileInfo, cache.DefaultExpiration) return fileInfo, nil } // Monitor network changes func monitorNetwork(ctx context.Context) { currentIP := getCurrentIPAddress() // Placeholder for the current IP address for { select { case <-ctx.Done(): log.Info("Stopping network monitor.") return case <-time.After(10 * time.Second): newIP := getCurrentIPAddress() if newIP != currentIP && newIP != "" { currentIP = newIP select { case networkEvents <- NetworkEvent{Type: "IP_CHANGE", Details: currentIP}: log.WithField("new_ip", currentIP).Info("Queued IP_CHANGE event") default: log.Warn("Network event channel is full. Dropping IP_CHANGE event.") } } } } } // Handle network events func handleNetworkEvents(ctx context.Context) { for { select { case <-ctx.Done(): log.Info("Stopping network event handler.") return case event, ok := <-networkEvents: if !ok { log.Info("Network events channel closed.") return } switch event.Type { case "IP_CHANGE": log.WithField("new_ip", event.Details).Info("Network change detected") // Example: Update Prometheus gauge or trigger alerts // activeConnections.Set(float64(getActiveConnections())) } // Additional event types can be handled here } } } // Get current IP address (example) func getCurrentIPAddress() string { interfaces, err := net.Interfaces() if err != nil { log.WithError(err).Error("Failed to get network interfaces") return "" } for _, iface := range interfaces { if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagLoopback != 0 { continue // Skip interfaces that are down or loopback } addrs, err := iface.Addrs() if err != nil { log.WithError(err).Errorf("Failed to get addresses for interface %s", iface.Name) continue } for _, addr := range addrs { if ipnet, ok := addr.(*net.IPNet); ok && ipnet.IP.IsGlobalUnicast() && ipnet.IP.To4() != nil { return ipnet.IP.String() } } } return "" } // setupGracefulShutdown sets up handling for graceful server shutdown func setupGracefulShutdown(server *http.Server, cancel context.CancelFunc) { quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) go func() { sig := <-quit log.Infof("Received signal %s. Initiating shutdown...", sig) // Create a deadline to wait for. ctxShutdown, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second) defer shutdownCancel() // Attempt graceful shutdown if err := server.Shutdown(ctxShutdown); err != nil { log.Errorf("Server shutdown failed: %v", err) } else { log.Info("Server shutdown gracefully.") } // Signal other goroutines to stop cancel() // Close the upload, scan, and network event channels close(uploadQueue) log.Info("Upload queue closed.") close(scanQueue) log.Info("Scan queue closed.") close(networkEvents) log.Info("Network events channel closed.") log.Info("Shutdown process completed. Exiting application.") os.Exit(0) }() } // Initialize Redis client func initRedis() { if !conf.Redis.RedisEnabled { log.Info("Redis is disabled in configuration.") return } redisClient = redis.NewClient(&redis.Options{ Addr: conf.Redis.RedisAddr, Password: conf.Redis.RedisPassword, DB: conf.Redis.RedisDBIndex, }) // Test the Redis connection ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() _, err := redisClient.Ping(ctx).Result() if err != nil { log.Fatalf("Failed to connect to Redis: %v", err) } log.Info("Connected to Redis successfully") // Set initial connection status mu.Lock() redisConnected = true mu.Unlock() // Start monitoring Redis health go MonitorRedisHealth(context.Background(), redisClient, parseDuration(conf.Redis.RedisHealthCheckInterval)) } // MonitorRedisHealth periodically checks Redis connectivity and updates redisConnected status. func MonitorRedisHealth(ctx context.Context, client *redis.Client, checkInterval time.Duration) { ticker := time.NewTicker(checkInterval) defer ticker.Stop() for { select { case <-ctx.Done(): log.Info("Stopping Redis health monitor.") return case <-ticker.C: err := client.Ping(ctx).Err() mu.Lock() if err != nil { if redisConnected { log.Errorf("Redis health check failed: %v", err) } redisConnected = false } else { if !redisConnected { log.Info("Redis reconnected successfully") } redisConnected = true log.Debug("Redis health check succeeded.") } mu.Unlock() } } } // Helper function to parse duration strings func parseDuration(durationStr string) time.Duration { duration, err := time.ParseDuration(durationStr) if err != nil { log.WithError(err).Warn("Invalid duration format, using default 30s") return 30 * time.Second } return duration } // RunFileCleaner periodically deletes files that exceed the FileTTL duration. func runFileCleaner(ctx context.Context, storeDir string, ttl time.Duration) { ticker := time.NewTicker(1 * time.Hour) defer ticker.Stop() for { select { case <-ctx.Done(): log.Info("Stopping file cleaner.") return case <-ticker.C: now := time.Now() err := filepath.Walk(storeDir, func(path string, info os.FileInfo, err error) error { if err != nil { return err } if info.IsDir() { return nil } if now.Sub(info.ModTime()) > ttl { err := os.Remove(path) if err != nil { log.WithError(err).Errorf("Failed to remove expired file: %s", path) } else { log.Infof("Removed expired file: %s", path) } } return nil }) if err != nil { log.WithError(err).Error("Error walking store directory for file cleaning") } } } } // DeduplicateFiles scans the store directory and removes duplicate files based on SHA256 hash. // It retains one copy of each unique file and replaces duplicates with hard links. func DeduplicateFiles(storeDir string) error { hashMap := make(map[string]string) // map[hash]filepath var mu sync.Mutex var wg sync.WaitGroup fileChan := make(chan string, 100) // Worker to process files numWorkers := 10 for i := 0; i < numWorkers; i++ { wg.Add(1) go func() { defer wg.Done() for filePath := range fileChan { hash, err := computeFileHash(filePath) if err != nil { logrus.WithError(err).Errorf("Failed to compute hash for %s", filePath) continue } mu.Lock() original, exists := hashMap[hash] if !exists { hashMap[hash] = filePath mu.Unlock() continue } mu.Unlock() // Duplicate found err = os.Remove(filePath) if err != nil { logrus.WithError(err).Errorf("Failed to remove duplicate file %s", filePath) continue } // Create hard link to the original file err = os.Link(original, filePath) if err != nil { logrus.WithError(err).Errorf("Failed to create hard link from %s to %s", original, filePath) continue } logrus.Infof("Removed duplicate %s and linked to %s", filePath, original) } }() } // Walk through the store directory err := filepath.Walk(storeDir, func(path string, info os.FileInfo, err error) error { if err != nil { logrus.WithError(err).Errorf("Error accessing path %s", path) return nil } if !info.Mode().IsRegular() { return nil } fileChan <- path return nil }) if err != nil { return fmt.Errorf("error walking the path %s: %w", storeDir, err) } close(fileChan) wg.Wait() return nil } // computeFileHash computes the SHA256 hash of the given file. func computeFileHash(filePath string) (string, error) { file, err := os.Open(filePath) if err != nil { return "", fmt.Errorf("unable to open file %s: %w", filePath, err) } defer file.Close() hasher := sha256.New() if _, err := io.Copy(hasher, file); err != nil { return "", fmt.Errorf("error hashing file %s: %w", filePath, err) } return hex.EncodeToString(hasher.Sum(nil)), nil } // Handle multipart uploads func handleMultipartUpload(w http.ResponseWriter, r *http.Request, absFilename string) error { err := r.ParseMultipartForm(32 << 20) // 32MB is the default used by FormFile if err != nil { log.WithError(err).Error("Failed to parse multipart form") http.Error(w, "Failed to parse multipart form", http.StatusBadRequest) return err } file, handler, err := r.FormFile("file") if err != nil { log.WithError(err).Error("Failed to retrieve file from form data") http.Error(w, "Failed to retrieve file from form data", http.StatusBadRequest) return err } defer file.Close() // Validate file extension if !isExtensionAllowed(handler.Filename) { log.WithFields(logrus.Fields{ "filename": handler.Filename, "extension": filepath.Ext(handler.Filename), }).Warn("Attempted upload with disallowed file extension") http.Error(w, "Disallowed file extension. Allowed extensions are: "+strings.Join(conf.Uploads.AllowedExtensions, ", "), http.StatusForbidden) uploadErrorsTotal.Inc() return fmt.Errorf("disallowed file extension") } // Create a temporary file tempFilename := absFilename + ".tmp" tempFile, err := os.OpenFile(tempFilename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666) if err != nil { log.WithError(err).Error("Failed to create temporary file") http.Error(w, "Failed to create temporary file", http.StatusInternalServerError) return err } defer tempFile.Close() // Copy the uploaded file to the temporary file _, err = io.Copy(tempFile, file) if err != nil { log.WithError(err).Error("Failed to copy uploaded file to temporary file") http.Error(w, "Failed to copy uploaded file", http.StatusInternalServerError) return err } // Perform ClamAV scan on the temporary file if clamClient != nil { err := scanFileWithClamAV(tempFilename) if err != nil { log.WithFields(logrus.Fields{ "file": tempFilename, "error": err, }).Warn("ClamAV detected a virus or scan failed") os.Remove(tempFilename) uploadErrorsTotal.Inc() return err } } // Handle file versioning if enabled if conf.Versioning.EnableVersioning { existing, _ := fileExists(absFilename) if existing { err := versionFile(absFilename) if err != nil { log.WithFields(logrus.Fields{ "file": absFilename, "error": err, }).Error("Error versioning file") os.Remove(tempFilename) return err } } } // Move the temporary file to the final destination err = os.Rename(tempFilename, absFilename) if err != nil { log.WithFields(logrus.Fields{ "temp_file": tempFilename, "final_file": absFilename, "error": err, }).Error("Failed to move file to final destination") os.Remove(tempFilename) return err } log.WithFields(logrus.Fields{ "file": absFilename, }).Info("File uploaded and scanned successfully") uploadsTotal.Inc() return nil } // sanitizeFilePath ensures that the file path is within the designated storage directory func sanitizeFilePath(baseDir, filePath string) (string, error) { // Resolve the absolute path absBaseDir, err := filepath.Abs(baseDir) if err != nil { return "", fmt.Errorf("failed to resolve base directory: %w", err) } absFilePath, err := filepath.Abs(filepath.Join(absBaseDir, filePath)) if err != nil { return "", fmt.Errorf("failed to resolve file path: %w", err) } // Check if the resolved file path is within the base directory if !strings.HasPrefix(absFilePath, absBaseDir) { return "", fmt.Errorf("invalid file path: %s", filePath) } return absFilePath, nil } // checkStorageSpace ensures that there is enough free space in the storage path func checkStorageSpace(storagePath string, minFreeBytes int64) error { var stat syscall.Statfs_t err := syscall.Statfs(storagePath, &stat) if err != nil { return fmt.Errorf("failed to get filesystem stats: %w", err) } // Calculate available bytes availableBytes := stat.Bavail * uint64(stat.Bsize) if int64(availableBytes) < minFreeBytes { return fmt.Errorf("not enough free space: %d bytes available, %d bytes required", availableBytes, minFreeBytes) } return nil } // Function to compute SHA256 checksum of a file func computeSHA256(filePath string) (string, error) { file, err := os.Open(filePath) if err != nil { return "", fmt.Errorf("failed to open file for checksum: %w", err) } defer file.Close() hasher := sha256.New() if _, err := io.Copy(hasher, file); err != nil { return "", fmt.Errorf("failed to compute checksum: %w", err) } return hex.EncodeToString(hasher.Sum(nil)), nil } // handleDeduplication handles file deduplication using SHA256 checksum and hard links func handleDeduplication(ctx context.Context, absFilename string) error { // Compute checksum of the uploaded file checksum, err := computeSHA256(absFilename) if err != nil { log.Errorf("Failed to compute SHA256 for %s: %v", absFilename, err) return fmt.Errorf("checksum computation failed: %w", err) } log.Debugf("Computed checksum for %s: %s", absFilename, checksum) // Check Redis for existing checksum existingPath, err := redisClient.Get(ctx, checksum).Result() if err != nil && err != redis.Nil { log.Errorf("Redis error while fetching checksum %s: %v", checksum, err) return fmt.Errorf("redis error: %w", err) } if err != redis.Nil { // Duplicate found, create hard link log.Infof("Duplicate detected: %s already exists at %s", absFilename, existingPath) err = os.Link(existingPath, absFilename) if err != nil { log.Errorf("Failed to create hard link from %s to %s: %v", existingPath, absFilename, err) return fmt.Errorf("failed to create hard link: %w", err) } log.Infof("Created hard link from %s to %s", existingPath, absFilename) return nil } // No duplicate found, store checksum in Redis err = redisClient.Set(ctx, checksum, absFilename, 0).Err() if err != nil { log.Errorf("Failed to store checksum %s in Redis: %v", checksum, err) return fmt.Errorf("failed to store checksum in Redis: %w", err) } log.Infof("Stored new file checksum in Redis: %s -> %s", checksum, absFilename) return nil }