from ast import Dict from typing import List, Optional, Tuple from uuid import UUID from fastapi_async_sqlalchemy import db from sqlalchemy import Integer, String, and_, cast, func, or_, select, update from sqlalchemy.orm import selectinload from app.models import ( ClassificationModel, MapAccessModel, MapsetModel, OrganizationModel, ) from app.schemas.user_schema import UserSchema from . import BaseRepository class MapsetRepository(BaseRepository[MapsetModel]): def __init__(self, model): super().__init__(model) async def find_all( self, user: UserSchema = None, filters: list = None, sort: list = ..., search: str = "", group_by: str = None, limit: int = 100, offset: int = 0, landing: bool = False, ) -> Tuple[List[MapsetModel], int]: base_query = select(self.model).distinct() base_query = base_query.join(ClassificationModel, self.model.classification_id == ClassificationModel.id) if user and user.role.name not in {"administrator", "data_validator"}: base_query = base_query.join(MapAccessModel, self.model.id == MapAccessModel.mapset_id, isouter=True) if (user is None) or landing: base_query = base_query.filter(ClassificationModel.is_open == True) base_query = base_query.filter(self.model.is_active == True) base_query = base_query.filter(self.model.status_validation == "approved") elif user.role.name not in {"administrator", "data_validator"}: base_query = base_query.filter( or_( # ClassificationModel.is_limited.is_(True), # ClassificationModel.is_open.is_(True), and_( # ClassificationModel.is_secret.is_(True), self.model.producer_id == user.organization.id, ), and_( # ClassificationModel.is_secret.is_(True), MapAccessModel.organization_id == user.organization.id, ), and_( # ClassificationModel.is_secret.is_(True), MapAccessModel.user_id == user.id, ), ) ) if filters: base_query = base_query.filter(*filters) if search: base_query = base_query.join( OrganizationModel, self.model.producer_id == OrganizationModel.id, isouter=True, ).filter( or_( *[ cast(getattr(self.model, col), String).ilike(f"%{search}%") for col in self.model.__table__.columns.keys() ], OrganizationModel.name.ilike(f"%{search}%"), ) ) if group_by: base_query = base_query.group_by(getattr(self.model, group_by)) count_query = select(func.count()).select_from(base_query.subquery()) total = await db.session.scalar(count_query) if not sort or sort is ...: base_query = base_query.order_by(self.model.order.asc()) else: base_query = base_query.order_by(*sort) base_query = base_query.limit(limit).offset(offset) result = await db.session.execute(base_query) result = result.scalars().all() return result, total async def find_all_group_by_organization( self, user: Optional[UserSchema] = None, mapset_filters: list = None, organization_filters: list = None, sort: list = None, search: str = "", limit: int = 100, offset: int = 0, ) -> Tuple[List[Dict], int]: mapset_filters = mapset_filters or [] organization_filters = organization_filters or [] sort = sort or [OrganizationModel.name.asc()] if user is None: base_mapset_query = ( select(self.model) .join( ClassificationModel, self.model.classification_id == ClassificationModel.id, ) .filter(ClassificationModel.is_open.is_(True)) ) elif user.role in {"administrator", "data-validator"}: base_mapset_query = select(self.model) else: user_org_id = user.organization.id if user.organization else None base_mapset_query = ( select(self.model) .join( ClassificationModel, self.model.classification_id == ClassificationModel.id, ) .outerjoin(MapAccessModel, self.model.id == MapAccessModel.mapset_id) .filter( or_( ClassificationModel.is_open.is_(True), ClassificationModel.is_limited.is_(True), and_( ClassificationModel.is_secret.is_(True), self.model.producer_id == user.id, ), and_( ClassificationModel.is_secret.is_(True), MapAccessModel.user_id == user.id, ), and_( ClassificationModel.is_secret.is_(True), user_org_id is not None, MapAccessModel.organization_id == user_org_id, ), ) ) ) filtered_mapset_query = base_mapset_query if mapset_filters: filtered_mapset_query = filtered_mapset_query.filter(*mapset_filters) if search: search_filters = [] for col in self.model.__table__.columns.keys(): if hasattr(self.model, col): search_filters.append(cast(getattr(self.model, col), String).ilike(f"%{search}%")) if search_filters: filtered_mapset_query = filtered_mapset_query.filter(or_(*search_filters)) producer_ids_subquery = select(self.model.producer_id).select_from(filtered_mapset_query.subquery()).distinct() org_query = select(OrganizationModel).filter(OrganizationModel.id.in_(producer_ids_subquery)) if organization_filters: org_query = org_query.filter(*organization_filters) if search: org_search_filters = [] for col in OrganizationModel.__table__.columns.keys(): if hasattr(OrganizationModel, col): org_search_filters.append(cast(getattr(OrganizationModel, col), String).ilike(f"%{search}%")) if org_search_filters: org_query = org_query.filter(or_(*org_search_filters)) count_query = select(func.count()).select_from( select(OrganizationModel.id).select_from(org_query.subquery()).distinct() ) total = await db.session.scalar(count_query) org_query = org_query.order_by(*sort) if limit: org_query = org_query.limit(limit) if offset: org_query = org_query.offset(offset) org_result = await db.session.execute(org_query) organizations = org_result.scalars().unique().all() org_ids = [org.id for org in organizations] if not org_ids: return [], total all_mapsets_query = filtered_mapset_query.filter(self.model.producer_id.in_(org_ids)).options( selectinload(self.model.classification) ) all_mapsets_result = await db.session.execute(all_mapsets_query) all_mapsets = all_mapsets_result.scalars().unique().all() mapsets_by_org = {} for mapset in all_mapsets: if mapset.producer_id not in mapsets_by_org: mapsets_by_org[mapset.producer_id] = [] mapsets_by_org[mapset.producer_id].append(mapset) result_data = [] for org in organizations: org_mapsets = mapsets_by_org.get(org.id, []) result_data.append( { "id": org.id, "name": org.name, "mapsets": org_mapsets, "found": len(org_mapsets), } ) return result_data, total async def bulk_update_activation(self, mapset_ids: List[UUID], is_active: bool) -> None: for mapset_id in mapset_ids: await db.session.execute(update(self.model).where(self.model.id == mapset_id).values(is_active=is_active)) await db.session.commit() async def increment_view_count(self, mapset_id: UUID) -> None: query = ( update(self.model) .where(self.model.id == mapset_id) .values(view_count=self.model.view_count + 1) ) await db.session.execute(query) await db.session.commit() async def increment_download_count(self, mapset_id: UUID) -> None: query = ( update(self.model) .where(self.model.id == mapset_id) .values(download_count=self.model.download_count + 1) ) await db.session.execute(query) await db.session.commit()