Initial commit: Database Backup Tool v1.1.0
- PostgreSQL and MySQL support - Interactive TUI with fixed menu navigation - Line-by-line progress display - CPU-aware parallel processing - Cross-platform build support - Configuration settings menu - Silent mode for TUI operations
This commit is contained in:
133
internal/database/interface.go
Normal file
133
internal/database/interface.go
Normal file
@ -0,0 +1,133 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"dbbackup/internal/config"
|
||||
"dbbackup/internal/logger"
|
||||
|
||||
_ "github.com/lib/pq" // PostgreSQL driver
|
||||
_ "github.com/go-sql-driver/mysql" // MySQL driver
|
||||
)
|
||||
|
||||
// Database represents a database connection and operations
|
||||
type Database interface {
|
||||
// Connection management
|
||||
Connect(ctx context.Context) error
|
||||
Close() error
|
||||
Ping(ctx context.Context) error
|
||||
|
||||
// Database discovery
|
||||
ListDatabases(ctx context.Context) ([]string, error)
|
||||
ListTables(ctx context.Context, database string) ([]string, error)
|
||||
|
||||
// Database operations
|
||||
CreateDatabase(ctx context.Context, name string) error
|
||||
DropDatabase(ctx context.Context, name string) error
|
||||
DatabaseExists(ctx context.Context, name string) (bool, error)
|
||||
|
||||
// Information
|
||||
GetVersion(ctx context.Context) (string, error)
|
||||
GetDatabaseSize(ctx context.Context, database string) (int64, error)
|
||||
GetTableRowCount(ctx context.Context, database, table string) (int64, error)
|
||||
|
||||
// Backup/Restore command building
|
||||
BuildBackupCommand(database, outputFile string, options BackupOptions) []string
|
||||
BuildRestoreCommand(database, inputFile string, options RestoreOptions) []string
|
||||
BuildSampleQuery(database, table string, strategy SampleStrategy) string
|
||||
|
||||
// Validation
|
||||
ValidateBackupTools() error
|
||||
}
|
||||
|
||||
// BackupOptions holds options for backup operations
|
||||
type BackupOptions struct {
|
||||
Compression int
|
||||
Parallel int
|
||||
Format string // "custom", "plain", "directory"
|
||||
Blobs bool
|
||||
SchemaOnly bool
|
||||
DataOnly bool
|
||||
NoOwner bool
|
||||
NoPrivileges bool
|
||||
Clean bool
|
||||
IfExists bool
|
||||
Role string
|
||||
}
|
||||
|
||||
// RestoreOptions holds options for restore operations
|
||||
type RestoreOptions struct {
|
||||
Parallel int
|
||||
Clean bool
|
||||
IfExists bool
|
||||
NoOwner bool
|
||||
NoPrivileges bool
|
||||
SingleTransaction bool
|
||||
}
|
||||
|
||||
// SampleStrategy defines how to sample data
|
||||
type SampleStrategy struct {
|
||||
Type string // "ratio", "percent", "count"
|
||||
Value int
|
||||
}
|
||||
|
||||
// DatabaseInfo holds database metadata
|
||||
type DatabaseInfo struct {
|
||||
Name string
|
||||
Size int64
|
||||
Owner string
|
||||
Encoding string
|
||||
Collation string
|
||||
Tables []TableInfo
|
||||
}
|
||||
|
||||
// TableInfo holds table metadata
|
||||
type TableInfo struct {
|
||||
Schema string
|
||||
Name string
|
||||
RowCount int64
|
||||
Size int64
|
||||
}
|
||||
|
||||
// New creates a new database instance based on configuration
|
||||
func New(cfg *config.Config, log logger.Logger) (Database, error) {
|
||||
if cfg.IsPostgreSQL() {
|
||||
return NewPostgreSQL(cfg, log), nil
|
||||
} else if cfg.IsMySQL() {
|
||||
return NewMySQL(cfg, log), nil
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported database type: %s", cfg.DatabaseType)
|
||||
}
|
||||
|
||||
// Common database implementation
|
||||
type baseDatabase struct {
|
||||
cfg *config.Config
|
||||
log logger.Logger
|
||||
db *sql.DB
|
||||
dsn string
|
||||
}
|
||||
|
||||
func (b *baseDatabase) Close() error {
|
||||
if b.db != nil {
|
||||
return b.db.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *baseDatabase) Ping(ctx context.Context) error {
|
||||
if b.db == nil {
|
||||
return fmt.Errorf("database not connected")
|
||||
}
|
||||
return b.db.PingContext(ctx)
|
||||
}
|
||||
|
||||
// buildTimeout creates a context with timeout for database operations
|
||||
func buildTimeout(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
|
||||
if timeout <= 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
return context.WithTimeout(ctx, timeout)
|
||||
}
|
||||
410
internal/database/mysql.go
Normal file
410
internal/database/mysql.go
Normal file
@ -0,0 +1,410 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"dbbackup/internal/config"
|
||||
"dbbackup/internal/logger"
|
||||
)
|
||||
|
||||
// MySQL implements Database interface for MySQL
|
||||
type MySQL struct {
|
||||
baseDatabase
|
||||
}
|
||||
|
||||
// NewMySQL creates a new MySQL database instance
|
||||
func NewMySQL(cfg *config.Config, log logger.Logger) *MySQL {
|
||||
return &MySQL{
|
||||
baseDatabase: baseDatabase{
|
||||
cfg: cfg,
|
||||
log: log,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Connect establishes a connection to MySQL
|
||||
func (m *MySQL) Connect(ctx context.Context) error {
|
||||
// Build MySQL DSN
|
||||
dsn := m.buildDSN()
|
||||
m.dsn = dsn
|
||||
|
||||
m.log.Debug("Connecting to MySQL", "dsn", sanitizeMySQLDSN(dsn))
|
||||
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open MySQL connection: %w", err)
|
||||
}
|
||||
|
||||
// Configure connection pool
|
||||
db.SetMaxOpenConns(10)
|
||||
db.SetMaxIdleConns(5)
|
||||
db.SetConnMaxLifetime(0)
|
||||
|
||||
// Test connection
|
||||
timeoutCtx, cancel := buildTimeout(ctx, 0)
|
||||
defer cancel()
|
||||
|
||||
if err := db.PingContext(timeoutCtx); err != nil {
|
||||
db.Close()
|
||||
return fmt.Errorf("failed to ping MySQL: %w", err)
|
||||
}
|
||||
|
||||
m.db = db
|
||||
m.log.Info("Connected to MySQL successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListDatabases returns list of databases (excluding system databases)
|
||||
func (m *MySQL) ListDatabases(ctx context.Context) ([]string, error) {
|
||||
if m.db == nil {
|
||||
return nil, fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
query := `SHOW DATABASES`
|
||||
|
||||
rows, err := m.db.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query databases: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var databases []string
|
||||
systemDbs := map[string]bool{
|
||||
"information_schema": true,
|
||||
"performance_schema": true,
|
||||
"mysql": true,
|
||||
"sys": true,
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var name string
|
||||
if err := rows.Scan(&name); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan database name: %w", err)
|
||||
}
|
||||
|
||||
// Skip system databases
|
||||
if !systemDbs[name] {
|
||||
databases = append(databases, name)
|
||||
}
|
||||
}
|
||||
|
||||
return databases, rows.Err()
|
||||
}
|
||||
|
||||
// ListTables returns list of tables in a database
|
||||
func (m *MySQL) ListTables(ctx context.Context, database string) ([]string, error) {
|
||||
if m.db == nil {
|
||||
return nil, fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
query := `SELECT table_name FROM information_schema.tables
|
||||
WHERE table_schema = ? AND table_type = 'BASE TABLE'
|
||||
ORDER BY table_name`
|
||||
|
||||
rows, err := m.db.QueryContext(ctx, query, database)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query tables: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tables []string
|
||||
for rows.Next() {
|
||||
var name string
|
||||
if err := rows.Scan(&name); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan table name: %w", err)
|
||||
}
|
||||
tables = append(tables, name)
|
||||
}
|
||||
|
||||
return tables, rows.Err()
|
||||
}
|
||||
|
||||
// CreateDatabase creates a new database
|
||||
func (m *MySQL) CreateDatabase(ctx context.Context, name string) error {
|
||||
if m.db == nil {
|
||||
return fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s`", name)
|
||||
_, err := m.db.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create database %s: %w", name, err)
|
||||
}
|
||||
|
||||
m.log.Info("Created database", "name", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DropDatabase drops a database
|
||||
func (m *MySQL) DropDatabase(ctx context.Context, name string) error {
|
||||
if m.db == nil {
|
||||
return fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", name)
|
||||
_, err := m.db.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to drop database %s: %w", name, err)
|
||||
}
|
||||
|
||||
m.log.Info("Dropped database", "name", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DatabaseExists checks if a database exists
|
||||
func (m *MySQL) DatabaseExists(ctx context.Context, name string) (bool, error) {
|
||||
if m.db == nil {
|
||||
return false, fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
query := `SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?`
|
||||
var dbName string
|
||||
err := m.db.QueryRowContext(ctx, query, name).Scan(&dbName)
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check database existence: %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// GetVersion returns MySQL version
|
||||
func (m *MySQL) GetVersion(ctx context.Context) (string, error) {
|
||||
if m.db == nil {
|
||||
return "", fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
var version string
|
||||
err := m.db.QueryRowContext(ctx, "SELECT VERSION()").Scan(&version)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get version: %w", err)
|
||||
}
|
||||
|
||||
return version, nil
|
||||
}
|
||||
|
||||
// GetDatabaseSize returns database size in bytes
|
||||
func (m *MySQL) GetDatabaseSize(ctx context.Context, database string) (int64, error) {
|
||||
if m.db == nil {
|
||||
return 0, fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
query := `SELECT COALESCE(SUM(data_length + index_length), 0) as size_bytes
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = ?`
|
||||
|
||||
var size int64
|
||||
err := m.db.QueryRowContext(ctx, query, database).Scan(&size)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get database size: %w", err)
|
||||
}
|
||||
|
||||
return size, nil
|
||||
}
|
||||
|
||||
// GetTableRowCount returns row count for a table
|
||||
func (m *MySQL) GetTableRowCount(ctx context.Context, database, table string) (int64, error) {
|
||||
if m.db == nil {
|
||||
return 0, fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
// First try information_schema for approximate count (faster)
|
||||
query := `SELECT table_rows FROM information_schema.tables
|
||||
WHERE table_schema = ? AND table_name = ?`
|
||||
|
||||
var count int64
|
||||
err := m.db.QueryRowContext(ctx, query, database, table).Scan(&count)
|
||||
if err != nil || count == 0 {
|
||||
// Fallback to exact count
|
||||
exactQuery := fmt.Sprintf("SELECT COUNT(*) FROM `%s`.`%s`", database, table)
|
||||
err = m.db.QueryRowContext(ctx, exactQuery).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get table row count: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// BuildBackupCommand builds mysqldump command
|
||||
func (m *MySQL) BuildBackupCommand(database, outputFile string, options BackupOptions) []string {
|
||||
cmd := []string{"mysqldump"}
|
||||
|
||||
// Connection parameters
|
||||
cmd = append(cmd, "-h", m.cfg.Host)
|
||||
cmd = append(cmd, "-P", strconv.Itoa(m.cfg.Port))
|
||||
cmd = append(cmd, "-u", m.cfg.User)
|
||||
|
||||
if m.cfg.Password != "" {
|
||||
cmd = append(cmd, "-p"+m.cfg.Password)
|
||||
}
|
||||
|
||||
// SSL options
|
||||
if m.cfg.Insecure {
|
||||
cmd = append(cmd, "--skip-ssl")
|
||||
} else if m.cfg.SSLMode != "" {
|
||||
// MySQL SSL modes: DISABLED, PREFERRED, REQUIRED, VERIFY_CA, VERIFY_IDENTITY
|
||||
switch strings.ToLower(m.cfg.SSLMode) {
|
||||
case "disable", "disabled":
|
||||
cmd = append(cmd, "--skip-ssl")
|
||||
case "require", "required":
|
||||
cmd = append(cmd, "--ssl-mode=REQUIRED")
|
||||
case "verify-ca":
|
||||
cmd = append(cmd, "--ssl-mode=VERIFY_CA")
|
||||
case "verify-full", "verify-identity":
|
||||
cmd = append(cmd, "--ssl-mode=VERIFY_IDENTITY")
|
||||
default:
|
||||
cmd = append(cmd, "--ssl-mode=PREFERRED")
|
||||
}
|
||||
}
|
||||
|
||||
// Backup options
|
||||
cmd = append(cmd, "--single-transaction") // Consistent backup
|
||||
cmd = append(cmd, "--routines") // Include stored procedures/functions
|
||||
cmd = append(cmd, "--triggers") // Include triggers
|
||||
cmd = append(cmd, "--events") // Include events
|
||||
|
||||
if options.SchemaOnly {
|
||||
cmd = append(cmd, "--no-data")
|
||||
} else if options.DataOnly {
|
||||
cmd = append(cmd, "--no-create-info")
|
||||
}
|
||||
|
||||
if options.NoOwner || options.NoPrivileges {
|
||||
cmd = append(cmd, "--skip-add-drop-table")
|
||||
}
|
||||
|
||||
// Compression (handled externally for MySQL)
|
||||
// Output redirection will be handled by caller
|
||||
|
||||
// Database
|
||||
cmd = append(cmd, database)
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// BuildRestoreCommand builds mysql restore command
|
||||
func (m *MySQL) BuildRestoreCommand(database, inputFile string, options RestoreOptions) []string {
|
||||
cmd := []string{"mysql"}
|
||||
|
||||
// Connection parameters
|
||||
cmd = append(cmd, "-h", m.cfg.Host)
|
||||
cmd = append(cmd, "-P", strconv.Itoa(m.cfg.Port))
|
||||
cmd = append(cmd, "-u", m.cfg.User)
|
||||
|
||||
if m.cfg.Password != "" {
|
||||
cmd = append(cmd, "-p"+m.cfg.Password)
|
||||
}
|
||||
|
||||
// SSL options
|
||||
if m.cfg.Insecure {
|
||||
cmd = append(cmd, "--skip-ssl")
|
||||
}
|
||||
|
||||
// Options
|
||||
if options.SingleTransaction {
|
||||
cmd = append(cmd, "--single-transaction")
|
||||
}
|
||||
|
||||
// Database
|
||||
cmd = append(cmd, database)
|
||||
|
||||
// Input file (will be handled via stdin redirection)
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// BuildSampleQuery builds SQL query for sampling data
|
||||
func (m *MySQL) BuildSampleQuery(database, table string, strategy SampleStrategy) string {
|
||||
switch strategy.Type {
|
||||
case "ratio":
|
||||
// Every Nth record using row_number (MySQL 8.0+) or modulo
|
||||
return fmt.Sprintf("SELECT * FROM (SELECT *, (@row_number:=@row_number + 1) AS rn FROM %s.%s CROSS JOIN (SELECT @row_number:=0) AS t) AS numbered WHERE rn %% %d = 1",
|
||||
database, table, strategy.Value)
|
||||
case "percent":
|
||||
// Percentage sampling using RAND()
|
||||
return fmt.Sprintf("SELECT * FROM %s.%s WHERE RAND() <= %f",
|
||||
database, table, float64(strategy.Value)/100.0)
|
||||
case "count":
|
||||
// First N records
|
||||
return fmt.Sprintf("SELECT * FROM %s.%s LIMIT %d", database, table, strategy.Value)
|
||||
default:
|
||||
return fmt.Sprintf("SELECT * FROM %s.%s LIMIT 1000", database, table)
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateBackupTools checks if required MySQL tools are available
|
||||
func (m *MySQL) ValidateBackupTools() error {
|
||||
tools := []string{"mysqldump", "mysql"}
|
||||
|
||||
for _, tool := range tools {
|
||||
if _, err := exec.LookPath(tool); err != nil {
|
||||
return fmt.Errorf("required tool not found: %s", tool)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildDSN constructs MySQL connection string
|
||||
func (m *MySQL) buildDSN() string {
|
||||
dsn := ""
|
||||
|
||||
if m.cfg.User != "" {
|
||||
dsn += m.cfg.User
|
||||
}
|
||||
|
||||
if m.cfg.Password != "" {
|
||||
dsn += ":" + m.cfg.Password
|
||||
}
|
||||
|
||||
dsn += "@"
|
||||
|
||||
if m.cfg.Host != "" && m.cfg.Host != "localhost" {
|
||||
dsn += "tcp(" + m.cfg.Host + ":" + strconv.Itoa(m.cfg.Port) + ")"
|
||||
}
|
||||
|
||||
dsn += "/" + m.cfg.Database
|
||||
|
||||
// Add connection parameters
|
||||
params := []string{}
|
||||
|
||||
if m.cfg.Insecure {
|
||||
params = append(params, "tls=skip-verify")
|
||||
} else if m.cfg.SSLMode != "" {
|
||||
switch strings.ToLower(m.cfg.SSLMode) {
|
||||
case "disable", "disabled":
|
||||
params = append(params, "tls=false")
|
||||
case "require", "required":
|
||||
params = append(params, "tls=true")
|
||||
}
|
||||
}
|
||||
|
||||
// Add charset
|
||||
params = append(params, "charset=utf8mb4")
|
||||
params = append(params, "parseTime=true")
|
||||
|
||||
if len(params) > 0 {
|
||||
dsn += "?" + strings.Join(params, "&")
|
||||
}
|
||||
|
||||
return dsn
|
||||
}
|
||||
|
||||
// sanitizeMySQLDSN removes password from DSN for logging
|
||||
func sanitizeMySQLDSN(dsn string) string {
|
||||
// Find password part and replace it
|
||||
if idx := strings.Index(dsn, ":"); idx != -1 {
|
||||
if endIdx := strings.Index(dsn[idx:], "@"); endIdx != -1 {
|
||||
return dsn[:idx] + ":***" + dsn[idx+endIdx:]
|
||||
}
|
||||
}
|
||||
return dsn
|
||||
}
|
||||
427
internal/database/postgresql.go
Normal file
427
internal/database/postgresql.go
Normal file
@ -0,0 +1,427 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"dbbackup/internal/config"
|
||||
"dbbackup/internal/logger"
|
||||
)
|
||||
|
||||
// PostgreSQL implements Database interface for PostgreSQL
|
||||
type PostgreSQL struct {
|
||||
baseDatabase
|
||||
}
|
||||
|
||||
// NewPostgreSQL creates a new PostgreSQL database instance
|
||||
func NewPostgreSQL(cfg *config.Config, log logger.Logger) *PostgreSQL {
|
||||
return &PostgreSQL{
|
||||
baseDatabase: baseDatabase{
|
||||
cfg: cfg,
|
||||
log: log,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Connect establishes a connection to PostgreSQL
|
||||
func (p *PostgreSQL) Connect(ctx context.Context) error {
|
||||
// Build PostgreSQL DSN
|
||||
dsn := p.buildDSN()
|
||||
p.dsn = dsn
|
||||
|
||||
p.log.Debug("Connecting to PostgreSQL", "dsn", sanitizeDSN(dsn))
|
||||
|
||||
db, err := sql.Open("postgres", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open PostgreSQL connection: %w", err)
|
||||
}
|
||||
|
||||
// Configure connection pool
|
||||
db.SetMaxOpenConns(10)
|
||||
db.SetMaxIdleConns(5)
|
||||
db.SetConnMaxLifetime(0)
|
||||
|
||||
// Test connection
|
||||
timeoutCtx, cancel := buildTimeout(ctx, 0)
|
||||
defer cancel()
|
||||
|
||||
if err := db.PingContext(timeoutCtx); err != nil {
|
||||
db.Close()
|
||||
return fmt.Errorf("failed to ping PostgreSQL: %w", err)
|
||||
}
|
||||
|
||||
p.db = db
|
||||
p.log.Info("Connected to PostgreSQL successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListDatabases returns list of non-template databases
|
||||
func (p *PostgreSQL) ListDatabases(ctx context.Context) ([]string, error) {
|
||||
if p.db == nil {
|
||||
return nil, fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
query := `SELECT datname FROM pg_database
|
||||
WHERE datistemplate = false
|
||||
ORDER BY datname`
|
||||
|
||||
rows, err := p.db.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query databases: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var databases []string
|
||||
for rows.Next() {
|
||||
var name string
|
||||
if err := rows.Scan(&name); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan database name: %w", err)
|
||||
}
|
||||
databases = append(databases, name)
|
||||
}
|
||||
|
||||
return databases, rows.Err()
|
||||
}
|
||||
|
||||
// ListTables returns list of tables in a database
|
||||
func (p *PostgreSQL) ListTables(ctx context.Context, database string) ([]string, error) {
|
||||
if p.db == nil {
|
||||
return nil, fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
query := `SELECT schemaname||'.'||tablename as full_name
|
||||
FROM pg_tables
|
||||
WHERE schemaname NOT IN ('information_schema', 'pg_catalog', 'pg_toast')
|
||||
ORDER BY schemaname, tablename`
|
||||
|
||||
rows, err := p.db.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query tables: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tables []string
|
||||
for rows.Next() {
|
||||
var name string
|
||||
if err := rows.Scan(&name); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan table name: %w", err)
|
||||
}
|
||||
tables = append(tables, name)
|
||||
}
|
||||
|
||||
return tables, rows.Err()
|
||||
}
|
||||
|
||||
// CreateDatabase creates a new database
|
||||
func (p *PostgreSQL) CreateDatabase(ctx context.Context, name string) error {
|
||||
if p.db == nil {
|
||||
return fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
// PostgreSQL doesn't support CREATE DATABASE in transactions or prepared statements
|
||||
query := fmt.Sprintf("CREATE DATABASE %s", name)
|
||||
_, err := p.db.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create database %s: %w", name, err)
|
||||
}
|
||||
|
||||
p.log.Info("Created database", "name", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DropDatabase drops a database
|
||||
func (p *PostgreSQL) DropDatabase(ctx context.Context, name string) error {
|
||||
if p.db == nil {
|
||||
return fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
// Force drop connections and drop database
|
||||
query := fmt.Sprintf("DROP DATABASE IF EXISTS %s", name)
|
||||
_, err := p.db.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to drop database %s: %w", name, err)
|
||||
}
|
||||
|
||||
p.log.Info("Dropped database", "name", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DatabaseExists checks if a database exists
|
||||
func (p *PostgreSQL) DatabaseExists(ctx context.Context, name string) (bool, error) {
|
||||
if p.db == nil {
|
||||
return false, fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
query := `SELECT EXISTS(SELECT 1 FROM pg_database WHERE datname = $1)`
|
||||
var exists bool
|
||||
err := p.db.QueryRowContext(ctx, query, name).Scan(&exists)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check database existence: %w", err)
|
||||
}
|
||||
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// GetVersion returns PostgreSQL version
|
||||
func (p *PostgreSQL) GetVersion(ctx context.Context) (string, error) {
|
||||
if p.db == nil {
|
||||
return "", fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
var version string
|
||||
err := p.db.QueryRowContext(ctx, "SELECT version()").Scan(&version)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get version: %w", err)
|
||||
}
|
||||
|
||||
return version, nil
|
||||
}
|
||||
|
||||
// GetDatabaseSize returns database size in bytes
|
||||
func (p *PostgreSQL) GetDatabaseSize(ctx context.Context, database string) (int64, error) {
|
||||
if p.db == nil {
|
||||
return 0, fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
query := `SELECT pg_database_size($1)`
|
||||
var size int64
|
||||
err := p.db.QueryRowContext(ctx, query, database).Scan(&size)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get database size: %w", err)
|
||||
}
|
||||
|
||||
return size, nil
|
||||
}
|
||||
|
||||
// GetTableRowCount returns approximate row count for a table
|
||||
func (p *PostgreSQL) GetTableRowCount(ctx context.Context, database, table string) (int64, error) {
|
||||
if p.db == nil {
|
||||
return 0, fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
// Use pg_stat_user_tables for approximate count (faster)
|
||||
parts := strings.Split(table, ".")
|
||||
if len(parts) != 2 {
|
||||
return 0, fmt.Errorf("table name must be in format schema.table")
|
||||
}
|
||||
|
||||
query := `SELECT COALESCE(n_tup_ins, 0) FROM pg_stat_user_tables
|
||||
WHERE schemaname = $1 AND relname = $2`
|
||||
|
||||
var count int64
|
||||
err := p.db.QueryRowContext(ctx, query, parts[0], parts[1]).Scan(&count)
|
||||
if err != nil {
|
||||
// Fallback to exact count if stats not available
|
||||
exactQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s", table)
|
||||
err = p.db.QueryRowContext(ctx, exactQuery).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get table row count: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// BuildBackupCommand builds pg_dump command
|
||||
func (p *PostgreSQL) BuildBackupCommand(database, outputFile string, options BackupOptions) []string {
|
||||
cmd := []string{"pg_dump"}
|
||||
|
||||
// Connection parameters
|
||||
if p.cfg.Host != "localhost" {
|
||||
cmd = append(cmd, "-h", p.cfg.Host)
|
||||
cmd = append(cmd, "-p", strconv.Itoa(p.cfg.Port))
|
||||
cmd = append(cmd, "--no-password")
|
||||
}
|
||||
cmd = append(cmd, "-U", p.cfg.User)
|
||||
|
||||
// Format and compression
|
||||
if options.Format != "" {
|
||||
cmd = append(cmd, "--format="+options.Format)
|
||||
} else {
|
||||
cmd = append(cmd, "--format=custom")
|
||||
}
|
||||
|
||||
if options.Compression > 0 {
|
||||
cmd = append(cmd, "--compress="+strconv.Itoa(options.Compression))
|
||||
}
|
||||
|
||||
// Parallel jobs (only for directory format)
|
||||
if options.Parallel > 1 && options.Format == "directory" {
|
||||
cmd = append(cmd, "--jobs="+strconv.Itoa(options.Parallel))
|
||||
}
|
||||
|
||||
// Options
|
||||
if options.Blobs {
|
||||
cmd = append(cmd, "--blobs")
|
||||
}
|
||||
if options.SchemaOnly {
|
||||
cmd = append(cmd, "--schema-only")
|
||||
}
|
||||
if options.DataOnly {
|
||||
cmd = append(cmd, "--data-only")
|
||||
}
|
||||
if options.NoOwner {
|
||||
cmd = append(cmd, "--no-owner")
|
||||
}
|
||||
if options.NoPrivileges {
|
||||
cmd = append(cmd, "--no-privileges")
|
||||
}
|
||||
if options.Role != "" {
|
||||
cmd = append(cmd, "--role="+options.Role)
|
||||
}
|
||||
|
||||
// Database and output
|
||||
cmd = append(cmd, "--dbname="+database)
|
||||
cmd = append(cmd, "--file="+outputFile)
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// BuildRestoreCommand builds pg_restore command
|
||||
func (p *PostgreSQL) BuildRestoreCommand(database, inputFile string, options RestoreOptions) []string {
|
||||
cmd := []string{"pg_restore"}
|
||||
|
||||
// Connection parameters
|
||||
if p.cfg.Host != "localhost" {
|
||||
cmd = append(cmd, "-h", p.cfg.Host)
|
||||
cmd = append(cmd, "-p", strconv.Itoa(p.cfg.Port))
|
||||
cmd = append(cmd, "--no-password")
|
||||
}
|
||||
cmd = append(cmd, "-U", p.cfg.User)
|
||||
|
||||
// Parallel jobs
|
||||
if options.Parallel > 1 {
|
||||
cmd = append(cmd, "--jobs="+strconv.Itoa(options.Parallel))
|
||||
}
|
||||
|
||||
// Options
|
||||
if options.Clean {
|
||||
cmd = append(cmd, "--clean")
|
||||
}
|
||||
if options.IfExists {
|
||||
cmd = append(cmd, "--if-exists")
|
||||
}
|
||||
if options.NoOwner {
|
||||
cmd = append(cmd, "--no-owner")
|
||||
}
|
||||
if options.NoPrivileges {
|
||||
cmd = append(cmd, "--no-privileges")
|
||||
}
|
||||
if options.SingleTransaction {
|
||||
cmd = append(cmd, "--single-transaction")
|
||||
}
|
||||
|
||||
// Database and input
|
||||
cmd = append(cmd, "--dbname="+database)
|
||||
cmd = append(cmd, inputFile)
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// BuildSampleQuery builds SQL query for sampling data
|
||||
func (p *PostgreSQL) BuildSampleQuery(database, table string, strategy SampleStrategy) string {
|
||||
switch strategy.Type {
|
||||
case "ratio":
|
||||
// Every Nth record using row_number
|
||||
return fmt.Sprintf("SELECT * FROM (SELECT *, row_number() OVER () as rn FROM %s) t WHERE rn %% %d = 1",
|
||||
table, strategy.Value)
|
||||
case "percent":
|
||||
// Percentage sampling using TABLESAMPLE (PostgreSQL 9.5+)
|
||||
return fmt.Sprintf("SELECT * FROM %s TABLESAMPLE BERNOULLI(%d)", table, strategy.Value)
|
||||
case "count":
|
||||
// First N records
|
||||
return fmt.Sprintf("SELECT * FROM %s LIMIT %d", table, strategy.Value)
|
||||
default:
|
||||
return fmt.Sprintf("SELECT * FROM %s LIMIT 1000", table)
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateBackupTools checks if required PostgreSQL tools are available
|
||||
func (p *PostgreSQL) ValidateBackupTools() error {
|
||||
tools := []string{"pg_dump", "pg_restore", "pg_dumpall", "psql"}
|
||||
|
||||
for _, tool := range tools {
|
||||
if _, err := exec.LookPath(tool); err != nil {
|
||||
return fmt.Errorf("required tool not found: %s", tool)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildDSN constructs PostgreSQL connection string
|
||||
func (p *PostgreSQL) buildDSN() string {
|
||||
dsn := fmt.Sprintf("user=%s dbname=%s", p.cfg.User, p.cfg.Database)
|
||||
|
||||
if p.cfg.Password != "" {
|
||||
dsn += " password=" + p.cfg.Password
|
||||
}
|
||||
|
||||
// For localhost connections, try socket first for peer auth
|
||||
if p.cfg.Host == "localhost" && p.cfg.Password == "" {
|
||||
// Try Unix socket connection for peer authentication
|
||||
// Common PostgreSQL socket locations
|
||||
socketDirs := []string{
|
||||
"/var/run/postgresql",
|
||||
"/tmp",
|
||||
"/var/lib/pgsql",
|
||||
}
|
||||
|
||||
for _, dir := range socketDirs {
|
||||
socketPath := fmt.Sprintf("%s/.s.PGSQL.%d", dir, p.cfg.Port)
|
||||
if _, err := os.Stat(socketPath); err == nil {
|
||||
dsn += " host=" + dir
|
||||
p.log.Debug("Using PostgreSQL socket", "path", socketPath)
|
||||
break
|
||||
}
|
||||
}
|
||||
} else if p.cfg.Host != "localhost" || p.cfg.Password != "" {
|
||||
// Use TCP connection
|
||||
dsn += " host=" + p.cfg.Host
|
||||
dsn += " port=" + strconv.Itoa(p.cfg.Port)
|
||||
}
|
||||
|
||||
if p.cfg.SSLMode != "" && !p.cfg.Insecure {
|
||||
// Map SSL modes to supported values for lib/pq
|
||||
switch strings.ToLower(p.cfg.SSLMode) {
|
||||
case "prefer", "preferred":
|
||||
dsn += " sslmode=require" // lib/pq default, closest to prefer
|
||||
case "require", "required":
|
||||
dsn += " sslmode=require"
|
||||
case "verify-ca":
|
||||
dsn += " sslmode=verify-ca"
|
||||
case "verify-full", "verify-identity":
|
||||
dsn += " sslmode=verify-full"
|
||||
case "disable", "disabled":
|
||||
dsn += " sslmode=disable"
|
||||
default:
|
||||
dsn += " sslmode=require" // Safe default
|
||||
}
|
||||
} else if p.cfg.Insecure {
|
||||
dsn += " sslmode=disable"
|
||||
}
|
||||
|
||||
return dsn
|
||||
}
|
||||
|
||||
// sanitizeDSN removes password from DSN for logging
|
||||
func sanitizeDSN(dsn string) string {
|
||||
// Simple password removal for logging
|
||||
parts := strings.Split(dsn, " ")
|
||||
var sanitized []string
|
||||
|
||||
for _, part := range parts {
|
||||
if strings.HasPrefix(part, "password=") {
|
||||
sanitized = append(sanitized, "password=***")
|
||||
} else {
|
||||
sanitized = append(sanitized, part)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(sanitized, " ")
|
||||
}
|
||||
Reference in New Issue
Block a user