from typing import Any, Dict, List, Optional, Tuple, override from uuid import UUID from fastapi_async_sqlalchemy import db from sqlalchemy import ( Integer, Numeric, String, Unicode, UnicodeText, and_, cast, exists, func, or_, select, desc, text, update as sqlalchemy_update ) from app.models.classification_model import ClassificationModel from app.models.map_access_model import MapAccessModel from app.models.mapset_model import MapsetModel from app.models.organization_model import OrganizationModel from app.schemas.user_schema import UserSchema from . import BaseRepository class OrganizationRepository(BaseRepository[OrganizationModel]): def __init__(self, model, mapset_model: MapsetModel): super().__init__(model) self.mapset_model = mapset_model async def flag_delete_organization(self, id): return await self.flag_delete_organization(id) async def find_by_name(self, name: str, sensitive: bool = False): if not sensitive: name = name.lower() query = select(self.model) if not sensitive: query = query.where(self.model.name == name) else: query = query.where(self.model.name.ilike(f"%{name}%")) result = await db.session.execute(query) return result.scalar_one_or_none() async def find_all( self, user: UserSchema | None, filters: list, sort: list | None = None, search: str = "", group_by: str = None, limit: int = 100, offset: int = 0, landing: bool = False, ) -> Tuple[List[OrganizationModel], int]: """Find all records with pagination.""" if sort is None: sort = [] mapset_count = func.count(self.mapset_model.id).label("count_mapset") base_query = select( self.model.id, self.model.name, self.model.description, self.model.thumbnail, self.model.address, self.model.phone_number, self.model.email, self.model.website, mapset_count, self.model.is_active, self.model.is_deleted, self.model.created_at, self.model.modified_at, ).select_from(self.model) # Use simple outerjoin first, then apply WHERE conditions # This ensures organizations without mapsets are still included base_query = base_query.outerjoin( self.mapset_model, self.model.id == self.mapset_model.producer_id ) base_query = base_query.outerjoin( ClassificationModel, self.mapset_model.classification_id == ClassificationModel.id, ) # Apply mapset-level filters only to the mapset records, not to the organization join mapset_conditions = [ or_( self.mapset_model.id.is_(None), # Allow organizations without mapsets and_( self.mapset_model.is_active.is_(True), self.mapset_model.is_deleted.is_(False), ) ) ] # Add user-specific filters for mapsets # When landing=True, count all mapsets without filtering by user organization if (user is None) or landing: mapset_conditions.append( or_( self.mapset_model.id.is_(None), # Organizations without mapsets and_( self.mapset_model.status_validation == "approved", ClassificationModel.is_open.is_(True) ) ) ) elif user.role not in {"administrator", "data_validator"}: # When landing=False and user is not admin, filter by user organization base_query = base_query.outerjoin( MapAccessModel, and_( self.mapset_model.id == MapAccessModel.mapset_id, or_( MapAccessModel.organization_id == user.organization.id, MapAccessModel.user_id == user.id, ), ), ) mapset_conditions.append( or_( self.mapset_model.id.is_(None), # Organizations without mapsets ClassificationModel.is_limited.is_(True), ClassificationModel.is_open.is_(True), and_( ClassificationModel.is_secret.is_(True), self.mapset_model.producer_id == user.organization.id, ), ) ) # Apply all mapset conditions base_query = base_query.where(and_(*mapset_conditions)) if hasattr(self.model, "is_deleted"): base_query = base_query.where(self.model.is_deleted.is_(False)) if filters: base_query = base_query.where(*filters) if search: search_filters = [] for col in self.model.__table__.columns.keys(): column = getattr(self.model, col) if isinstance(column.type, (String, Unicode, UnicodeText)): search_filters.append(column.ilike(f"%{search}%")) elif isinstance(column.type, (Integer, Numeric)): try: num_val = float(search) search_filters.append(cast(column, String) == str(num_val)) except (ValueError, TypeError): pass if search_filters: base_query = base_query.where(or_(*search_filters)) group_columns = [self.model.id] if group_by and hasattr(self.model, group_by): group_col = getattr(self.model, group_by) if group_col not in group_columns: group_columns.append(group_col) base_query = base_query.group_by(*group_columns) count_query = select(func.count(self.model.id)).select_from(self.model) if hasattr(self.model, "is_deleted"): count_query = count_query.where(self.model.is_deleted.is_(False)) if filters: count_query = count_query.where(*filters) if search and search_filters: count_query = count_query.where(or_(*search_filters)) # For count query, we don't need to filter organizations based on mapset availability # This allows organizations with 0 mapsets to be included in the count # The filtering logic should be the same as the main query structure # but we don't need the mapset join conditions for counting organizations pass total = await db.session.scalar(count_query) if sort: base_query = base_query.order_by(*sort) else: base_query = base_query.order_by(desc(mapset_count)) base_query = base_query.limit(limit).offset(offset) result = await db.session.execute(base_query) items = result.mappings().all() return items, total @override async def update(self, id: UUID, data: Dict[str, Any], refresh: bool = True) -> Optional[OrganizationModel]: """Update record with optimization.""" clean_data = {k: v for k, v in data.items() if v is not None} if not clean_data: return await self.find_by_id(None, id) if refresh else None query = ( sqlalchemy_update(self.model) .where(self.model.id == id) .values(**clean_data) .execution_options(synchronize_session="fetch") ) result = await db.session.execute(query) await db.session.commit() if result.rowcount == 0: return None return await self.find_by_id(None, id) if refresh else None @override async def find_by_id(self, user: UserSchema | None, id: UUID) -> Optional[OrganizationModel]: if user is None: mapset_condition = and_( self.mapset_model.is_active.is_(True), self.mapset_model.is_deleted.is_(False), self.mapset_model.status_validation == "approved", self.mapset_model.producer_id == self.model.id, ) mapset_filter = or_( mapset_condition, self.mapset_model.id.is_(None) ) elif user.role in {"administrator", "data_validator"}: mapset_condition = and_( self.mapset_model.is_active.is_(True), self.mapset_model.is_deleted.is_(False), self.mapset_model.producer_id == self.model.id, ) mapset_filter = or_( mapset_condition, self.mapset_model.id.is_(None) ) else: mapset_condition = and_( or_( ClassificationModel.is_limited.is_(True), ClassificationModel.is_open.is_(True), and_( ClassificationModel.is_secret.is_(True), self.mapset_model.producer_id == user.organization.id, ), and_( ClassificationModel.is_secret.is_(True), MapAccessModel.organization_id == user.organization.id, ), and_( ClassificationModel.is_secret.is_(True), MapAccessModel.user_id == user.id, ), ), self.mapset_model.is_active.is_(True), self.mapset_model.is_deleted.is_(False), self.mapset_model.producer_id == self.model.id, ) mapset_filter = or_( mapset_condition, self.mapset_model.id.is_(None) ) query = ( select( self.model.id, self.model.name, self.model.description, self.model.thumbnail, self.model.address, self.model.phone_number, self.model.email, self.model.website, func.count(self.mapset_model.id).label("count_mapset"), self.model.is_active, self.model.is_deleted, self.model.created_at, self.model.modified_at, ) .outerjoin(self.mapset_model, self.model.id == self.mapset_model.producer_id) .outerjoin( ClassificationModel, self.mapset_model.classification_id == ClassificationModel.id, ) ) if user is not None and user.role not in {"administrator", "data_validator"}: query = query.outerjoin(MapAccessModel, self.mapset_model.id == MapAccessModel.mapset_id) if user is None or user.role not in {"administrator", "data_validator"}: query = query.where(mapset_filter) if hasattr(self.model, "is_deleted"): query = query.filter(self.model.is_deleted.is_(False)) query = query.filter(self.model.id == id) query = query.group_by(self.model.id) result = await db.session.execute(query) return result.mappings().one_or_none()