satupeta-main/app/repositories/mapset_repository.py

259 lines
9.2 KiB
Python
Raw Normal View History

2026-01-27 02:11:58 +00:00
from ast import Dict
from typing import List, Optional, Tuple
from uuid import UUID
from fastapi_async_sqlalchemy import db
from sqlalchemy import Integer, String, and_, cast, func, or_, select, update
from sqlalchemy.orm import selectinload
from app.models import (
ClassificationModel,
MapAccessModel,
MapsetModel,
OrganizationModel,
)
from app.schemas.user_schema import UserSchema
from . import BaseRepository
class MapsetRepository(BaseRepository[MapsetModel]):
def __init__(self, model):
super().__init__(model)
async def find_all(
self,
user: UserSchema = None,
filters: list = None,
sort: list = ...,
search: str = "",
group_by: str = None,
limit: int = 100,
offset: int = 0,
landing: bool = False,
) -> Tuple[List[MapsetModel], int]:
base_query = select(self.model).distinct()
base_query = base_query.join(ClassificationModel, self.model.classification_id == ClassificationModel.id)
if user and user.role.name not in {"administrator", "data_validator"}:
base_query = base_query.join(MapAccessModel, self.model.id == MapAccessModel.mapset_id, isouter=True)
if (user is None) or landing:
base_query = base_query.filter(ClassificationModel.is_open == True)
base_query = base_query.filter(self.model.is_active == True)
base_query = base_query.filter(self.model.status_validation == "approved")
elif user.role.name not in {"administrator", "data_validator"}:
base_query = base_query.filter(
or_(
# ClassificationModel.is_limited.is_(True),
# ClassificationModel.is_open.is_(True),
and_(
# ClassificationModel.is_secret.is_(True),
self.model.producer_id == user.organization.id,
),
and_(
# ClassificationModel.is_secret.is_(True),
MapAccessModel.organization_id == user.organization.id,
),
and_(
# ClassificationModel.is_secret.is_(True),
MapAccessModel.user_id == user.id,
),
)
)
if filters:
base_query = base_query.filter(*filters)
if search:
base_query = base_query.join(
OrganizationModel,
self.model.producer_id == OrganizationModel.id,
isouter=True,
).filter(
or_(
*[
cast(getattr(self.model, col), String).ilike(f"%{search}%")
for col in self.model.__table__.columns.keys()
],
OrganizationModel.name.ilike(f"%{search}%"),
)
)
if group_by:
base_query = base_query.group_by(getattr(self.model, group_by))
count_query = select(func.count()).select_from(base_query.subquery())
total = await db.session.scalar(count_query)
if not sort or sort is ...:
base_query = base_query.order_by(self.model.order.asc())
else:
base_query = base_query.order_by(*sort)
base_query = base_query.limit(limit).offset(offset)
result = await db.session.execute(base_query)
result = result.scalars().all()
return result, total
async def find_all_group_by_organization(
self,
user: Optional[UserSchema] = None,
mapset_filters: list = None,
organization_filters: list = None,
sort: list = None,
search: str = "",
limit: int = 100,
offset: int = 0,
) -> Tuple[List[Dict], int]:
mapset_filters = mapset_filters or []
organization_filters = organization_filters or []
sort = sort or [OrganizationModel.name.asc()]
if user is None:
base_mapset_query = (
select(self.model)
.join(
ClassificationModel,
self.model.classification_id == ClassificationModel.id,
)
.filter(ClassificationModel.is_open.is_(True))
)
elif user.role in {"administrator", "data-validator"}:
base_mapset_query = select(self.model)
else:
user_org_id = user.organization.id if user.organization else None
base_mapset_query = (
select(self.model)
.join(
ClassificationModel,
self.model.classification_id == ClassificationModel.id,
)
.outerjoin(MapAccessModel, self.model.id == MapAccessModel.mapset_id)
.filter(
or_(
ClassificationModel.is_open.is_(True),
ClassificationModel.is_limited.is_(True),
and_(
ClassificationModel.is_secret.is_(True),
self.model.producer_id == user.id,
),
and_(
ClassificationModel.is_secret.is_(True),
MapAccessModel.user_id == user.id,
),
and_(
ClassificationModel.is_secret.is_(True),
user_org_id is not None,
MapAccessModel.organization_id == user_org_id,
),
)
)
)
filtered_mapset_query = base_mapset_query
if mapset_filters:
filtered_mapset_query = filtered_mapset_query.filter(*mapset_filters)
if search:
search_filters = []
for col in self.model.__table__.columns.keys():
if hasattr(self.model, col):
search_filters.append(cast(getattr(self.model, col), String).ilike(f"%{search}%"))
if search_filters:
filtered_mapset_query = filtered_mapset_query.filter(or_(*search_filters))
producer_ids_subquery = select(self.model.producer_id).select_from(filtered_mapset_query.subquery()).distinct()
org_query = select(OrganizationModel).filter(OrganizationModel.id.in_(producer_ids_subquery))
if organization_filters:
org_query = org_query.filter(*organization_filters)
if search:
org_search_filters = []
for col in OrganizationModel.__table__.columns.keys():
if hasattr(OrganizationModel, col):
org_search_filters.append(cast(getattr(OrganizationModel, col), String).ilike(f"%{search}%"))
if org_search_filters:
org_query = org_query.filter(or_(*org_search_filters))
count_query = select(func.count()).select_from(
select(OrganizationModel.id).select_from(org_query.subquery()).distinct()
)
total = await db.session.scalar(count_query)
org_query = org_query.order_by(*sort)
if limit:
org_query = org_query.limit(limit)
if offset:
org_query = org_query.offset(offset)
org_result = await db.session.execute(org_query)
organizations = org_result.scalars().unique().all()
org_ids = [org.id for org in organizations]
if not org_ids:
return [], total
all_mapsets_query = filtered_mapset_query.filter(self.model.producer_id.in_(org_ids)).options(
selectinload(self.model.classification)
)
all_mapsets_result = await db.session.execute(all_mapsets_query)
all_mapsets = all_mapsets_result.scalars().unique().all()
mapsets_by_org = {}
for mapset in all_mapsets:
if mapset.producer_id not in mapsets_by_org:
mapsets_by_org[mapset.producer_id] = []
mapsets_by_org[mapset.producer_id].append(mapset)
result_data = []
for org in organizations:
org_mapsets = mapsets_by_org.get(org.id, [])
result_data.append(
{
"id": org.id,
"name": org.name,
"mapsets": org_mapsets,
"found": len(org_mapsets),
}
)
return result_data, total
async def bulk_update_activation(self, mapset_ids: List[UUID], is_active: bool) -> None:
for mapset_id in mapset_ids:
await db.session.execute(update(self.model).where(self.model.id == mapset_id).values(is_active=is_active))
await db.session.commit()
async def increment_view_count(self, mapset_id: UUID) -> None:
query = (
update(self.model)
.where(self.model.id == mapset_id)
.values(view_count=self.model.view_count + 1)
)
await db.session.execute(query)
await db.session.commit()
async def increment_download_count(self, mapset_id: UUID) -> None:
query = (
update(self.model)
.where(self.model.id == mapset_id)
.values(download_count=self.model.download_count + 1)
)
await db.session.execute(query)
await db.session.commit()