Skip to content

Commit

Permalink
Add list and detail endpoints for known model metadata
Browse files Browse the repository at this point in the history
Add `metadata` arg to /v2/status/models endpoint to request model metadata for available models

Refactor model_reference.py
  • Loading branch information
ceruleandeep committed Sep 27, 2024
1 parent 9ce5b55 commit 0032bd6
Show file tree
Hide file tree
Showing 6 changed files with 470 additions and 76 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ venv/
ENV/
env.bak/
venv.bak/
.env_docker

# Spyder project settings
.spyderproject
Expand Down
248 changes: 242 additions & 6 deletions horde/apis/models/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
#
# SPDX-License-Identifier: AGPL-3.0-or-later

from flask_restx import fields, reqparse
from flask_restx import Namespace, fields, reqparse

from horde.enums import WarningMessage
from horde.exceptions import KNOWN_RC
from horde.model_reference import KnownImageModelRef, KnownTextModelRef
from horde.vars import horde_noun, horde_title


Expand Down Expand Up @@ -260,7 +261,7 @@ def __init__(self):


class Models:
def __init__(self, api):
def __init__(self, api: Namespace):
self.response_model_wp_status_lite = api.model(
"RequestStatusCheck",
{
Expand Down Expand Up @@ -406,7 +407,7 @@ def __init__(self, api):
min=0,
),
"untrusted": fields.Integer(
description=("How many waiting requests were skipped because they demanded a trusted worker which this worker is not."),
description="How many waiting requests were skipped because they demanded a trusted worker which this worker is not.",
min=0,
),
"models": fields.Integer(
Expand Down Expand Up @@ -618,11 +619,11 @@ def __init__(self, api):
"forms": fields.List(fields.String(description="Which forms this worker if offering.")),
"team": fields.Nested(
self.response_model_team_details_lite,
"The Team to which this worker is dedicated.",
description="The Team to which this worker is dedicated.",
),
"contact": fields.String(
example="[email protected]",
description=("(Privileged) Contact details for the horde admins to reach the owner of this worker in emergencies."),
description="(Privileged) Contact details for the horde admins to reach the owner of this worker in emergencies.",
min_length=5,
max_length=500,
),
Expand Down Expand Up @@ -1053,7 +1054,7 @@ def __init__(self, api):
max=10,
),
"worker_invited": fields.Integer(
description=("Set to the amount of workers this user is allowed to join to the horde when in worker invite-only mode."),
description="Set to the amount of workers this user is allowed to join to the horde when in worker invite-only mode.",
),
"moderator": fields.Boolean(
example=False,
Expand Down Expand Up @@ -1278,6 +1279,232 @@ def __init__(self, api):
),
},
)
self.response_model_known_model_md = api.model(
"KnownModelMetadata",
{
"name": fields.String(description="The name of this model."),
"description": fields.String(description="The description of this model."),
"version": fields.String(description="The version of this model."),
"style": fields.String(description="The style of this model."),
"nsfw": fields.Boolean(description="Whether this model can generate NSFW content."),
"baseline": fields.String(description="The baseline model used for this model."),
},
)

settings = api.model(
"KnownTextModelSettings",
{
"n": fields.Integer(example=1, min=1, max=20),
"frmtadsnsp": fields.Boolean(
example=False,
description=(
"Input formatting option. When enabled, adds a leading space to your input "
"if there is no trailing whitespace at the end of the previous action."
),
),
"frmtrmblln": fields.Boolean(
example=False,
description=(
"Output formatting option. When enabled, replaces all occurrences of two or more consecutive newlines "
"in the output with one newline."
),
),
"frmtrmspch": fields.Boolean(
example=False,
description=r"Output formatting option. When enabled, removes #/@%}{+=~|\^<> from the output.",
),
"frmttriminc": fields.Boolean(
example=False,
description=(
"Output formatting option. When enabled, removes some characters from the end of the output such "
"that the output doesn't end in the middle of a sentence. "
"If the output is less than one sentence long, does nothing."
),
),
"max_context_length": fields.Integer(
min=80,
example=1024,
max=32000,
description="Maximum number of tokens to send to the model.",
),
"max_length": fields.Integer(
min=16,
max=1024,
example=80,
description="Number of tokens to generate.",
),
"rep_pen": fields.Float(description="Base repetition penalty value.", min=1, max=3),
"rep_pen_range": fields.Integer(description="Repetition penalty range.", min=0, max=4096),
"rep_pen_slope": fields.Float(description="Repetition penalty slope.", min=0, max=10),
"singleline": fields.Boolean(
example=False,
description=(
"Output formatting option. When enabled, removes everything after the first line of the output, "
"including the newline."
),
),
"temperature": fields.Float(description="Temperature value.", min=0, max=5.0),
"tfs": fields.Float(description="Tail free sampling value.", min=0.0, max=1.0),
"top_a": fields.Float(description="Top-a sampling value.", min=0.0, max=1.0),
"top_k": fields.Integer(description="Top-k sampling value.", min=0, max=100),
"top_p": fields.Float(description="Top-p sampling value.", min=0.001, max=1.0),
"typical": fields.Float(description="Typical sampling value.", min=0.0, max=1.0),
"sampler_order": fields.List(
fields.Integer(description="Array of integers representing the sampler order to be used."),
),
"use_default_badwordsids": fields.Boolean(
example=True,
description="When True, uses the default KoboldAI bad word IDs.",
),
"stop_sequence": fields.List(
fields.String(
description=(
"An array of string sequences whereby the model will stop generating further tokens. "
"The returned text WILL contain the stop sequence."
),
),
),
"min_p": fields.Float(description="Min-p sampling value.", min=0.0, example=0.0, max=1.0),
"smoothing_factor": fields.Float(
description="Quadratic sampling value.",
min=0.0,
example=0.0,
max=10.0,
),
"dynatemp_range": fields.Float(
description="Dynamic temperature range value.",
min=0.0,
example=0.0,
max=5.0,
),
"dynatemp_exponent": fields.Float(
description="Dynamic temperature exponent value.",
min=0.0,
example=1.0,
max=5.0,
),
},
)

