Source code for sqlpyhelper.db_helper

import csv
import logging
import os
import re
from typing import Any, Literal, Optional

from dotenv import load_dotenv

load_dotenv()

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger("sqlpyhelper")


def _validate_identifier(name: str) -> str:
    """
    Validate a SQL identifier (table or column name).
    Allows only alphanumeric characters and underscores.
    Raises ValueError for anything else, preventing SQL injection via identifiers.
    """
    if not re.match(r"^[A-Za-z_][A-Za-z0-9_]*$", name):
        raise ValueError(
            f"Invalid SQL identifier: {name!r}. "
            "Only letters, digits, and underscores are allowed."
        )
    return name


[docs] class SQLPyHelperError(Exception): """Base exception for SQLPyHelper errors."""
[docs] class ConnectionError(SQLPyHelperError): """Raised when a database connection fails."""
[docs] class QueryError(SQLPyHelperError): """Raised when a query fails to execute."""
[docs] class BackupError(SQLPyHelperError): """Raised when a backup operation fails."""
[docs] class SQLPyHelper: def __init__( self, db_type: Optional[str] = None, host: Optional[str] = None, user: Optional[str] = None, password: Optional[str] = None, database: Optional[str] = None, driver: Optional[str] = None, port: Optional[str] = None, oracle_sid: Optional[str] = None, ) -> None: # Store original params so reconnect() can replay them self._init_kwargs = { "db_type": db_type, "host": host, "user": user, "password": password, "database": database, "driver": driver, "port": port, "oracle_sid": oracle_sid, } self.db_type: str = (db_type or os.getenv("DB_TYPE") or "").lower() self.host: Optional[str] = host or os.getenv("DB_HOST") self.user: Optional[str] = user or os.getenv("DB_USER") self.password: Optional[str] = password or os.getenv("DB_PASSWORD") self.database: Optional[str] = database or os.getenv("DB_NAME") self.driver: Optional[str] = driver or os.getenv("DB_DRIVER") self.port: Optional[str] = port or os.getenv("DB_PORT") self.oracle_sid: Optional[str] = oracle_sid or os.getenv("ORACLE_SID") self.pool: Any = None self._in_transaction: bool = False if not self.db_type or not self.database: raise ValueError("Missing required database configuration.") if self.db_type == "sqlite": import sqlite3 self.connection = sqlite3.connect(self.database) elif self.db_type == "postgres": import psycopg2 self.connection = psycopg2.connect( host=self.host, user=self.user, password=self.password, dbname=self.database, port=self.port, ) elif self.db_type == "mysql": import mysql.connector self.connection = mysql.connector.connect( host=self.host, user=self.user, password=self.password, database=self.database, ) # type: ignore[assignment] elif self.db_type == "sqlserver": import pyodbc self.connection = pyodbc.connect( f"DRIVER={self.driver};SERVER={self.host};DATABASE={self.database};" f"UID={self.user};PWD={self.password}" ) elif self.db_type == "oracle": import oracledb oracle_port = os.getenv("ORACLE_DB_PORT", 1521) dsn = oracledb.makedsn( self.host, oracle_port, sid=self.oracle_sid # type: ignore[arg-type] ) self.connection = oracledb.connect( user=self.user, password=self.password, dsn=dsn ) # type: ignore[assignment] else: raise ValueError("Unsupported database type") self.cursor = self.connection.cursor() def __enter__(self) -> "SQLPyHelper": return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]: self.close() return False
[docs] def execute_query(self, query: str, params: Optional[tuple] = None) -> None: """Executes a query with optional parameters""" try: if params: self.cursor.execute(query, params) else: self.cursor.execute(query) if not self._in_transaction: self.connection.commit() except Exception as e: if "server has gone away" in str(e): self.reconnect() self.cursor.execute(query, params) # type: ignore[arg-type] if not self._in_transaction: self.connection.commit() else: raise QueryError(f"Query failed: {e}") from e
[docs] def fetch_one(self) -> Optional[tuple]: """Fetches a single row""" try: return self.cursor.fetchone() except Exception as e: raise QueryError(f"Failed to fetch row: {e}") from e
[docs] def fetch_all(self) -> list[tuple]: """Fetches all rows from the last executed query""" try: return self.cursor.fetchall() except Exception as e: raise QueryError(f"Failed to fetch rows: {e}") from e
[docs] def fetch_by_param( self, table_name: str, column_name: str, value: Any ) -> list[tuple]: """Fetches rows from a table where a column matches the given value.""" try: table_name = _validate_identifier(table_name) column_name = _validate_identifier(column_name) placeholder = "?" if self.db_type == "sqlite" else "%s" query = f"SELECT * FROM {table_name} WHERE {column_name} = {placeholder}" self.cursor.execute(query, (value,)) return self.cursor.fetchall() except Exception as e: raise QueryError(f"Failed to fetch by param: {e}") from e
[docs] def close(self) -> None: """Closes the cursor and database connection.""" try: self.cursor.close() self.connection.close() except Exception as e: raise ConnectionError(f"Failed to close connection: {e}") from e
[docs] def create_table(self, table_name: str, columns: dict[str, str]) -> None: """ Creates a table dynamically using a dictionary format. Example: columns = {'id': 'INTEGER PRIMARY KEY', 'name': 'TEXT', 'age': 'INTEGER'} """ try: table_name = _validate_identifier(table_name) validated_cols = { _validate_identifier(col): dtype for col, dtype in columns.items() } columns_def = ", ".join( [f"{col} {dtype}" for col, dtype in validated_cols.items()] ) query = f"CREATE TABLE IF NOT EXISTS {table_name} ({columns_def})" self.execute_query(query) except Exception as e: raise QueryError(f"Failed to create table: {e}") from e
[docs] def insert_bulk(self, table_name: str, data: list[dict[str, Any]]) -> None: """ Inserts multiple rows at once. Example: data = [{'id': 1, 'name': 'Alice'}, {'id': 2, 'name': 'Bob'}] """ try: table_name = _validate_identifier(table_name) col_names = [_validate_identifier(col) for col in data[0].keys()] columns = ", ".join(col_names) placeholder = "?" if self.db_type == "sqlite" else "%s" placeholders = ", ".join([placeholder] * len(data[0])) query = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})" values = [tuple(row.values()) for row in data] self.cursor.executemany(query, values) self.connection.commit() except Exception as e: raise QueryError(f"Failed to insert bulk rows: {e}") from e
[docs] def backup_table(self, table_name: str, backup_file: str) -> None: """ Exports table data into a CSV file. Example: backup_table('users', 'users_backup.csv') """ try: table_name = _validate_identifier(table_name) query = f"SELECT * FROM {table_name}" self.execute_query(query) rows = self.fetch_all() with open(backup_file, mode="w", newline="") as file: writer = csv.writer(file) writer.writerow( [desc[0] for desc in self.cursor.description] ) # Column headers writer.writerows(rows) except Exception as e: raise BackupError(f"Failed to backup table: {e}") from e
[docs] def setup_connection_pool( self, min_conn: int = 1, max_conn: int = 5, pool_size: int = 5 ) -> None: """Sets up connection pooling based on the database type""" try: if self.db_type == "postgres": from psycopg2 import pool self.pool = pool.SimpleConnectionPool( min_conn, max_conn, host=self.host, user=self.user, password=self.password, dbname=self.database, ) elif self.db_type == "mysql": import mysql.connector.pooling self.pool = mysql.connector.pooling.MySQLConnectionPool( pool_name="mypool", pool_size=pool_size, host=self.host, user=self.user, password=self.password, database=self.database, ) elif self.db_type == "sqlserver": import pyodbc self.pool = [ pyodbc.connect( f"DRIVER={self.driver};SERVER={self.host};DATABASE={self.database};" f"UID={self.user};PWD={self.password};ConnectionPooling=Yes" ) for _ in range(pool_size) ] elif self.db_type == "oracle": import oracledb oracle_port = os.getenv("ORACLE_DB_PORT", 1521) dsn = oracledb.makedsn(self.host, oracle_port, sid=self.oracle_sid) # type: ignore[arg-type] self.pool = oracledb.create_pool( user=self.user, password=self.password, dsn=dsn, min=min_conn, max=max_conn, increment=1, ) else: raise ValueError(f"Connection pooling not supported for {self.db_type}") except Exception as e: raise ConnectionError(f"Failed to set up connection pool: {e}") from e
[docs] def get_connection_from_pool(self) -> Any: """Fetches a connection from the pool.""" return self.pool.get_connection()
[docs] def return_connection_to_pool(self, connection: Any = None) -> None: """Returns a connection back to the pool.""" conn = connection or self.connection if self.pool is None: raise RuntimeError( "No connection pool initialised. Call setup_connection_pool() first." ) if self.db_type == "postgres": self.pool.putconn(conn) elif self.db_type == "mysql": conn.close() elif self.db_type == "oracle": self.pool.release(conn) else: conn.close()
[docs] def reconnect(self) -> None: """Reconnects to the database if connection is lost""" try: self.connection.close() self.__init__(**self._init_kwargs) # type: ignore[misc] print("Database reconnected successfully.") except Exception as e: raise ConnectionError(f"Reconnection failed: {e}") from e
[docs] def begin_transaction(self) -> None: """Begin an explicit transaction. Works across all supported databases.""" try: if self.db_type == "sqlite": self.execute_query("BEGIN") elif self.db_type in ("postgres", "mysql"): self.execute_query("START TRANSACTION") elif self.db_type == "sqlserver": self.execute_query("BEGIN TRANSACTION") elif self.db_type == "oracle": pass # Oracle starts transactions implicitly on first DML statement self._in_transaction = True logger.info("Transaction started on %s database", self.db_type) except Exception as e: raise QueryError(f"Failed to begin transaction: {e}") from e
[docs] def commit_transaction(self) -> None: """Commit the current transaction.""" try: self.connection.commit() self._in_transaction = False logger.info("Transaction committed on %s database", self.db_type) except Exception as e: raise QueryError(f"Failed to commit transaction: {e}") from e
[docs] def rollback_transaction(self) -> None: """Roll back the current transaction.""" try: self.connection.rollback() self._in_transaction = False logger.info("Transaction rolled back on %s database", self.db_type) except Exception as e: raise QueryError(f"Failed to rollback transaction: {e}") from e
[docs] def insert_dynamic(self, table: str, data: dict[str, Any]) -> None: """ Dynamically constructs and executes an INSERT query with database-specific placeholders. """ table = _validate_identifier(table) columns = ", ".join(_validate_identifier(col) for col in data.keys()) placeholders_style = "?" if self.db_type == "sqlite" else "%s" placeholders = ", ".join([placeholders_style] * len(data)) values = tuple(data.values()) sql = f"INSERT INTO {table} ({columns}) VALUES ({placeholders})" self.execute_query(sql, values)