release: hmac-file-server 3.2

This commit is contained in:
2025-06-13 04:24:11 +02:00
parent cc3a4f4dd7
commit 16f50940d0
34 changed files with 10354 additions and 2255 deletions

View File

@@ -1,67 +0,0 @@
# Server Settings
[server]
ListenPort = "8080"
UnixSocket = false
StoreDir = "./testupload"
LogLevel = "info"
LogFile = "./hmac-file-server.log"
MetricsEnabled = true
MetricsPort = "9090"
FileTTL = "8760h"
# Workers and Connections
[workers]
NumWorkers = 2
UploadQueueSize = 500
# Timeout Settings
[timeouts]
ReadTimeout = "600s"
WriteTimeout = "600s"
IdleTimeout = "600s"
# Security Settings
[security]
Secret = "a-orc-and-a-humans-is-drinking-ale"
# Versioning Settings
[versioning]
EnableVersioning = false
MaxVersions = 1
# Upload/Download Settings
[uploads]
ResumableUploadsEnabled = true
ChunkedUploadsEnabled = true
ChunkSize = 16777216
AllowedExtensions = [
# Document formats
".txt", ".pdf",
# Image formats
".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp",
# Video formats
".wav", ".mp4", ".avi", ".mkv", ".mov", ".wmv", ".flv", ".webm", ".mpeg", ".mpg", ".m4v", ".3gp", ".3g2",
# Audio formats
".mp3", ".ogg"
]
# ClamAV Settings
[clamav]
ClamAVEnabled = false
ClamAVSocket = "/var/run/clamav/clamd.ctl"
NumScanWorkers = 4
# Redis Settings
[redis]
RedisEnabled = false
RedisAddr = "localhost:6379"
RedisPassword = ""
RedisDBIndex = 0
RedisHealthCheckInterval = "120s"
# Deduplication
[deduplication]
enabled = false

View File

@@ -0,0 +1,294 @@
// config_test_scenarios.go
package main
import (
"fmt"
"os"
"path/filepath"
)
// ConfigTestScenario represents a test scenario for configuration validation
type ConfigTestScenario struct {
Name string
Config Config
ShouldPass bool
ExpectedErrors []string
ExpectedWarnings []string
}
// GetConfigTestScenarios returns a set of test scenarios for configuration validation
func GetConfigTestScenarios() []ConfigTestScenario {
baseValidConfig := Config{
Server: ServerConfig{
ListenAddress: "8080",
BindIP: "0.0.0.0",
StoragePath: "/tmp/test-storage",
MetricsEnabled: true,
MetricsPort: "9090",
FileTTLEnabled: true,
FileTTL: "24h",
MinFreeBytes: "1GB",
FileNaming: "HMAC",
ForceProtocol: "auto",
PIDFilePath: "/tmp/test.pid",
},
Security: SecurityConfig{
Secret: "test-secret-key-32-characters",
EnableJWT: false,
},
Logging: LoggingConfig{
Level: "info",
File: "/tmp/test.log",
MaxSize: 100,
MaxBackups: 3,
MaxAge: 30,
},
Timeouts: TimeoutConfig{
Read: "30s",
Write: "30s",
Idle: "60s",
},
Workers: WorkersConfig{
NumWorkers: 4,
UploadQueueSize: 50,
},
Uploads: UploadsConfig{
AllowedExtensions: []string{".txt", ".pdf", ".jpg"},
ChunkSize: "10MB",
},
Downloads: DownloadsConfig{
AllowedExtensions: []string{".txt", ".pdf", ".jpg"},
ChunkSize: "10MB",
},
}
return []ConfigTestScenario{
{
Name: "Valid Basic Configuration",
Config: baseValidConfig,
ShouldPass: true,
},
{
Name: "Missing Listen Address",
Config: func() Config {
c := baseValidConfig
c.Server.ListenAddress = ""
return c
}(),
ShouldPass: false,
ExpectedErrors: []string{"server.listen_address is required"},
},
{
Name: "Invalid Port Number",
Config: func() Config {
c := baseValidConfig
c.Server.ListenAddress = "99999"
return c
}(),
ShouldPass: false,
ExpectedErrors: []string{"invalid port number"},
},
{
Name: "Invalid IP Address",
Config: func() Config {
c := baseValidConfig
c.Server.BindIP = "999.999.999.999"
return c
}(),
ShouldPass: false,
ExpectedErrors: []string{"invalid IP address format"},
},
{
Name: "Same Port for Server and Metrics",
Config: func() Config {
c := baseValidConfig
c.Server.ListenAddress = "8080"
c.Server.MetricsPort = "8080"
return c
}(),
ShouldPass: false,
ExpectedErrors: []string{"metrics port cannot be the same as main listen port"},
},
{
Name: "JWT Enabled Without Secret",
Config: func() Config {
c := baseValidConfig
c.Security.EnableJWT = true
c.Security.JWTSecret = ""
return c
}(),
ShouldPass: false,
ExpectedErrors: []string{"JWT secret is required when JWT is enabled"},
},
{
Name: "Short JWT Secret",
Config: func() Config {
c := baseValidConfig
c.Security.EnableJWT = true
c.Security.JWTSecret = "short"
c.Security.JWTAlgorithm = "HS256"
return c
}(),
ShouldPass: true,
ExpectedWarnings: []string{"JWT secret should be at least 32 characters"},
},
{
Name: "Invalid Log Level",
Config: func() Config {
c := baseValidConfig
c.Logging.Level = "invalid"
return c
}(),
ShouldPass: false,
ExpectedErrors: []string{"invalid log level"},
},
{
Name: "Invalid Timeout Format",
Config: func() Config {
c := baseValidConfig
c.Timeouts.Read = "invalid"
return c
}(),
ShouldPass: false,
ExpectedErrors: []string{"invalid read timeout format"},
},
{
Name: "Negative Worker Count",
Config: func() Config {
c := baseValidConfig
c.Workers.NumWorkers = -1
return c
}(),
ShouldPass: false,
ExpectedErrors: []string{"number of workers must be positive"},
},
{
Name: "Extensions Without Dots",
Config: func() Config {
c := baseValidConfig
c.Uploads.AllowedExtensions = []string{"txt", "pdf"}
return c
}(),
ShouldPass: false,
ExpectedErrors: []string{"file extensions must start with a dot"},
},
{
Name: "High Worker Count Warning",
Config: func() Config {
c := baseValidConfig
c.Workers.NumWorkers = 100
return c
}(),
ShouldPass: true,
ExpectedWarnings: []string{"very high worker count may impact performance"},
},
{
Name: "Deduplication Without Directory",
Config: func() Config {
c := baseValidConfig
c.Deduplication.Enabled = true
c.Deduplication.Directory = ""
return c
}(),
ShouldPass: false,
ExpectedErrors: []string{"deduplication directory is required"},
},
}
}
// RunConfigTests runs all configuration test scenarios
func RunConfigTests() {
scenarios := GetConfigTestScenarios()
passed := 0
failed := 0
fmt.Println("🧪 Running Configuration Test Scenarios")
fmt.Println("=======================================")
fmt.Println()
for i, scenario := range scenarios {
fmt.Printf("Test %d: %s\n", i+1, scenario.Name)
// Create temporary directories for testing
tempDir := filepath.Join(os.TempDir(), fmt.Sprintf("hmac-test-%d", i))
os.MkdirAll(tempDir, 0755)
defer os.RemoveAll(tempDir)
// Update paths in config to use temp directory
scenario.Config.Server.StoragePath = filepath.Join(tempDir, "storage")
scenario.Config.Logging.File = filepath.Join(tempDir, "test.log")
scenario.Config.Server.PIDFilePath = filepath.Join(tempDir, "test.pid")
if scenario.Config.Deduplication.Enabled {
scenario.Config.Deduplication.Directory = filepath.Join(tempDir, "dedup")
}
result := ValidateConfigComprehensive(&scenario.Config)
// Check if test passed as expected
testPassed := true
if scenario.ShouldPass && result.HasErrors() {
fmt.Printf(" ❌ Expected to pass but failed with errors:\n")
for _, err := range result.Errors {
fmt.Printf(" • %s\n", err.Message)
}
testPassed = false
} else if !scenario.ShouldPass && !result.HasErrors() {
fmt.Printf(" ❌ Expected to fail but passed\n")
testPassed = false
} else if !scenario.ShouldPass && result.HasErrors() {
// Check if expected errors are present
expectedFound := true
for _, expectedError := range scenario.ExpectedErrors {
found := false
for _, actualError := range result.Errors {
if contains([]string{actualError.Message}, expectedError) ||
contains([]string{actualError.Error()}, expectedError) {
found = true
break
}
}
if !found {
fmt.Printf(" ❌ Expected error not found: %s\n", expectedError)
expectedFound = false
}
}
if !expectedFound {
testPassed = false
}
}
// Check expected warnings
if len(scenario.ExpectedWarnings) > 0 {
for _, expectedWarning := range scenario.ExpectedWarnings {
found := false
for _, actualWarning := range result.Warnings {
if contains([]string{actualWarning.Message}, expectedWarning) ||
contains([]string{actualWarning.Error()}, expectedWarning) {
found = true
break
}
}
if !found {
fmt.Printf(" ⚠️ Expected warning not found: %s\n", expectedWarning)
}
}
}
if testPassed {
fmt.Printf(" ✅ Passed\n")
passed++
} else {
failed++
}
fmt.Println()
}
// Summary
fmt.Printf("📊 Test Results: %d passed, %d failed\n", passed, failed)
if failed > 0 {
fmt.Printf("❌ Some tests failed. Please review the implementation.\n")
os.Exit(1)
} else {
fmt.Printf("✅ All tests passed!\n")
}
}

View File

