Skip to content

Commit

Permalink
feat: claude paid optimize (langgenius#890)
Browse files Browse the repository at this point in the history
  • Loading branch information
takatost authored Aug 17, 2023
1 parent 2f7b234 commit 9adbead
Show file tree
Hide file tree
Showing 10 changed files with 69 additions and 25 deletions.
6 changes: 4 additions & 2 deletions api/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,12 @@ HOSTED_AZURE_OPENAI_QUOTA_LIMIT=200
HOSTED_ANTHROPIC_ENABLED=false
HOSTED_ANTHROPIC_API_BASE=
HOSTED_ANTHROPIC_API_KEY=
HOSTED_ANTHROPIC_QUOTA_LIMIT=1000000
HOSTED_ANTHROPIC_QUOTA_LIMIT=600000
HOSTED_ANTHROPIC_PAID_ENABLED=false
HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID=
HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1
HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1000000
HOSTED_ANTHROPIC_PAID_MIN_QUANTITY=20
HOSTED_ANTHROPIC_PAID_MAX_QUANTITY=100

STRIPE_API_KEY=
STRIPE_WEBHOOK_SECRET=
7 changes: 5 additions & 2 deletions api/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,23 +258,26 @@ def sync_anthropic_hosted_providers():
click.echo(click.style('Start sync anthropic hosted providers.', fg='green'))
count = 0

new_quota_limit = hosted_model_providers.anthropic.quota_limit

page = 1
while True:
try:
providers = db.session.query(Provider).filter(
Provider.provider_name == 'anthropic',
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.TRIAL.value,
Provider.quota_limit != new_quota_limit
).order_by(Provider.created_at.desc()).paginate(page=page, per_page=100)
except NotFound:
break

page += 1
for provider in providers:
try:
click.echo('Syncing tenant anthropic hosted provider: {}'.format(provider.tenant_id))
click.echo('Syncing tenant anthropic hosted provider: {}, origin: limit {}, used {}'
.format(provider.tenant_id, provider.quota_limit, provider.quota_used))
original_quota_limit = provider.quota_limit
new_quota_limit = hosted_model_providers.anthropic.quota_limit
division = math.ceil(new_quota_limit / 1000)

provider.quota_limit = new_quota_limit if original_quota_limit == 1000 \
Expand Down
16 changes: 10 additions & 6 deletions api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,12 @@
'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1,
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 1000000,
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
'HOSTED_ANTHROPIC_ENABLED': 'False',
'HOSTED_ANTHROPIC_PAID_ENABLED': 'False',
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1,
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000,
'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20,
'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100,
'TENANT_DOCUMENT_COUNT': 100,
'CLEAN_DAY_SETTING': 30,
'UPLOAD_FILE_SIZE_LIMIT': 15,
Expand Down Expand Up @@ -211,23 +213,25 @@ def __init__(self):
self.HOSTED_OPENAI_API_KEY = get_env('HOSTED_OPENAI_API_KEY')
self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
self.HOSTED_OPENAI_QUOTA_LIMIT = get_env('HOSTED_OPENAI_QUOTA_LIMIT')
self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT'))
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID')
self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA'))

self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
self.HOSTED_AZURE_OPENAI_API_BASE = get_env('HOSTED_AZURE_OPENAI_API_BASE')
self.HOSTED_AZURE_OPENAI_QUOTA_LIMIT = get_env('HOSTED_AZURE_OPENAI_QUOTA_LIMIT')
self.HOSTED_AZURE_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_AZURE_OPENAI_QUOTA_LIMIT'))

self.HOSTED_ANTHROPIC_ENABLED = get_bool_env('HOSTED_ANTHROPIC_ENABLED')
self.HOSTED_ANTHROPIC_API_BASE = get_env('HOSTED_ANTHROPIC_API_BASE')
self.HOSTED_ANTHROPIC_API_KEY = get_env('HOSTED_ANTHROPIC_API_KEY')
self.HOSTED_ANTHROPIC_QUOTA_LIMIT = get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT')
self.HOSTED_ANTHROPIC_QUOTA_LIMIT = int(get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT'))
self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED')
self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID')
self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA')
self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = int(get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA'))
self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'))
self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'))

self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
Expand Down
10 changes: 9 additions & 1 deletion api/controllers/console/webhook/stripe.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,20 @@ def post(self):
logging.debug(event['data']['object']['payment_status'])
logging.debug(event['data']['object']['metadata'])

session = stripe.checkout.Session.retrieve(
event['data']['object']['id'],
expand=['line_items'],
)

logging.debug(session.line_items['data'][0]['quantity'])

# Fulfill the purchase...
provider_checkout_service = ProviderCheckoutService()

try:
provider_checkout_service.fulfill_provider_order(event)
provider_checkout_service.fulfill_provider_order(event, session.line_items)
except Exception as e:

logging.debug(str(e))
return 'success', 200

Expand Down
2 changes: 2 additions & 0 deletions api/core/model_providers/models/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def run(self, messages: List[PromptMessage],
completion_tokens = self.get_num_tokens([PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
total_tokens = prompt_tokens + completion_tokens

self.model_provider.update_last_used()

if self.deduct_quota:
self.model_provider.deduct_quota(total_tokens)

Expand Down
2 changes: 2 additions & 0 deletions api/core/model_providers/providers/anthropic_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ def get_payment_info(self) -> Optional[dict]:
return {
'product_id': hosted_model_providers.anthropic.paid_stripe_price_id,
'increase_quota': hosted_model_providers.anthropic.paid_increase_quota,
'min_quantity': hosted_model_providers.anthropic.paid_min_quantity,
'max_quantity': hosted_model_providers.anthropic.paid_max_quantity,
}

return None
Expand Down
6 changes: 5 additions & 1 deletion api/core/model_providers/providers/hosted.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ class HostedAnthropic(BaseModel):
"""Quota limit for the anthropic hosted model. 0 means unlimited."""
paid_enabled: bool = False
paid_stripe_price_id: str = None
paid_increase_quota: int = 1
paid_increase_quota: int = 1000000
paid_min_quantity: int = 20
paid_max_quantity: int = 100


class HostedModelProviders(BaseModel):
Expand Down Expand Up @@ -73,4 +75,6 @@ def init_app(app: Flask):
paid_enabled=app.config.get("HOSTED_ANTHROPIC_PAID_ENABLED"),
paid_stripe_price_id=app.config.get("HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID"),
paid_increase_quota=app.config.get("HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA"),
paid_min_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY"),
paid_max_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY"),
)
5 changes: 3 additions & 2 deletions api/core/model_providers/rules/anthropic.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
],
"system_config": {
"supported_quota_types": [
"paid",
"trial"
],
"quota_unit": "times",
"quota_limit": 1000
"quota_unit": "tokens",
"quota_limit": 600000
},
"model_flexibility": "fixed"
}
34 changes: 25 additions & 9 deletions api/services/provider_checkout_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def create_checkout(self, tenant_id: str, provider_name: str, account: Account)
raise ValueError(f'provider name {provider_name} not support payment')

payment_product_id = payment_info['product_id']
payment_min_quantity = payment_info['min_quantity']
payment_max_quantity = payment_info['max_quantity']

# create provider order
provider_order = ProviderOrder(
Expand All @@ -53,18 +55,29 @@ def create_checkout(self, tenant_id: str, provider_name: str, account: Account)
db.session.add(provider_order)
db.session.flush()

line_item = {
'price': f'{payment_product_id}',
'quantity': payment_min_quantity
}

if payment_min_quantity > 1 and payment_max_quantity != payment_min_quantity:
line_item['adjustable_quantity'] = {
'enabled': True,
'minimum': payment_min_quantity,
'maximum': payment_max_quantity
}

try:
# create stripe checkout session
checkout_session = stripe.checkout.Session.create(
line_items=[
{
'price': f'{payment_product_id}',
'quantity': 1,
},
line_item
],
mode='payment',
success_url=current_app.config.get("CONSOLE_WEB_URL") + '?provider_payment=succeeded',
cancel_url=current_app.config.get("CONSOLE_WEB_URL") + '?provider_payment=cancelled',
success_url=current_app.config.get("CONSOLE_WEB_URL")
+ f'?provider_name={provider_name}&payment_result=succeeded',
cancel_url=current_app.config.get("CONSOLE_WEB_URL")
+ f'?provider_name={provider_name}&payment_result=cancelled',
automatic_tax={'enabled': True},
)
except Exception as e:
Expand All @@ -76,7 +89,7 @@ def create_checkout(self, tenant_id: str, provider_name: str, account: Account)

return ProviderCheckout(checkout_session)

def fulfill_provider_order(self, event):
def fulfill_provider_order(self, event, line_items):
provider_order = db.session.query(ProviderOrder) \
.filter(ProviderOrder.payment_id == event['data']['object']['id']) \
.first()
Expand All @@ -85,7 +98,8 @@ def fulfill_provider_order(self, event):
raise ValueError(f'provider order not found, payment id: {event["data"]["object"]["id"]}')

if provider_order.payment_status != ProviderOrderPaymentStatus.WAIT_PAY.value:
raise ValueError(f'provider order payment status is not wait pay, payment id: {event["data"]["object"]["id"]}')
raise ValueError(
f'provider order payment status is not wait pay, payment id: {event["data"]["object"]["id"]}')

provider_order.transaction_id = event['data']['object']['payment_intent']
provider_order.currency = event['data']['object']['currency']
Expand All @@ -110,10 +124,12 @@ def fulfill_provider_order(self, event):
model_provider = model_provider_class(provider=provider)
payment_info = model_provider.get_payment_info()

quantity = line_items['data'][0]['quantity']

if not payment_info:
increase_quota = 0
else:
increase_quota = int(payment_info['increase_quota'])
increase_quota = int(payment_info['increase_quota']) * quantity

if increase_quota > 0:
provider.quota_limit += increase_quota
Expand Down
6 changes: 4 additions & 2 deletions api/services/provider_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,14 @@ def get_provider_list(self, tenant_id: str):
provider_parameter_dict[key]['is_valid'] = provider.is_valid
provider_parameter_dict[key]['quota_used'] = provider.quota_used
provider_parameter_dict[key]['quota_limit'] = provider.quota_limit
provider_parameter_dict[key]['last_used'] = provider.last_used
provider_parameter_dict[key]['last_used'] = int(provider.last_used.timestamp()) \
if provider.last_used else None
elif provider.provider_type == ProviderType.CUSTOM.value \
and ProviderType.CUSTOM.value in provider_parameter_dict:
# if custom
key = ProviderType.CUSTOM.value
provider_parameter_dict[key]['last_used'] = provider.last_used
provider_parameter_dict[key]['last_used'] = int(provider.last_used.timestamp()) \
if provider.last_used else None
provider_parameter_dict[key]['is_valid'] = provider.is_valid

if model_provider_rule['model_flexibility'] == 'fixed':
Expand Down

0 comments on commit 9adbead

Please sign in to comment.