release: hmac-file-server 3.2

This commit is contained in:
2025-06-13 04:24:11 +02:00
parent cc3a4f4dd7
commit 16f50940d0
34 changed files with 10354 additions and 2255 deletions

View File

@@ -8,140 +8,285 @@ import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"mime"
"net"
"net/http"
"net/url"
"os"
"os/signal"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"syscall"
"time"
"sync"
"github.com/dutchcoders/go-clamd" // ClamAV integration
"github.com/go-redis/redis/v8" // Redis integration
jwt "github.com/golang-jwt/jwt/v5"
"github.com/patrickmn/go-cache"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/shirou/gopsutil/cpu"
"github.com/shirou/gopsutil/disk"
"github.com/shirou/gopsutil/host"
"github.com/shirou/gopsutil/mem"
"github.com/sirupsen/logrus"
"github.com/spf13/viper"
)
// Configuration structure
// parseSize converts a human-readable size string to bytes
func parseSize(sizeStr string) (int64, error) {
sizeStr = strings.TrimSpace(sizeStr)
if len(sizeStr) < 2 {
return 0, fmt.Errorf("invalid size string: %s", sizeStr)
}
unit := strings.ToUpper(sizeStr[len(sizeStr)-2:])
valueStr := sizeStr[:len(sizeStr)-2]
value, err := strconv.Atoi(valueStr)
if err != nil {
return 0, fmt.Errorf("invalid size value: %v", err)
}
switch unit {
case "KB":
return int64(value) * 1024, nil
case "MB":
return int64(value) * 1024 * 1024, nil
case "GB":
return int64(value) * 1024 * 1024 * 1024, nil
default:
return 0, fmt.Errorf("unknown size unit: %s", unit)
}
}
// parseTTL converts a human-readable TTL string to a time.Duration
func parseTTL(ttlStr string) (time.Duration, error) {
ttlStr = strings.ToLower(strings.TrimSpace(ttlStr))
if ttlStr == "" {
return 0, fmt.Errorf("TTL string cannot be empty")
}
var valueStr string
var unit rune
for _, r := range ttlStr {
if r >= '0' && r <= '9' {
valueStr += string(r)
} else {
unit = r
break
}
}
val, err := strconv.Atoi(valueStr)
if err != nil {
return 0, fmt.Errorf("invalid TTL value: %v", err)
}
switch unit {
case 's':
return time.Duration(val) * time.Second, nil
case 'm':
return time.Duration(val) * time.Minute, nil
case 'h':
return time.Duration(val) * time.Hour, nil
case 'd':
return time.Duration(val) * 24 * time.Hour, nil
case 'w':
return time.Duration(val) * 7 * 24 * time.Hour, nil
case 'y':
return time.Duration(val) * 365 * 24 * time.Hour, nil
default:
return 0, fmt.Errorf("unknown TTL unit: %c", unit)
}
}
// Configuration structures
type ServerConfig struct {
ListenPort string `mapstructure:"ListenPort"`
UnixSocket bool `mapstructure:"UnixSocket"`
StoragePath string `mapstructure:"StoragePath"`
LogLevel string `mapstructure:"LogLevel"`
LogFile string `mapstructure:"LogFile"`
MetricsEnabled bool `mapstructure:"MetricsEnabled"`
MetricsPort string `mapstructure:"MetricsPort"`
FileTTL string `mapstructure:"FileTTL"`
MinFreeBytes int64 `mapstructure:"MinFreeBytes"` // Minimum free bytes required
DeduplicationEnabled bool `mapstructure:"DeduplicationEnabled"`
}
type TimeoutConfig struct {
ReadTimeout string `mapstructure:"ReadTimeout"`
WriteTimeout string `mapstructure:"WriteTimeout"`
IdleTimeout string `mapstructure:"IdleTimeout"`
}
type SecurityConfig struct {
Secret string `mapstructure:"Secret"`
}
type VersioningConfig struct {
EnableVersioning bool `mapstructure:"EnableVersioning"`
MaxVersions int `mapstructure:"MaxVersions"`
ListenAddress string `toml:"listenport" mapstructure:"listenport"` // Fixed to match config file field
StoragePath string `toml:"storagepath" mapstructure:"storagepath"` // Fixed to match config
MetricsEnabled bool `toml:"metricsenabled" mapstructure:"metricsenabled"` // Fixed to match config
MetricsPath string `toml:"metrics_path" mapstructure:"metrics_path"`
PidFile string `toml:"pid_file" mapstructure:"pid_file"`
MaxUploadSize string `toml:"max_upload_size" mapstructure:"max_upload_size"`
MaxHeaderBytes int `toml:"max_header_bytes" mapstructure:"max_header_bytes"`
CleanupInterval string `toml:"cleanup_interval" mapstructure:"cleanup_interval"`
MaxFileAge string `toml:"max_file_age" mapstructure:"max_file_age"`
PreCache bool `toml:"pre_cache" mapstructure:"pre_cache"`
PreCacheWorkers int `toml:"pre_cache_workers" mapstructure:"pre_cache_workers"`
PreCacheInterval string `toml:"pre_cache_interval" mapstructure:"pre_cache_interval"`
GlobalExtensions []string `toml:"global_extensions" mapstructure:"global_extensions"`
DeduplicationEnabled bool `toml:"deduplication_enabled" mapstructure:"deduplication_enabled"`
MinFreeBytes string `toml:"min_free_bytes" mapstructure:"min_free_bytes"`
FileNaming string `toml:"file_naming" mapstructure:"file_naming"`
ForceProtocol string `toml:"force_protocol" mapstructure:"force_protocol"`
EnableDynamicWorkers bool `toml:"enable_dynamic_workers" mapstructure:"enable_dynamic_workers"`
WorkerScaleUpThresh int `toml:"worker_scale_up_thresh" mapstructure:"worker_scale_up_thresh"`
WorkerScaleDownThresh int `toml:"worker_scale_down_thresh" mapstructure:"worker_scale_down_thresh"`
UnixSocket bool `toml:"unixsocket" mapstructure:"unixsocket"` // Added missing field from example/logs
MetricsPort string `toml:"metricsport" mapstructure:"metricsport"` // Fixed to match config
FileTTL string `toml:"filettl" mapstructure:"filettl"` // Fixed to match config
FileTTLEnabled bool `toml:"filettlenabled" mapstructure:"filettlenabled"` // Fixed to match config
AutoAdjustWorkers bool `toml:"autoadjustworkers" mapstructure:"autoadjustworkers"` // Fixed to match config
NetworkEvents bool `toml:"networkevents" mapstructure:"networkevents"` // Fixed to match config
PIDFilePath string `toml:"pidfilepath" mapstructure:"pidfilepath"` // Fixed to match config
CleanUponExit bool `toml:"clean_upon_exit" mapstructure:"clean_upon_exit"` // Added missing field
PreCaching bool `toml:"precaching" mapstructure:"precaching"` // Fixed to match config
BindIP string `toml:"bind_ip" mapstructure:"bind_ip"` // Added missing field
}
type UploadsConfig struct {
ResumableUploadsEnabled bool `mapstructure:"ResumableUploadsEnabled"`
ChunkedUploadsEnabled bool `mapstructure:"ChunkedUploadsEnabled"`
ChunkSize int64 `mapstructure:"ChunkSize"`
AllowedExtensions []string `mapstructure:"AllowedExtensions"`
AllowedExtensions []string `toml:"allowedextensions" mapstructure:"allowedextensions"`
ChunkedUploadsEnabled bool `toml:"chunkeduploadsenabled" mapstructure:"chunkeduploadsenabled"`
ChunkSize string `toml:"chunksize" mapstructure:"chunksize"`
ResumableUploadsEnabled bool `toml:"resumableuploadsenabled" mapstructure:"resumableuploadsenabled"`
MaxResumableAge string `toml:"max_resumable_age" mapstructure:"max_resumable_age"`
}
type DownloadsConfig struct {
AllowedExtensions []string `toml:"allowedextensions" mapstructure:"allowedextensions"`
ChunkedDownloadsEnabled bool `toml:"chunkeddownloadsenabled" mapstructure:"chunkeddownloadsenabled"`
ChunkSize string `toml:"chunksize" mapstructure:"chunksize"`
ResumableDownloadsEnabled bool `toml:"resumable_downloads_enabled" mapstructure:"resumable_downloads_enabled"`
}
type SecurityConfig struct {
Secret string `toml:"secret" mapstructure:"secret"`
EnableJWT bool `toml:"enablejwt" mapstructure:"enablejwt"` // Added EnableJWT field
JWTSecret string `toml:"jwtsecret" mapstructure:"jwtsecret"`
JWTAlgorithm string `toml:"jwtalgorithm" mapstructure:"jwtalgorithm"`
JWTExpiration string `toml:"jwtexpiration" mapstructure:"jwtexpiration"`
}
type LoggingConfig struct {
Level string `mapstructure:"level"`
File string `mapstructure:"file"`
MaxSize int `mapstructure:"max_size"`
MaxBackups int `mapstructure:"max_backups"`
MaxAge int `mapstructure:"max_age"`
Compress bool `mapstructure:"compress"`
}
type DeduplicationConfig struct {
Enabled bool `mapstructure:"enabled"`
Directory string `mapstructure:"directory"`
}
type ISOConfig struct {
Enabled bool `mapstructure:"enabled"`
MountPoint string `mapstructure:"mountpoint"`
Size string `mapstructure:"size"`
Charset string `mapstructure:"charset"`
ContainerFile string `mapstructure:"containerfile"` // Added missing field
}
type TimeoutConfig struct {
Read string `mapstructure:"readtimeout" toml:"readtimeout"`
Write string `mapstructure:"writetimeout" toml:"writetimeout"`
Idle string `mapstructure:"idletimeout" toml:"idletimeout"`
Shutdown string `mapstructure:"shutdown" toml:"shutdown"`
}
type VersioningConfig struct {
Enabled bool `mapstructure:"enableversioning" toml:"enableversioning"` // Corrected to match example config
Backend string `mapstructure:"backend" toml:"backend"`
MaxRevs int `mapstructure:"maxversions" toml:"maxversions"` // Corrected to match example config
}
type ClamAVConfig struct {
ClamAVEnabled bool `mapstructure:"ClamAVEnabled"`
ClamAVSocket string `mapstructure:"ClamAVSocket"`
NumScanWorkers int `mapstructure:"NumScanWorkers"`
ClamAVEnabled bool `mapstructure:"clamavenabled"`
ClamAVSocket string `mapstructure:"clamavsocket"`
NumScanWorkers int `mapstructure:"numscanworkers"`
ScanFileExtensions []string `mapstructure:"scanfileextensions"`
}
type RedisConfig struct {
RedisEnabled bool `mapstructure:"RedisEnabled"`
RedisDBIndex int `mapstructure:"RedisDBIndex"`
RedisAddr string `mapstructure:"RedisAddr"`
RedisPassword string `mapstructure:"RedisPassword"`
RedisHealthCheckInterval string `mapstructure:"RedisHealthCheckInterval"`
RedisEnabled bool `mapstructure:"redisenabled"`
RedisDBIndex int `mapstructure:"redisdbindex"`
RedisAddr string `mapstructure:"redisaddr"`
RedisPassword string `mapstructure:"redispassword"`
RedisHealthCheckInterval string `mapstructure:"redishealthcheckinterval"`
}
type WorkersConfig struct {
NumWorkers int `mapstructure:"NumWorkers"`
UploadQueueSize int `mapstructure:"UploadQueueSize"`
NumWorkers int `mapstructure:"numworkers"`
UploadQueueSize int `mapstructure:"uploadqueuesize"`
}
type FileConfig struct {
FileRevision int `mapstructure:"FileRevision"`
}
type BuildConfig struct {
Version string `mapstructure:"version"` // Updated version
}
// This is the main Config struct to be used
type Config struct {
Server ServerConfig `mapstructure:"server"`
Timeouts TimeoutConfig `mapstructure:"timeouts"`
Security SecurityConfig `mapstructure:"security"`
Versioning VersioningConfig `mapstructure:"versioning"`
Uploads UploadsConfig `mapstructure:"uploads"`
ClamAV ClamAVConfig `mapstructure:"clamav"`
Redis RedisConfig `mapstructure:"redis"`
Workers WorkersConfig `mapstructure:"workers"`
File FileConfig `mapstructure:"file"`
Server ServerConfig `mapstructure:"server"`
Logging LoggingConfig `mapstructure:"logging"`
Deduplication DeduplicationConfig `mapstructure:"deduplication"` // Added
ISO ISOConfig `mapstructure:"iso"` // Added
Timeouts TimeoutConfig `mapstructure:"timeouts"` // Added
Security SecurityConfig `mapstructure:"security"`
Versioning VersioningConfig `mapstructure:"versioning"` // Added
Uploads UploadsConfig `mapstructure:"uploads"`
Downloads DownloadsConfig `mapstructure:"downloads"`
ClamAV ClamAVConfig `mapstructure:"clamav"`
Redis RedisConfig `mapstructure:"redis"`
Workers WorkersConfig `mapstructure:"workers"`
File FileConfig `mapstructure:"file"`
Build BuildConfig `mapstructure:"build"`
}
// UploadTask represents a file upload task
type UploadTask struct {
AbsFilename string
Request *http.Request
Result chan error
}
// ScanTask represents a file scan task
type ScanTask struct {
AbsFilename string
Result chan error
}
// NetworkEvent represents a network-related event
type NetworkEvent struct {
Type string
Details string
}
var (
conf Config
versionString string = "v2.0-dev"
log = logrus.New()
uploadQueue chan UploadTask
networkEvents chan NetworkEvent
fileInfoCache *cache.Cache
clamClient *clamd.Clamd // Added for ClamAV integration
redisClient *redis.Client // Redis client
redisConnected bool // Redis connection status
mu sync.RWMutex
// Add a new field to store the creation date of files
type FileMetadata struct {
CreationDate time.Time
}
// processScan processes a scan task
func processScan(task ScanTask) error {
log.Infof("Started processing scan for file: %s", task.AbsFilename)
semaphore <- struct{}{}
defer func() { <-semaphore }()
err := scanFileWithClamAV(task.AbsFilename)
if err != nil {
log.WithFields(logrus.Fields{"file": task.AbsFilename, "error": err}).Error("Failed to scan file")
return err
}
log.Infof("Finished processing scan for file: %s", task.AbsFilename)
return nil
}
var (
conf Config
versionString string
log = logrus.New()
fileInfoCache *cache.Cache
fileMetadataCache *cache.Cache
clamClient *clamd.Clamd
redisClient *redis.Client
redisConnected bool
confMutex sync.RWMutex // Protects the global 'conf' variable and related critical sections.
// Use RLock() for reading, Lock() for writing.
// Prometheus metrics
uploadDuration prometheus.Histogram
uploadErrorsTotal prometheus.Counter
uploadsTotal prometheus.Counter
@@ -156,1768 +301,1692 @@ var (
uploadSizeBytes prometheus.Histogram
downloadSizeBytes prometheus.Histogram
// Constants for worker pool
MinWorkers = 5 // Increased from 10 to 20 for better concurrency
UploadQueueSize = 10000 // Increased from 5000 to 10000
filesDeduplicatedTotal prometheus.Counter
deduplicationErrorsTotal prometheus.Counter
isoContainersCreatedTotal prometheus.Counter
isoCreationErrorsTotal prometheus.Counter
isoContainersMountedTotal prometheus.Counter
isoMountErrorsTotal prometheus.Counter
// Channels
scanQueue chan ScanTask
ScanWorkers = 5 // Number of ClamAV scan workers
workerPool *WorkerPool
networkEvents chan NetworkEvent
workerAdjustmentsTotal prometheus.Counter
workerReAdjustmentsTotal prometheus.Counter
)
func main() {
// Set default configuration values
setDefaults()
var bufferPool = sync.Pool{
New: func() interface{} {
buf := make([]byte, 32*1024)
return &buf
},
}
const maxConcurrentOperations = 10
var semaphore = make(chan struct{}, maxConcurrentOperations)
var logMessages []string
var logMu sync.Mutex
func flushLogMessages() {
logMu.Lock()
defer logMu.Unlock()
for _, msg := range logMessages {
log.Info(msg)
}
logMessages = []string{}
}
// writePIDFile writes the current process ID to the specified pid file
func writePIDFile(pidPath string) error {
pid := os.Getpid()
pidStr := strconv.Itoa(pid)
err := os.WriteFile(pidPath, []byte(pidStr), 0644)
if err != nil {
log.Errorf("Failed to write PID file: %v", err) // Improved error logging
return err
}
log.Infof("PID %d written to %s", pid, pidPath)
return nil
}
// removePIDFile removes the PID file
func removePIDFile(pidPath string) {
err := os.Remove(pidPath)
if err != nil {
log.Errorf("Failed to remove PID file: %v", err) // Improved error logging
} else {
log.Infof("PID file %s removed successfully", pidPath)
}
}
// createAndMountISO creates an ISO container and mounts it to the specified mount point
func createAndMountISO(size, mountpoint, charset string) error {
isoPath := conf.ISO.ContainerFile
// Create an empty ISO file
cmd := exec.Command("dd", "if=/dev/zero", fmt.Sprintf("of=%s", isoPath), fmt.Sprintf("bs=%s", size), "count=1")
if err := cmd.Run(); err != nil {
isoCreationErrorsTotal.Inc()
return fmt.Errorf("failed to create ISO file: %w", err)
}
// Format the ISO file with a filesystem
cmd = exec.Command("mkfs", "-t", "iso9660", "-input-charset", charset, isoPath)
if err := cmd.Run(); err != nil {
return fmt.Errorf("failed to format ISO file: %w", err)
}
// Create the mount point directory if it doesn't exist
if err := os.MkdirAll(mountpoint, os.ModePerm); err != nil {
return fmt.Errorf("failed to create mount point: %w", err)
}
// Mount the ISO file
cmd = exec.Command("mount", "-o", "loop", isoPath, mountpoint)
if err := cmd.Run(); err != nil {
isoMountErrorsTotal.Inc()
return fmt.Errorf("failed to mount ISO file: %w", err)
}
isoContainersCreatedTotal.Inc()
isoContainersMountedTotal.Inc()
return nil
}
func initializeNetworkProtocol(forceProtocol string) (*net.Dialer, error) {
switch forceProtocol {
case "ipv4":
return &net.Dialer{
Timeout: 5 * time.Second,
DualStack: false,
Control: func(network, address string, c syscall.RawConn) error {
if network == "tcp6" {
return fmt.Errorf("IPv6 is disabled by forceprotocol setting")
}
return nil
},
}, nil
case "ipv6":
return &net.Dialer{
Timeout: 5 * time.Second,
DualStack: false,
Control: func(network, address string, c syscall.RawConn) error {
if network == "tcp4" {
return fmt.Errorf("IPv4 is disabled by forceprotocol setting")
}
return nil
},
}, nil
case "auto":
return &net.Dialer{
Timeout: 5 * time.Second,
DualStack: true,
}, nil
default:
return nil, fmt.Errorf("invalid forceprotocol value: %s", forceProtocol)
}
}
var dualStackClient *http.Client
func main() {
setDefaults() // Call setDefaults before parsing flags or reading config
// Flags for configuration file
var configFile string
flag.StringVar(&configFile, "config", "./config.toml", "Path to configuration file \"config.toml\".")
var genConfig bool
var genConfigPath string
var validateOnly bool
var runConfigTests bool
var validateQuiet bool
var validateVerbose bool
var validateFixable bool
var validateSecurity bool
var validatePerformance bool
var validateConnectivity bool
var listValidationChecks bool
var showVersion bool
flag.BoolVar(&genConfig, "genconfig", false, "Print example configuration and exit.")
flag.StringVar(&genConfigPath, "genconfig-path", "", "Write example configuration to the given file and exit.")
flag.BoolVar(&validateOnly, "validate-config", false, "Validate configuration and exit without starting server.")
flag.BoolVar(&runConfigTests, "test-config", false, "Run configuration validation test scenarios and exit.")
flag.BoolVar(&validateQuiet, "validate-quiet", false, "Only show errors during validation (suppress warnings and info).")
flag.BoolVar(&validateVerbose, "validate-verbose", false, "Show detailed validation information including system checks.")
flag.BoolVar(&validateFixable, "check-fixable", false, "Only show validation issues that can be automatically fixed.")
flag.BoolVar(&validateSecurity, "check-security", false, "Run only security-related validation checks.")
flag.BoolVar(&validatePerformance, "check-performance", false, "Run only performance-related validation checks.")
flag.BoolVar(&validateConnectivity, "check-connectivity", false, "Run only network connectivity validation checks.")
flag.BoolVar(&listValidationChecks, "list-checks", false, "List all available validation checks and exit.")
flag.BoolVar(&showVersion, "version", false, "Show version information and exit.")
flag.Parse()
// Load configuration
if showVersion {
fmt.Printf("HMAC File Server v3.2\n")
os.Exit(0)
}
if listValidationChecks {
printValidationChecks()
os.Exit(0)
}
if genConfig {
printExampleConfig()
os.Exit(0)
}
if genConfigPath != "" {
f, err := os.Create(genConfigPath)
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to create file: %v\n", err)
os.Exit(1)
}
defer f.Close()
w := bufio.NewWriter(f)
fmt.Fprint(w, getExampleConfigString())
w.Flush()
fmt.Printf("Example config written to %s\n", genConfigPath)
os.Exit(0)
}
if runConfigTests {
RunConfigTests()
os.Exit(0)
}
// Initialize Viper
viper.SetConfigType("toml")
// Set default config path
defaultConfigPath := "/etc/hmac-file-server/config.toml"
// Attempt to load the default config
viper.SetConfigFile(defaultConfigPath)
if err := viper.ReadInConfig(); err != nil {
// If default config not found, fallback to parent directory
parentDirConfig := "../config.toml"
viper.SetConfigFile(parentDirConfig)
if err := viper.ReadInConfig(); err != nil {
// If still not found and -config is provided, use it
if configFile != "" {
viper.SetConfigFile(configFile)
if err := viper.ReadInConfig(); err != nil {
fmt.Printf("Error loading config file: %v\n", err)
os.Exit(1)
}
} else {
fmt.Println("No configuration file found. Please create a config file with the following content:")
printExampleConfig()
os.Exit(1)
}
}
}
err := readConfig(configFile, &conf)
if err != nil {
log.Fatalf("Error reading config: %v", err) // Fatal: application cannot proceed
log.Fatalf("Failed to load configuration: %v\nPlease ensure your config.toml is present at one of the following paths:\n%v", err, []string{
"/etc/hmac-file-server/config.toml",
"../config.toml",
"./config.toml",
})
}
log.Info("Configuration loaded successfully.")
// Initialize file info cache
fileInfoCache = cache.New(5*time.Minute, 10*time.Minute)
// Create store directory
err = os.MkdirAll(conf.Server.StoragePath, os.ModePerm)
err = validateConfig(&conf)
if err != nil {
log.Fatalf("Error creating store directory: %v", err)
log.Fatalf("Configuration validation failed: %v", err)
}
log.WithField("directory", conf.Server.StoragePath).Info("Store directory is ready")
log.Info("Configuration validated successfully.")
// Perform comprehensive configuration validation
validationResult := ValidateConfigComprehensive(&conf)
PrintValidationResults(validationResult)
if validationResult.HasErrors() {
log.Fatal("Cannot start server due to configuration errors. Please fix the above issues and try again.")
}
// Handle specialized validation flags
if validateSecurity || validatePerformance || validateConnectivity || validateQuiet || validateVerbose || validateFixable {
runSpecializedValidation(&conf, validateSecurity, validatePerformance, validateConnectivity, validateQuiet, validateVerbose, validateFixable)
os.Exit(0)
}
// If only validation was requested, exit now
if validateOnly {
if validationResult.HasErrors() {
log.Error("Configuration validation failed with errors. Review the errors above.")
os.Exit(1)
} else if validationResult.HasWarnings() {
log.Info("Configuration is valid but has warnings. Review the warnings above.")
os.Exit(0)
} else {
log.Info("Configuration validation completed successfully!")
os.Exit(0)
}
}
// Set log level based on configuration
level, err := logrus.ParseLevel(conf.Logging.Level)
if err != nil {
log.Warnf("Invalid log level '%s', defaulting to 'info'", conf.Logging.Level)
level = logrus.InfoLevel
}
log.SetLevel(level)
log.Infof("Log level set to: %s", level.String())
// Log configuration settings using [logging] section
log.Infof("Server ListenAddress: %s", conf.Server.ListenAddress) // Corrected field name
log.Infof("Server UnixSocket: %v", conf.Server.UnixSocket)
log.Infof("Server StoragePath: %s", conf.Server.StoragePath)
log.Infof("Logging Level: %s", conf.Logging.Level)
log.Infof("Logging File: %s", conf.Logging.File)
log.Infof("Server MetricsEnabled: %v", conf.Server.MetricsEnabled)
log.Infof("Server MetricsPort: %s", conf.Server.MetricsPort) // Corrected field name
log.Infof("Server FileTTL: %s", conf.Server.FileTTL) // Corrected field name
log.Infof("Server MinFreeBytes: %s", conf.Server.MinFreeBytes)
log.Infof("Server AutoAdjustWorkers: %v", conf.Server.AutoAdjustWorkers) // Corrected field name
log.Infof("Server NetworkEvents: %v", conf.Server.NetworkEvents) // Corrected field name
log.Infof("Server PIDFilePath: %s", conf.Server.PIDFilePath) // Corrected field name
log.Infof("Server CleanUponExit: %v", conf.Server.CleanUponExit) // Corrected field name
log.Infof("Server PreCaching: %v", conf.Server.PreCaching) // Corrected field name
log.Infof("Server FileTTLEnabled: %v", conf.Server.FileTTLEnabled) // Corrected field name
log.Infof("Server DeduplicationEnabled: %v", conf.Server.DeduplicationEnabled)
log.Infof("Server BindIP: %s", conf.Server.BindIP) // Corrected field name
log.Infof("Server FileNaming: %s", conf.Server.FileNaming)
log.Infof("Server ForceProtocol: %s", conf.Server.ForceProtocol)
err = writePIDFile(conf.Server.PIDFilePath) // Corrected field name
if err != nil {
log.Fatalf("Error writing PID file: %v", err)
}
log.Debug("DEBUG: PID file written successfully")
log.Debugf("DEBUG: Config logging file: %s", conf.Logging.File)
// Setup logging
setupLogging()
log.Debug("DEBUG: Logging setup completed")
// Log system information
logSystemInfo()
log.Debug("DEBUG: System info logged")
// Initialize Prometheus metrics
// Initialize metrics before using any Prometheus counters
initMetrics()
log.Info("Prometheus metrics initialized.")
log.Debug("DEBUG: Metrics initialized")
// Initialize upload and scan queues
uploadQueue = make(chan UploadTask, conf.Workers.UploadQueueSize)
scanQueue = make(chan ScanTask, conf.Workers.UploadQueueSize)
networkEvents = make(chan NetworkEvent, 100)
log.Info("Upload, scan, and network event channels initialized.")
initializeWorkerSettings(&conf.Server, &conf.Workers, &conf.ClamAV)
log.Debug("DEBUG: Worker settings initialized")
// Context for goroutines
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Start network monitoring
go monitorNetwork(ctx)
go handleNetworkEvents(ctx)
// Update system metrics
go updateSystemMetrics(ctx)
// Initialize ClamAV client if enabled
if conf.ClamAV.ClamAVEnabled {
clamClient, err = initClamAV(conf.ClamAV.ClamAVSocket)
if conf.ISO.Enabled {
err := createAndMountISO(conf.ISO.Size, conf.ISO.MountPoint, conf.ISO.Charset)
if err != nil {
log.WithFields(logrus.Fields{
"error": err.Error(),
}).Warn("ClamAV client initialization failed. Continuing without ClamAV.")
} else {
log.Info("ClamAV client initialized successfully.")
log.Fatalf("Failed to create and mount ISO container: %v", err)
}
log.Infof("ISO container mounted at %s", conf.ISO.MountPoint)
}
// Initialize Redis client if enabled
if conf.Redis.RedisEnabled {
initRedis()
// Set storage path to ISO mount point if ISO is enabled
storagePath := conf.Server.StoragePath
if conf.ISO.Enabled {
storagePath = conf.ISO.MountPoint
}
// Redis Initialization
initRedis()
log.Info("Redis client initialized and connected successfully.")
fileInfoCache = cache.New(5*time.Minute, 10*time.Minute)
fileMetadataCache = cache.New(5*time.Minute, 10*time.Minute)
// ClamAV Initialization
if conf.ClamAV.ClamAVEnabled {
clamClient, err = initClamAV(conf.ClamAV.ClamAVSocket)
if err != nil {
log.WithFields(logrus.Fields{
"error": err.Error(),
}).Warn("ClamAV client initialization failed. Continuing without ClamAV.")
} else {
log.Info("ClamAV client initialized successfully.")
}
}
// Initialize worker pools
initializeUploadWorkerPool(ctx)
if conf.ClamAV.ClamAVEnabled && clamClient != nil {
initializeScanWorkerPool(ctx)
}
// Start Redis health monitor if Redis is enabled
if conf.Redis.RedisEnabled && redisClient != nil {
go MonitorRedisHealth(ctx, redisClient, parseDuration(conf.Redis.RedisHealthCheckInterval))
}
// Setup router
router := setupRouter()
// Start file cleaner
fileTTL, err := time.ParseDuration(conf.Server.FileTTL)
if err != nil {
log.Fatalf("Invalid FileTTL: %v", err)
}
go runFileCleaner(ctx, conf.Server.StoragePath, fileTTL)
// Parse timeout durations
readTimeout, err := time.ParseDuration(conf.Timeouts.ReadTimeout)
if err != nil {
log.Fatalf("Invalid ReadTimeout: %v", err)
}
writeTimeout, err := time.ParseDuration(conf.Timeouts.WriteTimeout)
if err != nil {
log.Fatalf("Invalid WriteTimeout: %v", err)
}
idleTimeout, err := time.ParseDuration(conf.Timeouts.IdleTimeout)
if err != nil {
log.Fatalf("Invalid IdleTimeout: %v", err)
}
// Configure HTTP server
server := &http.Server{
Addr: ":" + conf.Server.ListenPort, // Prepend colon to ListenPort
Handler: router,
ReadTimeout: readTimeout,
WriteTimeout: writeTimeout,
IdleTimeout: idleTimeout,
}
// Start metrics server if enabled
if conf.Server.MetricsEnabled {
if conf.Server.PreCaching { // Corrected field name
go func() {
http.Handle("/metrics", promhttp.Handler())
log.Infof("Metrics server started on port %s", conf.Server.MetricsPort)
if err := http.ListenAndServe(":"+conf.Server.MetricsPort, nil); err != nil {
log.Fatalf("Metrics server failed: %v", err)
log.Info("Starting pre-caching of storage path...")
// Use helper function
err := precacheStoragePath(storagePath)
if err != nil {
log.Warnf("Pre-caching storage path failed: %v", err)
} else {
log.Info("Pre-cached all files in the storage path.")
log.Info("Pre-caching status: complete.")
}
}()
}
// Setup graceful shutdown
setupGracefulShutdown(server, cancel)
err = os.MkdirAll(storagePath, os.ModePerm)
if err != nil {
log.Fatalf("Error creating store directory: %v", err)
}
log.WithField("directory", storagePath).Info("Store directory is ready")
// Use helper function
err = checkFreeSpaceWithRetry(storagePath, 3, 5*time.Second)
if err != nil {
log.Fatalf("Insufficient free space: %v", err)
}
initializeWorkerSettings(&conf.Server, &conf.Workers, &conf.ClamAV)
log.Info("Prometheus metrics initialized.")
networkEvents = make(chan NetworkEvent, 100)
log.Info("Network event channel initialized.")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if conf.Server.NetworkEvents { // Corrected field name
go monitorNetwork(ctx) // Assuming monitorNetwork is defined in helpers.go or elsewhere
go handleNetworkEvents(ctx) // Assuming handleNetworkEvents is defined in helpers.go or elsewhere
}
go updateSystemMetrics(ctx)
if conf.ClamAV.ClamAVEnabled {
var clamErr error
clamClient, clamErr = initClamAV(conf.ClamAV.ClamAVSocket) // Assuming initClamAV is defined in helpers.go or elsewhere
if clamErr != nil {
log.WithError(clamErr).Warn("ClamAV client initialization failed. Continuing without ClamAV.")
} else {
log.Info("ClamAV client initialized successfully.")
}
}
if conf.Redis.RedisEnabled {
initRedis() // Assuming initRedis is defined in helpers.go or elsewhere
}
router := setupRouter() // Assuming setupRouter is defined (likely in this file or router.go
go handleFileCleanup(&conf) // Directly call handleFileCleanup
readTimeout, err := time.ParseDuration(conf.Timeouts.Read) // Corrected field name
if err != nil {
log.Fatalf("Invalid ReadTimeout: %v", err)
}
writeTimeout, err := time.ParseDuration(conf.Timeouts.Write) // Corrected field name
if err != nil {
log.Fatalf("Invalid WriteTimeout: %v", err)
}
idleTimeout, err := time.ParseDuration(conf.Timeouts.Idle) // Corrected field name
if err != nil {
log.Fatalf("Invalid IdleTimeout: %v", err)
}
// Initialize network protocol based on forceprotocol setting
dialer, err := initializeNetworkProtocol(conf.Server.ForceProtocol)
if err != nil {
log.Fatalf("Failed to initialize network protocol: %v", err)
}
// Enhanced dual-stack HTTP client for robust IPv4/IPv6 and resource management
// See: https://pkg.go.dev/net/http#Transport for details on these settings
dualStackClient = &http.Client{
Transport: &http.Transport{
DialContext: dialer.DialContext,
IdleConnTimeout: 90 * time.Second, // Close idle connections after 90s
MaxIdleConns: 100, // Max idle connections across all hosts
MaxIdleConnsPerHost: 10, // Max idle connections per host
TLSHandshakeTimeout: 10 * time.Second, // Timeout for TLS handshake
ResponseHeaderTimeout: 15 * time.Second, // Timeout for reading response headers
},
}
server := &http.Server{
Addr: conf.Server.BindIP + ":" + conf.Server.ListenAddress, // Use BindIP + ListenAddress (port)
Handler: router,
ReadTimeout: readTimeout,
WriteTimeout: writeTimeout,
IdleTimeout: idleTimeout,
MaxHeaderBytes: 1 << 20, // 1 MB
}
if conf.Server.MetricsEnabled {
var wg sync.WaitGroup
go func() {
http.Handle("/metrics", promhttp.Handler())
log.Infof("Metrics server started on port %s", conf.Server.MetricsPort) // Corrected field name
if err := http.ListenAndServe(":"+conf.Server.MetricsPort, nil); err != nil { // Corrected field name
log.Fatalf("Metrics server failed: %v", err)
}
wg.Wait()
}()
}
setupGracefulShutdown(server, cancel) // Assuming setupGracefulShutdown is defined
if conf.Server.AutoAdjustWorkers { // Corrected field name
go monitorWorkerPerformance(ctx, &conf.Server, &conf.Workers, &conf.ClamAV)
}
versionString = "3.2" // Set a default version for now
if conf.Build.Version != "" {
versionString = conf.Build.Version
}
log.Infof("Running version: %s", versionString)
// Start server
log.Infof("Starting HMAC file server %s...", versionString)
if conf.Server.UnixSocket {
// Listen on Unix socket
if err := os.RemoveAll(conf.Server.ListenPort); err != nil {
socketPath := "/tmp/hmac-file-server.sock" // Use a default socket path since ListenAddress is now a port
if err := os.RemoveAll(socketPath); err != nil {
log.Fatalf("Failed to remove existing Unix socket: %v", err)
}
listener, err := net.Listen("unix", conf.Server.ListenPort)
listener, err := net.Listen("unix", socketPath)
if err != nil {
log.Fatalf("Failed to listen on Unix socket %s: %v", conf.Server.ListenPort, err)
log.Fatalf("Failed to listen on Unix socket %s: %v", socketPath, err)
}
defer listener.Close()
log.Infof("Server listening on Unix socket: %s", socketPath)
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
log.Fatalf("Server failed: %v", err)
}
} else {
// Listen on TCP port
if conf.Server.BindIP == "0.0.0.0" {
log.Info("Binding to 0.0.0.0. Any net/http logs you see are normal for this universal address.")
}
log.Infof("Server listening on %s", server.Addr)
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("Server failed: %v", err)
}
}
// Start file cleanup in a separate goroutine
// Use helper function
go handleFileCleanup(&conf)
}
// Function to load configuration using Viper
func readConfig(configFilename string, conf *Config) error {
viper.SetConfigFile(configFilename)
viper.SetConfigType("toml")
func printExampleConfig() {
fmt.Print(`
[server]
bind_ip = "0.0.0.0"
listenport = "8080"
unixsocket = false
storagepath = "./uploads"
logfile = "/var/log/hmac-file-server.log"
metricsenabled = true
metricsport = "9090"
minfreebytes = "100MB"
filettl = "8760h"
filettlenabled = true
autoadjustworkers = true
networkevents = true
pidfilepath = "/var/run/hmacfileserver.pid"
cleanuponexit = true
precaching = true
deduplicationenabled = true
globalextensions = [".txt", ".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp"]
# FileNaming options: "HMAC", "None"
filenaming = "HMAC"
forceprotocol = "auto"
// Read in environment variables that match
viper.AutomaticEnv()
viper.SetEnvPrefix("HMAC") // Prefix for environment variables
[logging]
level = "info"
file = "/var/log/hmac-file-server.log"
max_size = 100
max_backups = 7
max_age = 30
compress = true
// Read the config file
if err := viper.ReadInConfig(); err != nil {
return fmt.Errorf("error reading config file: %w", err)
}
[deduplication]
enabled = true
directory = "./deduplication"
// Unmarshal the config into the Config struct
if err := viper.Unmarshal(conf); err != nil {
return fmt.Errorf("unable to decode into struct: %w", err)
}
[iso]
enabled = true
size = "1GB"
mountpoint = "/mnt/iso"
charset = "utf-8"
containerfile = "/mnt/iso/container.iso"
// Debug log the loaded configuration
log.Debugf("Loaded Configuration: %+v", conf.Server)
[timeouts]
readtimeout = "4800s"
writetimeout = "4800s"
idletimeout = "4800s"
// Validate the configuration
if err := validateConfig(conf); err != nil {
return fmt.Errorf("configuration validation failed: %w", err)
}
[security]
secret = "changeme"
enablejwt = false
jwtsecret = "anothersecretkey"
jwtalgorithm = "HS256"
jwtexpiration = "24h"
// Set Deduplication Enabled
conf.Server.DeduplicationEnabled = viper.GetBool("deduplication.Enabled")
[versioning]
enableversioning = false
maxversions = 1
return nil
[uploads]
resumableuploadsenabled = true
chunkeduploadsenabled = true
chunksize = "8192"
allowedextensions = [".txt", ".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp"]
[downloads]
resumabledownloadsenabled = true
chunkeddownloadsenabled = true
chunksize = "8192"
allowedextensions = [".txt", ".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp"]
[clamav]
clamavenabled = true
clamavsocket = "/var/run/clamav/clamd.ctl"
numscanworkers = 2
scanfileextensions = [".txt", ".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp"]
[redis]
redisenabled = true
redisdbindex = 0
redisaddr = "localhost:6379"
redispassword = ""
redishealthcheckinterval = "120s"
[workers]
numworkers = 4
uploadqueuesize = 50
[file]
# Add file-specific configurations here
[build]
version = "3.2"
`)
}
// Set default configuration values
func setDefaults() {
// Server defaults
viper.SetDefault("server.ListenPort", "8080")
viper.SetDefault("server.UnixSocket", false)
viper.SetDefault("server.StoragePath", "./uploads")
viper.SetDefault("server.LogLevel", "info")
viper.SetDefault("server.LogFile", "")
viper.SetDefault("server.MetricsEnabled", true)
viper.SetDefault("server.MetricsPort", "9090")
viper.SetDefault("server.FileTTL", "8760h") // 365d -> 8760h
viper.SetDefault("server.MinFreeBytes", 100<<20) // 100 MB
func getExampleConfigString() string {
return `[server]
listen_address = ":8080"
storage_path = "/srv/hmac-file-server/uploads"
metrics_enabled = true
metrics_path = "/metrics"
pid_file = "/var/run/hmac-file-server.pid"
max_upload_size = "10GB" # Supports B, KB, MB, GB, TB
max_header_bytes = 1048576 # 1MB
cleanup_interval = "24h"
max_file_age = "720h" # 30 days
pre_cache = true
pre_cache_workers = 4
pre_cache_interval = "1h"
global_extensions = [".txt", ".dat", ".iso"] # If set, overrides upload/download extensions
deduplication_enabled = true
min_free_bytes = "1GB" # Minimum free space required for uploads
file_naming = "original" # Options: "original", "HMAC"
force_protocol = "" # Options: "http", "https" - if set, redirects to this protocol
enable_dynamic_workers = true # Enable dynamic worker scaling
worker_scale_up_thresh = 50 # Queue length to scale up workers
worker_scale_down_thresh = 10 # Queue length to scale down workers
// Timeout defaults
viper.SetDefault("timeouts.ReadTimeout", "4800s") // supports 's'
viper.SetDefault("timeouts.WriteTimeout", "4800s")
viper.SetDefault("timeouts.IdleTimeout", "4800s")
[uploads]
allowed_extensions = [".zip", ".rar", ".7z", ".tar.gz", ".tgz", ".gpg", ".enc", ".pgp"]
chunked_uploads_enabled = true
chunk_size = "10MB"
resumable_uploads_enabled = true
max_resumable_age = "48h"
// Security defaults
viper.SetDefault("security.Secret", "changeme")
[downloads]
allowed_extensions = [".zip", ".rar", ".7z", ".tar.gz", ".tgz", ".gpg", ".enc", ".pgp"]
chunked_downloads_enabled = true
chunk_size = "10MB"
resumable_downloads_enabled = true
// Versioning defaults
viper.SetDefault("versioning.EnableVersioning", false)
viper.SetDefault("versioning.MaxVersions", 1)
[security]
secret = "your-very-secret-hmac-key"
enablejwt = false
jwtsecret = "anothersecretkey"
jwtalgorithm = "HS256"
jwtexpiration = "24h"
// Uploads defaults
viper.SetDefault("uploads.ResumableUploadsEnabled", true)
viper.SetDefault("uploads.ChunkedUploadsEnabled", true)
viper.SetDefault("uploads.ChunkSize", 8192)
viper.SetDefault("uploads.AllowedExtensions", []string{
".txt", ".pdf",
".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp",
".wav", ".mp4", ".avi", ".mkv", ".mov", ".wmv", ".flv", ".webm", ".mpeg", ".mpg", ".m4v", ".3gp", ".3g2",
".mp3", ".ogg",
})
[logging]
level = "info"
file = "/var/log/hmac-file-server.log"
max_size = 100
max_backups = 7
max_age = 30
compress = true
// ClamAV defaults
viper.SetDefault("clamav.ClamAVEnabled", true)
viper.SetDefault("clamav.ClamAVSocket", "/var/run/clamav/clamd.ctl")
viper.SetDefault("clamav.NumScanWorkers", 2)
[deduplication]
enabled = true
directory = "./deduplication"
// Redis defaults
viper.SetDefault("redis.RedisEnabled", true)
viper.SetDefault("redis.RedisAddr", "localhost:6379")
viper.SetDefault("redis.RedisPassword", "")
viper.SetDefault("redis.RedisDBIndex", 0)
viper.SetDefault("redis.RedisHealthCheckInterval", "120s")
[iso]
enabled = true
size = "1GB"
mountpoint = "/mnt/iso"
charset = "utf-8"
containerfile = "/mnt/iso/container.iso"
// Workers defaults
viper.SetDefault("workers.NumWorkers", 2)
viper.SetDefault("workers.UploadQueueSize", 50)
[timeouts]
readtimeout = "4800s"
writetimeout = "4800s"
idletimeout = "4800s"
// Deduplication defaults
viper.SetDefault("deduplication.Enabled", true)
[security]
secret = "changeme"
enablejwt = false
jwtsecret = "anothersecretkey"
jwtalgorithm = "HS256"
jwtexpiration = "24h"
[versioning]
enableversioning = false
maxversions = 1
[uploads]
resumableuploadsenabled = true
chunkeduploadsenabled = true
chunksize = "8192"
allowedextensions = [".txt", ".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp"]
[downloads]
resumabledownloadsenabled = true
chunkeddownloadsenabled = true
chunksize = "8192"
allowedextensions = [".txt", ".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp"]
[clamav]
clamavenabled = true
clamavsocket = "/var/run/clamav/clamd.ctl"
numscanworkers = 2
scanfileextensions = [".txt", ".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp"]
[redis]
redisenabled = true
redisdbindex = 0
redisaddr = "localhost:6379"
redispassword = ""
redishealthcheckinterval = "120s"
[workers]
numworkers = 4
uploadqueuesize = 50
[file]
# Add file-specific configurations here
[build]
version = "3.2"
`
}
// Validate configuration fields
func validateConfig(conf *Config) error {
if conf.Server.ListenPort == "" {
return fmt.Errorf("ListenPort must be set")
func max(a, b int) int {
if a > b {
return a
}
if conf.Security.Secret == "" {
return fmt.Errorf("secret must be set")
}
if conf.Server.StoragePath == "" {
return fmt.Errorf("StoragePath must be set")
}
if conf.Server.FileTTL == "" {
return fmt.Errorf("FileTTL must be set")
}
// Validate timeouts
if _, err := time.ParseDuration(conf.Timeouts.ReadTimeout); err != nil {
return fmt.Errorf("invalid ReadTimeout: %v", err)
}
if _, err := time.ParseDuration(conf.Timeouts.WriteTimeout); err != nil {
return fmt.Errorf("invalid WriteTimeout: %v", err)
}
if _, err := time.ParseDuration(conf.Timeouts.IdleTimeout); err != nil {
return fmt.Errorf("invalid IdleTimeout: %v", err)
}
// Validate Redis configuration if enabled
if conf.Redis.RedisEnabled {
if conf.Redis.RedisAddr == "" {
return fmt.Errorf("RedisAddr must be set when Redis is enabled")
}
}
// Add more validations as needed
return nil
return b
}
// Setup logging
func setupLogging() {
level, err := logrus.ParseLevel(conf.Server.LogLevel)
if err != nil {
log.Fatalf("Invalid log level: %s", conf.Server.LogLevel)
}
log.SetLevel(level)
if conf.Server.LogFile != "" {
logFile, err := os.OpenFile(conf.Server.LogFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
if err != nil {
log.Fatalf("Failed to open log file: %v", err)
}
log.SetOutput(io.MultiWriter(os.Stdout, logFile))
} else {
log.SetOutput(os.Stdout)
}
// Use Text formatter for human-readable logs
log.SetFormatter(&logrus.TextFormatter{
FullTimestamp: true,
// You can customize the format further if needed
})
}
// Log system information
func logSystemInfo() {
log.Info("========================================")
log.Infof(" HMAC File Server - %s ", versionString)
log.Info(" Secure File Handling with HMAC Auth ")
log.Info("========================================")
log.Info("Features: Prometheus Metrics, Chunked Uploads, ClamAV Scanning")
log.Info("Build Date: 2024-10-28")
log.Infof("Operating System: %s", runtime.GOOS)
log.Infof("Architecture: %s", runtime.GOARCH)
log.Infof("Number of CPUs: %d", runtime.NumCPU())
log.Infof("Go Version: %s", runtime.Version())
func autoAdjustWorkers() (int, int) {
v, _ := mem.VirtualMemory()
log.Infof("Total Memory: %v MB", v.Total/1024/1024)
log.Infof("Free Memory: %v MB", v.Free/1024/1024)
log.Infof("Used Memory: %v MB", v.Used/1024/1024)
cpuCores, _ := cpu.Counts(true)
cpuInfo, _ := cpu.Info()
for _, info := range cpuInfo {
log.Infof("CPU Model: %s, Cores: %d, Mhz: %f", info.ModelName, info.Cores, info.Mhz)
numWorkers := cpuCores * 2
if v.Available < 4*1024*1024*1024 { // Less than 4GB available
numWorkers = max(numWorkers/2, 1)
} else if v.Available < 8*1024*1024*1024 { // Less than 8GB available
numWorkers = max(numWorkers*3/4, 1)
}
queueSize := numWorkers * 10
partitions, _ := disk.Partitions(false)
for _, partition := range partitions {
usage, _ := disk.Usage(partition.Mountpoint)
log.Infof("Disk Mountpoint: %s, Total: %v GB, Free: %v GB, Used: %v GB",
partition.Mountpoint, usage.Total/1024/1024/1024, usage.Free/1024/1024/1024, usage.Used/1024/1024/1024)
}
hInfo, _ := host.Info()
log.Infof("Hostname: %s", hInfo.Hostname)
log.Infof("Uptime: %v seconds", hInfo.Uptime)
log.Infof("Boot Time: %v", time.Unix(int64(hInfo.BootTime), 0))
log.Infof("Platform: %s", hInfo.Platform)
log.Infof("Platform Family: %s", hInfo.PlatformFamily)
log.Infof("Platform Version: %s", hInfo.PlatformVersion)
log.Infof("Kernel Version: %s", hInfo.KernelVersion)
log.Infof("Auto-adjusting workers: NumWorkers=%d, UploadQueueSize=%d", numWorkers, queueSize)
workerAdjustmentsTotal.Inc()
return numWorkers, queueSize
}
// Initialize Prometheus metrics
// Duplicate initMetrics function removed
func initMetrics() {
uploadDuration = prometheus.NewHistogram(prometheus.HistogramOpts{Namespace: "hmac", Name: "file_server_upload_duration_seconds", Help: "Histogram of file upload duration in seconds.", Buckets: prometheus.DefBuckets})
uploadErrorsTotal = prometheus.NewCounter(prometheus.CounterOpts{Namespace: "hmac", Name: "file_server_upload_errors_total", Help: "Total number of file upload errors."})
uploadsTotal = prometheus.NewCounter(prometheus.CounterOpts{Namespace: "hmac", Name: "file_server_uploads_total", Help: "Total number of successful file uploads."})
downloadDuration = prometheus.NewHistogram(prometheus.HistogramOpts{Namespace: "hmac", Name: "file_server_download_duration_seconds", Help: "Histogram of file download duration in seconds.", Buckets: prometheus.DefBuckets})
downloadsTotal = prometheus.NewCounter(prometheus.CounterOpts{Namespace: "hmac", Name: "file_server_downloads_total", Help: "Total number of successful file downloads."})
downloadErrorsTotal = prometheus.NewCounter(prometheus.CounterOpts{Namespace: "hmac", Name: "file_server_download_errors_total", Help: "Total number of file download errors."})
memoryUsage = prometheus.NewGauge(prometheus.GaugeOpts{Namespace: "hmac", Name: "memory_usage_bytes", Help: "Current memory usage in bytes."})
cpuUsage = prometheus.NewGauge(prometheus.GaugeOpts{Namespace: "hmac", Name: "cpu_usage_percent", Help: "Current CPU usage as a percentage."})
activeConnections = prometheus.NewGauge(prometheus.GaugeOpts{Namespace: "hmac", Name: "active_connections_total", Help: "Total number of active connections."})
requestsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{Namespace: "hmac", Name: "http_requests_total", Help: "Total number of HTTP requests received, labeled by method and path."}, []string{"method", "path"})
goroutines = prometheus.NewGauge(prometheus.GaugeOpts{Namespace: "hmac", Name: "goroutines_count", Help: "Current number of goroutines."})
uploadSizeBytes = prometheus.NewHistogram(prometheus.HistogramOpts{
Namespace: "hmac",
Name: "file_server_upload_size_bytes",
Help: "Histogram of uploaded file sizes in bytes.",
Buckets: prometheus.ExponentialBuckets(100, 10, 8),
})
downloadSizeBytes = prometheus.NewHistogram(prometheus.HistogramOpts{
Namespace: "hmac",
Name: "file_server_download_size_bytes",
Help: "Histogram of downloaded file sizes in bytes.",
Buckets: prometheus.ExponentialBuckets(100, 10, 8),
})
func initializeWorkerSettings(server *ServerConfig, workers *WorkersConfig, clamav *ClamAVConfig) {
if server.AutoAdjustWorkers {
numWorkers, queueSize := autoAdjustWorkers()
workers.NumWorkers = numWorkers
workers.UploadQueueSize = queueSize
clamav.NumScanWorkers = max(numWorkers/2, 1)
if conf.Server.MetricsEnabled {
prometheus.MustRegister(uploadDuration, uploadErrorsTotal, uploadsTotal)
prometheus.MustRegister(downloadDuration, downloadsTotal, downloadErrorsTotal)
prometheus.MustRegister(memoryUsage, cpuUsage, activeConnections, requestsTotal, goroutines)
prometheus.MustRegister(uploadSizeBytes, downloadSizeBytes)
log.Infof("AutoAdjustWorkers enabled: NumWorkers=%d, UploadQueueSize=%d, NumScanWorkers=%d",
workers.NumWorkers, workers.UploadQueueSize, clamav.NumScanWorkers)
} else {
log.Infof("Manual configuration in effect: NumWorkers=%d, UploadQueueSize=%d, NumScanWorkers=%d",
workers.NumWorkers, workers.UploadQueueSize, clamav.NumScanWorkers)
}
}
// Update system metrics
func updateSystemMetrics(ctx context.Context) {
ticker := time.NewTicker(10 * time.Second)
func monitorWorkerPerformance(ctx context.Context, server *ServerConfig, w *WorkersConfig, clamav *ClamAVConfig) {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
log.Info("Stopping system metrics updater.")
log.Info("Stopping worker performance monitor.")
return
case <-ticker.C:
v, _ := mem.VirtualMemory()
memoryUsage.Set(float64(v.Used))
if server.AutoAdjustWorkers {
numWorkers, queueSize := autoAdjustWorkers()
w.NumWorkers = numWorkers
w.UploadQueueSize = queueSize
clamav.NumScanWorkers = max(numWorkers/2, 1)
cpuPercent, _ := cpu.Percent(0, false)
if len(cpuPercent) > 0 {
cpuUsage.Set(cpuPercent[0])
log.Infof("Re-adjusted workers: NumWorkers=%d, UploadQueueSize=%d, NumScanWorkers=%d",
w.NumWorkers, w.UploadQueueSize, clamav.NumScanWorkers)
workerReAdjustmentsTotal.Inc()
}
goroutines.Set(float64(runtime.NumGoroutine()))
}
}
}
// Function to check if a file exists and return its size
func fileExists(filePath string) (bool, int64) {
if cachedInfo, found := fileInfoCache.Get(filePath); found {
if info, ok := cachedInfo.(os.FileInfo); ok {
return !info.IsDir(), info.Size()
}
}
fileInfo, err := os.Stat(filePath)
if os.IsNotExist(err) {
return false, 0
} else if err != nil {
log.Error("Error checking file existence:", err)
return false, 0
}
fileInfoCache.Set(filePath, fileInfo, cache.DefaultExpiration)
return !fileInfo.IsDir(), fileInfo.Size()
}
// Function to check file extension
func isExtensionAllowed(filename string) bool {
if len(conf.Uploads.AllowedExtensions) == 0 {
return true // No restrictions if the list is empty
}
ext := strings.ToLower(filepath.Ext(filename))
for _, allowedExt := range conf.Uploads.AllowedExtensions {
if strings.ToLower(allowedExt) == ext {
return true
}
}
return false
}
// Version the file by moving the existing file to a versioned directory
func versionFile(absFilename string) error {
versionDir := absFilename + "_versions"
err := os.MkdirAll(versionDir, os.ModePerm)
if err != nil {
return fmt.Errorf("failed to create version directory: %v", err)
}
timestamp := time.Now().Format("20060102-150405")
versionedFilename := filepath.Join(versionDir, filepath.Base(absFilename)+"."+timestamp)
err = os.Rename(absFilename, versionedFilename)
if err != nil {
return fmt.Errorf("failed to version the file: %v", err)
}
log.WithFields(logrus.Fields{
"original": absFilename,
"versioned_as": versionedFilename,
}).Info("Versioned old file")
return cleanupOldVersions(versionDir)
}
// Clean up older versions if they exceed the maximum allowed
func cleanupOldVersions(versionDir string) error {
files, err := os.ReadDir(versionDir)
if err != nil {
return fmt.Errorf("failed to list version files: %v", err)
}
if conf.Versioning.MaxVersions > 0 && len(files) > conf.Versioning.MaxVersions {
excessFiles := len(files) - conf.Versioning.MaxVersions
for i := 0; i < excessFiles; i++ {
err := os.Remove(filepath.Join(versionDir, files[i].Name()))
if err != nil {
return fmt.Errorf("failed to remove old version: %v", err)
}
log.WithField("file", files[i].Name()).Info("Removed old version")
}
}
return nil
}
// Process the upload task
func processUpload(task UploadTask) error {
absFilename := task.AbsFilename
tempFilename := absFilename + ".tmp"
r := task.Request
log.Infof("Processing upload for file: %s", absFilename)
startTime := time.Now()
// Handle uploads and write to a temporary file
if conf.Uploads.ChunkedUploadsEnabled {
log.Debugf("Chunked uploads enabled. Handling chunked upload for %s", tempFilename)
err := handleChunkedUpload(tempFilename, r)
if err != nil {
uploadDuration.Observe(time.Since(startTime).Seconds())
log.WithFields(logrus.Fields{
"file": tempFilename,
"error": err,
}).Error("Failed to handle chunked upload")
return err
}
} else {
log.Debugf("Handling standard upload for %s", tempFilename)
err := createFile(tempFilename, r)
if err != nil {
log.WithFields(logrus.Fields{
"file": tempFilename,
"error": err,
}).Error("Error creating file")
uploadDuration.Observe(time.Since(startTime).Seconds())
return err
}
}
// Perform ClamAV scan on the temporary file
if clamClient != nil {
log.Debugf("Scanning %s with ClamAV", tempFilename)
err := scanFileWithClamAV(tempFilename)
if err != nil {
log.WithFields(logrus.Fields{
"file": tempFilename,
"error": err,
}).Warn("ClamAV detected a virus or scan failed")
os.Remove(tempFilename)
uploadErrorsTotal.Inc()
return err
}
log.Infof("ClamAV scan passed for file: %s", tempFilename)
}
// Handle file versioning if enabled
if conf.Versioning.EnableVersioning {
existing, _ := fileExists(absFilename)
if existing {
log.Infof("File %s exists. Initiating versioning.", absFilename)
err := versionFile(absFilename)
if err != nil {
log.WithFields(logrus.Fields{
"file": absFilename,
"error": err,
}).Error("Error versioning file")
os.Remove(tempFilename)
return err
}
log.Infof("File versioned successfully: %s", absFilename)
}
}
// Rename temporary file to final destination
err := os.Rename(tempFilename, absFilename)
if err != nil {
log.WithFields(logrus.Fields{
"temp_file": tempFilename,
"final_file": absFilename,
"error": err,
}).Error("Failed to move file to final destination")
os.Remove(tempFilename)
func readConfig(configFilename string, conf *Config) error {
viper.SetConfigFile(configFilename)
if err := viper.ReadInConfig(); err != nil {
log.WithError(err).Errorf("Unable to read config from %s", configFilename)
return err
}
log.Infof("File moved to final destination: %s", absFilename)
// Handle deduplication if enabled
if conf.Server.DeduplicationEnabled {
log.Debugf("Deduplication enabled. Checking duplicates for %s", absFilename)
err = handleDeduplication(context.Background(), absFilename)
if err != nil {
log.WithError(err).Error("Deduplication failed")
uploadErrorsTotal.Inc()
return err
}
log.Infof("Deduplication handled successfully for file: %s", absFilename)
if err := viper.Unmarshal(conf); err != nil {
return fmt.Errorf("unable to decode config into struct: %v", err)
}
log.WithFields(logrus.Fields{
"file": absFilename,
}).Info("File uploaded and processed successfully")
uploadDuration.Observe(time.Since(startTime).Seconds())
uploadsTotal.Inc()
return nil
}
// uploadWorker processes upload tasks from the uploadQueue
func uploadWorker(ctx context.Context, workerID int) {
log.Infof("Upload worker %d started.", workerID)
defer log.Infof("Upload worker %d stopped.", workerID)
for {
select {
case <-ctx.Done():
return
case task, ok := <-uploadQueue:
if !ok {
log.Warnf("Upload queue closed. Worker %d exiting.", workerID)
return
}
log.Infof("Worker %d processing upload for file: %s", workerID, task.AbsFilename)
err := processUpload(task)
if err != nil {
log.Errorf("Worker %d failed to process upload for %s: %v", workerID, task.AbsFilename, err)
uploadErrorsTotal.Inc()
} else {
log.Infof("Worker %d successfully processed upload for %s", workerID, task.AbsFilename)
}
task.Result <- err
close(task.Result)
func setDefaults() {
viper.SetDefault("server.listen_address", ":8080")
viper.SetDefault("server.storage_path", "./uploads")
viper.SetDefault("server.metrics_enabled", true)
viper.SetDefault("server.metrics_path", "/metrics")
viper.SetDefault("server.pid_file", "/var/run/hmac-file-server.pid")
viper.SetDefault("server.max_upload_size", "10GB")
viper.SetDefault("server.max_header_bytes", 1048576) // 1MB
viper.SetDefault("server.cleanup_interval", "24h")
viper.SetDefault("server.max_file_age", "720h") // 30 days
viper.SetDefault("server.pre_cache", true)
viper.SetDefault("server.pre_cache_workers", 4)
viper.SetDefault("server.pre_cache_interval", "1h")
viper.SetDefault("server.global_extensions", []string{})
viper.SetDefault("server.deduplication_enabled", true)
viper.SetDefault("server.min_free_bytes", "1GB")
viper.SetDefault("server.file_naming", "original")
viper.SetDefault("server.force_protocol", "")
viper.SetDefault("server.enable_dynamic_workers", true)
viper.SetDefault("server.worker_scale_up_thresh", 50)
viper.SetDefault("server.worker_scale_down_thresh", 10)
viper.SetDefault("uploads.allowed_extensions", []string{".zip", ".rar", ".7z", ".tar.gz", ".tgz", ".gpg", ".enc", ".pgp"})
viper.SetDefault("uploads.chunked_uploads_enabled", true)
viper.SetDefault("uploads.chunk_size", "10MB")
viper.SetDefault("uploads.resumable_uploads_enabled", true)
viper.SetDefault("uploads.max_resumable_age", "48h")
viper.SetDefault("downloads.allowed_extensions", []string{".zip", ".rar", ".7z", ".tar.gz", ".tgz", ".gpg", ".enc", ".pgp"})
viper.SetDefault("downloads.chunked_downloads_enabled", true)
viper.SetDefault("downloads.chunk_size", "10MB")
viper.SetDefault("downloads.resumable_downloads_enabled", true)
viper.SetDefault("security.secret", "your-very-secret-hmac-key")
viper.SetDefault("security.enablejwt", false)
viper.SetDefault("security.jwtsecret", "your-256-bit-secret")
viper.SetDefault("security.jwtalgorithm", "HS256")
viper.SetDefault("security.jwtexpiration", "24h")
// Logging defaults
viper.SetDefault("logging.level", "info")
viper.SetDefault("logging.file", "/var/log/hmac-file-server.log")
viper.SetDefault("logging.max_size", 100)
viper.SetDefault("logging.max_backups", 7)
viper.SetDefault("logging.max_age", 30)
viper.SetDefault("logging.compress", true)
// Deduplication defaults
viper.SetDefault("deduplication.enabled", false)
viper.SetDefault("deduplication.directory", "./dedup_store")
// ISO defaults
viper.SetDefault("iso.enabled", false)
viper.SetDefault("iso.mount_point", "/mnt/hmac_iso")
viper.SetDefault("iso.size", "1GB")
viper.SetDefault("iso.charset", "utf-8")
viper.SetDefault("iso.containerfile", "/var/lib/hmac-file-server/data.iso")
// Timeouts defaults
viper.SetDefault("timeouts.read", "60s")
viper.SetDefault("timeouts.write", "60s")
viper.SetDefault("timeouts.idle", "120s")
viper.SetDefault("timeouts.shutdown", "30s")
// Versioning defaults
viper.SetDefault("versioning.enabled", false)
viper.SetDefault("versioning.backend", "simple")
viper.SetDefault("versioning.max_revisions", 5)
// ... other defaults for Uploads, Downloads, ClamAV, Redis, Workers, File, Build
viper.SetDefault("build.version", "dev")
}
func validateConfig(c *Config) error {
if c.Server.ListenAddress == "" { // Corrected field name
return errors.New("server.listen_address is required")
}
if c.Server.FileTTL == "" && c.Server.FileTTLEnabled { // Corrected field names
return errors.New("server.file_ttl is required when server.file_ttl_enabled is true")
}
if _, err := time.ParseDuration(c.Timeouts.Read); err != nil { // Corrected field name
return fmt.Errorf("invalid timeouts.read: %v", err)
}
if _, err := time.ParseDuration(c.Timeouts.Write); err != nil { // Corrected field name
return fmt.Errorf("invalid timeouts.write: %v", err)
}
if _, err := time.ParseDuration(c.Timeouts.Idle); err != nil { // Corrected field name
return fmt.Errorf("invalid timeouts.idle: %v", err)
}
// Corrected VersioningConfig field access
if c.Versioning.Enabled { // Use the Go struct field name 'Enabled'
if c.Versioning.MaxRevs <= 0 { // Use the Go struct field name 'MaxRevs'
return errors.New("versioning.max_revisions must be positive if versioning is enabled")
}
}
}
// Initialize upload worker pool
func initializeUploadWorkerPool(ctx context.Context) {
for i := 0; i < MinWorkers; i++ {
go uploadWorker(ctx, i)
}
log.Infof("Initialized %d upload workers", MinWorkers)
}
// Worker function to process scan tasks
func scanWorker(ctx context.Context, workerID int) {
log.WithField("worker_id", workerID).Info("Scan worker started")
for {
select {
case <-ctx.Done():
log.WithField("worker_id", workerID).Info("Scan worker stopping")
return
case task, ok := <-scanQueue:
if !ok {
log.WithField("worker_id", workerID).Info("Scan queue closed")
return
}
log.WithFields(logrus.Fields{
"worker_id": workerID,
"file": task.AbsFilename,
}).Info("Processing scan task")
err := scanFileWithClamAV(task.AbsFilename)
if err != nil {
log.WithFields(logrus.Fields{
"worker_id": workerID,
"file": task.AbsFilename,
"error": err,
}).Error("Failed to scan file")
} else {
log.WithFields(logrus.Fields{
"worker_id": workerID,
"file": task.AbsFilename,
}).Info("Successfully scanned file")
}
task.Result <- err
close(task.Result)
}
}
}
// Initialize scan worker pool
func initializeScanWorkerPool(ctx context.Context) {
for i := 0; i < ScanWorkers; i++ {
go scanWorker(ctx, i)
}
log.Infof("Initialized %d scan workers", ScanWorkers)
}
// Setup router with middleware
func setupRouter() http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("/", handleRequest)
if conf.Server.MetricsEnabled {
mux.Handle("/metrics", promhttp.Handler())
// Validate JWT secret if JWT is enabled
if c.Security.EnableJWT && strings.TrimSpace(c.Security.JWTSecret) == "" {
return errors.New("security.jwtsecret is required when security.enablejwt is true")
}
// Apply middleware
handler := loggingMiddleware(mux)
handler = recoveryMiddleware(handler)
handler = corsMiddleware(handler)
return handler
}
// Middleware for logging
func loggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestsTotal.WithLabelValues(r.Method, r.URL.Path).Inc()
next.ServeHTTP(w, r)
})
}
// Middleware for panic recovery
func recoveryMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if rec := recover(); rec != nil {
log.WithFields(logrus.Fields{
"method": r.Method,
"url": r.URL.String(),
"error": rec,
}).Error("Panic recovered in HTTP handler")
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
}()
next.ServeHTTP(w, r)
})
}
// corsMiddleware handles CORS by setting appropriate headers
func corsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Set CORS headers
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-File-MAC")
w.Header().Set("Access-Control-Max-Age", "86400") // Cache preflight response for 1 day
// Handle preflight OPTIONS request
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusOK)
return
}
// Proceed to the next handler
next.ServeHTTP(w, r)
})
}
// Handle file uploads and downloads
func handleRequest(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPost && strings.Contains(r.Header.Get("Content-Type"), "multipart/form-data") {
absFilename, err := sanitizeFilePath(conf.Server.StoragePath, strings.TrimPrefix(r.URL.Path, "/"))
if err != nil {
log.WithError(err).Error("Invalid file path")
http.Error(w, "Invalid file path", http.StatusBadRequest)
return
}
err = handleMultipartUpload(w, r, absFilename)
if err != nil {
log.WithError(err).Error("Failed to handle multipart upload")
http.Error(w, "Failed to handle multipart upload", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusCreated)
return
// Validate HMAC secret if JWT is not enabled (as it's the fallback)
if !c.Security.EnableJWT && strings.TrimSpace(c.Security.Secret) == "" {
return errors.New("security.secret is required for HMAC authentication (when JWT is disabled)")
}
// Get client IP address
clientIP := r.Header.Get("X-Real-IP")
if clientIP == "" {
clientIP = r.Header.Get("X-Forwarded-For")
}
if clientIP == "" {
// Fallback to RemoteAddr
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
log.WithError(err).Warn("Failed to parse RemoteAddr")
clientIP = r.RemoteAddr
return nil
}
// validateJWTFromRequest extracts and validates a JWT from the request.
func validateJWTFromRequest(r *http.Request, secret string) (*jwt.Token, error) {
authHeader := r.Header.Get("Authorization")
tokenString := ""
if authHeader != "" {
splitToken := strings.Split(authHeader, "Bearer ")
if len(splitToken) == 2 {
tokenString = splitToken[1]
} else {
clientIP = host
return nil, errors.New("invalid Authorization header format")
}
} else {
// Fallback to checking 'token' query parameter
tokenString = r.URL.Query().Get("token")
if tokenString == "" {
return nil, errors.New("missing JWT in Authorization header or 'token' query parameter")
}
}
// Log the request with the client IP
log.WithFields(logrus.Fields{
"method": r.Method,
"url": r.URL.String(),
"remote": clientIP,
}).Info("Incoming request")
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(secret), nil
})
// Parse URL and query parameters
p := r.URL.Path
a, err := url.ParseQuery(r.URL.RawQuery)
if err != nil {
log.Warn("Failed to parse query parameters")
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
return nil, fmt.Errorf("JWT validation failed: %w", err)
}
fileStorePath := strings.TrimPrefix(p, "/")
if fileStorePath == "" || fileStorePath == "/" {
log.Warn("Access to root directory is forbidden")
http.Error(w, "Forbidden", http.StatusForbidden)
return
} else if fileStorePath[0] == '/' {
fileStorePath = fileStorePath[1:]
if !token.Valid {
return nil, errors.New("invalid JWT")
}
absFilename, err := sanitizeFilePath(conf.Server.StoragePath, fileStorePath)
if err != nil {
log.WithFields(logrus.Fields{
"file": fileStorePath,
"error": err,
}).Warn("Invalid file path")
http.Error(w, "Invalid file path", http.StatusBadRequest)
return
}
switch r.Method {
case http.MethodPut:
handleUpload(w, r, absFilename, fileStorePath, a)
case http.MethodHead, http.MethodGet:
handleDownload(w, r, absFilename, fileStorePath)
case http.MethodOptions:
// Handled by NGINX; no action needed
w.Header().Set("Allow", "OPTIONS, GET, PUT, HEAD")
return
default:
log.WithField("method", r.Method).Warn("Invalid HTTP method for upload directory")
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return
}
return token, nil
}
// Handle file uploads with extension restrictions and HMAC validation
func handleUpload(w http.ResponseWriter, r *http.Request, absFilename, fileStorePath string, a url.Values) {
// Log the storage path being used
log.Infof("Using storage path: %s", conf.Server.StoragePath)
// validateHMAC validates the HMAC signature of the request for legacy protocols and POST uploads.
func validateHMAC(r *http.Request, secret string) error {
log.Debugf("validateHMAC: Validating request to %s with query: %s", r.URL.Path, r.URL.RawQuery)
// Check for X-Signature header (for POST uploads)
signature := r.Header.Get("X-Signature")
if signature != "" {
// This is a POST upload with X-Signature header
message := r.URL.Path
h := hmac.New(sha256.New, []byte(secret))
h.Write([]byte(message))
expectedSignature := hex.EncodeToString(h.Sum(nil))
// Determine protocol version based on query parameters
var protocolVersion string
if a.Get("v2") != "" {
protocolVersion = "v2"
} else if a.Get("token") != "" {
protocolVersion = "token"
} else if a.Get("v") != "" {
protocolVersion = "v"
} else {
log.Warn("No HMAC attached to URL. Expecting 'v', 'v2', or 'token' parameter as MAC")
http.Error(w, "No HMAC attached to URL. Expecting 'v', 'v2', or 'token' parameter as MAC", http.StatusForbidden)
return
if !hmac.Equal([]byte(signature), []byte(expectedSignature)) {
return errors.New("invalid HMAC signature in X-Signature header")
}
return nil
}
log.Debugf("Protocol version determined: %s", protocolVersion)
// Initialize HMAC
mac := hmac.New(sha256.New, []byte(conf.Security.Secret))
// Check for legacy URL-based HMAC protocols (v, v2, token)
query := r.URL.Query()
var protocolVersion string
var providedMACHex string
if query.Get("v2") != "" {
protocolVersion = "v2"
providedMACHex = query.Get("v2")
} else if query.Get("token") != "" {
protocolVersion = "token"
providedMACHex = query.Get("token")
} else if query.Get("v") != "" {
protocolVersion = "v"
providedMACHex = query.Get("v")
} else {
return errors.New("no HMAC signature found (missing X-Signature header or v/v2/token query parameter)")
}
// Extract file path from URL
fileStorePath := strings.TrimPrefix(r.URL.Path, "/")
// Calculate HMAC based on protocol version (matching legacy behavior)
mac := hmac.New(sha256.New, []byte(secret))
// Calculate MAC based on protocolVersion
if protocolVersion == "v" {
mac.Write([]byte(fileStorePath + "\x20" + strconv.FormatInt(r.ContentLength, 10)))
} else if protocolVersion == "v2" || protocolVersion == "token" {
// Legacy v protocol: fileStorePath + "\x20" + contentLength
message := fileStorePath + "\x20" + strconv.FormatInt(r.ContentLength, 10)
mac.Write([]byte(message))
} else {
// v2 and token protocols: fileStorePath + "\x00" + contentLength + "\x00" + contentType
contentType := mime.TypeByExtension(filepath.Ext(fileStorePath))
if contentType == "" {
contentType = "application/octet-stream"
}
mac.Write([]byte(fileStorePath + "\x00" + strconv.FormatInt(r.ContentLength, 10) + "\x00" + contentType))
message := fileStorePath + "\x00" + strconv.FormatInt(r.ContentLength, 10) + "\x00" + contentType
log.Debugf("validateHMAC: %s protocol message: %q (len=%d)", protocolVersion, message, len(message))
mac.Write([]byte(message))
}
calculatedMAC := mac.Sum(nil)
log.Debugf("Calculated MAC: %x", calculatedMAC)
calculatedMACHex := hex.EncodeToString(calculatedMAC)
// Decode provided MAC from hex
providedMACHex := a.Get(protocolVersion)
// Decode provided MAC
providedMAC, err := hex.DecodeString(providedMACHex)
if err != nil {
log.Warn("Invalid MAC encoding")
http.Error(w, "Invalid MAC encoding", http.StatusForbidden)
return
return fmt.Errorf("invalid MAC encoding for %s protocol: %v", protocolVersion, err)
}
log.Debugf("Provided MAC: %x", providedMAC)
// Validate the HMAC
log.Debugf("validateHMAC: %s protocol - calculated: %s, provided: %s", protocolVersion, calculatedMACHex, providedMACHex)
// Compare MACs
if !hmac.Equal(calculatedMAC, providedMAC) {
log.Warn("Invalid MAC")
http.Error(w, "Invalid MAC", http.StatusForbidden)
return
}
log.Debug("HMAC validation successful")
// Validate file extension
if !isExtensionAllowed(fileStorePath) {
log.WithFields(logrus.Fields{
// No need to sanitize and validate the file path here since absFilename is already sanitized in handleRequest
"file": fileStorePath,
"error": err,
}).Warn("Invalid file path")
http.Error(w, "Invalid file path", http.StatusBadRequest)
uploadErrorsTotal.Inc()
return
}
// absFilename = sanitizedFilename
// Check if there is enough free space
err = checkStorageSpace(conf.Server.StoragePath, conf.Server.MinFreeBytes)
if err != nil {
log.WithFields(logrus.Fields{
"storage_path": conf.Server.StoragePath,
"error": err,
}).Warn("Not enough free space")
http.Error(w, "Not enough free space", http.StatusInsufficientStorage)
uploadErrorsTotal.Inc()
return
return fmt.Errorf("invalid MAC for %s protocol", protocolVersion)
}
// Create an UploadTask with a result channel
result := make(chan error)
task := UploadTask{
AbsFilename: absFilename,
Request: r,
Result: result,
}
// Submit task to the upload queue
select {
case uploadQueue <- task:
// Successfully added to the queue
log.Debug("Upload task enqueued successfully")
default:
// Queue is full
log.Warn("Upload queue is full. Rejecting upload")
http.Error(w, "Server busy. Try again later.", http.StatusServiceUnavailable)
uploadErrorsTotal.Inc()
return
}
// Wait for the worker to process the upload
err = <-result
if err != nil {
// The worker has already logged the error; send an appropriate HTTP response
http.Error(w, fmt.Sprintf("Upload failed: %v", err), http.StatusInternalServerError)
return
}
// Upload was successful
w.WriteHeader(http.StatusCreated)
log.Debugf("%s HMAC authentication successful for request: %s", protocolVersion, r.URL.Path)
return nil
}
// Handle file downloads
func handleDownload(w http.ResponseWriter, r *http.Request, absFilename, fileStorePath string) {
fileInfo, err := getFileInfo(absFilename)
// validateV3HMAC validates the HMAC signature for v3 protocol (mod_http_upload_external).
func validateV3HMAC(r *http.Request, secret string) error {
query := r.URL.Query()
// Extract v3 signature and expires from query parameters
signature := query.Get("v3")
expiresStr := query.Get("expires")
if signature == "" {
return errors.New("missing v3 signature parameter")
}
if expiresStr == "" {
return errors.New("missing expires parameter")
}
// Parse expires timestamp
expires, err := strconv.ParseInt(expiresStr, 10, 64)
if err != nil {
log.WithError(err).Error("Failed to get file information")
http.Error(w, "Not Found", http.StatusNotFound)
downloadErrorsTotal.Inc()
return fmt.Errorf("invalid expires parameter: %v", err)
}
// Check if signature has expired
now := time.Now().Unix()
if now > expires {
return errors.New("signature has expired")
}
// Construct message for HMAC verification
// Format: METHOD\nEXPIRES\nPATH
message := fmt.Sprintf("%s\n%s\n%s", r.Method, expiresStr, r.URL.Path)
// Calculate expected HMAC signature
h := hmac.New(sha256.New, []byte(secret))
h.Write([]byte(message))
expectedSignature := hex.EncodeToString(h.Sum(nil))
// Compare signatures
if !hmac.Equal([]byte(signature), []byte(expectedSignature)) {
return errors.New("invalid v3 HMAC signature")
}
return nil
}
// handleUpload handles file uploads.
func handleUpload(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
activeConnections.Inc()
defer activeConnections.Dec()
// Only allow POST method
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
uploadErrorsTotal.Inc()
return
} else if fileInfo.IsDir() {
log.Warn("Directory listing forbidden")
http.Error(w, "Forbidden", http.StatusForbidden)
}
// Authentication
if conf.Security.EnableJWT {
_, err := validateJWTFromRequest(r, conf.Security.JWTSecret)
if err != nil {
http.Error(w, fmt.Sprintf("JWT Authentication failed: %v", err), http.StatusUnauthorized)
uploadErrorsTotal.Inc()
return
}
log.Debugf("JWT authentication successful for upload request: %s", r.URL.Path)
} else {
err := validateHMAC(r, conf.Security.Secret)
if err != nil {
http.Error(w, fmt.Sprintf("HMAC Authentication failed: %v", err), http.StatusUnauthorized)
uploadErrorsTotal.Inc()
return
}
log.Debugf("HMAC authentication successful for upload request: %s", r.URL.Path)
}
// Parse multipart form
err := r.ParseMultipartForm(32 << 20) // 32MB max memory
if err != nil {
http.Error(w, fmt.Sprintf("Error parsing multipart form: %v", err), http.StatusBadRequest)
uploadErrorsTotal.Inc()
return
}
// Get file from form
file, header, err := r.FormFile("file")
if err != nil {
http.Error(w, fmt.Sprintf("Error getting file from form: %v", err), http.StatusBadRequest)
uploadErrorsTotal.Inc()
return
}
defer file.Close()
// Validate file extension if configured
if len(conf.Uploads.AllowedExtensions) > 0 {
ext := strings.ToLower(filepath.Ext(header.Filename))
allowed := false
for _, allowedExt := range conf.Uploads.AllowedExtensions {
if ext == allowedExt {
allowed = true
break
}
}
if !allowed {
http.Error(w, fmt.Sprintf("File extension %s not allowed", ext), http.StatusBadRequest)
uploadErrorsTotal.Inc()
return
}
}
// Generate filename based on configuration
var filename string
switch conf.Server.FileNaming {
case "HMAC":
// Generate HMAC-based filename
h := hmac.New(sha256.New, []byte(conf.Security.Secret))
h.Write([]byte(header.Filename + time.Now().String()))
filename = hex.EncodeToString(h.Sum(nil)) + filepath.Ext(header.Filename)
default: // "original" or "None"
filename = header.Filename
}
// Create full file path
storagePath := conf.Server.StoragePath
if conf.ISO.Enabled {
storagePath = conf.ISO.MountPoint
}
absFilename := filepath.Join(storagePath, filename)
// Create the file
dst, err := os.Create(absFilename)
if err != nil {
http.Error(w, fmt.Sprintf("Error creating file: %v", err), http.StatusInternalServerError)
uploadErrorsTotal.Inc()
return
}
defer dst.Close()
// Copy file content
written, err := io.Copy(dst, file)
if err != nil {
http.Error(w, fmt.Sprintf("Error saving file: %v", err), http.StatusInternalServerError)
uploadErrorsTotal.Inc()
// Clean up partial file
os.Remove(absFilename)
return
}
// Handle deduplication if enabled
if conf.Server.DeduplicationEnabled {
ctx := context.Background()
err = handleDeduplication(ctx, absFilename)
if err != nil {
log.Warnf("Deduplication failed for %s: %v", absFilename, err)
}
}
// Update metrics
duration := time.Since(startTime)
uploadDuration.Observe(duration.Seconds())
uploadsTotal.Inc()
uploadSizeBytes.Observe(float64(written))
// Return success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
response := map[string]interface{}{
"success": true,
"filename": filename,
"size": written,
"duration": duration.String(),
}
// Create JSON response
if jsonBytes, err := json.Marshal(response); err == nil {
w.Write(jsonBytes)
} else {
fmt.Fprintf(w, `{"success": true, "filename": "%s", "size": %d}`, filename, written)
}
log.Infof("Successfully uploaded %s (%s) in %s", filename, formatBytes(written), duration)
}
// handleDownload handles file downloads.
func handleDownload(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
activeConnections.Inc()
defer activeConnections.Dec()
// Authentication
if conf.Security.EnableJWT {
_, err := validateJWTFromRequest(r, conf.Security.JWTSecret)
if err != nil {
http.Error(w, fmt.Sprintf("JWT Authentication failed: %v", err), http.StatusUnauthorized)
downloadErrorsTotal.Inc()
return
}
log.Debugf("JWT authentication successful for download request: %s", r.URL.Path)
} else {
err := validateHMAC(r, conf.Security.Secret)
if err != nil {
http.Error(w, fmt.Sprintf("HMAC Authentication failed: %v", err), http.StatusUnauthorized)
downloadErrorsTotal.Inc()
return
}
log.Debugf("HMAC authentication successful for download request: %s", r.URL.Path)
}
filename := strings.TrimPrefix(r.URL.Path, "/download/")
if filename == "" {
http.Error(w, "Filename not specified", http.StatusBadRequest)
downloadErrorsTotal.Inc()
return
}
absFilename, err := sanitizeFilePath(conf.Server.StoragePath, filename) // Use sanitizeFilePath from helpers.go
if err != nil {
http.Error(w, fmt.Sprintf("Invalid file path: %v", err), http.StatusBadRequest)
downloadErrorsTotal.Inc()
return
}
fileInfo, err := os.Stat(absFilename)
if os.IsNotExist(err) {
http.Error(w, "File not found", http.StatusNotFound)
downloadErrorsTotal.Inc()
return
}
if err != nil {
http.Error(w, fmt.Sprintf("Error accessing file: %v", err), http.StatusInternalServerError)
downloadErrorsTotal.Inc()
return
}
if fileInfo.IsDir() {
http.Error(w, "Cannot download a directory", http.StatusBadRequest)
downloadErrorsTotal.Inc()
return
}
file, err := os.Open(absFilename)
if err != nil {
http.Error(w, fmt.Sprintf("Error opening file: %v", err), http.StatusInternalServerError)
downloadErrorsTotal.Inc()
return
}
defer file.Close()
w.Header().Set("Content-Disposition", "attachment; filename=\""+filepath.Base(absFilename)+"\"")
w.Header().Set("Content-Type", "application/octet-stream")
w.Header().Set("Content-Length", fmt.Sprintf("%d", fileInfo.Size()))
// Use a pooled buffer for copying
bufPtr := bufferPool.Get().(*[]byte)
defer bufferPool.Put(bufPtr)
buf := *bufPtr
n, err := io.CopyBuffer(w, file, buf)
if err != nil {
log.Errorf("Error during download of %s: %v", absFilename, err)
// Don't write http.Error here if headers already sent
downloadErrorsTotal.Inc()
return // Ensure we don't try to record metrics if there was an error during copy
}
duration := time.Since(startTime)
downloadDuration.Observe(duration.Seconds())
downloadsTotal.Inc()
downloadSizeBytes.Observe(float64(n))
log.Infof("Successfully downloaded %s (%s) in %s", absFilename, formatBytes(n), duration)
}
// handleV3Upload handles PUT requests for v3 protocol (mod_http_upload_external).
func handleV3Upload(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
activeConnections.Inc()
defer activeConnections.Dec()
// Only allow PUT method for v3
if r.Method != http.MethodPut {
http.Error(w, "Method not allowed for v3 uploads", http.StatusMethodNotAllowed)
uploadErrorsTotal.Inc()
return
}
// Validate v3 HMAC signature
err := validateV3HMAC(r, conf.Security.Secret)
if err != nil {
http.Error(w, fmt.Sprintf("v3 Authentication failed: %v", err), http.StatusUnauthorized)
uploadErrorsTotal.Inc()
return
}
log.Debugf("v3 HMAC authentication successful for upload request: %s", r.URL.Path)
// Extract filename from the URL path
// Path format: /uuid/subdir/filename.ext
pathParts := strings.Split(strings.Trim(r.URL.Path, "/"), "/")
if len(pathParts) < 1 {
http.Error(w, "Invalid upload path", http.StatusBadRequest)
uploadErrorsTotal.Inc()
return
}
// Use the last part as filename
originalFilename := pathParts[len(pathParts)-1]
if originalFilename == "" {
http.Error(w, "No filename specified", http.StatusBadRequest)
uploadErrorsTotal.Inc()
return
}
// Validate file extension if configured
if len(conf.Uploads.AllowedExtensions) > 0 {
ext := strings.ToLower(filepath.Ext(originalFilename))
allowed := false
for _, allowedExt := range conf.Uploads.AllowedExtensions {
if ext == allowedExt {
allowed = true
break
}
}
if !allowed {
http.Error(w, fmt.Sprintf("File extension %s not allowed", ext), http.StatusBadRequest)
uploadErrorsTotal.Inc()
return
}
}
// Generate filename based on configuration
var filename string
switch conf.Server.FileNaming {
case "HMAC":
// Generate HMAC-based filename
h := hmac.New(sha256.New, []byte(conf.Security.Secret))
h.Write([]byte(originalFilename + time.Now().String()))
filename = hex.EncodeToString(h.Sum(nil)) + filepath.Ext(originalFilename)
default: // "original" or "None"
filename = originalFilename
}
// Create full file path
storagePath := conf.Server.StoragePath
if conf.ISO.Enabled {
storagePath = conf.ISO.MountPoint
}
absFilename := filepath.Join(storagePath, filename)
// Create the file
dst, err := os.Create(absFilename)
if err != nil {
http.Error(w, fmt.Sprintf("Error creating file: %v", err), http.StatusInternalServerError)
uploadErrorsTotal.Inc()
return
}
defer dst.Close()
// Copy file content from request body
written, err := io.Copy(dst, r.Body)
if err != nil {
http.Error(w, fmt.Sprintf("Error saving file: %v", err), http.StatusInternalServerError)
uploadErrorsTotal.Inc()
// Clean up partial file
os.Remove(absFilename)
return
}
// Handle deduplication if enabled
if conf.Server.DeduplicationEnabled {
ctx := context.Background()
err = handleDeduplication(ctx, absFilename)
if err != nil {
log.Warnf("Deduplication failed for %s: %v", absFilename, err)
}
}
// Update metrics
duration := time.Since(startTime)
uploadDuration.Observe(duration.Seconds())
uploadsTotal.Inc()
uploadSizeBytes.Observe(float64(written))
// Return success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
response := map[string]interface{}{
"success": true,
"filename": filename,
"size": written,
"duration": duration.String(),
}
// Create JSON response
if jsonBytes, err := json.Marshal(response); err == nil {
w.Write(jsonBytes)
} else {
fmt.Fprintf(w, `{"success": true, "filename": "%s", "size": %d}`, filename, written)
}
log.Infof("Successfully uploaded %s via v3 protocol (%s) in %s", filename, formatBytes(written), duration)
}
// handleLegacyUpload handles PUT requests for legacy protocols (v, v2, token).
func handleLegacyUpload(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
activeConnections.Inc()
defer activeConnections.Dec()
log.Debugf("handleLegacyUpload: Processing request to %s with query: %s", r.URL.Path, r.URL.RawQuery)
// Only allow PUT method for legacy uploads
if r.Method != http.MethodPut {
http.Error(w, "Method not allowed for legacy uploads", http.StatusMethodNotAllowed)
uploadErrorsTotal.Inc()
return
}
// Validate legacy HMAC signature
err := validateHMAC(r, conf.Security.Secret)
if err != nil {
http.Error(w, fmt.Sprintf("Legacy Authentication failed: %v", err), http.StatusUnauthorized)
uploadErrorsTotal.Inc()
return
}
// Extract filename from the URL path
fileStorePath := strings.TrimPrefix(r.URL.Path, "/")
if fileStorePath == "" {
http.Error(w, "No filename specified", http.StatusBadRequest)
uploadErrorsTotal.Inc()
return
}
// Validate file extension if configured
if len(conf.Uploads.AllowedExtensions) > 0 {
ext := strings.ToLower(filepath.Ext(fileStorePath))
allowed := false
for _, allowedExt := range conf.Uploads.AllowedExtensions {
if ext == allowedExt {
allowed = true
break
}
}
if !allowed {
http.Error(w, fmt.Sprintf("File extension %s not allowed", ext), http.StatusBadRequest)
uploadErrorsTotal.Inc()
return
}
}
// Create full file path
storagePath := conf.Server.StoragePath
if conf.ISO.Enabled {
storagePath = conf.ISO.MountPoint
}
// Generate filename based on configuration
var absFilename string
var filename string
switch conf.Server.FileNaming {
case "HMAC":
// Generate HMAC-based filename
h := hmac.New(sha256.New, []byte(conf.Security.Secret))
h.Write([]byte(fileStorePath + time.Now().String()))
filename = hex.EncodeToString(h.Sum(nil)) + filepath.Ext(fileStorePath)
absFilename = filepath.Join(storagePath, filename)
default: // "original" or "None"
// Preserve full directory structure for legacy XMPP compatibility
var sanitizeErr error
absFilename, sanitizeErr = sanitizeFilePath(storagePath, fileStorePath)
if sanitizeErr != nil {
http.Error(w, fmt.Sprintf("Invalid file path: %v", sanitizeErr), http.StatusBadRequest)
uploadErrorsTotal.Inc()
return
}
filename = filepath.Base(fileStorePath) // For logging purposes
}
// Create directory structure if it doesn't exist
if err := os.MkdirAll(filepath.Dir(absFilename), 0755); err != nil {
http.Error(w, fmt.Sprintf("Error creating directory: %v", err), http.StatusInternalServerError)
uploadErrorsTotal.Inc()
return
}
// Create the file
dst, err := os.Create(absFilename)
if err != nil {
http.Error(w, fmt.Sprintf("Error creating file: %v", err), http.StatusInternalServerError)
uploadErrorsTotal.Inc()
return
}
defer dst.Close()
// Log upload start for large files
if r.ContentLength > 10*1024*1024 { // Log for files > 10MB
log.Infof("Starting upload of %s (%.1f MiB)", filename, float64(r.ContentLength)/(1024*1024))
}
// Copy file content from request body with progress reporting
written, err := copyWithProgress(dst, r.Body, r.ContentLength, filename)
if err != nil {
http.Error(w, fmt.Sprintf("Error saving file: %v", err), http.StatusInternalServerError)
uploadErrorsTotal.Inc()
// Clean up partial file
os.Remove(absFilename)
return
}
// Handle deduplication if enabled
if conf.Server.DeduplicationEnabled {
ctx := context.Background()
err = handleDeduplication(ctx, absFilename)
if err != nil {
log.Warnf("Deduplication failed for %s: %v", absFilename, err)
}
}
// Update metrics
duration := time.Since(startTime)
uploadDuration.Observe(duration.Seconds())
uploadsTotal.Inc()
uploadSizeBytes.Observe(float64(written))
// Return success response (201 Created for legacy compatibility)
w.WriteHeader(http.StatusCreated)
log.Infof("Successfully uploaded %s via legacy protocol (%s) in %s", filename, formatBytes(written), duration)
}
// handleLegacyDownload handles GET/HEAD requests for legacy downloads.
func handleLegacyDownload(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
activeConnections.Inc()
defer activeConnections.Dec()
// Extract filename from the URL path
fileStorePath := strings.TrimPrefix(r.URL.Path, "/")
if fileStorePath == "" {
http.Error(w, "No filename specified", http.StatusBadRequest)
downloadErrorsTotal.Inc()
return
}
// Create full file path
storagePath := conf.Server.StoragePath
if conf.ISO.Enabled {
storagePath = conf.ISO.MountPoint
}
absFilename := filepath.Join(storagePath, fileStorePath)
fileInfo, err := os.Stat(absFilename)
if os.IsNotExist(err) {
http.Error(w, "File not found", http.StatusNotFound)
downloadErrorsTotal.Inc()
return
}
if err != nil {
http.Error(w, fmt.Sprintf("Error accessing file: %v", err), http.StatusInternalServerError)
downloadErrorsTotal.Inc()
return
}
if fileInfo.IsDir() {
http.Error(w, "Cannot download a directory", http.StatusBadRequest)
downloadErrorsTotal.Inc()
return
}
// Set appropriate headers
contentType := mime.TypeByExtension(filepath.Ext(fileStorePath))
if contentType == "" {
contentType = "application/octet-stream"
}
w.Header().Set("Content-Type", contentType)
w.Header().Set("Content-Length", strconv.FormatInt(fileInfo.Size(), 10))
// Handle resumable downloads
if conf.Uploads.ResumableUploadsEnabled {
handleResumableDownload(absFilename, w, r, fileInfo.Size())
return
}
// For HEAD requests, only send headers
if r.Method == http.MethodHead {
w.Header().Set("Content-Length", strconv.FormatInt(fileInfo.Size(), 10))
downloadsTotal.Inc()
return
} else {
// Measure download duration
startTime := time.Now()
log.Infof("Initiating download for file: %s", absFilename)
http.ServeFile(w, r, absFilename)
downloadDuration.Observe(time.Since(startTime).Seconds())
downloadSizeBytes.Observe(float64(fileInfo.Size()))
downloadsTotal.Inc()
log.Infof("File downloaded successfully: %s", absFilename)
return
}
}
// Create the file for upload with buffered Writer
func createFile(tempFilename string, r *http.Request) error {
absDirectory := filepath.Dir(tempFilename)
err := os.MkdirAll(absDirectory, os.ModePerm)
if err != nil {
log.WithError(err).Errorf("Failed to create directory %s", absDirectory)
return fmt.Errorf("failed to create directory %s: %w", absDirectory, err)
}
// Open the file for writing
targetFile, err := os.OpenFile(tempFilename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
log.WithError(err).Errorf("Failed to create file %s", tempFilename)
return fmt.Errorf("failed to create file %s: %w", tempFilename, err)
}
defer targetFile.Close()
// Use a large buffer for efficient file writing
bufferSize := 4 * 1024 * 1024 // 4 MB buffer
writer := bufio.NewWriterSize(targetFile, bufferSize)
buffer := make([]byte, bufferSize)
totalBytes := int64(0)
for {
n, readErr := r.Body.Read(buffer)
if n > 0 {
totalBytes += int64(n)
_, writeErr := writer.Write(buffer[:n])
if writeErr != nil {
log.WithError(writeErr).Errorf("Failed to write to file %s", tempFilename)
return fmt.Errorf("failed to write to file %s: %w", tempFilename, writeErr)
}
}
if readErr != nil {
if readErr == io.EOF {
break
}
log.WithError(readErr).Error("Failed to read request body")
return fmt.Errorf("failed to read request body: %w", readErr)
}
}
err = writer.Flush()
if err != nil {
log.WithError(err).Errorf("Failed to flush buffer to file %s", tempFilename)
return fmt.Errorf("failed to flush buffer to file %s: %w", tempFilename, err)
}
log.WithFields(logrus.Fields{
"temp_file": tempFilename,
"total_bytes": totalBytes,
}).Info("File uploaded successfully")
uploadSizeBytes.Observe(float64(totalBytes))
return nil
}
// Scan the uploaded file with ClamAV (Optional)
func scanFileWithClamAV(filePath string) error {
log.WithField("file", filePath).Info("Scanning file with ClamAV")
scanResultChan, err := clamClient.ScanFile(filePath)
if err != nil {
log.WithError(err).Error("Failed to initiate ClamAV scan")
return fmt.Errorf("failed to initiate ClamAV scan: %w", err)
}
// Receive scan result
scanResult := <-scanResultChan
if scanResult == nil {
log.Error("Failed to receive scan result from ClamAV")
return fmt.Errorf("failed to receive scan result from ClamAV")
}
// Handle scan result
switch scanResult.Status {
case clamd.RES_OK:
log.WithField("file", filePath).Info("ClamAV scan passed")
return nil
case clamd.RES_FOUND:
log.WithFields(logrus.Fields{
"file": filePath,
"description": scanResult.Description,
}).Warn("ClamAV detected a virus")
return fmt.Errorf("virus detected: %s", scanResult.Description)
default:
log.WithFields(logrus.Fields{
"file": filePath,
"status": scanResult.Status,
"description": scanResult.Description,
}).Warn("ClamAV scan returned unexpected status")
return fmt.Errorf("ClamAV scan returned unexpected status: %s", scanResult.Description)
}
}
// initClamAV initializes the ClamAV client and logs the status
func initClamAV(socket string) (*clamd.Clamd, error) {
if socket == "" {
log.Error("ClamAV socket path is not configured.")
return nil, fmt.Errorf("ClamAV socket path is not configured")
}
clamClient := clamd.NewClamd("unix:" + socket)
err := clamClient.Ping()
if err != nil {
log.Errorf("Failed to connect to ClamAV at %s: %v", socket, err)
return nil, fmt.Errorf("failed to connect to ClamAV: %w", err)
}
log.Info("Connected to ClamAV successfully.")
return clamClient, nil
}
// Handle resumable downloads
func handleResumableDownload(absFilename string, w http.ResponseWriter, r *http.Request, fileSize int64) {
rangeHeader := r.Header.Get("Range")
if rangeHeader == "" {
// If no Range header, serve the full file
startTime := time.Now()
http.ServeFile(w, r, absFilename)
downloadDuration.Observe(time.Since(startTime).Seconds())
downloadSizeBytes.Observe(float64(fileSize))
w.WriteHeader(http.StatusOK)
downloadsTotal.Inc()
return
}
// Parse Range header
ranges := strings.Split(strings.TrimPrefix(rangeHeader, "bytes="), "-")
if len(ranges) != 2 {
http.Error(w, "Invalid Range", http.StatusRequestedRangeNotSatisfiable)
downloadErrorsTotal.Inc()
return
}
start, err := strconv.ParseInt(ranges[0], 10, 64)
if err != nil {
http.Error(w, "Invalid Range", http.StatusRequestedRangeNotSatisfiable)
downloadErrorsTotal.Inc()
return
}
// Calculate end byte
end := fileSize - 1
if ranges[1] != "" {
end, err = strconv.ParseInt(ranges[1], 10, 64)
if err != nil || end >= fileSize {
http.Error(w, "Invalid Range", http.StatusRequestedRangeNotSatisfiable)
downloadErrorsTotal.Inc()
return
}
}
// Set response headers for partial content
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, fileSize))
w.Header().Set("Content-Length", strconv.FormatInt(end-start+1, 10))
w.Header().Set("Accept-Ranges", "bytes")
w.WriteHeader(http.StatusPartialContent)
// Serve the requested byte range
// For GET requests, serve the file
file, err := os.Open(absFilename)
if err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
http.Error(w, fmt.Sprintf("Error opening file: %v", err), http.StatusInternalServerError)
downloadErrorsTotal.Inc()
return
}
defer file.Close()
// Seek to the start byte
_, err = file.Seek(start, 0)
// Use a pooled buffer for copying
bufPtr := bufferPool.Get().(*[]byte)
defer bufferPool.Put(bufPtr)
buf := *bufPtr
n, err := io.CopyBuffer(w, file, buf)
if err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
log.Errorf("Error during download of %s: %v", absFilename, err)
downloadErrorsTotal.Inc()
return
}
// Create a buffer and copy the specified range to the response writer
buffer := make([]byte, 32*1024) // 32KB buffer
remaining := end - start + 1
startTime := time.Now()
for remaining > 0 {
if int64(len(buffer)) > remaining {
buffer = buffer[:remaining]
}
n, err := file.Read(buffer)
if n > 0 {
if _, writeErr := w.Write(buffer[:n]); writeErr != nil {
log.WithError(writeErr).Error("Failed to write to response")
downloadErrorsTotal.Inc()
return
}
remaining -= int64(n)
}
if err != nil {
if err != io.EOF {
log.WithError(err).Error("Error reading file during resumable download")
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
downloadErrorsTotal.Inc()
}
break
}
}
downloadDuration.Observe(time.Since(startTime).Seconds())
downloadSizeBytes.Observe(float64(end - start + 1))
duration := time.Since(startTime)
downloadDuration.Observe(duration.Seconds())
downloadsTotal.Inc()
downloadSizeBytes.Observe(float64(n))
log.Infof("Successfully downloaded %s (%s) in %s", absFilename, formatBytes(n), duration)
}
// Handle chunked uploads with bufio.Writer
func handleChunkedUpload(tempFilename string, r *http.Request) error {
log.WithField("file", tempFilename).Info("Handling chunked upload to temporary file")
// printValidationChecks prints all available validation checks
func printValidationChecks() {
fmt.Println("HMAC File Server Configuration Validation Checks")
fmt.Println("=================================================")
fmt.Println()
// Ensure the directory exists
absDirectory := filepath.Dir(tempFilename)
err := os.MkdirAll(absDirectory, os.ModePerm)
if err != nil {
log.WithError(err).Errorf("Failed to create directory %s for chunked upload", absDirectory)
return fmt.Errorf("failed to create directory %s: %w", absDirectory, err)
}
fmt.Println("🔍 CORE VALIDATION CHECKS:")
fmt.Println(" ✓ server.* - Server configuration (ports, paths, protocols)")
fmt.Println(" ✓ security.* - Security settings (secrets, JWT, authentication)")
fmt.Println(" ✓ logging.* - Logging configuration (levels, files, rotation)")
fmt.Println(" ✓ timeouts.* - Timeout settings (read, write, idle)")
fmt.Println(" ✓ uploads.* - Upload configuration (extensions, chunk size)")
fmt.Println(" ✓ downloads.* - Download configuration (extensions, chunk size)")
fmt.Println(" ✓ workers.* - Worker pool configuration (count, queue size)")
fmt.Println(" ✓ redis.* - Redis configuration (address, credentials)")
fmt.Println(" ✓ clamav.* - ClamAV antivirus configuration")
fmt.Println(" ✓ versioning.* - File versioning configuration")
fmt.Println(" ✓ deduplication.* - File deduplication configuration")
fmt.Println(" ✓ iso.* - ISO filesystem configuration")
fmt.Println()
targetFile, err := os.OpenFile(tempFilename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
log.WithError(err).Error("Failed to open temporary file for chunked upload")
return err
}
defer targetFile.Close()
fmt.Println("🔐 SECURITY CHECKS:")
fmt.Println(" ✓ Secret strength analysis (length, entropy, patterns)")
fmt.Println(" ✓ Default/example value detection")
fmt.Println(" ✓ JWT algorithm security recommendations")
fmt.Println(" ✓ Network binding security (0.0.0.0 warnings)")
fmt.Println(" ✓ File permission analysis")
fmt.Println(" ✓ Debug logging security implications")
fmt.Println()
writer := bufio.NewWriterSize(targetFile, int(conf.Uploads.ChunkSize))
buffer := make([]byte, conf.Uploads.ChunkSize)
fmt.Println("⚡ PERFORMANCE CHECKS:")
fmt.Println(" ✓ Worker count vs CPU cores optimization")
fmt.Println(" ✓ Queue size vs memory usage analysis")
fmt.Println(" ✓ Timeout configuration balance")
fmt.Println(" ✓ Large file handling preparation")
fmt.Println(" ✓ Memory-intensive configuration detection")
fmt.Println()
totalBytes := int64(0)
for {
n, err := r.Body.Read(buffer)
if n > 0 {
totalBytes += int64(n)
_, writeErr := writer.Write(buffer[:n])
if writeErr != nil {
log.WithError(writeErr).Error("Failed to write chunk to temporary file")
return writeErr
}
}
if err != nil {
if err == io.EOF {
break // Finished reading the body
}
log.WithError(err).Error("Error reading from request body")
return err
}
}
fmt.Println("🌐 CONNECTIVITY CHECKS:")
fmt.Println(" ✓ Redis server connectivity testing")
fmt.Println(" ✓ ClamAV socket accessibility")
fmt.Println(" ✓ Network address format validation")
fmt.Println(" ✓ DNS resolution testing")
fmt.Println()
err = writer.Flush()
if err != nil {
log.WithError(err).Error("Failed to flush buffer to temporary file")
return err
}
fmt.Println("💾 SYSTEM RESOURCE CHECKS:")
fmt.Println(" ✓ CPU core availability analysis")
fmt.Println(" ✓ Memory usage monitoring")
fmt.Println(" ✓ Disk space validation")
fmt.Println(" ✓ Directory write permissions")
fmt.Println(" ✓ Goroutine count analysis")
fmt.Println()
log.WithFields(logrus.Fields{
"temp_file": tempFilename,
"total_bytes": totalBytes,
}).Info("Chunked upload completed successfully")
fmt.Println("🔄 CROSS-SECTION VALIDATION:")
fmt.Println(" ✓ Path conflict detection")
fmt.Println(" ✓ Extension compatibility checks")
fmt.Println(" ✓ Configuration consistency validation")
fmt.Println()
uploadSizeBytes.Observe(float64(totalBytes))
return nil
}
// Get file information with caching
func getFileInfo(absFilename string) (os.FileInfo, error) {
if cachedInfo, found := fileInfoCache.Get(absFilename); found {
if info, ok := cachedInfo.(os.FileInfo); ok {
return info, nil
}
}
fileInfo, err := os.Stat(absFilename)
if err != nil {
return nil, err
}
fileInfoCache.Set(absFilename, fileInfo, cache.DefaultExpiration)
return fileInfo, nil
}
// Monitor network changes
func monitorNetwork(ctx context.Context) {
currentIP := getCurrentIPAddress() // Placeholder for the current IP address
for {
select {
case <-ctx.Done():
log.Info("Stopping network monitor.")
return
case <-time.After(10 * time.Second):
newIP := getCurrentIPAddress()
if newIP != currentIP && newIP != "" {
currentIP = newIP
select {
case networkEvents <- NetworkEvent{Type: "IP_CHANGE", Details: currentIP}:
log.WithField("new_ip", currentIP).Info("Queued IP_CHANGE event")
default:
log.Warn("Network event channel is full. Dropping IP_CHANGE event.")
}
}
}
}
}
// Handle network events
func handleNetworkEvents(ctx context.Context) {
for {
select {
case <-ctx.Done():
log.Info("Stopping network event handler.")
return
case event, ok := <-networkEvents:
if !ok {
log.Info("Network events channel closed.")
return
}
switch event.Type {
case "IP_CHANGE":
log.WithField("new_ip", event.Details).Info("Network change detected")
// Example: Update Prometheus gauge or trigger alerts
// activeConnections.Set(float64(getActiveConnections()))
}
// Additional event types can be handled here
}
}
}
// Get current IP address (example)
func getCurrentIPAddress() string {
interfaces, err := net.Interfaces()
if err != nil {
log.WithError(err).Error("Failed to get network interfaces")
return ""
}
for _, iface := range interfaces {
if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagLoopback != 0 {
continue // Skip interfaces that are down or loopback
}
addrs, err := iface.Addrs()
if err != nil {
log.WithError(err).Errorf("Failed to get addresses for interface %s", iface.Name)
continue
}
for _, addr := range addrs {
if ipnet, ok := addr.(*net.IPNet); ok && ipnet.IP.IsGlobalUnicast() && ipnet.IP.To4() != nil {
return ipnet.IP.String()
}
}
}
return ""
}
// setupGracefulShutdown sets up handling for graceful server shutdown
func setupGracefulShutdown(server *http.Server, cancel context.CancelFunc) {
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
go func() {
sig := <-quit
log.Infof("Received signal %s. Initiating shutdown...", sig)
// Create a deadline to wait for.
ctxShutdown, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer shutdownCancel()
// Attempt graceful shutdown
if err := server.Shutdown(ctxShutdown); err != nil {
log.Errorf("Server shutdown failed: %v", err)
} else {
log.Info("Server shutdown gracefully.")
}
// Signal other goroutines to stop
cancel()
// Close the upload, scan, and network event channels
close(uploadQueue)
log.Info("Upload queue closed.")
close(scanQueue)
log.Info("Scan queue closed.")
close(networkEvents)
log.Info("Network events channel closed.")
log.Info("Shutdown process completed. Exiting application.")
os.Exit(0)
}()
}
// Initialize Redis client
func initRedis() {
if !conf.Redis.RedisEnabled {
log.Info("Redis is disabled in configuration.")
return
}
redisClient = redis.NewClient(&redis.Options{
Addr: conf.Redis.RedisAddr,
Password: conf.Redis.RedisPassword,
DB: conf.Redis.RedisDBIndex,
})
// Test the Redis connection
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, err := redisClient.Ping(ctx).Result()
if err != nil {
log.Fatalf("Failed to connect to Redis: %v", err)
}
log.Info("Connected to Redis successfully")
// Set initial connection status
mu.Lock()
redisConnected = true
mu.Unlock()
// Start monitoring Redis health
go MonitorRedisHealth(context.Background(), redisClient, parseDuration(conf.Redis.RedisHealthCheckInterval))
}
// MonitorRedisHealth periodically checks Redis connectivity and updates redisConnected status.
func MonitorRedisHealth(ctx context.Context, client *redis.Client, checkInterval time.Duration) {
ticker := time.NewTicker(checkInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
log.Info("Stopping Redis health monitor.")
return
case <-ticker.C:
err := client.Ping(ctx).Err()
mu.Lock()
if err != nil {
if redisConnected {
log.Errorf("Redis health check failed: %v", err)
}
redisConnected = false
} else {
if !redisConnected {
log.Info("Redis reconnected successfully")
}
redisConnected = true
log.Debug("Redis health check succeeded.")
}
mu.Unlock()
}
}
}
// Helper function to parse duration strings
func parseDuration(durationStr string) time.Duration {
duration, err := time.ParseDuration(durationStr)
if err != nil {
log.WithError(err).Warn("Invalid duration format, using default 30s")
return 30 * time.Second
}
return duration
}
// RunFileCleaner periodically deletes files that exceed the FileTTL duration.
func runFileCleaner(ctx context.Context, storeDir string, ttl time.Duration) {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
log.Info("Stopping file cleaner.")
return
case <-ticker.C:
now := time.Now()
err := filepath.Walk(storeDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() {
return nil
}
if now.Sub(info.ModTime()) > ttl {
err := os.Remove(path)
if err != nil {
log.WithError(err).Errorf("Failed to remove expired file: %s", path)
} else {
log.Infof("Removed expired file: %s", path)
}
}
return nil
})
if err != nil {
log.WithError(err).Error("Error walking store directory for file cleaning")
}
}
}
}
// DeduplicateFiles scans the store directory and removes duplicate files based on SHA256 hash.
// It retains one copy of each unique file and replaces duplicates with hard links.
func DeduplicateFiles(storeDir string) error {
hashMap := make(map[string]string) // map[hash]filepath
var mu sync.Mutex
var wg sync.WaitGroup
fileChan := make(chan string, 100)
// Worker to process files
numWorkers := 10
for i := 0; i < numWorkers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for filePath := range fileChan {
hash, err := computeFileHash(filePath)
if err != nil {
logrus.WithError(err).Errorf("Failed to compute hash for %s", filePath)
continue
}
mu.Lock()
original, exists := hashMap[hash]
if !exists {
hashMap[hash] = filePath
mu.Unlock()
continue
}
mu.Unlock()
// Duplicate found
err = os.Remove(filePath)
if err != nil {
logrus.WithError(err).Errorf("Failed to remove duplicate file %s", filePath)
continue
}
// Create hard link to the original file
err = os.Link(original, filePath)
if err != nil {
logrus.WithError(err).Errorf("Failed to create hard link from %s to %s", original, filePath)
continue
}
logrus.Infof("Removed duplicate %s and linked to %s", filePath, original)
}
}()
}
// Walk through the store directory
err := filepath.Walk(storeDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
logrus.WithError(err).Errorf("Error accessing path %s", path)
return nil
}
if !info.Mode().IsRegular() {
return nil
}
fileChan <- path
return nil
})
if err != nil {
return fmt.Errorf("error walking the path %s: %w", storeDir, err)
}
close(fileChan)
wg.Wait()
return nil
}
// computeFileHash computes the SHA256 hash of the given file.
func computeFileHash(filePath string) (string, error) {
file, err := os.Open(filePath)
if err != nil {
return "", fmt.Errorf("unable to open file %s: %w", filePath, err)
}
defer file.Close()
hasher := sha256.New()
if _, err := io.Copy(hasher, file); err != nil {
return "", fmt.Errorf("error hashing file %s: %w", filePath, err)
}
return hex.EncodeToString(hasher.Sum(nil)), nil
}
// Handle multipart uploads
func handleMultipartUpload(w http.ResponseWriter, r *http.Request, absFilename string) error {
err := r.ParseMultipartForm(32 << 20) // 32MB is the default used by FormFile
if err != nil {
log.WithError(err).Error("Failed to parse multipart form")
http.Error(w, "Failed to parse multipart form", http.StatusBadRequest)
return err
}
file, handler, err := r.FormFile("file")
if err != nil {
log.WithError(err).Error("Failed to retrieve file from form data")
http.Error(w, "Failed to retrieve file from form data", http.StatusBadRequest)
return err
}
defer file.Close()
// Validate file extension
if !isExtensionAllowed(handler.Filename) {
log.WithFields(logrus.Fields{
"filename": handler.Filename,
"extension": filepath.Ext(handler.Filename),
}).Warn("Attempted upload with disallowed file extension")
http.Error(w, "Disallowed file extension. Allowed extensions are: "+strings.Join(conf.Uploads.AllowedExtensions, ", "), http.StatusForbidden)
uploadErrorsTotal.Inc()
return fmt.Errorf("disallowed file extension")
}
// Create a temporary file
tempFilename := absFilename + ".tmp"
tempFile, err := os.OpenFile(tempFilename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666)
if err != nil {
log.WithError(err).Error("Failed to create temporary file")
http.Error(w, "Failed to create temporary file", http.StatusInternalServerError)
return err
}
defer tempFile.Close()
// Copy the uploaded file to the temporary file
_, err = io.Copy(tempFile, file)
if err != nil {
log.WithError(err).Error("Failed to copy uploaded file to temporary file")
http.Error(w, "Failed to copy uploaded file", http.StatusInternalServerError)
return err
}
// Perform ClamAV scan on the temporary file
if clamClient != nil {
err := scanFileWithClamAV(tempFilename)
if err != nil {
log.WithFields(logrus.Fields{
"file": tempFilename,
"error": err,
}).Warn("ClamAV detected a virus or scan failed")
os.Remove(tempFilename)
uploadErrorsTotal.Inc()
return err
}
}
// Handle file versioning if enabled
if conf.Versioning.EnableVersioning {
existing, _ := fileExists(absFilename)
if existing {
err := versionFile(absFilename)
if err != nil {
log.WithFields(logrus.Fields{
"file": absFilename,
"error": err,
}).Error("Error versioning file")
os.Remove(tempFilename)
return err
}
}
}
// Move the temporary file to the final destination
err = os.Rename(tempFilename, absFilename)
if err != nil {
log.WithFields(logrus.Fields{
"temp_file": tempFilename,
"final_file": absFilename,
"error": err,
}).Error("Failed to move file to final destination")
os.Remove(tempFilename)
return err
}
log.WithFields(logrus.Fields{
"file": absFilename,
}).Info("File uploaded and scanned successfully")
uploadsTotal.Inc()
return nil
}
// sanitizeFilePath ensures that the file path is within the designated storage directory
func sanitizeFilePath(baseDir, filePath string) (string, error) {
// Resolve the absolute path
absBaseDir, err := filepath.Abs(baseDir)
if err != nil {
return "", fmt.Errorf("failed to resolve base directory: %w", err)
}
absFilePath, err := filepath.Abs(filepath.Join(absBaseDir, filePath))
if err != nil {
return "", fmt.Errorf("failed to resolve file path: %w", err)
}
// Check if the resolved file path is within the base directory
if !strings.HasPrefix(absFilePath, absBaseDir) {
return "", fmt.Errorf("invalid file path: %s", filePath)
}
return absFilePath, nil
}
// checkStorageSpace ensures that there is enough free space in the storage path
func checkStorageSpace(storagePath string, minFreeBytes int64) error {
var stat syscall.Statfs_t
err := syscall.Statfs(storagePath, &stat)
if err != nil {
return fmt.Errorf("failed to get filesystem stats: %w", err)
}
// Calculate available bytes
availableBytes := stat.Bavail * uint64(stat.Bsize)
if int64(availableBytes) < minFreeBytes {
return fmt.Errorf("not enough free space: %d bytes available, %d bytes required", availableBytes, minFreeBytes)
}
return nil
}
// Function to compute SHA256 checksum of a file
func computeSHA256(filePath string) (string, error) {
file, err := os.Open(filePath)
if err != nil {
return "", fmt.Errorf("failed to open file for checksum: %w", err)
}
defer file.Close()
hasher := sha256.New()
if _, err := io.Copy(hasher, file); err != nil {
return "", fmt.Errorf("failed to compute checksum: %w", err)
}
return hex.EncodeToString(hasher.Sum(nil)), nil
}
// handleDeduplication handles file deduplication using SHA256 checksum and hard links
func handleDeduplication(ctx context.Context, absFilename string) error {
// Compute checksum of the uploaded file
checksum, err := computeSHA256(absFilename)
if err != nil {
log.Errorf("Failed to compute SHA256 for %s: %v", absFilename, err)
return fmt.Errorf("checksum computation failed: %w", err)
}
log.Debugf("Computed checksum for %s: %s", absFilename, checksum)
// Check Redis for existing checksum
existingPath, err := redisClient.Get(ctx, checksum).Result()
if err != nil && err != redis.Nil {
log.Errorf("Redis error while fetching checksum %s: %v", checksum, err)
return fmt.Errorf("redis error: %w", err)
}
if err != redis.Nil {
// Duplicate found, create hard link
log.Infof("Duplicate detected: %s already exists at %s", absFilename, existingPath)
err = os.Link(existingPath, absFilename)
if err != nil {
log.Errorf("Failed to create hard link from %s to %s: %v", existingPath, absFilename, err)
return fmt.Errorf("failed to create hard link: %w", err)
}
log.Infof("Created hard link from %s to %s", existingPath, absFilename)
return nil
}
// No duplicate found, store checksum in Redis
err = redisClient.Set(ctx, checksum, absFilename, 0).Err()
if err != nil {
log.Errorf("Failed to store checksum %s in Redis: %v", checksum, err)
return fmt.Errorf("failed to store checksum in Redis: %w", err)
}
log.Infof("Stored new file checksum in Redis: %s -> %s", checksum, absFilename)
return nil
fmt.Println("📋 USAGE EXAMPLES:")
fmt.Println(" hmac-file-server --validate-config # Full validation")
fmt.Println(" hmac-file-server --check-security # Security checks only")
fmt.Println(" hmac-file-server --check-performance # Performance checks only")
fmt.Println(" hmac-file-server --check-connectivity # Network checks only")
fmt.Println(" hmac-file-server --validate-quiet # Errors only")
fmt.Println(" hmac-file-server --validate-verbose # Detailed output")
fmt.Println(" hmac-file-server --check-fixable # Auto-fixable issues")
fmt.Println()
}