Files
hmac-file-server/cmd/server/main.go
Alexander Renz da403de111 Add test script for large file asynchronous post-processing
- Implemented a comprehensive test script to validate the new asynchronous handling of large file uploads (>1GB).
- The script checks for immediate HTTP responses, verifies server configurations for deduplication and virus scanning, and ensures server responsiveness during rapid uploads.
- Included checks for relevant response headers and session tracking.
- Documented the problem being solved, implementation details, and next steps for deployment and monitoring.
2025-08-26 20:20:05 +00:00

3828 lines
133 KiB
Go
Raw Blame History

// main.go
package main
import (
"bufio"
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"net"
"net/http"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"sync"
"syscall"
"time"
"github.com/dutchcoders/go-clamd" // ClamAV integration
"github.com/go-redis/redis/v8" // Redis integration
jwt "github.com/golang-jwt/jwt/v5"
"github.com/patrickmn/go-cache"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/shirou/gopsutil/cpu"
"github.com/shirou/gopsutil/mem"
"github.com/sirupsen/logrus"
"github.com/spf13/viper"
)
// NetworkResilientSession represents a persistent session for network switching
type NetworkResilientSession struct {
SessionID string `json:"session_id"`
UserJID string `json:"user_jid"`
OriginalToken string `json:"original_token"`
CreatedAt time.Time `json:"created_at"`
LastSeen time.Time `json:"last_seen"`
NetworkHistory []NetworkEvent `json:"network_history"`
UploadContext *UploadContext `json:"upload_context,omitempty"`
RefreshCount int `json:"refresh_count"`
MaxRefreshes int `json:"max_refreshes"`
LastIP string `json:"last_ip"`
UserAgent string `json:"user_agent"`
SecurityLevel int `json:"security_level"` // 1=normal, 2=challenge, 3=reauth
LastSecurityCheck time.Time `json:"last_security_check"`
NetworkChangeCount int `json:"network_change_count"`
StandbyDetected bool `json:"standby_detected"`
LastActivity time.Time `json:"last_activity"`
}
// NetworkEvent tracks network transitions during session
type NetworkEvent struct {
Timestamp time.Time `json:"timestamp"`
FromNetwork string `json:"from_network"`
ToNetwork string `json:"to_network"`
ClientIP string `json:"client_ip"`
UserAgent string `json:"user_agent"`
EventType string `json:"event_type"` // "switch", "resume", "refresh"
}
// UploadContext maintains upload state across network changes and network resilience channels
type UploadContext struct {
Filename string `json:"filename"`
TotalSize int64 `json:"total_size"`
UploadedBytes int64 `json:"uploaded_bytes"`
ChunkSize int64 `json:"chunk_size"`
LastChunk int `json:"last_chunk"`
ETag string `json:"etag,omitempty"`
UploadPath string `json:"upload_path"`
ContentType string `json:"content_type"`
LastUpdate time.Time `json:"last_update"`
SessionID string `json:"session_id"`
PauseChan chan bool `json:"-"`
ResumeChan chan bool `json:"-"`
CancelChan chan bool `json:"-"`
IsPaused bool `json:"is_paused"`
}
// SessionStore manages persistent sessions for network resilience
type SessionStore struct {
storage map[string]*NetworkResilientSession
mutex sync.RWMutex
cleanupTicker *time.Ticker
redisClient *redis.Client
memoryCache *cache.Cache
enabled bool
}
// Global session store
var sessionStore *SessionStore
// Session storage methods
func (s *SessionStore) GetSession(sessionID string) *NetworkResilientSession {
if !s.enabled || sessionID == "" {
return nil
}
s.mutex.RLock()
defer s.mutex.RUnlock()
// Try Redis first if available
if s.redisClient != nil {
ctx := context.Background()
sessionData, err := s.redisClient.Get(ctx, "session:"+sessionID).Result()
if err == nil {
var session NetworkResilientSession
if json.Unmarshal([]byte(sessionData), &session) == nil {
log.Debugf("📊 Session retrieved from Redis: %s", sessionID)
return &session
}
}
}
// Fallback to memory cache
if s.memoryCache != nil {
if sessionData, found := s.memoryCache.Get(sessionID); found {
if session, ok := sessionData.(*NetworkResilientSession); ok {
log.Debugf("📊 Session retrieved from memory: %s", sessionID)
return session
}
}
}
// Fallback to in-memory map
if session, exists := s.storage[sessionID]; exists {
if time.Since(session.LastSeen) < 72*time.Hour {
log.Debugf("📊 Session retrieved from storage: %s", sessionID)
return session
}
}
return nil
}
func (s *SessionStore) StoreSession(sessionID string, session *NetworkResilientSession) {
if !s.enabled || sessionID == "" || session == nil {
return
}
s.mutex.Lock()
defer s.mutex.Unlock()
session.LastSeen = time.Now()
// Store in Redis if available
if s.redisClient != nil {
ctx := context.Background()
sessionData, err := json.Marshal(session)
if err == nil {
s.redisClient.Set(ctx, "session:"+sessionID, sessionData, 72*time.Hour)
log.Debugf("📊 Session stored in Redis: %s", sessionID)
}
}
// Store in memory cache
if s.memoryCache != nil {
s.memoryCache.Set(sessionID, session, 72*time.Hour)
log.Debugf("📊 Session stored in memory: %s", sessionID)
}
// Store in local map as final fallback
s.storage[sessionID] = session
log.Debugf("📊 Session stored in local storage: %s", sessionID)
}
func (s *SessionStore) DeleteSession(sessionID string) {
if !s.enabled || sessionID == "" {
return
}
s.mutex.Lock()
defer s.mutex.Unlock()
// Remove from Redis
if s.redisClient != nil {
ctx := context.Background()
s.redisClient.Del(ctx, "session:"+sessionID)
}
// Remove from memory cache
if s.memoryCache != nil {
s.memoryCache.Delete(sessionID)
}
// Remove from local storage
delete(s.storage, sessionID)
log.Debugf("📊 Session deleted: %s", sessionID)
}
func (s *SessionStore) cleanupRoutine() {
if !s.enabled {
return
}
for range s.cleanupTicker.C {
s.mutex.Lock()
for sessionID, session := range s.storage {
if time.Since(session.LastSeen) > 72*time.Hour {
delete(s.storage, sessionID)
log.Debugf("🧹 Cleaned up expired session: %s", sessionID)
}
}
s.mutex.Unlock()
}
}
// Initialize session store
func initializeSessionStore() {
enabled := viper.GetBool("session_store.enabled")
if !enabled {
log.Infof("📊 Session store disabled in configuration")
sessionStore = &SessionStore{enabled: false}
return
}
sessionStore = &SessionStore{
storage: make(map[string]*NetworkResilientSession),
cleanupTicker: time.NewTicker(30 * time.Minute),
enabled: true,
}
// Initialize memory cache
sessionStore.memoryCache = cache.New(72*time.Hour, 1*time.Hour)
// Optional Redis backend
if redisURL := viper.GetString("session_store.redis_url"); redisURL != "" {
opt, err := redis.ParseURL(redisURL)
if err == nil {
sessionStore.redisClient = redis.NewClient(opt)
// Test Redis connection
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := sessionStore.redisClient.Ping(ctx).Err(); err == nil {
log.Infof("📊 Session store: Redis backend initialized (%s)", redisURL)
} else {
log.Warnf("📊 Session store: Redis connection failed, using memory backend: %v", err)
sessionStore.redisClient = nil
}
} else {
log.Warnf("📊 Session store: Invalid Redis URL, using memory backend: %v", err)
}
}
if sessionStore.redisClient == nil {
log.Infof("📊 Session store: Memory backend initialized")
}
// Start cleanup routine
go sessionStore.cleanupRoutine()
}
// Generate session ID from user and context
func generateSessionID(userJID, filename string) string {
h := sha256.New()
h.Write([]byte(fmt.Sprintf("%s:%s:%d", userJID, filename, time.Now().UnixNano())))
return fmt.Sprintf("sess_%s", hex.EncodeToString(h.Sum(nil))[:16])
}
// Generate session ID for multi-upload scenarios
func generateUploadSessionID(uploadType, userAgent, clientIP string) string {
h := sha256.New()
h.Write([]byte(fmt.Sprintf("%s:%s:%s:%d", uploadType, userAgent, clientIP, time.Now().UnixNano())))
return fmt.Sprintf("upload_%s", hex.EncodeToString(h.Sum(nil))[:16])
}
// Detect network context for intelligent switching
func detectNetworkContext(r *http.Request) string {
clientIP := getClientIP(r)
userAgent := r.Header.Get("User-Agent")
xForwardedFor := r.Header.Get("X-Forwarded-For")
// Detect network type based on IP ranges and headers
if strings.Contains(xForwardedFor, "10.") || strings.Contains(clientIP, "10.") {
return "cellular_lte"
} else if strings.Contains(clientIP, "192.168.") || strings.Contains(clientIP, "172.") {
return "wifi_private"
} else if strings.Contains(userAgent, "Mobile") || strings.Contains(userAgent, "Android") {
return "mobile_network"
} else if strings.Contains(clientIP, "127.0.0.1") || strings.Contains(clientIP, "::1") {
return "localhost"
}
return "external_network"
}
// Add session response headers for client tracking
func setSessionHeaders(w http.ResponseWriter, sessionID string) {
w.Header().Set("X-Session-ID", sessionID)
w.Header().Set("X-Session-Timeout", "259200") // 72 hours in seconds
w.Header().Set("X-Network-Resilience", "enabled")
}
// Extract session ID from request
func getSessionIDFromRequest(r *http.Request) string {
// Try header first
if sessionID := r.Header.Get("X-Session-ID"); sessionID != "" {
return sessionID
}
// Try query parameter
if sessionID := r.URL.Query().Get("session_id"); sessionID != "" {
return sessionID
}
// Try from Authorization header (for some XMPP clients)
if auth := r.Header.Get("Authorization"); strings.HasPrefix(auth, "Bearer ") {
token := strings.TrimPrefix(auth, "Bearer ")
// Generate consistent session ID from token
h := sha256.New()
h.Write([]byte(token))
return fmt.Sprintf("auth_%s", hex.EncodeToString(h.Sum(nil))[:16])
}
return ""
}
// parseSize converts a human-readable size string to bytes
func parseSize(sizeStr string) (int64, error) {
sizeStr = strings.TrimSpace(sizeStr)
if len(sizeStr) < 2 {
return 0, fmt.Errorf("invalid size string: %s", sizeStr)
}
unit := strings.ToUpper(sizeStr[len(sizeStr)-2:])
valueStr := sizeStr[:len(sizeStr)-2]
value, err := strconv.Atoi(valueStr)
if err != nil {
return 0, fmt.Errorf("invalid size value: %v", err)
}
switch unit {
case "KB":
return int64(value) * 1024, nil
case "MB":
return int64(value) * 1024 * 1024, nil
case "GB":
return int64(value) * 1024 * 1024 * 1024, nil
default:
return 0, fmt.Errorf("unknown size unit: %s", unit)
}
}
// parseTTL converts a human-readable TTL string to a time.Duration
func parseTTL(ttlStr string) (time.Duration, error) {
ttlStr = strings.ToLower(strings.TrimSpace(ttlStr))
if ttlStr == "" {
return 0, fmt.Errorf("TTL string cannot be empty")
}
var valueStr string
var unit rune
for _, r := range ttlStr {
if r >= '0' && r <= '9' {
valueStr += string(r)
} else {
unit = r
break
}
}
val, err := strconv.Atoi(valueStr)
if err != nil {
return 0, fmt.Errorf("invalid TTL value: %v", err)
}
switch unit {
case 's':
return time.Duration(val) * time.Second, nil
case 'm':
return time.Duration(val) * time.Minute, nil
case 'h':
return time.Duration(val) * time.Hour, nil
case 'd':
return time.Duration(val) * 24 * time.Hour, nil
case 'w':
return time.Duration(val) * 7 * 24 * time.Hour, nil
case 'y':
return time.Duration(val) * 365 * 24 * time.Hour, nil
default:
return 0, fmt.Errorf("unknown TTL unit: %c", unit)
}
}
// Configuration structures
type ServerConfig struct {
ListenAddress string `toml:"listen_address" mapstructure:"listen_address"`
StoragePath string `toml:"storage_path" mapstructure:"storage_path"`
MetricsEnabled bool `toml:"metricsenabled" mapstructure:"metricsenabled"` // Fixed to match config
MetricsPath string `toml:"metrics_path" mapstructure:"metrics_path"`
PidFile string `toml:"pid_file" mapstructure:"pid_file"`
MaxUploadSize string `toml:"max_upload_size" mapstructure:"max_upload_size"`
MaxHeaderBytes int `toml:"max_header_bytes" mapstructure:"max_header_bytes"`
CleanupInterval string `toml:"cleanup_interval" mapstructure:"cleanup_interval"`
MaxFileAge string `toml:"max_file_age" mapstructure:"max_file_age"`
PreCache bool `toml:"pre_cache" mapstructure:"pre_cache"`
PreCacheWorkers int `toml:"pre_cache_workers" mapstructure:"pre_cache_workers"`
PreCacheInterval string `toml:"pre_cache_interval" mapstructure:"pre_cache_interval"`
GlobalExtensions []string `toml:"global_extensions" mapstructure:"global_extensions"`
DeduplicationEnabled bool `toml:"deduplication_enabled" mapstructure:"deduplication_enabled"`
MinFreeBytes string `toml:"min_free_bytes" mapstructure:"min_free_bytes"`
FileNaming string `toml:"file_naming" mapstructure:"file_naming"`
ForceProtocol string `toml:"force_protocol" mapstructure:"force_protocol"`
EnableDynamicWorkers bool `toml:"enable_dynamic_workers" mapstructure:"enable_dynamic_workers"`
WorkerScaleUpThresh int `toml:"worker_scale_up_thresh" mapstructure:"worker_scale_up_thresh"`
WorkerScaleDownThresh int `toml:"worker_scale_down_thresh" mapstructure:"worker_scale_down_thresh"`
UnixSocket bool `toml:"unixsocket" mapstructure:"unixsocket"` // Added missing field from example/logs
MetricsPort string `toml:"metricsport" mapstructure:"metricsport"` // Fixed to match config
FileTTL string `toml:"filettl" mapstructure:"filettl"` // Fixed to match config
FileTTLEnabled bool `toml:"filettlenabled" mapstructure:"filettlenabled"` // Fixed to match config
AutoAdjustWorkers bool `toml:"autoadjustworkers" mapstructure:"autoadjustworkers"` // Fixed to match config
NetworkEvents bool `toml:"networkevents" mapstructure:"networkevents"` // Fixed to match config
PIDFilePath string `toml:"pidfilepath" mapstructure:"pidfilepath"` // Fixed to match config
CleanUponExit bool `toml:"clean_upon_exit" mapstructure:"clean_upon_exit"` // Added missing field
PreCaching bool `toml:"precaching" mapstructure:"precaching"` // Fixed to match config
BindIP string `toml:"bind_ip" mapstructure:"bind_ip"` // Added missing field
}
type UploadsConfig struct {
AllowedExtensions []string `toml:"allowed_extensions" mapstructure:"allowed_extensions"`
ChunkedUploadsEnabled bool `toml:"chunked_uploads_enabled" mapstructure:"chunked_uploads_enabled"`
ChunkSize string `toml:"chunk_size" mapstructure:"chunk_size"`
ResumableUploadsEnabled bool `toml:"resumable_uploads_enabled" mapstructure:"resumable_uploads_enabled"`
SessionTimeout string `toml:"sessiontimeout" mapstructure:"sessiontimeout"`
MaxRetries int `toml:"maxretries" mapstructure:"maxretries"`
}
type DownloadsConfig struct {
AllowedExtensions []string `toml:"allowed_extensions" mapstructure:"allowed_extensions"`
ChunkedDownloadsEnabled bool `toml:"chunked_downloads_enabled" mapstructure:"chunked_downloads_enabled"`
ChunkSize string `toml:"chunk_size" mapstructure:"chunk_size"`
ResumableDownloadsEnabled bool `toml:"resumable_downloads_enabled" mapstructure:"resumable_downloads_enabled"`
}
type SecurityConfig struct {
Secret string `toml:"secret" mapstructure:"secret"`
EnableJWT bool `toml:"enablejwt" mapstructure:"enablejwt"` // Added EnableJWT field
JWTSecret string `toml:"jwtsecret" mapstructure:"jwtsecret"`
JWTAlgorithm string `toml:"jwtalgorithm" mapstructure:"jwtalgorithm"`
JWTExpiration string `toml:"jwtexpiration" mapstructure:"jwtexpiration"`
EnhancedSecurity bool `toml:"enhanced_security" mapstructure:"enhanced_security"`
ChallengeOnNetworkChange bool `toml:"challenge_on_network_change" mapstructure:"challenge_on_network_change"`
ReauthOnLongStandby bool `toml:"reauth_on_long_standby" mapstructure:"reauth_on_long_standby"`
StandbyThresholdMinutes int `toml:"standby_threshold_minutes" mapstructure:"standby_threshold_minutes"`
LongStandbyThresholdHours int `toml:"long_standby_threshold_hours" mapstructure:"long_standby_threshold_hours"`
}
type LoggingConfig struct {
Level string `mapstructure:"level"`
File string `mapstructure:"file"`
MaxSize int `mapstructure:"max_size"`
MaxBackups int `mapstructure:"max_backups"`
MaxAge int `mapstructure:"max_age"`
Compress bool `mapstructure:"compress"`
}
type DeduplicationConfig struct {
Enabled bool `mapstructure:"enabled"`
Directory string `mapstructure:"directory"`
MaxSize string `mapstructure:"maxsize"`
}
type ISOConfig struct {
Enabled bool `mapstructure:"enabled"`
MountPoint string `mapstructure:"mountpoint"`
Size string `mapstructure:"size"`
Charset string `mapstructure:"charset"`
ContainerFile string `mapstructure:"containerfile"` // Added missing field
}
type TimeoutConfig struct {
Read string `mapstructure:"readtimeout" toml:"readtimeout"`
Write string `mapstructure:"writetimeout" toml:"writetimeout"`
Idle string `mapstructure:"idletimeout" toml:"idletimeout"`
Shutdown string `mapstructure:"shutdown" toml:"shutdown"`
}
type VersioningConfig struct {
Enabled bool `mapstructure:"enableversioning" toml:"enableversioning"` // Corrected to match example config
Backend string `mapstructure:"backend" toml:"backend"`
MaxRevs int `mapstructure:"maxversions" toml:"maxversions"` // Corrected to match example config
}
type ClamAVConfig struct {
ClamAVEnabled bool `mapstructure:"clamavenabled"`
ClamAVSocket string `mapstructure:"clamavsocket"`
NumScanWorkers int `mapstructure:"numscanworkers"`
ScanFileExtensions []string `mapstructure:"scanfileextensions"`
MaxScanSize string `mapstructure:"maxscansize"`
}
type RedisConfig struct {
RedisEnabled bool `mapstructure:"redisenabled"`
RedisDBIndex int `mapstructure:"redisdbindex"`
RedisAddr string `mapstructure:"redisaddr"`
RedisPassword string `mapstructure:"redispassword"`
RedisHealthCheckInterval string `mapstructure:"redishealthcheckinterval"`
}
type WorkersConfig struct {
NumWorkers int `mapstructure:"numworkers"`
UploadQueueSize int `mapstructure:"uploadqueuesize"`
}
type FileConfig struct {
}
type BuildConfig struct {
Version string `mapstructure:"version"` // Updated version
}
type NetworkResilienceConfig struct {
FastDetection bool `toml:"fast_detection" mapstructure:"fast_detection"`
QualityMonitoring bool `toml:"quality_monitoring" mapstructure:"quality_monitoring"`
PredictiveSwitching bool `toml:"predictive_switching" mapstructure:"predictive_switching"`
MobileOptimizations bool `toml:"mobile_optimizations" mapstructure:"mobile_optimizations"`
DetectionInterval string `toml:"detection_interval" mapstructure:"detection_interval"`
QualityCheckInterval string `toml:"quality_check_interval" mapstructure:"quality_check_interval"`
MaxDetectionInterval string `toml:"max_detection_interval" mapstructure:"max_detection_interval"`
// Multi-interface support
MultiInterfaceEnabled bool `toml:"multi_interface_enabled" mapstructure:"multi_interface_enabled"`
InterfacePriority []string `toml:"interface_priority" mapstructure:"interface_priority"`
AutoSwitchEnabled bool `toml:"auto_switch_enabled" mapstructure:"auto_switch_enabled"`
SwitchThresholdLatency string `toml:"switch_threshold_latency" mapstructure:"switch_threshold_latency"`
SwitchThresholdPacketLoss float64 `toml:"switch_threshold_packet_loss" mapstructure:"switch_threshold_packet_loss"`
QualityDegradationThreshold float64 `toml:"quality_degradation_threshold" mapstructure:"quality_degradation_threshold"`
MaxSwitchAttempts int `toml:"max_switch_attempts" mapstructure:"max_switch_attempts"`
SwitchDetectionInterval string `toml:"switch_detection_interval" mapstructure:"switch_detection_interval"`
}
// ClientNetworkConfigTOML is used for loading from TOML where timeout is a string
type ClientNetworkConfigTOML struct {
SessionBasedTracking bool `toml:"session_based_tracking" mapstructure:"session_based_tracking"`
AllowIPChanges bool `toml:"allow_ip_changes" mapstructure:"allow_ip_changes"`
SessionMigrationTimeout string `toml:"session_migration_timeout" mapstructure:"session_migration_timeout"`
MaxIPChangesPerSession int `toml:"max_ip_changes_per_session" mapstructure:"max_ip_changes_per_session"`
ClientConnectionDetection bool `toml:"client_connection_detection" mapstructure:"client_connection_detection"`
AdaptToClientNetwork bool `toml:"adapt_to_client_network" mapstructure:"adapt_to_client_network"`
}
// This is the main Config struct to be used
type Config struct {
Server ServerConfig `mapstructure:"server"`
Logging LoggingConfig `mapstructure:"logging"`
Deduplication DeduplicationConfig `mapstructure:"deduplication"` // Added
ISO ISOConfig `mapstructure:"iso"` // Added
Timeouts TimeoutConfig `mapstructure:"timeouts"` // Added
Security SecurityConfig `mapstructure:"security"`
Versioning VersioningConfig `mapstructure:"versioning"` // Added
Uploads UploadsConfig `mapstructure:"uploads"`
Downloads DownloadsConfig `mapstructure:"downloads"`
ClamAV ClamAVConfig `mapstructure:"clamav"`
Redis RedisConfig `mapstructure:"redis"`
Workers WorkersConfig `mapstructure:"workers"`
File FileConfig `mapstructure:"file"`
Build BuildConfig `mapstructure:"build"`
NetworkResilience NetworkResilienceConfig `mapstructure:"network_resilience"`
ClientNetwork ClientNetworkConfigTOML `mapstructure:"client_network_support"`
}
type UploadTask struct {
AbsFilename string
Request *http.Request
Result chan error
}
type ScanTask struct {
AbsFilename string
Result chan error
}
// Add a new field to store the creation date of files
type FileMetadata struct {
CreationDate time.Time
}
// processScan processes a scan task
func processScan(task ScanTask) error {
// Check if ClamAV is enabled before processing
confMutex.RLock()
clamEnabled := conf.ClamAV.ClamAVEnabled
confMutex.RUnlock()
if !clamEnabled {
log.Infof("ClamAV disabled, skipping scan for file: %s", task.AbsFilename)
return nil
}
log.Infof("Started processing scan for file: %s", task.AbsFilename)
semaphore <- struct{}{}
defer func() { <-semaphore }()
err := scanFileWithClamAV(task.AbsFilename)
if err != nil {
log.WithFields(logrus.Fields{"file": task.AbsFilename, "error": err}).Error("Failed to scan file")
return err
}
log.Infof("Finished processing scan for file: %s", task.AbsFilename)
return nil
}
var (
conf Config
versionString string
log = logrus.New()
fileInfoCache *cache.Cache
fileMetadataCache *cache.Cache
clamClient *clamd.Clamd
redisClient *redis.Client
redisConnected bool
confMutex sync.RWMutex // Protects the global 'conf' variable and related critical sections.
// Use RLock() for reading, Lock() for writing.
uploadDuration prometheus.Histogram
uploadErrorsTotal prometheus.Counter
uploadsTotal prometheus.Counter
downloadDuration prometheus.Histogram
downloadsTotal prometheus.Counter
downloadErrorsTotal prometheus.Counter
memoryUsage prometheus.Gauge
cpuUsage prometheus.Gauge
activeConnections prometheus.Gauge
requestsTotal *prometheus.CounterVec
goroutines prometheus.Gauge
uploadSizeBytes prometheus.Histogram
downloadSizeBytes prometheus.Histogram
filesDeduplicatedTotal prometheus.Counter
deduplicationErrorsTotal prometheus.Counter
isoContainersCreatedTotal prometheus.Counter
isoCreationErrorsTotal prometheus.Counter
isoContainersMountedTotal prometheus.Counter
isoMountErrorsTotal prometheus.Counter
workerPool *WorkerPool
networkEvents chan NetworkEvent
workerAdjustmentsTotal prometheus.Counter
workerReAdjustmentsTotal prometheus.Counter
)
var bufferPool = sync.Pool{
New: func() interface{} {
buf := make([]byte, 32*1024)
return &buf
},
}
const maxConcurrentOperations = 10
var semaphore = make(chan struct{}, maxConcurrentOperations)
// Global client connection tracker for multi-interface support
var clientTracker *ClientConnectionTracker
var logMessages []string
var logMu sync.Mutex
func flushLogMessages() {
logMu.Lock()
defer logMu.Unlock()
for _, msg := range logMessages {
log.Info(msg)
}
logMessages = []string{}
}
// writePIDFile writes the current process ID to the specified pid file
func writePIDFile(pidPath string) error {
pid := os.Getpid()
pidStr := strconv.Itoa(pid)
err := os.WriteFile(pidPath, []byte(pidStr), 0644)
if err != nil {
log.Errorf("Failed to write PID file: %v", err) // Improved error logging
return err
}
log.Infof("PID %d written to %s", pid, pidPath)
return nil
}
// removePIDFile removes the PID file
func removePIDFile(pidPath string) {
err := os.Remove(pidPath)
if err != nil {
log.Errorf("Failed to remove PID file: %v", err) // Improved error logging
} else {
log.Infof("PID file %s removed successfully", pidPath)
}
}
// createAndMountISO creates an ISO container and mounts it to the specified mount point
func createAndMountISO(size, mountpoint, charset string) error {
isoPath := conf.ISO.ContainerFile
// Create an empty ISO file
cmd := exec.Command("dd", "if=/dev/zero", fmt.Sprintf("of=%s", isoPath), fmt.Sprintf("bs=%s", size), "count=1")
if err := cmd.Run(); err != nil {
isoCreationErrorsTotal.Inc()
return fmt.Errorf("failed to create ISO file: %w", err)
}
// Format the ISO file with a filesystem
cmd = exec.Command("mkfs", "-t", "iso9660", "-input-charset", charset, isoPath)
if err := cmd.Run(); err != nil {
return fmt.Errorf("failed to format ISO file: %w", err)
}
// Create the mount point directory if it doesn't exist
if err := os.MkdirAll(mountpoint, os.ModePerm); err != nil {
return fmt.Errorf("failed to create mount point: %w", err)
}
// Mount the ISO file
cmd = exec.Command("mount", "-o", "loop", isoPath, mountpoint)
if err := cmd.Run(); err != nil {
isoMountErrorsTotal.Inc()
return fmt.Errorf("failed to mount ISO file: %w", err)
}
isoContainersCreatedTotal.Inc()
isoContainersMountedTotal.Inc()
return nil
}
func initializeNetworkProtocol(forceProtocol string) (*net.Dialer, error) {
// Handle empty/default value
if forceProtocol == "" {
forceProtocol = "auto"
}
switch forceProtocol {
case "ipv4":
return &net.Dialer{
Timeout: 5 * time.Second,
DualStack: false,
Control: func(network, address string, c syscall.RawConn) error {
if network == "tcp6" {
return fmt.Errorf("IPv6 is disabled by forceprotocol setting")
}
return nil
},
}, nil
case "ipv6":
return &net.Dialer{
Timeout: 5 * time.Second,
DualStack: false,
Control: func(network, address string, c syscall.RawConn) error {
if network == "tcp4" {
return fmt.Errorf("IPv4 is disabled by forceprotocol setting")
}
return nil
},
}, nil
case "auto":
return &net.Dialer{
Timeout: 5 * time.Second,
DualStack: true,
}, nil
default:
return nil, fmt.Errorf("invalid forceprotocol value: %s", forceProtocol)
}
}
var dualStackClient *http.Client
func main() {
var configFile string
flag.StringVar(&configFile, "config", "./config.toml", "Path to configuration file \"config.toml\".")
var genConfig bool
var genConfigAdvanced bool
var genConfigPath string
var validateOnly bool
var runConfigTests bool
var validateQuiet bool
var validateVerbose bool
var validateFixable bool
var validateSecurity bool
var validatePerformance bool
var validateConnectivity bool
var listValidationChecks bool
var showVersion bool
flag.BoolVar(&genConfig, "genconfig", false, "Print minimal configuration example and exit.")
flag.BoolVar(&genConfigAdvanced, "genconfig-advanced", false, "Print advanced configuration template and exit.")
flag.StringVar(&genConfigPath, "genconfig-path", "", "Write configuration to the given file and exit.")
flag.BoolVar(&validateOnly, "validate-config", false, "Validate configuration and exit without starting server.")
flag.BoolVar(&runConfigTests, "test-config", false, "Run configuration validation test scenarios and exit.")
flag.BoolVar(&validateQuiet, "validate-quiet", false, "Only show errors during validation (suppress warnings and info).")
flag.BoolVar(&validateVerbose, "validate-verbose", false, "Show detailed validation information including system checks.")
flag.BoolVar(&validateFixable, "check-fixable", false, "Only show validation issues that can be automatically fixed.")
flag.BoolVar(&validateSecurity, "check-security", false, "Run only security-related validation checks.")
flag.BoolVar(&validatePerformance, "check-performance", false, "Run only performance-related validation checks.")
flag.BoolVar(&validateConnectivity, "check-connectivity", false, "Run only network connectivity validation checks.")
flag.BoolVar(&listValidationChecks, "list-checks", false, "List all available validation checks and exit.")
flag.BoolVar(&showVersion, "version", false, "Show version information and exit.")
flag.Parse()
if showVersion {
fmt.Printf("HMAC File Server v3.3.0\n")
os.Exit(0)
}
if listValidationChecks {
printValidationChecks()
os.Exit(0)
}
if genConfig {
fmt.Println("# Option 1: Minimal Configuration (recommended for most users)")
fmt.Println(GenerateMinimalConfig())
fmt.Println("\n# Option 2: Advanced Configuration Template (for fine-tuning)")
fmt.Println("# Use -genconfig-advanced to generate the advanced template")
os.Exit(0)
}
if genConfigAdvanced {
fmt.Println(GenerateAdvancedConfigTemplate())
os.Exit(0)
}
if genConfigPath != "" {
var content string
if genConfigAdvanced {
content = GenerateAdvancedConfigTemplate()
} else {
content = GenerateMinimalConfig()
}
f, err := os.Create(genConfigPath)
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to create file: %v\n", err)
os.Exit(1)
}
defer f.Close()
w := bufio.NewWriter(f)
fmt.Fprint(w, content)
w.Flush()
fmt.Printf("Configuration written to %s\n", genConfigPath)
os.Exit(0)
}
if runConfigTests {
RunConfigTests()
os.Exit(0)
}
// Load configuration using simplified approach
loadedConfig, err := LoadSimplifiedConfig(configFile)
if err != nil {
// If no config file exists, offer to create a minimal one
if configFile == "./config.toml" || configFile == "" {
fmt.Println("No configuration file found. Creating a minimal config.toml...")
if err := createMinimalConfig(); err != nil {
log.Fatalf("Failed to create minimal config: %v", err)
}
fmt.Println("Minimal config.toml created. Please review and modify as needed, then restart the server.")
os.Exit(0)
}
log.Fatalf("Failed to load configuration: %v", err)
}
conf = *loadedConfig
configFileGlobal = configFile // Store for validation helper functions
log.Info("Configuration loaded successfully.")
err = validateConfig(&conf)
if err != nil {
log.Fatalf("Configuration validation failed: %v", err)
}
log.Info("Configuration validated successfully.")
// Perform comprehensive configuration validation
validationResult := ValidateConfigComprehensive(&conf)
// Initialize client connection tracker for multi-interface support
clientNetworkConfig := &ClientNetworkConfig{
SessionBasedTracking: conf.ClientNetwork.SessionBasedTracking,
AllowIPChanges: conf.ClientNetwork.AllowIPChanges,
MaxIPChangesPerSession: conf.ClientNetwork.MaxIPChangesPerSession,
AdaptToClientNetwork: conf.ClientNetwork.AdaptToClientNetwork,
}
// Parse session migration timeout
if conf.ClientNetwork.SessionMigrationTimeout != "" {
if timeout, err := time.ParseDuration(conf.ClientNetwork.SessionMigrationTimeout); err == nil {
clientNetworkConfig.SessionMigrationTimeout = timeout
} else {
clientNetworkConfig.SessionMigrationTimeout = 5 * time.Minute // default
}
} else {
clientNetworkConfig.SessionMigrationTimeout = 5 * time.Minute // default
}
// Set defaults if not configured
if clientNetworkConfig.MaxIPChangesPerSession == 0 {
clientNetworkConfig.MaxIPChangesPerSession = 10
}
// Initialize the client tracker
clientTracker = NewClientConnectionTracker(clientNetworkConfig)
if clientTracker != nil {
clientTracker.StartCleanupRoutine()
log.Info("Client multi-interface support initialized")
}
// Initialize session store for network resilience
initializeSessionStore()
log.Info("Session store for network switching initialized")
PrintValidationResults(validationResult)
if validationResult.HasErrors() {
log.Fatal("Cannot start server due to configuration errors. Please fix the above issues and try again.")
}
// Handle specialized validation flags
if validateSecurity || validatePerformance || validateConnectivity || validateQuiet || validateVerbose || validateFixable {
runSpecializedValidation(&conf, validateSecurity, validatePerformance, validateConnectivity, validateQuiet, validateVerbose, validateFixable)
os.Exit(0)
}
// If only validation was requested, exit now
if validateOnly {
if validationResult.HasErrors() {
log.Error("Configuration validation failed with errors. Review the errors above.")
os.Exit(1)
} else if validationResult.HasWarnings() {
log.Info("Configuration is valid but has warnings. Review the warnings above.")
os.Exit(0)
} else {
log.Info("Configuration validation completed successfully!")
os.Exit(0)
}
}
// Set log level based on configuration
level, err := logrus.ParseLevel(conf.Logging.Level)
if err != nil {
log.Warnf("Invalid log level '%s', defaulting to 'info'", conf.Logging.Level)
level = logrus.InfoLevel
}
log.SetLevel(level)
log.Infof("Log level set to: %s", level.String())
// Log configuration settings using [logging] section
log.Infof("Server ListenAddress: %s", conf.Server.ListenAddress) // Corrected field name
log.Infof("Server UnixSocket: %v", conf.Server.UnixSocket)
log.Infof("Server StoragePath: %s", conf.Server.StoragePath)
log.Infof("Logging Level: %s", conf.Logging.Level)
log.Infof("Logging File: %s", conf.Logging.File)
log.Infof("Server MetricsEnabled: %v", conf.Server.MetricsEnabled)
log.Infof("Server MetricsPort: %s", conf.Server.MetricsPort) // Corrected field name
log.Infof("Server FileTTL: %s", conf.Server.FileTTL) // Corrected field name
log.Infof("Server MinFreeBytes: %s", conf.Server.MinFreeBytes)
log.Infof("Server AutoAdjustWorkers: %v", conf.Server.AutoAdjustWorkers) // Corrected field name
log.Infof("Server NetworkEvents: %v", conf.Server.NetworkEvents) // Corrected field name
log.Infof("Server PIDFilePath: %s", conf.Server.PIDFilePath) // Corrected field name
log.Infof("Server CleanUponExit: %v", conf.Server.CleanUponExit) // Corrected field name
log.Infof("Server PreCaching: %v", conf.Server.PreCaching) // Corrected field name
log.Infof("Server FileTTLEnabled: %v", conf.Server.FileTTLEnabled) // Corrected field name
log.Infof("Server DeduplicationEnabled: %v", conf.Server.DeduplicationEnabled)
log.Infof("Server BindIP: %s", conf.Server.BindIP) // Corrected field name
log.Infof("Server FileNaming: %s", conf.Server.FileNaming)
log.Infof("Server ForceProtocol: %s", conf.Server.ForceProtocol)
err = writePIDFile(conf.Server.PIDFilePath) // Corrected field name
if err != nil {
log.Fatalf("Error writing PID file: %v", err)
}
log.Debug("DEBUG: PID file written successfully")
log.Debugf("DEBUG: Config logging file: %s", conf.Logging.File)
setupLogging()
log.Debug("DEBUG: Logging setup completed")
logSystemInfo()
log.Debug("DEBUG: System info logged")
// Initialize metrics before using any Prometheus counters
initMetrics()
log.Debug("DEBUG: Metrics initialized")
initializeWorkerSettings(&conf.Server, &conf.Workers, &conf.ClamAV)
log.Debug("DEBUG: Worker settings initialized")
if conf.ISO.Enabled {
err := createAndMountISO(conf.ISO.Size, conf.ISO.MountPoint, conf.ISO.Charset)
if err != nil {
log.Fatalf("Failed to create and mount ISO container: %v", err)
}
log.Infof("ISO container mounted at %s", conf.ISO.MountPoint)
}
// Set storage path to ISO mount point if ISO is enabled
storagePath := conf.Server.StoragePath
if conf.ISO.Enabled {
storagePath = conf.ISO.MountPoint
}
fileInfoCache = cache.New(5*time.Minute, 10*time.Minute)
fileMetadataCache = cache.New(5*time.Minute, 10*time.Minute)
if conf.Server.PreCaching { // Corrected field name
go func() {
log.Info("Starting pre-caching of storage path...")
// Use helper function
err := precacheStoragePath(storagePath)
if err != nil {
log.Warnf("Pre-caching storage path failed: %v", err)
} else {
log.Info("Pre-cached all files in the storage path.")
log.Info("Pre-caching status: complete.")
}
}()
}
err = os.MkdirAll(storagePath, os.ModePerm)
if err != nil {
log.Fatalf("Error creating store directory: %v", err)
}
log.WithField("directory", storagePath).Info("Store directory is ready")
// Use helper function
err = checkFreeSpaceWithRetry(storagePath, 3, 5*time.Second)
if err != nil {
log.Fatalf("Insufficient free space: %v", err)
}
initializeWorkerSettings(&conf.Server, &conf.Workers, &conf.ClamAV)
log.Info("Prometheus metrics initialized.")
networkEvents = make(chan NetworkEvent, 100)
log.Info("Network event channel initialized.")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Legacy network monitoring disabled - now handled by NetworkResilienceManager
// if conf.Server.NetworkEvents { // Corrected field name
// go monitorNetwork(ctx) // OLD: Basic network monitoring (replaced by NetworkResilienceManager)
// go handleNetworkEvents(ctx) // OLD: Basic event logging (replaced by NetworkResilienceManager)
// }
go updateSystemMetrics(ctx)
if conf.ClamAV.ClamAVEnabled {
var clamErr error
clamClient, clamErr = initClamAV(conf.ClamAV.ClamAVSocket) // Assuming initClamAV is defined in helpers.go or elsewhere
if clamErr != nil {
log.WithError(clamErr).Warn("ClamAV client initialization failed. Continuing without ClamAV.")
} else {
log.Info("ClamAV client initialized successfully.")
}
}
if conf.Redis.RedisEnabled {
initRedis() // Assuming initRedis is defined in helpers.go or elsewhere
}
router := setupRouter() // Assuming setupRouter is defined (likely in this file or router.go
// Initialize enhancements and enhance the router
InitializeEnhancements(router)
go handleFileCleanup(&conf) // Directly call handleFileCleanup
readTimeout, err := time.ParseDuration(conf.Timeouts.Read) // Corrected field name
if err != nil {
log.Fatalf("Invalid ReadTimeout: %v", err)
}
writeTimeout, err := time.ParseDuration(conf.Timeouts.Write) // Corrected field name
if err != nil {
log.Fatalf("Invalid WriteTimeout: %v", err)
}
idleTimeout, err := time.ParseDuration(conf.Timeouts.Idle) // Corrected field name
if err != nil {
log.Fatalf("Invalid IdleTimeout: %v", err)
}
// Initialize network protocol based on forceprotocol setting
dialer, err := initializeNetworkProtocol(conf.Server.ForceProtocol)
if err != nil {
log.Fatalf("Failed to initialize network protocol: %v", err)
}
// Enhanced dual-stack HTTP client for robust IPv4/IPv6 and resource management
// See: https://pkg.go.dev/net/http#Transport for details on these settings
dualStackClient = &http.Client{
Transport: &http.Transport{
DialContext: dialer.DialContext,
IdleConnTimeout: 90 * time.Second, // Close idle connections after 90s
MaxIdleConns: 100, // Max idle connections across all hosts
MaxIdleConnsPerHost: 10, // Max idle connections per host
TLSHandshakeTimeout: 10 * time.Second, // Timeout for TLS handshake
ResponseHeaderTimeout: 15 * time.Second, // Timeout for reading response headers
},
}
server := &http.Server{
Addr: conf.Server.BindIP + ":" + conf.Server.ListenAddress, // Use BindIP + ListenAddress (port)
Handler: router,
ReadTimeout: readTimeout,
WriteTimeout: writeTimeout,
IdleTimeout: idleTimeout,
MaxHeaderBytes: 1 << 20, // 1 MB
}
if conf.Server.MetricsEnabled {
var wg sync.WaitGroup
go func() {
http.Handle("/metrics", promhttp.Handler())
log.Infof("Metrics server started on port %s", conf.Server.MetricsPort) // Corrected field name
if err := http.ListenAndServe(":"+conf.Server.MetricsPort, nil); err != nil { // Corrected field name
log.Fatalf("Metrics server failed: %v", err)
}
wg.Wait()
}()
}
setupGracefulShutdown(server, cancel) // Assuming setupGracefulShutdown is defined
if conf.Server.AutoAdjustWorkers { // Corrected field name
go monitorWorkerPerformance(ctx, &conf.Server, &conf.Workers, &conf.ClamAV)
}
versionString = "3.3.0" // Set a default version for now
if conf.Build.Version != "" {
versionString = conf.Build.Version
}
log.Infof("Running version: %s", versionString)
log.Infof("Starting HMAC file server %s...", versionString)
if conf.Server.UnixSocket {
socketPath := "/tmp/hmac-file-server.sock" // Use a default socket path since ListenAddress is now a port
if err := os.RemoveAll(socketPath); err != nil {
log.Fatalf("Failed to remove existing Unix socket: %v", err)
}
listener, err := net.Listen("unix", socketPath)
if err != nil {
log.Fatalf("Failed to listen on Unix socket %s: %v", socketPath, err)
}
defer listener.Close()
log.Infof("Server listening on Unix socket: %s", socketPath)
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
log.Fatalf("Server failed: %v", err)
}
} else {
if conf.Server.BindIP == "0.0.0.0" {
log.Info("Binding to 0.0.0.0. Any net/http logs you see are normal for this universal address.")
}
log.Infof("Server listening on %s", server.Addr)
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("Server failed: %v", err)
}
}
// Start file cleanup in a separate goroutine
// Use helper function
go handleFileCleanup(&conf)
}
func printExampleConfig() {
fmt.Print(`
[server]
bind_ip = "0.0.0.0"
listenport = "8080"
unixsocket = false
storagepath = "./uploads"
logfile = "/var/log/hmac-file-server.log"
metricsenabled = true
metricsport = "9090"
minfreebytes = "100MB"
filettl = "8760h"
filettlenabled = true
autoadjustworkers = true
networkevents = true
pidfilepath = "/var/run/hmacfileserver.pid"
cleanuponexit = true
precaching = true
deduplicationenabled = true
globalextensions = [".txt", ".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp"]
# FileNaming options: "HMAC", "None"
filenaming = "HMAC"
forceprotocol = "auto"
[logging]
level = "info"
file = "/var/log/hmac-file-server.log"
max_size = 100
max_backups = 7
max_age = 30
compress = true
[deduplication]
enabled = true
directory = "./deduplication"
[iso]
enabled = true
size = "1GB"
mountpoint = "/mnt/iso"
charset = "utf-8"
containerfile = "/mnt/iso/container.iso"
[timeouts]
readtimeout = "4800s"
writetimeout = "4800s"
idletimeout = "4800s"
[security]
secret = "changeme"
enablejwt = false
jwtsecret = "anothersecretkey"
jwtalgorithm = "HS256"
jwtexpiration = "24h"
[versioning]
enableversioning = false
maxversions = 1
[uploads]
resumableuploadsenabled = true
chunkeduploadsenabled = true
chunksize = "8192"
allowedextensions = [".txt", ".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp"]
[downloads]
resumabledownloadsenabled = true
chunkeddownloadsenabled = true
chunksize = "8192"
allowedextensions = [".txt", ".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp"]
[clamav]
clamavenabled = true
clamavsocket = "/var/run/clamav/clamd.ctl"
numscanworkers = 2
scanfileextensions = [".txt", ".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp"]
[redis]
redisenabled = true
redisdbindex = 0
redisaddr = "localhost:6379"
redispassword = ""
redishealthcheckinterval = "120s"
[workers]
numworkers = 4
uploadqueuesize = 50
[file]
# Add file-specific configurations here
[build]
version = "3.3.0"
`)
}
func getExampleConfigString() string {
return `[server]
listen_address = ":8080"
storage_path = "/srv/hmac-file-server/uploads"
metrics_enabled = true
metrics_path = "/metrics"
pid_file = "/var/run/hmac-file-server.pid"
max_upload_size = "10GB" # Supports B, KB, MB, GB, TB
max_header_bytes = 1048576 # 1MB
cleanup_interval = "24h"
max_file_age = "720h" # 30 days
pre_cache = true
pre_cache_workers = 4
pre_cache_interval = "1h"
global_extensions = [".txt", ".dat", ".iso"] # If set, overrides upload/download extensions
deduplication_enabled = true
min_free_bytes = "1GB" # Minimum free space required for uploads
file_naming = "original" # Options: "original", "HMAC"
force_protocol = "" # Options: "http", "https" - if set, redirects to this protocol
enable_dynamic_workers = true # Enable dynamic worker scaling
worker_scale_up_thresh = 50 # Queue length to scale up workers
worker_scale_down_thresh = 10 # Queue length to scale down workers
[uploads]
allowed_extensions = [".zip", ".rar", ".7z", ".tar.gz", ".tgz", ".gpg", ".enc", ".pgp"]
chunked_uploads_enabled = true
chunk_size = "10MB"
resumable_uploads_enabled = true
max_resumable_age = "48h"
[downloads]
allowed_extensions = [".zip", ".rar", ".7z", ".tar.gz", ".tgz", ".gpg", ".enc", ".pgp"]
chunked_downloads_enabled = true
chunk_size = "10MB"
resumable_downloads_enabled = true
[security]
secret = "your-very-secret-hmac-key"
enablejwt = false
jwtsecret = "anothersecretkey"
jwtalgorithm = "HS256"
jwtexpiration = "24h"
[logging]
level = "info"
file = "/var/log/hmac-file-server.log"
max_size = 100
max_backups = 7
max_age = 30
compress = true
[deduplication]
enabled = true
directory = "./deduplication"
[iso]
enabled = true
size = "1GB"
mountpoint = "/mnt/iso"
charset = "utf-8"
containerfile = "/mnt/iso/container.iso"
[timeouts]
readtimeout = "4800s"
writetimeout = "4800s"
idletimeout = "4800s"
[security]
secret = "changeme"
enablejwt = false
jwtsecret = "anothersecretkey"
jwtalgorithm = "HS256"
jwtexpiration = "24h"
[versioning]
enableversioning = false
maxversions = 1
[uploads]
resumableuploadsenabled = true
chunkeduploadsenabled = true
chunksize = "8192"
allowedextensions = [".txt", ".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp"]
[downloads]
resumabledownloadsenabled = true
chunkeddownloadsenabled = true
chunksize = "8192"
allowedextensions = [".txt", ".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp"]
[clamav]
clamavenabled = true
clamavsocket = "/var/run/clamav/clamd.ctl"
numscanworkers = 2
scanfileextensions = [".txt", ".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp"]
[redis]
redisenabled = true
redisdbindex = 0
redisaddr = "localhost:6379"
redispassword = ""
redishealthcheckinterval = "120s"
[workers]
numworkers = 4
uploadqueuesize = 50
[file]
# Add file-specific configurations here
[build]
version = "3.3.0"
`
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
func autoAdjustWorkers() (int, int) {
v, _ := mem.VirtualMemory()
cpuCores, _ := cpu.Counts(true)
numWorkers := cpuCores * 2
if v.Available < 4*1024*1024*1024 { // Less than 4GB available
numWorkers = max(numWorkers/2, 1)
} else if v.Available < 8*1024*1024*1024 { // Less than 8GB available
numWorkers = max(numWorkers*3/4, 1)
}
queueSize := numWorkers * 10
log.Infof("Auto-adjusting workers: NumWorkers=%d, UploadQueueSize=%d", numWorkers, queueSize)
workerAdjustmentsTotal.Inc()
return numWorkers, queueSize
}
func initializeWorkerSettings(server *ServerConfig, workers *WorkersConfig, clamav *ClamAVConfig) {
if server.AutoAdjustWorkers {
numWorkers, queueSize := autoAdjustWorkers()
workers.NumWorkers = numWorkers
workers.UploadQueueSize = queueSize
clamav.NumScanWorkers = max(numWorkers/2, 1)
log.Infof("AutoAdjustWorkers enabled: NumWorkers=%d, UploadQueueSize=%d, NumScanWorkers=%d",
workers.NumWorkers, workers.UploadQueueSize, clamav.NumScanWorkers)
} else {
log.Infof("Manual configuration in effect: NumWorkers=%d, UploadQueueSize=%d, NumScanWorkers=%d",
workers.NumWorkers, workers.UploadQueueSize, clamav.NumScanWorkers)
}
}
func monitorWorkerPerformance(ctx context.Context, server *ServerConfig, w *WorkersConfig, clamav *ClamAVConfig) {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
log.Info("Stopping worker performance monitor.")
return
case <-ticker.C:
if server.AutoAdjustWorkers {
numWorkers, queueSize := autoAdjustWorkers()
w.NumWorkers = numWorkers
w.UploadQueueSize = queueSize
clamav.NumScanWorkers = max(numWorkers/2, 1)
log.Infof("Re-adjusted workers: NumWorkers=%d, UploadQueueSize=%d, NumScanWorkers=%d",
w.NumWorkers, w.UploadQueueSize, clamav.NumScanWorkers)
workerReAdjustmentsTotal.Inc()
}
}
}
}
func readConfig(configFilename string, conf *Config) error {
viper.SetConfigFile(configFilename)
if err := viper.ReadInConfig(); err != nil {
log.WithError(err).Errorf("Unable to read config from %s", configFilename)
return err
}
if err := viper.Unmarshal(conf); err != nil {
return fmt.Errorf("unable to decode config into struct: %v", err)
}
return nil
}
func setDefaults() {
viper.SetDefault("server.listen_address", ":8080")
viper.SetDefault("server.storage_path", "./uploads")
viper.SetDefault("server.metrics_enabled", true)
viper.SetDefault("server.metrics_path", "/metrics")
viper.SetDefault("server.pid_file", "/var/run/hmac-file-server.pid")
viper.SetDefault("server.max_upload_size", "10GB")
viper.SetDefault("server.max_header_bytes", 1048576) // 1MB
viper.SetDefault("server.cleanup_interval", "24h")
viper.SetDefault("server.max_file_age", "720h") // 30 days
viper.SetDefault("server.pre_cache", true)
viper.SetDefault("server.pre_cache_workers", 4)
viper.SetDefault("server.pre_cache_interval", "1h")
viper.SetDefault("server.global_extensions", []string{})
viper.SetDefault("server.deduplication_enabled", true)
viper.SetDefault("server.min_free_bytes", "1GB")
viper.SetDefault("server.file_naming", "original")
viper.SetDefault("server.force_protocol", "auto")
viper.SetDefault("server.enable_dynamic_workers", true)
viper.SetDefault("server.worker_scale_up_thresh", 50)
viper.SetDefault("server.worker_scale_down_thresh", 10)
viper.SetDefault("uploads.allowed_extensions", []string{".zip", ".rar", ".7z", ".tar.gz", ".tgz", ".gpg", ".enc", ".pgp"})
viper.SetDefault("uploads.chunked_uploads_enabled", true)
viper.SetDefault("uploads.chunk_size", "10MB")
viper.SetDefault("uploads.resumable_uploads_enabled", true)
viper.SetDefault("uploads.max_resumable_age", "48h")
viper.SetDefault("downloads.allowed_extensions", []string{".zip", ".rar", ".7z", ".tar.gz", ".tgz", ".gpg", ".enc", ".pgp"})
viper.SetDefault("downloads.chunked_downloads_enabled", true)
viper.SetDefault("downloads.chunk_size", "10MB")
viper.SetDefault("downloads.resumable_downloads_enabled", true)
viper.SetDefault("security.secret", "your-very-secret-hmac-key")
viper.SetDefault("security.enablejwt", false)
viper.SetDefault("security.jwtsecret", "your-256-bit-secret")
viper.SetDefault("security.jwtalgorithm", "HS256")
viper.SetDefault("security.jwtexpiration", "24h")
// Logging defaults
viper.SetDefault("logging.level", "info")
viper.SetDefault("logging.file", "/var/log/hmac-file-server.log")
viper.SetDefault("logging.max_size", 100)
viper.SetDefault("logging.max_backups", 7)
viper.SetDefault("logging.max_age", 30)
viper.SetDefault("logging.compress", true)
// Deduplication defaults
viper.SetDefault("deduplication.enabled", false)
viper.SetDefault("deduplication.directory", "./dedup_store")
// ISO defaults
viper.SetDefault("iso.enabled", false)
viper.SetDefault("iso.mount_point", "/mnt/hmac_iso")
viper.SetDefault("iso.size", "1GB")
viper.SetDefault("iso.charset", "utf-8")
viper.SetDefault("iso.containerfile", "/var/lib/hmac-file-server/data.iso")
// Timeouts defaults
viper.SetDefault("timeouts.read", "60s")
viper.SetDefault("timeouts.write", "60s")
viper.SetDefault("timeouts.idle", "120s")
viper.SetDefault("timeouts.shutdown", "30s")
// Versioning defaults
viper.SetDefault("versioning.enabled", false)
viper.SetDefault("versioning.backend", "simple")
viper.SetDefault("versioning.max_revisions", 5)
// Session store defaults for network resilience
viper.SetDefault("session_store.enabled", true)
viper.SetDefault("session_store.backend", "memory")
viper.SetDefault("session_store.max_sessions", 10000)
viper.SetDefault("session_store.cleanup_interval", "30m")
viper.SetDefault("session_store.max_session_age", "72h")
viper.SetDefault("session_store.redis_url", "")
// ... other defaults for Uploads, Downloads, ClamAV, Redis, Workers, File, Build
viper.SetDefault("build.version", "dev")
}
func validateConfig(c *Config) error {
if c.Server.ListenAddress == "" { // Corrected field name
return errors.New("server.listen_address is required")
}
if c.Server.FileTTL == "" && c.Server.FileTTLEnabled { // Corrected field names
return errors.New("server.file_ttl is required when server.file_ttl_enabled is true")
}
if _, err := time.ParseDuration(c.Timeouts.Read); err != nil { // Corrected field name
return fmt.Errorf("invalid timeouts.read: %v", err)
}
if _, err := time.ParseDuration(c.Timeouts.Write); err != nil { // Corrected field name
return fmt.Errorf("invalid timeouts.write: %v", err)
}
if _, err := time.ParseDuration(c.Timeouts.Idle); err != nil { // Corrected field name
return fmt.Errorf("invalid timeouts.idle: %v", err)
}
// Corrected VersioningConfig field access
if c.Versioning.Enabled { // Use the Go struct field name 'Enabled'
if c.Versioning.MaxRevs <= 0 { // Use the Go struct field name 'MaxRevs'
return errors.New("versioning.max_revisions must be positive if versioning is enabled")
}
}
// Validate JWT secret if JWT is enabled
if c.Security.EnableJWT && strings.TrimSpace(c.Security.JWTSecret) == "" {
return errors.New("security.jwtsecret is required when security.enablejwt is true")
}
// Validate HMAC secret if JWT is not enabled (as it's the fallback)
if !c.Security.EnableJWT && strings.TrimSpace(c.Security.Secret) == "" {
return errors.New("security.secret is required for HMAC authentication (when JWT is disabled)")
}
return nil
}
// validateJWTFromRequest extracts and validates a JWT from the request.
func validateJWTFromRequest(r *http.Request, secret string) (*jwt.Token, error) {
authHeader := r.Header.Get("Authorization")
tokenString := ""
if authHeader != "" {
splitToken := strings.Split(authHeader, "Bearer ")
if len(splitToken) == 2 {
tokenString = splitToken[1]
} else {
return nil, errors.New("invalid Authorization header format")
}
} else {
// Fallback to checking 'token' query parameter
tokenString = r.URL.Query().Get("token")
if tokenString == "" {
return nil, errors.New("missing JWT in Authorization header or 'token' query parameter")
}
}
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(secret), nil
})
if err != nil {
return nil, fmt.Errorf("JWT validation failed: %w", err)
}
if !token.Valid {
return nil, errors.New("invalid JWT")
}
return token, nil
}
// validateBearerToken validates Bearer token authentication from ejabberd module
// ENHANCED FOR 100% WIFI ↔ LTE SWITCHING AND STANDBY RECOVERY RELIABILITY
func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, error) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
return nil, errors.New("missing Authorization header")
}
// Check for Bearer token format
if !strings.HasPrefix(authHeader, "Bearer ") {
return nil, errors.New("invalid Authorization header format")
}
token := strings.TrimPrefix(authHeader, "Bearer ")
if token == "" {
return nil, errors.New("empty Bearer token")
}
// Decode base64 token
tokenBytes, err := base64.StdEncoding.DecodeString(token)
if err != nil {
return nil, fmt.Errorf("invalid base64 token: %v", err)
}
// Extract claims from URL parameters
query := r.URL.Query()
user := query.Get("user")
expiryStr := query.Get("expiry")
if user == "" {
return nil, errors.New("missing user parameter")
}
if expiryStr == "" {
return nil, errors.New("missing expiry parameter")
}
expiry, err := strconv.ParseInt(expiryStr, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid expiry parameter: %v", err)
}
// ULTRA-FLEXIBLE GRACE PERIODS FOR NETWORK SWITCHING AND STANDBY SCENARIOS
now := time.Now().Unix()
// Base grace period: 8 hours (increased from 4 hours for better WiFi ↔ LTE reliability)
gracePeriod := int64(28800) // 8 hours base grace period for all scenarios
// Detect mobile XMPP clients and apply enhanced grace periods
userAgent := r.Header.Get("User-Agent")
isMobileXMPP := strings.Contains(strings.ToLower(userAgent), "conversations") ||
strings.Contains(strings.ToLower(userAgent), "dino") ||
strings.Contains(strings.ToLower(userAgent), "gajim") ||
strings.Contains(strings.ToLower(userAgent), "android") ||
strings.Contains(strings.ToLower(userAgent), "mobile") ||
strings.Contains(strings.ToLower(userAgent), "xmpp") ||
strings.Contains(strings.ToLower(userAgent), "client") ||
strings.Contains(strings.ToLower(userAgent), "bot")
// Enhanced XMPP client detection and grace period management
// Desktop XMPP clients (Dino, Gajim) need extended grace for session restoration after restart
isDesktopXMPP := strings.Contains(strings.ToLower(userAgent), "dino") ||
strings.Contains(strings.ToLower(userAgent), "gajim")
if isMobileXMPP || isDesktopXMPP {
if isDesktopXMPP {
gracePeriod = int64(86400) // 24 hours for desktop XMPP clients (session restoration)
log.Infof("🖥️ Desktop XMPP client detected (%s), using 24-hour grace period for session restoration", userAgent)
} else {
gracePeriod = int64(43200) // 12 hours for mobile XMPP clients
log.Infof("<22> Mobile XMPP client detected (%s), using extended 12-hour grace period", userAgent)
}
}
// Network resilience parameters for session recovery
sessionId := query.Get("session_id")
networkResilience := query.Get("network_resilience")
resumeAllowed := query.Get("resume_allowed")
// Maximum grace period for network resilience scenarios
if sessionId != "" || networkResilience == "true" || resumeAllowed == "true" {
gracePeriod = int64(86400) // 24 hours for explicit network resilience scenarios
log.Infof("🌐 Network resilience mode activated (session_id: %s, network_resilience: %s), using 24-hour grace period",
sessionId, networkResilience)
}
// Detect potential network switching scenarios
clientIP := getClientIP(r)
xForwardedFor := r.Header.Get("X-Forwarded-For")
xRealIP := r.Header.Get("X-Real-IP")
// Check for client IP change indicators (WiFi ↔ LTE switching detection)
if xForwardedFor != "" || xRealIP != "" {
// Client is behind proxy/NAT - likely mobile switching between networks
gracePeriod = int64(86400) // 24 hours for proxy/NAT scenarios
log.Infof("📱 Network switching detected (client IP: %s, X-Forwarded-For: %s, X-Real-IP: %s), using 24-hour grace period",
clientIP, xForwardedFor, xRealIP)
}
// Check Content-Length to identify large uploads that need extra time
contentLength := r.Header.Get("Content-Length")
var size int64 = 0
if contentLength != "" {
size, _ = strconv.ParseInt(contentLength, 10, 64)
// For large files (>10MB), add extra grace time for mobile uploads
if size > 10*1024*1024 {
additionalTime := (size / (10 * 1024 * 1024)) * 3600 // 1 hour per 10MB
gracePeriod += additionalTime
log.Infof("📁 Large file detected (%d bytes), extending grace period by %d seconds", size, additionalTime)
}
}
// ABSOLUTE MAXIMUM: 48 hours for extreme scenarios
maxAbsoluteGrace := int64(172800) // 48 hours absolute maximum
if gracePeriod > maxAbsoluteGrace {
gracePeriod = maxAbsoluteGrace
log.Infof("⚠️ Grace period capped at 48 hours maximum")
}
// STANDBY RECOVERY: Special handling for device standby scenarios
isLikelyStandbyRecovery := false
standbyGraceExtension := int64(86400) // Additional 24 hours for standby recovery
if now > expiry {
expiredTime := now - expiry
// If token expired more than grace period but less than standby window, allow standby recovery
if expiredTime > gracePeriod && expiredTime < (gracePeriod + standbyGraceExtension) {
isLikelyStandbyRecovery = true
log.Infof("💤 STANDBY RECOVERY: Token expired %d seconds ago, within standby recovery window", expiredTime)
}
// Apply grace period check
if expiredTime > gracePeriod && !isLikelyStandbyRecovery {
// DESKTOP XMPP CLIENT SESSION RESTORATION: Special handling for Dino/Gajim restart scenarios
isDesktopSessionRestore := false
if isDesktopXMPP && expiredTime < int64(172800) { // 48 hours for desktop session restore
isDesktopSessionRestore = true
log.Infof("🖥️ DESKTOP SESSION RESTORE: %s token expired %d seconds ago, allowing within 48-hour desktop restoration window", userAgent, expiredTime)
}
// Still apply ultra-generous final check for mobile scenarios
ultraMaxGrace := int64(259200) // 72 hours ultra-maximum for critical mobile scenarios
if (isMobileXMPP && expiredTime < ultraMaxGrace) || isDesktopSessionRestore {
if isMobileXMPP {
log.Warnf("⚡ ULTRA-GRACE: Mobile XMPP client token expired %d seconds ago, allowing within 72-hour ultra-grace window", expiredTime)
}
} else {
log.Warnf("❌ Bearer token expired beyond all grace periods: now=%d, expiry=%d, expired_for=%d seconds, grace_period=%d, user_agent=%s",
now, expiry, expiredTime, gracePeriod, userAgent)
return nil, fmt.Errorf("token has expired beyond grace period (expired %d seconds ago, grace period: %d seconds)",
expiredTime, gracePeriod)
}
} else if isLikelyStandbyRecovery {
log.Infof("✅ STANDBY RECOVERY successful: allowing token within extended standby window")
} else {
log.Infof("✅ Bearer token expired but within grace period: %d seconds remaining", gracePeriod-expiredTime)
}
} else {
log.Debugf("✅ Bearer token still valid: %d seconds until expiry", expiry-now)
}
// Extract filename and size from request with enhanced path parsing
pathParts := strings.Split(strings.Trim(r.URL.Path, "/"), "/")
if len(pathParts) < 1 {
return nil, errors.New("invalid upload path format")
}
// Handle different path formats from various ejabberd modules
filename := ""
if len(pathParts) >= 3 {
filename = pathParts[len(pathParts)-1] // Standard format: /upload/uuid/filename
} else if len(pathParts) >= 1 {
filename = pathParts[len(pathParts)-1] // Simplified format: /filename
}
if filename == "" {
filename = "upload" // Fallback filename
}
// ENHANCED HMAC VALIDATION: Try multiple payload formats for maximum compatibility
var validPayload bool
var payloadFormat string
// Format 1: Network-resilient payload (mod_http_upload_hmac_network_resilient)
extendedPayload := fmt.Sprintf("%s\x00%s\x00%d\x00%d\x00%d\x00network_resilient",
user, filename, size, expiry-86400, expiry)
h1 := hmac.New(sha256.New, []byte(secret))
h1.Write([]byte(extendedPayload))
expectedMAC1 := h1.Sum(nil)
if hmac.Equal(tokenBytes, expectedMAC1) {
validPayload = true
payloadFormat = "network_resilient"
}
// Format 2: Extended payload with session support
if !validPayload {
sessionPayload := fmt.Sprintf("%s\x00%s\x00%d\x00%d\x00%s", user, filename, size, expiry, sessionId)
h2 := hmac.New(sha256.New, []byte(secret))
h2.Write([]byte(sessionPayload))
expectedMAC2 := h2.Sum(nil)
if hmac.Equal(tokenBytes, expectedMAC2) {
validPayload = true
payloadFormat = "session_based"
}
}
// Format 3: Standard payload (original mod_http_upload_hmac)
if !validPayload {
standardPayload := fmt.Sprintf("%s\x00%s\x00%d\x00%d", user, filename, size, expiry-3600)
h3 := hmac.New(sha256.New, []byte(secret))
h3.Write([]byte(standardPayload))
expectedMAC3 := h3.Sum(nil)
if hmac.Equal(tokenBytes, expectedMAC3) {
validPayload = true
payloadFormat = "standard"
}
}
// Format 4: Simplified payload (fallback compatibility)
if !validPayload {
simplePayload := fmt.Sprintf("%s\x00%s\x00%d", user, filename, size)
h4 := hmac.New(sha256.New, []byte(secret))
h4.Write([]byte(simplePayload))
expectedMAC4 := h4.Sum(nil)
if hmac.Equal(tokenBytes, expectedMAC4) {
validPayload = true
payloadFormat = "simple"
}
}
// Format 5: User-only payload (maximum fallback)
if !validPayload {
userPayload := fmt.Sprintf("%s\x00%d", user, expiry)
h5 := hmac.New(sha256.New, []byte(secret))
h5.Write([]byte(userPayload))
expectedMAC5 := h5.Sum(nil)
if hmac.Equal(tokenBytes, expectedMAC5) {
validPayload = true
payloadFormat = "user_only"
}
}
if !validPayload {
log.Warnf("❌ Invalid Bearer token HMAC for user %s, file %s (tried all 5 payload formats)", user, filename)
return nil, errors.New("invalid Bearer token HMAC")
}
claims := &BearerTokenClaims{
User: user,
Filename: filename,
Size: size,
Expiry: expiry,
}
log.Infof("✅ Bearer token authentication SUCCESSFUL: user=%s, file=%s, format=%s, grace_period=%d seconds",
user, filename, payloadFormat, gracePeriod)
return claims, nil
}
// evaluateSecurityLevel determines the required security level based on network changes and standby detection
func evaluateSecurityLevel(session *NetworkResilientSession, currentIP string, userAgent string) int {
now := time.Now()
// Initialize if this is the first check
if session.LastSecurityCheck.IsZero() {
session.LastSecurityCheck = now
session.LastActivity = now
session.SecurityLevel = 1 // Normal level
return 1
}
// Detect potential standby scenario
timeSinceLastActivity := now.Sub(session.LastActivity)
standbyThreshold := 30 * time.Minute
if timeSinceLastActivity > standbyThreshold {
session.StandbyDetected = true
log.Infof("🔒 STANDBY DETECTED: %v since last activity for session %s", timeSinceLastActivity, session.SessionID)
// Long standby requires full re-authentication
if timeSinceLastActivity > 2*time.Hour {
log.Warnf("🔐 SECURITY LEVEL 3: Long standby (%v) requires full re-authentication", timeSinceLastActivity)
return 3
}
// Medium standby requires challenge-response
log.Infof("🔐 SECURITY LEVEL 2: Medium standby (%v) requires challenge-response", timeSinceLastActivity)
return 2
}
// Detect network changes
if session.LastIP != "" && session.LastIP != currentIP {
session.NetworkChangeCount++
log.Infof("🌐 NETWORK CHANGE #%d: %s → %s for session %s",
session.NetworkChangeCount, session.LastIP, currentIP, session.SessionID)
// Multiple rapid network changes are suspicious
if session.NetworkChangeCount > 3 {
log.Warnf("🔐 SECURITY LEVEL 3: Multiple network changes (%d) requires full re-authentication",
session.NetworkChangeCount)
return 3
}
// Single network change requires challenge-response
log.Infof("🔐 SECURITY LEVEL 2: Network change requires challenge-response")
return 2
}
// Check for suspicious user agent changes
if session.UserAgent != "" && session.UserAgent != userAgent {
log.Warnf("🔐 SECURITY LEVEL 3: User agent change detected - potential device hijacking")
return 3
}
// Normal operation
return 1
}
// generateSecurityChallenge creates a challenge for Level 2 authentication
func generateSecurityChallenge(session *NetworkResilientSession, secret string) (string, error) {
// Create a time-based challenge using session data
timestamp := time.Now().Unix()
challengeData := fmt.Sprintf("%s:%s:%d", session.SessionID, session.UserJID, timestamp)
h := hmac.New(sha256.New, []byte(secret))
h.Write([]byte(challengeData))
challenge := hex.EncodeToString(h.Sum(nil))
log.Infof("🔐 Generated security challenge for session %s", session.SessionID)
return challenge, nil
}
// validateSecurityChallenge verifies Level 2 challenge-response
func validateSecurityChallenge(session *NetworkResilientSession, providedResponse string, secret string) bool {
// This would validate against the expected response
// For now, we'll implement a simple time-window validation
timestamp := time.Now().Unix()
// Allow 5-minute window for challenge responses
for i := int64(0); i <= 300; i += 60 {
testTimestamp := timestamp - i
challengeData := fmt.Sprintf("%s:%s:%d", session.SessionID, session.UserJID, testTimestamp)
h := hmac.New(sha256.New, []byte(secret))
h.Write([]byte(challengeData))
expectedResponse := hex.EncodeToString(h.Sum(nil))
if expectedResponse == providedResponse {
log.Infof("✅ Security challenge validated for session %s", session.SessionID)
return true
}
}
log.Warnf("❌ Security challenge failed for session %s", session.SessionID)
return false
}
// setSecurityHeaders adds appropriate headers for re-authentication requests
func setSecurityHeaders(w http.ResponseWriter, securityLevel int, challenge string) {
switch securityLevel {
case 2:
// Challenge-response required
w.Header().Set("WWW-Authenticate", fmt.Sprintf("HMAC-Challenge challenge=\"%s\"", challenge))
w.Header().Set("X-Security-Level", "2")
w.Header().Set("X-Auth-Required", "challenge-response")
case 3:
// Full re-authentication required
w.Header().Set("WWW-Authenticate", "HMAC realm=\"HMAC File Server\"")
w.Header().Set("X-Security-Level", "3")
w.Header().Set("X-Auth-Required", "full-authentication")
default:
// Normal level
w.Header().Set("X-Security-Level", "1")
}
}
// validateBearerTokenWithSession validates Bearer token with session recovery support
// ENHANCED FOR NETWORK SWITCHING: 5G ↔ WiFi transition support with session persistence
func validateBearerTokenWithSession(r *http.Request, secret string) (*BearerTokenClaims, error) {
// Step 1: Try standard Bearer token validation first
claims, err := validateBearerToken(r, secret)
if err == nil {
// Token is valid - create or update session for network resilience
sessionID := getSessionIDFromRequest(r)
if sessionID == "" {
sessionID = generateSessionID(claims.User, claims.Filename)
}
// Get or create session
session := sessionStore.GetSession(sessionID)
if session == nil {
session = &NetworkResilientSession{
SessionID: sessionID,
UserJID: claims.User,
OriginalToken: getBearerTokenFromRequest(r),
CreatedAt: time.Now(),
MaxRefreshes: 10,
NetworkHistory: []NetworkEvent{},
SecurityLevel: 1,
LastSecurityCheck: time.Now(),
NetworkChangeCount: 0,
StandbyDetected: false,
LastActivity: time.Now(),
}
}
// Update session with current network context
currentIP := getClientIP(r)
userAgent := r.Header.Get("User-Agent")
// ENHANCED SECURITY: Evaluate security level based on network changes and standby
requiredSecurityLevel := evaluateSecurityLevel(session, currentIP, userAgent)
session.SecurityLevel = requiredSecurityLevel
session.LastActivity = time.Now()
// Handle security level requirements
if requiredSecurityLevel > 1 {
// Extract response writer from context for security headers
w, ok := r.Context().Value("responseWriter").(http.ResponseWriter)
if !ok {
log.Errorf("❌ Could not extract response writer for security headers")
return nil, fmt.Errorf("security evaluation failed")
}
switch requiredSecurityLevel {
case 2:
// Challenge-response required
challenge, err := generateSecurityChallenge(session, secret)
if err != nil {
log.Errorf("❌ Failed to generate security challenge: %v", err)
return nil, fmt.Errorf("security challenge generation failed")
}
// Check if client provided challenge response
challengeResponse := r.Header.Get("X-Challenge-Response")
if challengeResponse == "" {
// No response provided, send challenge
setSecurityHeaders(w, 2, challenge)
return nil, fmt.Errorf("challenge-response required for network change")
}
// Validate challenge response
if !validateSecurityChallenge(session, challengeResponse, secret) {
setSecurityHeaders(w, 2, challenge)
return nil, fmt.Errorf("invalid challenge response")
}
log.Infof("✅ Challenge-response validated for session %s", sessionID)
case 3:
// Full re-authentication required
setSecurityHeaders(w, 3, "")
log.Warnf("🔐 Full re-authentication required for session %s", sessionID)
return nil, fmt.Errorf("full re-authentication required")
}
}
if session.LastIP != "" && session.LastIP != currentIP {
// Network change detected
session.NetworkHistory = append(session.NetworkHistory, NetworkEvent{
Timestamp: time.Now(),
FromNetwork: session.LastIP,
ToNetwork: currentIP,
ClientIP: currentIP,
UserAgent: userAgent,
EventType: "network_switch",
})
log.Infof("🌐 Network switch detected for session %s: %s → %s",
sessionID, session.LastIP, currentIP)
}
session.LastIP = currentIP
session.UserAgent = userAgent
sessionStore.StoreSession(sessionID, session)
// Set session headers in response
if w, ok := r.Context().Value("responseWriter").(http.ResponseWriter); ok {
setSessionHeaders(w, sessionID)
}
log.Infof("✅ Bearer token valid, session updated: %s (user: %s)", sessionID, claims.User)
return claims, nil
}
// Step 2: Token validation failed - try session recovery
sessionID := getSessionIDFromRequest(r)
if sessionID != "" {
session := sessionStore.GetSession(sessionID)
if session != nil {
// Check if session is still valid (within 72-hour window)
sessionAge := time.Since(session.CreatedAt)
if sessionAge < 72*time.Hour {
log.Infof("🔄 Session recovery attempt for %s (age: %v)", sessionID, sessionAge)
// Check if we can refresh the token
if session.RefreshCount < session.MaxRefreshes {
_, err := refreshSessionToken(session, secret, r)
if err == nil {
// Token refresh successful
session.RefreshCount++
session.LastSeen = time.Now()
// Add refresh event to history
session.NetworkHistory = append(session.NetworkHistory, NetworkEvent{
Timestamp: time.Now(),
ClientIP: getClientIP(r),
UserAgent: r.Header.Get("User-Agent"),
EventType: "token_refresh",
})
sessionStore.StoreSession(sessionID, session)
// Create claims from refreshed session
refreshedClaims := &BearerTokenClaims{
User: session.UserJID,
Filename: extractFilenameFromPath(r.URL.Path),
Size: extractSizeFromRequest(r),
Expiry: time.Now().Add(24 * time.Hour).Unix(),
}
log.Infof("✅ Session recovery successful: %s (refresh #%d)",
sessionID, session.RefreshCount)
return refreshedClaims, nil
}
} else {
log.Warnf("❌ Session %s exceeded maximum refreshes (%d)",
sessionID, session.MaxRefreshes)
}
} else {
log.Warnf("❌ Session %s expired (age: %v, max: 72h)", sessionID, sessionAge)
}
} else {
log.Warnf("❌ Session %s not found in store", sessionID)
}
}
// Step 3: No valid token or session recovery possible
log.Warnf("❌ Authentication failed: %v (no session recovery available)", err)
return nil, fmt.Errorf("authentication failed: %v", err)
}
// refreshSessionToken generates a new token for an existing session
func refreshSessionToken(session *NetworkResilientSession, secret string, r *http.Request) (string, error) {
if session.RefreshCount >= session.MaxRefreshes {
return "", fmt.Errorf("maximum token refreshes exceeded")
}
// Generate new HMAC token with extended validity
timestamp := time.Now().Unix()
expiry := timestamp + 86400 // 24 hours
filename := extractFilenameFromPath(r.URL.Path)
size := extractSizeFromRequest(r)
// Use session-based payload format for refresh
payload := fmt.Sprintf("%s\x00%s\x00%d\x00%d\x00%s\x00session_refresh",
session.UserJID,
filename,
size,
expiry,
session.SessionID)
h := hmac.New(sha256.New, []byte(secret))
h.Write([]byte(payload))
token := base64.StdEncoding.EncodeToString(h.Sum(nil))
log.Infof("🆕 Generated refresh token for session %s (refresh #%d)",
session.SessionID, session.RefreshCount+1)
return token, nil
}
// Helper functions for token and session management
func getBearerTokenFromRequest(r *http.Request) string {
authHeader := r.Header.Get("Authorization")
if strings.HasPrefix(authHeader, "Bearer ") {
return strings.TrimPrefix(authHeader, "Bearer ")
}
return ""
}
func extractFilenameFromPath(path string) string {
pathParts := strings.Split(strings.Trim(path, "/"), "/")
if len(pathParts) >= 1 {
return pathParts[len(pathParts)-1]
}
return "unknown"
}
func extractSizeFromRequest(r *http.Request) int64 {
if sizeStr := r.Header.Get("Content-Length"); sizeStr != "" {
if size, err := strconv.ParseInt(sizeStr, 10, 64); err == nil {
return size
}
}
if sizeStr := r.URL.Query().Get("size"); sizeStr != "" {
if size, err := strconv.ParseInt(sizeStr, 10, 64); err == nil {
return size
}
}
return 0
}
// BearerTokenClaims represents the claims extracted from a Bearer token
type BearerTokenClaims struct {
User string
Filename string
Size int64
Expiry int64
}
// validateHMAC validates the HMAC signature of the request for legacy protocols and POST uploads.
// ENHANCED FOR 100% WIFI ↔ LTE SWITCHING AND STANDBY RECOVERY RELIABILITY
func validateHMAC(r *http.Request, secret string) error {
log.Debugf("🔍 validateHMAC: Validating request to %s with query: %s", r.URL.Path, r.URL.RawQuery)
// Check for X-Signature header (for POST uploads)
signature := r.Header.Get("X-Signature")
if signature != "" {
// This is a POST upload with X-Signature header
message := r.URL.Path
h := hmac.New(sha256.New, []byte(secret))
h.Write([]byte(message))
expectedSignature := hex.EncodeToString(h.Sum(nil))
if !hmac.Equal([]byte(signature), []byte(expectedSignature)) {
log.Warnf("❌ Invalid HMAC signature in X-Signature header")
return errors.New("invalid HMAC signature in X-Signature header")
}
log.Debugf("✅ X-Signature HMAC authentication successful")
return nil
}
// Check for legacy URL-based HMAC protocols (v, v2, token)
query := r.URL.Query()
var protocolVersion string
var providedMACHex string
if query.Get("v2") != "" {
protocolVersion = "v2"
providedMACHex = query.Get("v2")
} else if query.Get("token") != "" {
protocolVersion = "token"
providedMACHex = query.Get("token")
} else if query.Get("v") != "" {
protocolVersion = "v"
providedMACHex = query.Get("v")
} else {
return errors.New("no HMAC signature found (missing X-Signature header or v/v2/token query parameter)")
}
// Extract file path from URL
fileStorePath := strings.TrimPrefix(r.URL.Path, "/")
// ENHANCED HMAC CALCULATION: Try multiple formats for maximum compatibility
var validMAC bool
var messageFormat string
// Calculate HMAC based on protocol version with enhanced compatibility
mac := hmac.New(sha256.New, []byte(secret))
if protocolVersion == "v" {
// Format 1: Legacy v protocol - fileStorePath + "\x20" + contentLength
message1 := fileStorePath + "\x20" + strconv.FormatInt(r.ContentLength, 10)
mac.Reset()
mac.Write([]byte(message1))
calculatedMAC1 := mac.Sum(nil)
calculatedMACHex1 := hex.EncodeToString(calculatedMAC1)
// Decode provided MAC
if providedMAC, err := hex.DecodeString(providedMACHex); err == nil {
if hmac.Equal(calculatedMAC1, providedMAC) {
validMAC = true
messageFormat = "v_standard"
log.Debugf("✅ Legacy v protocol HMAC validated: %s", calculatedMACHex1)
}
}
// Format 2: Try without content length for compatibility
if !validMAC {
message2 := fileStorePath
mac.Reset()
mac.Write([]byte(message2))
calculatedMAC2 := mac.Sum(nil)
if providedMAC, err := hex.DecodeString(providedMACHex); err == nil {
if hmac.Equal(calculatedMAC2, providedMAC) {
validMAC = true
messageFormat = "v_simple"
log.Debugf("✅ Legacy v protocol HMAC validated (simple format)")
}
}
}
} else {
// v2 and token protocols: Enhanced format compatibility
contentType := GetContentType(fileStorePath)
// Format 1: Standard format - fileStorePath + "\x00" + contentLength + "\x00" + contentType
message1 := fileStorePath + "\x00" + strconv.FormatInt(r.ContentLength, 10) + "\x00" + contentType
mac.Reset()
mac.Write([]byte(message1))
calculatedMAC1 := mac.Sum(nil)
calculatedMACHex1 := hex.EncodeToString(calculatedMAC1)
if providedMAC, err := hex.DecodeString(providedMACHex); err == nil {
if hmac.Equal(calculatedMAC1, providedMAC) {
validMAC = true
messageFormat = protocolVersion + "_standard"
log.Debugf("✅ %s protocol HMAC validated (standard): %s", protocolVersion, calculatedMACHex1)
}
}
// Format 2: Without content type for compatibility
if !validMAC {
message2 := fileStorePath + "\x00" + strconv.FormatInt(r.ContentLength, 10)
mac.Reset()
mac.Write([]byte(message2))
calculatedMAC2 := mac.Sum(nil)
if providedMAC, err := hex.DecodeString(providedMACHex); err == nil {
if hmac.Equal(calculatedMAC2, providedMAC) {
validMAC = true
messageFormat = protocolVersion + "_no_content_type"
log.Debugf("✅ %s protocol HMAC validated (no content type)", protocolVersion)
}
}
}
// Format 3: Simple path only for maximum compatibility
if !validMAC {
message3 := fileStorePath
mac.Reset()
mac.Write([]byte(message3))
calculatedMAC3 := mac.Sum(nil)
if providedMAC, err := hex.DecodeString(providedMACHex); err == nil {
if hmac.Equal(calculatedMAC3, providedMAC) {
validMAC = true
messageFormat = protocolVersion + "_simple"
log.Debugf("✅ %s protocol HMAC validated (simple path)", protocolVersion)
}
}
}
}
if !validMAC {
log.Warnf("❌ Invalid MAC for %s protocol (tried all formats)", protocolVersion)
return fmt.Errorf("invalid MAC for %s protocol", protocolVersion)
}
log.Infof("✅ %s HMAC authentication SUCCESSFUL: format=%s, path=%s",
protocolVersion, messageFormat, r.URL.Path)
return nil
}
// validateV3HMAC validates the HMAC signature for v3 protocol (mod_http_upload_external).
// ENHANCED FOR 100% WIFI ↔ LTE SWITCHING AND STANDBY RECOVERY RELIABILITY
func validateV3HMAC(r *http.Request, secret string) error {
query := r.URL.Query()
// Extract v3 signature and expires from query parameters
signature := query.Get("v3")
expiresStr := query.Get("expires")
if signature == "" {
return errors.New("missing v3 signature parameter")
}
if expiresStr == "" {
return errors.New("missing expires parameter")
}
// Parse expires timestamp
expires, err := strconv.ParseInt(expiresStr, 10, 64)
if err != nil {
return fmt.Errorf("invalid expires parameter: %v", err)
}
// ULTRA-FLEXIBLE GRACE PERIODS FOR V3 PROTOCOL NETWORK SWITCHING
now := time.Now().Unix()
if now > expires {
// Base grace period: 8 hours (significantly increased for WiFi ↔ LTE reliability)
gracePeriod := int64(28800) // 8 hours base grace period
// Enhanced mobile XMPP client detection
userAgent := r.Header.Get("User-Agent")
isMobileXMPP := strings.Contains(strings.ToLower(userAgent), "gajim") ||
strings.Contains(strings.ToLower(userAgent), "dino") ||
strings.Contains(strings.ToLower(userAgent), "conversations") ||
strings.Contains(strings.ToLower(userAgent), "android") ||
strings.Contains(strings.ToLower(userAgent), "mobile") ||
strings.Contains(strings.ToLower(userAgent), "xmpp") ||
strings.Contains(strings.ToLower(userAgent), "client") ||
strings.Contains(strings.ToLower(userAgent), "bot")
if isMobileXMPP {
gracePeriod = int64(43200) // 12 hours for mobile XMPP clients
log.Infof("📱 V3: Mobile XMPP client detected (%s), using 12-hour grace period", userAgent)
}
// Network resilience parameters for V3 protocol
sessionId := query.Get("session_id")
networkResilience := query.Get("network_resilience")
resumeAllowed := query.Get("resume_allowed")
if sessionId != "" || networkResilience == "true" || resumeAllowed == "true" {
gracePeriod = int64(86400) // 24 hours for network resilience scenarios
log.Infof("🌐 V3: Network resilience mode detected, using 24-hour grace period")
}
// Detect network switching indicators
clientIP := getClientIP(r)
xForwardedFor := r.Header.Get("X-Forwarded-For")
xRealIP := r.Header.Get("X-Real-IP")
if xForwardedFor != "" || xRealIP != "" {
// Client behind proxy/NAT - likely mobile network switching
gracePeriod = int64(86400) // 24 hours for proxy/NAT scenarios
log.Infof("🔄 V3: Network switching detected (IP: %s, X-Forwarded-For: %s), using 24-hour grace period",
clientIP, xForwardedFor)
}
// Large file uploads get additional grace time
if contentLengthStr := r.Header.Get("Content-Length"); contentLengthStr != "" {
if contentLength, parseErr := strconv.ParseInt(contentLengthStr, 10, 64); parseErr == nil {
// For files > 10MB, add additional grace time
if contentLength > 10*1024*1024 {
additionalTime := (contentLength / (10 * 1024 * 1024)) * 3600 // 1 hour per 10MB
gracePeriod += additionalTime
log.Infof("📁 V3: Large file (%d bytes), extending grace period by %d seconds",
contentLength, additionalTime)
}
}
}
// Maximum grace period cap: 48 hours
maxGracePeriod := int64(172800) // 48 hours absolute maximum
if gracePeriod > maxGracePeriod {
gracePeriod = maxGracePeriod
log.Infof("⚠️ V3: Grace period capped at 48 hours maximum")
}
// STANDBY RECOVERY: Handle device standby scenarios
expiredTime := now - expires
standbyGraceExtension := int64(86400) // Additional 24 hours for standby
isLikelyStandbyRecovery := expiredTime > gracePeriod && expiredTime < (gracePeriod + standbyGraceExtension)
if expiredTime > gracePeriod && !isLikelyStandbyRecovery {
// Ultra-generous final check for mobile scenarios
ultraMaxGrace := int64(259200) // 72 hours ultra-maximum for critical scenarios
if isMobileXMPP && expiredTime < ultraMaxGrace {
log.Warnf("⚡ V3 ULTRA-GRACE: Mobile client token expired %d seconds ago, allowing within 72-hour window", expiredTime)
} else {
log.Warnf("❌ V3 signature expired beyond all grace periods: now=%d, expires=%d, expired_for=%d seconds, grace_period=%d, user_agent=%s",
now, expires, expiredTime, gracePeriod, userAgent)
return fmt.Errorf("signature has expired beyond grace period (expired %d seconds ago, grace period: %d seconds)",
expiredTime, gracePeriod)
}
} else if isLikelyStandbyRecovery {
log.Infof("💤 V3 STANDBY RECOVERY: Allowing signature within extended standby window (expired %d seconds ago)", expiredTime)
} else {
log.Infof("✅ V3 signature within grace period: %d seconds remaining", gracePeriod-expiredTime)
}
} else {
log.Debugf("✅ V3 signature still valid: %d seconds until expiry", expires-now)
}
// ENHANCED MESSAGE CONSTRUCTION: Try multiple formats for compatibility
var validSignature bool
var messageFormat string
// Format 1: Standard v3 format
message1 := fmt.Sprintf("%s\n%s\n%s", r.Method, expiresStr, r.URL.Path)
h1 := hmac.New(sha256.New, []byte(secret))
h1.Write([]byte(message1))
expectedSignature1 := hex.EncodeToString(h1.Sum(nil))
if hmac.Equal([]byte(signature), []byte(expectedSignature1)) {
validSignature = true
messageFormat = "standard_v3"
}
// Format 2: Alternative format with query string
if !validSignature {
pathWithQuery := r.URL.Path
if r.URL.RawQuery != "" {
pathWithQuery += "?" + r.URL.RawQuery
}
message2 := fmt.Sprintf("%s\n%s\n%s", r.Method, expiresStr, pathWithQuery)
h2 := hmac.New(sha256.New, []byte(secret))
h2.Write([]byte(message2))
expectedSignature2 := hex.EncodeToString(h2.Sum(nil))
if hmac.Equal([]byte(signature), []byte(expectedSignature2)) {
validSignature = true
messageFormat = "with_query"
}
}
// Format 3: Simplified format (fallback)
if !validSignature {
message3 := fmt.Sprintf("%s\n%s", r.Method, r.URL.Path)
h3 := hmac.New(sha256.New, []byte(secret))
h3.Write([]byte(message3))
expectedSignature3 := hex.EncodeToString(h3.Sum(nil))
if hmac.Equal([]byte(signature), []byte(expectedSignature3)) {
validSignature = true
messageFormat = "simplified"
}
}
if !validSignature {
log.Warnf("❌ Invalid V3 HMAC signature (tried all 3 formats)")
return errors.New("invalid v3 HMAC signature")
}
log.Infof("✅ V3 HMAC authentication SUCCESSFUL: format=%s, method=%s, path=%s",
messageFormat, r.Method, r.URL.Path)
return nil
}
// copyWithProgressTracking copies data with progress tracking for large downloads
func copyWithProgressTracking(dst io.Writer, src io.Reader, buf []byte, totalSize int64, clientIP string) (int64, error) {
var written int64
lastLogTime := time.Now()
for {
n, err := src.Read(buf)
if n > 0 {
w, werr := dst.Write(buf[:n])
written += int64(w)
if werr != nil {
return written, werr
}
// Log progress for large files every 10MB or 30 seconds
if totalSize > 50*1024*1024 &&
(written%10*1024*1024 == 0 || time.Since(lastLogTime) > 30*time.Second) {
progress := float64(written) / float64(totalSize) * 100
log.Infof("📥 Download progress: %.1f%% (%s/%s) for IP %s",
progress, formatBytes(written), formatBytes(totalSize), clientIP)
lastLogTime = time.Now()
}
}
if err == io.EOF {
break
}
if err != nil {
return written, err
}
}
return written, nil
}
// handleUpload handles file uploads.
// ENHANCED FOR 100% WIFI ↔ LTE SWITCHING AND STANDBY RECOVERY RELIABILITY
func handleUpload(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
activeConnections.Inc()
defer activeConnections.Dec()
// Enhanced session handling for multi-upload scenarios (Gajim fix)
sessionID := r.Header.Get("X-Session-ID")
if sessionID == "" {
// Generate session ID for multi-upload tracking
sessionID = generateUploadSessionID("upload", r.Header.Get("User-Agent"), getClientIP(r))
}
// Set session headers for client continuation
w.Header().Set("X-Session-ID", sessionID)
w.Header().Set("X-Upload-Session-Timeout", "3600") // 1 hour
// Only allow POST method
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
uploadErrorsTotal.Inc()
return
}
// ENHANCED AUTHENTICATION with network switching support
var bearerClaims *BearerTokenClaims
authHeader := r.Header.Get("Authorization")
if strings.HasPrefix(authHeader, "Bearer ") {
// Bearer token authentication with session recovery for network switching
// Store response writer in context for session headers
ctx := context.WithValue(r.Context(), "responseWriter", w)
r = r.WithContext(ctx)
claims, err := validateBearerTokenWithSession(r, conf.Security.Secret)
if err != nil {
// Enhanced error logging for network switching scenarios
clientIP := getClientIP(r)
userAgent := r.Header.Get("User-Agent")
sessionID := getSessionIDFromRequest(r)
log.Warnf("🔴 Authentication failed for IP %s, User-Agent: %s, Session: %s, Error: %v",
clientIP, userAgent, sessionID, err)
// Check if this might be a network switching scenario and provide helpful response
if strings.Contains(err.Error(), "expired") || strings.Contains(err.Error(), "invalid") {
w.Header().Set("X-Network-Switch-Detected", "true")
w.Header().Set("X-Retry-After", "30") // Suggest retry after 30 seconds
w.Header().Set("X-Session-Recovery", "available")
if sessionID != "" {
w.Header().Set("X-Session-ID", sessionID)
}
}
http.Error(w, fmt.Sprintf("Bearer Token Authentication failed: %v", err), http.StatusUnauthorized)
uploadErrorsTotal.Inc()
return
}
bearerClaims = claims
log.Infof("✅ Bearer token authentication successful: user=%s, file=%s, IP=%s",
claims.User, claims.Filename, getClientIP(r))
// Add comprehensive response headers for audit logging and client tracking
w.Header().Set("X-Authenticated-User", claims.User)
w.Header().Set("X-Auth-Method", "Bearer-Token")
w.Header().Set("X-Client-IP", getClientIP(r))
w.Header().Set("X-Network-Switch-Support", "enabled")
} else if conf.Security.EnableJWT {
// JWT authentication
_, err := validateJWTFromRequest(r, conf.Security.JWTSecret)
if err != nil {
log.Warnf("🔴 JWT Authentication failed for IP %s: %v", getClientIP(r), err)
http.Error(w, fmt.Sprintf("JWT Authentication failed: %v", err), http.StatusUnauthorized)
uploadErrorsTotal.Inc()
return
}
log.Infof("✅ JWT authentication successful for upload request: %s", r.URL.Path)
w.Header().Set("X-Auth-Method", "JWT")
} else {
// HMAC authentication with enhanced network switching support
err := validateHMAC(r, conf.Security.Secret)
if err != nil {
log.Warnf("🔴 HMAC Authentication failed for IP %s: %v", getClientIP(r), err)
http.Error(w, fmt.Sprintf("HMAC Authentication failed: %v", err), http.StatusUnauthorized)
uploadErrorsTotal.Inc()
return
}
log.Infof("✅ HMAC authentication successful for upload request: %s", r.URL.Path)
w.Header().Set("X-Auth-Method", "HMAC")
}
// ENHANCED CLIENT MULTI-INTERFACE TRACKING with network switching detection
var clientSession *ClientSession
if clientTracker != nil && conf.ClientNetwork.SessionBasedTracking {
// Enhanced session ID extraction from multiple sources
sessionID := r.Header.Get("X-Upload-Session-ID")
if sessionID == "" {
sessionID = r.FormValue("session_id")
}
if sessionID == "" {
sessionID = r.URL.Query().Get("session_id")
}
if sessionID == "" {
// Generate new session ID with enhanced entropy
sessionID = generateSessionID("", "")
}
clientIP := getClientIP(r)
// Detect potential network switching
xForwardedFor := r.Header.Get("X-Forwarded-For")
xRealIP := r.Header.Get("X-Real-IP")
networkSwitchIndicators := xForwardedFor != "" || xRealIP != ""
if networkSwitchIndicators {
log.Infof("🔄 Network switching indicators detected: session=%s, client_ip=%s, x_forwarded_for=%s, x_real_ip=%s",
sessionID, clientIP, xForwardedFor, xRealIP)
w.Header().Set("X-Network-Switch-Detected", "true")
}
clientSession = clientTracker.TrackClientSession(sessionID, clientIP, r)
// Enhanced session response headers for client coordination
w.Header().Set("X-Upload-Session-ID", sessionID)
w.Header().Set("X-Session-IP-Count", fmt.Sprintf("%d", len(clientSession.ClientIPs)))
w.Header().Set("X-Connection-Type", clientSession.ConnectionType)
log.Infof("🔗 Client session tracking: %s from IP %s (connection: %s, total_ips: %d)",
sessionID, clientIP, clientSession.ConnectionType, len(clientSession.ClientIPs))
// Add user context for Bearer token authentication
if bearerClaims != nil {
log.Infof("👤 Session associated with XMPP user: %s", bearerClaims.User)
w.Header().Set("X-XMPP-User", bearerClaims.User)
}
}
// Parse multipart form with enhanced error handling
err := r.ParseMultipartForm(32 << 20) // 32MB max memory
if err != nil {
log.Errorf("🔴 Error parsing multipart form from IP %s: %v", getClientIP(r), err)
http.Error(w, fmt.Sprintf("Error parsing multipart form: %v", err), http.StatusBadRequest)
uploadErrorsTotal.Inc()
return
}
// Get file from form with enhanced validation
file, header, err := r.FormFile("file")
if err != nil {
log.Errorf("🔴 Error getting file from form (IP: %s): %v", getClientIP(r), err)
http.Error(w, fmt.Sprintf("Error getting file from form: %v", err), http.StatusBadRequest)
uploadErrorsTotal.Inc()
return
}
defer file.Close()
// Validate file size against max_upload_size if configured
if conf.Server.MaxUploadSize != "" {
maxSizeBytes, err := parseSize(conf.Server.MaxUploadSize)
if err != nil {
log.Errorf("🔴 Invalid max_upload_size configuration: %v", err)
http.Error(w, "Server configuration error", http.StatusInternalServerError)
uploadErrorsTotal.Inc()
return
}
if header.Size > maxSizeBytes {
log.Warnf("⚠️ File size %s exceeds maximum allowed size %s (IP: %s)",
formatBytes(header.Size), conf.Server.MaxUploadSize, getClientIP(r))
http.Error(w, fmt.Sprintf("File size %s exceeds maximum allowed size %s",
formatBytes(header.Size), conf.Server.MaxUploadSize), http.StatusRequestEntityTooLarge)
uploadErrorsTotal.Inc()
return
}
}
// Validate file extension if configured
if len(conf.Uploads.AllowedExtensions) > 0 {
ext := strings.ToLower(filepath.Ext(header.Filename))
allowed := false
for _, allowedExt := range conf.Uploads.AllowedExtensions {
if ext == allowedExt {
allowed = true
break
}
}
if !allowed {
log.Warnf("⚠️ File extension %s not allowed (IP: %s, file: %s)", ext, getClientIP(r), header.Filename)
http.Error(w, fmt.Sprintf("File extension %s not allowed", ext), http.StatusBadRequest)
uploadErrorsTotal.Inc()
return
}
}
// Generate filename based on configuration
var filename string
switch conf.Server.FileNaming {
case "HMAC":
// Generate HMAC-based filename with enhanced entropy
h := hmac.New(sha256.New, []byte(conf.Security.Secret))
h.Write([]byte(header.Filename + time.Now().String() + getClientIP(r)))
filename = hex.EncodeToString(h.Sum(nil)) + filepath.Ext(header.Filename)
default: // "original" or "None"
filename = header.Filename
}
// Create full file path
storagePath := conf.Server.StoragePath
if conf.ISO.Enabled {
storagePath = conf.ISO.MountPoint
}
absFilename := filepath.Join(storagePath, filename)
// Pre-upload deduplication check: if file already exists and deduplication is enabled, return success immediately
if conf.Server.DeduplicationEnabled {
if existingFileInfo, err := os.Stat(absFilename); err == nil {
// File already exists - return success immediately for deduplication hit
duration := time.Since(startTime)
uploadDuration.Observe(duration.Seconds())
uploadsTotal.Inc()
uploadSizeBytes.Observe(float64(existingFileInfo.Size()))
filesDeduplicatedTotal.Inc()
w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-Deduplication-Hit", "true")
w.WriteHeader(http.StatusOK)
response := map[string]interface{}{
"success": true,
"filename": filename,
"size": existingFileInfo.Size(),
"message": "File already exists (deduplication hit)",
"upload_time": duration.String(),
}
json.NewEncoder(w).Encode(response)
log.Infof("💾 Deduplication hit: file %s already exists (%s), returning success immediately (IP: %s)",
filename, formatBytes(existingFileInfo.Size()), getClientIP(r))
return
}
}
// Create the file with enhanced error handling
dst, err := os.Create(absFilename)
if err != nil {
log.Errorf("🔴 Error creating file %s (IP: %s): %v", absFilename, getClientIP(r), err)
http.Error(w, fmt.Sprintf("Error creating file: %v", err), http.StatusInternalServerError)
uploadErrorsTotal.Inc()
return
}
defer dst.Close()
// Register upload with network resilience manager for WLAN/5G switching support
var uploadCtx *UploadContext
var networkSessionID string
if networkManager != nil {
networkSessionID = r.Header.Get("X-Upload-Session-ID")
if networkSessionID == "" {
networkSessionID = fmt.Sprintf("upload_%s_%d", getClientIP(r), time.Now().UnixNano())
}
uploadCtx = networkManager.RegisterUpload(networkSessionID)
defer networkManager.UnregisterUpload(networkSessionID)
log.Infof("🌐 Registered upload with network resilience: session=%s, IP=%s", networkSessionID, getClientIP(r))
// Add network resilience headers
w.Header().Set("X-Network-Resilience", "enabled")
w.Header().Set("X-Upload-Context-ID", networkSessionID)
}
// Copy file content with network resilience support and enhanced progress tracking
written, err := copyWithNetworkResilience(dst, file, uploadCtx)
if err != nil {
log.Errorf("🔴 Error saving file %s (IP: %s, session: %s): %v", filename, getClientIP(r), sessionID, err)
http.Error(w, fmt.Sprintf("Error saving file: %v", err), http.StatusInternalServerError)
uploadErrorsTotal.Inc()
// Clean up partial file
os.Remove(absFilename)
return
}
// ✅ CRITICAL FIX: Send immediate success response for large files (>1GB)
// This prevents client timeouts while server does post-processing
isLargeFile := header.Size > 1024*1024*1024 // 1GB threshold
if isLargeFile {
log.Infof("🚀 Large file detected (%s), sending immediate success response", formatBytes(header.Size))
// Send immediate success response to client
duration := time.Since(startTime)
uploadDuration.Observe(duration.Seconds())
uploadsTotal.Inc()
uploadSizeBytes.Observe(float64(written))
w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-Upload-Success", "true")
w.Header().Set("X-Upload-Duration", duration.String())
w.Header().Set("X-Large-File-Processing", "async")
w.Header().Set("X-Post-Processing", "background")
w.WriteHeader(http.StatusOK)
response := map[string]interface{}{
"success": true,
"filename": filename,
"size": written,
"duration": duration.String(),
"client_ip": getClientIP(r),
"timestamp": time.Now().Unix(),
"post_processing": "background",
}
// Add session information if available
if clientSession != nil {
response["session_id"] = clientSession.SessionID
response["connection_type"] = clientSession.ConnectionType
response["ip_count"] = len(clientSession.ClientIPs)
}
// Add user information if available
if bearerClaims != nil {
response["user"] = bearerClaims.User
}
// Send response immediately
if jsonBytes, err := json.Marshal(response); err == nil {
w.Write(jsonBytes)
} else {
fmt.Fprintf(w, `{"success": true, "filename": "%s", "size": %d, "post_processing": "background"}`, filename, written)
}
log.Infof("✅ Immediate response sent for large file %s (%s) in %s from IP %s",
filename, formatBytes(written), duration, getClientIP(r))
// Process deduplication asynchronously for large files
go func() {
if conf.Server.DeduplicationEnabled {
log.Infof("🔄 Starting background deduplication for large file: %s", filename)
ctx := context.Background()
err := handleDeduplication(ctx, absFilename)
if err != nil {
log.Warnf("⚠️ Background deduplication failed for %s: %v", absFilename, err)
} else {
log.Infof("✅ Background deduplication completed for %s", filename)
}
}
// Add to scan queue for virus scanning if enabled
if conf.ClamAV.ClamAVEnabled && len(conf.ClamAV.ScanFileExtensions) > 0 {
ext := strings.ToLower(filepath.Ext(header.Filename))
shouldScan := false
for _, scanExt := range conf.ClamAV.ScanFileExtensions {
if ext == strings.ToLower(scanExt) {
shouldScan = true
break
}
}
if shouldScan {
log.Infof("🔍 Starting background virus scan for large file: %s", filename)
err := scanFileWithClamAV(absFilename)
if err != nil {
log.Warnf("⚠️ Background virus scan failed for %s: %v", filename, err)
} else {
log.Infof("✅ Background virus scan completed for %s", filename)
}
}
}
}()
return
}
// Standard processing for small files (synchronous)
// Handle deduplication if enabled
if conf.Server.DeduplicationEnabled {
ctx := context.Background()
err = handleDeduplication(ctx, absFilename)
if err != nil {
log.Warnf("⚠️ Deduplication failed for %s (IP: %s): %v", absFilename, getClientIP(r), err)
} else {
log.Debugf("💾 Deduplication processed for %s", absFilename)
}
}
// Update metrics
duration := time.Since(startTime)
uploadDuration.Observe(duration.Seconds())
uploadsTotal.Inc()
uploadSizeBytes.Observe(float64(written))
// Enhanced success response with comprehensive metadata
w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-Upload-Success", "true")
w.Header().Set("X-Upload-Duration", duration.String())
w.WriteHeader(http.StatusOK)
response := map[string]interface{}{
"success": true,
"filename": filename,
"size": written,
"duration": duration.String(),
"client_ip": getClientIP(r),
"timestamp": time.Now().Unix(),
}
// Add session information if available
if clientSession != nil {
response["session_id"] = clientSession.SessionID
response["connection_type"] = clientSession.ConnectionType
response["ip_count"] = len(clientSession.ClientIPs)
}
// Add user information if available
if bearerClaims != nil {
response["user"] = bearerClaims.User
}
// Create JSON response
if jsonBytes, err := json.Marshal(response); err == nil {
w.Write(jsonBytes)
} else {
fmt.Fprintf(w, `{"success": true, "filename": "%s", "size": %d}`, filename, written)
}
log.Infof("✅ Successfully uploaded %s (%s) in %s from IP %s (session: %s)",
filename, formatBytes(written), duration, getClientIP(r), sessionID)
}
// handleDownload handles file downloads.
// ENHANCED FOR 100% WIFI ↔ LTE SWITCHING AND STANDBY RECOVERY RELIABILITY
func handleDownload(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
activeConnections.Inc()
defer activeConnections.Dec()
// Enhanced Authentication with network switching tolerance
if conf.Security.EnableJWT {
_, err := validateJWTFromRequest(r, conf.Security.JWTSecret)
if err != nil {
log.Warnf("🔴 JWT Authentication failed for download from IP %s: %v", getClientIP(r), err)
http.Error(w, fmt.Sprintf("JWT Authentication failed: %v", err), http.StatusUnauthorized)
downloadErrorsTotal.Inc()
return
}
log.Infof("✅ JWT authentication successful for download request: %s", r.URL.Path)
w.Header().Set("X-Auth-Method", "JWT")
} else {
err := validateHMAC(r, conf.Security.Secret)
if err != nil {
log.Warnf("🔴 HMAC Authentication failed for download from IP %s: %v", getClientIP(r), err)
http.Error(w, fmt.Sprintf("HMAC Authentication failed: %v", err), http.StatusUnauthorized)
downloadErrorsTotal.Inc()
return
}
log.Infof("✅ HMAC authentication successful for download request: %s", r.URL.Path)
w.Header().Set("X-Auth-Method", "HMAC")
}
// Extract filename with enhanced path handling
filename := strings.TrimPrefix(r.URL.Path, "/download/")
if filename == "" {
log.Warnf("⚠️ No filename specified in download request from IP %s", getClientIP(r))
http.Error(w, "Filename not specified", http.StatusBadRequest)
downloadErrorsTotal.Inc()
return
}
// Enhanced file path validation and construction
var absFilename string
var err error
// Use storage path or ISO mount point
storagePath := conf.Server.StoragePath
if conf.ISO.Enabled {
storagePath = conf.ISO.MountPoint
}
absFilename, err = sanitizeFilePath(storagePath, filename)
if err != nil {
log.Warnf("🔴 Invalid file path requested from IP %s: %s, error: %v", getClientIP(r), filename, err)
http.Error(w, fmt.Sprintf("Invalid file path: %v", err), http.StatusBadRequest)
downloadErrorsTotal.Inc()
return
}
// Enhanced file existence and accessibility check
fileInfo, err := os.Stat(absFilename)
if os.IsNotExist(err) {
log.Warnf("🔴 File not found: %s (requested by IP %s)", absFilename, getClientIP(r))
// Enhanced 404 response with network switching hints
w.Header().Set("X-File-Not-Found", "true")
w.Header().Set("X-Client-IP", getClientIP(r))
w.Header().Set("X-Network-Switch-Support", "enabled")
// Check if this might be a network switching issue
userAgent := r.Header.Get("User-Agent")
isMobileXMPP := strings.Contains(strings.ToLower(userAgent), "conversations") ||
strings.Contains(strings.ToLower(userAgent), "dino") ||
strings.Contains(strings.ToLower(userAgent), "gajim") ||
strings.Contains(strings.ToLower(userAgent), "android") ||
strings.Contains(strings.ToLower(userAgent), "mobile") ||
strings.Contains(strings.ToLower(userAgent), "xmpp")
if isMobileXMPP {
w.Header().Set("X-Mobile-Client-Detected", "true")
w.Header().Set("X-Retry-Suggestion", "30") // Suggest retry after 30 seconds
log.Infof("📱 Mobile XMPP client file not found - may be network switching issue: %s", userAgent)
}
http.Error(w, "File not found", http.StatusNotFound)
downloadErrorsTotal.Inc()
return
}
if err != nil {
log.Errorf("🔴 Error accessing file %s from IP %s: %v", absFilename, getClientIP(r), err)
http.Error(w, fmt.Sprintf("Error accessing file: %v", err), http.StatusInternalServerError)
downloadErrorsTotal.Inc()
return
}
if fileInfo.IsDir() {
log.Warnf("⚠️ Attempt to download directory %s from IP %s", absFilename, getClientIP(r))
http.Error(w, "Cannot download a directory", http.StatusBadRequest)
downloadErrorsTotal.Inc()
return
}
// Enhanced file opening with retry logic for network switching scenarios
var file *os.File
maxRetries := 3
for attempt := 1; attempt <= maxRetries; attempt++ {
file, err = os.Open(absFilename)
if err == nil {
break
}
if attempt < maxRetries {
log.Warnf("⚠️ Attempt %d/%d: Error opening file %s from IP %s: %v (retrying...)",
attempt, maxRetries, absFilename, getClientIP(r), err)
time.Sleep(time.Duration(attempt) * time.Second) // Progressive backoff
} else {
log.Errorf("🔴 Failed to open file %s after %d attempts from IP %s: %v",
absFilename, maxRetries, getClientIP(r), err)
http.Error(w, fmt.Sprintf("Error opening file: %v", err), http.StatusInternalServerError)
downloadErrorsTotal.Inc()
return
}
}
defer file.Close()
// Enhanced response headers with network switching support
w.Header().Set("Content-Disposition", "attachment; filename=\""+filepath.Base(absFilename)+"\"")
w.Header().Set("Content-Type", "application/octet-stream")
w.Header().Set("Content-Length", fmt.Sprintf("%d", fileInfo.Size()))
w.Header().Set("X-Client-IP", getClientIP(r))
w.Header().Set("X-Network-Switch-Support", "enabled")
w.Header().Set("X-File-Path", filename)
w.Header().Set("X-Download-Start-Time", fmt.Sprintf("%d", time.Now().Unix()))
// Add cache control headers for mobile network optimization
userAgent := r.Header.Get("User-Agent")
isMobileXMPP := strings.Contains(strings.ToLower(userAgent), "conversations") ||
strings.Contains(strings.ToLower(userAgent), "dino") ||
strings.Contains(strings.ToLower(userAgent), "gajim") ||
strings.Contains(strings.ToLower(userAgent), "android") ||
strings.Contains(strings.ToLower(userAgent), "mobile") ||
strings.Contains(strings.ToLower(userAgent), "xmpp")
if isMobileXMPP {
w.Header().Set("X-Mobile-Client-Detected", "true")
w.Header().Set("Cache-Control", "public, max-age=86400") // 24 hours cache for mobile
w.Header().Set("X-Mobile-Optimized", "true")
log.Infof("📱 Mobile XMPP client download detected, applying mobile optimizations")
}
// Enhanced file transfer with buffered copy and progress tracking
bufPtr := bufferPool.Get().(*[]byte)
defer bufferPool.Put(bufPtr)
buf := *bufPtr
// Track download progress for large files
if fileInfo.Size() > 10*1024*1024 { // Log progress for files > 10MB
log.Infof("📥 Starting download of %s (%.1f MiB) for IP %s",
filepath.Base(absFilename), float64(fileInfo.Size())/(1024*1024), getClientIP(r))
}
// Enhanced copy with network resilience
n, err := copyWithProgressTracking(w, file, buf, fileInfo.Size(), getClientIP(r))
if err != nil {
log.Errorf("🔴 Error during download of %s for IP %s: %v", absFilename, getClientIP(r), err)
// Don't write http.Error here if headers already sent
downloadErrorsTotal.Inc()
return
}
// Update metrics and log success
duration := time.Since(startTime)
downloadDuration.Observe(duration.Seconds())
downloadsTotal.Inc()
downloadSizeBytes.Observe(float64(n))
log.Infof("✅ Successfully downloaded %s (%s) in %s for IP %s (session complete)",
filepath.Base(absFilename), formatBytes(n), duration, getClientIP(r))
}
// handleV3Upload handles PUT requests for v3 protocol (mod_http_upload_external).
func handleV3Upload(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
activeConnections.Inc()
defer activeConnections.Dec()
// Only allow PUT method for v3
if r.Method != http.MethodPut {
http.Error(w, "Method not allowed for v3 uploads", http.StatusMethodNotAllowed)
uploadErrorsTotal.Inc()
return
}
// Validate v3 HMAC signature
err := validateV3HMAC(r, conf.Security.Secret)
if err != nil {
http.Error(w, fmt.Sprintf("v3 Authentication failed: %v", err), http.StatusUnauthorized)
uploadErrorsTotal.Inc()
return
}
log.Debugf("v3 HMAC authentication successful for upload request: %s", r.URL.Path)
// Extract filename from the URL path
// Path format: /uuid/subdir/filename.ext
pathParts := strings.Split(strings.Trim(r.URL.Path, "/"), "/")
if len(pathParts) < 1 {
http.Error(w, "Invalid upload path", http.StatusBadRequest)
uploadErrorsTotal.Inc()
return
}
// Use the last part as filename
originalFilename := pathParts[len(pathParts)-1]
if originalFilename == "" {
http.Error(w, "No filename specified", http.StatusBadRequest)
uploadErrorsTotal.Inc()
return
}
// Validate file extension if configured
if len(conf.Uploads.AllowedExtensions) > 0 {
ext := strings.ToLower(filepath.Ext(originalFilename))
allowed := false
for _, allowedExt := range conf.Uploads.AllowedExtensions {
if ext == allowedExt {
allowed = true
break
}
}
if !allowed {
http.Error(w, fmt.Sprintf("File extension %s not allowed", ext), http.StatusBadRequest)
uploadErrorsTotal.Inc()
return
}
}
// Validate file size against max_upload_size if configured
if conf.Server.MaxUploadSize != "" {
maxSizeBytes, err := parseSize(conf.Server.MaxUploadSize)
if err != nil {
log.Errorf("Invalid max_upload_size configuration: %v", err)
http.Error(w, "Server configuration error", http.StatusInternalServerError)
uploadErrorsTotal.Inc()
return
}
if r.ContentLength > maxSizeBytes {
http.Error(w, fmt.Sprintf("File size %s exceeds maximum allowed size %s",
formatBytes(r.ContentLength), conf.Server.MaxUploadSize), http.StatusRequestEntityTooLarge)
uploadErrorsTotal.Inc()
return
}
}
// Generate filename based on configuration
var filename string
switch conf.Server.FileNaming {
case "HMAC":
// Generate HMAC-based filename
h := hmac.New(sha256.New, []byte(conf.Security.Secret))
h.Write([]byte(originalFilename + time.Now().String()))
filename = hex.EncodeToString(h.Sum(nil)) + filepath.Ext(originalFilename)
default: // "original" or "None"
filename = originalFilename
}
// Create full file path
storagePath := conf.Server.StoragePath
if conf.ISO.Enabled {
storagePath = conf.ISO.MountPoint
}
absFilename := filepath.Join(storagePath, filename)
// Pre-upload deduplication check: if file already exists and deduplication is enabled, return success immediately
if conf.Server.DeduplicationEnabled {
if existingFileInfo, err := os.Stat(absFilename); err == nil {
// File already exists - return success immediately for deduplication hit
duration := time.Since(startTime)
uploadDuration.Observe(duration.Seconds())
uploadsTotal.Inc()
uploadSizeBytes.Observe(float64(existingFileInfo.Size()))
filesDeduplicatedTotal.Inc()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
response := map[string]interface{}{
"success": true,
"filename": filename,
"size": existingFileInfo.Size(),
"message": "File already exists (deduplication hit)",
}
json.NewEncoder(w).Encode(response)
log.Infof("Deduplication hit: file %s already exists (%s), returning success immediately",
filename, formatBytes(existingFileInfo.Size()))
return
}
}
// Create the file
dst, err := os.Create(absFilename)
if err != nil {
http.Error(w, fmt.Sprintf("Error creating file: %v", err), http.StatusInternalServerError)
uploadErrorsTotal.Inc()
return
}
defer dst.Close()
// Copy file content from request body
written, err := io.Copy(dst, r.Body)
if err != nil {
http.Error(w, fmt.Sprintf("Error saving file: %v", err), http.StatusInternalServerError)
uploadErrorsTotal.Inc()
// Clean up partial file
os.Remove(absFilename)
return
}
// ✅ CRITICAL FIX: Send immediate success response for large files (>1GB)
// This prevents client timeouts while server does post-processing
isLargeFile := written > 1024*1024*1024 // 1GB threshold
if isLargeFile {
log.Infof("🚀 Large file detected (%s), sending immediate success response (v3)", formatBytes(written))
// Send immediate success response to client
duration := time.Since(startTime)
uploadDuration.Observe(duration.Seconds())
uploadsTotal.Inc()
uploadSizeBytes.Observe(float64(written))
w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-Upload-Success", "true")
w.Header().Set("X-Upload-Duration", duration.String())
w.Header().Set("X-Large-File-Processing", "async")
w.Header().Set("X-Post-Processing", "background")
w.WriteHeader(http.StatusOK)
response := map[string]interface{}{
"success": true,
"filename": filename,
"size": written,
"duration": duration.String(),
"protocol": "v3",
"post_processing": "background",
}
// Send response immediately
if jsonBytes, err := json.Marshal(response); err == nil {
w.Write(jsonBytes)
} else {
fmt.Fprintf(w, `{"success": true, "filename": "%s", "size": %d, "post_processing": "background"}`, filename, written)
}
log.Infof("✅ Immediate response sent for large file %s (%s) in %s via v3 protocol",
filename, formatBytes(written), duration)
// Process deduplication asynchronously for large files
go func() {
if conf.Server.DeduplicationEnabled {
log.Infof("🔄 Starting background deduplication for large file (v3): %s", filename)
ctx := context.Background()
err := handleDeduplication(ctx, absFilename)
if err != nil {
log.Warnf("⚠️ Background deduplication failed for %s: %v", absFilename, err)
} else {
log.Infof("✅ Background deduplication completed for %s (v3)", filename)
}
}
// Add to scan queue for virus scanning if enabled
if conf.ClamAV.ClamAVEnabled && len(conf.ClamAV.ScanFileExtensions) > 0 {
ext := strings.ToLower(filepath.Ext(originalFilename))
shouldScan := false
for _, scanExt := range conf.ClamAV.ScanFileExtensions {
if ext == strings.ToLower(scanExt) {
shouldScan = true
break
}
}
if shouldScan {
log.Infof("🔍 Starting background virus scan for large file (v3): %s", filename)
err := scanFileWithClamAV(absFilename)
if err != nil {
log.Warnf("⚠️ Background virus scan failed for %s: %v", filename, err)
} else {
log.Infof("✅ Background virus scan completed for %s (v3)", filename)
}
}
}
}()
return
}
// Standard processing for small files (synchronous)
// Handle deduplication if enabled
if conf.Server.DeduplicationEnabled {
ctx := context.Background()
err = handleDeduplication(ctx, absFilename)
if err != nil {
log.Warnf("Deduplication failed for %s: %v", absFilename, err)
}
}
// Update metrics
duration := time.Since(startTime)
uploadDuration.Observe(duration.Seconds())
uploadsTotal.Inc()
uploadSizeBytes.Observe(float64(written))
// Return success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
response := map[string]interface{}{
"success": true,
"filename": filename,
"size": written,
"duration": duration.String(),
}
// Create JSON response
if jsonBytes, err := json.Marshal(response); err == nil {
w.Write(jsonBytes)
} else {
fmt.Fprintf(w, `{"success": true, "filename": "%s", "size": %d}`, filename, written)
}
log.Infof("Successfully uploaded %s via v3 protocol (%s) in %s", filename, formatBytes(written), duration)
}
// handleLegacyUpload handles PUT requests for legacy protocols (v, v2, token).
func handleLegacyUpload(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
activeConnections.Inc()
defer activeConnections.Dec()
log.Infof("🔥 DEBUG: handleLegacyUpload called - method:%s path:%s query:%s", r.Method, r.URL.Path, r.URL.RawQuery)
// Enhanced session handling for multi-upload scenarios (Gajim XMPP fix)
sessionID := r.Header.Get("X-Session-ID")
if sessionID == "" {
// Generate session ID for XMPP multi-upload tracking
sessionID = generateUploadSessionID("legacy", r.Header.Get("User-Agent"), getClientIP(r))
}
// Set session headers for XMPP client continuation
w.Header().Set("X-Session-ID", sessionID)
w.Header().Set("X-Upload-Session-Timeout", "3600") // 1 hour
w.Header().Set("X-Upload-Type", "legacy-xmpp")
log.Debugf("handleLegacyUpload: Processing request to %s with query: %s", r.URL.Path, r.URL.RawQuery)
// Only allow PUT method for legacy uploads
if r.Method != http.MethodPut {
http.Error(w, "Method not allowed for legacy uploads", http.StatusMethodNotAllowed)
uploadErrorsTotal.Inc()
return
}
// Validate legacy HMAC signature
err := validateHMAC(r, conf.Security.Secret)
if err != nil {
http.Error(w, fmt.Sprintf("Legacy Authentication failed: %v", err), http.StatusUnauthorized)
uploadErrorsTotal.Inc()
return
}
log.Debugf("✅ HMAC validation passed for: %s", r.URL.Path)
// Extract filename from the URL path
fileStorePath := strings.TrimPrefix(r.URL.Path, "/")
if fileStorePath == "" {
log.Debugf("❌ No filename specified")
http.Error(w, "No filename specified", http.StatusBadRequest)
uploadErrorsTotal.Inc()
return
}
log.Debugf("✅ File path extracted: %s", fileStorePath)
// Validate file extension if configured
if len(conf.Uploads.AllowedExtensions) > 0 {
ext := strings.ToLower(filepath.Ext(fileStorePath))
log.Infof("<22> DEBUG: Checking file extension: %s against %d allowed extensions", ext, len(conf.Uploads.AllowedExtensions))
log.Infof("<22> DEBUG: Allowed extensions: %v", conf.Uploads.AllowedExtensions)
allowed := false
for i, allowedExt := range conf.Uploads.AllowedExtensions {
log.Infof("<22> DEBUG: [%d] Comparing '%s' == '%s'", i, ext, allowedExt)
if ext == allowedExt {
allowed = true
log.Infof("🔥 DEBUG: Extension match found!")
break
}
}
if !allowed {
log.Infof("🔥 DEBUG: Extension %s not found in allowed list", ext)
http.Error(w, fmt.Sprintf("File extension %s not allowed", ext), http.StatusBadRequest)
uploadErrorsTotal.Inc()
return
}
log.Infof("🔥 DEBUG: File extension %s is allowed", ext)
}
// Validate file size against max_upload_size if configured
if conf.Server.MaxUploadSize != "" {
maxSizeBytes, err := parseSize(conf.Server.MaxUploadSize)
if err != nil {
log.Errorf("Invalid max_upload_size configuration: %v", err)
http.Error(w, "Server configuration error", http.StatusInternalServerError)
uploadErrorsTotal.Inc()
return
}
if r.ContentLength > maxSizeBytes {
http.Error(w, fmt.Sprintf("File size %s exceeds maximum allowed size %s",
formatBytes(r.ContentLength), conf.Server.MaxUploadSize), http.StatusRequestEntityTooLarge)
uploadErrorsTotal.Inc()
return
}
}
// Create full file path
storagePath := conf.Server.StoragePath
if conf.ISO.Enabled {
storagePath = conf.ISO.MountPoint
}
// Generate filename based on configuration
var absFilename string
var filename string
switch conf.Server.FileNaming {
case "HMAC":
// Generate HMAC-based filename
h := hmac.New(sha256.New, []byte(conf.Security.Secret))
h.Write([]byte(fileStorePath + time.Now().String()))
filename = hex.EncodeToString(h.Sum(nil)) + filepath.Ext(fileStorePath)
absFilename = filepath.Join(storagePath, filename)
default: // "original" or "None"
// Preserve full directory structure for legacy XMPP compatibility
var sanitizeErr error
absFilename, sanitizeErr = sanitizeFilePath(storagePath, fileStorePath)
if sanitizeErr != nil {
http.Error(w, fmt.Sprintf("Invalid file path: %v", sanitizeErr), http.StatusBadRequest)
uploadErrorsTotal.Inc()
return
}
filename = filepath.Base(fileStorePath) // For logging purposes
}
// Create directory structure if it doesn't exist
if err := os.MkdirAll(filepath.Dir(absFilename), 0755); err != nil {
http.Error(w, fmt.Sprintf("Error creating directory: %v", err), http.StatusInternalServerError)
uploadErrorsTotal.Inc()
return
}
// Pre-upload deduplication check: if file already exists and deduplication is enabled, return success immediately
if conf.Server.DeduplicationEnabled {
if existingFileInfo, err := os.Stat(absFilename); err == nil {
// File already exists - return success immediately for deduplication hit
duration := time.Since(startTime)
uploadDuration.Observe(duration.Seconds())
uploadsTotal.Inc()
uploadSizeBytes.Observe(float64(existingFileInfo.Size()))
filesDeduplicatedTotal.Inc()
w.WriteHeader(http.StatusCreated) // 201 Created for legacy compatibility
log.Infof("Deduplication hit: file %s already exists (%s), returning success immediately",
filename, formatBytes(existingFileInfo.Size()))
return
}
}
// Create the file
dst, err := os.Create(absFilename)
if err != nil {
http.Error(w, fmt.Sprintf("Error creating file: %v", err), http.StatusInternalServerError)
uploadErrorsTotal.Inc()
return
}
defer dst.Close()
// Log upload start for large files
if r.ContentLength > 10*1024*1024 { // Log for files > 10MB
log.Infof("Starting upload of %s (%.1f MiB)", filename, float64(r.ContentLength)/(1024*1024))
}
// Copy file content from request body with progress reporting
written, err := copyWithProgress(dst, r.Body, r.ContentLength, filename)
if err != nil {
http.Error(w, fmt.Sprintf("Error saving file: %v", err), http.StatusInternalServerError)
uploadErrorsTotal.Inc()
// Clean up partial file
os.Remove(absFilename)
return
}
// ✅ CRITICAL FIX: Send immediate success response for large files (>1GB)
// This prevents client timeouts while server does post-processing
isLargeFile := written > 1024*1024*1024 // 1GB threshold
if isLargeFile {
log.Infof("🚀 Large file detected (%s), sending immediate success response (legacy)", formatBytes(written))
// Send immediate success response to client
duration := time.Since(startTime)
uploadDuration.Observe(duration.Seconds())
uploadsTotal.Inc()
uploadSizeBytes.Observe(float64(written))
// Return success response (201 Created for legacy compatibility)
w.Header().Set("X-Upload-Success", "true")
w.Header().Set("X-Upload-Duration", duration.String())
w.Header().Set("X-Large-File-Processing", "async")
w.Header().Set("X-Post-Processing", "background")
w.WriteHeader(http.StatusCreated)
log.Infof("✅ Immediate response sent for large file %s (%s) in %s via legacy protocol",
filename, formatBytes(written), duration)
// Process deduplication asynchronously for large files
go func() {
if conf.Server.DeduplicationEnabled {
log.Infof("🔄 Starting background deduplication for large file (legacy): %s", filename)
ctx := context.Background()
err := handleDeduplication(ctx, absFilename)
if err != nil {
log.Warnf("⚠️ Background deduplication failed for %s: %v", absFilename, err)
} else {
log.Infof("✅ Background deduplication completed for %s (legacy)", filename)
}
}
// Add to scan queue for virus scanning if enabled
if conf.ClamAV.ClamAVEnabled && len(conf.ClamAV.ScanFileExtensions) > 0 {
ext := strings.ToLower(filepath.Ext(fileStorePath))
shouldScan := false
for _, scanExt := range conf.ClamAV.ScanFileExtensions {
if ext == strings.ToLower(scanExt) {
shouldScan = true
break
}
}
if shouldScan {
log.Infof("🔍 Starting background virus scan for large file (legacy): %s", filename)
err := scanFileWithClamAV(absFilename)
if err != nil {
log.Warnf("⚠️ Background virus scan failed for %s: %v", filename, err)
} else {
log.Infof("✅ Background virus scan completed for %s (legacy)", filename)
}
}
}
}()
return
}
// Standard processing for small files (synchronous)
// Handle deduplication if enabled
if conf.Server.DeduplicationEnabled {
ctx := context.Background()
err = handleDeduplication(ctx, absFilename)
if err != nil {
log.Warnf("Deduplication failed for %s: %v", absFilename, err)
}
}
// Update metrics
duration := time.Since(startTime)
uploadDuration.Observe(duration.Seconds())
uploadsTotal.Inc()
uploadSizeBytes.Observe(float64(written))
// Return success response (201 Created for legacy compatibility)
w.WriteHeader(http.StatusCreated)
log.Infof("Successfully uploaded %s via legacy protocol (%s) in %s", filename, formatBytes(written), duration)
}
// handleLegacyDownload handles GET/HEAD requests for legacy downloads.
func handleLegacyDownload(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
activeConnections.Inc()
defer activeConnections.Dec()
// Extract filename from the URL path
fileStorePath := strings.TrimPrefix(r.URL.Path, "/")
if fileStorePath == "" {
http.Error(w, "No filename specified", http.StatusBadRequest)
downloadErrorsTotal.Inc()
return
}
// Create full file path
storagePath := conf.Server.StoragePath
if conf.ISO.Enabled {
storagePath = conf.ISO.MountPoint
}
absFilename := filepath.Join(storagePath, fileStorePath)
fileInfo, err := os.Stat(absFilename)
if os.IsNotExist(err) {
http.Error(w, "File not found", http.StatusNotFound)
downloadErrorsTotal.Inc()
return
}
if err != nil {
http.Error(w, fmt.Sprintf("Error accessing file: %v", err), http.StatusInternalServerError)
downloadErrorsTotal.Inc()
return
}
if fileInfo.IsDir() {
http.Error(w, "Cannot download a directory", http.StatusBadRequest)
downloadErrorsTotal.Inc()
return
}
// Set appropriate headers
contentType := GetContentType(fileStorePath)
w.Header().Set("Content-Type", contentType)
w.Header().Set("Content-Length", strconv.FormatInt(fileInfo.Size(), 10))
// For HEAD requests, only send headers
if r.Method == http.MethodHead {
w.WriteHeader(http.StatusOK)
downloadsTotal.Inc()
return
}
// For GET requests, serve the file
file, err := os.Open(absFilename)
if err != nil {
http.Error(w, fmt.Sprintf("Error opening file: %v", err), http.StatusInternalServerError)
downloadErrorsTotal.Inc()
return
}
defer file.Close()
// Use a pooled buffer for copying
bufPtr := bufferPool.Get().(*[]byte)
defer bufferPool.Put(bufPtr)
buf := *bufPtr
n, err := io.CopyBuffer(w, file, buf)
if err != nil {
log.Errorf("Error during download of %s: %v", absFilename, err)
downloadErrorsTotal.Inc()
return
}
duration := time.Since(startTime)
downloadDuration.Observe(duration.Seconds())
downloadsTotal.Inc()
downloadSizeBytes.Observe(float64(n))
log.Infof("Successfully downloaded %s (%s) in %s", absFilename, formatBytes(n), duration)
}
// printValidationChecks prints all available validation checks
func printValidationChecks() {
fmt.Println("HMAC File Server Configuration Validation Checks")
fmt.Println("=================================================")
fmt.Println()
fmt.Println("🔍 CORE VALIDATION CHECKS:")
fmt.Println(" ✓ server.* - Server configuration (ports, paths, protocols)")
fmt.Println(" ✓ security.* - Security settings (secrets, JWT, authentication)")
fmt.Println(" ✓ logging.* - Logging configuration (levels, files, rotation)")
fmt.Println(" ✓ timeouts.* - Timeout settings (read, write, idle)")
fmt.Println(" ✓ uploads.* - Upload configuration (extensions, chunk size)")
fmt.Println(" ✓ downloads.* - Download configuration (extensions, chunk size)")
fmt.Println(" ✓ workers.* - Worker pool configuration (count, queue size)")
fmt.Println(" ✓ redis.* - Redis configuration (address, credentials)")
fmt.Println(" ✓ clamav.* - ClamAV antivirus configuration")
fmt.Println(" ✓ versioning.* - File versioning configuration")
fmt.Println(" ✓ deduplication.* - File deduplication configuration")
fmt.Println(" ✓ iso.* - ISO filesystem configuration")
fmt.Println()
fmt.Println("🔐 SECURITY CHECKS:")
fmt.Println(" ✓ Secret strength analysis (length, entropy, patterns)")
fmt.Println(" ✓ Default/example value detection")
fmt.Println(" ✓ JWT algorithm security recommendations")
fmt.Println(" ✓ Network binding security (0.0.0.0 warnings)")
fmt.Println(" ✓ File permission analysis")
fmt.Println(" ✓ Debug logging security implications")
fmt.Println()
fmt.Println("⚡ PERFORMANCE CHECKS:")
fmt.Println(" ✓ Worker count vs CPU cores optimization")
fmt.Println(" ✓ Queue size vs memory usage analysis")
fmt.Println(" ✓ Timeout configuration balance")
fmt.Println(" ✓ Large file handling preparation")
fmt.Println(" ✓ Memory-intensive configuration detection")
fmt.Println()
fmt.Println("🌐 CONNECTIVITY CHECKS:")
fmt.Println(" ✓ Redis server connectivity testing")
fmt.Println(" ✓ ClamAV socket accessibility")
fmt.Println(" ✓ Network address format validation")
fmt.Println(" ✓ DNS resolution testing")
fmt.Println()
fmt.Println("💾 SYSTEM RESOURCE CHECKS:")
fmt.Println(" ✓ CPU core availability analysis")
fmt.Println(" ✓ Memory usage monitoring")
fmt.Println(" ✓ Disk space validation")
fmt.Println(" ✓ Directory write permissions")
fmt.Println(" ✓ Goroutine count analysis")
fmt.Println()
fmt.Println("🔄 CROSS-SECTION VALIDATION:")
fmt.Println(" ✓ Path conflict detection")
fmt.Println(" ✓ Extension compatibility checks")
fmt.Println(" ✓ Configuration consistency validation")
fmt.Println()
fmt.Println("📋 USAGE EXAMPLES:")
fmt.Println(" hmac-file-server --validate-config # Full validation")
fmt.Println(" hmac-file-server --check-security # Security checks only")
fmt.Println(" hmac-file-server --check-performance # Performance checks only")
fmt.Println(" hmac-file-server --check-connectivity # Network checks only")
fmt.Println(" hmac-file-server --validate-quiet # Errors only")
fmt.Println(" hmac-file-server --validate-verbose # Detailed output")
fmt.Println(" hmac-file-server --check-fixable # Auto-fixable issues")
fmt.Println()
}