Fix: Auth Session

This commit is contained in:
2025-08-26 15:53:36 +00:00
parent 71a62eca3f
commit 9b5b3ae820
25 changed files with 1142 additions and 44 deletions

View File

@ -708,7 +708,7 @@ func handleUploadWithAdaptiveIO(w http.ResponseWriter, r *http.Request) {
// Use adaptive streaming engine
clientIP := getClientIP(r)
sessionID := generateSessionID()
sessionID := generateSessionID("", "")
written, err := globalStreamingEngine.StreamWithAdaptation(
dst,
@ -804,7 +804,7 @@ func handleDownloadWithAdaptiveIO(w http.ResponseWriter, r *http.Request) {
// Use adaptive streaming engine
clientIP := getClientIP(r)
sessionID := generateSessionID()
sessionID := generateSessionID("", "")
n, err := globalStreamingEngine.StreamWithAdaptation(
w,

View File

@ -613,8 +613,12 @@ func monitorNetwork(ctx context.Context) {
if iface.Flags&net.FlagUp != 0 && iface.Flags&net.FlagLoopback == 0 {
select {
case networkEvents <- NetworkEvent{
Type: "interface_up",
Details: fmt.Sprintf("Interface %s is up", iface.Name),
Timestamp: time.Now(),
EventType: "interface_up",
ToNetwork: iface.Name,
FromNetwork: "unknown",
ClientIP: "",
UserAgent: "",
}:
default:
// Channel full, skip
@ -635,7 +639,7 @@ func handleNetworkEvents(ctx context.Context) {
log.Info("Network event handler stopped")
return
case event := <-networkEvents:
log.Debugf("Network event: %s - %s", event.Type, event.Details)
log.Debugf("Network event: %s - From: %s To: %s", event.EventType, event.FromNetwork, event.ToNetwork)
}
}
}

View File

@ -6,7 +6,6 @@ import (
"bufio"
"context"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
@ -38,6 +37,282 @@ import (
"github.com/spf13/viper"
)
// NetworkResilientSession represents a persistent session for network switching
type NetworkResilientSession struct {
SessionID string `json:"session_id"`
UserJID string `json:"user_jid"`
OriginalToken string `json:"original_token"`
CreatedAt time.Time `json:"created_at"`
LastSeen time.Time `json:"last_seen"`
NetworkHistory []NetworkEvent `json:"network_history"`
UploadContext *UploadContext `json:"upload_context,omitempty"`
RefreshCount int `json:"refresh_count"`
MaxRefreshes int `json:"max_refreshes"`
LastIP string `json:"last_ip"`
UserAgent string `json:"user_agent"`
}
// NetworkEvent tracks network transitions during session
type NetworkEvent struct {
Timestamp time.Time `json:"timestamp"`
FromNetwork string `json:"from_network"`
ToNetwork string `json:"to_network"`
ClientIP string `json:"client_ip"`
UserAgent string `json:"user_agent"`
EventType string `json:"event_type"` // "switch", "resume", "refresh"
}
// UploadContext maintains upload state across network changes and network resilience channels
type UploadContext struct {
Filename string `json:"filename"`
TotalSize int64 `json:"total_size"`
UploadedBytes int64 `json:"uploaded_bytes"`
ChunkSize int64 `json:"chunk_size"`
LastChunk int `json:"last_chunk"`
ETag string `json:"etag,omitempty"`
UploadPath string `json:"upload_path"`
ContentType string `json:"content_type"`
LastUpdate time.Time `json:"last_update"`
SessionID string `json:"session_id"`
PauseChan chan bool `json:"-"`
ResumeChan chan bool `json:"-"`
CancelChan chan bool `json:"-"`
IsPaused bool `json:"is_paused"`
}
// SessionStore manages persistent sessions for network resilience
type SessionStore struct {
storage map[string]*NetworkResilientSession
mutex sync.RWMutex
cleanupTicker *time.Ticker
redisClient *redis.Client
memoryCache *cache.Cache
enabled bool
}
// Global session store
var sessionStore *SessionStore
// Session storage methods
func (s *SessionStore) GetSession(sessionID string) *NetworkResilientSession {
if !s.enabled || sessionID == "" {
return nil
}
s.mutex.RLock()
defer s.mutex.RUnlock()
// Try Redis first if available
if s.redisClient != nil {
ctx := context.Background()
sessionData, err := s.redisClient.Get(ctx, "session:"+sessionID).Result()
if err == nil {
var session NetworkResilientSession
if json.Unmarshal([]byte(sessionData), &session) == nil {
log.Debugf("📊 Session retrieved from Redis: %s", sessionID)
return &session
}
}
}
// Fallback to memory cache
if s.memoryCache != nil {
if sessionData, found := s.memoryCache.Get(sessionID); found {
if session, ok := sessionData.(*NetworkResilientSession); ok {
log.Debugf("📊 Session retrieved from memory: %s", sessionID)
return session
}
}
}
// Fallback to in-memory map
if session, exists := s.storage[sessionID]; exists {
if time.Since(session.LastSeen) < 72*time.Hour {
log.Debugf("📊 Session retrieved from storage: %s", sessionID)
return session
}
}
return nil
}
func (s *SessionStore) StoreSession(sessionID string, session *NetworkResilientSession) {
if !s.enabled || sessionID == "" || session == nil {
return
}
s.mutex.Lock()
defer s.mutex.Unlock()
session.LastSeen = time.Now()
// Store in Redis if available
if s.redisClient != nil {
ctx := context.Background()
sessionData, err := json.Marshal(session)
if err == nil {
s.redisClient.Set(ctx, "session:"+sessionID, sessionData, 72*time.Hour)
log.Debugf("📊 Session stored in Redis: %s", sessionID)
}
}
// Store in memory cache
if s.memoryCache != nil {
s.memoryCache.Set(sessionID, session, 72*time.Hour)
log.Debugf("📊 Session stored in memory: %s", sessionID)
}
// Store in local map as final fallback
s.storage[sessionID] = session
log.Debugf("📊 Session stored in local storage: %s", sessionID)
}
func (s *SessionStore) DeleteSession(sessionID string) {
if !s.enabled || sessionID == "" {
return
}
s.mutex.Lock()
defer s.mutex.Unlock()
// Remove from Redis
if s.redisClient != nil {
ctx := context.Background()
s.redisClient.Del(ctx, "session:"+sessionID)
}
// Remove from memory cache
if s.memoryCache != nil {
s.memoryCache.Delete(sessionID)
}
// Remove from local storage
delete(s.storage, sessionID)
log.Debugf("📊 Session deleted: %s", sessionID)
}
func (s *SessionStore) cleanupRoutine() {
if !s.enabled {
return
}
for range s.cleanupTicker.C {
s.mutex.Lock()
for sessionID, session := range s.storage {
if time.Since(session.LastSeen) > 72*time.Hour {
delete(s.storage, sessionID)
log.Debugf("🧹 Cleaned up expired session: %s", sessionID)
}
}
s.mutex.Unlock()
}
}
// Initialize session store
func initializeSessionStore() {
enabled := viper.GetBool("session_store.enabled")
if !enabled {
log.Infof("📊 Session store disabled in configuration")
sessionStore = &SessionStore{enabled: false}
return
}
sessionStore = &SessionStore{
storage: make(map[string]*NetworkResilientSession),
cleanupTicker: time.NewTicker(30 * time.Minute),
enabled: true,
}
// Initialize memory cache
sessionStore.memoryCache = cache.New(72*time.Hour, 1*time.Hour)
// Optional Redis backend
if redisURL := viper.GetString("session_store.redis_url"); redisURL != "" {
opt, err := redis.ParseURL(redisURL)
if err == nil {
sessionStore.redisClient = redis.NewClient(opt)
// Test Redis connection
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := sessionStore.redisClient.Ping(ctx).Err(); err == nil {
log.Infof("📊 Session store: Redis backend initialized (%s)", redisURL)
} else {
log.Warnf("📊 Session store: Redis connection failed, using memory backend: %v", err)
sessionStore.redisClient = nil
}
} else {
log.Warnf("📊 Session store: Invalid Redis URL, using memory backend: %v", err)
}
}
if sessionStore.redisClient == nil {
log.Infof("📊 Session store: Memory backend initialized")
}
// Start cleanup routine
go sessionStore.cleanupRoutine()
}
// Generate session ID from user and context
func generateSessionID(userJID, filename string) string {
h := sha256.New()
h.Write([]byte(fmt.Sprintf("%s:%s:%d", userJID, filename, time.Now().UnixNano())))
return fmt.Sprintf("sess_%s", hex.EncodeToString(h.Sum(nil))[:16])
}
// Detect network context for intelligent switching
func detectNetworkContext(r *http.Request) string {
clientIP := getClientIP(r)
userAgent := r.Header.Get("User-Agent")
xForwardedFor := r.Header.Get("X-Forwarded-For")
// Detect network type based on IP ranges and headers
if strings.Contains(xForwardedFor, "10.") || strings.Contains(clientIP, "10.") {
return "cellular_lte"
} else if strings.Contains(clientIP, "192.168.") || strings.Contains(clientIP, "172.") {
return "wifi_private"
} else if strings.Contains(userAgent, "Mobile") || strings.Contains(userAgent, "Android") {
return "mobile_network"
} else if strings.Contains(clientIP, "127.0.0.1") || strings.Contains(clientIP, "::1") {
return "localhost"
}
return "external_network"
}
// Add session response headers for client tracking
func setSessionHeaders(w http.ResponseWriter, sessionID string) {
w.Header().Set("X-Session-ID", sessionID)
w.Header().Set("X-Session-Timeout", "259200") // 72 hours in seconds
w.Header().Set("X-Network-Resilience", "enabled")
}
// Extract session ID from request
func getSessionIDFromRequest(r *http.Request) string {
// Try header first
if sessionID := r.Header.Get("X-Session-ID"); sessionID != "" {
return sessionID
}
// Try query parameter
if sessionID := r.URL.Query().Get("session_id"); sessionID != "" {
return sessionID
}
// Try from Authorization header (for some XMPP clients)
if auth := r.Header.Get("Authorization"); strings.HasPrefix(auth, "Bearer ") {
token := strings.TrimPrefix(auth, "Bearer ")
// Generate consistent session ID from token
h := sha256.New()
h.Write([]byte(token))
return fmt.Sprintf("auth_%s", hex.EncodeToString(h.Sum(nil))[:16])
}
return ""
}
// parseSize converts a human-readable size string to bytes
func parseSize(sizeStr string) (int64, error) {
sizeStr = strings.TrimSpace(sizeStr)
@ -285,11 +560,6 @@ type ScanTask struct {
Result chan error
}
type NetworkEvent struct {
Type string
Details string
}
// Add a new field to store the creation date of files
type FileMetadata struct {
CreationDate time.Time
@ -620,6 +890,11 @@ func main() {
clientTracker.StartCleanupRoutine()
log.Info("Client multi-interface support initialized")
}
// Initialize session store for network resilience
initializeSessionStore()
log.Info("Session store for network switching initialized")
PrintValidationResults(validationResult)
if validationResult.HasErrors() {
@ -1228,6 +1503,14 @@ func setDefaults() {
viper.SetDefault("versioning.backend", "simple")
viper.SetDefault("versioning.max_revisions", 5)
// Session store defaults for network resilience
viper.SetDefault("session_store.enabled", true)
viper.SetDefault("session_store.backend", "memory")
viper.SetDefault("session_store.max_sessions", 10000)
viper.SetDefault("session_store.cleanup_interval", "30m")
viper.SetDefault("session_store.max_session_age", "72h")
viper.SetDefault("session_store.redis_url", "")
// ... other defaults for Uploads, Downloads, ClamAV, Redis, Workers, File, Build
viper.SetDefault("build.version", "dev")
}
@ -1575,6 +1858,180 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er
return claims, nil
}
// validateBearerTokenWithSession validates Bearer token with session recovery support
// ENHANCED FOR NETWORK SWITCHING: 5G ↔ WiFi transition support with session persistence
func validateBearerTokenWithSession(r *http.Request, secret string) (*BearerTokenClaims, error) {
// Step 1: Try standard Bearer token validation first
claims, err := validateBearerToken(r, secret)
if err == nil {
// Token is valid - create or update session for network resilience
sessionID := getSessionIDFromRequest(r)
if sessionID == "" {
sessionID = generateSessionID(claims.User, claims.Filename)
}
// Get or create session
session := sessionStore.GetSession(sessionID)
if session == nil {
session = &NetworkResilientSession{
SessionID: sessionID,
UserJID: claims.User,
OriginalToken: getBearerTokenFromRequest(r),
CreatedAt: time.Now(),
MaxRefreshes: 10,
NetworkHistory: []NetworkEvent{},
}
}
// Update session with current network context
currentIP := getClientIP(r)
userAgent := r.Header.Get("User-Agent")
if session.LastIP != "" && session.LastIP != currentIP {
// Network change detected
session.NetworkHistory = append(session.NetworkHistory, NetworkEvent{
Timestamp: time.Now(),
FromNetwork: session.LastIP,
ToNetwork: currentIP,
ClientIP: currentIP,
UserAgent: userAgent,
EventType: "network_switch",
})
log.Infof("🌐 Network switch detected for session %s: %s → %s",
sessionID, session.LastIP, currentIP)
}
session.LastIP = currentIP
session.UserAgent = userAgent
sessionStore.StoreSession(sessionID, session)
// Set session headers in response
if w, ok := r.Context().Value("responseWriter").(http.ResponseWriter); ok {
setSessionHeaders(w, sessionID)
}
log.Infof("✅ Bearer token valid, session updated: %s (user: %s)", sessionID, claims.User)
return claims, nil
}
// Step 2: Token validation failed - try session recovery
sessionID := getSessionIDFromRequest(r)
if sessionID != "" {
session := sessionStore.GetSession(sessionID)
if session != nil {
// Check if session is still valid (within 72-hour window)
sessionAge := time.Since(session.CreatedAt)
if sessionAge < 72*time.Hour {
log.Infof("🔄 Session recovery attempt for %s (age: %v)", sessionID, sessionAge)
// Check if we can refresh the token
if session.RefreshCount < session.MaxRefreshes {
_, err := refreshSessionToken(session, secret, r)
if err == nil {
// Token refresh successful
session.RefreshCount++
session.LastSeen = time.Now()
// Add refresh event to history
session.NetworkHistory = append(session.NetworkHistory, NetworkEvent{
Timestamp: time.Now(),
ClientIP: getClientIP(r),
UserAgent: r.Header.Get("User-Agent"),
EventType: "token_refresh",
})
sessionStore.StoreSession(sessionID, session)
// Create claims from refreshed session
refreshedClaims := &BearerTokenClaims{
User: session.UserJID,
Filename: extractFilenameFromPath(r.URL.Path),
Size: extractSizeFromRequest(r),
Expiry: time.Now().Add(24 * time.Hour).Unix(),
}
log.Infof("✅ Session recovery successful: %s (refresh #%d)",
sessionID, session.RefreshCount)
return refreshedClaims, nil
}
} else {
log.Warnf("❌ Session %s exceeded maximum refreshes (%d)",
sessionID, session.MaxRefreshes)
}
} else {
log.Warnf("❌ Session %s expired (age: %v, max: 72h)", sessionID, sessionAge)
}
} else {
log.Warnf("❌ Session %s not found in store", sessionID)
}
}
// Step 3: No valid token or session recovery possible
log.Warnf("❌ Authentication failed: %v (no session recovery available)", err)
return nil, fmt.Errorf("authentication failed: %v", err)
}
// refreshSessionToken generates a new token for an existing session
func refreshSessionToken(session *NetworkResilientSession, secret string, r *http.Request) (string, error) {
if session.RefreshCount >= session.MaxRefreshes {
return "", fmt.Errorf("maximum token refreshes exceeded")
}
// Generate new HMAC token with extended validity
timestamp := time.Now().Unix()
expiry := timestamp + 86400 // 24 hours
filename := extractFilenameFromPath(r.URL.Path)
size := extractSizeFromRequest(r)
// Use session-based payload format for refresh
payload := fmt.Sprintf("%s\x00%s\x00%d\x00%d\x00%s\x00session_refresh",
session.UserJID,
filename,
size,
expiry,
session.SessionID)
h := hmac.New(sha256.New, []byte(secret))
h.Write([]byte(payload))
token := base64.StdEncoding.EncodeToString(h.Sum(nil))
log.Infof("🆕 Generated refresh token for session %s (refresh #%d)",
session.SessionID, session.RefreshCount+1)
return token, nil
}
// Helper functions for token and session management
func getBearerTokenFromRequest(r *http.Request) string {
authHeader := r.Header.Get("Authorization")
if strings.HasPrefix(authHeader, "Bearer ") {
return strings.TrimPrefix(authHeader, "Bearer ")
}
return ""
}
func extractFilenameFromPath(path string) string {
pathParts := strings.Split(strings.Trim(path, "/"), "/")
if len(pathParts) >= 1 {
return pathParts[len(pathParts)-1]
}
return "unknown"
}
func extractSizeFromRequest(r *http.Request) int64 {
if sizeStr := r.Header.Get("Content-Length"); sizeStr != "" {
if size, err := strconv.ParseInt(sizeStr, 10, 64); err == nil {
return size
}
}
if sizeStr := r.URL.Query().Get("size"); sizeStr != "" {
if size, err := strconv.ParseInt(sizeStr, 10, 64); err == nil {
return size
}
}
return 0
}
// BearerTokenClaims represents the claims extracted from a Bearer token
type BearerTokenClaims struct {
User string
@ -1895,24 +2352,6 @@ func validateV3HMAC(r *http.Request, secret string) error {
return nil
}
// generateSessionID creates a unique session ID for client tracking
// ENHANCED FOR NETWORK SWITCHING SCENARIOS
func generateSessionID() string {
// Use multiple entropy sources for better uniqueness across network switches
timestamp := time.Now().UnixNano()
randomBytes := make([]byte, 16)
if _, err := rand.Read(randomBytes); err != nil {
// Fallback to time-based generation if random fails
h := sha256.Sum256([]byte(fmt.Sprintf("%d%s", timestamp, conf.Security.Secret)))
return fmt.Sprintf("session_%x", h[:8])
}
// Combine timestamp, random bytes, and secret for maximum entropy
combined := fmt.Sprintf("%d_%x_%s", timestamp, randomBytes, conf.Security.Secret)
h := sha256.Sum256([]byte(combined))
return fmt.Sprintf("session_%x", h[:12])
}
// copyWithProgressTracking copies data with progress tracking for large downloads
func copyWithProgressTracking(dst io.Writer, src io.Reader, buf []byte, totalSize int64, clientIP string) (int64, error) {
var written int64
@ -1966,18 +2405,28 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if strings.HasPrefix(authHeader, "Bearer ") {
// Bearer token authentication (ejabberd module) - now with enhanced network switching support
claims, err := validateBearerToken(r, conf.Security.Secret)
// Bearer token authentication with session recovery for network switching
// Store response writer in context for session headers
ctx := context.WithValue(r.Context(), "responseWriter", w)
r = r.WithContext(ctx)
claims, err := validateBearerTokenWithSession(r, conf.Security.Secret)
if err != nil {
// Enhanced error logging for network switching scenarios
clientIP := getClientIP(r)
userAgent := r.Header.Get("User-Agent")
log.Warnf("🔴 Bearer Token Authentication failed for IP %s, User-Agent: %s, Error: %v", clientIP, userAgent, err)
sessionID := getSessionIDFromRequest(r)
log.Warnf("🔴 Authentication failed for IP %s, User-Agent: %s, Session: %s, Error: %v",
clientIP, userAgent, sessionID, err)
// Check if this might be a network switching scenario and provide helpful response
if strings.Contains(err.Error(), "expired") {
if strings.Contains(err.Error(), "expired") || strings.Contains(err.Error(), "invalid") {
w.Header().Set("X-Network-Switch-Detected", "true")
w.Header().Set("X-Retry-After", "30") // Suggest retry after 30 seconds
w.Header().Set("X-Session-Recovery", "available")
if sessionID != "" {
w.Header().Set("X-Session-ID", sessionID)
}
}
http.Error(w, fmt.Sprintf("Bearer Token Authentication failed: %v", err), http.StatusUnauthorized)
@ -2030,7 +2479,7 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
}
if sessionID == "" {
// Generate new session ID with enhanced entropy
sessionID = generateSessionID()
sessionID = generateSessionID("", "")
}
clientIP := getClientIP(r)

View File

@ -98,15 +98,6 @@ type AdaptiveTicker struct {
done chan bool
}
// UploadContext tracks active upload state
type UploadContext struct {
SessionID string
PauseChan chan bool
ResumeChan chan bool
CancelChan chan bool
IsPaused bool
}
// NewNetworkResilienceManager creates a new network resilience manager with enhanced capabilities
func NewNetworkResilienceManager() *NetworkResilienceManager {
// Get configuration from global config, with sensible defaults

View File

@ -62,7 +62,7 @@ func (s *UploadSessionStore) CreateSession(filename string, totalSize int64, cli
s.mutex.Lock()
defer s.mutex.Unlock()
sessionID := generateSessionID()
sessionID := generateSessionID("", filename)
tempDir := filepath.Join(s.tempDir, sessionID)
os.MkdirAll(tempDir, 0755)