release: hmac-file-server 3.2
This commit is contained in:
@@ -1,67 +0,0 @@
|
||||
# Server Settings
|
||||
[server]
|
||||
ListenPort = "8080"
|
||||
UnixSocket = false
|
||||
StoreDir = "./testupload"
|
||||
LogLevel = "info"
|
||||
LogFile = "./hmac-file-server.log"
|
||||
MetricsEnabled = true
|
||||
MetricsPort = "9090"
|
||||
FileTTL = "8760h"
|
||||
|
||||
# Workers and Connections
|
||||
[workers]
|
||||
NumWorkers = 2
|
||||
UploadQueueSize = 500
|
||||
|
||||
# Timeout Settings
|
||||
[timeouts]
|
||||
ReadTimeout = "600s"
|
||||
WriteTimeout = "600s"
|
||||
IdleTimeout = "600s"
|
||||
|
||||
# Security Settings
|
||||
[security]
|
||||
Secret = "a-orc-and-a-humans-is-drinking-ale"
|
||||
|
||||
# Versioning Settings
|
||||
[versioning]
|
||||
EnableVersioning = false
|
||||
MaxVersions = 1
|
||||
|
||||
# Upload/Download Settings
|
||||
[uploads]
|
||||
ResumableUploadsEnabled = true
|
||||
ChunkedUploadsEnabled = true
|
||||
ChunkSize = 16777216
|
||||
AllowedExtensions = [
|
||||
# Document formats
|
||||
".txt", ".pdf",
|
||||
|
||||
# Image formats
|
||||
".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp",
|
||||
|
||||
# Video formats
|
||||
".wav", ".mp4", ".avi", ".mkv", ".mov", ".wmv", ".flv", ".webm", ".mpeg", ".mpg", ".m4v", ".3gp", ".3g2",
|
||||
|
||||
# Audio formats
|
||||
".mp3", ".ogg"
|
||||
]
|
||||
|
||||
# ClamAV Settings
|
||||
[clamav]
|
||||
ClamAVEnabled = false
|
||||
ClamAVSocket = "/var/run/clamav/clamd.ctl"
|
||||
NumScanWorkers = 4
|
||||
|
||||
# Redis Settings
|
||||
[redis]
|
||||
RedisEnabled = false
|
||||
RedisAddr = "localhost:6379"
|
||||
RedisPassword = ""
|
||||
RedisDBIndex = 0
|
||||
RedisHealthCheckInterval = "120s"
|
||||
|
||||
# Deduplication
|
||||
[deduplication]
|
||||
enabled = false
|
||||
294
cmd/server/config_test_scenarios.go
Normal file
294
cmd/server/config_test_scenarios.go
Normal file
@@ -0,0 +1,294 @@
|
||||
// config_test_scenarios.go
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// ConfigTestScenario represents a test scenario for configuration validation
|
||||
type ConfigTestScenario struct {
|
||||
Name string
|
||||
Config Config
|
||||
ShouldPass bool
|
||||
ExpectedErrors []string
|
||||
ExpectedWarnings []string
|
||||
}
|
||||
|
||||
// GetConfigTestScenarios returns a set of test scenarios for configuration validation
|
||||
func GetConfigTestScenarios() []ConfigTestScenario {
|
||||
baseValidConfig := Config{
|
||||
Server: ServerConfig{
|
||||
ListenAddress: "8080",
|
||||
BindIP: "0.0.0.0",
|
||||
StoragePath: "/tmp/test-storage",
|
||||
MetricsEnabled: true,
|
||||
MetricsPort: "9090",
|
||||
FileTTLEnabled: true,
|
||||
FileTTL: "24h",
|
||||
MinFreeBytes: "1GB",
|
||||
FileNaming: "HMAC",
|
||||
ForceProtocol: "auto",
|
||||
PIDFilePath: "/tmp/test.pid",
|
||||
},
|
||||
Security: SecurityConfig{
|
||||
Secret: "test-secret-key-32-characters",
|
||||
EnableJWT: false,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
File: "/tmp/test.log",
|
||||
MaxSize: 100,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 30,
|
||||
},
|
||||
Timeouts: TimeoutConfig{
|
||||
Read: "30s",
|
||||
Write: "30s",
|
||||
Idle: "60s",
|
||||
},
|
||||
Workers: WorkersConfig{
|
||||
NumWorkers: 4,
|
||||
UploadQueueSize: 50,
|
||||
},
|
||||
Uploads: UploadsConfig{
|
||||
AllowedExtensions: []string{".txt", ".pdf", ".jpg"},
|
||||
ChunkSize: "10MB",
|
||||
},
|
||||
Downloads: DownloadsConfig{
|
||||
AllowedExtensions: []string{".txt", ".pdf", ".jpg"},
|
||||
ChunkSize: "10MB",
|
||||
},
|
||||
}
|
||||
|
||||
return []ConfigTestScenario{
|
||||
{
|
||||
Name: "Valid Basic Configuration",
|
||||
Config: baseValidConfig,
|
||||
ShouldPass: true,
|
||||
},
|
||||
{
|
||||
Name: "Missing Listen Address",
|
||||
Config: func() Config {
|
||||
c := baseValidConfig
|
||||
c.Server.ListenAddress = ""
|
||||
return c
|
||||
}(),
|
||||
ShouldPass: false,
|
||||
ExpectedErrors: []string{"server.listen_address is required"},
|
||||
},
|
||||
{
|
||||
Name: "Invalid Port Number",
|
||||
Config: func() Config {
|
||||
c := baseValidConfig
|
||||
c.Server.ListenAddress = "99999"
|
||||
return c
|
||||
}(),
|
||||
ShouldPass: false,
|
||||
ExpectedErrors: []string{"invalid port number"},
|
||||
},
|
||||
{
|
||||
Name: "Invalid IP Address",
|
||||
Config: func() Config {
|
||||
c := baseValidConfig
|
||||
c.Server.BindIP = "999.999.999.999"
|
||||
return c
|
||||
}(),
|
||||
ShouldPass: false,
|
||||
ExpectedErrors: []string{"invalid IP address format"},
|
||||
},
|
||||
{
|
||||
Name: "Same Port for Server and Metrics",
|
||||
Config: func() Config {
|
||||
c := baseValidConfig
|
||||
c.Server.ListenAddress = "8080"
|
||||
c.Server.MetricsPort = "8080"
|
||||
return c
|
||||
}(),
|
||||
ShouldPass: false,
|
||||
ExpectedErrors: []string{"metrics port cannot be the same as main listen port"},
|
||||
},
|
||||
{
|
||||
Name: "JWT Enabled Without Secret",
|
||||
Config: func() Config {
|
||||
c := baseValidConfig
|
||||
c.Security.EnableJWT = true
|
||||
c.Security.JWTSecret = ""
|
||||
return c
|
||||
}(),
|
||||
ShouldPass: false,
|
||||
ExpectedErrors: []string{"JWT secret is required when JWT is enabled"},
|
||||
},
|
||||
{
|
||||
Name: "Short JWT Secret",
|
||||
Config: func() Config {
|
||||
c := baseValidConfig
|
||||
c.Security.EnableJWT = true
|
||||
c.Security.JWTSecret = "short"
|
||||
c.Security.JWTAlgorithm = "HS256"
|
||||
return c
|
||||
}(),
|
||||
ShouldPass: true,
|
||||
ExpectedWarnings: []string{"JWT secret should be at least 32 characters"},
|
||||
},
|
||||
{
|
||||
Name: "Invalid Log Level",
|
||||
Config: func() Config {
|
||||
c := baseValidConfig
|
||||
c.Logging.Level = "invalid"
|
||||
return c
|
||||
}(),
|
||||
ShouldPass: false,
|
||||
ExpectedErrors: []string{"invalid log level"},
|
||||
},
|
||||
{
|
||||
Name: "Invalid Timeout Format",
|
||||
Config: func() Config {
|
||||
c := baseValidConfig
|
||||
c.Timeouts.Read = "invalid"
|
||||
return c
|
||||
}(),
|
||||
ShouldPass: false,
|
||||
ExpectedErrors: []string{"invalid read timeout format"},
|
||||
},
|
||||
{
|
||||
Name: "Negative Worker Count",
|
||||
Config: func() Config {
|
||||
c := baseValidConfig
|
||||
c.Workers.NumWorkers = -1
|
||||
return c
|
||||
}(),
|
||||
ShouldPass: false,
|
||||
ExpectedErrors: []string{"number of workers must be positive"},
|
||||
},
|
||||
{
|
||||
Name: "Extensions Without Dots",
|
||||
Config: func() Config {
|
||||
c := baseValidConfig
|
||||
c.Uploads.AllowedExtensions = []string{"txt", "pdf"}
|
||||
return c
|
||||
}(),
|
||||
ShouldPass: false,
|
||||
ExpectedErrors: []string{"file extensions must start with a dot"},
|
||||
},
|
||||
{
|
||||
Name: "High Worker Count Warning",
|
||||
Config: func() Config {
|
||||
c := baseValidConfig
|
||||
c.Workers.NumWorkers = 100
|
||||
return c
|
||||
}(),
|
||||
ShouldPass: true,
|
||||
ExpectedWarnings: []string{"very high worker count may impact performance"},
|
||||
},
|
||||
{
|
||||
Name: "Deduplication Without Directory",
|
||||
Config: func() Config {
|
||||
c := baseValidConfig
|
||||
c.Deduplication.Enabled = true
|
||||
c.Deduplication.Directory = ""
|
||||
return c
|
||||
}(),
|
||||
ShouldPass: false,
|
||||
ExpectedErrors: []string{"deduplication directory is required"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// RunConfigTests runs all configuration test scenarios
|
||||
func RunConfigTests() {
|
||||
scenarios := GetConfigTestScenarios()
|
||||
passed := 0
|
||||
failed := 0
|
||||
|
||||
fmt.Println("🧪 Running Configuration Test Scenarios")
|
||||
fmt.Println("=======================================")
|
||||
fmt.Println()
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
fmt.Printf("Test %d: %s\n", i+1, scenario.Name)
|
||||
|
||||
// Create temporary directories for testing
|
||||
tempDir := filepath.Join(os.TempDir(), fmt.Sprintf("hmac-test-%d", i))
|
||||
os.MkdirAll(tempDir, 0755)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Update paths in config to use temp directory
|
||||
scenario.Config.Server.StoragePath = filepath.Join(tempDir, "storage")
|
||||
scenario.Config.Logging.File = filepath.Join(tempDir, "test.log")
|
||||
scenario.Config.Server.PIDFilePath = filepath.Join(tempDir, "test.pid")
|
||||
if scenario.Config.Deduplication.Enabled {
|
||||
scenario.Config.Deduplication.Directory = filepath.Join(tempDir, "dedup")
|
||||
}
|
||||
|
||||
result := ValidateConfigComprehensive(&scenario.Config)
|
||||
|
||||
// Check if test passed as expected
|
||||
testPassed := true
|
||||
if scenario.ShouldPass && result.HasErrors() {
|
||||
fmt.Printf(" ❌ Expected to pass but failed with errors:\n")
|
||||
for _, err := range result.Errors {
|
||||
fmt.Printf(" • %s\n", err.Message)
|
||||
}
|
||||
testPassed = false
|
||||
} else if !scenario.ShouldPass && !result.HasErrors() {
|
||||
fmt.Printf(" ❌ Expected to fail but passed\n")
|
||||
testPassed = false
|
||||
} else if !scenario.ShouldPass && result.HasErrors() {
|
||||
// Check if expected errors are present
|
||||
expectedFound := true
|
||||
for _, expectedError := range scenario.ExpectedErrors {
|
||||
found := false
|
||||
for _, actualError := range result.Errors {
|
||||
if contains([]string{actualError.Message}, expectedError) ||
|
||||
contains([]string{actualError.Error()}, expectedError) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
fmt.Printf(" ❌ Expected error not found: %s\n", expectedError)
|
||||
expectedFound = false
|
||||
}
|
||||
}
|
||||
if !expectedFound {
|
||||
testPassed = false
|
||||
}
|
||||
}
|
||||
|
||||
// Check expected warnings
|
||||
if len(scenario.ExpectedWarnings) > 0 {
|
||||
for _, expectedWarning := range scenario.ExpectedWarnings {
|
||||
found := false
|
||||
for _, actualWarning := range result.Warnings {
|
||||
if contains([]string{actualWarning.Message}, expectedWarning) ||
|
||||
contains([]string{actualWarning.Error()}, expectedWarning) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
fmt.Printf(" ⚠️ Expected warning not found: %s\n", expectedWarning)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if testPassed {
|
||||
fmt.Printf(" ✅ Passed\n")
|
||||
passed++
|
||||
} else {
|
||||
failed++
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
// Summary
|
||||
fmt.Printf("📊 Test Results: %d passed, %d failed\n", passed, failed)
|
||||
if failed > 0 {
|
||||
fmt.Printf("❌ Some tests failed. Please review the implementation.\n")
|
||||
os.Exit(1)
|
||||
} else {
|
||||
fmt.Printf("✅ All tests passed!\n")
|
||||
}
|
||||
}
|
||||
1131
cmd/server/config_validator.go
Normal file
1131
cmd/server/config_validator.go
Normal file
@@ -0,0 +1,1131 @@
|
||||
// config_validator.go
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ConfigValidationError represents a configuration validation error
|
||||
type ConfigValidationError struct {
|
||||
Field string
|
||||
Value interface{}
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e ConfigValidationError) Error() string {
|
||||
return fmt.Sprintf("config validation error in field '%s': %s (value: %v)", e.Field, e.Message, e.Value)
|
||||
}
|
||||
|
||||
// ConfigValidationResult contains the results of config validation
|
||||
type ConfigValidationResult struct {
|
||||
Errors []ConfigValidationError
|
||||
Warnings []ConfigValidationError
|
||||
Valid bool
|
||||
}
|
||||
|
||||
// AddError adds a validation error
|
||||
func (r *ConfigValidationResult) AddError(field string, value interface{}, message string) {
|
||||
r.Errors = append(r.Errors, ConfigValidationError{Field: field, Value: value, Message: message})
|
||||
r.Valid = false
|
||||
}
|
||||
|
||||
// AddWarning adds a validation warning
|
||||
func (r *ConfigValidationResult) AddWarning(field string, value interface{}, message string) {
|
||||
r.Warnings = append(r.Warnings, ConfigValidationError{Field: field, Value: value, Message: message})
|
||||
}
|
||||
|
||||
// HasErrors returns true if there are validation errors
|
||||
func (r *ConfigValidationResult) HasErrors() bool {
|
||||
return len(r.Errors) > 0
|
||||
}
|
||||
|
||||
// HasWarnings returns true if there are validation warnings
|
||||
func (r *ConfigValidationResult) HasWarnings() bool {
|
||||
return len(r.Warnings) > 0
|
||||
}
|
||||
|
||||
// ValidateConfigComprehensive performs comprehensive configuration validation
|
||||
func ValidateConfigComprehensive(c *Config) *ConfigValidationResult {
|
||||
result := &ConfigValidationResult{Valid: true}
|
||||
|
||||
// Validate each section
|
||||
validateServerConfig(&c.Server, result)
|
||||
validateSecurityConfig(&c.Security, result)
|
||||
validateLoggingConfig(&c.Logging, result)
|
||||
validateTimeoutConfig(&c.Timeouts, result)
|
||||
validateUploadsConfig(&c.Uploads, result)
|
||||
validateDownloadsConfig(&c.Downloads, result)
|
||||
validateClamAVConfig(&c.ClamAV, result)
|
||||
validateRedisConfig(&c.Redis, result)
|
||||
validateWorkersConfig(&c.Workers, result)
|
||||
validateVersioningConfig(&c.Versioning, result)
|
||||
validateDeduplicationConfig(&c.Deduplication, result)
|
||||
validateISOConfig(&c.ISO, result)
|
||||
|
||||
// Cross-section validations
|
||||
validateCrossSection(c, result)
|
||||
|
||||
// Enhanced validations
|
||||
validateSystemResources(result)
|
||||
validateNetworkConnectivity(c, result)
|
||||
validatePerformanceSettings(c, result)
|
||||
validateSecurityHardening(c, result)
|
||||
|
||||
// Check disk space for storage paths
|
||||
if c.Server.StoragePath != "" {
|
||||
checkDiskSpace(c.Server.StoragePath, result)
|
||||
}
|
||||
if c.Deduplication.Enabled && c.Deduplication.Directory != "" {
|
||||
checkDiskSpace(c.Deduplication.Directory, result)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// validateServerConfig validates server configuration
|
||||
func validateServerConfig(server *ServerConfig, result *ConfigValidationResult) {
|
||||
// ListenAddress validation
|
||||
if server.ListenAddress == "" {
|
||||
result.AddError("server.listenport", server.ListenAddress, "listen address/port is required")
|
||||
} else {
|
||||
if !isValidPort(server.ListenAddress) {
|
||||
result.AddError("server.listenport", server.ListenAddress, "invalid port number (must be 1-65535)")
|
||||
}
|
||||
}
|
||||
|
||||
// BindIP validation
|
||||
if server.BindIP != "" {
|
||||
if ip := net.ParseIP(server.BindIP); ip == nil {
|
||||
result.AddError("server.bind_ip", server.BindIP, "invalid IP address format")
|
||||
}
|
||||
}
|
||||
|
||||
// StoragePath validation
|
||||
if server.StoragePath == "" {
|
||||
result.AddError("server.storagepath", server.StoragePath, "storage path is required")
|
||||
} else {
|
||||
if err := validateDirectoryPath(server.StoragePath, true); err != nil {
|
||||
result.AddError("server.storagepath", server.StoragePath, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// MetricsPort validation
|
||||
if server.MetricsEnabled && server.MetricsPort != "" {
|
||||
if !isValidPort(server.MetricsPort) {
|
||||
result.AddError("server.metricsport", server.MetricsPort, "invalid metrics port number")
|
||||
}
|
||||
if server.MetricsPort == server.ListenAddress {
|
||||
result.AddError("server.metricsport", server.MetricsPort, "metrics port cannot be the same as main listen port")
|
||||
}
|
||||
}
|
||||
|
||||
// Size validations
|
||||
if server.MaxUploadSize != "" {
|
||||
if _, err := parseSize(server.MaxUploadSize); err != nil {
|
||||
result.AddError("server.max_upload_size", server.MaxUploadSize, "invalid size format")
|
||||
}
|
||||
}
|
||||
|
||||
if server.MinFreeBytes != "" {
|
||||
if _, err := parseSize(server.MinFreeBytes); err != nil {
|
||||
result.AddError("server.min_free_bytes", server.MinFreeBytes, "invalid size format")
|
||||
}
|
||||
}
|
||||
|
||||
// TTL validation
|
||||
if server.FileTTLEnabled {
|
||||
if server.FileTTL == "" {
|
||||
result.AddError("server.filettl", server.FileTTL, "file TTL is required when TTL is enabled")
|
||||
} else {
|
||||
if _, err := parseTTL(server.FileTTL); err != nil {
|
||||
result.AddError("server.filettl", server.FileTTL, "invalid TTL format")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// File naming validation
|
||||
validFileNaming := []string{"HMAC", "original", "None"}
|
||||
if !contains(validFileNaming, server.FileNaming) {
|
||||
result.AddError("server.file_naming", server.FileNaming, "must be one of: HMAC, original, None")
|
||||
}
|
||||
|
||||
// Protocol validation
|
||||
validProtocols := []string{"ipv4", "ipv6", "auto", ""}
|
||||
if !contains(validProtocols, server.ForceProtocol) {
|
||||
result.AddError("server.force_protocol", server.ForceProtocol, "must be one of: ipv4, ipv6, auto, or empty")
|
||||
}
|
||||
|
||||
// PID file validation
|
||||
if server.PIDFilePath != "" {
|
||||
dir := filepath.Dir(server.PIDFilePath)
|
||||
if err := validateDirectoryPath(dir, false); err != nil {
|
||||
result.AddError("server.pidfilepath", server.PIDFilePath, fmt.Sprintf("PID file directory invalid: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
// Worker threshold validation
|
||||
if server.EnableDynamicWorkers {
|
||||
if server.WorkerScaleUpThresh <= 0 {
|
||||
result.AddError("server.worker_scale_up_thresh", server.WorkerScaleUpThresh, "must be positive when dynamic workers are enabled")
|
||||
}
|
||||
if server.WorkerScaleDownThresh <= 0 {
|
||||
result.AddError("server.worker_scale_down_thresh", server.WorkerScaleDownThresh, "must be positive when dynamic workers are enabled")
|
||||
}
|
||||
if server.WorkerScaleDownThresh >= server.WorkerScaleUpThresh {
|
||||
result.AddWarning("server.worker_scale_down_thresh", server.WorkerScaleDownThresh, "scale down threshold should be lower than scale up threshold")
|
||||
}
|
||||
}
|
||||
|
||||
// Extensions validation
|
||||
for _, ext := range server.GlobalExtensions {
|
||||
if !strings.HasPrefix(ext, ".") {
|
||||
result.AddError("server.global_extensions", ext, "file extensions must start with a dot")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validateSecurityConfig validates security configuration
|
||||
func validateSecurityConfig(security *SecurityConfig, result *ConfigValidationResult) {
|
||||
if security.EnableJWT {
|
||||
// JWT validation
|
||||
if strings.TrimSpace(security.JWTSecret) == "" {
|
||||
result.AddError("security.jwtsecret", security.JWTSecret, "JWT secret is required when JWT is enabled")
|
||||
} else if len(security.JWTSecret) < 32 {
|
||||
result.AddWarning("security.jwtsecret", "[REDACTED]", "JWT secret should be at least 32 characters for security")
|
||||
}
|
||||
|
||||
validAlgorithms := []string{"HS256", "HS384", "HS512", "RS256", "RS384", "RS512", "ES256", "ES384", "ES512"}
|
||||
if !contains(validAlgorithms, security.JWTAlgorithm) {
|
||||
result.AddError("security.jwtalgorithm", security.JWTAlgorithm, "unsupported JWT algorithm")
|
||||
}
|
||||
|
||||
if security.JWTExpiration != "" {
|
||||
if _, err := time.ParseDuration(security.JWTExpiration); err != nil {
|
||||
result.AddError("security.jwtexpiration", security.JWTExpiration, "invalid JWT expiration format")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// HMAC validation
|
||||
if strings.TrimSpace(security.Secret) == "" {
|
||||
result.AddError("security.secret", security.Secret, "HMAC secret is required when JWT is disabled")
|
||||
} else if len(security.Secret) < 16 {
|
||||
result.AddWarning("security.secret", "[REDACTED]", "HMAC secret should be at least 16 characters for security")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validateLoggingConfig validates logging configuration
|
||||
func validateLoggingConfig(logging *LoggingConfig, result *ConfigValidationResult) {
|
||||
validLevels := []string{"panic", "fatal", "error", "warn", "warning", "info", "debug", "trace"}
|
||||
if !contains(validLevels, strings.ToLower(logging.Level)) {
|
||||
result.AddError("logging.level", logging.Level, "invalid log level")
|
||||
}
|
||||
|
||||
if logging.File != "" {
|
||||
dir := filepath.Dir(logging.File)
|
||||
if err := validateDirectoryPath(dir, false); err != nil {
|
||||
result.AddError("logging.file", logging.File, fmt.Sprintf("log file directory invalid: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
if logging.MaxSize <= 0 {
|
||||
result.AddWarning("logging.max_size", logging.MaxSize, "max size should be positive")
|
||||
}
|
||||
|
||||
if logging.MaxBackups < 0 {
|
||||
result.AddWarning("logging.max_backups", logging.MaxBackups, "max backups should be non-negative")
|
||||
}
|
||||
|
||||
if logging.MaxAge < 0 {
|
||||
result.AddWarning("logging.max_age", logging.MaxAge, "max age should be non-negative")
|
||||
}
|
||||
}
|
||||
|
||||
// validateTimeoutConfig validates timeout configuration
|
||||
func validateTimeoutConfig(timeouts *TimeoutConfig, result *ConfigValidationResult) {
|
||||
if timeouts.Read != "" {
|
||||
if duration, err := time.ParseDuration(timeouts.Read); err != nil {
|
||||
result.AddError("timeouts.read", timeouts.Read, "invalid read timeout format")
|
||||
} else if duration <= 0 {
|
||||
result.AddError("timeouts.read", timeouts.Read, "read timeout must be positive")
|
||||
}
|
||||
}
|
||||
|
||||
if timeouts.Write != "" {
|
||||
if duration, err := time.ParseDuration(timeouts.Write); err != nil {
|
||||
result.AddError("timeouts.write", timeouts.Write, "invalid write timeout format")
|
||||
} else if duration <= 0 {
|
||||
result.AddError("timeouts.write", timeouts.Write, "write timeout must be positive")
|
||||
}
|
||||
}
|
||||
|
||||
if timeouts.Idle != "" {
|
||||
if duration, err := time.ParseDuration(timeouts.Idle); err != nil {
|
||||
result.AddError("timeouts.idle", timeouts.Idle, "invalid idle timeout format")
|
||||
} else if duration <= 0 {
|
||||
result.AddError("timeouts.idle", timeouts.Idle, "idle timeout must be positive")
|
||||
}
|
||||
}
|
||||
|
||||
if timeouts.Shutdown != "" {
|
||||
if duration, err := time.ParseDuration(timeouts.Shutdown); err != nil {
|
||||
result.AddError("timeouts.shutdown", timeouts.Shutdown, "invalid shutdown timeout format")
|
||||
} else if duration <= 0 {
|
||||
result.AddError("timeouts.shutdown", timeouts.Shutdown, "shutdown timeout must be positive")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validateUploadsConfig validates uploads configuration
|
||||
func validateUploadsConfig(uploads *UploadsConfig, result *ConfigValidationResult) {
|
||||
// Validate extensions
|
||||
for _, ext := range uploads.AllowedExtensions {
|
||||
if !strings.HasPrefix(ext, ".") {
|
||||
result.AddError("uploads.allowed_extensions", ext, "file extensions must start with a dot")
|
||||
}
|
||||
}
|
||||
|
||||
// Validate chunk size
|
||||
if uploads.ChunkSize != "" {
|
||||
if _, err := parseSize(uploads.ChunkSize); err != nil {
|
||||
result.AddError("uploads.chunk_size", uploads.ChunkSize, "invalid chunk size format")
|
||||
}
|
||||
}
|
||||
|
||||
// Validate resumable age
|
||||
if uploads.MaxResumableAge != "" {
|
||||
if _, err := time.ParseDuration(uploads.MaxResumableAge); err != nil {
|
||||
result.AddError("uploads.max_resumable_age", uploads.MaxResumableAge, "invalid resumable age format")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validateDownloadsConfig validates downloads configuration
|
||||
func validateDownloadsConfig(downloads *DownloadsConfig, result *ConfigValidationResult) {
|
||||
// Validate extensions
|
||||
for _, ext := range downloads.AllowedExtensions {
|
||||
if !strings.HasPrefix(ext, ".") {
|
||||
result.AddError("downloads.allowed_extensions", ext, "file extensions must start with a dot")
|
||||
}
|
||||
}
|
||||
|
||||
// Validate chunk size
|
||||
if downloads.ChunkSize != "" {
|
||||
if _, err := parseSize(downloads.ChunkSize); err != nil {
|
||||
result.AddError("downloads.chunk_size", downloads.ChunkSize, "invalid chunk size format")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validateClamAVConfig validates ClamAV configuration
|
||||
func validateClamAVConfig(clamav *ClamAVConfig, result *ConfigValidationResult) {
|
||||
if clamav.ClamAVEnabled {
|
||||
if clamav.ClamAVSocket == "" {
|
||||
result.AddWarning("clamav.clamavsocket", clamav.ClamAVSocket, "ClamAV socket path not specified, using default")
|
||||
} else {
|
||||
// Check if socket file exists
|
||||
if _, err := os.Stat(clamav.ClamAVSocket); os.IsNotExist(err) {
|
||||
result.AddWarning("clamav.clamavsocket", clamav.ClamAVSocket, "ClamAV socket file does not exist")
|
||||
}
|
||||
}
|
||||
|
||||
if clamav.NumScanWorkers <= 0 {
|
||||
result.AddError("clamav.numscanworkers", clamav.NumScanWorkers, "number of scan workers must be positive")
|
||||
}
|
||||
|
||||
// Validate scan extensions
|
||||
for _, ext := range clamav.ScanFileExtensions {
|
||||
if !strings.HasPrefix(ext, ".") {
|
||||
result.AddError("clamav.scanfileextensions", ext, "file extensions must start with a dot")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validateRedisConfig validates Redis configuration
|
||||
func validateRedisConfig(redis *RedisConfig, result *ConfigValidationResult) {
|
||||
if redis.RedisEnabled {
|
||||
if redis.RedisAddr == "" {
|
||||
result.AddError("redis.redisaddr", redis.RedisAddr, "Redis address is required when Redis is enabled")
|
||||
} else {
|
||||
// Validate address format (host:port)
|
||||
if !isValidHostPort(redis.RedisAddr) {
|
||||
result.AddError("redis.redisaddr", redis.RedisAddr, "invalid Redis address format (should be host:port)")
|
||||
}
|
||||
}
|
||||
|
||||
if redis.RedisDBIndex < 0 || redis.RedisDBIndex > 15 {
|
||||
result.AddWarning("redis.redisdbindex", redis.RedisDBIndex, "Redis DB index is typically 0-15")
|
||||
}
|
||||
|
||||
if redis.RedisHealthCheckInterval != "" {
|
||||
if _, err := time.ParseDuration(redis.RedisHealthCheckInterval); err != nil {
|
||||
result.AddError("redis.redishealthcheckinterval", redis.RedisHealthCheckInterval, "invalid health check interval format")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validateWorkersConfig validates workers configuration
|
||||
func validateWorkersConfig(workers *WorkersConfig, result *ConfigValidationResult) {
|
||||
if workers.NumWorkers <= 0 {
|
||||
result.AddError("workers.numworkers", workers.NumWorkers, "number of workers must be positive")
|
||||
}
|
||||
|
||||
if workers.UploadQueueSize <= 0 {
|
||||
result.AddError("workers.uploadqueuesize", workers.UploadQueueSize, "upload queue size must be positive")
|
||||
}
|
||||
|
||||
// Performance recommendations
|
||||
if workers.NumWorkers > 50 {
|
||||
result.AddWarning("workers.numworkers", workers.NumWorkers, "very high worker count may impact performance")
|
||||
}
|
||||
|
||||
if workers.UploadQueueSize > 1000 {
|
||||
result.AddWarning("workers.uploadqueuesize", workers.UploadQueueSize, "very large queue size may impact memory usage")
|
||||
}
|
||||
}
|
||||
|
||||
// validateVersioningConfig validates versioning configuration
|
||||
func validateVersioningConfig(versioning *VersioningConfig, result *ConfigValidationResult) {
|
||||
if versioning.Enabled {
|
||||
if versioning.MaxRevs <= 0 {
|
||||
result.AddError("versioning.maxversions", versioning.MaxRevs, "max versions must be positive when versioning is enabled")
|
||||
}
|
||||
|
||||
validBackends := []string{"filesystem", "database", "s3", ""}
|
||||
if !contains(validBackends, versioning.Backend) {
|
||||
result.AddWarning("versioning.backend", versioning.Backend, "unknown versioning backend")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validateDeduplicationConfig validates deduplication configuration
|
||||
func validateDeduplicationConfig(dedup *DeduplicationConfig, result *ConfigValidationResult) {
|
||||
if dedup.Enabled {
|
||||
if dedup.Directory == "" {
|
||||
result.AddError("deduplication.directory", dedup.Directory, "deduplication directory is required when deduplication is enabled")
|
||||
} else {
|
||||
if err := validateDirectoryPath(dedup.Directory, true); err != nil {
|
||||
result.AddError("deduplication.directory", dedup.Directory, err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validateISOConfig validates ISO configuration
|
||||
func validateISOConfig(iso *ISOConfig, result *ConfigValidationResult) {
|
||||
if iso.Enabled {
|
||||
if iso.MountPoint == "" {
|
||||
result.AddError("iso.mount_point", iso.MountPoint, "mount point is required when ISO is enabled")
|
||||
}
|
||||
|
||||
if iso.Size != "" {
|
||||
if _, err := parseSize(iso.Size); err != nil {
|
||||
result.AddError("iso.size", iso.Size, "invalid ISO size format")
|
||||
}
|
||||
}
|
||||
|
||||
if iso.ContainerFile == "" {
|
||||
result.AddWarning("iso.containerfile", iso.ContainerFile, "container file path not specified")
|
||||
}
|
||||
|
||||
validCharsets := []string{"utf-8", "iso-8859-1", "ascii", ""}
|
||||
if !contains(validCharsets, strings.ToLower(iso.Charset)) {
|
||||
result.AddWarning("iso.charset", iso.Charset, "uncommon charset specified")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validateCrossSection performs cross-section validations
|
||||
func validateCrossSection(c *Config, result *ConfigValidationResult) {
|
||||
// Storage path vs deduplication directory conflict
|
||||
if c.Deduplication.Enabled && c.Server.StoragePath == c.Deduplication.Directory {
|
||||
result.AddError("deduplication.directory", c.Deduplication.Directory, "deduplication directory cannot be the same as storage path")
|
||||
}
|
||||
|
||||
// ISO mount point vs storage path conflict
|
||||
if c.ISO.Enabled && c.Server.StoragePath == c.ISO.MountPoint {
|
||||
result.AddWarning("iso.mount_point", c.ISO.MountPoint, "ISO mount point is the same as storage path")
|
||||
}
|
||||
|
||||
// Extension conflicts between uploads and downloads
|
||||
if len(c.Uploads.AllowedExtensions) > 0 && len(c.Downloads.AllowedExtensions) > 0 {
|
||||
uploadExts := make(map[string]bool)
|
||||
for _, ext := range c.Uploads.AllowedExtensions {
|
||||
uploadExts[ext] = true
|
||||
}
|
||||
|
||||
hasCommonExtensions := false
|
||||
for _, ext := range c.Downloads.AllowedExtensions {
|
||||
if uploadExts[ext] {
|
||||
hasCommonExtensions = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasCommonExtensions {
|
||||
result.AddWarning("uploads/downloads.allowed_extensions", "", "no common extensions between uploads and downloads - files may not be downloadable")
|
||||
}
|
||||
}
|
||||
|
||||
// Global extensions override warning
|
||||
if len(c.Server.GlobalExtensions) > 0 && (len(c.Uploads.AllowedExtensions) > 0 || len(c.Downloads.AllowedExtensions) > 0) {
|
||||
result.AddWarning("server.global_extensions", c.Server.GlobalExtensions, "global extensions will override upload/download extension settings")
|
||||
}
|
||||
}
|
||||
|
||||
// Enhanced Security Validation Functions
|
||||
|
||||
// checkSecretStrength analyzes the strength of secrets/passwords
|
||||
func checkSecretStrength(secret string) (score int, issues []string) {
|
||||
if len(secret) == 0 {
|
||||
return 0, []string{"secret is empty"}
|
||||
}
|
||||
|
||||
issues = []string{}
|
||||
score = 0
|
||||
|
||||
// Length scoring
|
||||
if len(secret) >= 32 {
|
||||
score += 3
|
||||
} else if len(secret) >= 16 {
|
||||
score += 2
|
||||
} else if len(secret) >= 8 {
|
||||
score += 1
|
||||
} else {
|
||||
issues = append(issues, "secret is too short")
|
||||
}
|
||||
|
||||
// Character variety scoring
|
||||
hasLower := false
|
||||
hasUpper := false
|
||||
hasDigit := false
|
||||
hasSpecial := false
|
||||
|
||||
for _, char := range secret {
|
||||
switch {
|
||||
case char >= 'a' && char <= 'z':
|
||||
hasLower = true
|
||||
case char >= 'A' && char <= 'Z':
|
||||
hasUpper = true
|
||||
case char >= '0' && char <= '9':
|
||||
hasDigit = true
|
||||
case strings.ContainsRune("!@#$%^&*()_+-=[]{}|;:,.<>?", char):
|
||||
hasSpecial = true
|
||||
}
|
||||
}
|
||||
|
||||
varietyCount := 0
|
||||
if hasLower {
|
||||
varietyCount++
|
||||
}
|
||||
if hasUpper {
|
||||
varietyCount++
|
||||
}
|
||||
if hasDigit {
|
||||
varietyCount++
|
||||
}
|
||||
if hasSpecial {
|
||||
varietyCount++
|
||||
}
|
||||
|
||||
score += varietyCount
|
||||
|
||||
if varietyCount < 3 {
|
||||
issues = append(issues, "secret should contain uppercase, lowercase, numbers, and special characters")
|
||||
}
|
||||
|
||||
// Check for common patterns
|
||||
lowerSecret := strings.ToLower(secret)
|
||||
commonWeakPasswords := []string{
|
||||
"password", "123456", "qwerty", "admin", "root", "test", "guest",
|
||||
"secret", "hmac", "server", "default", "changeme", "example",
|
||||
"demo", "temp", "temporary", "fileserver", "upload", "download",
|
||||
}
|
||||
|
||||
for _, weak := range commonWeakPasswords {
|
||||
if strings.Contains(lowerSecret, weak) {
|
||||
issues = append(issues, fmt.Sprintf("contains common weak pattern: %s", weak))
|
||||
score -= 2
|
||||
}
|
||||
}
|
||||
|
||||
// Check for repeated characters
|
||||
if hasRepeatedChars(secret) {
|
||||
issues = append(issues, "contains too many repeated characters")
|
||||
score -= 1
|
||||
}
|
||||
|
||||
// Ensure score doesn't go negative
|
||||
if score < 0 {
|
||||
score = 0
|
||||
}
|
||||
|
||||
return score, issues
|
||||
}
|
||||
|
||||
// hasRepeatedChars checks if a string has excessive repeated characters
|
||||
func hasRepeatedChars(s string) bool {
|
||||
if len(s) < 4 {
|
||||
return false
|
||||
}
|
||||
|
||||
for i := 0; i <= len(s)-3; i++ {
|
||||
if s[i] == s[i+1] && s[i+1] == s[i+2] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isDefaultOrExampleSecret checks if a secret appears to be a default/example value
|
||||
func isDefaultOrExampleSecret(secret string) bool {
|
||||
defaultSecrets := []string{
|
||||
"your-secret-key-here",
|
||||
"change-this-secret",
|
||||
"example-secret",
|
||||
"default-secret",
|
||||
"test-secret",
|
||||
"demo-secret",
|
||||
"sample-secret",
|
||||
"placeholder",
|
||||
"PUT_YOUR_SECRET_HERE",
|
||||
"CHANGE_ME",
|
||||
"YOUR_JWT_SECRET",
|
||||
"your-hmac-secret",
|
||||
"supersecret",
|
||||
"secretkey",
|
||||
"myverysecuresecret",
|
||||
}
|
||||
|
||||
lowerSecret := strings.ToLower(strings.TrimSpace(secret))
|
||||
|
||||
for _, defaultSecret := range defaultSecrets {
|
||||
if strings.Contains(lowerSecret, strings.ToLower(defaultSecret)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check for obvious patterns
|
||||
if strings.Contains(lowerSecret, "example") ||
|
||||
strings.Contains(lowerSecret, "default") ||
|
||||
strings.Contains(lowerSecret, "change") ||
|
||||
strings.Contains(lowerSecret, "replace") ||
|
||||
strings.Contains(lowerSecret, "todo") ||
|
||||
strings.Contains(lowerSecret, "fixme") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// calculateEntropy calculates the Shannon entropy of a string
|
||||
func calculateEntropy(s string) float64 {
|
||||
if len(s) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Count character frequencies
|
||||
freq := make(map[rune]int)
|
||||
for _, char := range s {
|
||||
freq[char]++
|
||||
}
|
||||
|
||||
// Calculate entropy
|
||||
entropy := 0.0
|
||||
length := float64(len(s))
|
||||
|
||||
for _, count := range freq {
|
||||
if count > 0 {
|
||||
p := float64(count) / length
|
||||
entropy -= p * (float64(count) / length) // Simplified calculation
|
||||
}
|
||||
}
|
||||
|
||||
return entropy
|
||||
}
|
||||
|
||||
// validateSecretSecurity performs comprehensive secret security validation
|
||||
func validateSecretSecurity(fieldName, secret string, result *ConfigValidationResult) {
|
||||
if secret == "" {
|
||||
return // Already handled by other validators
|
||||
}
|
||||
|
||||
// Check for default/example secrets
|
||||
if isDefaultOrExampleSecret(secret) {
|
||||
result.AddError(fieldName, "[REDACTED]", "appears to be a default or example secret - must be changed")
|
||||
return
|
||||
}
|
||||
|
||||
// Check secret strength
|
||||
score, issues := checkSecretStrength(secret)
|
||||
|
||||
if score < 3 {
|
||||
for _, issue := range issues {
|
||||
result.AddError(fieldName, "[REDACTED]", fmt.Sprintf("weak secret: %s", issue))
|
||||
}
|
||||
} else if score < 6 {
|
||||
for _, issue := range issues {
|
||||
result.AddWarning(fieldName, "[REDACTED]", fmt.Sprintf("secret could be stronger: %s", issue))
|
||||
}
|
||||
}
|
||||
|
||||
// Check entropy (simplified)
|
||||
entropy := calculateEntropy(secret)
|
||||
if entropy < 3.0 {
|
||||
result.AddWarning(fieldName, "[REDACTED]", "secret has low entropy - consider using more varied characters")
|
||||
}
|
||||
|
||||
// Length-specific warnings
|
||||
if len(secret) > 256 {
|
||||
result.AddWarning(fieldName, "[REDACTED]", "secret is very long - may impact performance")
|
||||
}
|
||||
}
|
||||
|
||||
// validateSystemResources checks system resource availability
|
||||
func validateSystemResources(result *ConfigValidationResult) {
|
||||
// Check available CPU cores
|
||||
cpuCores := runtime.NumCPU()
|
||||
if cpuCores < 2 {
|
||||
result.AddWarning("system.cpu", cpuCores, "minimum 2 CPU cores recommended for optimal performance")
|
||||
} else if cpuCores < 4 {
|
||||
result.AddWarning("system.cpu", cpuCores, "4+ CPU cores recommended for high-load environments")
|
||||
}
|
||||
|
||||
// Check available memory (basic check through runtime)
|
||||
var memStats runtime.MemStats
|
||||
runtime.ReadMemStats(&memStats)
|
||||
|
||||
// Basic memory availability check (simplified version)
|
||||
// This checks current Go heap, but for production we'd want system memory
|
||||
allocMB := float64(memStats.Alloc) / 1024 / 1024
|
||||
if allocMB > 512 {
|
||||
result.AddWarning("system.memory", allocMB, "current memory usage is high - ensure adequate system memory")
|
||||
}
|
||||
|
||||
// Check for potential resource constraints
|
||||
numGoroutines := runtime.NumGoroutine()
|
||||
if numGoroutines > 1000 {
|
||||
result.AddWarning("system.goroutines", numGoroutines, "high goroutine count may indicate resource constraints")
|
||||
}
|
||||
}
|
||||
|
||||
// validateNetworkConnectivity tests network connectivity to external services
|
||||
func validateNetworkConnectivity(c *Config, result *ConfigValidationResult) {
|
||||
// Test Redis connectivity if enabled
|
||||
if c.Redis.RedisEnabled && c.Redis.RedisAddr != "" {
|
||||
if err := testNetworkConnection("tcp", c.Redis.RedisAddr, 5*time.Second); err != nil {
|
||||
result.AddWarning("redis.connectivity", c.Redis.RedisAddr, fmt.Sprintf("cannot connect to Redis: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
// Test ClamAV connectivity if enabled
|
||||
if c.ClamAV.ClamAVEnabled && c.ClamAV.ClamAVSocket != "" {
|
||||
// For Unix socket, test file existence and permissions
|
||||
if strings.HasPrefix(c.ClamAV.ClamAVSocket, "/") {
|
||||
if stat, err := os.Stat(c.ClamAV.ClamAVSocket); err != nil {
|
||||
result.AddWarning("clamav.connectivity", c.ClamAV.ClamAVSocket, fmt.Sprintf("ClamAV socket not accessible: %v", err))
|
||||
} else if stat.Mode()&os.ModeSocket == 0 {
|
||||
result.AddWarning("clamav.connectivity", c.ClamAV.ClamAVSocket, "specified path is not a socket file")
|
||||
}
|
||||
} else {
|
||||
// Assume TCP connection format
|
||||
if err := testNetworkConnection("tcp", c.ClamAV.ClamAVSocket, 5*time.Second); err != nil {
|
||||
result.AddWarning("clamav.connectivity", c.ClamAV.ClamAVSocket, fmt.Sprintf("cannot connect to ClamAV: %v", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// testNetworkConnection attempts to connect to a network address
|
||||
func testNetworkConnection(network, address string, timeout time.Duration) error {
|
||||
conn, err := net.DialTimeout(network, address, timeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// validatePerformanceSettings analyzes configuration for performance implications
|
||||
func validatePerformanceSettings(c *Config, result *ConfigValidationResult) {
|
||||
// Check worker configuration against system resources
|
||||
cpuCores := runtime.NumCPU()
|
||||
|
||||
if c.Workers.NumWorkers > cpuCores*4 {
|
||||
result.AddWarning("workers.performance", c.Workers.NumWorkers,
|
||||
fmt.Sprintf("worker count (%d) significantly exceeds CPU cores (%d) - may cause context switching overhead",
|
||||
c.Workers.NumWorkers, cpuCores))
|
||||
}
|
||||
|
||||
// Check ClamAV scan workers
|
||||
if c.ClamAV.ClamAVEnabled && c.ClamAV.NumScanWorkers > cpuCores {
|
||||
result.AddWarning("clamav.performance", c.ClamAV.NumScanWorkers,
|
||||
fmt.Sprintf("scan workers (%d) exceed CPU cores (%d) - may impact scanning performance",
|
||||
c.ClamAV.NumScanWorkers, cpuCores))
|
||||
}
|
||||
|
||||
// Check timeout configurations for performance balance
|
||||
if c.Timeouts.Read != "" {
|
||||
if duration, err := time.ParseDuration(c.Timeouts.Read); err == nil {
|
||||
if duration > 300*time.Second {
|
||||
result.AddWarning("timeouts.performance", c.Timeouts.Read, "very long read timeout may impact server responsiveness")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check upload size vs available resources
|
||||
if c.Server.MaxUploadSize != "" {
|
||||
if size, err := parseSize(c.Server.MaxUploadSize); err == nil {
|
||||
if size > 10*1024*1024*1024 { // 10GB
|
||||
result.AddWarning("server.performance", c.Server.MaxUploadSize, "very large max upload size requires adequate disk space and memory")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for potential memory-intensive configurations
|
||||
if c.Workers.UploadQueueSize > 500 && c.Workers.NumWorkers > 20 {
|
||||
result.AddWarning("workers.memory", fmt.Sprintf("queue:%d workers:%d", c.Workers.UploadQueueSize, c.Workers.NumWorkers),
|
||||
"high queue size with many workers may consume significant memory")
|
||||
}
|
||||
}
|
||||
|
||||
// validateSecurityHardening performs advanced security validation
|
||||
func validateSecurityHardening(c *Config, result *ConfigValidationResult) {
|
||||
// Check for default or weak configurations
|
||||
if c.Security.EnableJWT {
|
||||
if c.Security.JWTSecret == "your-secret-key-here" || c.Security.JWTSecret == "changeme" {
|
||||
result.AddError("security.jwtsecret", "[REDACTED]", "JWT secret appears to be a default value - change immediately")
|
||||
}
|
||||
|
||||
// Check JWT algorithm strength
|
||||
weakAlgorithms := []string{"HS256"} // HS256 is considered less secure than RS256
|
||||
if contains(weakAlgorithms, c.Security.JWTAlgorithm) {
|
||||
result.AddWarning("security.jwtalgorithm", c.Security.JWTAlgorithm, "consider using RS256 or ES256 for enhanced security")
|
||||
}
|
||||
} else {
|
||||
if c.Security.Secret == "your-secret-key-here" || c.Security.Secret == "changeme" || c.Security.Secret == "secret" {
|
||||
result.AddError("security.secret", "[REDACTED]", "HMAC secret appears to be a default value - change immediately")
|
||||
}
|
||||
}
|
||||
|
||||
// Check for insecure bind configurations
|
||||
if c.Server.BindIP == "0.0.0.0" {
|
||||
result.AddWarning("server.bind_ip", c.Server.BindIP, "binding to 0.0.0.0 exposes service to all interfaces - ensure firewall protection")
|
||||
}
|
||||
|
||||
// Check for development/debug settings in production
|
||||
if c.Logging.Level == "debug" || c.Logging.Level == "trace" {
|
||||
result.AddWarning("logging.security", c.Logging.Level, "debug/trace logging may expose sensitive information - use 'info' or 'warn' in production")
|
||||
}
|
||||
|
||||
// Check file permissions for sensitive paths
|
||||
if c.Server.StoragePath != "" {
|
||||
if stat, err := os.Stat(c.Server.StoragePath); err == nil {
|
||||
mode := stat.Mode().Perm()
|
||||
if mode&0077 != 0 { // World or group writable
|
||||
result.AddWarning("server.storagepath.permissions", c.Server.StoragePath, "storage directory permissions allow group/world access - consider restricting to owner-only")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkDiskSpace validates available disk space for storage paths
|
||||
func checkDiskSpace(path string, result *ConfigValidationResult) {
|
||||
if stat, err := os.Stat(path); err == nil && stat.IsDir() {
|
||||
// Get available space (platform-specific implementation would be more robust)
|
||||
// This is a simplified check - in production, use syscall.Statfs on Unix or similar
|
||||
|
||||
// For now, we'll just check if we can write a test file
|
||||
testFile := filepath.Join(path, ".disk_space_test")
|
||||
if f, err := os.Create(testFile); err != nil {
|
||||
result.AddWarning("system.disk_space", path, fmt.Sprintf("cannot write to storage directory: %v", err))
|
||||
} else {
|
||||
f.Close()
|
||||
os.Remove(testFile)
|
||||
|
||||
// Additional check: try to write a larger test file to estimate space
|
||||
const testSize = 1024 * 1024 // 1MB
|
||||
testData := make([]byte, testSize)
|
||||
if f, err := os.Create(testFile); err == nil {
|
||||
if _, err := f.Write(testData); err != nil {
|
||||
result.AddWarning("system.disk_space", path, "low disk space detected - ensure adequate storage for operations")
|
||||
}
|
||||
f.Close()
|
||||
os.Remove(testFile)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isValidPort checks if a string represents a valid port number
|
||||
func isValidPort(port string) bool {
|
||||
if p, err := strconv.Atoi(port); err != nil || p < 1 || p > 65535 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// isValidHostPort checks if a string is a valid host:port combination
|
||||
func isValidHostPort(hostPort string) bool {
|
||||
host, port, err := net.SplitHostPort(hostPort)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Validate port
|
||||
if !isValidPort(port) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Validate host (can be IP, hostname, or empty for localhost)
|
||||
if host != "" {
|
||||
if ip := net.ParseIP(host); ip == nil {
|
||||
// If not an IP, check if it's a valid hostname
|
||||
matched, _ := regexp.MatchString(`^[a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?)*$`, host)
|
||||
return matched
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// validateDirectoryPath validates a directory path
|
||||
func validateDirectoryPath(path string, createIfMissing bool) error {
|
||||
if path == "" {
|
||||
return errors.New("directory path cannot be empty")
|
||||
}
|
||||
|
||||
// Check if path exists
|
||||
if stat, err := os.Stat(path); os.IsNotExist(err) {
|
||||
if createIfMissing {
|
||||
// Try to create the directory
|
||||
if err := os.MkdirAll(path, 0755); err != nil {
|
||||
return fmt.Errorf("cannot create directory: %v", err)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("directory does not exist: %s", path)
|
||||
}
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("cannot access directory: %v", err)
|
||||
} else if !stat.IsDir() {
|
||||
return fmt.Errorf("path exists but is not a directory: %s", path)
|
||||
}
|
||||
|
||||
// Check if directory is writable
|
||||
testFile := filepath.Join(path, ".write_test")
|
||||
if f, err := os.Create(testFile); err != nil {
|
||||
return fmt.Errorf("directory is not writable: %v", err)
|
||||
} else {
|
||||
f.Close()
|
||||
os.Remove(testFile)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// contains checks if a slice contains a string
|
||||
func contains(slice []string, item string) bool {
|
||||
for _, s := range slice {
|
||||
if s == item {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// PrintValidationResults prints the validation results in a user-friendly format
|
||||
func PrintValidationResults(result *ConfigValidationResult) {
|
||||
if result.HasErrors() {
|
||||
log.Error("❌ Configuration validation failed with the following errors:")
|
||||
for _, err := range result.Errors {
|
||||
log.Errorf(" • %s", err.Error())
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
if result.HasWarnings() {
|
||||
log.Warn("⚠️ Configuration validation completed with warnings:")
|
||||
for _, warn := range result.Warnings {
|
||||
log.Warnf(" • %s", warn.Error())
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
if !result.HasErrors() && !result.HasWarnings() {
|
||||
log.Info("✅ Configuration validation passed successfully!")
|
||||
}
|
||||
}
|
||||
|
||||
// runSpecializedValidation performs targeted validation based on flags
|
||||
func runSpecializedValidation(c *Config, security, performance, connectivity, quiet, verbose, fixable bool) {
|
||||
result := &ConfigValidationResult{Valid: true}
|
||||
|
||||
if verbose {
|
||||
log.Info("Running specialized validation with detailed output...")
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
// Run only the requested validation types
|
||||
if security {
|
||||
if verbose {
|
||||
log.Info("🔐 Running security validation checks...")
|
||||
}
|
||||
validateSecurityConfig(&c.Security, result)
|
||||
validateSecurityHardening(c, result)
|
||||
}
|
||||
|
||||
if performance {
|
||||
if verbose {
|
||||
log.Info("⚡ Running performance validation checks...")
|
||||
}
|
||||
validatePerformanceSettings(c, result)
|
||||
validateSystemResources(result)
|
||||
}
|
||||
|
||||
if connectivity {
|
||||
if verbose {
|
||||
log.Info("🌐 Running connectivity validation checks...")
|
||||
}
|
||||
validateNetworkConnectivity(c, result)
|
||||
}
|
||||
|
||||
// If no specific type is requested, run basic validation
|
||||
if !security && !performance && !connectivity {
|
||||
if verbose {
|
||||
log.Info("🔍 Running comprehensive validation...")
|
||||
}
|
||||
result = ValidateConfigComprehensive(c)
|
||||
}
|
||||
|
||||
// Filter results based on flags
|
||||
if fixable {
|
||||
filterFixableIssues(result)
|
||||
}
|
||||
|
||||
// Output results based on verbosity
|
||||
if quiet {
|
||||
printQuietValidationResults(result)
|
||||
} else if verbose {
|
||||
printVerboseValidationResults(result)
|
||||
} else {
|
||||
PrintValidationResults(result)
|
||||
}
|
||||
|
||||
// Exit with appropriate code
|
||||
if result.HasErrors() {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// filterFixableIssues removes non-fixable issues from results
|
||||
func filterFixableIssues(result *ConfigValidationResult) {
|
||||
fixablePatterns := []string{
|
||||
"permissions",
|
||||
"directory",
|
||||
"default value",
|
||||
"debug logging",
|
||||
"size format",
|
||||
"timeout format",
|
||||
"port number",
|
||||
"IP address",
|
||||
}
|
||||
|
||||
var fixableErrors []ConfigValidationError
|
||||
var fixableWarnings []ConfigValidationError
|
||||
|
||||
for _, err := range result.Errors {
|
||||
for _, pattern := range fixablePatterns {
|
||||
if strings.Contains(strings.ToLower(err.Message), pattern) {
|
||||
fixableErrors = append(fixableErrors, err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, warn := range result.Warnings {
|
||||
for _, pattern := range fixablePatterns {
|
||||
if strings.Contains(strings.ToLower(warn.Message), pattern) {
|
||||
fixableWarnings = append(fixableWarnings, warn)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result.Errors = fixableErrors
|
||||
result.Warnings = fixableWarnings
|
||||
result.Valid = len(fixableErrors) == 0
|
||||
}
|
||||
|
||||
// printQuietValidationResults prints only errors
|
||||
func printQuietValidationResults(result *ConfigValidationResult) {
|
||||
if result.HasErrors() {
|
||||
for _, err := range result.Errors {
|
||||
fmt.Printf("ERROR: %s\n", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// printVerboseValidationResults prints detailed validation information
|
||||
func printVerboseValidationResults(result *ConfigValidationResult) {
|
||||
fmt.Println("📊 DETAILED VALIDATION REPORT")
|
||||
fmt.Println("============================")
|
||||
fmt.Println()
|
||||
|
||||
// System information
|
||||
fmt.Printf("🖥️ System: %d CPU cores, %d goroutines\n", runtime.NumCPU(), runtime.NumGoroutine())
|
||||
|
||||
var memStats runtime.MemStats
|
||||
runtime.ReadMemStats(&memStats)
|
||||
fmt.Printf("💾 Memory: %.2f MB allocated\n", float64(memStats.Alloc)/1024/1024)
|
||||
fmt.Println()
|
||||
|
||||
// Validation summary
|
||||
fmt.Printf("✅ Checks passed: %d\n", countPassedChecks(result))
|
||||
fmt.Printf("⚠️ Warnings: %d\n", len(result.Warnings))
|
||||
fmt.Printf("❌ Errors: %d\n", len(result.Errors))
|
||||
fmt.Println()
|
||||
|
||||
// Detailed results
|
||||
if result.HasErrors() {
|
||||
fmt.Println("🚨 CONFIGURATION ERRORS:")
|
||||
for i, err := range result.Errors {
|
||||
fmt.Printf(" %d. Field: %s\n", i+1, err.Field)
|
||||
fmt.Printf(" Issue: %s\n", err.Message)
|
||||
fmt.Printf(" Value: %v\n", err.Value)
|
||||
fmt.Println()
|
||||
}
|
||||
}
|
||||
|
||||
if result.HasWarnings() {
|
||||
fmt.Println("⚠️ CONFIGURATION WARNINGS:")
|
||||
for i, warn := range result.Warnings {
|
||||
fmt.Printf(" %d. Field: %s\n", i+1, warn.Field)
|
||||
fmt.Printf(" Issue: %s\n", warn.Message)
|
||||
fmt.Printf(" Value: %v\n", warn.Value)
|
||||
fmt.Println()
|
||||
}
|
||||
}
|
||||
|
||||
if !result.HasErrors() && !result.HasWarnings() {
|
||||
fmt.Println("🎉 All validation checks passed successfully!")
|
||||
}
|
||||
}
|
||||
|
||||
// countPassedChecks estimates the number of successful validation checks
|
||||
func countPassedChecks(result *ConfigValidationResult) int {
|
||||
// Rough estimate: total possible checks minus errors and warnings
|
||||
totalPossibleChecks := 50 // Approximate number of validation checks
|
||||
return totalPossibleChecks - len(result.Errors) - len(result.Warnings)
|
||||
}
|
||||
713
cmd/server/helpers.go
Normal file
713
cmd/server/helpers.go
Normal file
@@ -0,0 +1,713 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/dutchcoders/go-clamd"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/shirou/gopsutil/cpu"
|
||||
"github.com/shirou/gopsutil/mem"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
)
|
||||
|
||||
// WorkerPool represents a pool of workers
|
||||
type WorkerPool struct {
|
||||
workers int
|
||||
taskQueue chan UploadTask
|
||||
scanQueue chan ScanTask
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewWorkerPool creates a new worker pool
|
||||
func NewWorkerPool(workers int, queueSize int) *WorkerPool {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &WorkerPool{
|
||||
workers: workers,
|
||||
taskQueue: make(chan UploadTask, queueSize),
|
||||
scanQueue: make(chan ScanTask, queueSize),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the worker pool
|
||||
func (wp *WorkerPool) Start() {
|
||||
for i := 0; i < wp.workers; i++ {
|
||||
go wp.worker()
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the worker pool
|
||||
func (wp *WorkerPool) Stop() {
|
||||
wp.cancel()
|
||||
close(wp.taskQueue)
|
||||
close(wp.scanQueue)
|
||||
}
|
||||
|
||||
// worker is the worker function
|
||||
func (wp *WorkerPool) worker() {
|
||||
for {
|
||||
select {
|
||||
case <-wp.ctx.Done():
|
||||
return
|
||||
case task := <-wp.taskQueue:
|
||||
if task.Result != nil {
|
||||
task.Result <- nil // Simple implementation
|
||||
}
|
||||
case scanTask := <-wp.scanQueue:
|
||||
err := processScan(scanTask)
|
||||
if scanTask.Result != nil {
|
||||
scanTask.Result <- err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stub for precacheStoragePath
|
||||
func precacheStoragePath(storagePath string) error {
|
||||
// TODO: Implement actual pre-caching logic
|
||||
// This would typically involve walking the storagePath
|
||||
// and loading file information into a cache.
|
||||
log.Infof("Pre-caching for storage path '%s' is a stub and not yet implemented.", storagePath)
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkFreeSpaceWithRetry(path string, retries int, delay time.Duration) error {
|
||||
for i := 0; i < retries; i++ {
|
||||
minFreeBytes, err := parseSize(conf.Server.MinFreeBytes)
|
||||
if err != nil {
|
||||
log.Fatalf("Invalid MinFreeBytes: %v", err)
|
||||
}
|
||||
if err := checkStorageSpace(path, minFreeBytes); err != nil {
|
||||
log.Warnf("Free space check failed (attempt %d/%d): %v", i+1, retries, err)
|
||||
time.Sleep(delay)
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("insufficient free space after %d attempts", retries)
|
||||
}
|
||||
|
||||
func handleFileCleanup(conf *Config) {
|
||||
if !conf.Server.FileTTLEnabled {
|
||||
log.Println("File TTL is disabled.")
|
||||
return
|
||||
}
|
||||
|
||||
ttlDuration, err := parseTTL(conf.Server.FileTTL)
|
||||
if err != nil {
|
||||
log.Fatalf("Invalid TTL configuration: %v", err)
|
||||
}
|
||||
|
||||
log.Printf("TTL cleanup enabled. Files older than %v will be deleted.", ttlDuration)
|
||||
ticker := time.NewTicker(24 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
deleteOldFiles(conf, ttlDuration)
|
||||
}
|
||||
}
|
||||
|
||||
func computeSHA256(ctx context.Context, filePath string) (string, error) {
|
||||
if filePath == "" {
|
||||
return "", fmt.Errorf("file path is empty")
|
||||
}
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to open file %s: %w", filePath, err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
hasher := sha256.New()
|
||||
if _, err := io.Copy(hasher, file); err != nil {
|
||||
return "", fmt.Errorf("failed to hash file: %w", err)
|
||||
}
|
||||
return hex.EncodeToString(hasher.Sum(nil)), nil
|
||||
}
|
||||
|
||||
func handleDeduplication(ctx context.Context, absFilename string) error {
|
||||
checksum, err := computeSHA256(ctx, absFilename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dedupDir := conf.Deduplication.Directory
|
||||
if dedupDir == "" {
|
||||
return fmt.Errorf("deduplication directory not configured")
|
||||
}
|
||||
|
||||
dedupPath := filepath.Join(dedupDir, checksum)
|
||||
if err := os.MkdirAll(dedupPath, os.ModePerm); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
existingPath := filepath.Join(dedupPath, filepath.Base(absFilename))
|
||||
if _, err := os.Stat(existingPath); err == nil {
|
||||
return os.Link(existingPath, absFilename)
|
||||
}
|
||||
|
||||
if err := os.Rename(absFilename, existingPath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.Link(existingPath, absFilename)
|
||||
}
|
||||
|
||||
func handleISOContainer(absFilename string) error {
|
||||
isoPath := filepath.Join(conf.ISO.MountPoint, "container.iso")
|
||||
if err := CreateISOContainer([]string{absFilename}, isoPath, conf.ISO.Size, conf.ISO.Charset); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := MountISOContainer(isoPath, conf.ISO.MountPoint); err != nil {
|
||||
return err
|
||||
}
|
||||
return UnmountISOContainer(conf.ISO.MountPoint)
|
||||
}
|
||||
|
||||
func sanitizeFilePath(baseDir, filePath string) (string, error) {
|
||||
absBaseDir, err := filepath.Abs(baseDir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
absFilePath, err := filepath.Abs(filepath.Join(absBaseDir, filePath))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !strings.HasPrefix(absFilePath, absBaseDir) {
|
||||
return "", fmt.Errorf("invalid file path: %s", filePath)
|
||||
}
|
||||
return absFilePath, nil
|
||||
}
|
||||
|
||||
// Stub for formatBytes
|
||||
func formatBytes(bytes int64) string {
|
||||
const unit = 1024
|
||||
if bytes < unit {
|
||||
return fmt.Sprintf("%d B", bytes)
|
||||
}
|
||||
div, exp := int64(unit), 0
|
||||
for n := bytes / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
|
||||
// Stub for deleteOldFiles
|
||||
func deleteOldFiles(conf *Config, ttlDuration time.Duration) {
|
||||
// TODO: Implement actual file deletion logic based on TTL
|
||||
log.Infof("deleteOldFiles is a stub and not yet implemented. It would check for files older than %v.", ttlDuration)
|
||||
}
|
||||
|
||||
// Stub for CreateISOContainer
|
||||
func CreateISOContainer(files []string, isoPath, size, charset string) error {
|
||||
// TODO: Implement actual ISO container creation logic
|
||||
log.Infof("CreateISOContainer is a stub and not yet implemented. It would create an ISO at %s.", isoPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stub for MountISOContainer
|
||||
func MountISOContainer(isoPath, mountPoint string) error {
|
||||
// TODO: Implement actual ISO container mounting logic
|
||||
log.Infof("MountISOContainer is a stub and not yet implemented. It would mount %s to %s.", isoPath, mountPoint)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stub for UnmountISOContainer
|
||||
func UnmountISOContainer(mountPoint string) error {
|
||||
// TODO: Implement actual ISO container unmounting logic
|
||||
log.Infof("UnmountISOContainer is a stub and not yet implemented. It would unmount %s.", mountPoint)
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkStorageSpace(storagePath string, minFreeBytes int64) error {
|
||||
var stat syscall.Statfs_t
|
||||
if err := syscall.Statfs(storagePath, &stat); err != nil {
|
||||
return err
|
||||
}
|
||||
availableBytes := stat.Bavail * uint64(stat.Bsize)
|
||||
if int64(availableBytes) < minFreeBytes {
|
||||
return fmt.Errorf("not enough space: available %d < required %d", availableBytes, minFreeBytes)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupLogging initializes logging configuration
|
||||
func setupLogging() {
|
||||
log.Infof("DEBUG: Starting setupLogging function")
|
||||
if conf.Logging.File != "" {
|
||||
log.Infof("DEBUG: Setting up file logging to: %s", conf.Logging.File)
|
||||
log.SetOutput(&lumberjack.Logger{
|
||||
Filename: conf.Logging.File,
|
||||
MaxSize: conf.Logging.MaxSize,
|
||||
MaxBackups: conf.Logging.MaxBackups,
|
||||
MaxAge: conf.Logging.MaxAge,
|
||||
Compress: conf.Logging.Compress,
|
||||
})
|
||||
log.Infof("Logging configured to file: %s", conf.Logging.File)
|
||||
}
|
||||
log.Infof("DEBUG: setupLogging function completed")
|
||||
}
|
||||
|
||||
// logSystemInfo logs system information
|
||||
func logSystemInfo() {
|
||||
memStats, err := mem.VirtualMemory()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to get memory stats: %v", err)
|
||||
} else {
|
||||
log.Infof("System Memory: Total=%s, Available=%s, Used=%.1f%%",
|
||||
formatBytes(int64(memStats.Total)),
|
||||
formatBytes(int64(memStats.Available)),
|
||||
memStats.UsedPercent)
|
||||
}
|
||||
|
||||
cpuStats, err := cpu.Info()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to get CPU stats: %v", err)
|
||||
} else if len(cpuStats) > 0 {
|
||||
log.Infof("CPU: %s, Cores=%d", cpuStats[0].ModelName, len(cpuStats))
|
||||
}
|
||||
|
||||
log.Infof("Go Runtime: Version=%s, NumCPU=%d, NumGoroutine=%d",
|
||||
runtime.Version(), runtime.NumCPU(), runtime.NumGoroutine())
|
||||
}
|
||||
|
||||
// initMetrics initializes Prometheus metrics
|
||||
func initMetrics() {
|
||||
uploadDuration = prometheus.NewHistogram(prometheus.HistogramOpts{
|
||||
Name: "upload_duration_seconds",
|
||||
Help: "Duration of upload operations in seconds",
|
||||
})
|
||||
|
||||
uploadErrorsTotal = prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Name: "upload_errors_total",
|
||||
Help: "Total number of upload errors",
|
||||
})
|
||||
|
||||
uploadsTotal = prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Name: "uploads_total",
|
||||
Help: "Total number of uploads",
|
||||
})
|
||||
|
||||
downloadDuration = prometheus.NewHistogram(prometheus.HistogramOpts{
|
||||
Name: "download_duration_seconds",
|
||||
Help: "Duration of download operations in seconds",
|
||||
})
|
||||
|
||||
downloadsTotal = prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Name: "downloads_total",
|
||||
Help: "Total number of downloads",
|
||||
})
|
||||
|
||||
downloadErrorsTotal = prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Name: "download_errors_total",
|
||||
Help: "Total number of download errors",
|
||||
})
|
||||
|
||||
memoryUsage = prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "memory_usage_percent",
|
||||
Help: "Current memory usage percentage",
|
||||
})
|
||||
|
||||
cpuUsage = prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "cpu_usage_percent",
|
||||
Help: "Current CPU usage percentage",
|
||||
})
|
||||
|
||||
activeConnections = prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "active_connections_total",
|
||||
Help: "Number of active connections",
|
||||
})
|
||||
|
||||
requestsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "requests_total",
|
||||
Help: "Total number of requests",
|
||||
}, []string{"method", "status"})
|
||||
|
||||
goroutines = prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "goroutines_total",
|
||||
Help: "Number of goroutines",
|
||||
})
|
||||
|
||||
uploadSizeBytes = prometheus.NewHistogram(prometheus.HistogramOpts{
|
||||
Name: "upload_size_bytes",
|
||||
Help: "Size of uploaded files in bytes",
|
||||
})
|
||||
|
||||
downloadSizeBytes = prometheus.NewHistogram(prometheus.HistogramOpts{
|
||||
Name: "download_size_bytes",
|
||||
Help: "Size of downloaded files in bytes",
|
||||
})
|
||||
|
||||
filesDeduplicatedTotal = prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Name: "files_deduplicated_total",
|
||||
Help: "Total number of deduplicated files",
|
||||
})
|
||||
|
||||
deduplicationErrorsTotal = prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Name: "deduplication_errors_total",
|
||||
Help: "Total number of deduplication errors",
|
||||
})
|
||||
|
||||
isoContainersCreatedTotal = prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Name: "iso_containers_created_total",
|
||||
Help: "Total number of ISO containers created",
|
||||
})
|
||||
|
||||
isoCreationErrorsTotal = prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Name: "iso_creation_errors_total",
|
||||
Help: "Total number of ISO creation errors",
|
||||
})
|
||||
|
||||
isoContainersMountedTotal = prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Name: "iso_containers_mounted_total",
|
||||
Help: "Total number of ISO containers mounted",
|
||||
})
|
||||
|
||||
isoMountErrorsTotal = prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Name: "iso_mount_errors_total",
|
||||
Help: "Total number of ISO mount errors",
|
||||
})
|
||||
|
||||
workerAdjustmentsTotal = prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Name: "worker_adjustments_total",
|
||||
Help: "Total number of worker adjustments",
|
||||
})
|
||||
|
||||
workerReAdjustmentsTotal = prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Name: "worker_readjustments_total",
|
||||
Help: "Total number of worker readjustments",
|
||||
})
|
||||
|
||||
// Register all metrics
|
||||
prometheus.MustRegister(
|
||||
uploadDuration, uploadErrorsTotal, uploadsTotal,
|
||||
downloadDuration, downloadsTotal, downloadErrorsTotal,
|
||||
memoryUsage, cpuUsage, activeConnections, requestsTotal,
|
||||
goroutines, uploadSizeBytes, downloadSizeBytes,
|
||||
filesDeduplicatedTotal, deduplicationErrorsTotal,
|
||||
isoContainersCreatedTotal, isoCreationErrorsTotal,
|
||||
isoContainersMountedTotal, isoMountErrorsTotal,
|
||||
workerAdjustmentsTotal, workerReAdjustmentsTotal,
|
||||
)
|
||||
|
||||
log.Info("Prometheus metrics initialized successfully")
|
||||
}
|
||||
|
||||
// scanFileWithClamAV scans a file using ClamAV
|
||||
func scanFileWithClamAV(filename string) error {
|
||||
if clamClient == nil {
|
||||
return fmt.Errorf("ClamAV client not initialized")
|
||||
}
|
||||
|
||||
result, err := clamClient.ScanFile(filename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ClamAV scan failed: %w", err)
|
||||
}
|
||||
|
||||
// Handle the result channel
|
||||
if result != nil {
|
||||
select {
|
||||
case scanResult := <-result:
|
||||
if scanResult != nil && scanResult.Status != "OK" {
|
||||
return fmt.Errorf("virus detected in %s: %s", filename, scanResult.Status)
|
||||
}
|
||||
case <-time.After(30 * time.Second):
|
||||
return fmt.Errorf("ClamAV scan timeout for file: %s", filename)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("File %s passed ClamAV scan", filename)
|
||||
return nil
|
||||
}
|
||||
|
||||
// initClamAV initializes ClamAV client
|
||||
func initClamAV(socketPath string) (*clamd.Clamd, error) {
|
||||
if socketPath == "" {
|
||||
socketPath = "/var/run/clamav/clamd.ctl"
|
||||
}
|
||||
|
||||
client := clamd.NewClamd(socketPath)
|
||||
|
||||
// Test connection
|
||||
err := client.Ping()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to ping ClamAV daemon: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("ClamAV client initialized with socket: %s", socketPath)
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// initRedis initializes Redis client
|
||||
func initRedis() {
|
||||
redisClient = redis.NewClient(&redis.Options{
|
||||
Addr: conf.Redis.RedisAddr,
|
||||
Password: conf.Redis.RedisPassword,
|
||||
DB: conf.Redis.RedisDBIndex,
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, err := redisClient.Ping(ctx).Result()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to connect to Redis: %v", err)
|
||||
redisConnected = false
|
||||
} else {
|
||||
log.Info("Redis client initialized successfully")
|
||||
redisConnected = true
|
||||
}
|
||||
}
|
||||
|
||||
// monitorNetwork monitors network events
|
||||
func monitorNetwork(ctx context.Context) {
|
||||
log.Info("Starting network monitoring")
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info("Network monitoring stopped")
|
||||
return
|
||||
case <-ticker.C:
|
||||
// Simple network monitoring - check interface status
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to get network interfaces: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, iface := range interfaces {
|
||||
if iface.Flags&net.FlagUp != 0 && iface.Flags&net.FlagLoopback == 0 {
|
||||
select {
|
||||
case networkEvents <- NetworkEvent{
|
||||
Type: "interface_up",
|
||||
Details: fmt.Sprintf("Interface %s is up", iface.Name),
|
||||
}:
|
||||
default:
|
||||
// Channel full, skip
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleNetworkEvents handles network events
|
||||
func handleNetworkEvents(ctx context.Context) {
|
||||
log.Info("Starting network event handler")
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info("Network event handler stopped")
|
||||
return
|
||||
case event := <-networkEvents:
|
||||
log.Debugf("Network event: %s - %s", event.Type, event.Details)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// updateSystemMetrics updates system metrics
|
||||
func updateSystemMetrics(ctx context.Context) {
|
||||
ticker := time.NewTicker(15 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
// Update memory metrics
|
||||
if memStats, err := mem.VirtualMemory(); err == nil {
|
||||
memoryUsage.Set(memStats.UsedPercent)
|
||||
}
|
||||
|
||||
// Update CPU metrics
|
||||
if cpuPercents, err := cpu.Percent(time.Second, false); err == nil && len(cpuPercents) > 0 {
|
||||
cpuUsage.Set(cpuPercents[0])
|
||||
}
|
||||
|
||||
// Update goroutine count
|
||||
goroutines.Set(float64(runtime.NumGoroutine()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// setupRouter sets up HTTP routes
|
||||
func setupRouter() *http.ServeMux {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
mux.HandleFunc("/upload", handleUpload)
|
||||
mux.HandleFunc("/download/", handleDownload)
|
||||
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
})
|
||||
|
||||
if conf.Server.MetricsEnabled {
|
||||
mux.Handle("/metrics", promhttp.Handler())
|
||||
}
|
||||
|
||||
// Catch-all handler for all upload protocols (v, v2, token, v3)
|
||||
// This must be added last as it matches all paths
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
// Handle PUT requests for all upload protocols
|
||||
if r.Method == http.MethodPut {
|
||||
query := r.URL.Query()
|
||||
|
||||
// Check if this is a v3 request (mod_http_upload_external)
|
||||
if query.Get("v3") != "" && query.Get("expires") != "" {
|
||||
handleV3Upload(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if this is a legacy protocol request (v, v2, token)
|
||||
if query.Get("v") != "" || query.Get("v2") != "" || query.Get("token") != "" {
|
||||
handleLegacyUpload(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Handle GET/HEAD requests for downloads
|
||||
if r.Method == http.MethodGet || r.Method == http.MethodHead {
|
||||
// Only handle download requests if the path looks like a file
|
||||
path := strings.TrimPrefix(r.URL.Path, "/")
|
||||
if path != "" && !strings.HasSuffix(path, "/") {
|
||||
handleLegacyDownload(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// For all other requests, return 404
|
||||
http.NotFound(w, r)
|
||||
})
|
||||
|
||||
log.Info("HTTP router configured successfully with full protocol support (v, v2, token, v3)")
|
||||
return mux
|
||||
}
|
||||
|
||||
// setupGracefulShutdown sets up graceful shutdown
|
||||
func setupGracefulShutdown(server *http.Server, cancel context.CancelFunc) {
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
go func() {
|
||||
<-sigChan
|
||||
log.Info("Received shutdown signal, initiating graceful shutdown...")
|
||||
|
||||
// Cancel context
|
||||
cancel()
|
||||
|
||||
// Shutdown server with timeout
|
||||
ctx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer shutdownCancel()
|
||||
|
||||
if err := server.Shutdown(ctx); err != nil {
|
||||
log.Errorf("Server shutdown error: %v", err)
|
||||
} else {
|
||||
log.Info("Server shutdown completed")
|
||||
}
|
||||
|
||||
// Clean up PID file
|
||||
if conf.Server.CleanUponExit {
|
||||
removePIDFile(conf.Server.PIDFilePath)
|
||||
}
|
||||
|
||||
// Stop worker pool if it exists
|
||||
if workerPool != nil {
|
||||
workerPool.Stop()
|
||||
log.Info("Worker pool stopped")
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
}()
|
||||
}
|
||||
|
||||
// ProgressWriter wraps an io.Writer to provide upload progress reporting
|
||||
type ProgressWriter struct {
|
||||
dst io.Writer
|
||||
total int64
|
||||
written int64
|
||||
filename string
|
||||
onProgress func(written, total int64, filename string)
|
||||
lastReport time.Time
|
||||
}
|
||||
|
||||
// NewProgressWriter creates a new ProgressWriter
|
||||
func NewProgressWriter(dst io.Writer, total int64, filename string) *ProgressWriter {
|
||||
return &ProgressWriter{
|
||||
dst: dst,
|
||||
total: total,
|
||||
filename: filename,
|
||||
onProgress: func(written, total int64, filename string) {
|
||||
if total > 0 {
|
||||
percentage := float64(written) / float64(total) * 100
|
||||
sizeMiB := float64(written) / (1024 * 1024)
|
||||
totalMiB := float64(total) / (1024 * 1024)
|
||||
log.Infof("Upload progress for %s: %.1f%% (%.1f/%.1f MiB)",
|
||||
filepath.Base(filename), percentage, sizeMiB, totalMiB)
|
||||
}
|
||||
},
|
||||
lastReport: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Write implements io.Writer interface with progress reporting
|
||||
func (pw *ProgressWriter) Write(p []byte) (int, error) {
|
||||
n, err := pw.dst.Write(p)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
pw.written += int64(n)
|
||||
|
||||
// Report progress every 30 seconds or every 50MB for large files
|
||||
now := time.Now()
|
||||
shouldReport := false
|
||||
|
||||
if pw.total > 100*1024*1024 { // Files larger than 100MB
|
||||
shouldReport = now.Sub(pw.lastReport) > 30*time.Second ||
|
||||
(pw.written%(50*1024*1024) == 0 && pw.written > 0)
|
||||
} else if pw.total > 10*1024*1024 { // Files larger than 10MB
|
||||
shouldReport = now.Sub(pw.lastReport) > 10*time.Second ||
|
||||
(pw.written%(10*1024*1024) == 0 && pw.written > 0)
|
||||
}
|
||||
|
||||
if shouldReport && pw.onProgress != nil {
|
||||
pw.onProgress(pw.written, pw.total, pw.filename)
|
||||
pw.lastReport = now
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
// copyWithProgress copies data from src to dst with progress reporting
|
||||
func copyWithProgress(dst io.Writer, src io.Reader, total int64, filename string) (int64, error) {
|
||||
progressWriter := NewProgressWriter(dst, total, filename)
|
||||
|
||||
// Use a pooled buffer for efficient copying
|
||||
bufPtr := bufferPool.Get().(*[]byte)
|
||||
defer bufferPool.Put(bufPtr)
|
||||
buf := *bufPtr
|
||||
|
||||
return io.CopyBuffer(progressWriter, src, buf)
|
||||
}
|
||||
3373
cmd/server/main.go
3373
cmd/server/main.go
@@ -8,140 +8,285 @@ import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"sync"
|
||||
|
||||
"github.com/dutchcoders/go-clamd" // ClamAV integration
|
||||
"github.com/go-redis/redis/v8" // Redis integration
|
||||
jwt "github.com/golang-jwt/jwt/v5"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/shirou/gopsutil/cpu"
|
||||
"github.com/shirou/gopsutil/disk"
|
||||
"github.com/shirou/gopsutil/host"
|
||||
"github.com/shirou/gopsutil/mem"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// Configuration structure
|
||||
// parseSize converts a human-readable size string to bytes
|
||||
func parseSize(sizeStr string) (int64, error) {
|
||||
sizeStr = strings.TrimSpace(sizeStr)
|
||||
if len(sizeStr) < 2 {
|
||||
return 0, fmt.Errorf("invalid size string: %s", sizeStr)
|
||||
}
|
||||
|
||||
unit := strings.ToUpper(sizeStr[len(sizeStr)-2:])
|
||||
valueStr := sizeStr[:len(sizeStr)-2]
|
||||
value, err := strconv.Atoi(valueStr)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid size value: %v", err)
|
||||
}
|
||||
|
||||
switch unit {
|
||||
case "KB":
|
||||
return int64(value) * 1024, nil
|
||||
case "MB":
|
||||
return int64(value) * 1024 * 1024, nil
|
||||
case "GB":
|
||||
return int64(value) * 1024 * 1024 * 1024, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown size unit: %s", unit)
|
||||
}
|
||||
}
|
||||
|
||||
// parseTTL converts a human-readable TTL string to a time.Duration
|
||||
func parseTTL(ttlStr string) (time.Duration, error) {
|
||||
ttlStr = strings.ToLower(strings.TrimSpace(ttlStr))
|
||||
if ttlStr == "" {
|
||||
return 0, fmt.Errorf("TTL string cannot be empty")
|
||||
}
|
||||
var valueStr string
|
||||
var unit rune
|
||||
for _, r := range ttlStr {
|
||||
if r >= '0' && r <= '9' {
|
||||
valueStr += string(r)
|
||||
} else {
|
||||
unit = r
|
||||
break
|
||||
}
|
||||
}
|
||||
val, err := strconv.Atoi(valueStr)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid TTL value: %v", err)
|
||||
}
|
||||
switch unit {
|
||||
case 's':
|
||||
return time.Duration(val) * time.Second, nil
|
||||
case 'm':
|
||||
return time.Duration(val) * time.Minute, nil
|
||||
case 'h':
|
||||
return time.Duration(val) * time.Hour, nil
|
||||
case 'd':
|
||||
return time.Duration(val) * 24 * time.Hour, nil
|
||||
case 'w':
|
||||
return time.Duration(val) * 7 * 24 * time.Hour, nil
|
||||
case 'y':
|
||||
return time.Duration(val) * 365 * 24 * time.Hour, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown TTL unit: %c", unit)
|
||||
}
|
||||
}
|
||||
|
||||
// Configuration structures
|
||||
type ServerConfig struct {
|
||||
ListenPort string `mapstructure:"ListenPort"`
|
||||
UnixSocket bool `mapstructure:"UnixSocket"`
|
||||
StoragePath string `mapstructure:"StoragePath"`
|
||||
LogLevel string `mapstructure:"LogLevel"`
|
||||
LogFile string `mapstructure:"LogFile"`
|
||||
MetricsEnabled bool `mapstructure:"MetricsEnabled"`
|
||||
MetricsPort string `mapstructure:"MetricsPort"`
|
||||
FileTTL string `mapstructure:"FileTTL"`
|
||||
MinFreeBytes int64 `mapstructure:"MinFreeBytes"` // Minimum free bytes required
|
||||
DeduplicationEnabled bool `mapstructure:"DeduplicationEnabled"`
|
||||
}
|
||||
|
||||
type TimeoutConfig struct {
|
||||
ReadTimeout string `mapstructure:"ReadTimeout"`
|
||||
WriteTimeout string `mapstructure:"WriteTimeout"`
|
||||
IdleTimeout string `mapstructure:"IdleTimeout"`
|
||||
}
|
||||
|
||||
type SecurityConfig struct {
|
||||
Secret string `mapstructure:"Secret"`
|
||||
}
|
||||
|
||||
type VersioningConfig struct {
|
||||
EnableVersioning bool `mapstructure:"EnableVersioning"`
|
||||
MaxVersions int `mapstructure:"MaxVersions"`
|
||||
ListenAddress string `toml:"listenport" mapstructure:"listenport"` // Fixed to match config file field
|
||||
StoragePath string `toml:"storagepath" mapstructure:"storagepath"` // Fixed to match config
|
||||
MetricsEnabled bool `toml:"metricsenabled" mapstructure:"metricsenabled"` // Fixed to match config
|
||||
MetricsPath string `toml:"metrics_path" mapstructure:"metrics_path"`
|
||||
PidFile string `toml:"pid_file" mapstructure:"pid_file"`
|
||||
MaxUploadSize string `toml:"max_upload_size" mapstructure:"max_upload_size"`
|
||||
MaxHeaderBytes int `toml:"max_header_bytes" mapstructure:"max_header_bytes"`
|
||||
CleanupInterval string `toml:"cleanup_interval" mapstructure:"cleanup_interval"`
|
||||
MaxFileAge string `toml:"max_file_age" mapstructure:"max_file_age"`
|
||||
PreCache bool `toml:"pre_cache" mapstructure:"pre_cache"`
|
||||
PreCacheWorkers int `toml:"pre_cache_workers" mapstructure:"pre_cache_workers"`
|
||||
PreCacheInterval string `toml:"pre_cache_interval" mapstructure:"pre_cache_interval"`
|
||||
GlobalExtensions []string `toml:"global_extensions" mapstructure:"global_extensions"`
|
||||
DeduplicationEnabled bool `toml:"deduplication_enabled" mapstructure:"deduplication_enabled"`
|
||||
MinFreeBytes string `toml:"min_free_bytes" mapstructure:"min_free_bytes"`
|
||||
FileNaming string `toml:"file_naming" mapstructure:"file_naming"`
|
||||
ForceProtocol string `toml:"force_protocol" mapstructure:"force_protocol"`
|
||||
EnableDynamicWorkers bool `toml:"enable_dynamic_workers" mapstructure:"enable_dynamic_workers"`
|
||||
WorkerScaleUpThresh int `toml:"worker_scale_up_thresh" mapstructure:"worker_scale_up_thresh"`
|
||||
WorkerScaleDownThresh int `toml:"worker_scale_down_thresh" mapstructure:"worker_scale_down_thresh"`
|
||||
UnixSocket bool `toml:"unixsocket" mapstructure:"unixsocket"` // Added missing field from example/logs
|
||||
MetricsPort string `toml:"metricsport" mapstructure:"metricsport"` // Fixed to match config
|
||||
FileTTL string `toml:"filettl" mapstructure:"filettl"` // Fixed to match config
|
||||
FileTTLEnabled bool `toml:"filettlenabled" mapstructure:"filettlenabled"` // Fixed to match config
|
||||
AutoAdjustWorkers bool `toml:"autoadjustworkers" mapstructure:"autoadjustworkers"` // Fixed to match config
|
||||
NetworkEvents bool `toml:"networkevents" mapstructure:"networkevents"` // Fixed to match config
|
||||
PIDFilePath string `toml:"pidfilepath" mapstructure:"pidfilepath"` // Fixed to match config
|
||||
CleanUponExit bool `toml:"clean_upon_exit" mapstructure:"clean_upon_exit"` // Added missing field
|
||||
PreCaching bool `toml:"precaching" mapstructure:"precaching"` // Fixed to match config
|
||||
BindIP string `toml:"bind_ip" mapstructure:"bind_ip"` // Added missing field
|
||||
}
|
||||
|
||||
type UploadsConfig struct {
|
||||
ResumableUploadsEnabled bool `mapstructure:"ResumableUploadsEnabled"`
|
||||
ChunkedUploadsEnabled bool `mapstructure:"ChunkedUploadsEnabled"`
|
||||
ChunkSize int64 `mapstructure:"ChunkSize"`
|
||||
AllowedExtensions []string `mapstructure:"AllowedExtensions"`
|
||||
AllowedExtensions []string `toml:"allowedextensions" mapstructure:"allowedextensions"`
|
||||
ChunkedUploadsEnabled bool `toml:"chunkeduploadsenabled" mapstructure:"chunkeduploadsenabled"`
|
||||
ChunkSize string `toml:"chunksize" mapstructure:"chunksize"`
|
||||
ResumableUploadsEnabled bool `toml:"resumableuploadsenabled" mapstructure:"resumableuploadsenabled"`
|
||||
MaxResumableAge string `toml:"max_resumable_age" mapstructure:"max_resumable_age"`
|
||||
}
|
||||
|
||||
type DownloadsConfig struct {
|
||||
AllowedExtensions []string `toml:"allowedextensions" mapstructure:"allowedextensions"`
|
||||
ChunkedDownloadsEnabled bool `toml:"chunkeddownloadsenabled" mapstructure:"chunkeddownloadsenabled"`
|
||||
ChunkSize string `toml:"chunksize" mapstructure:"chunksize"`
|
||||
ResumableDownloadsEnabled bool `toml:"resumable_downloads_enabled" mapstructure:"resumable_downloads_enabled"`
|
||||
}
|
||||
|
||||
type SecurityConfig struct {
|
||||
Secret string `toml:"secret" mapstructure:"secret"`
|
||||
EnableJWT bool `toml:"enablejwt" mapstructure:"enablejwt"` // Added EnableJWT field
|
||||
JWTSecret string `toml:"jwtsecret" mapstructure:"jwtsecret"`
|
||||
JWTAlgorithm string `toml:"jwtalgorithm" mapstructure:"jwtalgorithm"`
|
||||
JWTExpiration string `toml:"jwtexpiration" mapstructure:"jwtexpiration"`
|
||||
}
|
||||
|
||||
type LoggingConfig struct {
|
||||
Level string `mapstructure:"level"`
|
||||
File string `mapstructure:"file"`
|
||||
MaxSize int `mapstructure:"max_size"`
|
||||
MaxBackups int `mapstructure:"max_backups"`
|
||||
MaxAge int `mapstructure:"max_age"`
|
||||
Compress bool `mapstructure:"compress"`
|
||||
}
|
||||
|
||||
type DeduplicationConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
Directory string `mapstructure:"directory"`
|
||||
}
|
||||
|
||||
type ISOConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
MountPoint string `mapstructure:"mountpoint"`
|
||||
Size string `mapstructure:"size"`
|
||||
Charset string `mapstructure:"charset"`
|
||||
ContainerFile string `mapstructure:"containerfile"` // Added missing field
|
||||
}
|
||||
|
||||
type TimeoutConfig struct {
|
||||
Read string `mapstructure:"readtimeout" toml:"readtimeout"`
|
||||
Write string `mapstructure:"writetimeout" toml:"writetimeout"`
|
||||
Idle string `mapstructure:"idletimeout" toml:"idletimeout"`
|
||||
Shutdown string `mapstructure:"shutdown" toml:"shutdown"`
|
||||
}
|
||||
|
||||
type VersioningConfig struct {
|
||||
Enabled bool `mapstructure:"enableversioning" toml:"enableversioning"` // Corrected to match example config
|
||||
Backend string `mapstructure:"backend" toml:"backend"`
|
||||
MaxRevs int `mapstructure:"maxversions" toml:"maxversions"` // Corrected to match example config
|
||||
}
|
||||
|
||||
type ClamAVConfig struct {
|
||||
ClamAVEnabled bool `mapstructure:"ClamAVEnabled"`
|
||||
ClamAVSocket string `mapstructure:"ClamAVSocket"`
|
||||
NumScanWorkers int `mapstructure:"NumScanWorkers"`
|
||||
ClamAVEnabled bool `mapstructure:"clamavenabled"`
|
||||
ClamAVSocket string `mapstructure:"clamavsocket"`
|
||||
NumScanWorkers int `mapstructure:"numscanworkers"`
|
||||
ScanFileExtensions []string `mapstructure:"scanfileextensions"`
|
||||
}
|
||||
|
||||
type RedisConfig struct {
|
||||
RedisEnabled bool `mapstructure:"RedisEnabled"`
|
||||
RedisDBIndex int `mapstructure:"RedisDBIndex"`
|
||||
RedisAddr string `mapstructure:"RedisAddr"`
|
||||
RedisPassword string `mapstructure:"RedisPassword"`
|
||||
RedisHealthCheckInterval string `mapstructure:"RedisHealthCheckInterval"`
|
||||
RedisEnabled bool `mapstructure:"redisenabled"`
|
||||
RedisDBIndex int `mapstructure:"redisdbindex"`
|
||||
RedisAddr string `mapstructure:"redisaddr"`
|
||||
RedisPassword string `mapstructure:"redispassword"`
|
||||
RedisHealthCheckInterval string `mapstructure:"redishealthcheckinterval"`
|
||||
}
|
||||
|
||||
type WorkersConfig struct {
|
||||
NumWorkers int `mapstructure:"NumWorkers"`
|
||||
UploadQueueSize int `mapstructure:"UploadQueueSize"`
|
||||
NumWorkers int `mapstructure:"numworkers"`
|
||||
UploadQueueSize int `mapstructure:"uploadqueuesize"`
|
||||
}
|
||||
|
||||
type FileConfig struct {
|
||||
FileRevision int `mapstructure:"FileRevision"`
|
||||
}
|
||||
|
||||
type BuildConfig struct {
|
||||
Version string `mapstructure:"version"` // Updated version
|
||||
}
|
||||
|
||||
// This is the main Config struct to be used
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Timeouts TimeoutConfig `mapstructure:"timeouts"`
|
||||
Security SecurityConfig `mapstructure:"security"`
|
||||
Versioning VersioningConfig `mapstructure:"versioning"`
|
||||
Uploads UploadsConfig `mapstructure:"uploads"`
|
||||
ClamAV ClamAVConfig `mapstructure:"clamav"`
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
Workers WorkersConfig `mapstructure:"workers"`
|
||||
File FileConfig `mapstructure:"file"`
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Logging LoggingConfig `mapstructure:"logging"`
|
||||
Deduplication DeduplicationConfig `mapstructure:"deduplication"` // Added
|
||||
ISO ISOConfig `mapstructure:"iso"` // Added
|
||||
Timeouts TimeoutConfig `mapstructure:"timeouts"` // Added
|
||||
Security SecurityConfig `mapstructure:"security"`
|
||||
Versioning VersioningConfig `mapstructure:"versioning"` // Added
|
||||
Uploads UploadsConfig `mapstructure:"uploads"`
|
||||
Downloads DownloadsConfig `mapstructure:"downloads"`
|
||||
ClamAV ClamAVConfig `mapstructure:"clamav"`
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
Workers WorkersConfig `mapstructure:"workers"`
|
||||
File FileConfig `mapstructure:"file"`
|
||||
Build BuildConfig `mapstructure:"build"`
|
||||
}
|
||||
|
||||
// UploadTask represents a file upload task
|
||||
type UploadTask struct {
|
||||
AbsFilename string
|
||||
Request *http.Request
|
||||
Result chan error
|
||||
}
|
||||
|
||||
// ScanTask represents a file scan task
|
||||
type ScanTask struct {
|
||||
AbsFilename string
|
||||
Result chan error
|
||||
}
|
||||
|
||||
// NetworkEvent represents a network-related event
|
||||
type NetworkEvent struct {
|
||||
Type string
|
||||
Details string
|
||||
}
|
||||
|
||||
var (
|
||||
conf Config
|
||||
versionString string = "v2.0-dev"
|
||||
log = logrus.New()
|
||||
uploadQueue chan UploadTask
|
||||
networkEvents chan NetworkEvent
|
||||
fileInfoCache *cache.Cache
|
||||
clamClient *clamd.Clamd // Added for ClamAV integration
|
||||
redisClient *redis.Client // Redis client
|
||||
redisConnected bool // Redis connection status
|
||||
mu sync.RWMutex
|
||||
// Add a new field to store the creation date of files
|
||||
type FileMetadata struct {
|
||||
CreationDate time.Time
|
||||
}
|
||||
|
||||
// processScan processes a scan task
|
||||
func processScan(task ScanTask) error {
|
||||
log.Infof("Started processing scan for file: %s", task.AbsFilename)
|
||||
semaphore <- struct{}{}
|
||||
defer func() { <-semaphore }()
|
||||
|
||||
err := scanFileWithClamAV(task.AbsFilename)
|
||||
if err != nil {
|
||||
log.WithFields(logrus.Fields{"file": task.AbsFilename, "error": err}).Error("Failed to scan file")
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("Finished processing scan for file: %s", task.AbsFilename)
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
conf Config
|
||||
versionString string
|
||||
log = logrus.New()
|
||||
fileInfoCache *cache.Cache
|
||||
fileMetadataCache *cache.Cache
|
||||
clamClient *clamd.Clamd
|
||||
redisClient *redis.Client
|
||||
redisConnected bool
|
||||
confMutex sync.RWMutex // Protects the global 'conf' variable and related critical sections.
|
||||
// Use RLock() for reading, Lock() for writing.
|
||||
|
||||
// Prometheus metrics
|
||||
uploadDuration prometheus.Histogram
|
||||
uploadErrorsTotal prometheus.Counter
|
||||
uploadsTotal prometheus.Counter
|
||||
@@ -156,1768 +301,1692 @@ var (
|
||||
uploadSizeBytes prometheus.Histogram
|
||||
downloadSizeBytes prometheus.Histogram
|
||||
|
||||
// Constants for worker pool
|
||||
MinWorkers = 5 // Increased from 10 to 20 for better concurrency
|
||||
UploadQueueSize = 10000 // Increased from 5000 to 10000
|
||||
filesDeduplicatedTotal prometheus.Counter
|
||||
deduplicationErrorsTotal prometheus.Counter
|
||||
isoContainersCreatedTotal prometheus.Counter
|
||||
isoCreationErrorsTotal prometheus.Counter
|
||||
isoContainersMountedTotal prometheus.Counter
|
||||
isoMountErrorsTotal prometheus.Counter
|
||||
|
||||
// Channels
|
||||
scanQueue chan ScanTask
|
||||
ScanWorkers = 5 // Number of ClamAV scan workers
|
||||
workerPool *WorkerPool
|
||||
networkEvents chan NetworkEvent
|
||||
|
||||
workerAdjustmentsTotal prometheus.Counter
|
||||
workerReAdjustmentsTotal prometheus.Counter
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Set default configuration values
|
||||
setDefaults()
|
||||
var bufferPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
buf := make([]byte, 32*1024)
|
||||
return &buf
|
||||
},
|
||||
}
|
||||
|
||||
const maxConcurrentOperations = 10
|
||||
|
||||
var semaphore = make(chan struct{}, maxConcurrentOperations)
|
||||
|
||||
var logMessages []string
|
||||
var logMu sync.Mutex
|
||||
|
||||
func flushLogMessages() {
|
||||
logMu.Lock()
|
||||
defer logMu.Unlock()
|
||||
for _, msg := range logMessages {
|
||||
log.Info(msg)
|
||||
}
|
||||
logMessages = []string{}
|
||||
}
|
||||
|
||||
// writePIDFile writes the current process ID to the specified pid file
|
||||
func writePIDFile(pidPath string) error {
|
||||
pid := os.Getpid()
|
||||
pidStr := strconv.Itoa(pid)
|
||||
err := os.WriteFile(pidPath, []byte(pidStr), 0644)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write PID file: %v", err) // Improved error logging
|
||||
return err
|
||||
}
|
||||
log.Infof("PID %d written to %s", pid, pidPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// removePIDFile removes the PID file
|
||||
func removePIDFile(pidPath string) {
|
||||
err := os.Remove(pidPath)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to remove PID file: %v", err) // Improved error logging
|
||||
} else {
|
||||
log.Infof("PID file %s removed successfully", pidPath)
|
||||
}
|
||||
}
|
||||
|
||||
// createAndMountISO creates an ISO container and mounts it to the specified mount point
|
||||
func createAndMountISO(size, mountpoint, charset string) error {
|
||||
isoPath := conf.ISO.ContainerFile
|
||||
|
||||
// Create an empty ISO file
|
||||
cmd := exec.Command("dd", "if=/dev/zero", fmt.Sprintf("of=%s", isoPath), fmt.Sprintf("bs=%s", size), "count=1")
|
||||
if err := cmd.Run(); err != nil {
|
||||
isoCreationErrorsTotal.Inc()
|
||||
return fmt.Errorf("failed to create ISO file: %w", err)
|
||||
}
|
||||
|
||||
// Format the ISO file with a filesystem
|
||||
cmd = exec.Command("mkfs", "-t", "iso9660", "-input-charset", charset, isoPath)
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to format ISO file: %w", err)
|
||||
}
|
||||
|
||||
// Create the mount point directory if it doesn't exist
|
||||
if err := os.MkdirAll(mountpoint, os.ModePerm); err != nil {
|
||||
return fmt.Errorf("failed to create mount point: %w", err)
|
||||
}
|
||||
|
||||
// Mount the ISO file
|
||||
cmd = exec.Command("mount", "-o", "loop", isoPath, mountpoint)
|
||||
if err := cmd.Run(); err != nil {
|
||||
isoMountErrorsTotal.Inc()
|
||||
return fmt.Errorf("failed to mount ISO file: %w", err)
|
||||
}
|
||||
|
||||
isoContainersCreatedTotal.Inc()
|
||||
isoContainersMountedTotal.Inc()
|
||||
return nil
|
||||
}
|
||||
|
||||
func initializeNetworkProtocol(forceProtocol string) (*net.Dialer, error) {
|
||||
switch forceProtocol {
|
||||
case "ipv4":
|
||||
return &net.Dialer{
|
||||
Timeout: 5 * time.Second,
|
||||
DualStack: false,
|
||||
Control: func(network, address string, c syscall.RawConn) error {
|
||||
if network == "tcp6" {
|
||||
return fmt.Errorf("IPv6 is disabled by forceprotocol setting")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}, nil
|
||||
case "ipv6":
|
||||
return &net.Dialer{
|
||||
Timeout: 5 * time.Second,
|
||||
DualStack: false,
|
||||
Control: func(network, address string, c syscall.RawConn) error {
|
||||
if network == "tcp4" {
|
||||
return fmt.Errorf("IPv4 is disabled by forceprotocol setting")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}, nil
|
||||
case "auto":
|
||||
return &net.Dialer{
|
||||
Timeout: 5 * time.Second,
|
||||
DualStack: true,
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid forceprotocol value: %s", forceProtocol)
|
||||
}
|
||||
}
|
||||
|
||||
var dualStackClient *http.Client
|
||||
|
||||
func main() {
|
||||
setDefaults() // Call setDefaults before parsing flags or reading config
|
||||
|
||||
// Flags for configuration file
|
||||
var configFile string
|
||||
flag.StringVar(&configFile, "config", "./config.toml", "Path to configuration file \"config.toml\".")
|
||||
var genConfig bool
|
||||
var genConfigPath string
|
||||
var validateOnly bool
|
||||
var runConfigTests bool
|
||||
var validateQuiet bool
|
||||
var validateVerbose bool
|
||||
var validateFixable bool
|
||||
var validateSecurity bool
|
||||
var validatePerformance bool
|
||||
var validateConnectivity bool
|
||||
var listValidationChecks bool
|
||||
var showVersion bool
|
||||
|
||||
flag.BoolVar(&genConfig, "genconfig", false, "Print example configuration and exit.")
|
||||
flag.StringVar(&genConfigPath, "genconfig-path", "", "Write example configuration to the given file and exit.")
|
||||
flag.BoolVar(&validateOnly, "validate-config", false, "Validate configuration and exit without starting server.")
|
||||
flag.BoolVar(&runConfigTests, "test-config", false, "Run configuration validation test scenarios and exit.")
|
||||
flag.BoolVar(&validateQuiet, "validate-quiet", false, "Only show errors during validation (suppress warnings and info).")
|
||||
flag.BoolVar(&validateVerbose, "validate-verbose", false, "Show detailed validation information including system checks.")
|
||||
flag.BoolVar(&validateFixable, "check-fixable", false, "Only show validation issues that can be automatically fixed.")
|
||||
flag.BoolVar(&validateSecurity, "check-security", false, "Run only security-related validation checks.")
|
||||
flag.BoolVar(&validatePerformance, "check-performance", false, "Run only performance-related validation checks.")
|
||||
flag.BoolVar(&validateConnectivity, "check-connectivity", false, "Run only network connectivity validation checks.")
|
||||
flag.BoolVar(&listValidationChecks, "list-checks", false, "List all available validation checks and exit.")
|
||||
flag.BoolVar(&showVersion, "version", false, "Show version information and exit.")
|
||||
flag.Parse()
|
||||
|
||||
// Load configuration
|
||||
if showVersion {
|
||||
fmt.Printf("HMAC File Server v3.2\n")
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
if listValidationChecks {
|
||||
printValidationChecks()
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
if genConfig {
|
||||
printExampleConfig()
|
||||
os.Exit(0)
|
||||
}
|
||||
if genConfigPath != "" {
|
||||
f, err := os.Create(genConfigPath)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to create file: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer f.Close()
|
||||
w := bufio.NewWriter(f)
|
||||
fmt.Fprint(w, getExampleConfigString())
|
||||
w.Flush()
|
||||
fmt.Printf("Example config written to %s\n", genConfigPath)
|
||||
os.Exit(0)
|
||||
}
|
||||
if runConfigTests {
|
||||
RunConfigTests()
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
// Initialize Viper
|
||||
viper.SetConfigType("toml")
|
||||
|
||||
// Set default config path
|
||||
defaultConfigPath := "/etc/hmac-file-server/config.toml"
|
||||
|
||||
// Attempt to load the default config
|
||||
viper.SetConfigFile(defaultConfigPath)
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
// If default config not found, fallback to parent directory
|
||||
parentDirConfig := "../config.toml"
|
||||
viper.SetConfigFile(parentDirConfig)
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
// If still not found and -config is provided, use it
|
||||
if configFile != "" {
|
||||
viper.SetConfigFile(configFile)
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
fmt.Printf("Error loading config file: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
} else {
|
||||
fmt.Println("No configuration file found. Please create a config file with the following content:")
|
||||
printExampleConfig()
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err := readConfig(configFile, &conf)
|
||||
if err != nil {
|
||||
log.Fatalf("Error reading config: %v", err) // Fatal: application cannot proceed
|
||||
log.Fatalf("Failed to load configuration: %v\nPlease ensure your config.toml is present at one of the following paths:\n%v", err, []string{
|
||||
"/etc/hmac-file-server/config.toml",
|
||||
"../config.toml",
|
||||
"./config.toml",
|
||||
})
|
||||
}
|
||||
log.Info("Configuration loaded successfully.")
|
||||
|
||||
// Initialize file info cache
|
||||
fileInfoCache = cache.New(5*time.Minute, 10*time.Minute)
|
||||
|
||||
// Create store directory
|
||||
err = os.MkdirAll(conf.Server.StoragePath, os.ModePerm)
|
||||
err = validateConfig(&conf)
|
||||
if err != nil {
|
||||
log.Fatalf("Error creating store directory: %v", err)
|
||||
log.Fatalf("Configuration validation failed: %v", err)
|
||||
}
|
||||
log.WithField("directory", conf.Server.StoragePath).Info("Store directory is ready")
|
||||
log.Info("Configuration validated successfully.")
|
||||
|
||||
// Perform comprehensive configuration validation
|
||||
validationResult := ValidateConfigComprehensive(&conf)
|
||||
PrintValidationResults(validationResult)
|
||||
|
||||
if validationResult.HasErrors() {
|
||||
log.Fatal("Cannot start server due to configuration errors. Please fix the above issues and try again.")
|
||||
}
|
||||
|
||||
// Handle specialized validation flags
|
||||
if validateSecurity || validatePerformance || validateConnectivity || validateQuiet || validateVerbose || validateFixable {
|
||||
runSpecializedValidation(&conf, validateSecurity, validatePerformance, validateConnectivity, validateQuiet, validateVerbose, validateFixable)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
// If only validation was requested, exit now
|
||||
if validateOnly {
|
||||
if validationResult.HasErrors() {
|
||||
log.Error("Configuration validation failed with errors. Review the errors above.")
|
||||
os.Exit(1)
|
||||
} else if validationResult.HasWarnings() {
|
||||
log.Info("Configuration is valid but has warnings. Review the warnings above.")
|
||||
os.Exit(0)
|
||||
} else {
|
||||
log.Info("Configuration validation completed successfully!")
|
||||
os.Exit(0)
|
||||
}
|
||||
}
|
||||
|
||||
// Set log level based on configuration
|
||||
level, err := logrus.ParseLevel(conf.Logging.Level)
|
||||
if err != nil {
|
||||
log.Warnf("Invalid log level '%s', defaulting to 'info'", conf.Logging.Level)
|
||||
level = logrus.InfoLevel
|
||||
}
|
||||
log.SetLevel(level)
|
||||
log.Infof("Log level set to: %s", level.String())
|
||||
|
||||
// Log configuration settings using [logging] section
|
||||
log.Infof("Server ListenAddress: %s", conf.Server.ListenAddress) // Corrected field name
|
||||
log.Infof("Server UnixSocket: %v", conf.Server.UnixSocket)
|
||||
log.Infof("Server StoragePath: %s", conf.Server.StoragePath)
|
||||
log.Infof("Logging Level: %s", conf.Logging.Level)
|
||||
log.Infof("Logging File: %s", conf.Logging.File)
|
||||
log.Infof("Server MetricsEnabled: %v", conf.Server.MetricsEnabled)
|
||||
log.Infof("Server MetricsPort: %s", conf.Server.MetricsPort) // Corrected field name
|
||||
log.Infof("Server FileTTL: %s", conf.Server.FileTTL) // Corrected field name
|
||||
log.Infof("Server MinFreeBytes: %s", conf.Server.MinFreeBytes)
|
||||
log.Infof("Server AutoAdjustWorkers: %v", conf.Server.AutoAdjustWorkers) // Corrected field name
|
||||
log.Infof("Server NetworkEvents: %v", conf.Server.NetworkEvents) // Corrected field name
|
||||
log.Infof("Server PIDFilePath: %s", conf.Server.PIDFilePath) // Corrected field name
|
||||
log.Infof("Server CleanUponExit: %v", conf.Server.CleanUponExit) // Corrected field name
|
||||
log.Infof("Server PreCaching: %v", conf.Server.PreCaching) // Corrected field name
|
||||
log.Infof("Server FileTTLEnabled: %v", conf.Server.FileTTLEnabled) // Corrected field name
|
||||
log.Infof("Server DeduplicationEnabled: %v", conf.Server.DeduplicationEnabled)
|
||||
log.Infof("Server BindIP: %s", conf.Server.BindIP) // Corrected field name
|
||||
log.Infof("Server FileNaming: %s", conf.Server.FileNaming)
|
||||
log.Infof("Server ForceProtocol: %s", conf.Server.ForceProtocol)
|
||||
|
||||
err = writePIDFile(conf.Server.PIDFilePath) // Corrected field name
|
||||
if err != nil {
|
||||
log.Fatalf("Error writing PID file: %v", err)
|
||||
}
|
||||
log.Debug("DEBUG: PID file written successfully")
|
||||
|
||||
log.Debugf("DEBUG: Config logging file: %s", conf.Logging.File)
|
||||
|
||||
// Setup logging
|
||||
setupLogging()
|
||||
log.Debug("DEBUG: Logging setup completed")
|
||||
|
||||
// Log system information
|
||||
logSystemInfo()
|
||||
log.Debug("DEBUG: System info logged")
|
||||
|
||||
// Initialize Prometheus metrics
|
||||
// Initialize metrics before using any Prometheus counters
|
||||
initMetrics()
|
||||
log.Info("Prometheus metrics initialized.")
|
||||
log.Debug("DEBUG: Metrics initialized")
|
||||
|
||||
// Initialize upload and scan queues
|
||||
uploadQueue = make(chan UploadTask, conf.Workers.UploadQueueSize)
|
||||
scanQueue = make(chan ScanTask, conf.Workers.UploadQueueSize)
|
||||
networkEvents = make(chan NetworkEvent, 100)
|
||||
log.Info("Upload, scan, and network event channels initialized.")
|
||||
initializeWorkerSettings(&conf.Server, &conf.Workers, &conf.ClamAV)
|
||||
log.Debug("DEBUG: Worker settings initialized")
|
||||
|
||||
// Context for goroutines
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Start network monitoring
|
||||
go monitorNetwork(ctx)
|
||||
go handleNetworkEvents(ctx)
|
||||
|
||||
// Update system metrics
|
||||
go updateSystemMetrics(ctx)
|
||||
|
||||
// Initialize ClamAV client if enabled
|
||||
if conf.ClamAV.ClamAVEnabled {
|
||||
clamClient, err = initClamAV(conf.ClamAV.ClamAVSocket)
|
||||
if conf.ISO.Enabled {
|
||||
err := createAndMountISO(conf.ISO.Size, conf.ISO.MountPoint, conf.ISO.Charset)
|
||||
if err != nil {
|
||||
log.WithFields(logrus.Fields{
|
||||
"error": err.Error(),
|
||||
}).Warn("ClamAV client initialization failed. Continuing without ClamAV.")
|
||||
} else {
|
||||
log.Info("ClamAV client initialized successfully.")
|
||||
log.Fatalf("Failed to create and mount ISO container: %v", err)
|
||||
}
|
||||
log.Infof("ISO container mounted at %s", conf.ISO.MountPoint)
|
||||
}
|
||||
|
||||
// Initialize Redis client if enabled
|
||||
if conf.Redis.RedisEnabled {
|
||||
initRedis()
|
||||
// Set storage path to ISO mount point if ISO is enabled
|
||||
storagePath := conf.Server.StoragePath
|
||||
if conf.ISO.Enabled {
|
||||
storagePath = conf.ISO.MountPoint
|
||||
}
|
||||
|
||||
// Redis Initialization
|
||||
initRedis()
|
||||
log.Info("Redis client initialized and connected successfully.")
|
||||
fileInfoCache = cache.New(5*time.Minute, 10*time.Minute)
|
||||
fileMetadataCache = cache.New(5*time.Minute, 10*time.Minute)
|
||||
|
||||
// ClamAV Initialization
|
||||
if conf.ClamAV.ClamAVEnabled {
|
||||
clamClient, err = initClamAV(conf.ClamAV.ClamAVSocket)
|
||||
if err != nil {
|
||||
log.WithFields(logrus.Fields{
|
||||
"error": err.Error(),
|
||||
}).Warn("ClamAV client initialization failed. Continuing without ClamAV.")
|
||||
} else {
|
||||
log.Info("ClamAV client initialized successfully.")
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize worker pools
|
||||
initializeUploadWorkerPool(ctx)
|
||||
if conf.ClamAV.ClamAVEnabled && clamClient != nil {
|
||||
initializeScanWorkerPool(ctx)
|
||||
}
|
||||
|
||||
// Start Redis health monitor if Redis is enabled
|
||||
if conf.Redis.RedisEnabled && redisClient != nil {
|
||||
go MonitorRedisHealth(ctx, redisClient, parseDuration(conf.Redis.RedisHealthCheckInterval))
|
||||
}
|
||||
|
||||
// Setup router
|
||||
router := setupRouter()
|
||||
|
||||
// Start file cleaner
|
||||
fileTTL, err := time.ParseDuration(conf.Server.FileTTL)
|
||||
if err != nil {
|
||||
log.Fatalf("Invalid FileTTL: %v", err)
|
||||
}
|
||||
go runFileCleaner(ctx, conf.Server.StoragePath, fileTTL)
|
||||
|
||||
// Parse timeout durations
|
||||
readTimeout, err := time.ParseDuration(conf.Timeouts.ReadTimeout)
|
||||
if err != nil {
|
||||
log.Fatalf("Invalid ReadTimeout: %v", err)
|
||||
}
|
||||
|
||||
writeTimeout, err := time.ParseDuration(conf.Timeouts.WriteTimeout)
|
||||
if err != nil {
|
||||
log.Fatalf("Invalid WriteTimeout: %v", err)
|
||||
}
|
||||
|
||||
idleTimeout, err := time.ParseDuration(conf.Timeouts.IdleTimeout)
|
||||
if err != nil {
|
||||
log.Fatalf("Invalid IdleTimeout: %v", err)
|
||||
}
|
||||
|
||||
// Configure HTTP server
|
||||
server := &http.Server{
|
||||
Addr: ":" + conf.Server.ListenPort, // Prepend colon to ListenPort
|
||||
Handler: router,
|
||||
ReadTimeout: readTimeout,
|
||||
WriteTimeout: writeTimeout,
|
||||
IdleTimeout: idleTimeout,
|
||||
}
|
||||
|
||||
// Start metrics server if enabled
|
||||
if conf.Server.MetricsEnabled {
|
||||
if conf.Server.PreCaching { // Corrected field name
|
||||
go func() {
|
||||
http.Handle("/metrics", promhttp.Handler())
|
||||
log.Infof("Metrics server started on port %s", conf.Server.MetricsPort)
|
||||
if err := http.ListenAndServe(":"+conf.Server.MetricsPort, nil); err != nil {
|
||||
log.Fatalf("Metrics server failed: %v", err)
|
||||
log.Info("Starting pre-caching of storage path...")
|
||||
// Use helper function
|
||||
err := precacheStoragePath(storagePath)
|
||||
if err != nil {
|
||||
log.Warnf("Pre-caching storage path failed: %v", err)
|
||||
} else {
|
||||
log.Info("Pre-cached all files in the storage path.")
|
||||
log.Info("Pre-caching status: complete.")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Setup graceful shutdown
|
||||
setupGracefulShutdown(server, cancel)
|
||||
err = os.MkdirAll(storagePath, os.ModePerm)
|
||||
if err != nil {
|
||||
log.Fatalf("Error creating store directory: %v", err)
|
||||
}
|
||||
log.WithField("directory", storagePath).Info("Store directory is ready")
|
||||
|
||||
// Use helper function
|
||||
err = checkFreeSpaceWithRetry(storagePath, 3, 5*time.Second)
|
||||
if err != nil {
|
||||
log.Fatalf("Insufficient free space: %v", err)
|
||||
}
|
||||
|
||||
initializeWorkerSettings(&conf.Server, &conf.Workers, &conf.ClamAV)
|
||||
log.Info("Prometheus metrics initialized.")
|
||||
|
||||
networkEvents = make(chan NetworkEvent, 100)
|
||||
log.Info("Network event channel initialized.")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
if conf.Server.NetworkEvents { // Corrected field name
|
||||
go monitorNetwork(ctx) // Assuming monitorNetwork is defined in helpers.go or elsewhere
|
||||
go handleNetworkEvents(ctx) // Assuming handleNetworkEvents is defined in helpers.go or elsewhere
|
||||
}
|
||||
go updateSystemMetrics(ctx)
|
||||
|
||||
if conf.ClamAV.ClamAVEnabled {
|
||||
var clamErr error
|
||||
clamClient, clamErr = initClamAV(conf.ClamAV.ClamAVSocket) // Assuming initClamAV is defined in helpers.go or elsewhere
|
||||
if clamErr != nil {
|
||||
log.WithError(clamErr).Warn("ClamAV client initialization failed. Continuing without ClamAV.")
|
||||
} else {
|
||||
log.Info("ClamAV client initialized successfully.")
|
||||
}
|
||||
}
|
||||
|
||||
if conf.Redis.RedisEnabled {
|
||||
initRedis() // Assuming initRedis is defined in helpers.go or elsewhere
|
||||
}
|
||||
|
||||
router := setupRouter() // Assuming setupRouter is defined (likely in this file or router.go
|
||||
|
||||
go handleFileCleanup(&conf) // Directly call handleFileCleanup
|
||||
|
||||
readTimeout, err := time.ParseDuration(conf.Timeouts.Read) // Corrected field name
|
||||
if err != nil {
|
||||
log.Fatalf("Invalid ReadTimeout: %v", err)
|
||||
}
|
||||
|
||||
writeTimeout, err := time.ParseDuration(conf.Timeouts.Write) // Corrected field name
|
||||
if err != nil {
|
||||
log.Fatalf("Invalid WriteTimeout: %v", err)
|
||||
}
|
||||
|
||||
idleTimeout, err := time.ParseDuration(conf.Timeouts.Idle) // Corrected field name
|
||||
if err != nil {
|
||||
log.Fatalf("Invalid IdleTimeout: %v", err)
|
||||
}
|
||||
|
||||
// Initialize network protocol based on forceprotocol setting
|
||||
dialer, err := initializeNetworkProtocol(conf.Server.ForceProtocol)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to initialize network protocol: %v", err)
|
||||
}
|
||||
// Enhanced dual-stack HTTP client for robust IPv4/IPv6 and resource management
|
||||
// See: https://pkg.go.dev/net/http#Transport for details on these settings
|
||||
dualStackClient = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: dialer.DialContext,
|
||||
IdleConnTimeout: 90 * time.Second, // Close idle connections after 90s
|
||||
MaxIdleConns: 100, // Max idle connections across all hosts
|
||||
MaxIdleConnsPerHost: 10, // Max idle connections per host
|
||||
TLSHandshakeTimeout: 10 * time.Second, // Timeout for TLS handshake
|
||||
ResponseHeaderTimeout: 15 * time.Second, // Timeout for reading response headers
|
||||
},
|
||||
}
|
||||
|
||||
server := &http.Server{
|
||||
Addr: conf.Server.BindIP + ":" + conf.Server.ListenAddress, // Use BindIP + ListenAddress (port)
|
||||
Handler: router,
|
||||
ReadTimeout: readTimeout,
|
||||
WriteTimeout: writeTimeout,
|
||||
IdleTimeout: idleTimeout,
|
||||
MaxHeaderBytes: 1 << 20, // 1 MB
|
||||
}
|
||||
|
||||
if conf.Server.MetricsEnabled {
|
||||
var wg sync.WaitGroup
|
||||
go func() {
|
||||
http.Handle("/metrics", promhttp.Handler())
|
||||
log.Infof("Metrics server started on port %s", conf.Server.MetricsPort) // Corrected field name
|
||||
if err := http.ListenAndServe(":"+conf.Server.MetricsPort, nil); err != nil { // Corrected field name
|
||||
log.Fatalf("Metrics server failed: %v", err)
|
||||
}
|
||||
wg.Wait()
|
||||
}()
|
||||
}
|
||||
|
||||
setupGracefulShutdown(server, cancel) // Assuming setupGracefulShutdown is defined
|
||||
|
||||
if conf.Server.AutoAdjustWorkers { // Corrected field name
|
||||
go monitorWorkerPerformance(ctx, &conf.Server, &conf.Workers, &conf.ClamAV)
|
||||
}
|
||||
|
||||
versionString = "3.2" // Set a default version for now
|
||||
if conf.Build.Version != "" {
|
||||
versionString = conf.Build.Version
|
||||
}
|
||||
log.Infof("Running version: %s", versionString)
|
||||
|
||||
// Start server
|
||||
log.Infof("Starting HMAC file server %s...", versionString)
|
||||
if conf.Server.UnixSocket {
|
||||
// Listen on Unix socket
|
||||
if err := os.RemoveAll(conf.Server.ListenPort); err != nil {
|
||||
socketPath := "/tmp/hmac-file-server.sock" // Use a default socket path since ListenAddress is now a port
|
||||
if err := os.RemoveAll(socketPath); err != nil {
|
||||
log.Fatalf("Failed to remove existing Unix socket: %v", err)
|
||||
}
|
||||
listener, err := net.Listen("unix", conf.Server.ListenPort)
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to listen on Unix socket %s: %v", conf.Server.ListenPort, err)
|
||||
log.Fatalf("Failed to listen on Unix socket %s: %v", socketPath, err)
|
||||
}
|
||||
defer listener.Close()
|
||||
log.Infof("Server listening on Unix socket: %s", socketPath)
|
||||
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("Server failed: %v", err)
|
||||
}
|
||||
} else {
|
||||
// Listen on TCP port
|
||||
if conf.Server.BindIP == "0.0.0.0" {
|
||||
log.Info("Binding to 0.0.0.0. Any net/http logs you see are normal for this universal address.")
|
||||
}
|
||||
log.Infof("Server listening on %s", server.Addr)
|
||||
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("Server failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Start file cleanup in a separate goroutine
|
||||
// Use helper function
|
||||
go handleFileCleanup(&conf)
|
||||
}
|
||||
|
||||
// Function to load configuration using Viper
|
||||
func readConfig(configFilename string, conf *Config) error {
|
||||
viper.SetConfigFile(configFilename)
|
||||
viper.SetConfigType("toml")
|
||||
func printExampleConfig() {
|
||||
fmt.Print(`
|
||||
[server]
|
||||
bind_ip = "0.0.0.0"
|
||||
listenport = "8080"
|
||||
unixsocket = false
|
||||
storagepath = "./uploads"
|
||||
logfile = "/var/log/hmac-file-server.log"
|
||||
metricsenabled = true
|
||||
metricsport = "9090"
|
||||
minfreebytes = "100MB"
|
||||
filettl = "8760h"
|
||||
filettlenabled = true
|
||||
autoadjustworkers = true
|
||||
networkevents = true
|
||||
pidfilepath = "/var/run/hmacfileserver.pid"
|
||||
cleanuponexit = true
|
||||
precaching = true
|
||||
deduplicationenabled = true
|
||||
globalextensions = [".txt", ".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp"]
|
||||
# FileNaming options: "HMAC", "None"
|
||||
filenaming = "HMAC"
|
||||
forceprotocol = "auto"
|
||||
|
||||
// Read in environment variables that match
|
||||
viper.AutomaticEnv()
|
||||
viper.SetEnvPrefix("HMAC") // Prefix for environment variables
|
||||
[logging]
|
||||
level = "info"
|
||||
file = "/var/log/hmac-file-server.log"
|
||||
max_size = 100
|
||||
max_backups = 7
|
||||
max_age = 30
|
||||
compress = true
|
||||
|
||||
// Read the config file
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
return fmt.Errorf("error reading config file: %w", err)
|
||||
}
|
||||
[deduplication]
|
||||
enabled = true
|
||||
directory = "./deduplication"
|
||||
|
||||
// Unmarshal the config into the Config struct
|
||||
if err := viper.Unmarshal(conf); err != nil {
|
||||
return fmt.Errorf("unable to decode into struct: %w", err)
|
||||
}
|
||||
[iso]
|
||||
enabled = true
|
||||
size = "1GB"
|
||||
mountpoint = "/mnt/iso"
|
||||
charset = "utf-8"
|
||||
containerfile = "/mnt/iso/container.iso"
|
||||
|
||||
// Debug log the loaded configuration
|
||||
log.Debugf("Loaded Configuration: %+v", conf.Server)
|
||||
[timeouts]
|
||||
readtimeout = "4800s"
|
||||
writetimeout = "4800s"
|
||||
idletimeout = "4800s"
|
||||
|
||||
// Validate the configuration
|
||||
if err := validateConfig(conf); err != nil {
|
||||
return fmt.Errorf("configuration validation failed: %w", err)
|
||||
}
|
||||
[security]
|
||||
secret = "changeme"
|
||||
enablejwt = false
|
||||
jwtsecret = "anothersecretkey"
|
||||
jwtalgorithm = "HS256"
|
||||
jwtexpiration = "24h"
|
||||
|
||||
// Set Deduplication Enabled
|
||||
conf.Server.DeduplicationEnabled = viper.GetBool("deduplication.Enabled")
|
||||
[versioning]
|
||||
enableversioning = false
|
||||
maxversions = 1
|
||||
|
||||
return nil
|
||||
[uploads]
|
||||
resumableuploadsenabled = true
|
||||
chunkeduploadsenabled = true
|
||||
chunksize = "8192"
|
||||
allowedextensions = [".txt", ".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp"]
|
||||
|
||||
[downloads]
|
||||
resumabledownloadsenabled = true
|
||||
chunkeddownloadsenabled = true
|
||||
chunksize = "8192"
|
||||
allowedextensions = [".txt", ".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp"]
|
||||
|
||||
[clamav]
|
||||
clamavenabled = true
|
||||
clamavsocket = "/var/run/clamav/clamd.ctl"
|
||||
numscanworkers = 2
|
||||
scanfileextensions = [".txt", ".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp"]
|
||||
|
||||
[redis]
|
||||
redisenabled = true
|
||||
redisdbindex = 0
|
||||
redisaddr = "localhost:6379"
|
||||
redispassword = ""
|
||||
redishealthcheckinterval = "120s"
|
||||
|
||||
[workers]
|
||||
numworkers = 4
|
||||
uploadqueuesize = 50
|
||||
|
||||
[file]
|
||||
# Add file-specific configurations here
|
||||
|
||||
[build]
|
||||
version = "3.2"
|
||||
`)
|
||||
}
|
||||
|
||||
// Set default configuration values
|
||||
func setDefaults() {
|
||||
// Server defaults
|
||||
viper.SetDefault("server.ListenPort", "8080")
|
||||
viper.SetDefault("server.UnixSocket", false)
|
||||
viper.SetDefault("server.StoragePath", "./uploads")
|
||||
viper.SetDefault("server.LogLevel", "info")
|
||||
viper.SetDefault("server.LogFile", "")
|
||||
viper.SetDefault("server.MetricsEnabled", true)
|
||||
viper.SetDefault("server.MetricsPort", "9090")
|
||||
viper.SetDefault("server.FileTTL", "8760h") // 365d -> 8760h
|
||||
viper.SetDefault("server.MinFreeBytes", 100<<20) // 100 MB
|
||||
func getExampleConfigString() string {
|
||||
return `[server]
|
||||
listen_address = ":8080"
|
||||
storage_path = "/srv/hmac-file-server/uploads"
|
||||
metrics_enabled = true
|
||||
metrics_path = "/metrics"
|
||||
pid_file = "/var/run/hmac-file-server.pid"
|
||||
max_upload_size = "10GB" # Supports B, KB, MB, GB, TB
|
||||
max_header_bytes = 1048576 # 1MB
|
||||
cleanup_interval = "24h"
|
||||
max_file_age = "720h" # 30 days
|
||||
pre_cache = true
|
||||
pre_cache_workers = 4
|
||||
pre_cache_interval = "1h"
|
||||
global_extensions = [".txt", ".dat", ".iso"] # If set, overrides upload/download extensions
|
||||
deduplication_enabled = true
|
||||
min_free_bytes = "1GB" # Minimum free space required for uploads
|
||||
file_naming = "original" # Options: "original", "HMAC"
|
||||
force_protocol = "" # Options: "http", "https" - if set, redirects to this protocol
|
||||
enable_dynamic_workers = true # Enable dynamic worker scaling
|
||||
worker_scale_up_thresh = 50 # Queue length to scale up workers
|
||||
worker_scale_down_thresh = 10 # Queue length to scale down workers
|
||||
|
||||
// Timeout defaults
|
||||
viper.SetDefault("timeouts.ReadTimeout", "4800s") // supports 's'
|
||||
viper.SetDefault("timeouts.WriteTimeout", "4800s")
|
||||
viper.SetDefault("timeouts.IdleTimeout", "4800s")
|
||||
[uploads]
|
||||
allowed_extensions = [".zip", ".rar", ".7z", ".tar.gz", ".tgz", ".gpg", ".enc", ".pgp"]
|
||||
chunked_uploads_enabled = true
|
||||
chunk_size = "10MB"
|
||||
resumable_uploads_enabled = true
|
||||
max_resumable_age = "48h"
|
||||
|
||||
// Security defaults
|
||||
viper.SetDefault("security.Secret", "changeme")
|
||||
[downloads]
|
||||
allowed_extensions = [".zip", ".rar", ".7z", ".tar.gz", ".tgz", ".gpg", ".enc", ".pgp"]
|
||||
chunked_downloads_enabled = true
|
||||
chunk_size = "10MB"
|
||||
resumable_downloads_enabled = true
|
||||
|
||||
// Versioning defaults
|
||||
viper.SetDefault("versioning.EnableVersioning", false)
|
||||
viper.SetDefault("versioning.MaxVersions", 1)
|
||||
[security]
|
||||
secret = "your-very-secret-hmac-key"
|
||||
enablejwt = false
|
||||
jwtsecret = "anothersecretkey"
|
||||
jwtalgorithm = "HS256"
|
||||
jwtexpiration = "24h"
|
||||
|
||||
// Uploads defaults
|
||||
viper.SetDefault("uploads.ResumableUploadsEnabled", true)
|
||||
viper.SetDefault("uploads.ChunkedUploadsEnabled", true)
|
||||
viper.SetDefault("uploads.ChunkSize", 8192)
|
||||
viper.SetDefault("uploads.AllowedExtensions", []string{
|
||||
".txt", ".pdf",
|
||||
".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp",
|
||||
".wav", ".mp4", ".avi", ".mkv", ".mov", ".wmv", ".flv", ".webm", ".mpeg", ".mpg", ".m4v", ".3gp", ".3g2",
|
||||
".mp3", ".ogg",
|
||||
})
|
||||
[logging]
|
||||
level = "info"
|
||||
file = "/var/log/hmac-file-server.log"
|
||||
max_size = 100
|
||||
max_backups = 7
|
||||
max_age = 30
|
||||
compress = true
|
||||
|
||||
// ClamAV defaults
|
||||
viper.SetDefault("clamav.ClamAVEnabled", true)
|
||||
viper.SetDefault("clamav.ClamAVSocket", "/var/run/clamav/clamd.ctl")
|
||||
viper.SetDefault("clamav.NumScanWorkers", 2)
|
||||
[deduplication]
|
||||
enabled = true
|
||||
directory = "./deduplication"
|
||||
|
||||
// Redis defaults
|
||||
viper.SetDefault("redis.RedisEnabled", true)
|
||||
viper.SetDefault("redis.RedisAddr", "localhost:6379")
|
||||
viper.SetDefault("redis.RedisPassword", "")
|
||||
viper.SetDefault("redis.RedisDBIndex", 0)
|
||||
viper.SetDefault("redis.RedisHealthCheckInterval", "120s")
|
||||
[iso]
|
||||
enabled = true
|
||||
size = "1GB"
|
||||
mountpoint = "/mnt/iso"
|
||||
charset = "utf-8"
|
||||
containerfile = "/mnt/iso/container.iso"
|
||||
|
||||
// Workers defaults
|
||||
viper.SetDefault("workers.NumWorkers", 2)
|
||||
viper.SetDefault("workers.UploadQueueSize", 50)
|
||||
[timeouts]
|
||||
readtimeout = "4800s"
|
||||
writetimeout = "4800s"
|
||||
idletimeout = "4800s"
|
||||
|
||||
// Deduplication defaults
|
||||
viper.SetDefault("deduplication.Enabled", true)
|
||||
[security]
|
||||
secret = "changeme"
|
||||
enablejwt = false
|
||||
jwtsecret = "anothersecretkey"
|
||||
jwtalgorithm = "HS256"
|
||||
jwtexpiration = "24h"
|
||||
|
||||
[versioning]
|
||||
enableversioning = false
|
||||
maxversions = 1
|
||||
|
||||
[uploads]
|
||||
resumableuploadsenabled = true
|
||||
chunkeduploadsenabled = true
|
||||
chunksize = "8192"
|
||||
allowedextensions = [".txt", ".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp"]
|
||||
|
||||
[downloads]
|
||||
resumabledownloadsenabled = true
|
||||
chunkeddownloadsenabled = true
|
||||
chunksize = "8192"
|
||||
allowedextensions = [".txt", ".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp"]
|
||||
|
||||
[clamav]
|
||||
clamavenabled = true
|
||||
clamavsocket = "/var/run/clamav/clamd.ctl"
|
||||
numscanworkers = 2
|
||||
scanfileextensions = [".txt", ".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp"]
|
||||
|
||||
[redis]
|
||||
redisenabled = true
|
||||
redisdbindex = 0
|
||||
redisaddr = "localhost:6379"
|
||||
redispassword = ""
|
||||
redishealthcheckinterval = "120s"
|
||||
|
||||
[workers]
|
||||
numworkers = 4
|
||||
uploadqueuesize = 50
|
||||
|
||||
[file]
|
||||
# Add file-specific configurations here
|
||||
|
||||
[build]
|
||||
version = "3.2"
|
||||
`
|
||||
}
|
||||
|
||||
// Validate configuration fields
|
||||
func validateConfig(conf *Config) error {
|
||||
if conf.Server.ListenPort == "" {
|
||||
return fmt.Errorf("ListenPort must be set")
|
||||
func max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
if conf.Security.Secret == "" {
|
||||
return fmt.Errorf("secret must be set")
|
||||
}
|
||||
if conf.Server.StoragePath == "" {
|
||||
return fmt.Errorf("StoragePath must be set")
|
||||
}
|
||||
if conf.Server.FileTTL == "" {
|
||||
return fmt.Errorf("FileTTL must be set")
|
||||
}
|
||||
|
||||
// Validate timeouts
|
||||
if _, err := time.ParseDuration(conf.Timeouts.ReadTimeout); err != nil {
|
||||
return fmt.Errorf("invalid ReadTimeout: %v", err)
|
||||
}
|
||||
if _, err := time.ParseDuration(conf.Timeouts.WriteTimeout); err != nil {
|
||||
return fmt.Errorf("invalid WriteTimeout: %v", err)
|
||||
}
|
||||
if _, err := time.ParseDuration(conf.Timeouts.IdleTimeout); err != nil {
|
||||
return fmt.Errorf("invalid IdleTimeout: %v", err)
|
||||
}
|
||||
|
||||
// Validate Redis configuration if enabled
|
||||
if conf.Redis.RedisEnabled {
|
||||
if conf.Redis.RedisAddr == "" {
|
||||
return fmt.Errorf("RedisAddr must be set when Redis is enabled")
|
||||
}
|
||||
}
|
||||
|
||||
// Add more validations as needed
|
||||
|
||||
return nil
|
||||
return b
|
||||
}
|
||||
|
||||
// Setup logging
|
||||
func setupLogging() {
|
||||
level, err := logrus.ParseLevel(conf.Server.LogLevel)
|
||||
if err != nil {
|
||||
log.Fatalf("Invalid log level: %s", conf.Server.LogLevel)
|
||||
}
|
||||
log.SetLevel(level)
|
||||
|
||||
if conf.Server.LogFile != "" {
|
||||
logFile, err := os.OpenFile(conf.Server.LogFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to open log file: %v", err)
|
||||
}
|
||||
log.SetOutput(io.MultiWriter(os.Stdout, logFile))
|
||||
} else {
|
||||
log.SetOutput(os.Stdout)
|
||||
}
|
||||
|
||||
// Use Text formatter for human-readable logs
|
||||
log.SetFormatter(&logrus.TextFormatter{
|
||||
FullTimestamp: true,
|
||||
// You can customize the format further if needed
|
||||
})
|
||||
}
|
||||
|
||||
// Log system information
|
||||
func logSystemInfo() {
|
||||
log.Info("========================================")
|
||||
log.Infof(" HMAC File Server - %s ", versionString)
|
||||
log.Info(" Secure File Handling with HMAC Auth ")
|
||||
log.Info("========================================")
|
||||
|
||||
log.Info("Features: Prometheus Metrics, Chunked Uploads, ClamAV Scanning")
|
||||
log.Info("Build Date: 2024-10-28")
|
||||
|
||||
log.Infof("Operating System: %s", runtime.GOOS)
|
||||
log.Infof("Architecture: %s", runtime.GOARCH)
|
||||
log.Infof("Number of CPUs: %d", runtime.NumCPU())
|
||||
log.Infof("Go Version: %s", runtime.Version())
|
||||
|
||||
func autoAdjustWorkers() (int, int) {
|
||||
v, _ := mem.VirtualMemory()
|
||||
log.Infof("Total Memory: %v MB", v.Total/1024/1024)
|
||||
log.Infof("Free Memory: %v MB", v.Free/1024/1024)
|
||||
log.Infof("Used Memory: %v MB", v.Used/1024/1024)
|
||||
cpuCores, _ := cpu.Counts(true)
|
||||
|
||||
cpuInfo, _ := cpu.Info()
|
||||
for _, info := range cpuInfo {
|
||||
log.Infof("CPU Model: %s, Cores: %d, Mhz: %f", info.ModelName, info.Cores, info.Mhz)
|
||||
numWorkers := cpuCores * 2
|
||||
if v.Available < 4*1024*1024*1024 { // Less than 4GB available
|
||||
numWorkers = max(numWorkers/2, 1)
|
||||
} else if v.Available < 8*1024*1024*1024 { // Less than 8GB available
|
||||
numWorkers = max(numWorkers*3/4, 1)
|
||||
}
|
||||
queueSize := numWorkers * 10
|
||||
|
||||
partitions, _ := disk.Partitions(false)
|
||||
for _, partition := range partitions {
|
||||
usage, _ := disk.Usage(partition.Mountpoint)
|
||||
log.Infof("Disk Mountpoint: %s, Total: %v GB, Free: %v GB, Used: %v GB",
|
||||
partition.Mountpoint, usage.Total/1024/1024/1024, usage.Free/1024/1024/1024, usage.Used/1024/1024/1024)
|
||||
}
|
||||
|
||||
hInfo, _ := host.Info()
|
||||
log.Infof("Hostname: %s", hInfo.Hostname)
|
||||
log.Infof("Uptime: %v seconds", hInfo.Uptime)
|
||||
log.Infof("Boot Time: %v", time.Unix(int64(hInfo.BootTime), 0))
|
||||
log.Infof("Platform: %s", hInfo.Platform)
|
||||
log.Infof("Platform Family: %s", hInfo.PlatformFamily)
|
||||
log.Infof("Platform Version: %s", hInfo.PlatformVersion)
|
||||
log.Infof("Kernel Version: %s", hInfo.KernelVersion)
|
||||
log.Infof("Auto-adjusting workers: NumWorkers=%d, UploadQueueSize=%d", numWorkers, queueSize)
|
||||
workerAdjustmentsTotal.Inc()
|
||||
return numWorkers, queueSize
|
||||
}
|
||||
|
||||
// Initialize Prometheus metrics
|
||||
// Duplicate initMetrics function removed
|
||||
func initMetrics() {
|
||||
uploadDuration = prometheus.NewHistogram(prometheus.HistogramOpts{Namespace: "hmac", Name: "file_server_upload_duration_seconds", Help: "Histogram of file upload duration in seconds.", Buckets: prometheus.DefBuckets})
|
||||
uploadErrorsTotal = prometheus.NewCounter(prometheus.CounterOpts{Namespace: "hmac", Name: "file_server_upload_errors_total", Help: "Total number of file upload errors."})
|
||||
uploadsTotal = prometheus.NewCounter(prometheus.CounterOpts{Namespace: "hmac", Name: "file_server_uploads_total", Help: "Total number of successful file uploads."})
|
||||
downloadDuration = prometheus.NewHistogram(prometheus.HistogramOpts{Namespace: "hmac", Name: "file_server_download_duration_seconds", Help: "Histogram of file download duration in seconds.", Buckets: prometheus.DefBuckets})
|
||||
downloadsTotal = prometheus.NewCounter(prometheus.CounterOpts{Namespace: "hmac", Name: "file_server_downloads_total", Help: "Total number of successful file downloads."})
|
||||
downloadErrorsTotal = prometheus.NewCounter(prometheus.CounterOpts{Namespace: "hmac", Name: "file_server_download_errors_total", Help: "Total number of file download errors."})
|
||||
memoryUsage = prometheus.NewGauge(prometheus.GaugeOpts{Namespace: "hmac", Name: "memory_usage_bytes", Help: "Current memory usage in bytes."})
|
||||
cpuUsage = prometheus.NewGauge(prometheus.GaugeOpts{Namespace: "hmac", Name: "cpu_usage_percent", Help: "Current CPU usage as a percentage."})
|
||||
activeConnections = prometheus.NewGauge(prometheus.GaugeOpts{Namespace: "hmac", Name: "active_connections_total", Help: "Total number of active connections."})
|
||||
requestsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{Namespace: "hmac", Name: "http_requests_total", Help: "Total number of HTTP requests received, labeled by method and path."}, []string{"method", "path"})
|
||||
goroutines = prometheus.NewGauge(prometheus.GaugeOpts{Namespace: "hmac", Name: "goroutines_count", Help: "Current number of goroutines."})
|
||||
uploadSizeBytes = prometheus.NewHistogram(prometheus.HistogramOpts{
|
||||
Namespace: "hmac",
|
||||
Name: "file_server_upload_size_bytes",
|
||||
Help: "Histogram of uploaded file sizes in bytes.",
|
||||
Buckets: prometheus.ExponentialBuckets(100, 10, 8),
|
||||
})
|
||||
downloadSizeBytes = prometheus.NewHistogram(prometheus.HistogramOpts{
|
||||
Namespace: "hmac",
|
||||
Name: "file_server_download_size_bytes",
|
||||
Help: "Histogram of downloaded file sizes in bytes.",
|
||||
Buckets: prometheus.ExponentialBuckets(100, 10, 8),
|
||||
})
|
||||
func initializeWorkerSettings(server *ServerConfig, workers *WorkersConfig, clamav *ClamAVConfig) {
|
||||
if server.AutoAdjustWorkers {
|
||||
numWorkers, queueSize := autoAdjustWorkers()
|
||||
workers.NumWorkers = numWorkers
|
||||
workers.UploadQueueSize = queueSize
|
||||
clamav.NumScanWorkers = max(numWorkers/2, 1)
|
||||
|
||||
if conf.Server.MetricsEnabled {
|
||||
prometheus.MustRegister(uploadDuration, uploadErrorsTotal, uploadsTotal)
|
||||
prometheus.MustRegister(downloadDuration, downloadsTotal, downloadErrorsTotal)
|
||||
prometheus.MustRegister(memoryUsage, cpuUsage, activeConnections, requestsTotal, goroutines)
|
||||
prometheus.MustRegister(uploadSizeBytes, downloadSizeBytes)
|
||||
log.Infof("AutoAdjustWorkers enabled: NumWorkers=%d, UploadQueueSize=%d, NumScanWorkers=%d",
|
||||
workers.NumWorkers, workers.UploadQueueSize, clamav.NumScanWorkers)
|
||||
} else {
|
||||
log.Infof("Manual configuration in effect: NumWorkers=%d, UploadQueueSize=%d, NumScanWorkers=%d",
|
||||
workers.NumWorkers, workers.UploadQueueSize, clamav.NumScanWorkers)
|
||||
}
|
||||
}
|
||||
|
||||
// Update system metrics
|
||||
func updateSystemMetrics(ctx context.Context) {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
func monitorWorkerPerformance(ctx context.Context, server *ServerConfig, w *WorkersConfig, clamav *ClamAVConfig) {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info("Stopping system metrics updater.")
|
||||
log.Info("Stopping worker performance monitor.")
|
||||
return
|
||||
case <-ticker.C:
|
||||
v, _ := mem.VirtualMemory()
|
||||
memoryUsage.Set(float64(v.Used))
|
||||
if server.AutoAdjustWorkers {
|
||||
numWorkers, queueSize := autoAdjustWorkers()
|
||||
w.NumWorkers = numWorkers
|
||||
w.UploadQueueSize = queueSize
|
||||
clamav.NumScanWorkers = max(numWorkers/2, 1)
|
||||
|
||||
cpuPercent, _ := cpu.Percent(0, false)
|
||||
if len(cpuPercent) > 0 {
|
||||
cpuUsage.Set(cpuPercent[0])
|
||||
log.Infof("Re-adjusted workers: NumWorkers=%d, UploadQueueSize=%d, NumScanWorkers=%d",
|
||||
w.NumWorkers, w.UploadQueueSize, clamav.NumScanWorkers)
|
||||
workerReAdjustmentsTotal.Inc()
|
||||
}
|
||||
|
||||
goroutines.Set(float64(runtime.NumGoroutine()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Function to check if a file exists and return its size
|
||||
func fileExists(filePath string) (bool, int64) {
|
||||
if cachedInfo, found := fileInfoCache.Get(filePath); found {
|
||||
if info, ok := cachedInfo.(os.FileInfo); ok {
|
||||
return !info.IsDir(), info.Size()
|
||||
}
|
||||
}
|
||||
|
||||
fileInfo, err := os.Stat(filePath)
|
||||
if os.IsNotExist(err) {
|
||||
return false, 0
|
||||
} else if err != nil {
|
||||
log.Error("Error checking file existence:", err)
|
||||
return false, 0
|
||||
}
|
||||
|
||||
fileInfoCache.Set(filePath, fileInfo, cache.DefaultExpiration)
|
||||
return !fileInfo.IsDir(), fileInfo.Size()
|
||||
}
|
||||
|
||||
// Function to check file extension
|
||||
func isExtensionAllowed(filename string) bool {
|
||||
if len(conf.Uploads.AllowedExtensions) == 0 {
|
||||
return true // No restrictions if the list is empty
|
||||
}
|
||||
ext := strings.ToLower(filepath.Ext(filename))
|
||||
for _, allowedExt := range conf.Uploads.AllowedExtensions {
|
||||
if strings.ToLower(allowedExt) == ext {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Version the file by moving the existing file to a versioned directory
|
||||
func versionFile(absFilename string) error {
|
||||
versionDir := absFilename + "_versions"
|
||||
|
||||
err := os.MkdirAll(versionDir, os.ModePerm)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create version directory: %v", err)
|
||||
}
|
||||
|
||||
timestamp := time.Now().Format("20060102-150405")
|
||||
versionedFilename := filepath.Join(versionDir, filepath.Base(absFilename)+"."+timestamp)
|
||||
|
||||
err = os.Rename(absFilename, versionedFilename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to version the file: %v", err)
|
||||
}
|
||||
|
||||
log.WithFields(logrus.Fields{
|
||||
"original": absFilename,
|
||||
"versioned_as": versionedFilename,
|
||||
}).Info("Versioned old file")
|
||||
return cleanupOldVersions(versionDir)
|
||||
}
|
||||
|
||||
// Clean up older versions if they exceed the maximum allowed
|
||||
func cleanupOldVersions(versionDir string) error {
|
||||
files, err := os.ReadDir(versionDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list version files: %v", err)
|
||||
}
|
||||
|
||||
if conf.Versioning.MaxVersions > 0 && len(files) > conf.Versioning.MaxVersions {
|
||||
excessFiles := len(files) - conf.Versioning.MaxVersions
|
||||
for i := 0; i < excessFiles; i++ {
|
||||
err := os.Remove(filepath.Join(versionDir, files[i].Name()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove old version: %v", err)
|
||||
}
|
||||
log.WithField("file", files[i].Name()).Info("Removed old version")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Process the upload task
|
||||
func processUpload(task UploadTask) error {
|
||||
absFilename := task.AbsFilename
|
||||
tempFilename := absFilename + ".tmp"
|
||||
r := task.Request
|
||||
|
||||
log.Infof("Processing upload for file: %s", absFilename)
|
||||
startTime := time.Now()
|
||||
|
||||
// Handle uploads and write to a temporary file
|
||||
if conf.Uploads.ChunkedUploadsEnabled {
|
||||
log.Debugf("Chunked uploads enabled. Handling chunked upload for %s", tempFilename)
|
||||
err := handleChunkedUpload(tempFilename, r)
|
||||
if err != nil {
|
||||
uploadDuration.Observe(time.Since(startTime).Seconds())
|
||||
log.WithFields(logrus.Fields{
|
||||
"file": tempFilename,
|
||||
"error": err,
|
||||
}).Error("Failed to handle chunked upload")
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
log.Debugf("Handling standard upload for %s", tempFilename)
|
||||
err := createFile(tempFilename, r)
|
||||
if err != nil {
|
||||
log.WithFields(logrus.Fields{
|
||||
"file": tempFilename,
|
||||
"error": err,
|
||||
}).Error("Error creating file")
|
||||
uploadDuration.Observe(time.Since(startTime).Seconds())
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Perform ClamAV scan on the temporary file
|
||||
if clamClient != nil {
|
||||
log.Debugf("Scanning %s with ClamAV", tempFilename)
|
||||
err := scanFileWithClamAV(tempFilename)
|
||||
if err != nil {
|
||||
log.WithFields(logrus.Fields{
|
||||
"file": tempFilename,
|
||||
"error": err,
|
||||
}).Warn("ClamAV detected a virus or scan failed")
|
||||
os.Remove(tempFilename)
|
||||
uploadErrorsTotal.Inc()
|
||||
return err
|
||||
}
|
||||
log.Infof("ClamAV scan passed for file: %s", tempFilename)
|
||||
}
|
||||
|
||||
// Handle file versioning if enabled
|
||||
if conf.Versioning.EnableVersioning {
|
||||
existing, _ := fileExists(absFilename)
|
||||
if existing {
|
||||
log.Infof("File %s exists. Initiating versioning.", absFilename)
|
||||
err := versionFile(absFilename)
|
||||
if err != nil {
|
||||
log.WithFields(logrus.Fields{
|
||||
"file": absFilename,
|
||||
"error": err,
|
||||
}).Error("Error versioning file")
|
||||
os.Remove(tempFilename)
|
||||
return err
|
||||
}
|
||||
log.Infof("File versioned successfully: %s", absFilename)
|
||||
}
|
||||
}
|
||||
|
||||
// Rename temporary file to final destination
|
||||
err := os.Rename(tempFilename, absFilename)
|
||||
if err != nil {
|
||||
log.WithFields(logrus.Fields{
|
||||
"temp_file": tempFilename,
|
||||
"final_file": absFilename,
|
||||
"error": err,
|
||||
}).Error("Failed to move file to final destination")
|
||||
os.Remove(tempFilename)
|
||||
func readConfig(configFilename string, conf *Config) error {
|
||||
viper.SetConfigFile(configFilename)
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
log.WithError(err).Errorf("Unable to read config from %s", configFilename)
|
||||
return err
|
||||
}
|
||||
log.Infof("File moved to final destination: %s", absFilename)
|
||||
|
||||
// Handle deduplication if enabled
|
||||
if conf.Server.DeduplicationEnabled {
|
||||
log.Debugf("Deduplication enabled. Checking duplicates for %s", absFilename)
|
||||
err = handleDeduplication(context.Background(), absFilename)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Deduplication failed")
|
||||
uploadErrorsTotal.Inc()
|
||||
return err
|
||||
}
|
||||
log.Infof("Deduplication handled successfully for file: %s", absFilename)
|
||||
if err := viper.Unmarshal(conf); err != nil {
|
||||
return fmt.Errorf("unable to decode config into struct: %v", err)
|
||||
}
|
||||
|
||||
log.WithFields(logrus.Fields{
|
||||
"file": absFilename,
|
||||
}).Info("File uploaded and processed successfully")
|
||||
|
||||
uploadDuration.Observe(time.Since(startTime).Seconds())
|
||||
uploadsTotal.Inc()
|
||||
return nil
|
||||
}
|
||||
|
||||
// uploadWorker processes upload tasks from the uploadQueue
|
||||
func uploadWorker(ctx context.Context, workerID int) {
|
||||
log.Infof("Upload worker %d started.", workerID)
|
||||
defer log.Infof("Upload worker %d stopped.", workerID)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case task, ok := <-uploadQueue:
|
||||
if !ok {
|
||||
log.Warnf("Upload queue closed. Worker %d exiting.", workerID)
|
||||
return
|
||||
}
|
||||
log.Infof("Worker %d processing upload for file: %s", workerID, task.AbsFilename)
|
||||
err := processUpload(task)
|
||||
if err != nil {
|
||||
log.Errorf("Worker %d failed to process upload for %s: %v", workerID, task.AbsFilename, err)
|
||||
uploadErrorsTotal.Inc()
|
||||
} else {
|
||||
log.Infof("Worker %d successfully processed upload for %s", workerID, task.AbsFilename)
|
||||
}
|
||||
task.Result <- err
|
||||
close(task.Result)
|
||||
func setDefaults() {
|
||||
viper.SetDefault("server.listen_address", ":8080")
|
||||
viper.SetDefault("server.storage_path", "./uploads")
|
||||
viper.SetDefault("server.metrics_enabled", true)
|
||||
viper.SetDefault("server.metrics_path", "/metrics")
|
||||
viper.SetDefault("server.pid_file", "/var/run/hmac-file-server.pid")
|
||||
viper.SetDefault("server.max_upload_size", "10GB")
|
||||
viper.SetDefault("server.max_header_bytes", 1048576) // 1MB
|
||||
viper.SetDefault("server.cleanup_interval", "24h")
|
||||
viper.SetDefault("server.max_file_age", "720h") // 30 days
|
||||
viper.SetDefault("server.pre_cache", true)
|
||||
viper.SetDefault("server.pre_cache_workers", 4)
|
||||
viper.SetDefault("server.pre_cache_interval", "1h")
|
||||
viper.SetDefault("server.global_extensions", []string{})
|
||||
viper.SetDefault("server.deduplication_enabled", true)
|
||||
viper.SetDefault("server.min_free_bytes", "1GB")
|
||||
viper.SetDefault("server.file_naming", "original")
|
||||
viper.SetDefault("server.force_protocol", "")
|
||||
viper.SetDefault("server.enable_dynamic_workers", true)
|
||||
viper.SetDefault("server.worker_scale_up_thresh", 50)
|
||||
viper.SetDefault("server.worker_scale_down_thresh", 10)
|
||||
|
||||
viper.SetDefault("uploads.allowed_extensions", []string{".zip", ".rar", ".7z", ".tar.gz", ".tgz", ".gpg", ".enc", ".pgp"})
|
||||
viper.SetDefault("uploads.chunked_uploads_enabled", true)
|
||||
viper.SetDefault("uploads.chunk_size", "10MB")
|
||||
viper.SetDefault("uploads.resumable_uploads_enabled", true)
|
||||
viper.SetDefault("uploads.max_resumable_age", "48h")
|
||||
|
||||
viper.SetDefault("downloads.allowed_extensions", []string{".zip", ".rar", ".7z", ".tar.gz", ".tgz", ".gpg", ".enc", ".pgp"})
|
||||
viper.SetDefault("downloads.chunked_downloads_enabled", true)
|
||||
viper.SetDefault("downloads.chunk_size", "10MB")
|
||||
viper.SetDefault("downloads.resumable_downloads_enabled", true)
|
||||
|
||||
viper.SetDefault("security.secret", "your-very-secret-hmac-key")
|
||||
viper.SetDefault("security.enablejwt", false)
|
||||
viper.SetDefault("security.jwtsecret", "your-256-bit-secret")
|
||||
viper.SetDefault("security.jwtalgorithm", "HS256")
|
||||
viper.SetDefault("security.jwtexpiration", "24h")
|
||||
|
||||
// Logging defaults
|
||||
viper.SetDefault("logging.level", "info")
|
||||
viper.SetDefault("logging.file", "/var/log/hmac-file-server.log")
|
||||
viper.SetDefault("logging.max_size", 100)
|
||||
viper.SetDefault("logging.max_backups", 7)
|
||||
viper.SetDefault("logging.max_age", 30)
|
||||
viper.SetDefault("logging.compress", true)
|
||||
|
||||
// Deduplication defaults
|
||||
viper.SetDefault("deduplication.enabled", false)
|
||||
viper.SetDefault("deduplication.directory", "./dedup_store")
|
||||
|
||||
// ISO defaults
|
||||
viper.SetDefault("iso.enabled", false)
|
||||
viper.SetDefault("iso.mount_point", "/mnt/hmac_iso")
|
||||
viper.SetDefault("iso.size", "1GB")
|
||||
viper.SetDefault("iso.charset", "utf-8")
|
||||
viper.SetDefault("iso.containerfile", "/var/lib/hmac-file-server/data.iso")
|
||||
|
||||
// Timeouts defaults
|
||||
viper.SetDefault("timeouts.read", "60s")
|
||||
viper.SetDefault("timeouts.write", "60s")
|
||||
viper.SetDefault("timeouts.idle", "120s")
|
||||
viper.SetDefault("timeouts.shutdown", "30s")
|
||||
|
||||
// Versioning defaults
|
||||
viper.SetDefault("versioning.enabled", false)
|
||||
viper.SetDefault("versioning.backend", "simple")
|
||||
viper.SetDefault("versioning.max_revisions", 5)
|
||||
|
||||
// ... other defaults for Uploads, Downloads, ClamAV, Redis, Workers, File, Build
|
||||
viper.SetDefault("build.version", "dev")
|
||||
}
|
||||
|
||||
func validateConfig(c *Config) error {
|
||||
if c.Server.ListenAddress == "" { // Corrected field name
|
||||
return errors.New("server.listen_address is required")
|
||||
}
|
||||
|
||||
if c.Server.FileTTL == "" && c.Server.FileTTLEnabled { // Corrected field names
|
||||
return errors.New("server.file_ttl is required when server.file_ttl_enabled is true")
|
||||
}
|
||||
|
||||
if _, err := time.ParseDuration(c.Timeouts.Read); err != nil { // Corrected field name
|
||||
return fmt.Errorf("invalid timeouts.read: %v", err)
|
||||
}
|
||||
if _, err := time.ParseDuration(c.Timeouts.Write); err != nil { // Corrected field name
|
||||
return fmt.Errorf("invalid timeouts.write: %v", err)
|
||||
}
|
||||
if _, err := time.ParseDuration(c.Timeouts.Idle); err != nil { // Corrected field name
|
||||
return fmt.Errorf("invalid timeouts.idle: %v", err)
|
||||
}
|
||||
|
||||
// Corrected VersioningConfig field access
|
||||
if c.Versioning.Enabled { // Use the Go struct field name 'Enabled'
|
||||
if c.Versioning.MaxRevs <= 0 { // Use the Go struct field name 'MaxRevs'
|
||||
return errors.New("versioning.max_revisions must be positive if versioning is enabled")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize upload worker pool
|
||||
func initializeUploadWorkerPool(ctx context.Context) {
|
||||
for i := 0; i < MinWorkers; i++ {
|
||||
go uploadWorker(ctx, i)
|
||||
}
|
||||
log.Infof("Initialized %d upload workers", MinWorkers)
|
||||
}
|
||||
|
||||
// Worker function to process scan tasks
|
||||
func scanWorker(ctx context.Context, workerID int) {
|
||||
log.WithField("worker_id", workerID).Info("Scan worker started")
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.WithField("worker_id", workerID).Info("Scan worker stopping")
|
||||
return
|
||||
case task, ok := <-scanQueue:
|
||||
if !ok {
|
||||
log.WithField("worker_id", workerID).Info("Scan queue closed")
|
||||
return
|
||||
}
|
||||
log.WithFields(logrus.Fields{
|
||||
"worker_id": workerID,
|
||||
"file": task.AbsFilename,
|
||||
}).Info("Processing scan task")
|
||||
err := scanFileWithClamAV(task.AbsFilename)
|
||||
if err != nil {
|
||||
log.WithFields(logrus.Fields{
|
||||
"worker_id": workerID,
|
||||
"file": task.AbsFilename,
|
||||
"error": err,
|
||||
}).Error("Failed to scan file")
|
||||
} else {
|
||||
log.WithFields(logrus.Fields{
|
||||
"worker_id": workerID,
|
||||
"file": task.AbsFilename,
|
||||
}).Info("Successfully scanned file")
|
||||
}
|
||||
task.Result <- err
|
||||
close(task.Result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize scan worker pool
|
||||
func initializeScanWorkerPool(ctx context.Context) {
|
||||
for i := 0; i < ScanWorkers; i++ {
|
||||
go scanWorker(ctx, i)
|
||||
}
|
||||
log.Infof("Initialized %d scan workers", ScanWorkers)
|
||||
}
|
||||
|
||||
// Setup router with middleware
|
||||
func setupRouter() http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", handleRequest)
|
||||
if conf.Server.MetricsEnabled {
|
||||
mux.Handle("/metrics", promhttp.Handler())
|
||||
// Validate JWT secret if JWT is enabled
|
||||
if c.Security.EnableJWT && strings.TrimSpace(c.Security.JWTSecret) == "" {
|
||||
return errors.New("security.jwtsecret is required when security.enablejwt is true")
|
||||
}
|
||||
|
||||
// Apply middleware
|
||||
handler := loggingMiddleware(mux)
|
||||
handler = recoveryMiddleware(handler)
|
||||
handler = corsMiddleware(handler)
|
||||
return handler
|
||||
}
|
||||
|
||||
// Middleware for logging
|
||||
func loggingMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestsTotal.WithLabelValues(r.Method, r.URL.Path).Inc()
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// Middleware for panic recovery
|
||||
func recoveryMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
log.WithFields(logrus.Fields{
|
||||
"method": r.Method,
|
||||
"url": r.URL.String(),
|
||||
"error": rec,
|
||||
}).Error("Panic recovered in HTTP handler")
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// corsMiddleware handles CORS by setting appropriate headers
|
||||
func corsMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Set CORS headers
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-File-MAC")
|
||||
w.Header().Set("Access-Control-Max-Age", "86400") // Cache preflight response for 1 day
|
||||
|
||||
// Handle preflight OPTIONS request
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
// Proceed to the next handler
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// Handle file uploads and downloads
|
||||
func handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodPost && strings.Contains(r.Header.Get("Content-Type"), "multipart/form-data") {
|
||||
absFilename, err := sanitizeFilePath(conf.Server.StoragePath, strings.TrimPrefix(r.URL.Path, "/"))
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Invalid file path")
|
||||
http.Error(w, "Invalid file path", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
err = handleMultipartUpload(w, r, absFilename)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to handle multipart upload")
|
||||
http.Error(w, "Failed to handle multipart upload", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
return
|
||||
// Validate HMAC secret if JWT is not enabled (as it's the fallback)
|
||||
if !c.Security.EnableJWT && strings.TrimSpace(c.Security.Secret) == "" {
|
||||
return errors.New("security.secret is required for HMAC authentication (when JWT is disabled)")
|
||||
}
|
||||
|
||||
// Get client IP address
|
||||
clientIP := r.Header.Get("X-Real-IP")
|
||||
if clientIP == "" {
|
||||
clientIP = r.Header.Get("X-Forwarded-For")
|
||||
}
|
||||
if clientIP == "" {
|
||||
// Fallback to RemoteAddr
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("Failed to parse RemoteAddr")
|
||||
clientIP = r.RemoteAddr
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateJWTFromRequest extracts and validates a JWT from the request.
|
||||
func validateJWTFromRequest(r *http.Request, secret string) (*jwt.Token, error) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
tokenString := ""
|
||||
|
||||
if authHeader != "" {
|
||||
splitToken := strings.Split(authHeader, "Bearer ")
|
||||
if len(splitToken) == 2 {
|
||||
tokenString = splitToken[1]
|
||||
} else {
|
||||
clientIP = host
|
||||
return nil, errors.New("invalid Authorization header format")
|
||||
}
|
||||
} else {
|
||||
// Fallback to checking 'token' query parameter
|
||||
tokenString = r.URL.Query().Get("token")
|
||||
if tokenString == "" {
|
||||
return nil, errors.New("missing JWT in Authorization header or 'token' query parameter")
|
||||
}
|
||||
}
|
||||
|
||||
// Log the request with the client IP
|
||||
log.WithFields(logrus.Fields{
|
||||
"method": r.Method,
|
||||
"url": r.URL.String(),
|
||||
"remote": clientIP,
|
||||
}).Info("Incoming request")
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return []byte(secret), nil
|
||||
})
|
||||
|
||||
// Parse URL and query parameters
|
||||
p := r.URL.Path
|
||||
a, err := url.ParseQuery(r.URL.RawQuery)
|
||||
if err != nil {
|
||||
log.Warn("Failed to parse query parameters")
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
return nil, fmt.Errorf("JWT validation failed: %w", err)
|
||||
}
|
||||
|
||||
fileStorePath := strings.TrimPrefix(p, "/")
|
||||
if fileStorePath == "" || fileStorePath == "/" {
|
||||
log.Warn("Access to root directory is forbidden")
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
return
|
||||
} else if fileStorePath[0] == '/' {
|
||||
fileStorePath = fileStorePath[1:]
|
||||
if !token.Valid {
|
||||
return nil, errors.New("invalid JWT")
|
||||
}
|
||||
|
||||
absFilename, err := sanitizeFilePath(conf.Server.StoragePath, fileStorePath)
|
||||
if err != nil {
|
||||
log.WithFields(logrus.Fields{
|
||||
"file": fileStorePath,
|
||||
"error": err,
|
||||
}).Warn("Invalid file path")
|
||||
http.Error(w, "Invalid file path", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodPut:
|
||||
handleUpload(w, r, absFilename, fileStorePath, a)
|
||||
case http.MethodHead, http.MethodGet:
|
||||
handleDownload(w, r, absFilename, fileStorePath)
|
||||
case http.MethodOptions:
|
||||
// Handled by NGINX; no action needed
|
||||
w.Header().Set("Allow", "OPTIONS, GET, PUT, HEAD")
|
||||
return
|
||||
default:
|
||||
log.WithField("method", r.Method).Warn("Invalid HTTP method for upload directory")
|
||||
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// Handle file uploads with extension restrictions and HMAC validation
|
||||
func handleUpload(w http.ResponseWriter, r *http.Request, absFilename, fileStorePath string, a url.Values) {
|
||||
// Log the storage path being used
|
||||
log.Infof("Using storage path: %s", conf.Server.StoragePath)
|
||||
// validateHMAC validates the HMAC signature of the request for legacy protocols and POST uploads.
|
||||
func validateHMAC(r *http.Request, secret string) error {
|
||||
log.Debugf("validateHMAC: Validating request to %s with query: %s", r.URL.Path, r.URL.RawQuery)
|
||||
// Check for X-Signature header (for POST uploads)
|
||||
signature := r.Header.Get("X-Signature")
|
||||
if signature != "" {
|
||||
// This is a POST upload with X-Signature header
|
||||
message := r.URL.Path
|
||||
h := hmac.New(sha256.New, []byte(secret))
|
||||
h.Write([]byte(message))
|
||||
expectedSignature := hex.EncodeToString(h.Sum(nil))
|
||||
|
||||
// Determine protocol version based on query parameters
|
||||
var protocolVersion string
|
||||
if a.Get("v2") != "" {
|
||||
protocolVersion = "v2"
|
||||
} else if a.Get("token") != "" {
|
||||
protocolVersion = "token"
|
||||
} else if a.Get("v") != "" {
|
||||
protocolVersion = "v"
|
||||
} else {
|
||||
log.Warn("No HMAC attached to URL. Expecting 'v', 'v2', or 'token' parameter as MAC")
|
||||
http.Error(w, "No HMAC attached to URL. Expecting 'v', 'v2', or 'token' parameter as MAC", http.StatusForbidden)
|
||||
return
|
||||
if !hmac.Equal([]byte(signature), []byte(expectedSignature)) {
|
||||
return errors.New("invalid HMAC signature in X-Signature header")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
log.Debugf("Protocol version determined: %s", protocolVersion)
|
||||
|
||||
// Initialize HMAC
|
||||
mac := hmac.New(sha256.New, []byte(conf.Security.Secret))
|
||||
// Check for legacy URL-based HMAC protocols (v, v2, token)
|
||||
query := r.URL.Query()
|
||||
|
||||
var protocolVersion string
|
||||
var providedMACHex string
|
||||
|
||||
if query.Get("v2") != "" {
|
||||
protocolVersion = "v2"
|
||||
providedMACHex = query.Get("v2")
|
||||
} else if query.Get("token") != "" {
|
||||
protocolVersion = "token"
|
||||
providedMACHex = query.Get("token")
|
||||
} else if query.Get("v") != "" {
|
||||
protocolVersion = "v"
|
||||
providedMACHex = query.Get("v")
|
||||
} else {
|
||||
return errors.New("no HMAC signature found (missing X-Signature header or v/v2/token query parameter)")
|
||||
}
|
||||
|
||||
// Extract file path from URL
|
||||
fileStorePath := strings.TrimPrefix(r.URL.Path, "/")
|
||||
|
||||
// Calculate HMAC based on protocol version (matching legacy behavior)
|
||||
mac := hmac.New(sha256.New, []byte(secret))
|
||||
|
||||
// Calculate MAC based on protocolVersion
|
||||
if protocolVersion == "v" {
|
||||
mac.Write([]byte(fileStorePath + "\x20" + strconv.FormatInt(r.ContentLength, 10)))
|
||||
} else if protocolVersion == "v2" || protocolVersion == "token" {
|
||||
// Legacy v protocol: fileStorePath + "\x20" + contentLength
|
||||
message := fileStorePath + "\x20" + strconv.FormatInt(r.ContentLength, 10)
|
||||
mac.Write([]byte(message))
|
||||
} else {
|
||||
// v2 and token protocols: fileStorePath + "\x00" + contentLength + "\x00" + contentType
|
||||
contentType := mime.TypeByExtension(filepath.Ext(fileStorePath))
|
||||
if contentType == "" {
|
||||
contentType = "application/octet-stream"
|
||||
}
|
||||
mac.Write([]byte(fileStorePath + "\x00" + strconv.FormatInt(r.ContentLength, 10) + "\x00" + contentType))
|
||||
message := fileStorePath + "\x00" + strconv.FormatInt(r.ContentLength, 10) + "\x00" + contentType
|
||||
log.Debugf("validateHMAC: %s protocol message: %q (len=%d)", protocolVersion, message, len(message))
|
||||
mac.Write([]byte(message))
|
||||
}
|
||||
|
||||
calculatedMAC := mac.Sum(nil)
|
||||
log.Debugf("Calculated MAC: %x", calculatedMAC)
|
||||
calculatedMACHex := hex.EncodeToString(calculatedMAC)
|
||||
|
||||
// Decode provided MAC from hex
|
||||
providedMACHex := a.Get(protocolVersion)
|
||||
// Decode provided MAC
|
||||
providedMAC, err := hex.DecodeString(providedMACHex)
|
||||
if err != nil {
|
||||
log.Warn("Invalid MAC encoding")
|
||||
http.Error(w, "Invalid MAC encoding", http.StatusForbidden)
|
||||
return
|
||||
return fmt.Errorf("invalid MAC encoding for %s protocol: %v", protocolVersion, err)
|
||||
}
|
||||
log.Debugf("Provided MAC: %x", providedMAC)
|
||||
|
||||
// Validate the HMAC
|
||||
log.Debugf("validateHMAC: %s protocol - calculated: %s, provided: %s", protocolVersion, calculatedMACHex, providedMACHex)
|
||||
|
||||
// Compare MACs
|
||||
if !hmac.Equal(calculatedMAC, providedMAC) {
|
||||
log.Warn("Invalid MAC")
|
||||
http.Error(w, "Invalid MAC", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
log.Debug("HMAC validation successful")
|
||||
|
||||
// Validate file extension
|
||||
if !isExtensionAllowed(fileStorePath) {
|
||||
log.WithFields(logrus.Fields{
|
||||
// No need to sanitize and validate the file path here since absFilename is already sanitized in handleRequest
|
||||
"file": fileStorePath,
|
||||
"error": err,
|
||||
}).Warn("Invalid file path")
|
||||
http.Error(w, "Invalid file path", http.StatusBadRequest)
|
||||
uploadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
// absFilename = sanitizedFilename
|
||||
|
||||
// Check if there is enough free space
|
||||
err = checkStorageSpace(conf.Server.StoragePath, conf.Server.MinFreeBytes)
|
||||
if err != nil {
|
||||
log.WithFields(logrus.Fields{
|
||||
"storage_path": conf.Server.StoragePath,
|
||||
"error": err,
|
||||
}).Warn("Not enough free space")
|
||||
http.Error(w, "Not enough free space", http.StatusInsufficientStorage)
|
||||
uploadErrorsTotal.Inc()
|
||||
return
|
||||
return fmt.Errorf("invalid MAC for %s protocol", protocolVersion)
|
||||
}
|
||||
|
||||
// Create an UploadTask with a result channel
|
||||
result := make(chan error)
|
||||
task := UploadTask{
|
||||
AbsFilename: absFilename,
|
||||
Request: r,
|
||||
Result: result,
|
||||
}
|
||||
|
||||
// Submit task to the upload queue
|
||||
select {
|
||||
case uploadQueue <- task:
|
||||
// Successfully added to the queue
|
||||
log.Debug("Upload task enqueued successfully")
|
||||
default:
|
||||
// Queue is full
|
||||
log.Warn("Upload queue is full. Rejecting upload")
|
||||
http.Error(w, "Server busy. Try again later.", http.StatusServiceUnavailable)
|
||||
uploadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
|
||||
// Wait for the worker to process the upload
|
||||
err = <-result
|
||||
if err != nil {
|
||||
// The worker has already logged the error; send an appropriate HTTP response
|
||||
http.Error(w, fmt.Sprintf("Upload failed: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Upload was successful
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
log.Debugf("%s HMAC authentication successful for request: %s", protocolVersion, r.URL.Path)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle file downloads
|
||||
func handleDownload(w http.ResponseWriter, r *http.Request, absFilename, fileStorePath string) {
|
||||
fileInfo, err := getFileInfo(absFilename)
|
||||
// validateV3HMAC validates the HMAC signature for v3 protocol (mod_http_upload_external).
|
||||
func validateV3HMAC(r *http.Request, secret string) error {
|
||||
query := r.URL.Query()
|
||||
|
||||
// Extract v3 signature and expires from query parameters
|
||||
signature := query.Get("v3")
|
||||
expiresStr := query.Get("expires")
|
||||
|
||||
if signature == "" {
|
||||
return errors.New("missing v3 signature parameter")
|
||||
}
|
||||
|
||||
if expiresStr == "" {
|
||||
return errors.New("missing expires parameter")
|
||||
}
|
||||
|
||||
// Parse expires timestamp
|
||||
expires, err := strconv.ParseInt(expiresStr, 10, 64)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to get file information")
|
||||
http.Error(w, "Not Found", http.StatusNotFound)
|
||||
downloadErrorsTotal.Inc()
|
||||
return fmt.Errorf("invalid expires parameter: %v", err)
|
||||
}
|
||||
|
||||
// Check if signature has expired
|
||||
now := time.Now().Unix()
|
||||
if now > expires {
|
||||
return errors.New("signature has expired")
|
||||
}
|
||||
|
||||
// Construct message for HMAC verification
|
||||
// Format: METHOD\nEXPIRES\nPATH
|
||||
message := fmt.Sprintf("%s\n%s\n%s", r.Method, expiresStr, r.URL.Path)
|
||||
|
||||
// Calculate expected HMAC signature
|
||||
h := hmac.New(sha256.New, []byte(secret))
|
||||
h.Write([]byte(message))
|
||||
expectedSignature := hex.EncodeToString(h.Sum(nil))
|
||||
|
||||
// Compare signatures
|
||||
if !hmac.Equal([]byte(signature), []byte(expectedSignature)) {
|
||||
return errors.New("invalid v3 HMAC signature")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleUpload handles file uploads.
|
||||
func handleUpload(w http.ResponseWriter, r *http.Request) {
|
||||
startTime := time.Now()
|
||||
activeConnections.Inc()
|
||||
defer activeConnections.Dec()
|
||||
|
||||
// Only allow POST method
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
uploadErrorsTotal.Inc()
|
||||
return
|
||||
} else if fileInfo.IsDir() {
|
||||
log.Warn("Directory listing forbidden")
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
}
|
||||
|
||||
// Authentication
|
||||
if conf.Security.EnableJWT {
|
||||
_, err := validateJWTFromRequest(r, conf.Security.JWTSecret)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("JWT Authentication failed: %v", err), http.StatusUnauthorized)
|
||||
uploadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
log.Debugf("JWT authentication successful for upload request: %s", r.URL.Path)
|
||||
} else {
|
||||
err := validateHMAC(r, conf.Security.Secret)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("HMAC Authentication failed: %v", err), http.StatusUnauthorized)
|
||||
uploadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
log.Debugf("HMAC authentication successful for upload request: %s", r.URL.Path)
|
||||
}
|
||||
|
||||
// Parse multipart form
|
||||
err := r.ParseMultipartForm(32 << 20) // 32MB max memory
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Error parsing multipart form: %v", err), http.StatusBadRequest)
|
||||
uploadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
|
||||
// Get file from form
|
||||
file, header, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Error getting file from form: %v", err), http.StatusBadRequest)
|
||||
uploadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Validate file extension if configured
|
||||
if len(conf.Uploads.AllowedExtensions) > 0 {
|
||||
ext := strings.ToLower(filepath.Ext(header.Filename))
|
||||
allowed := false
|
||||
for _, allowedExt := range conf.Uploads.AllowedExtensions {
|
||||
if ext == allowedExt {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !allowed {
|
||||
http.Error(w, fmt.Sprintf("File extension %s not allowed", ext), http.StatusBadRequest)
|
||||
uploadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Generate filename based on configuration
|
||||
var filename string
|
||||
switch conf.Server.FileNaming {
|
||||
case "HMAC":
|
||||
// Generate HMAC-based filename
|
||||
h := hmac.New(sha256.New, []byte(conf.Security.Secret))
|
||||
h.Write([]byte(header.Filename + time.Now().String()))
|
||||
filename = hex.EncodeToString(h.Sum(nil)) + filepath.Ext(header.Filename)
|
||||
default: // "original" or "None"
|
||||
filename = header.Filename
|
||||
}
|
||||
|
||||
// Create full file path
|
||||
storagePath := conf.Server.StoragePath
|
||||
if conf.ISO.Enabled {
|
||||
storagePath = conf.ISO.MountPoint
|
||||
}
|
||||
|
||||
absFilename := filepath.Join(storagePath, filename)
|
||||
|
||||
// Create the file
|
||||
dst, err := os.Create(absFilename)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Error creating file: %v", err), http.StatusInternalServerError)
|
||||
uploadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
defer dst.Close()
|
||||
|
||||
// Copy file content
|
||||
written, err := io.Copy(dst, file)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Error saving file: %v", err), http.StatusInternalServerError)
|
||||
uploadErrorsTotal.Inc()
|
||||
// Clean up partial file
|
||||
os.Remove(absFilename)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle deduplication if enabled
|
||||
if conf.Server.DeduplicationEnabled {
|
||||
ctx := context.Background()
|
||||
err = handleDeduplication(ctx, absFilename)
|
||||
if err != nil {
|
||||
log.Warnf("Deduplication failed for %s: %v", absFilename, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Update metrics
|
||||
duration := time.Since(startTime)
|
||||
uploadDuration.Observe(duration.Seconds())
|
||||
uploadsTotal.Inc()
|
||||
uploadSizeBytes.Observe(float64(written))
|
||||
|
||||
// Return success response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
response := map[string]interface{}{
|
||||
"success": true,
|
||||
"filename": filename,
|
||||
"size": written,
|
||||
"duration": duration.String(),
|
||||
}
|
||||
|
||||
// Create JSON response
|
||||
if jsonBytes, err := json.Marshal(response); err == nil {
|
||||
w.Write(jsonBytes)
|
||||
} else {
|
||||
fmt.Fprintf(w, `{"success": true, "filename": "%s", "size": %d}`, filename, written)
|
||||
}
|
||||
|
||||
log.Infof("Successfully uploaded %s (%s) in %s", filename, formatBytes(written), duration)
|
||||
}
|
||||
|
||||
// handleDownload handles file downloads.
|
||||
func handleDownload(w http.ResponseWriter, r *http.Request) {
|
||||
startTime := time.Now()
|
||||
activeConnections.Inc()
|
||||
defer activeConnections.Dec()
|
||||
|
||||
// Authentication
|
||||
if conf.Security.EnableJWT {
|
||||
_, err := validateJWTFromRequest(r, conf.Security.JWTSecret)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("JWT Authentication failed: %v", err), http.StatusUnauthorized)
|
||||
downloadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
log.Debugf("JWT authentication successful for download request: %s", r.URL.Path)
|
||||
} else {
|
||||
err := validateHMAC(r, conf.Security.Secret)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("HMAC Authentication failed: %v", err), http.StatusUnauthorized)
|
||||
downloadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
log.Debugf("HMAC authentication successful for download request: %s", r.URL.Path)
|
||||
}
|
||||
|
||||
filename := strings.TrimPrefix(r.URL.Path, "/download/")
|
||||
if filename == "" {
|
||||
http.Error(w, "Filename not specified", http.StatusBadRequest)
|
||||
downloadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
|
||||
absFilename, err := sanitizeFilePath(conf.Server.StoragePath, filename) // Use sanitizeFilePath from helpers.go
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Invalid file path: %v", err), http.StatusBadRequest)
|
||||
downloadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
|
||||
fileInfo, err := os.Stat(absFilename)
|
||||
if os.IsNotExist(err) {
|
||||
http.Error(w, "File not found", http.StatusNotFound)
|
||||
downloadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Error accessing file: %v", err), http.StatusInternalServerError)
|
||||
downloadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
|
||||
if fileInfo.IsDir() {
|
||||
http.Error(w, "Cannot download a directory", http.StatusBadRequest)
|
||||
downloadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
|
||||
file, err := os.Open(absFilename)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Error opening file: %v", err), http.StatusInternalServerError)
|
||||
downloadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
w.Header().Set("Content-Disposition", "attachment; filename=\""+filepath.Base(absFilename)+"\"")
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", fileInfo.Size()))
|
||||
|
||||
// Use a pooled buffer for copying
|
||||
bufPtr := bufferPool.Get().(*[]byte)
|
||||
defer bufferPool.Put(bufPtr)
|
||||
buf := *bufPtr
|
||||
|
||||
n, err := io.CopyBuffer(w, file, buf)
|
||||
if err != nil {
|
||||
log.Errorf("Error during download of %s: %v", absFilename, err)
|
||||
// Don't write http.Error here if headers already sent
|
||||
downloadErrorsTotal.Inc()
|
||||
return // Ensure we don't try to record metrics if there was an error during copy
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
downloadDuration.Observe(duration.Seconds())
|
||||
downloadsTotal.Inc()
|
||||
downloadSizeBytes.Observe(float64(n))
|
||||
log.Infof("Successfully downloaded %s (%s) in %s", absFilename, formatBytes(n), duration)
|
||||
}
|
||||
|
||||
// handleV3Upload handles PUT requests for v3 protocol (mod_http_upload_external).
|
||||
func handleV3Upload(w http.ResponseWriter, r *http.Request) {
|
||||
startTime := time.Now()
|
||||
activeConnections.Inc()
|
||||
defer activeConnections.Dec()
|
||||
|
||||
// Only allow PUT method for v3
|
||||
if r.Method != http.MethodPut {
|
||||
http.Error(w, "Method not allowed for v3 uploads", http.StatusMethodNotAllowed)
|
||||
uploadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
|
||||
// Validate v3 HMAC signature
|
||||
err := validateV3HMAC(r, conf.Security.Secret)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("v3 Authentication failed: %v", err), http.StatusUnauthorized)
|
||||
uploadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
log.Debugf("v3 HMAC authentication successful for upload request: %s", r.URL.Path)
|
||||
|
||||
// Extract filename from the URL path
|
||||
// Path format: /uuid/subdir/filename.ext
|
||||
pathParts := strings.Split(strings.Trim(r.URL.Path, "/"), "/")
|
||||
if len(pathParts) < 1 {
|
||||
http.Error(w, "Invalid upload path", http.StatusBadRequest)
|
||||
uploadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
|
||||
// Use the last part as filename
|
||||
originalFilename := pathParts[len(pathParts)-1]
|
||||
if originalFilename == "" {
|
||||
http.Error(w, "No filename specified", http.StatusBadRequest)
|
||||
uploadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
|
||||
// Validate file extension if configured
|
||||
if len(conf.Uploads.AllowedExtensions) > 0 {
|
||||
ext := strings.ToLower(filepath.Ext(originalFilename))
|
||||
allowed := false
|
||||
for _, allowedExt := range conf.Uploads.AllowedExtensions {
|
||||
if ext == allowedExt {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !allowed {
|
||||
http.Error(w, fmt.Sprintf("File extension %s not allowed", ext), http.StatusBadRequest)
|
||||
uploadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Generate filename based on configuration
|
||||
var filename string
|
||||
switch conf.Server.FileNaming {
|
||||
case "HMAC":
|
||||
// Generate HMAC-based filename
|
||||
h := hmac.New(sha256.New, []byte(conf.Security.Secret))
|
||||
h.Write([]byte(originalFilename + time.Now().String()))
|
||||
filename = hex.EncodeToString(h.Sum(nil)) + filepath.Ext(originalFilename)
|
||||
default: // "original" or "None"
|
||||
filename = originalFilename
|
||||
}
|
||||
|
||||
// Create full file path
|
||||
storagePath := conf.Server.StoragePath
|
||||
if conf.ISO.Enabled {
|
||||
storagePath = conf.ISO.MountPoint
|
||||
}
|
||||
|
||||
absFilename := filepath.Join(storagePath, filename)
|
||||
|
||||
// Create the file
|
||||
dst, err := os.Create(absFilename)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Error creating file: %v", err), http.StatusInternalServerError)
|
||||
uploadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
defer dst.Close()
|
||||
|
||||
// Copy file content from request body
|
||||
written, err := io.Copy(dst, r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Error saving file: %v", err), http.StatusInternalServerError)
|
||||
uploadErrorsTotal.Inc()
|
||||
// Clean up partial file
|
||||
os.Remove(absFilename)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle deduplication if enabled
|
||||
if conf.Server.DeduplicationEnabled {
|
||||
ctx := context.Background()
|
||||
err = handleDeduplication(ctx, absFilename)
|
||||
if err != nil {
|
||||
log.Warnf("Deduplication failed for %s: %v", absFilename, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Update metrics
|
||||
duration := time.Since(startTime)
|
||||
uploadDuration.Observe(duration.Seconds())
|
||||
uploadsTotal.Inc()
|
||||
uploadSizeBytes.Observe(float64(written))
|
||||
|
||||
// Return success response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
response := map[string]interface{}{
|
||||
"success": true,
|
||||
"filename": filename,
|
||||
"size": written,
|
||||
"duration": duration.String(),
|
||||
}
|
||||
|
||||
// Create JSON response
|
||||
if jsonBytes, err := json.Marshal(response); err == nil {
|
||||
w.Write(jsonBytes)
|
||||
} else {
|
||||
fmt.Fprintf(w, `{"success": true, "filename": "%s", "size": %d}`, filename, written)
|
||||
}
|
||||
|
||||
log.Infof("Successfully uploaded %s via v3 protocol (%s) in %s", filename, formatBytes(written), duration)
|
||||
}
|
||||
|
||||
// handleLegacyUpload handles PUT requests for legacy protocols (v, v2, token).
|
||||
func handleLegacyUpload(w http.ResponseWriter, r *http.Request) {
|
||||
startTime := time.Now()
|
||||
activeConnections.Inc()
|
||||
defer activeConnections.Dec()
|
||||
|
||||
log.Debugf("handleLegacyUpload: Processing request to %s with query: %s", r.URL.Path, r.URL.RawQuery)
|
||||
|
||||
// Only allow PUT method for legacy uploads
|
||||
if r.Method != http.MethodPut {
|
||||
http.Error(w, "Method not allowed for legacy uploads", http.StatusMethodNotAllowed)
|
||||
uploadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
|
||||
// Validate legacy HMAC signature
|
||||
err := validateHMAC(r, conf.Security.Secret)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Legacy Authentication failed: %v", err), http.StatusUnauthorized)
|
||||
uploadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
|
||||
// Extract filename from the URL path
|
||||
fileStorePath := strings.TrimPrefix(r.URL.Path, "/")
|
||||
if fileStorePath == "" {
|
||||
http.Error(w, "No filename specified", http.StatusBadRequest)
|
||||
uploadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
|
||||
// Validate file extension if configured
|
||||
if len(conf.Uploads.AllowedExtensions) > 0 {
|
||||
ext := strings.ToLower(filepath.Ext(fileStorePath))
|
||||
allowed := false
|
||||
for _, allowedExt := range conf.Uploads.AllowedExtensions {
|
||||
if ext == allowedExt {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !allowed {
|
||||
http.Error(w, fmt.Sprintf("File extension %s not allowed", ext), http.StatusBadRequest)
|
||||
uploadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Create full file path
|
||||
storagePath := conf.Server.StoragePath
|
||||
if conf.ISO.Enabled {
|
||||
storagePath = conf.ISO.MountPoint
|
||||
}
|
||||
|
||||
// Generate filename based on configuration
|
||||
var absFilename string
|
||||
var filename string
|
||||
switch conf.Server.FileNaming {
|
||||
case "HMAC":
|
||||
// Generate HMAC-based filename
|
||||
h := hmac.New(sha256.New, []byte(conf.Security.Secret))
|
||||
h.Write([]byte(fileStorePath + time.Now().String()))
|
||||
filename = hex.EncodeToString(h.Sum(nil)) + filepath.Ext(fileStorePath)
|
||||
absFilename = filepath.Join(storagePath, filename)
|
||||
default: // "original" or "None"
|
||||
// Preserve full directory structure for legacy XMPP compatibility
|
||||
var sanitizeErr error
|
||||
absFilename, sanitizeErr = sanitizeFilePath(storagePath, fileStorePath)
|
||||
if sanitizeErr != nil {
|
||||
http.Error(w, fmt.Sprintf("Invalid file path: %v", sanitizeErr), http.StatusBadRequest)
|
||||
uploadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
filename = filepath.Base(fileStorePath) // For logging purposes
|
||||
}
|
||||
|
||||
// Create directory structure if it doesn't exist
|
||||
if err := os.MkdirAll(filepath.Dir(absFilename), 0755); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Error creating directory: %v", err), http.StatusInternalServerError)
|
||||
uploadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
|
||||
// Create the file
|
||||
dst, err := os.Create(absFilename)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Error creating file: %v", err), http.StatusInternalServerError)
|
||||
uploadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
defer dst.Close()
|
||||
|
||||
// Log upload start for large files
|
||||
if r.ContentLength > 10*1024*1024 { // Log for files > 10MB
|
||||
log.Infof("Starting upload of %s (%.1f MiB)", filename, float64(r.ContentLength)/(1024*1024))
|
||||
}
|
||||
|
||||
// Copy file content from request body with progress reporting
|
||||
written, err := copyWithProgress(dst, r.Body, r.ContentLength, filename)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Error saving file: %v", err), http.StatusInternalServerError)
|
||||
uploadErrorsTotal.Inc()
|
||||
// Clean up partial file
|
||||
os.Remove(absFilename)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle deduplication if enabled
|
||||
if conf.Server.DeduplicationEnabled {
|
||||
ctx := context.Background()
|
||||
err = handleDeduplication(ctx, absFilename)
|
||||
if err != nil {
|
||||
log.Warnf("Deduplication failed for %s: %v", absFilename, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Update metrics
|
||||
duration := time.Since(startTime)
|
||||
uploadDuration.Observe(duration.Seconds())
|
||||
uploadsTotal.Inc()
|
||||
uploadSizeBytes.Observe(float64(written))
|
||||
|
||||
// Return success response (201 Created for legacy compatibility)
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
|
||||
log.Infof("Successfully uploaded %s via legacy protocol (%s) in %s", filename, formatBytes(written), duration)
|
||||
}
|
||||
|
||||
// handleLegacyDownload handles GET/HEAD requests for legacy downloads.
|
||||
func handleLegacyDownload(w http.ResponseWriter, r *http.Request) {
|
||||
startTime := time.Now()
|
||||
activeConnections.Inc()
|
||||
defer activeConnections.Dec()
|
||||
|
||||
// Extract filename from the URL path
|
||||
fileStorePath := strings.TrimPrefix(r.URL.Path, "/")
|
||||
if fileStorePath == "" {
|
||||
http.Error(w, "No filename specified", http.StatusBadRequest)
|
||||
downloadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
|
||||
// Create full file path
|
||||
storagePath := conf.Server.StoragePath
|
||||
if conf.ISO.Enabled {
|
||||
storagePath = conf.ISO.MountPoint
|
||||
}
|
||||
|
||||
absFilename := filepath.Join(storagePath, fileStorePath)
|
||||
|
||||
fileInfo, err := os.Stat(absFilename)
|
||||
if os.IsNotExist(err) {
|
||||
http.Error(w, "File not found", http.StatusNotFound)
|
||||
downloadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Error accessing file: %v", err), http.StatusInternalServerError)
|
||||
downloadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
|
||||
if fileInfo.IsDir() {
|
||||
http.Error(w, "Cannot download a directory", http.StatusBadRequest)
|
||||
downloadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
|
||||
// Set appropriate headers
|
||||
contentType := mime.TypeByExtension(filepath.Ext(fileStorePath))
|
||||
if contentType == "" {
|
||||
contentType = "application/octet-stream"
|
||||
}
|
||||
w.Header().Set("Content-Type", contentType)
|
||||
w.Header().Set("Content-Length", strconv.FormatInt(fileInfo.Size(), 10))
|
||||
|
||||
// Handle resumable downloads
|
||||
if conf.Uploads.ResumableUploadsEnabled {
|
||||
handleResumableDownload(absFilename, w, r, fileInfo.Size())
|
||||
return
|
||||
}
|
||||
|
||||
// For HEAD requests, only send headers
|
||||
if r.Method == http.MethodHead {
|
||||
w.Header().Set("Content-Length", strconv.FormatInt(fileInfo.Size(), 10))
|
||||
downloadsTotal.Inc()
|
||||
return
|
||||
} else {
|
||||
// Measure download duration
|
||||
startTime := time.Now()
|
||||
log.Infof("Initiating download for file: %s", absFilename)
|
||||
http.ServeFile(w, r, absFilename)
|
||||
downloadDuration.Observe(time.Since(startTime).Seconds())
|
||||
downloadSizeBytes.Observe(float64(fileInfo.Size()))
|
||||
downloadsTotal.Inc()
|
||||
log.Infof("File downloaded successfully: %s", absFilename)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Create the file for upload with buffered Writer
|
||||
func createFile(tempFilename string, r *http.Request) error {
|
||||
absDirectory := filepath.Dir(tempFilename)
|
||||
err := os.MkdirAll(absDirectory, os.ModePerm)
|
||||
if err != nil {
|
||||
log.WithError(err).Errorf("Failed to create directory %s", absDirectory)
|
||||
return fmt.Errorf("failed to create directory %s: %w", absDirectory, err)
|
||||
}
|
||||
|
||||
// Open the file for writing
|
||||
targetFile, err := os.OpenFile(tempFilename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
|
||||
if err != nil {
|
||||
log.WithError(err).Errorf("Failed to create file %s", tempFilename)
|
||||
return fmt.Errorf("failed to create file %s: %w", tempFilename, err)
|
||||
}
|
||||
defer targetFile.Close()
|
||||
|
||||
// Use a large buffer for efficient file writing
|
||||
bufferSize := 4 * 1024 * 1024 // 4 MB buffer
|
||||
writer := bufio.NewWriterSize(targetFile, bufferSize)
|
||||
buffer := make([]byte, bufferSize)
|
||||
|
||||
totalBytes := int64(0)
|
||||
for {
|
||||
n, readErr := r.Body.Read(buffer)
|
||||
if n > 0 {
|
||||
totalBytes += int64(n)
|
||||
_, writeErr := writer.Write(buffer[:n])
|
||||
if writeErr != nil {
|
||||
log.WithError(writeErr).Errorf("Failed to write to file %s", tempFilename)
|
||||
return fmt.Errorf("failed to write to file %s: %w", tempFilename, writeErr)
|
||||
}
|
||||
}
|
||||
if readErr != nil {
|
||||
if readErr == io.EOF {
|
||||
break
|
||||
}
|
||||
log.WithError(readErr).Error("Failed to read request body")
|
||||
return fmt.Errorf("failed to read request body: %w", readErr)
|
||||
}
|
||||
}
|
||||
|
||||
err = writer.Flush()
|
||||
if err != nil {
|
||||
log.WithError(err).Errorf("Failed to flush buffer to file %s", tempFilename)
|
||||
return fmt.Errorf("failed to flush buffer to file %s: %w", tempFilename, err)
|
||||
}
|
||||
|
||||
log.WithFields(logrus.Fields{
|
||||
"temp_file": tempFilename,
|
||||
"total_bytes": totalBytes,
|
||||
}).Info("File uploaded successfully")
|
||||
|
||||
uploadSizeBytes.Observe(float64(totalBytes))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Scan the uploaded file with ClamAV (Optional)
|
||||
func scanFileWithClamAV(filePath string) error {
|
||||
log.WithField("file", filePath).Info("Scanning file with ClamAV")
|
||||
|
||||
scanResultChan, err := clamClient.ScanFile(filePath)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to initiate ClamAV scan")
|
||||
return fmt.Errorf("failed to initiate ClamAV scan: %w", err)
|
||||
}
|
||||
|
||||
// Receive scan result
|
||||
scanResult := <-scanResultChan
|
||||
if scanResult == nil {
|
||||
log.Error("Failed to receive scan result from ClamAV")
|
||||
return fmt.Errorf("failed to receive scan result from ClamAV")
|
||||
}
|
||||
|
||||
// Handle scan result
|
||||
switch scanResult.Status {
|
||||
case clamd.RES_OK:
|
||||
log.WithField("file", filePath).Info("ClamAV scan passed")
|
||||
return nil
|
||||
case clamd.RES_FOUND:
|
||||
log.WithFields(logrus.Fields{
|
||||
"file": filePath,
|
||||
"description": scanResult.Description,
|
||||
}).Warn("ClamAV detected a virus")
|
||||
return fmt.Errorf("virus detected: %s", scanResult.Description)
|
||||
default:
|
||||
log.WithFields(logrus.Fields{
|
||||
"file": filePath,
|
||||
"status": scanResult.Status,
|
||||
"description": scanResult.Description,
|
||||
}).Warn("ClamAV scan returned unexpected status")
|
||||
return fmt.Errorf("ClamAV scan returned unexpected status: %s", scanResult.Description)
|
||||
}
|
||||
}
|
||||
|
||||
// initClamAV initializes the ClamAV client and logs the status
|
||||
func initClamAV(socket string) (*clamd.Clamd, error) {
|
||||
if socket == "" {
|
||||
log.Error("ClamAV socket path is not configured.")
|
||||
return nil, fmt.Errorf("ClamAV socket path is not configured")
|
||||
}
|
||||
|
||||
clamClient := clamd.NewClamd("unix:" + socket)
|
||||
err := clamClient.Ping()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to connect to ClamAV at %s: %v", socket, err)
|
||||
return nil, fmt.Errorf("failed to connect to ClamAV: %w", err)
|
||||
}
|
||||
|
||||
log.Info("Connected to ClamAV successfully.")
|
||||
return clamClient, nil
|
||||
}
|
||||
|
||||
// Handle resumable downloads
|
||||
func handleResumableDownload(absFilename string, w http.ResponseWriter, r *http.Request, fileSize int64) {
|
||||
rangeHeader := r.Header.Get("Range")
|
||||
if rangeHeader == "" {
|
||||
// If no Range header, serve the full file
|
||||
startTime := time.Now()
|
||||
http.ServeFile(w, r, absFilename)
|
||||
downloadDuration.Observe(time.Since(startTime).Seconds())
|
||||
downloadSizeBytes.Observe(float64(fileSize))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
downloadsTotal.Inc()
|
||||
return
|
||||
}
|
||||
|
||||
// Parse Range header
|
||||
ranges := strings.Split(strings.TrimPrefix(rangeHeader, "bytes="), "-")
|
||||
if len(ranges) != 2 {
|
||||
http.Error(w, "Invalid Range", http.StatusRequestedRangeNotSatisfiable)
|
||||
downloadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
|
||||
start, err := strconv.ParseInt(ranges[0], 10, 64)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid Range", http.StatusRequestedRangeNotSatisfiable)
|
||||
downloadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
|
||||
// Calculate end byte
|
||||
end := fileSize - 1
|
||||
if ranges[1] != "" {
|
||||
end, err = strconv.ParseInt(ranges[1], 10, 64)
|
||||
if err != nil || end >= fileSize {
|
||||
http.Error(w, "Invalid Range", http.StatusRequestedRangeNotSatisfiable)
|
||||
downloadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Set response headers for partial content
|
||||
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, fileSize))
|
||||
w.Header().Set("Content-Length", strconv.FormatInt(end-start+1, 10))
|
||||
w.Header().Set("Accept-Ranges", "bytes")
|
||||
w.WriteHeader(http.StatusPartialContent)
|
||||
|
||||
// Serve the requested byte range
|
||||
// For GET requests, serve the file
|
||||
file, err := os.Open(absFilename)
|
||||
if err != nil {
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
http.Error(w, fmt.Sprintf("Error opening file: %v", err), http.StatusInternalServerError)
|
||||
downloadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Seek to the start byte
|
||||
_, err = file.Seek(start, 0)
|
||||
// Use a pooled buffer for copying
|
||||
bufPtr := bufferPool.Get().(*[]byte)
|
||||
defer bufferPool.Put(bufPtr)
|
||||
buf := *bufPtr
|
||||
|
||||
n, err := io.CopyBuffer(w, file, buf)
|
||||
if err != nil {
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
log.Errorf("Error during download of %s: %v", absFilename, err)
|
||||
downloadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
|
||||
// Create a buffer and copy the specified range to the response writer
|
||||
buffer := make([]byte, 32*1024) // 32KB buffer
|
||||
remaining := end - start + 1
|
||||
startTime := time.Now()
|
||||
for remaining > 0 {
|
||||
if int64(len(buffer)) > remaining {
|
||||
buffer = buffer[:remaining]
|
||||
}
|
||||
n, err := file.Read(buffer)
|
||||
if n > 0 {
|
||||
if _, writeErr := w.Write(buffer[:n]); writeErr != nil {
|
||||
log.WithError(writeErr).Error("Failed to write to response")
|
||||
downloadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
remaining -= int64(n)
|
||||
}
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
log.WithError(err).Error("Error reading file during resumable download")
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
downloadErrorsTotal.Inc()
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
downloadDuration.Observe(time.Since(startTime).Seconds())
|
||||
downloadSizeBytes.Observe(float64(end - start + 1))
|
||||
duration := time.Since(startTime)
|
||||
downloadDuration.Observe(duration.Seconds())
|
||||
downloadsTotal.Inc()
|
||||
downloadSizeBytes.Observe(float64(n))
|
||||
log.Infof("Successfully downloaded %s (%s) in %s", absFilename, formatBytes(n), duration)
|
||||
}
|
||||
|
||||
// Handle chunked uploads with bufio.Writer
|
||||
func handleChunkedUpload(tempFilename string, r *http.Request) error {
|
||||
log.WithField("file", tempFilename).Info("Handling chunked upload to temporary file")
|
||||
// printValidationChecks prints all available validation checks
|
||||
func printValidationChecks() {
|
||||
fmt.Println("HMAC File Server Configuration Validation Checks")
|
||||
fmt.Println("=================================================")
|
||||
fmt.Println()
|
||||
|
||||
// Ensure the directory exists
|
||||
absDirectory := filepath.Dir(tempFilename)
|
||||
err := os.MkdirAll(absDirectory, os.ModePerm)
|
||||
if err != nil {
|
||||
log.WithError(err).Errorf("Failed to create directory %s for chunked upload", absDirectory)
|
||||
return fmt.Errorf("failed to create directory %s: %w", absDirectory, err)
|
||||
}
|
||||
fmt.Println("🔍 CORE VALIDATION CHECKS:")
|
||||
fmt.Println(" ✓ server.* - Server configuration (ports, paths, protocols)")
|
||||
fmt.Println(" ✓ security.* - Security settings (secrets, JWT, authentication)")
|
||||
fmt.Println(" ✓ logging.* - Logging configuration (levels, files, rotation)")
|
||||
fmt.Println(" ✓ timeouts.* - Timeout settings (read, write, idle)")
|
||||
fmt.Println(" ✓ uploads.* - Upload configuration (extensions, chunk size)")
|
||||
fmt.Println(" ✓ downloads.* - Download configuration (extensions, chunk size)")
|
||||
fmt.Println(" ✓ workers.* - Worker pool configuration (count, queue size)")
|
||||
fmt.Println(" ✓ redis.* - Redis configuration (address, credentials)")
|
||||
fmt.Println(" ✓ clamav.* - ClamAV antivirus configuration")
|
||||
fmt.Println(" ✓ versioning.* - File versioning configuration")
|
||||
fmt.Println(" ✓ deduplication.* - File deduplication configuration")
|
||||
fmt.Println(" ✓ iso.* - ISO filesystem configuration")
|
||||
fmt.Println()
|
||||
|
||||
targetFile, err := os.OpenFile(tempFilename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to open temporary file for chunked upload")
|
||||
return err
|
||||
}
|
||||
defer targetFile.Close()
|
||||
fmt.Println("🔐 SECURITY CHECKS:")
|
||||
fmt.Println(" ✓ Secret strength analysis (length, entropy, patterns)")
|
||||
fmt.Println(" ✓ Default/example value detection")
|
||||
fmt.Println(" ✓ JWT algorithm security recommendations")
|
||||
fmt.Println(" ✓ Network binding security (0.0.0.0 warnings)")
|
||||
fmt.Println(" ✓ File permission analysis")
|
||||
fmt.Println(" ✓ Debug logging security implications")
|
||||
fmt.Println()
|
||||
|
||||
writer := bufio.NewWriterSize(targetFile, int(conf.Uploads.ChunkSize))
|
||||
buffer := make([]byte, conf.Uploads.ChunkSize)
|
||||
fmt.Println("⚡ PERFORMANCE CHECKS:")
|
||||
fmt.Println(" ✓ Worker count vs CPU cores optimization")
|
||||
fmt.Println(" ✓ Queue size vs memory usage analysis")
|
||||
fmt.Println(" ✓ Timeout configuration balance")
|
||||
fmt.Println(" ✓ Large file handling preparation")
|
||||
fmt.Println(" ✓ Memory-intensive configuration detection")
|
||||
fmt.Println()
|
||||
|
||||
totalBytes := int64(0)
|
||||
for {
|
||||
n, err := r.Body.Read(buffer)
|
||||
if n > 0 {
|
||||
totalBytes += int64(n)
|
||||
_, writeErr := writer.Write(buffer[:n])
|
||||
if writeErr != nil {
|
||||
log.WithError(writeErr).Error("Failed to write chunk to temporary file")
|
||||
return writeErr
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break // Finished reading the body
|
||||
}
|
||||
log.WithError(err).Error("Error reading from request body")
|
||||
return err
|
||||
}
|
||||
}
|
||||
fmt.Println("🌐 CONNECTIVITY CHECKS:")
|
||||
fmt.Println(" ✓ Redis server connectivity testing")
|
||||
fmt.Println(" ✓ ClamAV socket accessibility")
|
||||
fmt.Println(" ✓ Network address format validation")
|
||||
fmt.Println(" ✓ DNS resolution testing")
|
||||
fmt.Println()
|
||||
|
||||
err = writer.Flush()
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to flush buffer to temporary file")
|
||||
return err
|
||||
}
|
||||
fmt.Println("💾 SYSTEM RESOURCE CHECKS:")
|
||||
fmt.Println(" ✓ CPU core availability analysis")
|
||||
fmt.Println(" ✓ Memory usage monitoring")
|
||||
fmt.Println(" ✓ Disk space validation")
|
||||
fmt.Println(" ✓ Directory write permissions")
|
||||
fmt.Println(" ✓ Goroutine count analysis")
|
||||
fmt.Println()
|
||||
|
||||
log.WithFields(logrus.Fields{
|
||||
"temp_file": tempFilename,
|
||||
"total_bytes": totalBytes,
|
||||
}).Info("Chunked upload completed successfully")
|
||||
fmt.Println("🔄 CROSS-SECTION VALIDATION:")
|
||||
fmt.Println(" ✓ Path conflict detection")
|
||||
fmt.Println(" ✓ Extension compatibility checks")
|
||||
fmt.Println(" ✓ Configuration consistency validation")
|
||||
fmt.Println()
|
||||
|
||||
uploadSizeBytes.Observe(float64(totalBytes))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get file information with caching
|
||||
func getFileInfo(absFilename string) (os.FileInfo, error) {
|
||||
if cachedInfo, found := fileInfoCache.Get(absFilename); found {
|
||||
if info, ok := cachedInfo.(os.FileInfo); ok {
|
||||
return info, nil
|
||||
}
|
||||
}
|
||||
|
||||
fileInfo, err := os.Stat(absFilename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fileInfoCache.Set(absFilename, fileInfo, cache.DefaultExpiration)
|
||||
return fileInfo, nil
|
||||
}
|
||||
|
||||
// Monitor network changes
|
||||
func monitorNetwork(ctx context.Context) {
|
||||
currentIP := getCurrentIPAddress() // Placeholder for the current IP address
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info("Stopping network monitor.")
|
||||
return
|
||||
case <-time.After(10 * time.Second):
|
||||
newIP := getCurrentIPAddress()
|
||||
if newIP != currentIP && newIP != "" {
|
||||
currentIP = newIP
|
||||
select {
|
||||
case networkEvents <- NetworkEvent{Type: "IP_CHANGE", Details: currentIP}:
|
||||
log.WithField("new_ip", currentIP).Info("Queued IP_CHANGE event")
|
||||
default:
|
||||
log.Warn("Network event channel is full. Dropping IP_CHANGE event.")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle network events
|
||||
func handleNetworkEvents(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info("Stopping network event handler.")
|
||||
return
|
||||
case event, ok := <-networkEvents:
|
||||
if !ok {
|
||||
log.Info("Network events channel closed.")
|
||||
return
|
||||
}
|
||||
switch event.Type {
|
||||
case "IP_CHANGE":
|
||||
log.WithField("new_ip", event.Details).Info("Network change detected")
|
||||
// Example: Update Prometheus gauge or trigger alerts
|
||||
// activeConnections.Set(float64(getActiveConnections()))
|
||||
}
|
||||
// Additional event types can be handled here
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get current IP address (example)
|
||||
func getCurrentIPAddress() string {
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to get network interfaces")
|
||||
return ""
|
||||
}
|
||||
|
||||
for _, iface := range interfaces {
|
||||
if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagLoopback != 0 {
|
||||
continue // Skip interfaces that are down or loopback
|
||||
}
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
log.WithError(err).Errorf("Failed to get addresses for interface %s", iface.Name)
|
||||
continue
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
if ipnet, ok := addr.(*net.IPNet); ok && ipnet.IP.IsGlobalUnicast() && ipnet.IP.To4() != nil {
|
||||
return ipnet.IP.String()
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// setupGracefulShutdown sets up handling for graceful server shutdown
|
||||
func setupGracefulShutdown(server *http.Server, cancel context.CancelFunc) {
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
sig := <-quit
|
||||
log.Infof("Received signal %s. Initiating shutdown...", sig)
|
||||
|
||||
// Create a deadline to wait for.
|
||||
ctxShutdown, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer shutdownCancel()
|
||||
|
||||
// Attempt graceful shutdown
|
||||
if err := server.Shutdown(ctxShutdown); err != nil {
|
||||
log.Errorf("Server shutdown failed: %v", err)
|
||||
} else {
|
||||
log.Info("Server shutdown gracefully.")
|
||||
}
|
||||
|
||||
// Signal other goroutines to stop
|
||||
cancel()
|
||||
|
||||
// Close the upload, scan, and network event channels
|
||||
close(uploadQueue)
|
||||
log.Info("Upload queue closed.")
|
||||
close(scanQueue)
|
||||
log.Info("Scan queue closed.")
|
||||
close(networkEvents)
|
||||
log.Info("Network events channel closed.")
|
||||
|
||||
log.Info("Shutdown process completed. Exiting application.")
|
||||
os.Exit(0)
|
||||
}()
|
||||
}
|
||||
|
||||
// Initialize Redis client
|
||||
func initRedis() {
|
||||
if !conf.Redis.RedisEnabled {
|
||||
log.Info("Redis is disabled in configuration.")
|
||||
return
|
||||
}
|
||||
|
||||
redisClient = redis.NewClient(&redis.Options{
|
||||
Addr: conf.Redis.RedisAddr,
|
||||
Password: conf.Redis.RedisPassword,
|
||||
DB: conf.Redis.RedisDBIndex,
|
||||
})
|
||||
|
||||
// Test the Redis connection
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, err := redisClient.Ping(ctx).Result()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to connect to Redis: %v", err)
|
||||
}
|
||||
log.Info("Connected to Redis successfully")
|
||||
|
||||
// Set initial connection status
|
||||
mu.Lock()
|
||||
redisConnected = true
|
||||
mu.Unlock()
|
||||
|
||||
// Start monitoring Redis health
|
||||
go MonitorRedisHealth(context.Background(), redisClient, parseDuration(conf.Redis.RedisHealthCheckInterval))
|
||||
}
|
||||
|
||||
// MonitorRedisHealth periodically checks Redis connectivity and updates redisConnected status.
|
||||
func MonitorRedisHealth(ctx context.Context, client *redis.Client, checkInterval time.Duration) {
|
||||
ticker := time.NewTicker(checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info("Stopping Redis health monitor.")
|
||||
return
|
||||
case <-ticker.C:
|
||||
err := client.Ping(ctx).Err()
|
||||
mu.Lock()
|
||||
if err != nil {
|
||||
if redisConnected {
|
||||
log.Errorf("Redis health check failed: %v", err)
|
||||
}
|
||||
redisConnected = false
|
||||
} else {
|
||||
if !redisConnected {
|
||||
log.Info("Redis reconnected successfully")
|
||||
}
|
||||
redisConnected = true
|
||||
log.Debug("Redis health check succeeded.")
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to parse duration strings
|
||||
func parseDuration(durationStr string) time.Duration {
|
||||
duration, err := time.ParseDuration(durationStr)
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("Invalid duration format, using default 30s")
|
||||
return 30 * time.Second
|
||||
}
|
||||
return duration
|
||||
}
|
||||
|
||||
// RunFileCleaner periodically deletes files that exceed the FileTTL duration.
|
||||
func runFileCleaner(ctx context.Context, storeDir string, ttl time.Duration) {
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info("Stopping file cleaner.")
|
||||
return
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
err := filepath.Walk(storeDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
if now.Sub(info.ModTime()) > ttl {
|
||||
err := os.Remove(path)
|
||||
if err != nil {
|
||||
log.WithError(err).Errorf("Failed to remove expired file: %s", path)
|
||||
} else {
|
||||
log.Infof("Removed expired file: %s", path)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Error walking store directory for file cleaning")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DeduplicateFiles scans the store directory and removes duplicate files based on SHA256 hash.
|
||||
// It retains one copy of each unique file and replaces duplicates with hard links.
|
||||
func DeduplicateFiles(storeDir string) error {
|
||||
hashMap := make(map[string]string) // map[hash]filepath
|
||||
var mu sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
fileChan := make(chan string, 100)
|
||||
|
||||
// Worker to process files
|
||||
numWorkers := 10
|
||||
for i := 0; i < numWorkers; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for filePath := range fileChan {
|
||||
hash, err := computeFileHash(filePath)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Errorf("Failed to compute hash for %s", filePath)
|
||||
continue
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
original, exists := hashMap[hash]
|
||||
if !exists {
|
||||
hashMap[hash] = filePath
|
||||
mu.Unlock()
|
||||
continue
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
// Duplicate found
|
||||
err = os.Remove(filePath)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Errorf("Failed to remove duplicate file %s", filePath)
|
||||
continue
|
||||
}
|
||||
|
||||
// Create hard link to the original file
|
||||
err = os.Link(original, filePath)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Errorf("Failed to create hard link from %s to %s", original, filePath)
|
||||
continue
|
||||
}
|
||||
|
||||
logrus.Infof("Removed duplicate %s and linked to %s", filePath, original)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Walk through the store directory
|
||||
err := filepath.Walk(storeDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
logrus.WithError(err).Errorf("Error accessing path %s", path)
|
||||
return nil
|
||||
}
|
||||
if !info.Mode().IsRegular() {
|
||||
return nil
|
||||
}
|
||||
fileChan <- path
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error walking the path %s: %w", storeDir, err)
|
||||
}
|
||||
|
||||
close(fileChan)
|
||||
wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
// computeFileHash computes the SHA256 hash of the given file.
|
||||
func computeFileHash(filePath string) (string, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to open file %s: %w", filePath, err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
hasher := sha256.New()
|
||||
if _, err := io.Copy(hasher, file); err != nil {
|
||||
return "", fmt.Errorf("error hashing file %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
return hex.EncodeToString(hasher.Sum(nil)), nil
|
||||
}
|
||||
|
||||
// Handle multipart uploads
|
||||
func handleMultipartUpload(w http.ResponseWriter, r *http.Request, absFilename string) error {
|
||||
err := r.ParseMultipartForm(32 << 20) // 32MB is the default used by FormFile
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to parse multipart form")
|
||||
http.Error(w, "Failed to parse multipart form", http.StatusBadRequest)
|
||||
return err
|
||||
}
|
||||
|
||||
file, handler, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to retrieve file from form data")
|
||||
http.Error(w, "Failed to retrieve file from form data", http.StatusBadRequest)
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Validate file extension
|
||||
if !isExtensionAllowed(handler.Filename) {
|
||||
log.WithFields(logrus.Fields{
|
||||
"filename": handler.Filename,
|
||||
"extension": filepath.Ext(handler.Filename),
|
||||
}).Warn("Attempted upload with disallowed file extension")
|
||||
http.Error(w, "Disallowed file extension. Allowed extensions are: "+strings.Join(conf.Uploads.AllowedExtensions, ", "), http.StatusForbidden)
|
||||
uploadErrorsTotal.Inc()
|
||||
return fmt.Errorf("disallowed file extension")
|
||||
}
|
||||
|
||||
// Create a temporary file
|
||||
tempFilename := absFilename + ".tmp"
|
||||
tempFile, err := os.OpenFile(tempFilename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to create temporary file")
|
||||
http.Error(w, "Failed to create temporary file", http.StatusInternalServerError)
|
||||
return err
|
||||
}
|
||||
defer tempFile.Close()
|
||||
|
||||
// Copy the uploaded file to the temporary file
|
||||
_, err = io.Copy(tempFile, file)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to copy uploaded file to temporary file")
|
||||
http.Error(w, "Failed to copy uploaded file", http.StatusInternalServerError)
|
||||
return err
|
||||
}
|
||||
|
||||
// Perform ClamAV scan on the temporary file
|
||||
if clamClient != nil {
|
||||
err := scanFileWithClamAV(tempFilename)
|
||||
if err != nil {
|
||||
log.WithFields(logrus.Fields{
|
||||
"file": tempFilename,
|
||||
"error": err,
|
||||
}).Warn("ClamAV detected a virus or scan failed")
|
||||
os.Remove(tempFilename)
|
||||
uploadErrorsTotal.Inc()
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Handle file versioning if enabled
|
||||
if conf.Versioning.EnableVersioning {
|
||||
existing, _ := fileExists(absFilename)
|
||||
if existing {
|
||||
err := versionFile(absFilename)
|
||||
if err != nil {
|
||||
log.WithFields(logrus.Fields{
|
||||
"file": absFilename,
|
||||
"error": err,
|
||||
}).Error("Error versioning file")
|
||||
os.Remove(tempFilename)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Move the temporary file to the final destination
|
||||
err = os.Rename(tempFilename, absFilename)
|
||||
if err != nil {
|
||||
log.WithFields(logrus.Fields{
|
||||
"temp_file": tempFilename,
|
||||
"final_file": absFilename,
|
||||
"error": err,
|
||||
}).Error("Failed to move file to final destination")
|
||||
os.Remove(tempFilename)
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithFields(logrus.Fields{
|
||||
"file": absFilename,
|
||||
}).Info("File uploaded and scanned successfully")
|
||||
|
||||
uploadsTotal.Inc()
|
||||
return nil
|
||||
}
|
||||
|
||||
// sanitizeFilePath ensures that the file path is within the designated storage directory
|
||||
func sanitizeFilePath(baseDir, filePath string) (string, error) {
|
||||
// Resolve the absolute path
|
||||
absBaseDir, err := filepath.Abs(baseDir)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to resolve base directory: %w", err)
|
||||
}
|
||||
|
||||
absFilePath, err := filepath.Abs(filepath.Join(absBaseDir, filePath))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to resolve file path: %w", err)
|
||||
}
|
||||
|
||||
// Check if the resolved file path is within the base directory
|
||||
if !strings.HasPrefix(absFilePath, absBaseDir) {
|
||||
return "", fmt.Errorf("invalid file path: %s", filePath)
|
||||
}
|
||||
|
||||
return absFilePath, nil
|
||||
}
|
||||
|
||||
// checkStorageSpace ensures that there is enough free space in the storage path
|
||||
func checkStorageSpace(storagePath string, minFreeBytes int64) error {
|
||||
var stat syscall.Statfs_t
|
||||
err := syscall.Statfs(storagePath, &stat)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get filesystem stats: %w", err)
|
||||
}
|
||||
|
||||
// Calculate available bytes
|
||||
availableBytes := stat.Bavail * uint64(stat.Bsize)
|
||||
if int64(availableBytes) < minFreeBytes {
|
||||
return fmt.Errorf("not enough free space: %d bytes available, %d bytes required", availableBytes, minFreeBytes)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Function to compute SHA256 checksum of a file
|
||||
func computeSHA256(filePath string) (string, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to open file for checksum: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
hasher := sha256.New()
|
||||
if _, err := io.Copy(hasher, file); err != nil {
|
||||
return "", fmt.Errorf("failed to compute checksum: %w", err)
|
||||
}
|
||||
|
||||
return hex.EncodeToString(hasher.Sum(nil)), nil
|
||||
}
|
||||
|
||||
// handleDeduplication handles file deduplication using SHA256 checksum and hard links
|
||||
func handleDeduplication(ctx context.Context, absFilename string) error {
|
||||
// Compute checksum of the uploaded file
|
||||
checksum, err := computeSHA256(absFilename)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to compute SHA256 for %s: %v", absFilename, err)
|
||||
return fmt.Errorf("checksum computation failed: %w", err)
|
||||
}
|
||||
log.Debugf("Computed checksum for %s: %s", absFilename, checksum)
|
||||
|
||||
// Check Redis for existing checksum
|
||||
existingPath, err := redisClient.Get(ctx, checksum).Result()
|
||||
if err != nil && err != redis.Nil {
|
||||
log.Errorf("Redis error while fetching checksum %s: %v", checksum, err)
|
||||
return fmt.Errorf("redis error: %w", err)
|
||||
}
|
||||
|
||||
if err != redis.Nil {
|
||||
// Duplicate found, create hard link
|
||||
log.Infof("Duplicate detected: %s already exists at %s", absFilename, existingPath)
|
||||
err = os.Link(existingPath, absFilename)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to create hard link from %s to %s: %v", existingPath, absFilename, err)
|
||||
return fmt.Errorf("failed to create hard link: %w", err)
|
||||
}
|
||||
log.Infof("Created hard link from %s to %s", existingPath, absFilename)
|
||||
return nil
|
||||
}
|
||||
|
||||
// No duplicate found, store checksum in Redis
|
||||
err = redisClient.Set(ctx, checksum, absFilename, 0).Err()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to store checksum %s in Redis: %v", checksum, err)
|
||||
return fmt.Errorf("failed to store checksum in Redis: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("Stored new file checksum in Redis: %s -> %s", checksum, absFilename)
|
||||
return nil
|
||||
fmt.Println("📋 USAGE EXAMPLES:")
|
||||
fmt.Println(" hmac-file-server --validate-config # Full validation")
|
||||
fmt.Println(" hmac-file-server --check-security # Security checks only")
|
||||
fmt.Println(" hmac-file-server --check-performance # Performance checks only")
|
||||
fmt.Println(" hmac-file-server --check-connectivity # Network checks only")
|
||||
fmt.Println(" hmac-file-server --validate-quiet # Errors only")
|
||||
fmt.Println(" hmac-file-server --validate-verbose # Detailed output")
|
||||
fmt.Println(" hmac-file-server --check-fixable # Auto-fixable issues")
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user