Source code for rail_pz_service.db.row

"""Mixin functionality for Database tables"""

from __future__ import annotations

from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, TypeVar

from pydantic import BaseModel, TypeAdapter
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError, StatementError
from sqlalchemy.ext.asyncio import async_scoped_session
from structlog import get_logger

from ..common.errors import (
    RAILIDMismatchError,
    RAILIntegrityError,
    RAILMissingIDError,
    RAILMissingNameError,
    RAILStatementError,
)

logger = get_logger(__name__)

T = TypeVar("T", bound="RowMixin")


[docs] class RowMixin: """Mixin class to define common features of database rows for all the tables we use in rail_server Here we a just defining the interface to manipulate any sort of table. """ id: Any # Primary Key, typically an int name: Any # Human-readable name for row class_string: str # Name to use for help functions and descriptions pydantic_mode_class: type[BaseModel] # Pydantic model class
[docs] @classmethod async def get_rows( cls: type[T], session: async_scoped_session, skip: int = 0, limit: int = 100, ) -> Sequence[T]: """Get rows associated to a particular table Parameters ---------- session DB session manager skip Number of rows to skip before returning results limit Number of row to return Returns ------- Sequence[T] All the matching rows """ q = select(cls) q = q.offset(skip).limit(limit) results = await session.scalars(q) return results.all()
[docs] @classmethod async def get_row( cls: type[T], session: async_scoped_session, row_id: int, ) -> T: """Get a single row, matching row.id == row_id Parameters ---------- session DB session manager row_id PrimaryKey of the row to return Returns ------- T The matching row Raises ------ RAILMissingIDError Row with ID does not exist """ result = await session.get(cls, row_id) if result is None: raise RAILMissingIDError(f"{cls} {row_id} not found") return result
[docs] @classmethod async def get_row_by_name( cls: type[T], session: async_scoped_session, name: str, ) -> T: """Get a single row, with row.name == name Parameters ---------- session DB session manager name name of the row to return Returns ------- T Matching row Raises ------ RAILMissingNameError Row with ID does not exist """ query = select(cls).where(cls.name == name) rows = await session.scalars(query) row = rows.first() if row is None: raise RAILMissingNameError(f"{cls} {name} not found") return row
[docs] @classmethod async def delete_row( cls, session: async_scoped_session, row_id: int, ) -> None: """Delete a single row, matching row.id == row_id Parameters ---------- session DB session manager row_id PrimaryKey of the row to delete Raises ------ CMMissingIDError Row does not exist CMIntegrityError sqlalchemy.IntegrityError raised """ row = await session.get(cls, row_id) if row is None: raise RAILMissingIDError(f"{cls} {row_id} not found") try: await session.delete(row) except IntegrityError as msg: if TYPE_CHECKING: assert msg.orig # for mypy raise RAILIntegrityError(msg) from msg await cls._delete_hook(session, row_id)
@classmethod async def _delete_hook( cls, session: async_scoped_session, # pylint: disable=unused-argument row_id: int, # pylint: disable=unused-argument ) -> None: """Hook called during delete_row Parameters ---------- session DB session manager row_id PrimaryKey of the row to delete """ return
[docs] @classmethod async def update_row( cls: type[T], session: async_scoped_session, row_id: int, **kwargs: Any, ) -> T: """Update a single row, matching row.id == row_id Parameters ---------- session DB session manager row_id PrimaryKey of the row to return **kwargs Columns and associated new values Returns ------- T: Updated row Raises ------ RAILIDMismatchError ID mismatch between row IDs RAILMissingIDError Could not find row RAILIntegrityError catching a IntegrityError """ if kwargs.get("id", row_id) != row_id: raise RAILIDMismatchError("ID mismatch between URL and body") row = await session.get(cls, row_id) if row is None: raise RAILMissingIDError(f"{cls} {row_id} not found") try: async with session.begin_nested(): for var, value in kwargs.items(): if isinstance(value, dict): # pragma: no cover the_dict = getattr(row, var).copy() the_dict.update(**value) setattr(row, var, the_dict) else: setattr(row, var, value) except StatementError as msg: if TYPE_CHECKING: assert msg.orig # for mypy raise RAILStatementError(msg) from msg return row
[docs] @classmethod async def create_row( cls: type[T], session: async_scoped_session, **kwargs: Any, ) -> T: """Create a single row Parameters ---------- session DB session manager **kwargs: Any Columns and associated values for the new row Returns ------- T Newly created row Raises ------ CMIntegrityError catching a IntegrityError """ create_kwargs = await cls.get_create_kwargs(session, **kwargs) row = cls(**create_kwargs) try: async with session.begin_nested(): session.add(row) except IntegrityError as msg: if TYPE_CHECKING: assert msg.orig # for mypy raise RAILIntegrityError(msg) from msg await session.refresh(row) return row
[docs] @classmethod async def get_create_kwargs( cls: type[T], session: async_scoped_session, **kwargs: Any, ) -> dict: """Get additional keywords needed to create a row This should be overridden by sub-classes as needed The default is to just return the original keywords Parameters ---------- session DB session manager **kwargs Columns and associated values for the new row Returns ------- dict Keywords needed to create a new row """ assert session return kwargs
[docs] async def update_values( self: T, session: async_scoped_session, **kwargs: Any, ) -> T: """Update values in a row Parameters ---------- session DB session manager **kwargs Columns and associated new values Returns ------- T Updated row Raises ------ CMIntegrityError Catching a IntegrityError """ return await self.update_row(session, self.id, **kwargs)
[docs] def to_model(self) -> BaseModel: """Return a reow as a pydantic model""" return_obj = TypeAdapter(self.pydantic_mode_class).validate_python(self) return return_obj