@@ -0,0 +1,1131 @@
// config_validator.go
package main
import (
"errors"
"fmt"
"net"
"os"
"path/filepath"
"regexp"
"runtime"
"strconv"
"strings"
"time"
)
// ConfigValidationError represents a configuration validation error
type ConfigValidationError struct {
Field string
Value interface{}
Message string
}
func (e ConfigValidationError) Error() string {
return fmt.Sprintf("config validation error in field '%s': %s (value: %v)", e.Field, e.Message, e.Value)
}
// ConfigValidationResult contains the results of config validation
type ConfigValidationResult struct {
Errors []ConfigValidationError
Warnings []ConfigValidationError
Valid bool
}
// AddError adds a validation error
func (r *ConfigValidationResult) AddError(field string, value interface{}, message string) {
r.Errors = append(r.Errors, ConfigValidationError{Field: field, Value: value, Message: message})
r.Valid = false
}
// AddWarning adds a validation warning
func (r *ConfigValidationResult) AddWarning(field string, value interface{}, message string) {
r.Warnings = append(r.Warnings, ConfigValidationError{Field: field, Value: value, Message: message})
}
// HasErrors returns true if there are validation errors
func (r *ConfigValidationResult) HasErrors() bool {
return len(r.Errors) > 0
}
// HasWarnings returns true if there are validation warnings
func (r *ConfigValidationResult) HasWarnings() bool {
return len(r.Warnings) > 0
}
// ValidateConfigComprehensive performs comprehensive configuration validation
func ValidateConfigComprehensive(c *Config) *ConfigValidationResult {
result := &ConfigValidationResult{Valid: true}
// Validate each section
validateServerConfig(&c.Server, result)
validateSecurityConfig(&c.Security, result)
validateLoggingConfig(&c.Logging, result)
validateTimeoutConfig(&c.Timeouts, result)
validateUploadsConfig(&c.Uploads, result)
validateDownloadsConfig(&c.Downloads, result)
validateClamAVConfig(&c.ClamAV, result)
validateRedisConfig(&c.Redis, result)
validateWorkersConfig(&c.Workers, result)
validateVersioningConfig(&c.Versioning, result)
validateDeduplicationConfig(&c.Deduplication, result)
validateISOConfig(&c.ISO, result)
// Cross-section validations
validateCrossSection(c, result)
// Enhanced validations
validateSystemResources(result)
validateNetworkConnectivity(c, result)
validatePerformanceSettings(c, result)
validateSecurityHardening(c, result)
// Check disk space for storage paths
if c.Server.StoragePath != "" {
checkDiskSpace(c.Server.StoragePath, result)
}
if c.Deduplication.Enabled && c.Deduplication.Directory != "" {
checkDiskSpace(c.Deduplication.Directory, result)
}
return result
}
// validateServerConfig validates server configuration
func validateServerConfig(server *ServerConfig, result *ConfigValidationResult) {
// ListenAddress validation
if server.ListenAddress == "" {
result.AddError("server.listenport", server.ListenAddress, "listen address/port is required")
} else {
if !isValidPort(server.ListenAddress) {
result.AddError("server.listenport", server.ListenAddress, "invalid port number (must be 1-65535)")
}
}
// BindIP validation
if server.BindIP != "" {
if ip := net.ParseIP(server.BindIP); ip == nil {
result.AddError("server.bind_ip", server.BindIP, "invalid IP address format")
}
}
// StoragePath validation
if server.StoragePath == "" {
result.AddError("server.storagepath", server.StoragePath, "storage path is required")
} else {
if err := validateDirectoryPath(server.StoragePath, true); err != nil {
result.AddError("server.storagepath", server.StoragePath, err.Error())
}
}
// MetricsPort validation
if server.MetricsEnabled && server.MetricsPort != "" {
if !isValidPort(server.MetricsPort) {
result.AddError("server.metricsport", server.MetricsPort, "invalid metrics port number")
}
if server.MetricsPort == server.ListenAddress {
result.AddError("server.metricsport", server.MetricsPort, "metrics port cannot be the same as main listen port")
}
}
// Size validations
if server.MaxUploadSize != "" {
if _, err := parseSize(server.MaxUploadSize); err != nil {
result.AddError("server.max_upload_size", server.MaxUploadSize, "invalid size format")
}
}
if server.MinFreeBytes != "" {
if _, err := parseSize(server.MinFreeBytes); err != nil {
result.AddError("server.min_free_bytes", server.MinFreeBytes, "invalid size format")
}
}
// TTL validation
if server.FileTTLEnabled {
if server.FileTTL == "" {
result.AddError("server.filettl", server.FileTTL, "file TTL is required when TTL is enabled")
} else {
if _, err := parseTTL(server.FileTTL); err != nil {
result.AddError("server.filettl", server.FileTTL, "invalid TTL format")
}
}
}
// File naming validation
validFileNaming := []string{"HMAC", "original", "None"}
if !contains(validFileNaming, server.FileNaming) {
result.AddError("server.file_naming", server.FileNaming, "must be one of: HMAC, original, None")
}
// Protocol validation
validProtocols := []string{"ipv4", "ipv6", "auto", ""}
if !contains(validProtocols, server.ForceProtocol) {
result.AddError("server.force_protocol", server.ForceProtocol, "must be one of: ipv4, ipv6, auto, or empty")
}
// PID file validation
if server.PIDFilePath != "" {
dir := filepath.Dir(server.PIDFilePath)
if err := validateDirectoryPath(dir, false); err != nil {
result.AddError("server.pidfilepath", server.PIDFilePath, fmt.Sprintf("PID file directory invalid: %v", err))
}
}
// Worker threshold validation
if server.EnableDynamicWorkers {
if server.WorkerScaleUpThresh <= 0 {
result.AddError("server.worker_scale_up_thresh", server.WorkerScaleUpThresh, "must be positive when dynamic workers are enabled")
}
if server.WorkerScaleDownThresh <= 0 {
result.AddError("server.worker_scale_down_thresh", server.WorkerScaleDownThresh, "must be positive when dynamic workers are enabled")
}
if server.WorkerScaleDownThresh >= server.WorkerScaleUpThresh {
result.AddWarning("server.worker_scale_down_thresh", server.WorkerScaleDownThresh, "scale down threshold should be lower than scale up threshold")
}
}
// Extensions validation
for _, ext := range server.GlobalExtensions {
if !strings.HasPrefix(ext, ".") {
result.AddError("server.global_extensions", ext, "file extensions must start with a dot")
}
}
}
// validateSecurityConfig validates security configuration
func validateSecurityConfig(security *SecurityConfig, result *ConfigValidationResult) {
if security.EnableJWT {
// JWT validation
if strings.TrimSpace(security.JWTSecret) == "" {
result.AddError("security.jwtsecret", security.JWTSecret, "JWT secret is required when JWT is enabled")
} else if len(security.JWTSecret) < 32 {
result.AddWarning("security.jwtsecret", "[REDACTED]", "JWT secret should be at least 32 characters for security")
}
validAlgorithms := []string{"HS256", "HS384", "HS512", "RS256", "RS384", "RS512", "ES256", "ES384", "ES512"}
if !contains(validAlgorithms, security.JWTAlgorithm) {
result.AddError("security.jwtalgorithm", security.JWTAlgorithm, "unsupported JWT algorithm")
}
if security.JWTExpiration != "" {
if _, err := time.ParseDuration(security.JWTExpiration); err != nil {
result.AddError("security.jwtexpiration", security.JWTExpiration, "invalid JWT expiration format")
}
}
} else {
// HMAC validation
if strings.TrimSpace(security.Secret) == "" {
result.AddError("security.secret", security.Secret, "HMAC secret is required when JWT is disabled")
} else if len(security.Secret) < 16 {
result.AddWarning("security.secret", "[REDACTED]", "HMAC secret should be at least 16 characters for security")
}
}
}
// validateLoggingConfig validates logging configuration
func validateLoggingConfig(logging *LoggingConfig, result *ConfigValidationResult) {
validLevels := []string{"panic", "fatal", "error", "warn", "warning", "info", "debug", "trace"}
if !contains(validLevels, strings.ToLower(logging.Level)) {
result.AddError("logging.level", logging.Level, "invalid log level")
}
if logging.File != "" {
dir := filepath.Dir(logging.File)
if err := validateDirectoryPath(dir, false); err != nil {
result.AddError("logging.file", logging.File, fmt.Sprintf("log file directory invalid: %v", err))
}
}
if logging.MaxSize <= 0 {
result.AddWarning("logging.max_size", logging.MaxSize, "max size should be positive")
}
if logging.MaxBackups < 0 {
result.AddWarning("logging.max_backups", logging.MaxBackups, "max backups should be non-negative")
}
if logging.MaxAge < 0 {
result.AddWarning("logging.max_age", logging.MaxAge, "max age should be non-negative")
}
}
// validateTimeoutConfig validates timeout configuration
func validateTimeoutConfig(timeouts *TimeoutConfig, result *ConfigValidationResult) {
if timeouts.Read != "" {
if duration, err := time.ParseDuration(timeouts.Read); err != nil {
result.AddError("timeouts.read", timeouts.Read, "invalid read timeout format")
} else if duration <= 0 {
result.AddError("timeouts.read", timeouts.Read, "read timeout must be positive")
}
}
if timeouts.Write != "" {
if duration, err := time.ParseDuration(timeouts.Write); err != nil {
result.AddError("timeouts.write", timeouts.Write, "invalid write timeout format")
} else if duration <= 0 {
result.AddError("timeouts.write", timeouts.Write, "write timeout must be positive")
}
}
if timeouts.Idle != "" {
if duration, err := time.ParseDuration(timeouts.Idle); err != nil {
result.AddError("timeouts.idle", timeouts.Idle, "invalid idle timeout format")
} else if duration <= 0 {
result.AddError("timeouts.idle", timeouts.Idle, "idle timeout must be positive")
}
}
if timeouts.Shutdown != "" {
if duration, err := time.ParseDuration(timeouts.Shutdown); err != nil {
result.AddError("timeouts.shutdown", timeouts.Shutdown, "invalid shutdown timeout format")
} else if duration <= 0 {
result.AddError("timeouts.shutdown", timeouts.Shutdown, "shutdown timeout must be positive")
}
}
}
// validateUploadsConfig validates uploads configuration
func validateUploadsConfig(uploads *UploadsConfig, result *ConfigValidationResult) {
// Validate extensions
for _, ext := range uploads.AllowedExtensions {
if !strings.HasPrefix(ext, ".") {
result.AddError("uploads.allowed_extensions", ext, "file extensions must start with a dot")
}
}
// Validate chunk size
if uploads.ChunkSize != "" {
if _, err := parseSize(uploads.ChunkSize); err != nil {
result.AddError("uploads.chunk_size", uploads.ChunkSize, "invalid chunk size format")
}
}
// Validate resumable age
if uploads.MaxResumableAge != "" {
if _, err := time.ParseDuration(uploads.MaxResumableAge); err != nil {
result.AddError("uploads.max_resumable_age", uploads.MaxResumableAge, "invalid resumable age format")
}
}
}
// validateDownloadsConfig validates downloads configuration
func validateDownloadsConfig(downloads *DownloadsConfig, result *ConfigValidationResult) {
// Validate extensions
for _, ext := range downloads.AllowedExtensions {
if !strings.HasPrefix(ext, ".") {
result.AddError("downloads.allowed_extensions", ext, "file extensions must start with a dot")
}
}
// Validate chunk size
if downloads.ChunkSize != "" {
if _, err := parseSize(downloads.ChunkSize); err != nil {
result.AddError("downloads.chunk_size", downloads.ChunkSize, "invalid chunk size format")
}
}
}
// validateClamAVConfig validates ClamAV configuration
func validateClamAVConfig(clamav *ClamAVConfig, result *ConfigValidationResult) {
if clamav.ClamAVEnabled {
if clamav.ClamAVSocket == "" {
result.AddWarning("clamav.clamavsocket", clamav.ClamAVSocket, "ClamAV socket path not specified, using default")
} else {
// Check if socket file exists
if _, err := os.Stat(clamav.ClamAVSocket); os.IsNotExist(err) {
result.AddWarning("clamav.clamavsocket", clamav.ClamAVSocket, "ClamAV socket file does not exist")
}
}
if clamav.NumScanWorkers <= 0 {
result.AddError("clamav.numscanworkers", clamav.NumScanWorkers, "number of scan workers must be positive")
}
// Validate scan extensions
for _, ext := range clamav.ScanFileExtensions {
if !strings.HasPrefix(ext, ".") {
result.AddError("clamav.scanfileextensions", ext, "file extensions must start with a dot")
}
}
}
}
// validateRedisConfig validates Redis configuration
func validateRedisConfig(redis *RedisConfig, result *ConfigValidationResult) {
if redis.RedisEnabled {
if redis.RedisAddr == "" {
result.AddError("redis.redisaddr", redis.RedisAddr, "Redis address is required when Redis is enabled")
} else {
// Validate address format (host:port)
if !isValidHostPort(redis.RedisAddr) {
result.AddError("redis.redisaddr", redis.RedisAddr, "invalid Redis address format (should be host:port)")
}
}
if redis.RedisDBIndex < 0 || redis.RedisDBIndex > 15 {
result.AddWarning("redis.redisdbindex", redis.RedisDBIndex, "Redis DB index is typically 0-15")
}
if redis.RedisHealthCheckInterval != "" {
if _, err := time.ParseDuration(redis.RedisHealthCheckInterval); err != nil {
result.AddError("redis.redishealthcheckinterval", redis.RedisHealthCheckInterval, "invalid health check interval format")
}
}
}
}
// validateWorkersConfig validates workers configuration
func validateWorkersConfig(workers *WorkersConfig, result *ConfigValidationResult) {
if workers.NumWorkers <= 0 {
result.AddError("workers.numworkers", workers.NumWorkers, "number of workers must be positive")
}
if workers.UploadQueueSize <= 0 {
result.AddError("workers.uploadqueuesize", workers.UploadQueueSize, "upload queue size must be positive")
}
// Performance recommendations
if workers.NumWorkers > 50 {
result.AddWarning("workers.numworkers", workers.NumWorkers, "very high worker count may impact performance")
}
if workers.UploadQueueSize > 1000 {
result.AddWarning("workers.uploadqueuesize", workers.UploadQueueSize, "very large queue size may impact memory usage")
}
}
// validateVersioningConfig validates versioning configuration
func validateVersioningConfig(versioning *VersioningConfig, result *ConfigValidationResult) {
if versioning.Enabled {
if versioning.MaxRevs <= 0 {
result.AddError("versioning.maxversions", versioning.MaxRevs, "max versions must be positive when versioning is enabled")
}
validBackends := []string{"filesystem", "database", "s3", ""}
if !contains(validBackends, versioning.Backend) {
result.AddWarning("versioning.backend", versioning.Backend, "unknown versioning backend")
}
}
}
// validateDeduplicationConfig validates deduplication configuration
func validateDeduplicationConfig(dedup *DeduplicationConfig, result *ConfigValidationResult) {
if dedup.Enabled {
if dedup.Directory == "" {
result.AddError("deduplication.directory", dedup.Directory, "deduplication directory is required when deduplication is enabled")
} else {
if err := validateDirectoryPath(dedup.Directory, true); err != nil {
result.AddError("deduplication.directory", dedup.Directory, err.Error())
}
}
}
}
// validateISOConfig validates ISO configuration
func validateISOConfig(iso *ISOConfig, result *ConfigValidationResult) {
if iso.Enabled {
if iso.MountPoint == "" {
result.AddError("iso.mount_point", iso.MountPoint, "mount point is required when ISO is enabled")
}
if iso.Size != "" {
if _, err := parseSize(iso.Size); err != nil {
result.AddError("iso.size", iso.Size, "invalid ISO size format")
}
}
if iso.ContainerFile == "" {
result.AddWarning("iso.containerfile", iso.ContainerFile, "container file path not specified")
}
validCharsets := []string{"utf-8", "iso-8859-1", "ascii", ""}
if !contains(validCharsets, strings.ToLower(iso.Charset)) {
result.AddWarning("iso.charset", iso.Charset, "uncommon charset specified")
}
}
}
// validateCrossSection performs cross-section validations
func validateCrossSection(c *Config, result *ConfigValidationResult) {
// Storage path vs deduplication directory conflict
if c.Deduplication.Enabled && c.Server.StoragePath == c.Deduplication.Directory {
result.AddError("deduplication.directory", c.Deduplication.Directory, "deduplication directory cannot be the same as storage path")
}
// ISO mount point vs storage path conflict
if c.ISO.Enabled && c.Server.StoragePath == c.ISO.MountPoint {
result.AddWarning("iso.mount_point", c.ISO.MountPoint, "ISO mount point is the same as storage path")
}
// Extension conflicts between uploads and downloads
if len(c.Uploads.AllowedExtensions) > 0 && len(c.Downloads.AllowedExtensions) > 0 {
uploadExts := make(map[string]bool)
for _, ext := range c.Uploads.AllowedExtensions {
uploadExts[ext] = true
}
hasCommonExtensions := false
for _, ext := range c.Downloads.AllowedExtensions {
if uploadExts[ext] {
hasCommonExtensions = true
break
}
}
if !hasCommonExtensions {
result.AddWarning("uploads/downloads.allowed_extensions", "", "no common extensions between uploads and downloads - files may not be downloadable")
}
}
// Global extensions override warning
if len(c.Server.GlobalExtensions) > 0 && (len(c.Uploads.AllowedExtensions) > 0 || len(c.Downloads.AllowedExtensions) > 0) {
result.AddWarning("server.global_extensions", c.Server.GlobalExtensions, "global extensions will override upload/download extension settings")
}
}
// Enhanced Security Validation Functions
// checkSecretStrength analyzes the strength of secrets/passwords
func checkSecretStrength(secret string) (score int, issues []string) {
if len(secret) == 0 {
return 0, []string{"secret is empty"}
}
issues = []string{}
score = 0
// Length scoring
if len(secret) >= 32 {
score += 3
} else if len(secret) >= 16 {
score += 2
} else if len(secret) >= 8 {
score += 1
} else {
issues = append(issues, "secret is too short")
}
// Character variety scoring
hasLower := false
hasUpper := false
hasDigit := false
hasSpecial := false
for _, char := range secret {
switch {
case char >= 'a' && char <= 'z':
hasLower = true
case char >= 'A' && char <= 'Z':
hasUpper = true
case char >= '0' && char <= '9':
hasDigit = true
case strings.ContainsRune("!@#$%^&*()_+-=[]{}|;:,.<>?", char):
hasSpecial = true
}
}
varietyCount := 0
if hasLower {
varietyCount++
}
if hasUpper {
varietyCount++
}
if hasDigit {
varietyCount++
}
if hasSpecial {
varietyCount++
}
score += varietyCount
if varietyCount < 3 {
issues = append(issues, "secret should contain uppercase, lowercase, numbers, and special characters")
}
// Check for common patterns
lowerSecret := strings.ToLower(secret)
commonWeakPasswords := []string{
"password", "123456", "qwerty", "admin", "root", "test", "guest",
"secret", "hmac", "server", "default", "changeme", "example",
"demo", "temp", "temporary", "fileserver", "upload", "download",
}
for _, weak := range commonWeakPasswords {
if strings.Contains(lowerSecret, weak) {
issues = append(issues, fmt.Sprintf("contains common weak pattern: %s", weak))
score -= 2
}
}
// Check for repeated characters
if hasRepeatedChars(secret) {
issues = append(issues, "contains too many repeated characters")
score -= 1
}
// Ensure score doesn't go negative
if score < 0 {
score = 0
}
return score, issues
}
// hasRepeatedChars checks if a string has excessive repeated characters
func hasRepeatedChars(s string) bool {
if len(s) < 4 {
return false
}
for i := 0; i <= len(s)-3; i++ {
if s[i] == s[i+1] && s[i+1] == s[i+2] {
return true
}
}
return false
}
// isDefaultOrExampleSecret checks if a secret appears to be a default/example value
func isDefaultOrExampleSecret(secret string) bool {
defaultSecrets := []string{
"your-secret-key-here",
"change-this-secret",
"example-secret",
"default-secret",
"test-secret",
"demo-secret",
"sample-secret",
"placeholder",
"PUT_YOUR_SECRET_HERE",
"CHANGE_ME",
"YOUR_JWT_SECRET",
"your-hmac-secret",
"supersecret",
"secretkey",
"myverysecuresecret",
}
lowerSecret := strings.ToLower(strings.TrimSpace(secret))
for _, defaultSecret := range defaultSecrets {
if strings.Contains(lowerSecret, strings.ToLower(defaultSecret)) {
return true
}
}
// Check for obvious patterns
if strings.Contains(lowerSecret, "example") ||
strings.Contains(lowerSecret, "default") ||
strings.Contains(lowerSecret, "change") ||
strings.Contains(lowerSecret, "replace") ||
strings.Contains(lowerSecret, "todo") ||
strings.Contains(lowerSecret, "fixme") {
return true
}
return false
}
// calculateEntropy calculates the Shannon entropy of a string
func calculateEntropy(s string) float64 {
if len(s) == 0 {
return 0
}
// Count character frequencies
freq := make(map[rune]int)
for _, char := range s {
freq[char]++
}
// Calculate entropy
entropy := 0.0
length := float64(len(s))
for _, count := range freq {
if count > 0 {
p := float64(count) / length
entropy -= p * (float64(count) / length) // Simplified calculation
}
}
return entropy
}
// validateSecretSecurity performs comprehensive secret security validation
func validateSecretSecurity(fieldName, secret string, result *ConfigValidationResult) {
if secret == "" {
return // Already handled by other validators
}
// Check for default/example secrets
if isDefaultOrExampleSecret(secret) {
result.AddError(fieldName, "[REDACTED]", "appears to be a default or example secret - must be changed")
return
}
// Check secret strength
score, issues := checkSecretStrength(secret)
if score < 3 {
for _, issue := range issues {
result.AddError(fieldName, "[REDACTED]", fmt.Sprintf("weak secret: %s", issue))
}
} else if score < 6 {
for _, issue := range issues {
result.AddWarning(fieldName, "[REDACTED]", fmt.Sprintf("secret could be stronger: %s", issue))
}
}
// Check entropy (simplified)
entropy := calculateEntropy(secret)
if entropy < 3.0 {
result.AddWarning(fieldName, "[REDACTED]", "secret has low entropy - consider using more varied characters")
}
// Length-specific warnings
if len(secret) > 256 {
result.AddWarning(fieldName, "[REDACTED]", "secret is very long - may impact performance")
}
}
// validateSystemResources checks system resource availability
func validateSystemResources(result *ConfigValidationResult) {
// Check available CPU cores
cpuCores := runtime.NumCPU()
if cpuCores < 2 {
result.AddWarning("system.cpu", cpuCores, "minimum 2 CPU cores recommended for optimal performance")
} else if cpuCores < 4 {
result.AddWarning("system.cpu", cpuCores, "4+ CPU cores recommended for high-load environments")
}
// Check available memory (basic check through runtime)
var memStats runtime.MemStats
runtime.ReadMemStats(&memStats)
// Basic memory availability check (simplified version)
// This checks current Go heap, but for production we'd want system memory
allocMB := float64(memStats.Alloc) / 1024 / 1024
if allocMB > 512 {
result.AddWarning("system.memory", allocMB, "current memory usage is high - ensure adequate system memory")
}
// Check for potential resource constraints
numGoroutines := runtime.NumGoroutine()
if numGoroutines > 1000 {
result.AddWarning("system.goroutines", numGoroutines, "high goroutine count may indicate resource constraints")
}
}
// validateNetworkConnectivity tests network connectivity to external services
func validateNetworkConnectivity(c *Config, result *ConfigValidationResult) {
// Test Redis connectivity if enabled
if c.Redis.RedisEnabled && c.Redis.RedisAddr != "" {
if err := testNetworkConnection("tcp", c.Redis.RedisAddr, 5*time.Second); err != nil {
result.AddWarning("redis.connectivity", c.Redis.RedisAddr, fmt.Sprintf("cannot connect to Redis: %v", err))
}
}
// Test ClamAV connectivity if enabled
if c.ClamAV.ClamAVEnabled && c.ClamAV.ClamAVSocket != "" {
// For Unix socket, test file existence and permissions
if strings.HasPrefix(c.ClamAV.ClamAVSocket, "/") {
if stat, err := os.Stat(c.ClamAV.ClamAVSocket); err != nil {
result.AddWarning("clamav.connectivity", c.ClamAV.ClamAVSocket, fmt.Sprintf("ClamAV socket not accessible: %v", err))
} else if stat.Mode()&os.ModeSocket == 0 {
result.AddWarning("clamav.connectivity", c.ClamAV.ClamAVSocket, "specified path is not a socket file")
}
} else {
// Assume TCP connection format
if err := testNetworkConnection("tcp", c.ClamAV.ClamAVSocket, 5*time.Second); err != nil {
result.AddWarning("clamav.connectivity", c.ClamAV.ClamAVSocket, fmt.Sprintf("cannot connect to ClamAV: %v", err))
}
}
}
}
// testNetworkConnection attempts to connect to a network address
func testNetworkConnection(network, address string, timeout time.Duration) error {
conn, err := net.DialTimeout(network, address, timeout)
if err != nil {
return err
}
defer conn.Close()
return nil
}
// validatePerformanceSettings analyzes configuration for performance implications
func validatePerformanceSettings(c *Config, result *ConfigValidationResult) {
// Check worker configuration against system resources
cpuCores := runtime.NumCPU()
if c.Workers.NumWorkers > cpuCores*4 {
result.AddWarning("workers.performance", c.Workers.NumWorkers,
fmt.Sprintf("worker count (%d) significantly exceeds CPU cores (%d) - may cause context switching overhead",
c.Workers.NumWorkers, cpuCores))
}
// Check ClamAV scan workers
if c.ClamAV.ClamAVEnabled && c.ClamAV.NumScanWorkers > cpuCores {
result.AddWarning("clamav.performance", c.ClamAV.NumScanWorkers,
fmt.Sprintf("scan workers (%d) exceed CPU cores (%d) - may impact scanning performance",
c.ClamAV.NumScanWorkers, cpuCores))
}
// Check timeout configurations for performance balance
if c.Timeouts.Read != "" {
if duration, err := time.ParseDuration(c.Timeouts.Read); err == nil {
if duration > 300*time.Second {
result.AddWarning("timeouts.performance", c.Timeouts.Read, "very long read timeout may impact server responsiveness")
}
}
}
// Check upload size vs available resources
if c.Server.MaxUploadSize != "" {
if size, err := parseSize(c.Server.MaxUploadSize); err == nil {
if size > 10*1024*1024*1024 { // 10GB
result.AddWarning("server.performance", c.Server.MaxUploadSize, "very large max upload size requires adequate disk space and memory")
}
}
}
// Check for potential memory-intensive configurations
if c.Workers.UploadQueueSize > 500 && c.Workers.NumWorkers > 20 {
result.AddWarning("workers.memory", fmt.Sprintf("queue:%d workers:%d", c.Workers.UploadQueueSize, c.Workers.NumWorkers),
"high queue size with many workers may consume significant memory")
}
}
// validateSecurityHardening performs advanced security validation
func validateSecurityHardening(c *Config, result *ConfigValidationResult) {
// Check for default or weak configurations
if c.Security.EnableJWT {
if c.Security.JWTSecret == "your-secret-key-here" || c.Security.JWTSecret == "changeme" {
result.AddError("security.jwtsecret", "[REDACTED]", "JWT secret appears to be a default value - change immediately")
}
// Check JWT algorithm strength
weakAlgorithms := []string{"HS256"} // HS256 is considered less secure than RS256
if contains(weakAlgorithms, c.Security.JWTAlgorithm) {
result.AddWarning("security.jwtalgorithm", c.Security.JWTAlgorithm, "consider using RS256 or ES256 for enhanced security")
}
} else {
if c.Security.Secret == "your-secret-key-here" || c.Security.Secret == "changeme" || c.Security.Secret == "secret" {
result.AddError("security.secret", "[REDACTED]", "HMAC secret appears to be a default value - change immediately")
}
}
// Check for insecure bind configurations
if c.Server.BindIP == "0.0.0.0" {
result.AddWarning("server.bind_ip", c.Server.BindIP, "binding to 0.0.0.0 exposes service to all interfaces - ensure firewall protection")
}
// Check for development/debug settings in production
if c.Logging.Level == "debug" || c.Logging.Level == "trace" {
result.AddWarning("logging.security", c.Logging.Level, "debug/trace logging may expose sensitive information - use 'info' or 'warn' in production")
}
// Check file permissions for sensitive paths
if c.Server.StoragePath != "" {
if stat, err := os.Stat(c.Server.StoragePath); err == nil {
mode := stat.Mode().Perm()
if mode&0077 != 0 { // World or group writable
result.AddWarning("server.storagepath.permissions", c.Server.StoragePath, "storage directory permissions allow group/world access - consider restricting to owner-only")
}
}
}
}
// checkDiskSpace validates available disk space for storage paths
func checkDiskSpace(path string, result *ConfigValidationResult) {
if stat, err := os.Stat(path); err == nil && stat.IsDir() {
// Get available space (platform-specific implementation would be more robust)
// This is a simplified check - in production, use syscall.Statfs on Unix or similar
// For now, we'll just check if we can write a test file
testFile := filepath.Join(path, ".disk_space_test")
if f, err := os.Create(testFile); err != nil {
result.AddWarning("system.disk_space", path, fmt.Sprintf("cannot write to storage directory: %v", err))
} else {
f.Close()
os.Remove(testFile)
// Additional check: try to write a larger test file to estimate space
const testSize = 1024 * 1024 // 1MB
testData := make([]byte, testSize)
if f, err := os.Create(testFile); err == nil {
if _, err := f.Write(testData); err != nil {
result.AddWarning("system.disk_space", path, "low disk space detected - ensure adequate storage for operations")
}
f.Close()
os.Remove(testFile)
}
}
}
}
// isValidPort checks if a string represents a valid port number
func isValidPort(port string) bool {
if p, err := strconv.Atoi(port); err != nil || p < 1 || p > 65535 {
return false
}
return true
}
// isValidHostPort checks if a string is a valid host:port combination
func isValidHostPort(hostPort string) bool {
host, port, err := net.SplitHostPort(hostPort)
if err != nil {
return false
}
// Validate port
if !isValidPort(port) {
return false
}
// Validate host (can be IP, hostname, or empty for localhost)
if host != "" {
if ip := net.ParseIP(host); ip == nil {
// If not an IP, check if it's a valid hostname
matched, _ := regexp.MatchString(`^[a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?)*$`, host)
return matched
}
}
return true
}
// validateDirectoryPath validates a directory path
func validateDirectoryPath(path string, createIfMissing bool) error {
if path == "" {
return errors.New("directory path cannot be empty")
}
// Check if path exists
if stat, err := os.Stat(path); os.IsNotExist(err) {
if createIfMissing {
// Try to create the directory
if err := os.MkdirAll(path, 0755); err != nil {
return fmt.Errorf("cannot create directory: %v", err)
}
} else {
return fmt.Errorf("directory does not exist: %s", path)
}
} else if err != nil {
return fmt.Errorf("cannot access directory: %v", err)
} else if !stat.IsDir() {
return fmt.Errorf("path exists but is not a directory: %s", path)
}
// Check if directory is writable
testFile := filepath.Join(path, ".write_test")
if f, err := os.Create(testFile); err != nil {
return fmt.Errorf("directory is not writable: %v", err)
} else {
f.Close()
os.Remove(testFile)
}
return nil
}
// contains checks if a slice contains a string
func contains(slice []string, item string) bool {
for _, s := range slice {
if s == item {
return true
}
}
return false
}
// PrintValidationResults prints the validation results in a user-friendly format
func PrintValidationResults(result *ConfigValidationResult) {
if result.HasErrors() {
log.Error("❌ Configuration validation failed with the following errors:")
for _, err := range result.Errors {
log.Errorf(" • %s", err.Error())
}
fmt.Println()
}
if result.HasWarnings() {
log.Warn("⚠️ Configuration validation completed with warnings:")
for _, warn := range result.Warnings {
log.Warnf(" • %s", warn.Error())
}
fmt.Println()
}
if !result.HasErrors() && !result.HasWarnings() {
log.Info("✅ Configuration validation passed successfully!")
}
}
// runSpecializedValidation performs targeted validation based on flags
func runSpecializedValidation(c *Config, security, performance, connectivity, quiet, verbose, fixable bool) {
result := &ConfigValidationResult{Valid: true}
if verbose {
log.Info("Running specialized validation with detailed output...")
fmt.Println()
}
// Run only the requested validation types
if security {
if verbose {
log.Info("🔐 Running security validation checks...")
}
validateSecurityConfig(&c.Security, result)
validateSecurityHardening(c, result)
}
if performance {
if verbose {
log.Info("⚡ Running performance validation checks...")
}
validatePerformanceSettings(c, result)
validateSystemResources(result)
}
if connectivity {
if verbose {
log.Info("🌐 Running connectivity validation checks...")
}
validateNetworkConnectivity(c, result)
}
// If no specific type is requested, run basic validation
if !security && !performance && !connectivity {
if verbose {
log.Info("🔍 Running comprehensive validation...")
}
result = ValidateConfigComprehensive(c)
}
// Filter results based on flags
if fixable {
filterFixableIssues(result)
}
// Output results based on verbosity
if quiet {
printQuietValidationResults(result)
} else if verbose {
printVerboseValidationResults(result)
} else {
PrintValidationResults(result)
}
// Exit with appropriate code
if result.HasErrors() {
os.Exit(1)
}
}
// filterFixableIssues removes non-fixable issues from results
func filterFixableIssues(result *ConfigValidationResult) {
fixablePatterns := []string{
"permissions",
"directory",
"default value",
"debug logging",
"size format",
"timeout format",
"port number",
"IP address",
}
var fixableErrors []ConfigValidationError
var fixableWarnings []ConfigValidationError
for _, err := range result.Errors {
for _, pattern := range fixablePatterns {
if strings.Contains(strings.ToLower(err.Message), pattern) {
fixableErrors = append(fixableErrors, err)
break
}
}
}
for _, warn := range result.Warnings {
for _, pattern := range fixablePatterns {
if strings.Contains(strings.ToLower(warn.Message), pattern) {
fixableWarnings = append(fixableWarnings, warn)
break
}
}
}
result.Errors = fixableErrors
result.Warnings = fixableWarnings
result.Valid = len(fixableErrors) == 0
}
// printQuietValidationResults prints only errors
func printQuietValidationResults(result *ConfigValidationResult) {
if result.HasErrors() {
for _, err := range result.Errors {
fmt.Printf("ERROR: %s\n", err.Error())
}
}
}
// printVerboseValidationResults prints detailed validation information
func printVerboseValidationResults(result *ConfigValidationResult) {
fmt.Println("📊 DETAILED VALIDATION REPORT")
fmt.Println("============================")
fmt.Println()
// System information
fmt.Printf("🖥️ System: %d CPU cores, %d goroutines\n", runtime.NumCPU(), runtime.NumGoroutine())
var memStats runtime.MemStats
runtime.ReadMemStats(&memStats)
fmt.Printf("💾 Memory: %.2f MB allocated\n", float64(memStats.Alloc)/1024/1024)
fmt.Println()
// Validation summary
fmt.Printf("✅ Checks passed: %d\n", countPassedChecks(result))
fmt.Printf("⚠️ Warnings: %d\n", len(result.Warnings))
fmt.Printf("❌ Errors: %d\n", len(result.Errors))
fmt.Println()
// Detailed results
if result.HasErrors() {
fmt.Println("🚨 CONFIGURATION ERRORS:")
for i, err := range result.Errors {
fmt.Printf(" %d. Field: %s\n", i+1, err.Field)
fmt.Printf(" Issue: %s\n", err.Message)
fmt.Printf(" Value: %v\n", err.Value)
fmt.Println()
}
}
if result.HasWarnings() {
fmt.Println("⚠️ CONFIGURATION WARNINGS:")
for i, warn := range result.Warnings {
fmt.Printf(" %d. Field: %s\n", i+1, warn.Field)
fmt.Printf(" Issue: %s\n", warn.Message)
fmt.Printf(" Value: %v\n", warn.Value)
fmt.Println()
}
}
if !result.HasErrors() && !result.HasWarnings() {
fmt.Println("🎉 All validation checks passed successfully!")
}
}
// countPassedChecks estimates the number of successful validation checks
func countPassedChecks(result *ConfigValidationResult) int {
// Rough estimate: total possible checks minus errors and warnings
totalPossibleChecks := 50 // Approximate number of validation checks
return totalPossibleChecks - len(result.Errors) - len(result.Warnings)
}

713
cmd/server/helpers.go Normal file
View File

