259 lines
9.2 KiB
Python
259 lines
9.2 KiB
Python
from ast import Dict
|
|
from typing import List, Optional, Tuple
|
|
from uuid import UUID
|
|
|
|
from fastapi_async_sqlalchemy import db
|
|
from sqlalchemy import Integer, String, and_, cast, func, or_, select, update
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
from app.models import (
|
|
ClassificationModel,
|
|
MapAccessModel,
|
|
MapsetModel,
|
|
OrganizationModel,
|
|
)
|
|
from app.schemas.user_schema import UserSchema
|
|
|
|
from . import BaseRepository
|
|
|
|
|
|
class MapsetRepository(BaseRepository[MapsetModel]):
|
|
def __init__(self, model):
|
|
super().__init__(model)
|
|
|
|
async def find_all(
|
|
self,
|
|
user: UserSchema = None,
|
|
filters: list = None,
|
|
sort: list = ...,
|
|
search: str = "",
|
|
group_by: str = None,
|
|
limit: int = 100,
|
|
offset: int = 0,
|
|
landing: bool = False,
|
|
) -> Tuple[List[MapsetModel], int]:
|
|
|
|
base_query = select(self.model).distinct()
|
|
|
|
base_query = base_query.join(ClassificationModel, self.model.classification_id == ClassificationModel.id)
|
|
|
|
if user and user.role.name not in {"administrator", "data_validator"}:
|
|
base_query = base_query.join(MapAccessModel, self.model.id == MapAccessModel.mapset_id, isouter=True)
|
|
|
|
if (user is None) or landing:
|
|
base_query = base_query.filter(ClassificationModel.is_open == True)
|
|
base_query = base_query.filter(self.model.is_active == True)
|
|
base_query = base_query.filter(self.model.status_validation == "approved")
|
|
elif user.role.name not in {"administrator", "data_validator"}:
|
|
base_query = base_query.filter(
|
|
or_(
|
|
# ClassificationModel.is_limited.is_(True),
|
|
# ClassificationModel.is_open.is_(True),
|
|
and_(
|
|
# ClassificationModel.is_secret.is_(True),
|
|
self.model.producer_id == user.organization.id,
|
|
),
|
|
and_(
|
|
# ClassificationModel.is_secret.is_(True),
|
|
MapAccessModel.organization_id == user.organization.id,
|
|
),
|
|
and_(
|
|
# ClassificationModel.is_secret.is_(True),
|
|
MapAccessModel.user_id == user.id,
|
|
),
|
|
)
|
|
)
|
|
|
|
if filters:
|
|
base_query = base_query.filter(*filters)
|
|
|
|
if search:
|
|
base_query = base_query.join(
|
|
OrganizationModel,
|
|
self.model.producer_id == OrganizationModel.id,
|
|
isouter=True,
|
|
).filter(
|
|
or_(
|
|
*[
|
|
cast(getattr(self.model, col), String).ilike(f"%{search}%")
|
|
for col in self.model.__table__.columns.keys()
|
|
],
|
|
OrganizationModel.name.ilike(f"%{search}%"),
|
|
)
|
|
)
|
|
|
|
if group_by:
|
|
base_query = base_query.group_by(getattr(self.model, group_by))
|
|
|
|
count_query = select(func.count()).select_from(base_query.subquery())
|
|
total = await db.session.scalar(count_query)
|
|
|
|
if not sort or sort is ...:
|
|
base_query = base_query.order_by(self.model.order.asc())
|
|
else:
|
|
base_query = base_query.order_by(*sort)
|
|
|
|
base_query = base_query.limit(limit).offset(offset)
|
|
|
|
result = await db.session.execute(base_query)
|
|
result = result.scalars().all()
|
|
|
|
return result, total
|
|
|
|
async def find_all_group_by_organization(
|
|
self,
|
|
user: Optional[UserSchema] = None,
|
|
mapset_filters: list = None,
|
|
organization_filters: list = None,
|
|
sort: list = None,
|
|
search: str = "",
|
|
limit: int = 100,
|
|
offset: int = 0,
|
|
) -> Tuple[List[Dict], int]:
|
|
|
|
mapset_filters = mapset_filters or []
|
|
organization_filters = organization_filters or []
|
|
sort = sort or [OrganizationModel.name.asc()]
|
|
|
|
if user is None:
|
|
base_mapset_query = (
|
|
select(self.model)
|
|
.join(
|
|
ClassificationModel,
|
|
self.model.classification_id == ClassificationModel.id,
|
|
)
|
|
.filter(ClassificationModel.is_open.is_(True))
|
|
)
|
|
elif user.role in {"administrator", "data-validator"}:
|
|
base_mapset_query = select(self.model)
|
|
else:
|
|
user_org_id = user.organization.id if user.organization else None
|
|
|
|
base_mapset_query = (
|
|
select(self.model)
|
|
.join(
|
|
ClassificationModel,
|
|
self.model.classification_id == ClassificationModel.id,
|
|
)
|
|
.outerjoin(MapAccessModel, self.model.id == MapAccessModel.mapset_id)
|
|
.filter(
|
|
or_(
|
|
ClassificationModel.is_open.is_(True),
|
|
ClassificationModel.is_limited.is_(True),
|
|
and_(
|
|
ClassificationModel.is_secret.is_(True),
|
|
self.model.producer_id == user.id,
|
|
),
|
|
and_(
|
|
ClassificationModel.is_secret.is_(True),
|
|
MapAccessModel.user_id == user.id,
|
|
),
|
|
and_(
|
|
ClassificationModel.is_secret.is_(True),
|
|
user_org_id is not None,
|
|
MapAccessModel.organization_id == user_org_id,
|
|
),
|
|
)
|
|
)
|
|
)
|
|
|
|
filtered_mapset_query = base_mapset_query
|
|
if mapset_filters:
|
|
filtered_mapset_query = filtered_mapset_query.filter(*mapset_filters)
|
|
|
|
if search:
|
|
search_filters = []
|
|
for col in self.model.__table__.columns.keys():
|
|
if hasattr(self.model, col):
|
|
search_filters.append(cast(getattr(self.model, col), String).ilike(f"%{search}%"))
|
|
|
|
if search_filters:
|
|
filtered_mapset_query = filtered_mapset_query.filter(or_(*search_filters))
|
|
|
|
producer_ids_subquery = select(self.model.producer_id).select_from(filtered_mapset_query.subquery()).distinct()
|
|
|
|
org_query = select(OrganizationModel).filter(OrganizationModel.id.in_(producer_ids_subquery))
|
|
|
|
if organization_filters:
|
|
org_query = org_query.filter(*organization_filters)
|
|
|
|
if search:
|
|
org_search_filters = []
|
|
for col in OrganizationModel.__table__.columns.keys():
|
|
if hasattr(OrganizationModel, col):
|
|
org_search_filters.append(cast(getattr(OrganizationModel, col), String).ilike(f"%{search}%"))
|
|
|
|
if org_search_filters:
|
|
org_query = org_query.filter(or_(*org_search_filters))
|
|
|
|
count_query = select(func.count()).select_from(
|
|
select(OrganizationModel.id).select_from(org_query.subquery()).distinct()
|
|
)
|
|
total = await db.session.scalar(count_query)
|
|
|
|
org_query = org_query.order_by(*sort)
|
|
|
|
if limit:
|
|
org_query = org_query.limit(limit)
|
|
if offset:
|
|
org_query = org_query.offset(offset)
|
|
|
|
org_result = await db.session.execute(org_query)
|
|
organizations = org_result.scalars().unique().all()
|
|
|
|
org_ids = [org.id for org in organizations]
|
|
|
|
if not org_ids:
|
|
return [], total
|
|
|
|
all_mapsets_query = filtered_mapset_query.filter(self.model.producer_id.in_(org_ids)).options(
|
|
selectinload(self.model.classification)
|
|
)
|
|
|
|
all_mapsets_result = await db.session.execute(all_mapsets_query)
|
|
all_mapsets = all_mapsets_result.scalars().unique().all()
|
|
|
|
mapsets_by_org = {}
|
|
for mapset in all_mapsets:
|
|
if mapset.producer_id not in mapsets_by_org:
|
|
mapsets_by_org[mapset.producer_id] = []
|
|
mapsets_by_org[mapset.producer_id].append(mapset)
|
|
|
|
result_data = []
|
|
for org in organizations:
|
|
org_mapsets = mapsets_by_org.get(org.id, [])
|
|
|
|
result_data.append(
|
|
{
|
|
"id": org.id,
|
|
"name": org.name,
|
|
"mapsets": org_mapsets,
|
|
"found": len(org_mapsets),
|
|
}
|
|
)
|
|
|
|
return result_data, total
|
|
|
|
async def bulk_update_activation(self, mapset_ids: List[UUID], is_active: bool) -> None:
|
|
for mapset_id in mapset_ids:
|
|
await db.session.execute(update(self.model).where(self.model.id == mapset_id).values(is_active=is_active))
|
|
await db.session.commit()
|
|
|
|
async def increment_view_count(self, mapset_id: UUID) -> None:
|
|
query = (
|
|
update(self.model)
|
|
.where(self.model.id == mapset_id)
|
|
.values(view_count=self.model.view_count + 1)
|
|
)
|
|
await db.session.execute(query)
|
|
await db.session.commit()
|
|
|
|
async def increment_download_count(self, mapset_id: UUID) -> None:
|
|
query = (
|
|
update(self.model)
|
|
.where(self.model.id == mapset_id)
|
|
.values(download_count=self.model.download_count + 1)
|
|
)
|
|
await db.session.execute(query)
|
|
await db.session.commit()
|