self.response_model_known_text_model_md = api.inherit(
"KnownTextModelMetadata",
self.response_model_known_model_md,
{
"parameters": fields.Integer(description="The number of parameters in this model."),
"display_name": fields.String(description="The display name of this model."),
"homepage": fields.String(description="The homepage of the model.", attribute="url"),
"tags": fields.List(fields.String(description="The tags of this model.")),
"instruct_format": fields.String(description="Instruct format template to use for this model."),
"settings": fields.Nested(
settings,
description="Recommended settings for this model.",
allow_null=False,
skip_none=True,
),
},
)

requirements = api.model(
"KnownImageModelRequirements",
{
"clip_skip": fields.Integer(
description="The number of steps to skip in CLIP.",
min=1,
example=1,
),
"min_steps": fields.Integer(
description="The minimum number of steps to take.",
min=1,
example=30,
),
"max_steps": fields.Integer(
description="The maximum number of steps to take.",
min=1,
example=30,
),
"cfg_scale": fields.Float(
description="Classifier-free guidance scale.",
min=0.0,
example=7.5,
),
"min_cfg_scale": fields.Float(
description="Minimum classifier-free guidance scale.",
min=0.0,
example=7.5,
),
"max_cfg_scale": fields.Float(
description="Maximum classifier-free guidance scale.",
min=0.0,
example=7.5,
),
"samplers": fields.List(
fields.String(
description="The samplers to use for this model.",
example="k_euler_a",
),
),
},
)

download = api.model(
"KnownImageModelDownload",
{
"file_name": fields.String(description="The filename of the file to download."),
"file_path": fields.String(description="The path to the file to download."),
"file_url": fields.String(description="The URL to download the file from."),
},
)

config = api.model(
"KnownImageModelConfig",
{
"files": fields.List(fields.Nested(api.model("KnownImageModelFile", {"path": fields.String, "sha256sum": fields.String}))),
"download": fields.List(fields.Nested(download)),
},
)
self.response_model_known_image_model_md = api.inherit(
"KnownImageModelMetadata",
self.response_model_known_model_md,
{
"homepage": fields.String(description="The URL of the model's page."),
"weight_type": fields.String(description="Storage format of the model weights.", attribute="type"),
"inpainting": fields.Boolean(description="Whether this model can generate inpainting content."),
"requirements": fields.Nested(
requirements,
description="Generation settings requirements for this model.",
allow_null=False,
skip_none=True,
),
"config": fields.Nested(
config,
description="The configuration of the model.",
allow_null=False,
skip_none=True,
),
"features_not_supported": fields.List(fields.String(description="The features not supported by the model.")),
"size_on_disk_bytes": fields.Integer(description="The size of the model on disk in bytes."),
},
)

self.response_model_known_model = api.model(
"KnownModel",
{
"name": fields.String(description="The name of this model."),
"type": fields.String(
description="Model type (text or image).",
enum=["text", "image"],
),
"metadata": fields.Polymorph(
{
KnownImageModelRef: self.response_model_known_image_model_md,
KnownTextModelRef: self.response_model_known_text_model_md,
},
description="The metadata of the model.",
skip_none=True,
),
},
)

self.response_model_active_model = api.inherit(
"ActiveModel",
self.response_model_active_model_lite,
Expand All @@ -1291,8 +1518,17 @@ def __init__(self, api):
description="The model type (text or image).",
enum=["image", "text"],
),
"metadata": fields.Polymorph(
{
KnownImageModelRef: self.response_model_known_image_model_md,
KnownTextModelRef: self.response_model_known_text_model_md,
},
description="The metadata of the model.",
skip_none=True,
),
},
)

self.response_model_deleted_worker = api.model(
"DeletedWorker",
{
Expand Down
2 changes: 2 additions & 0 deletions horde/apis/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,5 @@
api.add_resource(base.DocsTerms, "/documents/terms")
api.add_resource(base.DocsPrivacy, "/documents/privacy")
api.add_resource(base.DocsSponsors, "/documents/sponsors")
api.add_resource(base.KnownModels, "/knownmodels")
api.add_resource(base.KnownModelSingle, "/knownmodels/<string:model_name>")
Loading

0 comments on commit 0032bd6

Please sign in to comment.