@@ -0,0 +1,713 @@
package main
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net"
"net/http"
"os"
"os/signal"
"path/filepath"
"runtime"
"strings"
"syscall"
"time"
"github.com/dutchcoders/go-clamd"
"github.com/go-redis/redis/v8"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/shirou/gopsutil/cpu"
"github.com/shirou/gopsutil/mem"
"gopkg.in/natefinch/lumberjack.v2"
)
// WorkerPool represents a pool of workers
type WorkerPool struct {
workers int
taskQueue chan UploadTask
scanQueue chan ScanTask
ctx context.Context
cancel context.CancelFunc
}
// NewWorkerPool creates a new worker pool
func NewWorkerPool(workers int, queueSize int) *WorkerPool {
ctx, cancel := context.WithCancel(context.Background())
return &WorkerPool{
workers: workers,
taskQueue: make(chan UploadTask, queueSize),
scanQueue: make(chan ScanTask, queueSize),
ctx: ctx,
cancel: cancel,
}
}
// Start starts the worker pool
func (wp *WorkerPool) Start() {
for i := 0; i < wp.workers; i++ {
go wp.worker()
}
}
// Stop stops the worker pool
func (wp *WorkerPool) Stop() {
wp.cancel()
close(wp.taskQueue)
close(wp.scanQueue)
}
// worker is the worker function
func (wp *WorkerPool) worker() {
for {
select {
case <-wp.ctx.Done():
return
case task := <-wp.taskQueue:
if task.Result != nil {
task.Result <- nil // Simple implementation
}
case scanTask := <-wp.scanQueue:
err := processScan(scanTask)
if scanTask.Result != nil {
scanTask.Result <- err
}
}
}
}
// Stub for precacheStoragePath
func precacheStoragePath(storagePath string) error {
// TODO: Implement actual pre-caching logic
// This would typically involve walking the storagePath
// and loading file information into a cache.
log.Infof("Pre-caching for storage path '%s' is a stub and not yet implemented.", storagePath)
return nil
}
func checkFreeSpaceWithRetry(path string, retries int, delay time.Duration) error {
for i := 0; i < retries; i++ {
minFreeBytes, err := parseSize(conf.Server.MinFreeBytes)
if err != nil {
log.Fatalf("Invalid MinFreeBytes: %v", err)
}
if err := checkStorageSpace(path, minFreeBytes); err != nil {
log.Warnf("Free space check failed (attempt %d/%d): %v", i+1, retries, err)
time.Sleep(delay)
continue
}
return nil
}
return fmt.Errorf("insufficient free space after %d attempts", retries)
}
func handleFileCleanup(conf *Config) {
if !conf.Server.FileTTLEnabled {
log.Println("File TTL is disabled.")
return
}
ttlDuration, err := parseTTL(conf.Server.FileTTL)
if err != nil {
log.Fatalf("Invalid TTL configuration: %v", err)
}
log.Printf("TTL cleanup enabled. Files older than %v will be deleted.", ttlDuration)
ticker := time.NewTicker(24 * time.Hour)
defer ticker.Stop()
for range ticker.C {
deleteOldFiles(conf, ttlDuration)
}
}
func computeSHA256(ctx context.Context, filePath string) (string, error) {
if filePath == "" {
return "", fmt.Errorf("file path is empty")
}
file, err := os.Open(filePath)
if err != nil {
return "", fmt.Errorf("failed to open file %s: %w", filePath, err)
}
defer file.Close()
hasher := sha256.New()
if _, err := io.Copy(hasher, file); err != nil {
return "", fmt.Errorf("failed to hash file: %w", err)
}
return hex.EncodeToString(hasher.Sum(nil)), nil
}
func handleDeduplication(ctx context.Context, absFilename string) error {
checksum, err := computeSHA256(ctx, absFilename)
if err != nil {
return err
}
dedupDir := conf.Deduplication.Directory
if dedupDir == "" {
return fmt.Errorf("deduplication directory not configured")
}
dedupPath := filepath.Join(dedupDir, checksum)
if err := os.MkdirAll(dedupPath, os.ModePerm); err != nil {
return err
}
existingPath := filepath.Join(dedupPath, filepath.Base(absFilename))
if _, err := os.Stat(existingPath); err == nil {
return os.Link(existingPath, absFilename)
}
if err := os.Rename(absFilename, existingPath); err != nil {
return err
}
return os.Link(existingPath, absFilename)
}
func handleISOContainer(absFilename string) error {
isoPath := filepath.Join(conf.ISO.MountPoint, "container.iso")
if err := CreateISOContainer([]string{absFilename}, isoPath, conf.ISO.Size, conf.ISO.Charset); err != nil {
return err
}
if err := MountISOContainer(isoPath, conf.ISO.MountPoint); err != nil {
return err
}
return UnmountISOContainer(conf.ISO.MountPoint)
}
func sanitizeFilePath(baseDir, filePath string) (string, error) {
absBaseDir, err := filepath.Abs(baseDir)
if err != nil {
return "", err
}
absFilePath, err := filepath.Abs(filepath.Join(absBaseDir, filePath))
if err != nil {
return "", err
}
if !strings.HasPrefix(absFilePath, absBaseDir) {
return "", fmt.Errorf("invalid file path: %s", filePath)
}
return absFilePath, nil
}
// Stub for formatBytes
func formatBytes(bytes int64) string {
const unit = 1024
if bytes < unit {
return fmt.Sprintf("%d B", bytes)
}
div, exp := int64(unit), 0
for n := bytes / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp])
}
// Stub for deleteOldFiles
func deleteOldFiles(conf *Config, ttlDuration time.Duration) {
// TODO: Implement actual file deletion logic based on TTL
log.Infof("deleteOldFiles is a stub and not yet implemented. It would check for files older than %v.", ttlDuration)
}
// Stub for CreateISOContainer
func CreateISOContainer(files []string, isoPath, size, charset string) error {
// TODO: Implement actual ISO container creation logic
log.Infof("CreateISOContainer is a stub and not yet implemented. It would create an ISO at %s.", isoPath)
return nil
}
// Stub for MountISOContainer
func MountISOContainer(isoPath, mountPoint string) error {
// TODO: Implement actual ISO container mounting logic
log.Infof("MountISOContainer is a stub and not yet implemented. It would mount %s to %s.", isoPath, mountPoint)
return nil
}
// Stub for UnmountISOContainer
func UnmountISOContainer(mountPoint string) error {
// TODO: Implement actual ISO container unmounting logic
log.Infof("UnmountISOContainer is a stub and not yet implemented. It would unmount %s.", mountPoint)
return nil
}
func checkStorageSpace(storagePath string, minFreeBytes int64) error {
var stat syscall.Statfs_t
if err := syscall.Statfs(storagePath, &stat); err != nil {
return err
}
availableBytes := stat.Bavail * uint64(stat.Bsize)
if int64(availableBytes) < minFreeBytes {
return fmt.Errorf("not enough space: available %d < required %d", availableBytes, minFreeBytes)
}
return nil
}
// setupLogging initializes logging configuration
func setupLogging() {
log.Infof("DEBUG: Starting setupLogging function")
if conf.Logging.File != "" {
log.Infof("DEBUG: Setting up file logging to: %s", conf.Logging.File)
log.SetOutput(&lumberjack.Logger{
Filename: conf.Logging.File,
MaxSize: conf.Logging.MaxSize,
MaxBackups: conf.Logging.MaxBackups,
MaxAge: conf.Logging.MaxAge,
Compress: conf.Logging.Compress,
})
log.Infof("Logging configured to file: %s", conf.Logging.File)
}
log.Infof("DEBUG: setupLogging function completed")
}
// logSystemInfo logs system information
func logSystemInfo() {
memStats, err := mem.VirtualMemory()
if err != nil {
log.Warnf("Failed to get memory stats: %v", err)
} else {
log.Infof("System Memory: Total=%s, Available=%s, Used=%.1f%%",
formatBytes(int64(memStats.Total)),
formatBytes(int64(memStats.Available)),
memStats.UsedPercent)
}
cpuStats, err := cpu.Info()
if err != nil {
log.Warnf("Failed to get CPU stats: %v", err)
} else if len(cpuStats) > 0 {
log.Infof("CPU: %s, Cores=%d", cpuStats[0].ModelName, len(cpuStats))
}
log.Infof("Go Runtime: Version=%s, NumCPU=%d, NumGoroutine=%d",
runtime.Version(), runtime.NumCPU(), runtime.NumGoroutine())
}
// initMetrics initializes Prometheus metrics
func initMetrics() {
uploadDuration = prometheus.NewHistogram(prometheus.HistogramOpts{
Name: "upload_duration_seconds",
Help: "Duration of upload operations in seconds",
})
uploadErrorsTotal = prometheus.NewCounter(prometheus.CounterOpts{
Name: "upload_errors_total",
Help: "Total number of upload errors",
})
uploadsTotal = prometheus.NewCounter(prometheus.CounterOpts{
Name: "uploads_total",
Help: "Total number of uploads",
})
downloadDuration = prometheus.NewHistogram(prometheus.HistogramOpts{
Name: "download_duration_seconds",
Help: "Duration of download operations in seconds",
})
downloadsTotal = prometheus.NewCounter(prometheus.CounterOpts{
Name: "downloads_total",
Help: "Total number of downloads",
})
downloadErrorsTotal = prometheus.NewCounter(prometheus.CounterOpts{
Name: "download_errors_total",
Help: "Total number of download errors",
})
memoryUsage = prometheus.NewGauge(prometheus.GaugeOpts{
Name: "memory_usage_percent",
Help: "Current memory usage percentage",
})
cpuUsage = prometheus.NewGauge(prometheus.GaugeOpts{
Name: "cpu_usage_percent",
Help: "Current CPU usage percentage",
})
activeConnections = prometheus.NewGauge(prometheus.GaugeOpts{
Name: "active_connections_total",
Help: "Number of active connections",
})
requestsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "requests_total",
Help: "Total number of requests",
}, []string{"method", "status"})
goroutines = prometheus.NewGauge(prometheus.GaugeOpts{
Name: "goroutines_total",
Help: "Number of goroutines",
})
uploadSizeBytes = prometheus.NewHistogram(prometheus.HistogramOpts{
Name: "upload_size_bytes",
Help: "Size of uploaded files in bytes",
})
downloadSizeBytes = prometheus.NewHistogram(prometheus.HistogramOpts{
Name: "download_size_bytes",
Help: "Size of downloaded files in bytes",
})
filesDeduplicatedTotal = prometheus.NewCounter(prometheus.CounterOpts{
Name: "files_deduplicated_total",
Help: "Total number of deduplicated files",
})
deduplicationErrorsTotal = prometheus.NewCounter(prometheus.CounterOpts{
Name: "deduplication_errors_total",
Help: "Total number of deduplication errors",
})
isoContainersCreatedTotal = prometheus.NewCounter(prometheus.CounterOpts{
Name: "iso_containers_created_total",
Help: "Total number of ISO containers created",
})
isoCreationErrorsTotal = prometheus.NewCounter(prometheus.CounterOpts{
Name: "iso_creation_errors_total",
Help: "Total number of ISO creation errors",
})
isoContainersMountedTotal = prometheus.NewCounter(prometheus.CounterOpts{
Name: "iso_containers_mounted_total",
Help: "Total number of ISO containers mounted",
})
isoMountErrorsTotal = prometheus.NewCounter(prometheus.CounterOpts{
Name: "iso_mount_errors_total",
Help: "Total number of ISO mount errors",
})
workerAdjustmentsTotal = prometheus.NewCounter(prometheus.CounterOpts{
Name: "worker_adjustments_total",
Help: "Total number of worker adjustments",
})
workerReAdjustmentsTotal = prometheus.NewCounter(prometheus.CounterOpts{
Name: "worker_readjustments_total",
Help: "Total number of worker readjustments",
})
// Register all metrics
prometheus.MustRegister(
uploadDuration, uploadErrorsTotal, uploadsTotal,
downloadDuration, downloadsTotal, downloadErrorsTotal,
memoryUsage, cpuUsage, activeConnections, requestsTotal,
goroutines, uploadSizeBytes, downloadSizeBytes,
filesDeduplicatedTotal, deduplicationErrorsTotal,
isoContainersCreatedTotal, isoCreationErrorsTotal,
isoContainersMountedTotal, isoMountErrorsTotal,
workerAdjustmentsTotal, workerReAdjustmentsTotal,
)
log.Info("Prometheus metrics initialized successfully")
}
// scanFileWithClamAV scans a file using ClamAV
func scanFileWithClamAV(filename string) error {
if clamClient == nil {
return fmt.Errorf("ClamAV client not initialized")
}
result, err := clamClient.ScanFile(filename)
if err != nil {
return fmt.Errorf("ClamAV scan failed: %w", err)
}
// Handle the result channel
if result != nil {
select {
case scanResult := <-result:
if scanResult != nil && scanResult.Status != "OK" {
return fmt.Errorf("virus detected in %s: %s", filename, scanResult.Status)
}
case <-time.After(30 * time.Second):
return fmt.Errorf("ClamAV scan timeout for file: %s", filename)
}
}
log.Debugf("File %s passed ClamAV scan", filename)
return nil
}
// initClamAV initializes ClamAV client
func initClamAV(socketPath string) (*clamd.Clamd, error) {
if socketPath == "" {
socketPath = "/var/run/clamav/clamd.ctl"
}
client := clamd.NewClamd(socketPath)
// Test connection
err := client.Ping()
if err != nil {
return nil, fmt.Errorf("failed to ping ClamAV daemon: %w", err)
}
log.Infof("ClamAV client initialized with socket: %s", socketPath)
return client, nil
}
// initRedis initializes Redis client
func initRedis() {
redisClient = redis.NewClient(&redis.Options{
Addr: conf.Redis.RedisAddr,
Password: conf.Redis.RedisPassword,
DB: conf.Redis.RedisDBIndex,
})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, err := redisClient.Ping(ctx).Result()
if err != nil {
log.Warnf("Failed to connect to Redis: %v", err)
redisConnected = false
} else {
log.Info("Redis client initialized successfully")
redisConnected = true
}
}
// monitorNetwork monitors network events
func monitorNetwork(ctx context.Context) {
log.Info("Starting network monitoring")
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
log.Info("Network monitoring stopped")
return
case <-ticker.C:
// Simple network monitoring - check interface status
interfaces, err := net.Interfaces()
if err != nil {
log.Warnf("Failed to get network interfaces: %v", err)
continue
}
for _, iface := range interfaces {
if iface.Flags&net.FlagUp != 0 && iface.Flags&net.FlagLoopback == 0 {
select {
case networkEvents <- NetworkEvent{
Type: "interface_up",
Details: fmt.Sprintf("Interface %s is up", iface.Name),
}:
default:
// Channel full, skip
}
}
}
}
}
}
// handleNetworkEvents handles network events
func handleNetworkEvents(ctx context.Context) {
log.Info("Starting network event handler")
for {
select {
case <-ctx.Done():
log.Info("Network event handler stopped")
return
case event := <-networkEvents:
log.Debugf("Network event: %s - %s", event.Type, event.Details)
}
}
}
// updateSystemMetrics updates system metrics
func updateSystemMetrics(ctx context.Context) {
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
// Update memory metrics
if memStats, err := mem.VirtualMemory(); err == nil {
memoryUsage.Set(memStats.UsedPercent)
}
// Update CPU metrics
if cpuPercents, err := cpu.Percent(time.Second, false); err == nil && len(cpuPercents) > 0 {
cpuUsage.Set(cpuPercents[0])
}
// Update goroutine count
goroutines.Set(float64(runtime.NumGoroutine()))
}
}
}
// setupRouter sets up HTTP routes
func setupRouter() *http.ServeMux {
mux := http.NewServeMux()
mux.HandleFunc("/upload", handleUpload)
mux.HandleFunc("/download/", handleDownload)
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
})
if conf.Server.MetricsEnabled {
mux.Handle("/metrics", promhttp.Handler())
}
// Catch-all handler for all upload protocols (v, v2, token, v3)
// This must be added last as it matches all paths
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
// Handle PUT requests for all upload protocols
if r.Method == http.MethodPut {
query := r.URL.Query()
// Check if this is a v3 request (mod_http_upload_external)
if query.Get("v3") != "" && query.Get("expires") != "" {
handleV3Upload(w, r)
return
}
// Check if this is a legacy protocol request (v, v2, token)
if query.Get("v") != "" || query.Get("v2") != "" || query.Get("token") != "" {
handleLegacyUpload(w, r)
return
}
}
// Handle GET/HEAD requests for downloads
if r.Method == http.MethodGet || r.Method == http.MethodHead {
// Only handle download requests if the path looks like a file
path := strings.TrimPrefix(r.URL.Path, "/")
if path != "" && !strings.HasSuffix(path, "/") {
handleLegacyDownload(w, r)
return
}
}
// For all other requests, return 404
http.NotFound(w, r)
})
log.Info("HTTP router configured successfully with full protocol support (v, v2, token, v3)")
return mux
}
// setupGracefulShutdown sets up graceful shutdown
func setupGracefulShutdown(server *http.Server, cancel context.CancelFunc) {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigChan
log.Info("Received shutdown signal, initiating graceful shutdown...")
// Cancel context
cancel()
// Shutdown server with timeout
ctx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer shutdownCancel()
if err := server.Shutdown(ctx); err != nil {
log.Errorf("Server shutdown error: %v", err)
} else {
log.Info("Server shutdown completed")
}
// Clean up PID file
if conf.Server.CleanUponExit {
removePIDFile(conf.Server.PIDFilePath)
}
// Stop worker pool if it exists
if workerPool != nil {
workerPool.Stop()
log.Info("Worker pool stopped")
}
os.Exit(0)
}()
}
// ProgressWriter wraps an io.Writer to provide upload progress reporting
type ProgressWriter struct {
dst io.Writer
total int64
written int64
filename string
onProgress func(written, total int64, filename string)
lastReport time.Time
}
// NewProgressWriter creates a new ProgressWriter
func NewProgressWriter(dst io.Writer, total int64, filename string) *ProgressWriter {
return &ProgressWriter{
dst: dst,
total: total,
filename: filename,
onProgress: func(written, total int64, filename string) {
if total > 0 {
percentage := float64(written) / float64(total) * 100
sizeMiB := float64(written) / (1024 * 1024)
totalMiB := float64(total) / (1024 * 1024)
log.Infof("Upload progress for %s: %.1f%% (%.1f/%.1f MiB)",
filepath.Base(filename), percentage, sizeMiB, totalMiB)
}
},
lastReport: time.Now(),
}
}
// Write implements io.Writer interface with progress reporting
func (pw *ProgressWriter) Write(p []byte) (int, error) {
n, err := pw.dst.Write(p)
if err != nil {
return n, err
}
pw.written += int64(n)
// Report progress every 30 seconds or every 50MB for large files
now := time.Now()
shouldReport := false
if pw.total > 100*1024*1024 { // Files larger than 100MB
shouldReport = now.Sub(pw.lastReport) > 30*time.Second ||
(pw.written%(50*1024*1024) == 0 && pw.written > 0)
} else if pw.total > 10*1024*1024 { // Files larger than 10MB
shouldReport = now.Sub(pw.lastReport) > 10*time.Second ||
(pw.written%(10*1024*1024) == 0 && pw.written > 0)
}
if shouldReport && pw.onProgress != nil {
pw.onProgress(pw.written, pw.total, pw.filename)
pw.lastReport = now
}
return n, err
}
// copyWithProgress copies data from src to dst with progress reporting
func copyWithProgress(dst io.Writer, src io.Reader, total int64, filename string) (int64, error) {
progressWriter := NewProgressWriter(dst, total, filename)
// Use a pooled buffer for efficient copying
bufPtr := bufferPool.Get().(*[]byte)
defer bufferPool.Put(bufPtr)
buf := *bufPtr
return io.CopyBuffer(progressWriter, src, buf)
}

View File

