from app.models import CategoryModel, MapsetModel, ClassificationModel from sqlalchemy import func, or_, cast, String, select from sqlalchemy.orm import joinedload, selectinload from fastapi_async_sqlalchemy import db from typing import List, Tuple, Optional from uuid import UUID from . import BaseRepository class CategoryRepository(BaseRepository[CategoryModel]): def __init__(self, model): super().__init__(model) async def find_by_id(self, id: UUID, relationships: List[str] = None) -> Optional[CategoryModel]: """Find category by ID with mapset count.""" # Create subquery for mapset count mapset_count_subquery = ( select(func.count(MapsetModel.id)) .join(ClassificationModel, MapsetModel.classification_id == ClassificationModel.id) .where(MapsetModel.category_id == id) .where(MapsetModel.is_deleted == False) .where(MapsetModel.is_active == True) .where(MapsetModel.status_validation == "approved") .where(ClassificationModel.is_open == True) .scalar_subquery() ) # Build query with mapset count query = ( select(self.model, mapset_count_subquery.label('mapset_count')) .where(self.model.id == id) ) if hasattr(self.model, "is_deleted"): query = query.where(self.model.is_deleted.is_(False)) if relationships: for rel in relationships: if hasattr(self.model, rel): attr = getattr(self.model, rel) if hasattr(attr.property, "collection_class"): query = query.options(selectinload(attr)) else: query = query.options(joinedload(attr)) result = await db.session.execute(query) row = result.first() if row: category = row[0] category.count_mapset = row[1] if row[1] is not None else 0 return category return None async def find_all( self, filters: list = [], sort: list = [], search: str = "", group_by: str = None, limit: int = 100, offset: int = 0, relationships: List[str] = None, searchable_columns: List[str] = None, ) -> Tuple[List[CategoryModel], int]: """Optimized find_all method with mapset count.""" # Create subquery for mapset count mapset_count_subquery = ( select( MapsetModel.category_id, func.count(MapsetModel.id).label('mapset_count') ) .join(ClassificationModel, MapsetModel.classification_id == ClassificationModel.id) .where(MapsetModel.is_deleted == False) .where(MapsetModel.is_active == True) .where(MapsetModel.status_validation == "approved") .where(ClassificationModel.is_open == True) .group_by(MapsetModel.category_id) .subquery() ) # Build base query with mapset count query = ( select(self.model, func.coalesce(mapset_count_subquery.c.mapset_count, 0).label('mapset_count')) .outerjoin(mapset_count_subquery, self.model.id == mapset_count_subquery.c.category_id) .filter(*filters) ) # Optimized search if search: if searchable_columns: search_conditions = [ cast(getattr(self.model, col), String).ilike(f"%{search}%") for col in searchable_columns if hasattr(self.model, col) ] else: search_conditions = [ cast(getattr(self.model, col), String).ilike(f"%{search}%") for col in self.model.__table__.columns.keys() if not col.startswith("_") ] if search_conditions: query = query.where(or_(*search_conditions)) if group_by: query = query.group_by(getattr(self.model, group_by)) # Count query count_query = select(func.count()).select_from(query.subquery()) total = await db.session.scalar(count_query) # Data query if sort: query = query.order_by(*sort) else: query = query.order_by(self.model.order.asc()) if relationships: for rel in relationships: if hasattr(self.model, rel): attr = getattr(self.model, rel) if hasattr(attr.property, "collection_class"): query = query.options(selectinload(attr)) else: query = query.options(joinedload(attr)) query = query.limit(limit).offset(offset) result = await db.session.execute(query) # Extract records and set mapset_count records = [] for row in result: category = row[0] category.count_mapset = row[1] records.append(category) return records, total