108 lines
3.7 KiB
Python
108 lines
3.7 KiB
Python
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
|
from sqlalchemy import BigInteger, String, select
|
|
from typing import Optional
|
|
from config import DB_TYPE, DB_URL
|
|
|
|
engine = None
|
|
SessionLocal = None
|
|
|
|
class Base(DeclarativeBase):
|
|
pass
|
|
|
|
class Ticket(Base):
|
|
__tablename__ = "tickets"
|
|
|
|
id: Mapped[int] = mapped_column(primary_key=True)
|
|
guild_id: Mapped[int] = mapped_column(BigInteger)
|
|
channel_id: Mapped[int] = mapped_column(BigInteger, unique=True)
|
|
user_id: Mapped[int] = mapped_column(BigInteger)
|
|
ticket_number: Mapped[int]
|
|
|
|
class Config(Base):
|
|
__tablename__ = "config"
|
|
|
|
id: Mapped[int] = mapped_column(primary_key=True)
|
|
guild_id: Mapped[int] = mapped_column(BigInteger, unique=True)
|
|
panel_channel_id: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True)
|
|
panel_message_id: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True)
|
|
category_id: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True)
|
|
|
|
async def init_db():
|
|
global engine, SessionLocal
|
|
|
|
engine = create_async_engine(DB_URL, echo=False)
|
|
SessionLocal = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
|
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
def get_session():
|
|
return SessionLocal()
|
|
|
|
async def create_ticket(guild_id: int, channel_id: int, user_id: int) -> Ticket:
|
|
async with get_session() as session:
|
|
result = await session.execute(
|
|
select(Ticket).where(Ticket.guild_id == guild_id).order_by(Ticket.ticket_number.desc())
|
|
)
|
|
last_ticket = result.scalars().first()
|
|
ticket_num = (last_ticket.ticket_number + 1) if last_ticket else 1
|
|
|
|
ticket = Ticket(
|
|
guild_id=guild_id,
|
|
channel_id=channel_id,
|
|
user_id=user_id,
|
|
ticket_number=ticket_num
|
|
)
|
|
session.add(ticket)
|
|
await session.commit()
|
|
await session.refresh(ticket)
|
|
|
|
return type('TicketData', (), {
|
|
'ticket_number': ticket.ticket_number,
|
|
'channel_id': ticket.channel_id,
|
|
'user_id': ticket.user_id,
|
|
'guild_id': ticket.guild_id
|
|
})()
|
|
|
|
async def get_ticket(channel_id: int) -> Optional[Ticket]:
|
|
async with get_session() as session:
|
|
result = await session.execute(
|
|
select(Ticket).where(Ticket.channel_id == channel_id)
|
|
)
|
|
return result.scalars().first()
|
|
|
|
async def delete_ticket(channel_id: int):
|
|
async with get_session() as session:
|
|
result = await session.execute(
|
|
select(Ticket).where(Ticket.channel_id == channel_id)
|
|
)
|
|
ticket = result.scalars().first()
|
|
if ticket:
|
|
await session.delete(ticket)
|
|
await session.commit()
|
|
|
|
async def get_config(guild_id: int) -> Optional[Config]:
|
|
async with get_session() as session:
|
|
result = await session.execute(
|
|
select(Config).where(Config.guild_id == guild_id)
|
|
)
|
|
return result.scalars().first()
|
|
|
|
async def set_config(guild_id: int, **kwargs) -> Config:
|
|
async with get_session() as session:
|
|
result = await session.execute(
|
|
select(Config).where(Config.guild_id == guild_id)
|
|
)
|
|
config = result.scalars().first()
|
|
|
|
if not config:
|
|
config = Config(guild_id=guild_id)
|
|
session.add(config)
|
|
|
|
for key, value in kwargs.items():
|
|
setattr(config, key, value)
|
|
|
|
await session.commit()
|
|
await session.refresh(config)
|
|
return config |