@@ -8,140 +8,285 @@ import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"mime"
"net"
"net/http"
"net/url"
"os"
"os/signal"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"syscall"
"time"
"sync"
"github.com/dutchcoders/go-clamd" // ClamAV integration
"github.com/go-redis/redis/v8" // Redis integration
jwt "github.com/golang-jwt/jwt/v5"
"github.com/patrickmn/go-cache"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/shirou/gopsutil/cpu"
"github.com/shirou/gopsutil/disk"
"github.com/shirou/gopsutil/host"
"github.com/shirou/gopsutil/mem"
"github.com/sirupsen/logrus"
"github.com/spf13/viper"
)
// Configuration structure
// parseSize converts a human-readable size string to bytes
func parseSize(sizeStr string) (int64, error) {
sizeStr = strings.TrimSpace(sizeStr)
if len(sizeStr) < 2 {
return 0, fmt.Errorf("invalid size string: %s", sizeStr)
}
unit := strings.ToUpper(sizeStr[len(sizeStr)-2:])
valueStr := sizeStr[:len(sizeStr)-2]
value, err := strconv.Atoi(valueStr)
if err != nil {
return 0, fmt.Errorf("invalid size value: %v", err)
}
switch unit {
case "KB":
return int64(value) * 1024, nil
case "MB":
return int64(value) * 1024 * 1024, nil
case "GB":
return int64(value) * 1024 * 1024 * 1024, nil
default:
return 0, fmt.Errorf("unknown size unit: %s", unit)
}
}
// parseTTL converts a human-readable TTL string to a time.Duration
func parseTTL(ttlStr string) (time.Duration, error) {
ttlStr = strings.ToLower(strings.TrimSpace(ttlStr))
if ttlStr == "" {
return 0, fmt.Errorf("TTL string cannot be empty")
}
var valueStr string
var unit rune
for _, r := range ttlStr {
if r >= '0' && r <= '9' {
valueStr += string(r)
} else {
unit = r
break
}
}
val, err := strconv.Atoi(valueStr)
if err != nil {
return 0, fmt.Errorf("invalid TTL value: %v", err)
}
switch unit {
case 's':
return time.Duration(val) * time.Second, nil
case 'm':
return time.Duration(val) * time.Minute, nil
case 'h':
return time.Duration(val) * time.Hour, nil
case 'd':
return time.Duration(val) * 24 * time.Hour, nil
case 'w':
return time.Duration(val) * 7 * 24 * time.Hour, nil
case 'y':
return time.Duration(val) * 365 * 24 * time.Hour, nil
default:
return 0, fmt.Errorf("unknown TTL unit: %c", unit)
}
}
// Configuration structures
type ServerConfig struct {
ListenPort string `mapstructure:"ListenPort"`
UnixSocket bool `mapstructure:"UnixSocket"`
StoragePath string `mapstructure:"StoragePath"`
LogLevel string `mapstructure:"LogLevel"`
LogFile string `mapstructure:"LogFile"`
MetricsEnabled bool `mapstructure:"MetricsEnabled"`
MetricsPort string `mapstructure:"MetricsPort"`
FileTTL string `mapstructure:"FileTTL"`
MinFreeBytes int64 `mapstructure:"MinFreeBytes"` // Minimum free bytes required
DeduplicationEnabled bool `mapstructure:"DeduplicationEnabled"`
}
type TimeoutConfig struct {
ReadTimeout string `mapstructure:"ReadTimeout"`
WriteTimeout string `mapstructure:"WriteTimeout"`
IdleTimeout string `mapstructure:"IdleTimeout"`
}
type SecurityConfig struct {
Secret string `mapstructure:"Secret"`
}
type VersioningConfig struct {
EnableVersioning bool `mapstructure:"EnableVersioning"`
MaxVersions int `mapstructure:"MaxVersions"`
ListenAddress string `toml:"listenport" mapstructure:"listenport"` // Fixed to match config file field
StoragePath string `toml:"storagepath" mapstructure:"storagepath"` // Fixed to match config
MetricsEnabled bool `toml:"metricsenabled" mapstructure:"metricsenabled"` // Fixed to match config
MetricsPath string `toml:"metrics_path" mapstructure:"metrics_path"`
PidFile string `toml:"pid_file" mapstructure:"pid_file"`
MaxUploadSize string `toml:"max_upload_size" mapstructure:"max_upload_size"`
MaxHeaderBytes int `toml:"max_header_bytes" mapstructure:"max_header_bytes"`
CleanupInterval string `toml:"cleanup_interval" mapstructure:"cleanup_interval"`
MaxFileAge string `toml:"max_file_age" mapstructure:"max_file_age"`
PreCache bool `toml:"pre_cache" mapstructure:"pre_cache"`
PreCacheWorkers int `toml:"pre_cache_workers" mapstructure:"pre_cache_workers"`
PreCacheInterval string `toml:"pre_cache_interval" mapstructure:"pre_cache_interval"`
GlobalExtensions []string `toml:"global_extensions" mapstructure:"global_extensions"`
DeduplicationEnabled bool `toml:"deduplication_enabled" mapstructure:"deduplication_enabled"`
MinFreeBytes string `toml:"min_free_bytes" mapstructure:"min_free_bytes"`
FileNaming string `toml:"file_naming" mapstructure:"file_naming"`
ForceProtocol string `toml:"force_protocol" mapstructure:"force_protocol"`
EnableDynamicWorkers bool `toml:"enable_dynamic_workers" mapstructure:"enable_dynamic_workers"`
WorkerScaleUpThresh int `toml:"worker_scale_up_thresh" mapstructure:"worker_scale_up_thresh"`
WorkerScaleDownThresh int `toml:"worker_scale_down_thresh" mapstructure:"worker_scale_down_thresh"`
UnixSocket bool `toml:"unixsocket" mapstructure:"unixsocket"` // Added missing field from example/logs
MetricsPort string `toml:"metricsport" mapstructure:"metricsport"` // Fixed to match config
FileTTL string `toml:"filettl" mapstructure:"filettl"` // Fixed to match config
FileTTLEnabled bool `toml:"filettlenabled" mapstructure:"filettlenabled"` // Fixed to match config
AutoAdjustWorkers bool `toml:"autoadjustworkers" mapstructure:"autoadjustworkers"` // Fixed to match config
NetworkEvents bool `toml:"networkevents" mapstructure:"networkevents"` // Fixed to match config
PIDFilePath string `toml:"pidfilepath" mapstructure:"pidfilepath"` // Fixed to match config
CleanUponExit bool `toml:"clean_upon_exit" mapstructure:"clean_upon_exit"` // Added missing field
PreCaching bool `toml:"precaching" mapstructure:"precaching"` // Fixed to match config
BindIP string `toml:"bind_ip" mapstructure:"bind_ip"` // Added missing field
}
type UploadsConfig struct {
ResumableUploadsEnabled bool `mapstructure:"ResumableUploadsEnabled"`
ChunkedUploadsEnabled bool `mapstructure:"ChunkedUploadsEnabled"`
ChunkSize int64 `mapstructure:"ChunkSize"`
AllowedExtensions []string `mapstructure:"AllowedExtensions"`
AllowedExtensions []string `toml:"allowedextensions" mapstructure:"allowedextensions"`
ChunkedUploadsEnabled bool `toml:"chunkeduploadsenabled" mapstructure:"chunkeduploadsenabled"`
ChunkSize string `toml:"chunksize" mapstructure:"chunksize"`
ResumableUploadsEnabled bool `toml:"resumableuploadsenabled" mapstructure:"resumableuploadsenabled"`
MaxResumableAge string `toml:"max_resumable_age" mapstructure:"max_resumable_age"`
}
type DownloadsConfig struct {
AllowedExtensions []string `toml:"allowedextensions" mapstructure:"allowedextensions"`
ChunkedDownloadsEnabled bool `toml:"chunkeddownloadsenabled" mapstructure:"chunkeddownloadsenabled"`
ChunkSize string `toml:"chunksize" mapstructure:"chunksize"`
ResumableDownloadsEnabled bool `toml:"resumable_downloads_enabled" mapstructure:"resumable_downloads_enabled"`
}
type SecurityConfig struct {
Secret string `toml:"secret" mapstructure:"secret"`
EnableJWT bool `toml:"enablejwt" mapstructure:"enablejwt"` // Added EnableJWT field
JWTSecret string `toml:"jwtsecret" mapstructure:"jwtsecret"`
JWTAlgorithm string `toml:"jwtalgorithm" mapstructure:"jwtalgorithm"`
JWTExpiration string `toml:"jwtexpiration" mapstructure:"jwtexpiration"`
}
type LoggingConfig struct {
Level string `mapstructure:"level"`
File string `mapstructure:"file"`
MaxSize int `mapstructure:"max_size"`
MaxBackups int `mapstructure:"max_backups"`
MaxAge int `mapstructure:"max_age"`
Compress bool `mapstructure:"compress"`
}
type DeduplicationConfig struct {
Enabled bool `mapstructure:"enabled"`
Directory string `mapstructure:"directory"`
}
type ISOConfig struct {
Enabled bool `mapstructure:"enabled"`
MountPoint string `mapstructure:"mountpoint"`
Size string `mapstructure:"size"`
Charset string `mapstructure:"charset"`
ContainerFile string `mapstructure:"containerfile"` // Added missing field
}
type TimeoutConfig struct {
Read string `mapstructure:"readtimeout" toml:"readtimeout"`
Write string `mapstructure:"writetimeout" toml:"writetimeout"`
Idle string `mapstructure:"idletimeout" toml:"idletimeout"`
Shutdown string `mapstructure:"shutdown" toml:"shutdown"`
}
type VersioningConfig struct {
Enabled bool `mapstructure:"enableversioning" toml:"enableversioning"` // Corrected to match example config
Backend string `mapstructure:"backend" toml:"backend"`
MaxRevs int `mapstructure:"maxversions" toml:"maxversions"` // Corrected to match example config
}
type ClamAVConfig struct {
ClamAVEnabled bool `mapstructure:"ClamAVEnabled"`
ClamAVSocket string `mapstructure:"ClamAVSocket"`
NumScanWorkers int `mapstructure:"NumScanWorkers"`
ClamAVEnabled bool `mapstructure:"clamavenabled"`
ClamAVSocket string `mapstructure:"clamavsocket"`
NumScanWorkers int `mapstructure:"numscanworkers"`
ScanFileExtensions []string `mapstructure:"scanfileextensions"`
}
type RedisConfig struct {
RedisEnabled bool `mapstructure:"RedisEnabled"`
RedisDBIndex int `mapstructure:"RedisDBIndex"`
RedisAddr string `mapstructure:"RedisAddr"`
RedisPassword string `mapstructure:"RedisPassword"`
RedisHealthCheckInterval string `mapstructure:"RedisHealthCheckInterval"`
RedisEnabled bool `mapstructure:"redisenabled"`
RedisDBIndex int `mapstructure:"redisdbindex"`
RedisAddr string `mapstructure:"redisaddr"`
RedisPassword string `mapstructure:"redispassword"`
RedisHealthCheckInterval string `mapstructure:"redishealthcheckinterval"`
}
type WorkersConfig struct {
NumWorkers int `mapstructure:"NumWorkers"`
UploadQueueSize int `mapstructure:"UploadQueueSize"`
NumWorkers int `mapstructure:"numworkers"`
UploadQueueSize int `mapstructure:"uploadqueuesize"`
}
type FileConfig struct {
FileRevision int `mapstructure:"FileRevision"`
}
type BuildConfig struct {
Version string `mapstructure:"version"` // Updated version
}
// This is the main Config struct to be used
type Config struct {
Server ServerConfig `mapstructure:"server"`
Timeouts TimeoutConfig `mapstructure:"timeouts"`
Security SecurityConfig `mapstructure:"security"`
Versioning VersioningConfig `mapstructure:"versioning"`
Uploads UploadsConfig `mapstructure:"uploads"`
ClamAV ClamAVConfig `mapstructure:"clamav"`
Redis RedisConfig `mapstructure:"redis"`
Workers WorkersConfig `mapstructure:"workers"`
File FileConfig `mapstructure:"file"`
Server ServerConfig `mapstructure:"server"`
Logging LoggingConfig `mapstructure:"logging"`
Deduplication DeduplicationConfig `mapstructure:"deduplication"` // Added
ISO ISOConfig `mapstructure:"iso"` // Added
Timeouts TimeoutConfig `mapstructure:"timeouts"` // Added
Security SecurityConfig `mapstructure:"security"`
Versioning VersioningConfig `mapstructure:"versioning"` // Added
Uploads UploadsConfig `mapstructure:"uploads"`
Downloads DownloadsConfig `mapstructure:"downloads"`
ClamAV ClamAVConfig `mapstructure:"clamav"`
Redis RedisConfig `mapstructure:"redis"`
Workers WorkersConfig `mapstructure:"workers"`
File FileConfig `mapstructure:"file"`
Build BuildConfig `mapstructure:"build"`
}
// UploadTask represents a file upload task
type UploadTask struct {
AbsFilename string
Request *http.Request
Result chan error
}
// ScanTask represents a file scan task
type ScanTask struct {
AbsFilename string
Result chan error
}
// NetworkEvent represents a network-related event
type NetworkEvent struct {
Type string
Details string
}
var (
conf Config
versionString string = "v2.0-dev"
log = logrus.New()
uploadQueue chan UploadTask
networkEvents chan NetworkEvent
fileInfoCache *cache.Cache
clamClient *clamd.Clamd // Added for ClamAV integration
redisClient *redis.Client // Redis client
redisConnected bool // Redis connection status
mu sync.RWMutex
// Add a new field to store the creation date of files
type FileMetadata struct {
CreationDate time.Time
}
// processScan processes a scan task
func processScan(task ScanTask) error {
log.Infof("Started processing scan for file: %s", task.AbsFilename)
semaphore <- struct{}{}
defer func() { <-semaphore }()
err := scanFileWithClamAV(task.AbsFilename)
if err != nil {
log.WithFields(logrus.Fields{"file": task.AbsFilename, "error": err}).Error("Failed to scan file")
return err
}
log.Infof("Finished processing scan for file: %s", task.AbsFilename)
return nil
}
var (
conf Config
versionString string
log = logrus.New()
fileInfoCache *cache.Cache
fileMetadataCache *cache.Cache
clamClient *clamd.Clamd
redisClient *redis.Client
redisConnected bool
confMutex sync.RWMutex // Protects the global 'conf' variable and related critical sections.
// Use RLock() for reading, Lock() for writing.
// Prometheus metrics
uploadDuration prometheus.Histogram
uploadErrorsTotal prometheus.Counter
uploadsTotal prometheus.Counter
@@ -156,1768 +301,1692 @@ var (
uploadSizeBytes prometheus.Histogram
downloadSizeBytes prometheus.Histogram
// Constants for worker pool
MinWorkers = 5 // Increased from 10 to 20 for better concurrency
UploadQueueSize = 10000 // Increased from 5000 to 10000
filesDeduplicatedTotal prometheus.Counter
deduplicationErrorsTotal prometheus.Counter
isoContainersCreatedTotal prometheus.Counter
isoCreationErrorsTotal prometheus.Counter
isoContainersMountedTotal prometheus.Counter
isoMountErrorsTotal prometheus.Counter
// Channels
scanQueue chan ScanTask
ScanWorkers = 5 // Number of ClamAV scan workers
workerPool *WorkerPool
networkEvents chan NetworkEvent
workerAdjustmentsTotal prometheus.Counter
workerReAdjustmentsTotal prometheus.Counter
)
func main() {
// Set default configuration values
setDefaults()
var bufferPool = sync.Pool{
New: func() interface{} {
buf := make([]byte, 32*1024)
return &buf
},
}
const maxConcurrentOperations = 10
var semaphore = make(chan struct{}, maxConcurrentOperations)
var logMessages []string
var logMu sync.Mutex
func flushLogMessages() {
logMu.Lock()
defer logMu.Unlock()
for _, msg := range logMessages {
log.Info(msg)
}
logMessages = []string{}
}
// writePIDFile writes the current process ID to the specified pid file
func writePIDFile(pidPath string) error {
pid := os.Getpid()
pidStr := strconv.Itoa(pid)
err := os.WriteFile(pidPath, []byte(pidStr), 0644)
if err != nil {
log.Errorf("Failed to write PID file: %v", err) // Improved error logging
return err
}
log.Infof("PID %d written to %s", pid, pidPath)
return nil
}
// removePIDFile removes the PID file
func removePIDFile(pidPath string) {
err := os.Remove(pidPath)
if err != nil {
log.Errorf("Failed to remove PID file: %v", err) // Improved error logging
} else {
log.Infof("PID file %s removed successfully", pidPath)
}
}
// createAndMountISO creates an ISO container and mounts it to the specified mount point
func createAndMountISO(size, mountpoint, charset string) error {
isoPath := conf.ISO.ContainerFile
// Create an empty ISO file
cmd := exec.Command("dd", "if=/dev/zero", fmt.Sprintf("of=%s", isoPath), fmt.Sprintf("bs=%s", size), "count=1")
if err := cmd.Run(); err != nil {
isoCreationErrorsTotal.Inc()
return fmt.Errorf("failed to create ISO file: %w", err)
}
// Format the ISO file with a filesystem
cmd = exec.Command("mkfs", "-t", "iso9660", "-input-charset", charset, isoPath)
if err := cmd.Run(); err != nil {
return fmt.Errorf("failed to format ISO file: %w", err)
}
// Create the mount point directory if it doesn't exist
if err := os.MkdirAll(mountpoint, os.ModePerm); err != nil {
return fmt.Errorf("failed to create mount point: %w", err)
}
// Mount the ISO file
cmd = exec.Command("mount", "-o", "loop", isoPath, mountpoint)
if err := cmd.Run(); err != nil {
isoMountErrorsTotal.Inc()
return fmt.Errorf("failed to mount ISO file: %w", err)
}
isoContainersCreatedTotal.Inc()
isoContainersMountedTotal.Inc()
return nil
}
func initializeNetworkProtocol(forceProtocol string) (*net.Dialer, error) {
switch forceProtocol {
case "ipv4":
return &net.Dialer{
Timeout: 5 * time.Second,
DualStack: false,
Control: func(network, address string, c syscall.RawConn) error {
if network == "tcp6" {
return fmt.Errorf("IPv6 is disabled by forceprotocol setting")
}
return nil
},
}, nil
case "ipv6":
return &net.Dialer{
Timeout: 5 * time.Second,
DualStack: false,
Control: func(network, address string, c syscall.RawConn) error {
if network == "tcp4" {
return fmt.Errorf("IPv4 is disabled by forceprotocol setting")
}
return nil
},
}, nil
case "auto":
return &net.Dialer{
Timeout: 5 * time.Second,
DualStack: true,
}, nil
default:
return nil, fmt.Errorf("invalid forceprotocol value: %s", forceProtocol)
}
}
var dualStackClient *http.Client
func main() {
setDefaults() // Call setDefaults before parsing flags or reading config
// Flags for configuration file
var configFile string
flag.StringVar(&configFile, "config", "./config.toml", "Path to configuration file \"config.toml\".")
var genConfig bool
var genConfigPath string
var validateOnly bool
var runConfigTests bool
var validateQuiet bool
var validateVerbose bool
var validateFixable bool
var validateSecurity bool
var validatePerformance bool
var validateConnectivity bool
var listValidationChecks bool
var showVersion bool
flag.BoolVar(&genConfig, "genconfig", false, "Print example configuration and exit.")
flag.StringVar(&genConfigPath, "genconfig-path", "", "Write example configuration to the given file and exit.")
flag.BoolVar(&validateOnly, "validate-config", false, "Validate configuration and exit without starting server.")
flag.BoolVar(&runConfigTests, "test-config", false, "Run configuration validation test scenarios and exit.")
flag.BoolVar(&validateQuiet, "validate-quiet", false, "Only show errors during validation (suppress warnings and info).")
flag.BoolVar(&validateVerbose, "validate-verbose", false, "Show detailed validation information including system checks.")
flag.BoolVar(&validateFixable, "check-fixable", false, "Only show validation issues that can be automatically fixed.")
flag.BoolVar(&validateSecurity, "check-security", false, "Run only security-related validation checks.")
flag.BoolVar(&validatePerformance, "check-performance", false, "Run only performance-related validation checks.")
flag.BoolVar(&validateConnectivity, "check-connectivity", false, "Run only network connectivity validation checks.")
flag.BoolVar(&listValidationChecks, "list-checks", false, "List all available validation checks and exit.")
flag.BoolVar(&showVersion, "version", false, "Show version information and exit.")
flag.Parse()
// Load configuration
if showVersion {
fmt.Printf("HMAC File Server v3.2\n")
os.Exit(0)
}
if listValidationChecks {
printValidationChecks()
os.Exit(0)
}
if genConfig {
printExampleConfig()
os.Exit(0)
}
if genConfigPath != "" {
f, err := os.Create(genConfigPath)
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to create file: %v\n", err)
os.Exit(1)
}
defer f.Close()
w := bufio.NewWriter(f)
fmt.Fprint(w, getExampleConfigString())
w.Flush()
fmt.Printf("Example config written to %s\n", genConfigPath)
os.Exit(0)
}
if runConfigTests {
RunConfigTests()
os.Exit(0)
}
// Initialize Viper
viper.SetConfigType("toml")
// Set default config path
defaultConfigPath := "/etc/hmac-file-server/config.toml"
// Attempt to load the default config
viper.SetConfigFile(defaultConfigPath)
if err := viper.ReadInConfig(); err != nil {
// If default config not found, fallback to parent directory
parentDirConfig := "../config.toml"
viper.SetConfigFile(parentDirConfig)
if err := viper.ReadInConfig(); err != nil {
// If still not found and -config is provided, use it
if configFile != "" {
viper.SetConfigFile(configFile)
if err := viper.ReadInConfig(); err != nil {
fmt.Printf("Error loading config file: %v\n", err)
os.Exit(1)
}
} else {
fmt.Println("No configuration file found. Please create a config file with the following content:")
printExampleConfig()
os.Exit(1)
}
}
}
err := readConfig(configFile, &conf)
if err != nil {
log.Fatalf("Error reading config: %v", err) // Fatal: application cannot proceed
log.Fatalf("Failed to load configuration: %v\nPlease ensure your config.toml is present at one of the following paths:\n%v", err, []string{
"/etc/hmac-file-server/config.toml",
"../config.toml",
"./config.toml",
})
}
log.Info("Configuration loaded successfully.")
// Initialize file info cache
fileInfoCache = cache.New(5*time.Minute, 10*time.Minute)
// Create store directory
err = os.MkdirAll(conf.Server.StoragePath, os.ModePerm)
err = validateConfig(&conf)
if err != nil {
log.Fatalf("Error creating store directory: %v", err)
log.Fatalf("Configuration validation failed: %v", err)
}
log.WithField("directory", conf.Server.StoragePath).Info("Store directory is ready")
log.Info("Configuration validated successfully.")
// Perform comprehensive configuration validation
validationResult := ValidateConfigComprehensive(&conf)
PrintValidationResults(validationResult)
if validationResult.HasErrors() {
log.Fatal("Cannot start server due to configuration errors. Please fix the above issues and try again.")
}
// Handle specialized validation flags
if validateSecurity || validatePerformance || validateConnectivity || validateQuiet || validateVerbose || validateFixable {
runSpecializedValidation(&conf, validateSecurity, validatePerformance, validateConnectivity, validateQuiet, validateVerbose, validateFixable)
os.Exit(0)
}
// If only validation was requested, exit now
if validateOnly {
if validationResult.HasErrors() {
log.Error("Configuration validation failed with errors. Review the errors above.")
os.Exit(1)
} else if validationResult.HasWarnings() {
log.Info("Configuration is valid but has warnings. Review the warnings above.")
os.Exit(0)
} else {
log.Info("Configuration validation completed successfully!")
os.Exit(0)
}
}
// Set log level based on configuration
level, err := logrus.ParseLevel(conf.Logging.Level)
if err != nil {
log.Warnf("Invalid log level '%s', defaulting to 'info'", conf.Logging.Level)
level = logrus.InfoLevel
}
log.SetLevel(level)
log.Infof("Log level set to: %s", level.String())
// Log configuration settings using [logging] section
log.Infof("Server ListenAddress: %s", conf.Server.ListenAddress) // Corrected field name
log.Infof("Server UnixSocket: %v", conf.Server.UnixSocket)
log.Infof("Server StoragePath: %s", conf.Server.StoragePath)
log.Infof("Logging Level: %s", conf.Logging.Level)
log.Infof("Logging File: %s", conf.Logging.File)
log.Infof("Server MetricsEnabled: %v", conf.Server.MetricsEnabled)
log.Infof("Server MetricsPort: %s", conf.Server.MetricsPort) // Corrected field name
log.Infof("Server FileTTL: %s", conf.Server.FileTTL) // Corrected field name
log.Infof("Server MinFreeBytes: %s", conf.Server.MinFreeBytes)
log.Infof("Server AutoAdjustWorkers: %v", conf.Server.AutoAdjustWorkers) // Corrected field name
log.Infof("Server NetworkEvents: %v", conf.Server.NetworkEvents) // Corrected field name
log.Infof("Server PIDFilePath: %s", conf.Server.PIDFilePath) // Corrected field name
log.Infof("Server CleanUponExit: %v", conf.Server.CleanUponExit) // Corrected field name
log.Infof("Server PreCaching: %v", conf.Server.PreCaching) // Corrected field name
log.Infof("Server FileTTLEnabled: %v", conf.Server.FileTTLEnabled) // Corrected field name
log.Infof("Server DeduplicationEnabled: %v", conf.Server.DeduplicationEnabled)
log.Infof("Server BindIP: %s", conf.Server.BindIP) // Corrected field name
log.Infof("Server FileNaming: %s", conf.Server.FileNaming)
log.Infof("Server ForceProtocol: %s", conf.Server.ForceProtocol)
err = writePIDFile(conf.Server.PIDFilePath) // Corrected field name
if err != nil {
log.Fatalf("Error writing PID file: %v", err)
}
log.Debug("DEBUG: PID file written successfully")
log.Debugf("DEBUG: Config logging file: %s", conf.Logging.File)
// Setup logging
setupLogging()
log.Debug("DEBUG: Logging setup completed")
// Log system information
logSystemInfo()
log.Debug("DEBUG: System info logged")
// Initialize Prometheus metrics
// Initialize metrics before using any Prometheus counters
initMetrics()
log.Info("Prometheus metrics initialized.")
log.Debug("DEBUG: Metrics initialized")
// Initialize upload and scan queues
uploadQueue = make(chan UploadTask, conf.Workers.UploadQueueSize)
scanQueue = make(chan ScanTask, conf.Workers.UploadQueueSize)
networkEvents = make(chan NetworkEvent, 100)
log.Info("Upload, scan, and network event channels initialized.")
initializeWorkerSettings(&conf.Server, &conf.Workers, &conf.ClamAV)
log.Debug("DEBUG: Worker settings initialized")
// Context for goroutines
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Start network monitoring
go monitorNetwork(ctx)
go handleNetworkEvents(ctx)
// Update system metrics
go updateSystemMetrics(ctx)
// Initialize ClamAV client if enabled
if conf.ClamAV.ClamAVEnabled {
clamClient, err = initClamAV(conf.ClamAV.ClamAVSocket)
if conf.ISO.Enabled {
err := createAndMountISO(conf.ISO.Size, conf.ISO.MountPoint, conf.ISO.Charset)
if err != nil {
log.WithFields(logrus.Fields{
"error": err.Error(),
}).Warn("ClamAV client initialization failed. Continuing without ClamAV.")
} else {
log.Info("ClamAV client initialized successfully.")
log.Fatalf("Failed to create and mount ISO container: %v", err)
}
log.Infof("ISO container mounted at %s", conf.ISO.MountPoint)
}
// Initialize Redis client if enabled
if conf.Redis.RedisEnabled {
initRedis()
// Set storage path to ISO mount point if ISO is enabled
storagePath := conf.Server.StoragePath
if conf.ISO.Enabled {
storagePath = conf.ISO.MountPoint
}
// Redis Initialization
initRedis()
log.Info("Redis client initialized and connected successfully.")
fileInfoCache = cache.New(5*time.Minute, 10*time.Minute)
fileMetadataCache = cache.New(5*time.Minute, 10*time.Minute)
// ClamAV Initialization
if conf.ClamAV.ClamAVEnabled {
clamClient, err = initClamAV(conf.ClamAV.ClamAVSocket)
if err != nil {
log.WithFields(logrus.Fields{
"error": err.Error(),
}).Warn("ClamAV client initialization failed. Continuing without ClamAV.")
} else {
log.Info("ClamAV client initialized successfully.")
}
}
// Initialize worker pools
initializeUploadWorkerPool(ctx)
if conf.ClamAV.ClamAVEnabled && clamClient != nil {
initializeScanWorkerPool(ctx)
}
// Start Redis health monitor if Redis is enabled
if conf.Redis.RedisEnabled && redisClient != nil {
go MonitorRedisHealth(ctx, redisClient, parseDuration(conf.Redis.RedisHealthCheckInterval))
}
// Setup router
router := setupRouter()
// Start file cleaner
fileTTL, err := time.ParseDuration(conf.Server.FileTTL)
if err != nil {
log.Fatalf("Invalid FileTTL: %v", err)
}
go runFileCleaner(ctx, conf.Server.StoragePath, fileTTL)
// Parse timeout durations
readTimeout, err := time.ParseDuration(conf.Timeouts.ReadTimeout)
if err != nil {
log.Fatalf("Invalid ReadTimeout: %v", err)
}
writeTimeout, err := time.ParseDuration(conf.Timeouts.WriteTimeout)
if err != nil {
log.Fatalf("Invalid WriteTimeout: %v", err)
}
idleTimeout, err := time.ParseDuration(conf.Timeouts.IdleTimeout)
if err != nil {
log.Fatalf("Invalid IdleTimeout: %v", err)
}
// Configure HTTP server
server := &http.Server{
Addr: ":" + conf.Server.ListenPort, // Prepend colon to ListenPort
Handler: router,
ReadTimeout: readTimeout,
WriteTimeout: writeTimeout,
IdleTimeout: idleTimeout,
}
// Start metrics server if enabled
if conf.Server.MetricsEnabled {
if conf.Server.PreCaching { // Corrected field name
go func() {
http.Handle("/metrics", promhttp.Handler())
log.Infof("Metrics server started on port %s", conf.Server.MetricsPort)
if err := http.ListenAndServe(":"+conf.Server.MetricsPort, nil); err != nil {
log.Fatalf("Metrics server failed: %v", err)
log.Info("Starting pre-caching of storage path...")
// Use helper function
err := precacheStoragePath(storagePath)
if err != nil {
log.Warnf("Pre-caching storage path failed: %v", err)
} else {
log.Info("Pre-cached all files in the storage path.")
log.Info("Pre-caching status: complete.")
}
}()
}
// Setup graceful shutdown
setupGracefulShutdown(server, cancel)
err = os.MkdirAll(storagePath, os.ModePerm)
if err != nil {
log.Fatalf("Error creating store directory: %v", err)
}
log.WithField("directory", storagePath).Info("Store directory is ready")
// Use helper function
err = checkFreeSpaceWithRetry(storagePath, 3, 5*time.Second)
if err != nil {
log.Fatalf("Insufficient free space: %v", err)
}
initializeWorkerSettings(&conf.Server, &conf.Workers, &conf.ClamAV)
log.Info("Prometheus metrics initialized.")
networkEvents = make(chan NetworkEvent, 100)
log.Info("Network event channel initialized.")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if conf.Server.NetworkEvents { // Corrected field name
go monitorNetwork(ctx) // Assuming monitorNetwork is defined in helpers.go or elsewhere
go handleNetworkEvents(ctx) // Assuming handleNetworkEvents is defined in helpers.go or elsewhere
}
go updateSystemMetrics(ctx)
if conf.ClamAV.ClamAVEnabled {
var clamErr error
clamClient, clamErr = initClamAV(conf.ClamAV.ClamAVSocket) // Assuming initClamAV is defined in helpers.go or elsewhere
if clamErr != nil {
log.WithError(clamErr).Warn("ClamAV client initialization failed. Continuing without ClamAV.")
} else {
log.Info("ClamAV client initialized successfully.")
}
}
if conf.Redis.RedisEnabled {
initRedis() // Assuming initRedis is defined in helpers.go or elsewhere
}
router := setupRouter() // Assuming setupRouter is defined (likely in this file or router.go
go handleFileCleanup(&conf) // Directly call handleFileCleanup
readTimeout, err := time.ParseDuration(conf.Timeouts.Read) // Corrected field name
if err != nil {
log.Fatalf("Invalid ReadTimeout: %v", err)
}
writeTimeout, err := time.ParseDuration(conf.Timeouts.Write) // Corrected field name
if err != nil {
log.Fatalf("Invalid WriteTimeout: %v", err)
}
idleTimeout, err := time.ParseDuration(conf.Timeouts.Idle) // Corrected field name
if err != nil {
log.Fatalf("Invalid IdleTimeout: %v", err)
}
// Initialize network protocol based on forceprotocol setting
dialer, err := initializeNetworkProtocol(conf.Server.ForceProtocol)
if err != nil {
log.Fatalf("Failed to initialize network protocol: %v", err)
}
// Enhanced dual-stack HTTP client for robust IPv4/IPv6 and resource management
// See: https://pkg.go.dev/net/http#Transport for details on these settings
dualStackClient = &http.Client{
Transport: &http.Transport{
DialContext: dialer.DialContext,
IdleConnTimeout: 90 * time.Second, // Close idle connections after 90s
MaxIdleConns: 100, // Max idle connections across all hosts
MaxIdleConnsPerHost: 10, // Max idle connections per host
TLSHandshakeTimeout: 10 * time.Second, // Timeout for TLS handshake
ResponseHeaderTimeout: 15 * time.Second, // Timeout for reading response headers
},
}
server := &http.Server{
Addr: conf.Server.BindIP + ":" + conf.Server.ListenAddress, // Use BindIP + ListenAddress (port)
Handler: router,
ReadTimeout: readTimeout,
WriteTimeout: writeTimeout,
IdleTimeout: idleTimeout,
MaxHeaderBytes: 1 << 20, // 1 MB
}
if conf.Server.MetricsEnabled {
var wg sync.WaitGroup
go func() {
http.Handle("/metrics", promhttp.Handler())
log.Infof("Metrics server started on port %s", conf.Server.MetricsPort) // Corrected field name
if err := http.ListenAndServe(":"+conf.Server.MetricsPort, nil); err != nil { // Corrected field name
log.Fatalf("Metrics server failed: %v", err)
}
wg.Wait()
}()
}
setupGracefulShutdown(server, cancel) // Assuming setupGracefulShutdown is defined
if conf.Server.AutoAdjustWorkers { // Corrected field name
go monitorWorkerPerformance(ctx, &conf.Server, &conf.Workers, &conf.ClamAV)
}
versionString = "3.2" // Set a default version for now
if conf.Build.Version != "" {
versionString = conf.Build.Version
}
log.Infof("Running version: %s", versionString)
// Start server
log.Infof("Starting HMAC file server %s...", versionString)
if conf.Server.UnixSocket {
// Listen on Unix socket
if err := os.RemoveAll(conf.Server.ListenPort); err != nil {
socketPath := "/tmp/hmac-file-server.sock" // Use a default socket path since ListenAddress is now a port
if err := os.RemoveAll(socketPath); err != nil {
log.Fatalf("Failed to remove existing Unix socket: %v", err)
}
listener, err := net.Listen("unix", conf.Server.ListenPort)
listener, err := net.Listen("unix", socketPath)
if err != nil {
log.Fatalf("Failed to listen on Unix socket %s: %v", conf.Server.ListenPort, err)
log.Fatalf("Failed to listen on Unix socket %s: %v", socketPath, err)
}
defer listener.Close()
log.Infof("Server listening on Unix socket: %s", socketPath)
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
log.Fatalf("Server failed: %v", err)
}
} else {
// Listen on TCP port
if conf.Server.BindIP == "0.0.0.0" {
log.Info("Binding to 0.0.0.0. Any net/http logs you see are normal for this universal address.")
}
log.Infof("Server listening on %s", server.Addr)
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("Server failed: %v", err)
}
}
// Start file cleanup in a separate goroutine
// Use helper function
go handleFileCleanup(&conf)
}
// Function to load configuration using Viper
func readConfig(configFilename string, conf *Config) error {
viper.SetConfigFile(configFilename)
viper.SetConfigType("toml")
func printExampleConfig() {
fmt.Print(`
[server]
bind_ip = "0.0.0.0"
listenport = "8080"
unixsocket = false
storagepath = "./uploads"
logfile = "/var/log/hmac-file-server.log"
metricsenabled = true
metricsport = "9090"
minfreebytes = "100MB"
filettl = "8760h"
filettlenabled = true
autoadjustworkers = true
networkevents = true
pidfilepath = "/var/run/hmacfileserver.pid"
cleanuponexit = true
precaching = true
deduplicationenabled = true
globalextensions = [".txt", ".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp"]
# FileNaming options: "HMAC", "None"
filenaming = "HMAC"
forceprotocol = "auto"
// Read in environment variables that match
viper.AutomaticEnv()
viper.SetEnvPrefix("HMAC") // Prefix for environment variables
[logging]
level = "info"
file = "/var/log/hmac-file-server.log"
max_size = 100
max_backups = 7
max_age = 30
compress = true
// Read the config file
if err := viper.ReadInConfig(); err != nil {
return fmt.Errorf("error reading config file: %w", err)
}
[deduplication]
enabled = true
directory = "./deduplication"
// Unmarshal the config into the Config struct
if err := viper.Unmarshal(conf); err != nil {
return fmt.Errorf("unable to decode into struct: %w", err)
}
[iso]
enabled = true
size = "1GB"
mountpoint = "/mnt/iso"
charset = "utf-8"
containerfile = "/mnt/iso/container.iso"
// Debug log the loaded configuration
log.Debugf("Loaded Configuration: %+v", conf.Server)
[timeouts]
readtimeout = "4800s"
writetimeout = "4800s"
idletimeout = "4800s"
// Validate the configuration
if err := validateConfig(conf); err != nil {
return fmt.Errorf("configuration validation failed: %w", err)
}
[security]
secret = "changeme"
enablejwt = false
jwtsecret = "anothersecretkey"
jwtalgorithm = "HS256"
jwtexpiration = "24h"
// Set Deduplication Enabled
conf.Server.DeduplicationEnabled = viper.GetBool("deduplication.Enabled")
[versioning]
enableversioning = false
maxversions = 1
return nil
[uploads]
resumableuploadsenabled = true
chunkeduploadsenabled = true
chunksize = "8192"
allowedextensions = [".txt", ".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp"]
[downloads]
resumabledownloadsenabled = true
chunkeddownloadsenabled = true
chunksize = "8192"
allowedextensions = [".txt", ".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp"]
[clamav]
clamavenabled = true
clamavsocket = "/var/run/clamav/clamd.ctl"
numscanworkers = 2
scanfileextensions = [".txt", ".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp"]
[redis]
redisenabled = true
redisdbindex = 0
redisaddr = "localhost:6379"
redispassword = ""
redishealthcheckinterval = "120s"
[workers]
numworkers = 4
uploadqueuesize = 50
[file]
# Add file-specific configurations here
[build]
version = "3.2"
`)
}
// Set default configuration values
func setDefaults() {
// Server defaults
viper.SetDefault("server.ListenPort", "8080")
viper.SetDefault("server.UnixSocket", false)
viper.SetDefault("server.StoragePath", "./uploads")
viper.SetDefault("server.LogLevel", "info")
viper.SetDefault("server.LogFile", "")
viper.SetDefault("server.MetricsEnabled", true)
viper.SetDefault("server.MetricsPort", "9090")
viper.SetDefault("server.FileTTL", "8760h") // 365d -> 8760h
viper.SetDefault("server.MinFreeBytes", 100<<20) // 100 MB
func getExampleConfigString() string {
return `[server]
listen_address = ":8080"
storage_path = "/srv/hmac-file-server/uploads"
metrics_enabled = true
metrics_path = "/metrics"
pid_file = "/var/run/hmac-file-server.pid"
max_upload_size = "10GB" # Supports B, KB, MB, GB, TB
max_header_bytes = 1048576 # 1MB
cleanup_interval = "24h"
max_file_age = "720h" # 30 days
pre_cache = true
pre_cache_workers = 4
pre_cache_interval = "1h"
global_extensions = [".txt", ".dat", ".iso"] # If set, overrides upload/download extensions
deduplication_enabled = true
min_free_bytes = "1GB" # Minimum free space required for uploads
file_naming = "original" # Options: "original", "HMAC"
force_protocol = "" # Options: "http", "https" - if set, redirects to this protocol
enable_dynamic_workers = true # Enable dynamic worker scaling
worker_scale_up_thresh = 50 # Queue length to scale up workers
worker_scale_down_thresh = 10 # Queue length to scale down workers
// Timeout defaults
viper.SetDefault("timeouts.ReadTimeout", "4800s") // supports 's'
viper.SetDefault("timeouts.WriteTimeout", "4800s")
viper.SetDefault("timeouts.IdleTimeout", "4800s")
[uploads]
allowed_extensions = [".zip", ".rar", ".7z", ".tar.gz", ".tgz", ".gpg", ".enc", ".pgp"]
chunked_uploads_enabled = true
chunk_size = "10MB"
resumable_uploads_enabled = true
max_resumable_age = "48h"
// Security defaults
viper.SetDefault("security.Secret", "changeme")
[downloads]
allowed_extensions = [".zip", ".rar", ".7z", ".tar.gz", ".tgz", ".gpg", ".enc", ".pgp"]
chunked_downloads_enabled = true
chunk_size = "10MB"
resumable_downloads_enabled = true
// Versioning defaults
viper.SetDefault("versioning.EnableVersioning", false)
viper.SetDefault("versioning.MaxVersions", 1)
[security]
secret = "your-very-secret-hmac-key"
enablejwt = false
jwtsecret = "anothersecretkey"
jwtalgorithm = "HS256"
jwtexpiration = "24h"
// Uploads defaults
viper.SetDefault("uploads.ResumableUploadsEnabled", true)
viper.SetDefault("uploads.ChunkedUploadsEnabled", true)
viper.SetDefault("uploads.ChunkSize", 8192)
viper.SetDefault("uploads.AllowedExtensions", []string{
".txt", ".pdf",
".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp",
".wav", ".mp4", ".avi", ".mkv", ".mov", ".wmv", ".flv", ".webm", ".mpeg", ".mpg", ".m4v", ".3gp", ".3g2",
".mp3", ".ogg",
})
[logging]
level = "info"
file = "/var/log/hmac-file-server.log"
max_size = 100
max_backups = 7
max_age = 30
compress = true
// ClamAV defaults
viper.SetDefault("clamav.ClamAVEnabled", true)
viper.SetDefault("clamav.ClamAVSocket", "/var/run/clamav/clamd.ctl")
viper.SetDefault("clamav.NumScanWorkers", 2)
[deduplication]
enabled = true
directory = "./deduplication"
// Redis defaults
viper.SetDefault("redis.RedisEnabled", true)
viper.SetDefault("redis.RedisAddr", "localhost:6379")
viper.SetDefault("redis.RedisPassword", "")
viper.SetDefault("redis.RedisDBIndex", 0)
viper.SetDefault("redis.RedisHealthCheckInterval", "120s")
[iso]
enabled = true
size = "1GB"
mountpoint = "/mnt/iso"
charset = "utf-8"
containerfile = "/mnt/iso/container.iso"
// Workers defaults
viper.SetDefault("workers.NumWorkers", 2)
viper.SetDefault("workers.UploadQueueSize", 50)
[timeouts]
readtimeout = "4800s"
writetimeout = "4800s"
idletimeout = "4800s"
// Deduplication defaults
viper.SetDefault("deduplication.Enabled", true)
[security]
secret = "changeme"
enablejwt = false
jwtsecret = "anothersecretkey"
jwtalgorithm = "HS256"
jwtexpiration = "24h"
[versioning]
enableversioning = false
maxversions = 1
[uploads]
resumableuploadsenabled = true
chunkeduploadsenabled = true
chunksize = "8192"
allowedextensions = [".txt", ".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp"]
[downloads]
resumabledownloadsenabled = true
chunkeddownloadsenabled = true
chunksize = "8192"
allowedextensions = [".txt", ".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp"]
[clamav]
clamavenabled = true
clamavsocket = "/var/run/clamav/clamd.ctl"
numscanworkers = 2
scanfileextensions = [".txt", ".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp"]
[redis]
redisenabled = true
redisdbindex = 0
redisaddr = "localhost:6379"
redispassword = ""
redishealthcheckinterval = "120s"
[workers]
numworkers = 4
uploadqueuesize = 50
[file]
# Add file-specific configurations here
[build]
version = "3.2"
`
}
// Validate configuration fields
func validateConfig(conf *Config) error {
if conf.Server.ListenPort == "" {
return fmt.Errorf("ListenPort must be set")
func max(a, b int) int {
if a > b {
return a
}
if conf.Security.Secret == "" {
return fmt.Errorf("secret must be set")
}
if conf.Server.StoragePath == "" {
return fmt.Errorf("StoragePath must be set")
}
if conf.Server.FileTTL == "" {
return fmt.Errorf("FileTTL must be set")
}
// Validate timeouts
if _, err := time.ParseDuration(conf.Timeouts.ReadTimeout); err != nil {
return fmt.Errorf("invalid ReadTimeout: %v", err)
}
if _, err := time.ParseDuration(conf.Timeouts.WriteTimeout); err != nil {
return fmt.Errorf("invalid WriteTimeout: %v", err)
}
if _, err := time.ParseDuration(conf.Timeouts.IdleTimeout); err != nil {
return fmt.Errorf("invalid IdleTimeout: %v", err)
}
// Validate Redis configuration if enabled
if conf.Redis.RedisEnabled {
if conf.Redis.RedisAddr == "" {
return fmt.Errorf("RedisAddr must be set when Redis is enabled")
}
}
// Add more validations as needed
return nil
return b
}
// Setup logging
func setupLogging() {
level, err := logrus.ParseLevel(conf.Server.LogLevel)
if err != nil {
log.Fatalf("Invalid log level: %s", conf.Server.LogLevel)
}
log.SetLevel(level)
if conf.Server.LogFile != "" {
logFile, err := os.OpenFile(conf.Server.LogFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
if err != nil {
log.Fatalf("Failed to open log file: %v", err)
}
log.SetOutput(io.MultiWriter(os.Stdout, logFile))
} else {
log.SetOutput(os.Stdout)
}
// Use Text formatter for human-readable logs
log.SetFormatter(&logrus.TextFormatter{
FullTimestamp: true,
// You can customize the format further if needed
})
}
// Log system information
func logSystemInfo() {
log.Info("========================================")
log.Infof(" HMAC File Server - %s ", versionString)
log.Info(" Secure File Handling with HMAC Auth ")
log.Info("========================================")
log.Info("Features: Prometheus Metrics, Chunked Uploads, ClamAV Scanning")
log.Info("Build Date: 2024-10-28")
log.Infof("Operating System: %s", runtime.GOOS)
log.Infof("Architecture: %s", runtime.GOARCH)
log.Infof("Number of CPUs: %d", runtime.NumCPU())
log.Infof("Go Version: %s", runtime.Version())
func autoAdjustWorkers() (int, int) {
v, _ := mem.VirtualMemory()
log.Infof("Total Memory: %v MB", v.Total/1024/1024)
log.Infof("Free Memory: %v MB", v.Free/1024/1024)
log.Infof("Used Memory: %v MB", v.Used/1024/1024)
cpuCores, _ := cpu.Counts(true)
cpuInfo, _ := cpu.Info()
for _, info := range cpuInfo {
log.Infof("CPU Model: %s, Cores: %d, Mhz: %f", info.ModelName, info.Cores, info.Mhz)
numWorkers := cpuCores * 2
if v.Available < 4*1024*1024*1024 { // Less than 4GB available
numWorkers = max(numWorkers/2, 1)
} else if v.Available < 8*1024*1024*1024 { // Less than 8GB available
numWorkers = max(numWorkers*3/4, 1)
}
queueSize := numWorkers * 10
partitions, _ := disk.Partitions(false)
for _, partition := range partitions {
usage, _ := disk.Usage(partition.Mountpoint)
log.Infof("Disk Mountpoint: %s, Total: %v GB, Free: %v GB, Used: %v GB",
partition.Mountpoint, usage.Total/1024/1024/1024, usage.Free/1024/1024/1024, usage.Used/1024/1024/1024)
}
hInfo, _ := host.Info()
log.Infof("Hostname: %s", hInfo.Hostname)
log.Infof("Uptime: %v seconds", hInfo.Uptime)
log.Infof("Boot Time: %v", time.Unix(int64(hInfo.BootTime), 0))
log.Infof("Platform: %s", hInfo.Platform)
log.Infof("Platform Family: %s", hInfo.PlatformFamily)
log.Infof("Platform Version: %s", hInfo.PlatformVersion)
log.Infof("Kernel Version: %s", hInfo.KernelVersion)
log.Infof("Auto-adjusting workers: NumWorkers=%d, UploadQueueSize=%d", numWorkers, queueSize)
workerAdjustmentsTotal.Inc()
return numWorkers, queueSize
}
// Initialize Prometheus metrics
// Duplicate initMetrics function removed
func initMetrics() {
uploadDuration = prometheus.NewHistogram(prometheus.HistogramOpts{Namespace: "hmac", Name: "file_server_upload_duration_seconds", Help: "Histogram of file upload duration in seconds.", Buckets: prometheus.DefBuckets})
uploadErrorsTotal = prometheus.NewCounter(prometheus.CounterOpts{Namespace: "hmac", Name: "file_server_upload_errors_total", Help: "Total number of file upload errors."})
uploadsTotal = prometheus.NewCounter(prometheus.CounterOpts{Namespace: "hmac", Name: "file_server_uploads_total", Help: "Total number of successful file uploads."})
downloadDuration = prometheus.NewHistogram(prometheus.HistogramOpts{Namespace: "hmac", Name: "file_server_download_duration_seconds", Help: "Histogram of file download duration in seconds.", Buckets: prometheus.DefBuckets})
downloadsTotal = prometheus.NewCounter(prometheus.CounterOpts{Namespace: "hmac", Name: "file_server_downloads_total", Help: "Total number of successful file downloads."})
downloadErrorsTotal = prometheus.NewCounter(prometheus.CounterOpts{Namespace: "hmac", Name: "file_server_download_errors_total", Help: "Total number of file download errors."})
memoryUsage = prometheus.NewGauge(prometheus.GaugeOpts{Namespace: "hmac", Name: "memory_usage_bytes", Help: "Current memory usage in bytes."})
cpuUsage = prometheus.NewGauge(prometheus.GaugeOpts{Namespace: "hmac", Name: "cpu_usage_percent", Help: "Current CPU usage as a percentage."})
activeConnections = prometheus.NewGauge(prometheus.GaugeOpts{Namespace: "hmac", Name: "active_connections_total", Help: "Total number of active connections."})
requestsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{Namespace: "hmac", Name: "http_requests_total", Help: "Total number of HTTP requests received, labeled by method and path."}, []string{"method", "path"})
goroutines = prometheus.NewGauge(prometheus.GaugeOpts{Namespace: "hmac", Name: "goroutines_count", Help: "Current number of goroutines."})
uploadSizeBytes = prometheus.NewHistogram(prometheus.HistogramOpts{
Namespace: "hmac",
Name: "file_server_upload_size_bytes",
Help: "Histogram of uploaded file sizes in bytes.",
Buckets: prometheus.ExponentialBuckets(100, 10, 8),
})
downloadSizeBytes = prometheus.NewHistogram(prometheus.HistogramOpts{
Namespace: "hmac",
Name: "file_server_download_size_bytes",
Help: "Histogram of downloaded file sizes in bytes.",
Buckets: prometheus.ExponentialBuckets(100, 10, 8),
})
func initializeWorkerSettings(server *ServerConfig, workers *WorkersConfig, clamav *ClamAVConfig) {
if server.AutoAdjustWorkers {
numWorkers, queueSize := autoAdjustWorkers()
workers.NumWorkers = numWorkers
workers.UploadQueueSize = queueSize
clamav.NumScanWorkers = max(numWorkers/2, 1)
if conf.Server.MetricsEnabled {
prometheus.MustRegister(uploadDuration, uploadErrorsTotal, uploadsTotal)
prometheus.MustRegister(downloadDuration, downloadsTotal, downloadErrorsTotal)
prometheus.MustRegister(memoryUsage, cpuUsage, activeConnections, requestsTotal, goroutines)
prometheus.MustRegister(uploadSizeBytes, downloadSizeBytes)
log.Infof("AutoAdjustWorkers enabled: NumWorkers=%d, UploadQueueSize=%d, NumScanWorkers=%d",
workers.NumWorkers, workers.UploadQueueSize, clamav.NumScanWorkers)
} else {
log.Infof("Manual configuration in effect: NumWorkers=%d, UploadQueueSize=%d, NumScanWorkers=%d",
workers.NumWorkers, workers.UploadQueueSize, clamav.NumScanWorkers)
}
}
// Update system metrics
func updateSystemMetrics(ctx context.Context) {
ticker := time.NewTicker(10 * time.Second)
func monitorWorkerPerformance(ctx context.Context, server *ServerConfig, w *WorkersConfig, clamav *ClamAVConfig) {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
log.Info("Stopping system metrics updater.")
log.Info("Stopping worker performance monitor.")
return
case <-ticker.C:
v, _ := mem.VirtualMemory()
memoryUsage.Set(float64(v.Used))
if server.AutoAdjustWorkers {
numWorkers, queueSize := autoAdjustWorkers()
w.NumWorkers = numWorkers
w.UploadQueueSize = queueSize
clamav.NumScanWorkers = max(numWorkers/2, 1)
cpuPercent, _ := cpu.Percent(0, false)
if len(cpuPercent) > 0 {
cpuUsage.Set(cpuPercent[0])
log.Infof("Re-adjusted workers: NumWorkers=%d, UploadQueueSize=%d, NumScanWorkers=%d",
w.NumWorkers, w.UploadQueueSize, clamav.NumScanWorkers)
workerReAdjustmentsTotal.Inc()
}
goroutines.Set(float64(runtime.NumGoroutine()))
}
}
}
// Function to check if a file exists and return its size
func fileExists(filePath string) (bool, int64) {
if cachedInfo, found := fileInfoCache.Get(filePath); found {
if info, ok := cachedInfo.(os.FileInfo); ok {
return !info.IsDir(), info.Size()
}
}
fileInfo, err := os.Stat(filePath)
if os.IsNotExist(err) {
return false, 0
} else if err != nil {
log.Error("Error checking file existence:", err)
return false, 0
}
fileInfoCache.Set(filePath, fileInfo, cache.DefaultExpiration)
return !fileInfo.IsDir(), fileInfo.Size()
}
// Function to check file extension
func isExtensionAllowed(filename string) bool {
if len(conf.Uploads.AllowedExtensions) == 0 {
return true // No restrictions if the list is empty
}
ext := strings.ToLower(filepath.Ext(filename))
for _, allowedExt := range conf.Uploads.AllowedExtensions {
if strings.ToLower(allowedExt) == ext {
return true
}
}
return false
}
// Version the file by moving the existing file to a versioned directory
func versionFile(absFilename string) error {
versionDir := absFilename + "_versions"
err := os.MkdirAll(versionDir, os.ModePerm)
if err != nil {
return fmt.Errorf("failed to create version directory: %v", err)
}
timestamp := time.Now().Format("20060102-150405")
versionedFilename := filepath.Join(versionDir, filepath.Base(absFilename)+"."+timestamp)
err = os.Rename(absFilename, versionedFilename)
if err != nil {
return fmt.Errorf("failed to version the file: %v", err)
}
log.WithFields(logrus.Fields{
"original": absFilename,
"versioned_as": versionedFilename,
}).Info("Versioned old file")
return cleanupOldVersions(versionDir)
}
// Clean up older versions if they exceed the maximum allowed
func cleanupOldVersions(versionDir string) error {
files, err := os.ReadDir(versionDir)
if err != nil {
return fmt.Errorf("failed to list version files: %v", err)
}
if conf.Versioning.MaxVersions > 0 && len(files) > conf.Versioning.MaxVersions {
excessFiles := len(files) - conf.Versioning.MaxVersions
for i := 0; i < excessFiles; i++ {
err := os.Remove(filepath.Join(versionDir, files[i].Name()))
if err != nil {
return fmt.Errorf("failed to remove old version: %v", err)
}
log.WithField("file", files[i].Name()).Info("Removed old version")
}
}
return nil
}
// Process the upload task
func processUpload(task UploadTask) error {
absFilename := task.AbsFilename
tempFilename := absFilename + ".tmp"
r := task.Request
log.Infof("Processing upload for file: %s", absFilename)
startTime := time.Now()
// Handle uploads and write to a temporary file
if conf.Uploads.ChunkedUploadsEnabled {
log.Debugf("Chunked uploads enabled. Handling chunked upload for %s", tempFilename)
err := handleChunkedUpload(tempFilename, r)
if err != nil {
uploadDuration.Observe(time.Since(startTime).Seconds())
log.WithFields(logrus.Fields{
"file": tempFilename,
"error": err,
}).Error("Failed to handle chunked upload")
return err
}
} else {
log.Debugf("Handling standard upload for %s", tempFilename)
err := createFile(tempFilename, r)
if err != nil {
log.WithFields(logrus.Fields{
"file": tempFilename,
"error": err,
}).Error("Error creating file")
uploadDuration.Observe(time.Since(startTime).Seconds())
return err
}
}
// Perform ClamAV scan on the temporary file
if clamClient != nil {
log.Debugf("Scanning %s with ClamAV", tempFilename)
err := scanFileWithClamAV(tempFilename)
if err != nil {
log.WithFields(logrus.Fields{
"file": tempFilename,
"error": err,
}).Warn("ClamAV detected a virus or scan failed")
os.Remove(tempFilename)
uploadErrorsTotal.Inc()
return err
}
log.Infof("ClamAV scan passed for file: %s", tempFilename)
}
// Handle file versioning if enabled
if conf.Versioning.EnableVersioning {
existing, _ := fileExists(absFilename)
if existing {
log.Infof("File %s exists. Initiating versioning.", absFilename)
err := versionFile(absFilename)
if err != nil {
log.WithFields(logrus.Fields{
"file": absFilename,
"error": err,
}).Error("Error versioning file")
os.Remove(tempFilename)
return err
}
log.Infof("File versioned successfully: %s", absFilename)
}
}
// Rename temporary file to final destination
err := os.Rename(tempFilename, absFilename)
if err != nil {
log.WithFields(logrus.Fields{
"temp_file": tempFilename,
"final_file": absFilename,
"error": err,
}).Error("Failed to move file to final destination")
os.Remove(tempFilename)
func readConfig(configFilename string, conf *Config) error {
viper.SetConfigFile(configFilename)
if err := viper.ReadInConfig(); err != nil {
log.WithError(err).Errorf("Unable to read config from %s", configFilename)
return err
}
log.Infof("File moved to final destination: %s", absFilename)
// Handle deduplication if enabled
if conf.Server.DeduplicationEnabled {
log.Debugf("Deduplication enabled. Checking duplicates for %s", absFilename)
err = handleDeduplication(context.Background(), absFilename)
if err != nil {
log.WithError(err).Error("Deduplication failed")
uploadErrorsTotal.Inc()
return err
}
log.Infof("Deduplication handled successfully for file: %s", absFilename)
if err := viper.Unmarshal(conf); err != nil {
return fmt.Errorf("unable to decode config into struct: %v", err)
}
log.WithFields(logrus.Fields{
"file": absFilename,
}).Info("File uploaded and processed successfully")
uploadDuration.Observe(time.Since(startTime).Seconds())
uploadsTotal.Inc()
return nil
}
// uploadWorker processes upload tasks from the uploadQueue
func uploadWorker(ctx context.Context, workerID int) {
log.Infof("Upload worker %d started.", workerID)
defer log.Infof("Upload worker %d stopped.", workerID)
for {
select {
case <-ctx.Done():
return
case task, ok := <-uploadQueue:
if !ok {
log.Warnf("Upload queue closed. Worker %d exiting.", workerID)
return
}
log.Infof("Worker %d processing upload for file: %s", workerID, task.AbsFilename)
err := processUpload(task)
if err != nil {
log.Errorf("Worker %d failed to process upload for %s: %v", workerID, task.AbsFilename, err)
uploadErrorsTotal.Inc()
} else {
log.Infof("Worker %d successfully processed upload for %s", workerID, task.AbsFilename)
}
task.Result <- err
close(task.Result)
func setDefaults() {
viper.SetDefault("server.listen_address", ":8080")
viper.SetDefault("server.storage_path", "./uploads")
viper.SetDefault("server.metrics_enabled", true)
viper.SetDefault("server.metrics_path", "/metrics")
viper.SetDefault("server.pid_file", "/var/run/hmac-file-server.pid")
viper.SetDefault("server.max_upload_size", "10GB")
viper.SetDefault("server.max_header_bytes", 1048576) // 1MB
viper.SetDefault("server.cleanup_interval", "24h")
viper.SetDefault("server.max_file_age", "720h") // 30 days
viper.SetDefault("server.pre_cache", true)
viper.SetDefault("server.pre_cache_workers", 4)
viper.SetDefault("server.pre_cache_interval", "1h")
viper.SetDefault("server.global_extensions", []string{})
viper.SetDefault("server.deduplication_enabled", true)
viper.SetDefault("server.min_free_bytes", "1GB")
viper.SetDefault("server.file_naming", "original")
viper.SetDefault("server.force_protocol", "")
viper.SetDefault("server.enable_dynamic_workers", true)
viper.SetDefault("server.worker_scale_up_thresh", 50)
viper.SetDefault("server.worker_scale_down_thresh", 10)
viper.SetDefault("uploads.allowed_extensions", []string{".zip", ".rar", ".7z", ".tar.gz", ".tgz", ".gpg", ".enc", ".pgp"})
viper.SetDefault("uploads.chunked_uploads_enabled", true)
viper.SetDefault("uploads.chunk_size", "10MB")
viper.SetDefault("uploads.resumable_uploads_enabled", true)
viper.SetDefault("uploads.max_resumable_age", "48h")
viper.SetDefault("downloads.allowed_extensions", []string{".zip", ".rar", ".7z", ".tar.gz", ".tgz", ".gpg", ".enc", ".pgp"})
viper.SetDefault("downloads.chunked_downloads_enabled", true)
viper.SetDefault("downloads.chunk_size", "10MB")
viper.SetDefault("downloads.resumable_downloads_enabled", true)
viper.SetDefault("security.secret", "your-very-secret-hmac-key")
viper.SetDefault("security.enablejwt", false)
viper.SetDefault("security.jwtsecret", "your-256-bit-secret")
viper.SetDefault("security.jwtalgorithm", "HS256")
viper.SetDefault("security.jwtexpiration", "24h")
// Logging defaults
viper.SetDefault("logging.level", "info")
viper.SetDefault("logging.file", "/var/log/hmac-file-server.log")
viper.SetDefault("logging.max_size", 100)
viper.SetDefault("logging.max_backups", 7)
viper.SetDefault("logging.max_age", 30)
viper.SetDefault("logging.compress", true)
// Deduplication defaults
viper.SetDefault("deduplication.enabled", false)
viper.SetDefault("deduplication.directory", "./dedup_store")
// ISO defaults
viper.SetDefault("iso.enabled", false)
viper.SetDefault("iso.mount_point", "/mnt/hmac_iso")
viper.SetDefault("iso.size", "1GB")
viper.SetDefault("iso.charset", "utf-8")
viper.SetDefault("iso.containerfile", "/var/lib/hmac-file-server/data.iso")
// Timeouts defaults
viper.SetDefault("timeouts.read", "60s")
viper.SetDefault("timeouts.write", "60s")
viper.SetDefault("timeouts.idle", "120s")
viper.SetDefault("timeouts.shutdown", "30s")
// Versioning defaults
viper.SetDefault("versioning.enabled", false)
viper.SetDefault("versioning.backend", "simple")
viper.SetDefault("versioning.max_revisions", 5)
// ... other defaults for Uploads, Downloads, ClamAV, Redis, Workers, File, Build
viper.SetDefault("build.version", "dev")
}
func validateConfig(c *Config) error {
if c.Server.ListenAddress == "" { // Corrected field name
return errors.New("server.listen_address is required")
}
if c.Server.FileTTL == "" && c.Server.FileTTLEnabled { // Corrected field names
return errors.New("server.file_ttl is required when server.file_ttl_enabled is true")
}
if _, err := time.ParseDuration(c.Timeouts.Read); err != nil { // Corrected field name
return fmt.Errorf("invalid timeouts.read: %v", err)
}
if _, err := time.ParseDuration(c.Timeouts.Write); err != nil { // Corrected field name
return fmt.Errorf("invalid timeouts.write: %v", err)
}
if _, err := time.ParseDuration(c.Timeouts.Idle); err != nil { // Corrected field name
return fmt.Errorf("invalid timeouts.idle: %v", err)
}
// Corrected VersioningConfig field access
if c.Versioning.Enabled { // Use the Go struct field name 'Enabled'
if c.Versioning.MaxRevs <= 0 { // Use the Go struct field name 'MaxRevs'
return errors.New("versioning.max_revisions must be positive if versioning is enabled")
}
}
}
// Initialize upload worker pool
func initializeUploadWorkerPool(ctx context.Context) {
for i := 0; i < MinWorkers; i++ {
go uploadWorker(ctx, i)
}
log.Infof("Initialized %d upload workers", MinWorkers)
}
// Worker function to process scan tasks
func scanWorker(ctx context.Context, workerID int) {
log.WithField("worker_id", workerID).Info("Scan worker started")
for {
select {
case <-ctx.Done():
log.WithField("worker_id", workerID).Info("Scan worker stopping")
return
case task, ok := <-scanQueue:
if !ok {
log.WithField("worker_id", workerID).Info("Scan queue closed")
return
}
log.WithFields(logrus.Fields{
"worker_id": workerID,
"file": task.AbsFilename,
}).Info("Processing scan task")
err := scanFileWithClamAV(task.AbsFilename)
if err != nil {
log.WithFields(logrus.Fields{
"worker_id": workerID,
"file": task.AbsFilename,
"error": err,
}).Error("Failed to scan file")
} else {
log.WithFields(logrus.Fields{
"worker_id": workerID,
"file": task.AbsFilename,
}).Info("Successfully scanned file")
}
task.Result <- err
close(task.Result)
}
}
}
// Initialize scan worker pool
func initializeScanWorkerPool(ctx context.Context) {
for i := 0; i < ScanWorkers; i++ {
go scanWorker(ctx, i)
}
log.Infof("Initialized %d scan workers", ScanWorkers)
}
// Setup router with middleware
func setupRouter() http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("/", handleRequest)
if conf.Server.MetricsEnabled {
mux.Handle("/metrics", promhttp.Handler())
// Validate JWT secret if JWT is enabled
if c.Security.EnableJWT && strings.TrimSpace(c.Security.JWTSecret) == "" {
return errors.New("security.jwtsecret is required when security.enablejwt is true")
}
// Apply middleware
handler := loggingMiddleware(mux)
handler = recoveryMiddleware(handler)
handler = corsMiddleware(handler)
return handler
}
// Middleware for logging
func loggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestsTotal.WithLabelValues(r.Method, r.URL.Path).Inc()
next.ServeHTTP(w, r)
})
}
// Middleware for panic recovery
func recoveryMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if rec := recover(); rec != nil {
log.WithFields(logrus.Fields{
"method": r.Method,
"url": r.URL.String(),
"error": rec,
}).Error("Panic recovered in HTTP handler")
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
}()
next.ServeHTTP(w, r)
})
}
// corsMiddleware handles CORS by setting appropriate headers
func corsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Set CORS headers
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-File-MAC")
w.Header().Set("Access-Control-Max-Age", "86400") // Cache preflight response for 1 day
// Handle preflight OPTIONS request
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusOK)
return
}
// Proceed to the next handler
next.ServeHTTP(w, r)
})
}
// Handle file uploads and downloads
func handleRequest(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPost && strings.Contains(r.Header.Get("Content-Type"), "multipart/form-data") {
absFilename, err := sanitizeFilePath(conf.Server.StoragePath, strings.TrimPrefix(r.URL.Path, "/"))
if err != nil {
log.WithError(err).Error("Invalid file path")
http.Error(w, "Invalid file path", http.StatusBadRequest)
return
}
err = handleMultipartUpload(w, r, absFilename)
if err != nil {
log.WithError(err).Error("Failed to handle multipart upload")
http.Error(w, "Failed to handle multipart upload", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusCreated)
return
// Validate HMAC secret if JWT is not enabled (as it's the fallback)
if !c.Security.EnableJWT && strings.TrimSpace(c.Security.Secret) == "" {
return errors.New("security.secret is required for HMAC authentication (when JWT is disabled)")
}
// Get client IP address
clientIP := r.Header.Get("X-Real-IP")
if clientIP == "" {
clientIP = r.Header.Get("X-Forwarded-For")
}
if clientIP == "" {
// Fallback to RemoteAddr
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
log.WithError(err).Warn("Failed to parse RemoteAddr")
clientIP = r.RemoteAddr
return nil
}
// validateJWTFromRequest extracts and validates a JWT from the request.
func validateJWTFromRequest(r *http.Request, secret string) (*jwt.Token, error) {
authHeader := r.Header.Get("Authorization")
tokenString := ""
if authHeader != "" {
splitToken := strings.Split(authHeader, "Bearer ")
if len(splitToken) == 2 {
tokenString = splitToken[1]
} else {
clientIP = host
return nil, errors.New("invalid Authorization header format")
}
} else {
// Fallback to checking 'token' query parameter
tokenString = r.URL.Query().Get("token")
if tokenString == "" {
return nil, errors.New("missing JWT in Authorization header or 'token' query parameter")
}
}
// Log the request with the client IP
log.WithFields(logrus.Fields{
"method": r.Method,
"url": r.URL.String(),
"remote": clientIP,
}).Info("Incoming request")
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(secret), nil
})
// Parse URL and query parameters
p := r.URL.Path
a, err := url.ParseQuery(r.URL.RawQuery)
if err != nil {
log.Warn("Failed to parse query parameters")
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
return nil, fmt.Errorf("JWT validation failed: %w", err)
}
fileStorePath := strings.TrimPrefix(p, "/")
if fileStorePath == "" || fileStorePath == "/" {
log.Warn("Access to root directory is forbidden")
http.Error(w, "Forbidden", http.StatusForbidden)
return
} else if fileStorePath[0] == '/' {
fileStorePath = fileStorePath[1:]
if !token.Valid {
return nil, errors.New("invalid JWT")
}
absFilename, err := sanitizeFilePath(conf.Server.StoragePath, fileStorePath)
if err != nil {
log.WithFields(logrus.Fields{
"file": fileStorePath,
"error": err,
}).Warn("Invalid file path")
http.Error(w, "Invalid file path", http.StatusBadRequest)
return
}
switch r.Method {
case http.MethodPut:
handleUpload(w, r, absFilename, fileStorePath, a)
case http.MethodHead, http.MethodGet:
handleDownload(w, r, absFilename, fileStorePath)
case http.MethodOptions:
// Handled by NGINX; no action needed
w.Header().Set("Allow", "OPTIONS, GET, PUT, HEAD")
return
default:
log.WithField("method", r.Method).Warn("Invalid HTTP method for upload directory")
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return
}
return token, nil
}
// Handle file uploads with extension restrictions and HMAC validation
func handleUpload(w http.ResponseWriter, r *http.Request, absFilename, fileStorePath string, a url.Values) {
// Log the storage path being used
log.Infof("Using storage path: %s", conf.Server.StoragePath)
// validateHMAC validates the HMAC signature of the request for legacy protocols and POST uploads.
func validateHMAC(r *http.Request, secret string) error {
log.Debugf("validateHMAC: Validating request to %s with query: %s", r.URL.Path, r.URL.RawQuery)
// Check for X-Signature header (for POST uploads)
signature := r.Header.Get("X-Signature")
if signature != "" {
// This is a POST upload with X-Signature header
message := r.URL.Path
h := hmac.New(sha256.New, []byte(secret))
h.Write([]byte(message))
expectedSignature := hex.EncodeToString(h.Sum(nil))
// Determine protocol version based on query parameters
var protocolVersion string
if a.Get("v2") != "" {
protocolVersion = "v2"
} else if a.Get("token") != "" {
protocolVersion = "token"
} else if a.Get("v") != "" {
protocolVersion = "v"
} else {
log.Warn("No HMAC attached to URL. Expecting 'v', 'v2', or 'token' parameter as MAC")
http.Error(w, "No HMAC attached to URL. Expecting 'v', 'v2', or 'token' parameter as MAC", http.StatusForbidden)
return
if !hmac.Equal([]byte(signature), []byte(expectedSignature)) {
return errors.New("invalid HMAC signature in X-Signature header")
}
return nil
}
log.Debugf("Protocol version determined: %s", protocolVersion)
// Initialize HMAC
mac := hmac.New(sha256.New, []byte(conf.Security.Secret))
// Check for legacy URL-based HMAC protocols (v, v2, token)
query := r.URL.Query()
var protocolVersion string
var providedMACHex string
if query.Get("v2") != "" {
protocolVersion = "v2"
providedMACHex = query.Get("v2")
} else if query.Get("token") != "" {
protocolVersion = "token"
providedMACHex = query.Get("token")
} else if query.Get("v") != "" {
protocolVersion = "v"
providedMACHex = query.Get("v")
} else {
return errors.New("no HMAC signature found (missing X-Signature header or v/v2/token query parameter)")
}
// Extract file path from URL
fileStorePath := strings.TrimPrefix(r.URL.Path, "/")
// Calculate HMAC based on protocol version (matching legacy behavior)
mac := hmac.New(sha256.New, []byte(secret))
// Calculate MAC based on protocolVersion
if protocolVersion == "v" {
mac.Write([]byte(fileStorePath + "\x20" + strconv.FormatInt(r.ContentLength, 10)))
} else if protocolVersion == "v2" || protocolVersion == "token" {
// Legacy v protocol: fileStorePath + "\x20" + contentLength
message := fileStorePath + "\x20" + strconv.FormatInt(r.ContentLength, 10)
mac.Write([]byte(message))
} else {
// v2 and token protocols: fileStorePath + "\x00" + contentLength + "\x00" + contentType
contentType := mime.TypeByExtension(filepath.Ext(fileStorePath))
if contentType == "" {
contentType = "application/octet-stream"
}
mac.Write([]byte(fileStorePath + "\x00" + strconv.FormatInt(r.ContentLength, 10) + "\x00" + contentType))
message := fileStorePath + "\x00" + strconv.FormatInt(r.ContentLength, 10) + "\x00" + contentType
log.Debugf("validateHMAC: %s protocol message: %q (len=%d)", protocolVersion, message, len(message))
mac.Write([]byte(message))
}
calculatedMAC := mac.Sum(nil)
log.Debugf("Calculated MAC: %x", calculatedMAC)
calculatedMACHex := hex.EncodeToString(calculatedMAC)
// Decode provided MAC from hex
providedMACHex := a.Get(protocolVersion)
// Decode provided MAC
providedMAC, err := hex.DecodeString(providedMACHex)
if err != nil {
log.Warn("Invalid MAC encoding")
http.Error(w, "Invalid MAC encoding", http.StatusForbidden)
return
return fmt.Errorf("invalid MAC encoding for %s protocol: %v", protocolVersion, err)
}
log.Debugf("Provided MAC: %x", providedMAC)
// Validate the HMAC
log.Debugf("validateHMAC: %s protocol - calculated: %s, provided: %s", protocolVersion, calculatedMACHex, providedMACHex)
// Compare MACs
if !hmac.Equal(calculatedMAC, providedMAC) {
log.Warn("Invalid MAC")
http.Error(w, "Invalid MAC", http.StatusForbidden)
return
}
log.Debug("HMAC validation successful")
// Validate file extension
if !isExtensionAllowed(fileStorePath) {
log.WithFields(logrus.Fields{
// No need to sanitize and validate the file path here since absFilename is already sanitized in handleRequest
"file": fileStorePath,
"error": err,
}).Warn("Invalid file path")
http.Error(w, "Invalid file path", http.StatusBadRequest)
uploadErrorsTotal.Inc()
return
}
// absFilename = sanitizedFilename
// Check if there is enough free space
err = checkStorageSpace(conf.Server.StoragePath, conf.Server.MinFreeBytes)
if err != nil {
log.WithFields(logrus.Fields{
"storage_path": conf.Server.StoragePath,
"error": err,
}).Warn("Not enough free space")
http.Error(w, "Not enough free space", http.StatusInsufficientStorage)
uploadErrorsTotal.Inc()
return
return fmt.Errorf("invalid MAC for %s protocol", protocolVersion)
}
// Create an UploadTask with a result channel
result := make(chan error)
task := UploadTask{
AbsFilename: absFilename,
Request: r,
Result: result,
}
// Submit task to the upload queue
select {
case uploadQueue <- task:
// Successfully added to the queue
log.Debug("Upload task enqueued successfully")
default:
// Queue is full
log.Warn("Upload queue is full. Rejecting upload")
http.Error(w, "Server busy. Try again later.", http.StatusServiceUnavailable)
uploadErrorsTotal.Inc()
return
}
// Wait for the worker to process the upload
err = <-result
if err != nil {
// The worker has already logged the error; send an appropriate HTTP response
http.Error(w, fmt.Sprintf("Upload failed: %v", err), http.StatusInternalServerError)
return
}
// Upload was successful
w.WriteHeader(http.StatusCreated)
log.Debugf("%s HMAC authentication successful for request: %s", protocolVersion, r.URL.Path)
return nil
}
// Handle file downloads
func handleDownload(w http.ResponseWriter, r *http.Request, absFilename, fileStorePath string) {
fileInfo, err := getFileInfo(absFilename)
// validateV3HMAC validates the HMAC signature for v3 protocol (mod_http_upload_external).
func validateV3HMAC(r *http.Request, secret string) error {
query := r.URL.Query()
// Extract v3 signature and expires from query parameters
signature := query.Get("v3")
expiresStr := query.Get("expires")
if signature == "" {
return errors.New("missing v3 signature parameter")
}
if expiresStr == "" {
return errors.New("missing expires parameter")
}
// Parse expires timestamp
expires, err := strconv.ParseInt(expiresStr, 10, 64)
if err != nil {
log.WithError(err).Error("Failed to get file information")
http.Error(w, "Not Found", http.StatusNotFound)
downloadErrorsTotal.Inc()
return fmt.Errorf("invalid expires parameter: %v", err)
}
// Check if signature has expired
now := time.Now().Unix()
if now > expires {
return errors.New("signature has expired")
}
// Construct message for HMAC verification
// Format: METHOD\nEXPIRES\nPATH
message := fmt.Sprintf("%s\n%s\n%s", r.Method, expiresStr, r.URL.Path)
// Calculate expected HMAC signature
h := hmac.New(sha256.New, []byte(secret))
h.Write([]byte(message))
expectedSignature := hex.EncodeToString(h.Sum(nil))
// Compare signatures
if !hmac.Equal([]byte(signature), []byte(expectedSignature)) {
return errors.New("invalid v3 HMAC signature")
}
return nil
}
// handleUpload handles file uploads.
func handleUpload(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
activeConnections.Inc()
defer activeConnections.Dec()
// Only allow POST method
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
uploadErrorsTotal.Inc()
return
} else if fileInfo.IsDir() {
log.Warn("Directory listing forbidden")
http.Error(w, "Forbidden", http.StatusForbidden)
}
// Authentication
if conf.Security.EnableJWT {
_, err := validateJWTFromRequest(r, conf.Security.JWTSecret)
if err != nil {
http.Error(w, fmt.Sprintf("JWT Authentication failed: %v", err), http.StatusUnauthorized)
uploadErrorsTotal.Inc()
return
}
log.Debugf("JWT authentication successful for upload request: %s", r.URL.Path)
} else {
err := validateHMAC(r, conf.Security.Secret)
if err != nil {
http.Error(w, fmt.Sprintf("HMAC Authentication failed: %v", err), http.StatusUnauthorized)
uploadErrorsTotal.Inc()
return
}
log.Debugf("HMAC authentication successful for upload request: %s", r.URL.Path)
}
// Parse multipart form
err := r.ParseMultipartForm(32 << 20) // 32MB max memory
if err != nil {
http.Error(w, fmt.Sprintf("Error parsing multipart form: %v", err), http.StatusBadRequest)
uploadErrorsTotal.Inc()
return
}
// Get file from form
file, header, err := r.FormFile("file")
if err != nil {
http.Error(w, fmt.Sprintf("Error getting file from form: %v", err), http.StatusBadRequest)
uploadErrorsTotal.Inc()
return
}
defer file.Close()
// Validate file extension if configured
if len(conf.Uploads.AllowedExtensions) > 0 {
ext := strings.ToLower(filepath.Ext(header.Filename))
allowed := false
for _, allowedExt := range conf.Uploads.AllowedExtensions {
if ext == allowedExt {
allowed = true
break
}
}
if !allowed {
http.Error(w, fmt.Sprintf("File extension %s not allowed", ext), http.StatusBadRequest)
uploadErrorsTotal.Inc()
return
}
}
// Generate filename based on configuration
var filename string
switch conf.Server.FileNaming {
case "HMAC":
// Generate HMAC-based filename
h := hmac.New(sha256.New, []byte(conf.Security.Secret))
h.Write([]byte(header.Filename + time.Now().String()))
filename = hex.EncodeToString(h.Sum(nil)) + filepath.Ext(header.Filename)
default: // "original" or "None"
filename = header.Filename
}
// Create full file path
storagePath := conf.Server.StoragePath
if conf.ISO.Enabled {
storagePath = conf.ISO.MountPoint
}
absFilename := filepath.Join(storagePath, filename)
// Create the file
dst, err := os.Create(absFilename)
if err != nil {
http.Error(w, fmt.Sprintf("Error creating file: %v", err), http.StatusInternalServerError)
uploadErrorsTotal.Inc()
return
}
defer dst.Close()
// Copy file content
written, err := io.Copy(dst, file)
if err != nil {
http.Error(w, fmt.Sprintf("Error saving file: %v", err), http.StatusInternalServerError)
uploadErrorsTotal.Inc()
// Clean up partial file
os.Remove(absFilename)
return
}
// Handle deduplication if enabled
if conf.Server.DeduplicationEnabled {
ctx := context.Background()
err = handleDeduplication(ctx, absFilename)
if err != nil {
log.Warnf("Deduplication failed for %s: %v", absFilename, err)
}
}
// Update metrics
duration := time.Since(startTime)
uploadDuration.Observe(duration.Seconds())
uploadsTotal.Inc()
uploadSizeBytes.Observe(float64(written))
// Return success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
response := map[string]interface{}{
"success": true,
"filename": filename,
"size": written,
"duration": duration.String(),
}
// Create JSON response
if jsonBytes, err := json.Marshal(response); err == nil {
w.Write(jsonBytes)
} else {
fmt.Fprintf(w, `{"success": true, "filename": "%s", "size": %d}`, filename, written)
}
log.Infof("Successfully uploaded %s (%s) in %s", filename, formatBytes(written), duration)
}
// handleDownload handles file downloads.
func handleDownload(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
activeConnections.Inc()
defer activeConnections.Dec()
// Authentication
if conf.Security.EnableJWT {
_, err := validateJWTFromRequest(r, conf.Security.JWTSecret)
if err != nil {
http.Error(w, fmt.Sprintf("JWT Authentication failed: %v", err), http.StatusUnauthorized)
downloadErrorsTotal.Inc()
return
}
log.Debugf("JWT authentication successful for download request: %s", r.URL.Path)
} else {
err := validateHMAC(r, conf.Security.Secret)
if err != nil {
http.Error(w, fmt.Sprintf("HMAC Authentication failed: %v", err), http.StatusUnauthorized)
downloadErrorsTotal.Inc()
return
}
log.Debugf("HMAC authentication successful for download request: %s", r.URL.Path)
}
filename := strings.TrimPrefix(r.URL.Path, "/download/")
if filename == "" {
http.Error(w, "Filename not specified", http.StatusBadRequest)
downloadErrorsTotal.Inc()
return
}
absFilename, err := sanitizeFilePath(conf.Server.StoragePath, filename) // Use sanitizeFilePath from helpers.go
if err != nil {
http.Error(w, fmt.Sprintf("Invalid file path: %v", err), http.StatusBadRequest)
downloadErrorsTotal.Inc()
return
}
fileInfo, err := os.Stat(absFilename)
if os.IsNotExist(err) {
http.Error(w, "File not found", http.StatusNotFound)
downloadErrorsTotal.Inc()
return
}
if err != nil {
http.Error(w, fmt.Sprintf("Error accessing file: %v", err), http.StatusInternalServerError)
downloadErrorsTotal.Inc()
return
}
if fileInfo.IsDir() {
http.Error(w, "Cannot download a directory", http.StatusBadRequest)
downloadErrorsTotal.Inc()
return
}
file, err := os.Open(absFilename)
if err != nil {
http.Error(w, fmt.Sprintf("Error opening file: %v", err), http.StatusInternalServerError)
downloadErrorsTotal.Inc()
return
}
defer file.Close()
w.Header().Set("Content-Disposition", "attachment; filename=\""+filepath.Base(absFilename)+"\"")
w.Header().Set("Content-Type", "application/octet-stream")
w.Header().Set("Content-Length", fmt.Sprintf("%d", fileInfo.Size()))
// Use a pooled buffer for copying
bufPtr := bufferPool.Get().(*[]byte)
defer bufferPool.Put(bufPtr)
buf := *bufPtr
n, err := io.CopyBuffer(w, file, buf)
if err != nil {
log.Errorf("Error during download of %s: %v", absFilename, err)
// Don't write http.Error here if headers already sent
downloadErrorsTotal.Inc()
return // Ensure we don't try to record metrics if there was an error during copy
}
duration := time.Since(startTime)
downloadDuration.Observe(duration.Seconds())
downloadsTotal.Inc()
downloadSizeBytes.Observe(float64(n))
log.Infof("Successfully downloaded %s (%s) in %s", absFilename, formatBytes(n), duration)
}
// handleV3Upload handles PUT requests for v3 protocol (mod_http_upload_external).
func handleV3Upload(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
activeConnections.Inc()
defer activeConnections.Dec()
// Only allow PUT method for v3
if r.Method != http.MethodPut {
http.Error(w, "Method not allowed for v3 uploads", http.StatusMethodNotAllowed)
uploadErrorsTotal.Inc()
return
}
// Validate v3 HMAC signature
err := validateV3HMAC(r, conf.Security.Secret)
if err != nil {
http.Error(w, fmt.Sprintf("v3 Authentication failed: %v", err), http.StatusUnauthorized)
uploadErrorsTotal.Inc()
return
}
log.Debugf("v3 HMAC authentication successful for upload request: %s", r.URL.Path)
// Extract filename from the URL path
// Path format: /uuid/subdir/filename.ext
pathParts := strings.Split(strings.Trim(r.URL.Path, "/"), "/")
if len(pathParts) < 1 {
http.Error(w, "Invalid upload path", http.StatusBadRequest)
uploadErrorsTotal.Inc()
return
}
// Use the last part as filename
originalFilename := pathParts[len(pathParts)-1]
if originalFilename == "" {
http.Error(w, "No filename specified", http.StatusBadRequest)
uploadErrorsTotal.Inc()
return
}
// Validate file extension if configured
if len(conf.Uploads.AllowedExtensions) > 0 {
ext := strings.ToLower(filepath.Ext(originalFilename))
allowed := false
for _, allowedExt := range conf.Uploads.AllowedExtensions {
if ext == allowedExt {
allowed = true
break
}
}
if !allowed {
http.Error(w, fmt.Sprintf("File extension %s not allowed", ext), http.StatusBadRequest)
uploadErrorsTotal.Inc()
return
}
}
// Generate filename based on configuration
var filename string
switch conf.Server.FileNaming {
case "HMAC":
// Generate HMAC-based filename
h := hmac.New(sha256.New, []byte(conf.Security.Secret))
h.Write([]byte(originalFilename + time.Now().String()))
filename = hex.EncodeToString(h.Sum(nil)) + filepath.Ext(originalFilename)
default: // "original" or "None"
filename = originalFilename
}
// Create full file path
storagePath := conf.Server.StoragePath
if conf.ISO.Enabled {
storagePath = conf.ISO.MountPoint
}
absFilename := filepath.Join(storagePath, filename)
// Create the file
dst, err := os.Create(absFilename)
if err != nil {
http.Error(w, fmt.Sprintf("Error creating file: %v", err), http.StatusInternalServerError)
uploadErrorsTotal.Inc()
return
}
defer dst.Close()
// Copy file content from request body
written, err := io.Copy(dst, r.Body)
if err != nil {
http.Error(w, fmt.Sprintf("Error saving file: %v", err), http.StatusInternalServerError)
uploadErrorsTotal.Inc()
// Clean up partial file
os.Remove(absFilename)
return
}
// Handle deduplication if enabled
if conf.Server.DeduplicationEnabled {
ctx := context.Background()
err = handleDeduplication(ctx, absFilename)
if err != nil {
log.Warnf("Deduplication failed for %s: %v", absFilename, err)
}
}
// Update metrics
duration := time.Since(startTime)
uploadDuration.Observe(duration.Seconds())
uploadsTotal.Inc()
uploadSizeBytes.Observe(float64(written))
// Return success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
response := map[string]interface{}{
"success": true,
"filename": filename,
"size": written,
"duration": duration.String(),
}
// Create JSON response
if jsonBytes, err := json.Marshal(response); err == nil {
w.Write(jsonBytes)
} else {
fmt.Fprintf(w, `{"success": true, "filename": "%s", "size": %d}`, filename, written)
}
log.Infof("Successfully uploaded %s via v3 protocol (%s) in %s", filename, formatBytes(written), duration)
}
// handleLegacyUpload handles PUT requests for legacy protocols (v, v2, token).
func handleLegacyUpload(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
activeConnections.Inc()
defer activeConnections.Dec()
log.Debugf("handleLegacyUpload: Processing request to %s with query: %s", r.URL.Path, r.URL.RawQuery)
// Only allow PUT method for legacy uploads
if r.Method != http.MethodPut {
http.Error(w, "Method not allowed for legacy uploads", http.StatusMethodNotAllowed)
uploadErrorsTotal.Inc()
return
}
// Validate legacy HMAC signature
err := validateHMAC(r, conf.Security.Secret)
if err != nil {
http.Error(w, fmt.Sprintf("Legacy Authentication failed: %v", err), http.StatusUnauthorized)
uploadErrorsTotal.Inc()
return
}
// Extract filename from the URL path
fileStorePath := strings.TrimPrefix(r.URL.Path, "/")
if fileStorePath == "" {
http.Error(w, "No filename specified", http.StatusBadRequest)
uploadErrorsTotal.Inc()
return
}
// Validate file extension if configured
if len(conf.Uploads.AllowedExtensions) > 0 {
ext := strings.ToLower(filepath.Ext(fileStorePath))
allowed := false
for _, allowedExt := range conf.Uploads.AllowedExtensions {
if ext == allowedExt {
allowed = true
break
}
}
if !allowed {
http.Error(w, fmt.Sprintf("File extension %s not allowed", ext), http.StatusBadRequest)
uploadErrorsTotal.Inc()
return
}
}
// Create full file path
storagePath := conf.Server.StoragePath
if conf.ISO.Enabled {
storagePath = conf.ISO.MountPoint
}
// Generate filename based on configuration
var absFilename string
var filename string
switch conf.Server.FileNaming {
case "HMAC":
// Generate HMAC-based filename
h := hmac.New(sha256.New, []byte(conf.Security.Secret))
h.Write([]byte(fileStorePath + time.Now().String()))
filename = hex.EncodeToString(h.Sum(nil)) + filepath.Ext(fileStorePath)
absFilename = filepath.Join(storagePath, filename)
default: // "original" or "None"
// Preserve full directory structure for legacy XMPP compatibility
var sanitizeErr error
absFilename, sanitizeErr = sanitizeFilePath(storagePath, fileStorePath)
if sanitizeErr != nil {
http.Error(w, fmt.Sprintf("Invalid file path: %v", sanitizeErr), http.StatusBadRequest)
uploadErrorsTotal.Inc()
return
}
filename = filepath.Base(fileStorePath) // For logging purposes
}
// Create directory structure if it doesn't exist
if err := os.MkdirAll(filepath.Dir(absFilename), 0755); err != nil {
http.Error(w, fmt.Sprintf("Error creating directory: %v", err), http.StatusInternalServerError)
uploadErrorsTotal.Inc()
return
}
// Create the file
dst, err := os.Create(absFilename)
if err != nil {
http.Error(w, fmt.Sprintf("Error creating file: %v", err), http.StatusInternalServerError)
uploadErrorsTotal.Inc()
return
}
defer dst.Close()
// Log upload start for large files
if r.ContentLength > 10*1024*1024 { // Log for files > 10MB
log.Infof("Starting upload of %s (%.1f MiB)", filename, float64(r.ContentLength)/(1024*1024))
}
// Copy file content from request body with progress reporting
written, err := copyWithProgress(dst, r.Body, r.ContentLength, filename)
if err != nil {
http.Error(w, fmt.Sprintf("Error saving file: %v", err), http.StatusInternalServerError)
uploadErrorsTotal.Inc()
// Clean up partial file
os.Remove(absFilename)
return
}
// Handle deduplication if enabled
if conf.Server.DeduplicationEnabled {
ctx := context.Background()
err = handleDeduplication(ctx, absFilename)
if err != nil {
log.Warnf("Deduplication failed for %s: %v", absFilename, err)
}
}
// Update metrics
duration := time.Since(startTime)
uploadDuration.Observe(duration.Seconds())
uploadsTotal.Inc()
uploadSizeBytes.Observe(float64(written))
// Return success response (201 Created for legacy compatibility)
w.WriteHeader(http.StatusCreated)
log.Infof("Successfully uploaded %s via legacy protocol (%s) in %s", filename, formatBytes(written), duration)
}
// handleLegacyDownload handles GET/HEAD requests for legacy downloads.
func handleLegacyDownload(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
activeConnections.Inc()
defer activeConnections.Dec()
// Extract filename from the URL path
fileStorePath := strings.TrimPrefix(r.URL.Path, "/")
if fileStorePath == "" {
http.Error(w, "No filename specified", http.StatusBadRequest)
downloadErrorsTotal.Inc()
return
}
// Create full file path
storagePath := conf.Server.StoragePath
if conf.ISO.Enabled {
storagePath = conf.ISO.MountPoint
}
absFilename := filepath.Join(storagePath, fileStorePath)
fileInfo, err := os.Stat(absFilename)
if os.IsNotExist(err) {
http.Error(w, "File not found", http.StatusNotFound)
downloadErrorsTotal.Inc()
return
}
if err != nil {
http.Error(w, fmt.Sprintf("Error accessing file: %v", err), http.StatusInternalServerError)
downloadErrorsTotal.Inc()
return
}
if fileInfo.IsDir() {
http.Error(w, "Cannot download a directory", http.StatusBadRequest)
downloadErrorsTotal.Inc()
return
}
// Set appropriate headers
contentType := mime.TypeByExtension(filepath.Ext(fileStorePath))
if contentType == "" {
contentType = "application/octet-stream"
}
w.Header().Set("Content-Type", contentType)
w.Header().Set("Content-Length", strconv.FormatInt(fileInfo.Size(), 10))
// Handle resumable downloads
if conf.Uploads.ResumableUploadsEnabled {
handleResumableDownload(absFilename, w, r, fileInfo.Size())
return
}
// For HEAD requests, only send headers
if r.Method == http.MethodHead {
w.Header().Set("Content-Length", strconv.FormatInt(fileInfo.Size(), 10))
downloadsTotal.Inc()
return
} else {
// Measure download duration
startTime := time.Now()
log.Infof("Initiating download for file: %s", absFilename)
http.ServeFile(w, r, absFilename)
downloadDuration.Observe(time.Since(startTime).Seconds())
downloadSizeBytes.Observe(float64(fileInfo.Size()))
downloadsTotal.Inc()
log.Infof("File downloaded successfully: %s", absFilename)
return
}
}
// Create the file for upload with buffered Writer
func createFile(tempFilename string, r *http.Request) error {
absDirectory := filepath.Dir(tempFilename)
err := os.MkdirAll(absDirectory, os.ModePerm)
if err != nil {
log.WithError(err).Errorf("Failed to create directory %s", absDirectory)
return fmt.Errorf("failed to create directory %s: %w", absDirectory, err)
}
// Open the file for writing
targetFile, err := os.OpenFile(tempFilename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
log.WithError(err).Errorf("Failed to create file %s", tempFilename)
return fmt.Errorf("failed to create file %s: %w", tempFilename, err)
}
defer targetFile.Close()
// Use a large buffer for efficient file writing
bufferSize := 4 * 1024 * 1024 // 4 MB buffer
writer := bufio.NewWriterSize(targetFile, bufferSize)
buffer := make([]byte, bufferSize)
totalBytes := int64(0)
for {
n, readErr := r.Body.Read(buffer)
if n > 0 {
totalBytes += int64(n)
_, writeErr := writer.Write(buffer[:n])
if writeErr != nil {
log.WithError(writeErr).Errorf("Failed to write to file %s", tempFilename)
return fmt.Errorf("failed to write to file %s: %w", tempFilename, writeErr)
}
}
if readErr != nil {
if readErr == io.EOF {
break
}
log.WithError(readErr).Error("Failed to read request body")
return fmt.Errorf("failed to read request body: %w", readErr)
}
}
err = writer.Flush()
if err != nil {
log.WithError(err).Errorf("Failed to flush buffer to file %s", tempFilename)
return fmt.Errorf("failed to flush buffer to file %s: %w", tempFilename, err)
}
log.WithFields(logrus.Fields{
"temp_file": tempFilename,
"total_bytes": totalBytes,
}).Info("File uploaded successfully")
uploadSizeBytes.Observe(float64(totalBytes))
return nil
}
// Scan the uploaded file with ClamAV (Optional)
func scanFileWithClamAV(filePath string) error {
log.WithField("file", filePath).Info("Scanning file with ClamAV")
scanResultChan, err := clamClient.ScanFile(filePath)
if err != nil {
log.WithError(err).Error("Failed to initiate ClamAV scan")
return fmt.Errorf("failed to initiate ClamAV scan: %w", err)
}
// Receive scan result
scanResult := <-scanResultChan
if scanResult == nil {
log.Error("Failed to receive scan result from ClamAV")
return fmt.Errorf("failed to receive scan result from ClamAV")
}
// Handle scan result
switch scanResult.Status {
case clamd.RES_OK:
log.WithField("file", filePath).Info("ClamAV scan passed")
return nil
case clamd.RES_FOUND:
log.WithFields(logrus.Fields{
"file": filePath,
"description": scanResult.Description,
}).Warn("ClamAV detected a virus")
return fmt.Errorf("virus detected: %s", scanResult.Description)
default:
log.WithFields(logrus.Fields{
"file": filePath,
"status": scanResult.Status,
"description": scanResult.Description,
}).Warn("ClamAV scan returned unexpected status")
return fmt.Errorf("ClamAV scan returned unexpected status: %s", scanResult.Description)
}
}
// initClamAV initializes the ClamAV client and logs the status
func initClamAV(socket string) (*clamd.Clamd, error) {
if socket == "" {
log.Error("ClamAV socket path is not configured.")
return nil, fmt.Errorf("ClamAV socket path is not configured")
}
clamClient := clamd.NewClamd("unix:" + socket)
err := clamClient.Ping()
if err != nil {
log.Errorf("Failed to connect to ClamAV at %s: %v", socket, err)
return nil, fmt.Errorf("failed to connect to ClamAV: %w", err)
}
log.Info("Connected to ClamAV successfully.")
return clamClient, nil
}
// Handle resumable downloads
func handleResumableDownload(absFilename string, w http.ResponseWriter, r *http.Request, fileSize int64) {
rangeHeader := r.Header.Get("Range")
if rangeHeader == "" {
// If no Range header, serve the full file
startTime := time.Now()
http.ServeFile(w, r, absFilename)
downloadDuration.Observe(time.Since(startTime).Seconds())
downloadSizeBytes.Observe(float64(fileSize))
w.WriteHeader(http.StatusOK)
downloadsTotal.Inc()
return
}
// Parse Range header
ranges := strings.Split(strings.TrimPrefix(rangeHeader, "bytes="), "-")
if len(ranges) != 2 {
http.Error(w, "Invalid Range", http.StatusRequestedRangeNotSatisfiable)
downloadErrorsTotal.Inc()
return
}
start, err := strconv.ParseInt(ranges[0], 10, 64)
if err != nil {
http.Error(w, "Invalid Range", http.StatusRequestedRangeNotSatisfiable)
downloadErrorsTotal.Inc()
return
}
// Calculate end byte
end := fileSize - 1
if ranges[1] != "" {
end, err = strconv.ParseInt(ranges[1], 10, 64)
if err != nil || end >= fileSize {
http.Error(w, "Invalid Range", http.StatusRequestedRangeNotSatisfiable)
downloadErrorsTotal.Inc()
return
}
}
// Set response headers for partial content
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, fileSize))
w.Header().Set("Content-Length", strconv.FormatInt(end-start+1, 10))
w.Header().Set("Accept-Ranges", "bytes")
w.WriteHeader(http.StatusPartialContent)
// Serve the requested byte range
// For GET requests, serve the file
file, err := os.Open(absFilename)
if err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
http.Error(w, fmt.Sprintf("Error opening file: %v", err), http.StatusInternalServerError)
downloadErrorsTotal.Inc()
return
}
defer file.Close()
// Seek to the start byte
_, err = file.Seek(start, 0)
// Use a pooled buffer for copying
bufPtr := bufferPool.Get().(*[]byte)
defer bufferPool.Put(bufPtr)
buf := *bufPtr
n, err := io.CopyBuffer(w, file, buf)
if err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
log.Errorf("Error during download of %s: %v", absFilename, err)
downloadErrorsTotal.Inc()
return
}
// Create a buffer and copy the specified range to the response writer
buffer := make([]byte, 32*1024) // 32KB buffer
remaining := end - start + 1
startTime := time.Now()
for remaining > 0 {
if int64(len(buffer)) > remaining {
buffer = buffer[:remaining]
}
n, err := file.Read(buffer)
if n > 0 {
if _, writeErr := w.Write(buffer[:n]); writeErr != nil {
log.WithError(writeErr).Error("Failed to write to response")
downloadErrorsTotal.Inc()
return
}
remaining -= int64(n)
}
if err != nil {
if err != io.EOF {
log.WithError(err).Error("Error reading file during resumable download")
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
downloadErrorsTotal.Inc()
}
break
}
}
downloadDuration.Observe(time.Since(startTime).Seconds())
downloadSizeBytes.Observe(float64(end - start + 1))
duration := time.Since(startTime)
downloadDuration.Observe(duration.Seconds())
downloadsTotal.Inc()
downloadSizeBytes.Observe(float64(n))
log.Infof("Successfully downloaded %s (%s) in %s", absFilename, formatBytes(n), duration)
}
// Handle chunked uploads with bufio.Writer
func handleChunkedUpload(tempFilename string, r *http.Request) error {
log.WithField("file", tempFilename).Info("Handling chunked upload to temporary file")
// printValidationChecks prints all available validation checks
func printValidationChecks() {
fmt.Println("HMAC File Server Configuration Validation Checks")
fmt.Println("=================================================")
fmt.Println()
// Ensure the directory exists
absDirectory := filepath.Dir(tempFilename)
err := os.MkdirAll(absDirectory, os.ModePerm)
if err != nil {
log.WithError(err).Errorf("Failed to create directory %s for chunked upload", absDirectory)
return fmt.Errorf("failed to create directory %s: %w", absDirectory, err)
}
fmt.Println("🔍 CORE VALIDATION CHECKS:")
fmt.Println(" ✓ server.* - Server configuration (ports, paths, protocols)")
fmt.Println(" ✓ security.* - Security settings (secrets, JWT, authentication)")
fmt.Println(" ✓ logging.* - Logging configuration (levels, files, rotation)")
fmt.Println(" ✓ timeouts.* - Timeout settings (read, write, idle)")
fmt.Println(" ✓ uploads.* - Upload configuration (extensions, chunk size)")
fmt.Println(" ✓ downloads.* - Download configuration (extensions, chunk size)")
fmt.Println(" ✓ workers.* - Worker pool configuration (count, queue size)")
fmt.Println(" ✓ redis.* - Redis configuration (address, credentials)")
fmt.Println(" ✓ clamav.* - ClamAV antivirus configuration")
fmt.Println(" ✓ versioning.* - File versioning configuration")
fmt.Println(" ✓ deduplication.* - File deduplication configuration")
fmt.Println(" ✓ iso.* - ISO filesystem configuration")
fmt.Println()
targetFile, err := os.OpenFile(tempFilename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
log.WithError(err).Error("Failed to open temporary file for chunked upload")
return err
}
defer targetFile.Close()
fmt.Println("🔐 SECURITY CHECKS:")
fmt.Println(" ✓ Secret strength analysis (length, entropy, patterns)")
fmt.Println(" ✓ Default/example value detection")
fmt.Println(" ✓ JWT algorithm security recommendations")
fmt.Println(" ✓ Network binding security (0.0.0.0 warnings)")
fmt.Println(" ✓ File permission analysis")
fmt.Println(" ✓ Debug logging security implications")
fmt.Println()
writer := bufio.NewWriterSize(targetFile, int(conf.Uploads.ChunkSize))
buffer := make([]byte, conf.Uploads.ChunkSize)
fmt.Println("⚡ PERFORMANCE CHECKS:")
fmt.Println(" ✓ Worker count vs CPU cores optimization")
fmt.Println(" ✓ Queue size vs memory usage analysis")
fmt.Println(" ✓ Timeout configuration balance")
fmt.Println(" ✓ Large file handling preparation")
fmt.Println(" ✓ Memory-intensive configuration detection")
fmt.Println()
totalBytes := int64(0)
for {
n, err := r.Body.Read(buffer)
if n > 0 {
totalBytes += int64(n)
_, writeErr := writer.Write(buffer[:n])
if writeErr != nil {
log.WithError(writeErr).Error("Failed to write chunk to temporary file")
return writeErr
}
}
if err != nil {
if err == io.EOF {
break // Finished reading the body
}
log.WithError(err).Error("Error reading from request body")
return err
}
}
fmt.Println("🌐 CONNECTIVITY CHECKS:")
fmt.Println(" ✓ Redis server connectivity testing")
fmt.Println(" ✓ ClamAV socket accessibility")
fmt.Println(" ✓ Network address format validation")
fmt.Println(" ✓ DNS resolution testing")
fmt.Println()
err = writer.Flush()
if err != nil {
log.WithError(err).Error("Failed to flush buffer to temporary file")
return err
}
fmt.Println("💾 SYSTEM RESOURCE CHECKS:")
fmt.Println(" ✓ CPU core availability analysis")
fmt.Println(" ✓ Memory usage monitoring")
fmt.Println(" ✓ Disk space validation")
fmt.Println(" ✓ Directory write permissions")
fmt.Println(" ✓ Goroutine count analysis")
fmt.Println()
log.WithFields(logrus.Fields{
"temp_file": tempFilename,
"total_bytes": totalBytes,
}).Info("Chunked upload completed successfully")
fmt.Println("🔄 CROSS-SECTION VALIDATION:")
fmt.Println(" ✓ Path conflict detection")
fmt.Println(" ✓ Extension compatibility checks")
fmt.Println(" ✓ Configuration consistency validation")
fmt.Println()
uploadSizeBytes.Observe(float64(totalBytes))
return nil
}
// Get file information with caching
func getFileInfo(absFilename string) (os.FileInfo, error) {
if cachedInfo, found := fileInfoCache.Get(absFilename); found {
if info, ok := cachedInfo.(os.FileInfo); ok {
return info, nil
}
}
fileInfo, err := os.Stat(absFilename)
if err != nil {
return nil, err
}
fileInfoCache.Set(absFilename, fileInfo, cache.DefaultExpiration)
return fileInfo, nil
}
// Monitor network changes
func monitorNetwork(ctx context.Context) {
currentIP := getCurrentIPAddress() // Placeholder for the current IP address
for {
select {
case <-ctx.Done():
log.Info("Stopping network monitor.")
return
case <-time.After(10 * time.Second):
newIP := getCurrentIPAddress()
if newIP != currentIP && newIP != "" {
currentIP = newIP
select {
case networkEvents <- NetworkEvent{Type: "IP_CHANGE", Details: currentIP}:
log.WithField("new_ip", currentIP).Info("Queued IP_CHANGE event")
default:
log.Warn("Network event channel is full. Dropping IP_CHANGE event.")
}
}
}
}
}
// Handle network events
func handleNetworkEvents(ctx context.Context) {
for {
select {
case <-ctx.Done():
log.Info("Stopping network event handler.")
return
case event, ok := <-networkEvents:
if !ok {
log.Info("Network events channel closed.")
return
}
switch event.Type {
case "IP_CHANGE":
log.WithField("new_ip", event.Details).Info("Network change detected")
// Example: Update Prometheus gauge or trigger alerts
// activeConnections.Set(float64(getActiveConnections()))
}
// Additional event types can be handled here
}
}
}
// Get current IP address (example)
func getCurrentIPAddress() string {
interfaces, err := net.Interfaces()
if err != nil {
log.WithError(err).Error("Failed to get network interfaces")
return ""
}
for _, iface := range interfaces {
if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagLoopback != 0 {
continue // Skip interfaces that are down or loopback
}
addrs, err := iface.Addrs()
if err != nil {
log.WithError(err).Errorf("Failed to get addresses for interface %s", iface.Name)
continue
}
for _, addr := range addrs {
if ipnet, ok := addr.(*net.IPNet); ok && ipnet.IP.IsGlobalUnicast() && ipnet.IP.To4() != nil {
return ipnet.IP.String()
}
}
}
return ""
}
// setupGracefulShutdown sets up handling for graceful server shutdown
func setupGracefulShutdown(server *http.Server, cancel context.CancelFunc) {
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
go func() {
sig := <-quit
log.Infof("Received signal %s. Initiating shutdown...", sig)
// Create a deadline to wait for.
ctxShutdown, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer shutdownCancel()
// Attempt graceful shutdown
if err := server.Shutdown(ctxShutdown); err != nil {
log.Errorf("Server shutdown failed: %v", err)
} else {
log.Info("Server shutdown gracefully.")
}
// Signal other goroutines to stop
cancel()
// Close the upload, scan, and network event channels
close(uploadQueue)
log.Info("Upload queue closed.")
close(scanQueue)
log.Info("Scan queue closed.")
close(networkEvents)
log.Info("Network events channel closed.")
log.Info("Shutdown process completed. Exiting application.")
os.Exit(0)
}()
}
// Initialize Redis client
func initRedis() {
if !conf.Redis.RedisEnabled {
log.Info("Redis is disabled in configuration.")
return
}
redisClient = redis.NewClient(&redis.Options{
Addr: conf.Redis.RedisAddr,
Password: conf.Redis.RedisPassword,
DB: conf.Redis.RedisDBIndex,
})
// Test the Redis connection
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, err := redisClient.Ping(ctx).Result()
if err != nil {
log.Fatalf("Failed to connect to Redis: %v", err)
}
log.Info("Connected to Redis successfully")
// Set initial connection status
mu.Lock()
redisConnected = true
mu.Unlock()
// Start monitoring Redis health
go MonitorRedisHealth(context.Background(), redisClient, parseDuration(conf.Redis.RedisHealthCheckInterval))
}
// MonitorRedisHealth periodically checks Redis connectivity and updates redisConnected status.
func MonitorRedisHealth(ctx context.Context, client *redis.Client, checkInterval time.Duration) {
ticker := time.NewTicker(checkInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
log.Info("Stopping Redis health monitor.")
return
case <-ticker.C:
err := client.Ping(ctx).Err()
mu.Lock()
if err != nil {
if redisConnected {
log.Errorf("Redis health check failed: %v", err)
}
redisConnected = false
} else {
if !redisConnected {
log.Info("Redis reconnected successfully")
}
redisConnected = true
log.Debug("Redis health check succeeded.")
}
mu.Unlock()
}
}
}
// Helper function to parse duration strings
func parseDuration(durationStr string) time.Duration {
duration, err := time.ParseDuration(durationStr)
if err != nil {
log.WithError(err).Warn("Invalid duration format, using default 30s")
return 30 * time.Second
}
return duration
}
// RunFileCleaner periodically deletes files that exceed the FileTTL duration.
func runFileCleaner(ctx context.Context, storeDir string, ttl time.Duration) {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
log.Info("Stopping file cleaner.")
return
case <-ticker.C:
now := time.Now()
err := filepath.Walk(storeDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() {
return nil
}
if now.Sub(info.ModTime()) > ttl {
err := os.Remove(path)
if err != nil {
log.WithError(err).Errorf("Failed to remove expired file: %s", path)
} else {
log.Infof("Removed expired file: %s", path)
}
}
return nil
})
if err != nil {
log.WithError(err).Error("Error walking store directory for file cleaning")
}
}
}
}
// DeduplicateFiles scans the store directory and removes duplicate files based on SHA256 hash.
// It retains one copy of each unique file and replaces duplicates with hard links.
func DeduplicateFiles(storeDir string) error {
hashMap := make(map[string]string) // map[hash]filepath
var mu sync.Mutex
var wg sync.WaitGroup
fileChan := make(chan string, 100)
// Worker to process files
numWorkers := 10
for i := 0; i < numWorkers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for filePath := range fileChan {
hash, err := computeFileHash(filePath)
if err != nil {
logrus.WithError(err).Errorf("Failed to compute hash for %s", filePath)
continue
}
mu.Lock()
original, exists := hashMap[hash]
if !exists {
hashMap[hash] = filePath
mu.Unlock()
continue
}
mu.Unlock()
// Duplicate found
err = os.Remove(filePath)
if err != nil {
logrus.WithError(err).Errorf("Failed to remove duplicate file %s", filePath)
continue
}
// Create hard link to the original file
err = os.Link(original, filePath)
if err != nil {
logrus.WithError(err).Errorf("Failed to create hard link from %s to %s", original, filePath)
continue
}
logrus.Infof("Removed duplicate %s and linked to %s", filePath, original)
}
}()
}
// Walk through the store directory
err := filepath.Walk(storeDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
logrus.WithError(err).Errorf("Error accessing path %s", path)
return nil
}
if !info.Mode().IsRegular() {
return nil
}
fileChan <- path
return nil
})
if err != nil {
return fmt.Errorf("error walking the path %s: %w", storeDir, err)
}
close(fileChan)
wg.Wait()
return nil
}
// computeFileHash computes the SHA256 hash of the given file.
func computeFileHash(filePath string) (string, error) {
file, err := os.Open(filePath)
if err != nil {
return "", fmt.Errorf("unable to open file %s: %w", filePath, err)
}
defer file.Close()
hasher := sha256.New()
if _, err := io.Copy(hasher, file); err != nil {
return "", fmt.Errorf("error hashing file %s: %w", filePath, err)
}
return hex.EncodeToString(hasher.Sum(nil)), nil
}
// Handle multipart uploads
func handleMultipartUpload(w http.ResponseWriter, r *http.Request, absFilename string) error {
err := r.ParseMultipartForm(32 << 20) // 32MB is the default used by FormFile
if err != nil {
log.WithError(err).Error("Failed to parse multipart form")
http.Error(w, "Failed to parse multipart form", http.StatusBadRequest)
return err
}
file, handler, err := r.FormFile("file")
if err != nil {
log.WithError(err).Error("Failed to retrieve file from form data")
http.Error(w, "Failed to retrieve file from form data", http.StatusBadRequest)
return err
}
defer file.Close()
// Validate file extension
if !isExtensionAllowed(handler.Filename) {
log.WithFields(logrus.Fields{
"filename": handler.Filename,
"extension": filepath.Ext(handler.Filename),
}).Warn("Attempted upload with disallowed file extension")
http.Error(w, "Disallowed file extension. Allowed extensions are: "+strings.Join(conf.Uploads.AllowedExtensions, ", "), http.StatusForbidden)
uploadErrorsTotal.Inc()
return fmt.Errorf("disallowed file extension")
}
// Create a temporary file
tempFilename := absFilename + ".tmp"
tempFile, err := os.OpenFile(tempFilename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666)
if err != nil {
log.WithError(err).Error("Failed to create temporary file")
http.Error(w, "Failed to create temporary file", http.StatusInternalServerError)
return err
}
defer tempFile.Close()
// Copy the uploaded file to the temporary file
_, err = io.Copy(tempFile, file)
if err != nil {
log.WithError(err).Error("Failed to copy uploaded file to temporary file")
http.Error(w, "Failed to copy uploaded file", http.StatusInternalServerError)
return err
}
// Perform ClamAV scan on the temporary file
if clamClient != nil {
err := scanFileWithClamAV(tempFilename)
if err != nil {
log.WithFields(logrus.Fields{
"file": tempFilename,
"error": err,
}).Warn("ClamAV detected a virus or scan failed")
os.Remove(tempFilename)
uploadErrorsTotal.Inc()
return err
}
}
// Handle file versioning if enabled
if conf.Versioning.EnableVersioning {
existing, _ := fileExists(absFilename)
if existing {
err := versionFile(absFilename)
if err != nil {
log.WithFields(logrus.Fields{
"file": absFilename,
"error": err,
}).Error("Error versioning file")
os.Remove(tempFilename)
return err
}
}
}
// Move the temporary file to the final destination
err = os.Rename(tempFilename, absFilename)
if err != nil {
log.WithFields(logrus.Fields{
"temp_file": tempFilename,
"final_file": absFilename,
"error": err,
}).Error("Failed to move file to final destination")
os.Remove(tempFilename)
return err
}
log.WithFields(logrus.Fields{
"file": absFilename,
}).Info("File uploaded and scanned successfully")
uploadsTotal.Inc()
return nil
}
// sanitizeFilePath ensures that the file path is within the designated storage directory
func sanitizeFilePath(baseDir, filePath string) (string, error) {
// Resolve the absolute path
absBaseDir, err := filepath.Abs(baseDir)
if err != nil {
return "", fmt.Errorf("failed to resolve base directory: %w", err)
}
absFilePath, err := filepath.Abs(filepath.Join(absBaseDir, filePath))
if err != nil {
return "", fmt.Errorf("failed to resolve file path: %w", err)
}
// Check if the resolved file path is within the base directory
if !strings.HasPrefix(absFilePath, absBaseDir) {
return "", fmt.Errorf("invalid file path: %s", filePath)
}
return absFilePath, nil
}
// checkStorageSpace ensures that there is enough free space in the storage path
func checkStorageSpace(storagePath string, minFreeBytes int64) error {
var stat syscall.Statfs_t
err := syscall.Statfs(storagePath, &stat)
if err != nil {
return fmt.Errorf("failed to get filesystem stats: %w", err)
}
// Calculate available bytes
availableBytes := stat.Bavail * uint64(stat.Bsize)
if int64(availableBytes) < minFreeBytes {
return fmt.Errorf("not enough free space: %d bytes available, %d bytes required", availableBytes, minFreeBytes)
}
return nil
}
// Function to compute SHA256 checksum of a file
func computeSHA256(filePath string) (string, error) {
file, err := os.Open(filePath)
if err != nil {
return "", fmt.Errorf("failed to open file for checksum: %w", err)
}
defer file.Close()
hasher := sha256.New()
if _, err := io.Copy(hasher, file); err != nil {
return "", fmt.Errorf("failed to compute checksum: %w", err)
}
return hex.EncodeToString(hasher.Sum(nil)), nil
}
// handleDeduplication handles file deduplication using SHA256 checksum and hard links
func handleDeduplication(ctx context.Context, absFilename string) error {
// Compute checksum of the uploaded file
checksum, err := computeSHA256(absFilename)
if err != nil {
log.Errorf("Failed to compute SHA256 for %s: %v", absFilename, err)
return fmt.Errorf("checksum computation failed: %w", err)
}
log.Debugf("Computed checksum for %s: %s", absFilename, checksum)
// Check Redis for existing checksum
existingPath, err := redisClient.Get(ctx, checksum).Result()
if err != nil && err != redis.Nil {
log.Errorf("Redis error while fetching checksum %s: %v", checksum, err)
return fmt.Errorf("redis error: %w", err)
}
if err != redis.Nil {
// Duplicate found, create hard link
log.Infof("Duplicate detected: %s already exists at %s", absFilename, existingPath)
err = os.Link(existingPath, absFilename)
if err != nil {
log.Errorf("Failed to create hard link from %s to %s: %v", existingPath, absFilename, err)
return fmt.Errorf("failed to create hard link: %w", err)
}
log.Infof("Created hard link from %s to %s", existingPath, absFilename)
return nil
}
// No duplicate found, store checksum in Redis
err = redisClient.Set(ctx, checksum, absFilename, 0).Err()
if err != nil {
log.Errorf("Failed to store checksum %s in Redis: %v", checksum, err)
return fmt.Errorf("failed to store checksum in Redis: %w", err)
}
log.Infof("Stored new file checksum in Redis: %s -> %s", checksum, absFilename)
return nil
fmt.Println("📋 USAGE EXAMPLES:")
fmt.Println(" hmac-file-server --validate-config # Full validation")
fmt.Println(" hmac-file-server --check-security # Security checks only")
fmt.Println(" hmac-file-server --check-performance # Performance checks only")
fmt.Println(" hmac-file-server --check-connectivity # Network checks only")
fmt.Println(" hmac-file-server --validate-quiet # Errors only")
fmt.Println(" hmac-file-server --validate-verbose # Detailed output")
fmt.Println(" hmac-file-server --check-fixable # Auto-fixable issues")
fmt.Println()
}