1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86import logging
from sqlalchemy import update
from sqlalchemy.orm import Session
from configs import dify_config
from core.errors.error import QuotaExceededError
from extensions.ext_database import db
from models import TenantCreditPool
logger = logging.getLogger(__name__)
class CreditPoolService:
@classmethod
def create_default_pool(cls, tenant_id: str) -> TenantCreditPool:
"""create default credit pool for new tenant"""
credit_pool = TenantCreditPool(
tenant_id=tenant_id, quota_limit=dify_config.HOSTED_POOL_CREDITS, quota_used=0, pool_type="trial"
)
db.session.add(credit_pool)
db.session.commit()
return credit_pool
@classmethod
def get_pool(cls, tenant_id: str, pool_type: str = "trial") -> TenantCreditPool | None:
"""get tenant credit pool"""
return (
db.session.query(TenantCreditPool)
.filter_by(
tenant_id=tenant_id,
pool_type=pool_type,
)
.first()
)
@classmethod
def check_credits_available(
cls,
tenant_id: str,
credits_required: int,
pool_type: str = "trial",
) -> bool:
"""check if credits are available without deducting"""
pool = cls.get_pool(tenant_id, pool_type)
if not pool:
return False
return pool.remaining_credits >= credits_required
@classmethod
def check_and_deduct_credits(
cls,
tenant_id: str,
credits_required: int,
pool_type: str = "trial",
) -> int:
"""check and deduct credits, returns actual credits deducted"""
pool = cls.get_pool(tenant_id, pool_type)
if not pool:
raise QuotaExceededError("Credit pool not found")
if pool.remaining_credits <= 0:
raise QuotaExceededError("No credits remaining")
# deduct all remaining credits if less than required
actual_credits = min(credits_required, pool.remaining_credits)
try:
with Session(db.engine) as session:
stmt = (
update(TenantCreditPool)
.where(
TenantCreditPool.tenant_id == tenant_id,
TenantCreditPool.pool_type == pool_type,
)
.values(quota_used=TenantCreditPool.quota_used + actual_credits)
)
session.execute(stmt)
session.commit()
except Exception:
logger.exception("Failed to deduct credits for tenant %s", tenant_id)
raise QuotaExceededError("Failed to deduct credits")
return actual_credits