Source code for sqlpyhelper.async_helper

"""
sqlpyhelper.async_helper
~~~~~~~~~~~~~~~~~~~~~~~~
Async-native database helper supporting SQLite, PostgreSQL, MySQL,
SQL Server, and Oracle.

Uses async-native drivers:
- SQLite:     aiosqlite
- PostgreSQL: asyncpg
- MySQL:      aiomysql
- SQL Server: aioodbc
- Oracle:     python-oracledb (async mode)

Example usage::

    import asyncio
    from sqlpyhelper.async_helper import AsyncSQLPyHelper

    async def main():
        async with AsyncSQLPyHelper(db_type="sqlite", database="my.db") as db:
            await db.execute("CREATE TABLE IF NOT EXISTS users (id INTEGER, name TEXT)")
            await db.execute("INSERT INTO users VALUES ($1, $2)", 1, "Alice")
            rows = await db.fetch_all("SELECT * FROM users")
            print(rows)

    asyncio.run(main())
"""

import logging
import os
from typing import Any, Optional

from dotenv import load_dotenv

load_dotenv()

logger = logging.getLogger("sqlpyhelper.async")


[docs] class AsyncConnectionError(Exception): """Raised when an async database connection fails."""
[docs] class AsyncQueryError(Exception): """Raised when an async query fails."""
[docs] class AsyncSQLPyHelper: """ Async-native database helper with a unified API across SQLite, PostgreSQL, MySQL, SQL Server, and Oracle. Use as an async context manager:: async with AsyncSQLPyHelper(db_type="postgres", ...) as db: rows = await db.fetch_all("SELECT * FROM users") Or manage the connection lifecycle manually:: db = AsyncSQLPyHelper(db_type="sqlite", database="my.db") await db.connect() try: rows = await db.fetch_all("SELECT * FROM users") finally: await db.close() """ def __init__( self, db_type: Optional[str] = None, host: Optional[str] = None, user: Optional[str] = None, password: Optional[str] = None, database: Optional[str] = None, driver: Optional[str] = None, port: Optional[str] = None, oracle_sid: Optional[str] = None, ) -> None: self.db_type: str = (db_type or os.getenv("DB_TYPE") or "").lower() self.host: Optional[str] = host or os.getenv("DB_HOST") self.user: Optional[str] = user or os.getenv("DB_USER") self.password: Optional[str] = password or os.getenv("DB_PASSWORD") self.database: Optional[str] = database or os.getenv("DB_NAME") self.driver: Optional[str] = driver or os.getenv("DB_DRIVER") self.port: Optional[str] = port or os.getenv("DB_PORT") self.oracle_sid: Optional[str] = oracle_sid or os.getenv("ORACLE_SID") self._connection: Any = None self._pool: Any = None if not self.db_type or not self.database: raise ValueError("Missing required database configuration.") if self.db_type not in ("sqlite", "postgres", "mysql", "sqlserver", "oracle"): raise ValueError(f"Unsupported database type: {self.db_type!r}") # ----------------------------------------------------------------------- # Connection lifecycle # -----------------------------------------------------------------------
[docs] async def connect(self) -> None: """Open the database connection.""" try: if self.db_type == "sqlite": import aiosqlite self._connection = await aiosqlite.connect(self.database or "") # type: ignore[arg-type] logger.info("Connected to SQLite database: %s", self.database) elif self.db_type == "postgres": import asyncpg self._connection = await asyncpg.connect( host=self.host, port=int(self.port or 5432), user=self.user, password=self.password, database=self.database, ) logger.info("Connected to PostgreSQL database: %s", self.database) elif self.db_type == "mysql": import aiomysql self._connection = await aiomysql.connect( host=self.host or "localhost", port=int(self.port or 3306), user=self.user, password=self.password or "", db=self.database, autocommit=False, ) logger.info("Connected to MySQL database: %s", self.database) elif self.db_type == "sqlserver": import aioodbc dsn = ( f"DRIVER={self.driver};" f"SERVER={self.host};" f"DATABASE={self.database};" f"UID={self.user};" f"PWD={self.password}" ) self._connection = await aioodbc.connect(dsn=dsn) logger.info("Connected to SQL Server database: %s", self.database) elif self.db_type == "oracle": import oracledb oracle_port = int(os.getenv("ORACLE_DB_PORT", "1521")) dsn = oracledb.makedsn( self.host, oracle_port, sid=self.oracle_sid # type: ignore[arg-type] ) self._connection = await oracledb.connect_async( user=self.user, password=self.password, dsn=dsn ) logger.info("Connected to Oracle database: %s", self.oracle_sid) except Exception as e: raise AsyncConnectionError( f"Failed to connect to {self.db_type}: {e}" ) from e
[docs] async def close(self) -> None: """Close the database connection.""" try: if self._connection is not None: await self._connection.close() self._connection = None logger.info("Closed %s connection", self.db_type) except Exception as e: raise AsyncConnectionError(f"Failed to close connection: {e}") from e
async def __aenter__(self) -> "AsyncSQLPyHelper": await self.connect() return self async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: await self.close() return False # ----------------------------------------------------------------------- # Internal helpers # ----------------------------------------------------------------------- def _check_connection(self) -> None: if self._connection is None: raise AsyncConnectionError( "No active connection. Call connect() or use async with." ) def _adapt_query(self, query: str, args: tuple) -> tuple[str, tuple]: """ Adapt a query and its arguments for the active database driver. asyncpg uses $1, $2, ... positional placeholders. aiosqlite and aiomysql use ? and %s respectively. Callers should write queries using $1, $2, ... style and this method will translate as needed. """ if not args: return query, args if self.db_type == "postgres": # asyncpg natively uses $1, $2 — pass through unchanged return query, args elif self.db_type == "sqlite": # Replace $1, $2 with ? import re adapted = re.sub(r"\$\d+", "?", query) return adapted, args elif self.db_type in ("mysql", "sqlserver"): # Replace $1, $2 with %s import re adapted = re.sub(r"\$\d+", "%s", query) return adapted, args elif self.db_type == "oracle": # Replace $1, $2 with :1, :2 import re def replace_placeholder(m: Any) -> str: return f":{m.group(0)[1:]}" adapted = re.sub(r"\$(\d+)", replace_placeholder, query) return adapted, args return query, args # ----------------------------------------------------------------------- # Query execution # -----------------------------------------------------------------------
[docs] async def execute(self, query: str, *args: Any) -> None: """ Execute a SQL statement (INSERT, UPDATE, DELETE, DDL). Use $1, $2, ... for parameterised values:: await db.execute( "INSERT INTO users (id, name) VALUES ($1, $2)", 1, "Alice" ) Args: query: SQL query string using $1, $2 placeholders. *args: Query parameters. Raises: AsyncQueryError: If the query fails. """ self._check_connection() adapted_query, adapted_args = self._adapt_query(query, args) try: if self.db_type == "postgres": await self._connection.execute(adapted_query, *adapted_args) elif self.db_type == "sqlite": await self._connection.execute(adapted_query, adapted_args) await self._connection.commit() elif self.db_type in ("mysql", "sqlserver"): async with self._connection.cursor() as cursor: await cursor.execute(adapted_query, adapted_args) await self._connection.commit() elif self.db_type == "oracle": cursor = self._connection.cursor() await cursor.execute(adapted_query, adapted_args) await self._connection.commit() logger.debug("Executed: %s", query) except Exception as e: raise AsyncQueryError(f"Query failed: {e}") from e
[docs] async def fetch_one(self, query: str, *args: Any) -> Optional[Any]: """ Execute a SELECT query and return a single row, or None. Args: query: SQL query string using $1, $2 placeholders. *args: Query parameters. Returns: A single row, or None if no rows matched. Raises: AsyncQueryError: If the query fails. """ self._check_connection() adapted_query, adapted_args = self._adapt_query(query, args) try: if self.db_type == "postgres": return await self._connection.fetchrow(adapted_query, *adapted_args) elif self.db_type == "sqlite": async with self._connection.execute( adapted_query, adapted_args ) as cursor: return await cursor.fetchone() elif self.db_type in ("mysql", "sqlserver"): async with self._connection.cursor() as cursor: await cursor.execute(adapted_query, adapted_args) return await cursor.fetchone() elif self.db_type == "oracle": cursor = self._connection.cursor() await cursor.execute(adapted_query, adapted_args) return await cursor.fetchone() return None except Exception as e: raise AsyncQueryError(f"fetch_one failed: {e}") from e
[docs] async def fetch_all(self, query: str, *args: Any) -> list[Any]: """ Execute a SELECT query and return all rows. Args: query: SQL query string using $1, $2 placeholders. *args: Query parameters. Returns: A list of rows (empty list if no rows matched). Raises: AsyncQueryError: If the query fails. """ self._check_connection() adapted_query, adapted_args = self._adapt_query(query, args) try: if self.db_type == "postgres": return await self._connection.fetch(adapted_query, *adapted_args) elif self.db_type == "sqlite": async with self._connection.execute( adapted_query, adapted_args ) as cursor: return await cursor.fetchall() elif self.db_type in ("mysql", "sqlserver"): async with self._connection.cursor() as cursor: await cursor.execute(adapted_query, adapted_args) return await cursor.fetchall() elif self.db_type == "oracle": cursor = self._connection.cursor() await cursor.execute(adapted_query, adapted_args) return await cursor.fetchall() return [] except Exception as e: raise AsyncQueryError(f"fetch_all failed: {e}") from e
[docs] async def fetch_val(self, query: str, *args: Any) -> Optional[Any]: """ Execute a SELECT query and return a single scalar value. Useful for COUNT, SUM, or any query returning one value:: count = await db.fetch_val("SELECT COUNT(*) FROM users") Args: query: SQL query string using $1, $2 placeholders. *args: Query parameters. Returns: A single scalar value, or None. Raises: AsyncQueryError: If the query fails. """ self._check_connection() adapted_query, adapted_args = self._adapt_query(query, args) try: if self.db_type == "postgres": return await self._connection.fetchval(adapted_query, *adapted_args) elif self.db_type == "sqlite": async with self._connection.execute( adapted_query, adapted_args ) as cursor: row = await cursor.fetchone() return row[0] if row else None elif self.db_type in ("mysql", "sqlserver"): async with self._connection.cursor() as cursor: await cursor.execute(adapted_query, adapted_args) row = await cursor.fetchone() return row[0] if row else None elif self.db_type == "oracle": cursor = self._connection.cursor() await cursor.execute(adapted_query, adapted_args) row = await cursor.fetchone() return row[0] if row else None return None except Exception as e: raise AsyncQueryError(f"fetch_val failed: {e}") from e
[docs] async def execute_many(self, query: str, args_list: list[tuple]) -> None: """ Execute a SQL statement multiple times with different parameters. Efficient for bulk inserts:: await db.execute_many( "INSERT INTO users (id, name) VALUES ($1, $2)", [(1, "Alice"), (2, "Bob"), (3, "Charlie")] ) Args: query: SQL query string using $1, $2 placeholders. args_list: List of parameter tuples. Raises: AsyncQueryError: If the operation fails. """ self._check_connection() if not args_list: return try: if self.db_type == "postgres": await self._connection.executemany(query, args_list) elif self.db_type == "sqlite": import re adapted = re.sub(r"\$\d+", "?", query) await self._connection.executemany(adapted, args_list) await self._connection.commit() elif self.db_type in ("mysql", "sqlserver"): import re adapted = re.sub(r"\$\d+", "%s", query) async with self._connection.cursor() as cursor: await cursor.executemany(adapted, args_list) await self._connection.commit() elif self.db_type == "oracle": import re def replace_placeholder(m: Any) -> str: return f":{m.group(1)}" adapted = re.sub(r"\$(\d+)", replace_placeholder, query) cursor = self._connection.cursor() await cursor.executemany(adapted, args_list) await self._connection.commit() logger.debug("execute_many: %d rows", len(args_list)) except Exception as e: raise AsyncQueryError(f"execute_many failed: {e}") from e
# ----------------------------------------------------------------------- # Transaction management # -----------------------------------------------------------------------
[docs] async def begin_transaction(self) -> None: """ Begin an explicit transaction. For PostgreSQL, use the transaction() context manager instead, which is the idiomatic asyncpg approach. Raises: AsyncQueryError: If the transaction cannot be started. """ self._check_connection() try: if self.db_type == "sqlite": await self._connection.execute("BEGIN") elif self.db_type == "mysql": await self._connection.begin() elif self.db_type == "sqlserver": async with self._connection.cursor() as cursor: await cursor.execute("BEGIN TRANSACTION") elif self.db_type == "oracle": pass # Oracle starts transactions implicitly elif self.db_type == "postgres": # asyncpg transactions are managed via connection.transaction() # Calling begin() manually is supported but the context manager # is preferred — see transaction() below self._pg_transaction = self._connection.transaction() await self._pg_transaction.start() logger.info("Transaction started on %s", self.db_type) except Exception as e: raise AsyncQueryError(f"Failed to begin transaction: {e}") from e
[docs] async def commit_transaction(self) -> None: """Commit the current transaction.""" self._check_connection() try: if self.db_type == "postgres": await self._pg_transaction.commit() else: await self._connection.commit() logger.info("Transaction committed on %s", self.db_type) except Exception as e: raise AsyncQueryError(f"Failed to commit transaction: {e}") from e
[docs] async def rollback_transaction(self) -> None: """Roll back the current transaction.""" self._check_connection() try: if self.db_type == "postgres": await self._pg_transaction.rollback() else: await self._connection.rollback() logger.info("Transaction rolled back on %s", self.db_type) except Exception as e: raise AsyncQueryError(f"Failed to rollback transaction: {e}") from e
# ----------------------------------------------------------------------- # Connection pooling # -----------------------------------------------------------------------
[docs] async def setup_pool(self, min_size: int = 1, max_size: int = 10) -> None: """ Set up an async connection pool. Supported for PostgreSQL and MySQL only. After calling this, use get_connection_from_pool() to acquire connections. Args: min_size: Minimum number of connections in the pool. max_size: Maximum number of connections in the pool. Raises: AsyncConnectionError: If pool setup fails or db_type does not support pooling. """ try: if self.db_type == "postgres": import asyncpg self._pool = await asyncpg.create_pool( host=self.host, port=int(self.port or 5432), user=self.user, password=self.password, database=self.database, min_size=min_size, max_size=max_size, ) logger.info( "PostgreSQL async pool created (min=%d, max=%d)", min_size, max_size ) elif self.db_type == "mysql": import aiomysql self._pool = await aiomysql.create_pool( host=self.host or "localhost", port=int(self.port or 3306), user=self.user, password=self.password or "", db=self.database, minsize=min_size, maxsize=max_size, ) logger.info( "MySQL async pool created (min=%d, max=%d)", min_size, max_size ) else: raise AsyncConnectionError( f"Async connection pooling not supported for {self.db_type!r}. " "Supported: postgres, mysql." ) except AsyncConnectionError: raise except Exception as e: raise AsyncConnectionError(f"Failed to create async pool: {e}") from e
[docs] async def close_pool(self) -> None: """Close the async connection pool.""" if self._pool is not None: if self.db_type == "mysql": self._pool.close() await self._pool.wait_closed() else: await self._pool.close() self._pool = None logger.info("Async pool closed for %s", self.db_type)