Files
dbbackup/internal/database/postgresql.go
Renz 9b3c3f2b1b Initial commit: Database Backup Tool v1.1.0
- PostgreSQL and MySQL support
- Interactive TUI with fixed menu navigation
- Line-by-line progress display
- CPU-aware parallel processing
- Cross-platform build support
- Configuration settings menu
- Silent mode for TUI operations
2025-10-22 19:27:38 +00:00

427 lines
11 KiB
Go

package database
import (
"context"
"database/sql"
"fmt"
"os"
"os/exec"
"strconv"
"strings"
"dbbackup/internal/config"
"dbbackup/internal/logger"
)
// PostgreSQL implements Database interface for PostgreSQL
type PostgreSQL struct {
baseDatabase
}
// NewPostgreSQL creates a new PostgreSQL database instance
func NewPostgreSQL(cfg *config.Config, log logger.Logger) *PostgreSQL {
return &PostgreSQL{
baseDatabase: baseDatabase{
cfg: cfg,
log: log,
},
}
}
// Connect establishes a connection to PostgreSQL
func (p *PostgreSQL) Connect(ctx context.Context) error {
// Build PostgreSQL DSN
dsn := p.buildDSN()
p.dsn = dsn
p.log.Debug("Connecting to PostgreSQL", "dsn", sanitizeDSN(dsn))
db, err := sql.Open("postgres", dsn)
if err != nil {
return fmt.Errorf("failed to open PostgreSQL connection: %w", err)
}
// Configure connection pool
db.SetMaxOpenConns(10)
db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(0)
// Test connection
timeoutCtx, cancel := buildTimeout(ctx, 0)
defer cancel()
if err := db.PingContext(timeoutCtx); err != nil {
db.Close()
return fmt.Errorf("failed to ping PostgreSQL: %w", err)
}
p.db = db
p.log.Info("Connected to PostgreSQL successfully")
return nil
}
// ListDatabases returns list of non-template databases
func (p *PostgreSQL) ListDatabases(ctx context.Context) ([]string, error) {
if p.db == nil {
return nil, fmt.Errorf("not connected to database")
}
query := `SELECT datname FROM pg_database
WHERE datistemplate = false
ORDER BY datname`
rows, err := p.db.QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to query databases: %w", err)
}
defer rows.Close()
var databases []string
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
return nil, fmt.Errorf("failed to scan database name: %w", err)
}
databases = append(databases, name)
}
return databases, rows.Err()
}
// ListTables returns list of tables in a database
func (p *PostgreSQL) ListTables(ctx context.Context, database string) ([]string, error) {
if p.db == nil {
return nil, fmt.Errorf("not connected to database")
}
query := `SELECT schemaname||'.'||tablename as full_name
FROM pg_tables
WHERE schemaname NOT IN ('information_schema', 'pg_catalog', 'pg_toast')
ORDER BY schemaname, tablename`
rows, err := p.db.QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to query tables: %w", err)
}
defer rows.Close()
var tables []string
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
return nil, fmt.Errorf("failed to scan table name: %w", err)
}
tables = append(tables, name)
}
return tables, rows.Err()
}
// CreateDatabase creates a new database
func (p *PostgreSQL) CreateDatabase(ctx context.Context, name string) error {
if p.db == nil {
return fmt.Errorf("not connected to database")
}
// PostgreSQL doesn't support CREATE DATABASE in transactions or prepared statements
query := fmt.Sprintf("CREATE DATABASE %s", name)
_, err := p.db.ExecContext(ctx, query)
if err != nil {
return fmt.Errorf("failed to create database %s: %w", name, err)
}
p.log.Info("Created database", "name", name)
return nil
}
// DropDatabase drops a database
func (p *PostgreSQL) DropDatabase(ctx context.Context, name string) error {
if p.db == nil {
return fmt.Errorf("not connected to database")
}
// Force drop connections and drop database
query := fmt.Sprintf("DROP DATABASE IF EXISTS %s", name)
_, err := p.db.ExecContext(ctx, query)
if err != nil {
return fmt.Errorf("failed to drop database %s: %w", name, err)
}
p.log.Info("Dropped database", "name", name)
return nil
}
// DatabaseExists checks if a database exists
func (p *PostgreSQL) DatabaseExists(ctx context.Context, name string) (bool, error) {
if p.db == nil {
return false, fmt.Errorf("not connected to database")
}
query := `SELECT EXISTS(SELECT 1 FROM pg_database WHERE datname = $1)`
var exists bool
err := p.db.QueryRowContext(ctx, query, name).Scan(&exists)
if err != nil {
return false, fmt.Errorf("failed to check database existence: %w", err)
}
return exists, nil
}
// GetVersion returns PostgreSQL version
func (p *PostgreSQL) GetVersion(ctx context.Context) (string, error) {
if p.db == nil {
return "", fmt.Errorf("not connected to database")
}
var version string
err := p.db.QueryRowContext(ctx, "SELECT version()").Scan(&version)
if err != nil {
return "", fmt.Errorf("failed to get version: %w", err)
}
return version, nil
}
// GetDatabaseSize returns database size in bytes
func (p *PostgreSQL) GetDatabaseSize(ctx context.Context, database string) (int64, error) {
if p.db == nil {
return 0, fmt.Errorf("not connected to database")
}
query := `SELECT pg_database_size($1)`
var size int64
err := p.db.QueryRowContext(ctx, query, database).Scan(&size)
if err != nil {
return 0, fmt.Errorf("failed to get database size: %w", err)
}
return size, nil
}
// GetTableRowCount returns approximate row count for a table
func (p *PostgreSQL) GetTableRowCount(ctx context.Context, database, table string) (int64, error) {
if p.db == nil {
return 0, fmt.Errorf("not connected to database")
}
// Use pg_stat_user_tables for approximate count (faster)
parts := strings.Split(table, ".")
if len(parts) != 2 {
return 0, fmt.Errorf("table name must be in format schema.table")
}
query := `SELECT COALESCE(n_tup_ins, 0) FROM pg_stat_user_tables
WHERE schemaname = $1 AND relname = $2`
var count int64
err := p.db.QueryRowContext(ctx, query, parts[0], parts[1]).Scan(&count)
if err != nil {
// Fallback to exact count if stats not available
exactQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s", table)
err = p.db.QueryRowContext(ctx, exactQuery).Scan(&count)
if err != nil {
return 0, fmt.Errorf("failed to get table row count: %w", err)
}
}
return count, nil
}
// BuildBackupCommand builds pg_dump command
func (p *PostgreSQL) BuildBackupCommand(database, outputFile string, options BackupOptions) []string {
cmd := []string{"pg_dump"}
// Connection parameters
if p.cfg.Host != "localhost" {
cmd = append(cmd, "-h", p.cfg.Host)
cmd = append(cmd, "-p", strconv.Itoa(p.cfg.Port))
cmd = append(cmd, "--no-password")
}
cmd = append(cmd, "-U", p.cfg.User)
// Format and compression
if options.Format != "" {
cmd = append(cmd, "--format="+options.Format)
} else {
cmd = append(cmd, "--format=custom")
}
if options.Compression > 0 {
cmd = append(cmd, "--compress="+strconv.Itoa(options.Compression))
}
// Parallel jobs (only for directory format)
if options.Parallel > 1 && options.Format == "directory" {
cmd = append(cmd, "--jobs="+strconv.Itoa(options.Parallel))
}
// Options
if options.Blobs {
cmd = append(cmd, "--blobs")
}
if options.SchemaOnly {
cmd = append(cmd, "--schema-only")
}
if options.DataOnly {
cmd = append(cmd, "--data-only")
}
if options.NoOwner {
cmd = append(cmd, "--no-owner")
}
if options.NoPrivileges {
cmd = append(cmd, "--no-privileges")
}
if options.Role != "" {
cmd = append(cmd, "--role="+options.Role)
}
// Database and output
cmd = append(cmd, "--dbname="+database)
cmd = append(cmd, "--file="+outputFile)
return cmd
}
// BuildRestoreCommand builds pg_restore command
func (p *PostgreSQL) BuildRestoreCommand(database, inputFile string, options RestoreOptions) []string {
cmd := []string{"pg_restore"}
// Connection parameters
if p.cfg.Host != "localhost" {
cmd = append(cmd, "-h", p.cfg.Host)
cmd = append(cmd, "-p", strconv.Itoa(p.cfg.Port))
cmd = append(cmd, "--no-password")
}
cmd = append(cmd, "-U", p.cfg.User)
// Parallel jobs
if options.Parallel > 1 {
cmd = append(cmd, "--jobs="+strconv.Itoa(options.Parallel))
}
// Options
if options.Clean {
cmd = append(cmd, "--clean")
}
if options.IfExists {
cmd = append(cmd, "--if-exists")
}
if options.NoOwner {
cmd = append(cmd, "--no-owner")
}
if options.NoPrivileges {
cmd = append(cmd, "--no-privileges")
}
if options.SingleTransaction {
cmd = append(cmd, "--single-transaction")
}
// Database and input
cmd = append(cmd, "--dbname="+database)
cmd = append(cmd, inputFile)
return cmd
}
// BuildSampleQuery builds SQL query for sampling data
func (p *PostgreSQL) BuildSampleQuery(database, table string, strategy SampleStrategy) string {
switch strategy.Type {
case "ratio":
// Every Nth record using row_number
return fmt.Sprintf("SELECT * FROM (SELECT *, row_number() OVER () as rn FROM %s) t WHERE rn %% %d = 1",
table, strategy.Value)
case "percent":
// Percentage sampling using TABLESAMPLE (PostgreSQL 9.5+)
return fmt.Sprintf("SELECT * FROM %s TABLESAMPLE BERNOULLI(%d)", table, strategy.Value)
case "count":
// First N records
return fmt.Sprintf("SELECT * FROM %s LIMIT %d", table, strategy.Value)
default:
return fmt.Sprintf("SELECT * FROM %s LIMIT 1000", table)
}
}
// ValidateBackupTools checks if required PostgreSQL tools are available
func (p *PostgreSQL) ValidateBackupTools() error {
tools := []string{"pg_dump", "pg_restore", "pg_dumpall", "psql"}
for _, tool := range tools {
if _, err := exec.LookPath(tool); err != nil {
return fmt.Errorf("required tool not found: %s", tool)
}
}
return nil
}
// buildDSN constructs PostgreSQL connection string
func (p *PostgreSQL) buildDSN() string {
dsn := fmt.Sprintf("user=%s dbname=%s", p.cfg.User, p.cfg.Database)
if p.cfg.Password != "" {
dsn += " password=" + p.cfg.Password
}
// For localhost connections, try socket first for peer auth
if p.cfg.Host == "localhost" && p.cfg.Password == "" {
// Try Unix socket connection for peer authentication
// Common PostgreSQL socket locations
socketDirs := []string{
"/var/run/postgresql",
"/tmp",
"/var/lib/pgsql",
}
for _, dir := range socketDirs {
socketPath := fmt.Sprintf("%s/.s.PGSQL.%d", dir, p.cfg.Port)
if _, err := os.Stat(socketPath); err == nil {
dsn += " host=" + dir
p.log.Debug("Using PostgreSQL socket", "path", socketPath)
break
}
}
} else if p.cfg.Host != "localhost" || p.cfg.Password != "" {
// Use TCP connection
dsn += " host=" + p.cfg.Host
dsn += " port=" + strconv.Itoa(p.cfg.Port)
}
if p.cfg.SSLMode != "" && !p.cfg.Insecure {
// Map SSL modes to supported values for lib/pq
switch strings.ToLower(p.cfg.SSLMode) {
case "prefer", "preferred":
dsn += " sslmode=require" // lib/pq default, closest to prefer
case "require", "required":
dsn += " sslmode=require"
case "verify-ca":
dsn += " sslmode=verify-ca"
case "verify-full", "verify-identity":
dsn += " sslmode=verify-full"
case "disable", "disabled":
dsn += " sslmode=disable"
default:
dsn += " sslmode=require" // Safe default
}
} else if p.cfg.Insecure {
dsn += " sslmode=disable"
}
return dsn
}
// sanitizeDSN removes password from DSN for logging
func sanitizeDSN(dsn string) string {
// Simple password removal for logging
parts := strings.Split(dsn, " ")
var sanitized []string
for _, part := range parts {
if strings.HasPrefix(part, "password=") {
sanitized = append(sanitized, "password=***")
} else {
sanitized = append(sanitized, part)
}
}
return strings.Join(sanitized, " ")
}