From 232ba6b508ab86c21fc55a43506ed0465c415741 Mon Sep 17 00:00:00 2001 From: Tanner Kvarfordt Date: Mon, 6 Nov 2023 22:58:17 -0700 Subject: [PATCH] Add Support for Dall-e-3 (#9) Add support for Dall-e-3 --- README.md | 2 +- examples/images/images-example.go | 17 ++++-- images/images.go | 91 +++++++++++++++++++++++++++---- images/images_test.go | 23 +++++--- tools/build.sh | 5 ++ 5 files changed, 113 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index 81b511b..266c8aa 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ If you use this library, you must conform to Open AI's [Usage Policies](https:// ## Other Language Bindings -For another great Go implementation, see [sashabaranov/go-gpt3](https://github.com/sashabaranov/go-gpt3). +For another great Go implementation, see [sashabaranov/go-openai](https://github.com/sashabaranov/go-openai). For other languages, see [Open AI's Website](https://beta.openai.com/docs/libraries/libraries). ## Contributing diff --git a/examples/images/images-example.go b/examples/images/images-example.go index 28bdd01..90a0a93 100644 --- a/examples/images/images-example.go +++ b/examples/images/images-example.go @@ -16,14 +16,15 @@ func init() { authentication.SetAPIKey(key) } -func create() (*images.Response, error) { +func create(model, size string) (*images.Response, error) { const prompt = "A cute baby sea otter" - fmt.Printf("Creating from prompt: %s\n", prompt) + fmt.Printf("Creating from model=\"%s\", prompt=\"%s\"\n", model, prompt) resp, _, err := images.MakeModeratedCreationRequest(&images.CreationRequest{ Prompt: prompt, - Size: images.SmallImage, + Size: size, User: "https://github.com/TannerKvarfordt/gopenai", + Model: model, }, nil) if err != nil { return nil, err @@ -39,7 +40,7 @@ func variation(imagename, image string) error { resp, err := images.MakeVariationRequest(&images.VariationRequest{ Image: image, ImageName: imagename, - Size: images.SmallImage, + Size: images.Dalle2SmallImage, User: "https://github.com/TannerKvarfordt/gopenai", }, nil) if err != nil { @@ -51,7 +52,7 @@ func variation(imagename, image string) error { } func main() { - resp, err := create() + resp, err := create(images.ModelDalle2, images.Dalle2SmallImage) if err != nil { fmt.Println(err) return @@ -62,4 +63,10 @@ func main() { fmt.Println(err) return } + + _, err = create(images.ModelDalle3, images.Dalle3SquareImage) + if err != nil { + fmt.Println(err) + return + } } diff --git a/images/images.go b/images/images.go index 95b2759..3e7e6b1 100644 --- a/images/images.go +++ b/images/images.go @@ -25,9 +25,20 @@ const ( ) const ( - SmallImage string = "256x256" - MediumImage string = "512x512" - LargeImage string = "1024x1024" + Dalle2SmallImage = "256x256" + Dalle2MediumImage = "512x512" + Dalle2LargeImage = "1024x1024" + + Dalle3SquareImage = "1024x1024" + Dalle3LandscapeImage = "1792x1024" + Dalle3PortraitImage = "1024x1792" + + // Deprecated: Use Dalle2SmallImage instead. + SmallImage = Dalle2SmallImage + // Deprecated: Use Dalle2MediumImage instead. + MediumImage = Dalle2MediumImage + // Deprecated: Use Dalle2LargeImage instead. + LargeImage = Dalle2LargeImage ) const ( @@ -35,30 +46,64 @@ const ( ResponseFormatB64JSON = "b64_json" ) +const ( + ModelDalle2 = "dall-e-2" + ModelDalle3 = "dall-e-3" +) + +const ( + QualityStandard = "standard" + QualityHD = "hd" +) + +const ( + StyleVivid = "vivid" + StyleNatural = "natural" +) + // Response structure for the image API endpoint. type Response struct { Created uint64 `json:"created"` Data []struct { - URL string `json:"url"` - B64JSON string `json:"b64_json"` + URL string `json:"url"` + B64JSON string `json:"b64_json"` + RevisedPrompt string `json:"revised_prompt"` } Error *common.ResponseError `json:"error,omitempty"` } // Request structure for the image creation API endpoint. type CreationRequest struct { - // A text description of the desired image(s). The maximum length is 1000 characters. + // A text description of the desired image(s). + // The maximum length is 1000 characters for dall-e-2 and 4000 characters for dall-e-3. Prompt string `json:"prompt,omitempty"` + // The model to use for image generation. + Model string `json:"model,omitempty"` + // The number of images to generate. Must be between 1 and 10. + // For dall-e-3, only n=1 is supported. N *uint64 `json:"n,omitempty"` - // The size of the generated images. Must be one of 256x256, 512x512, or 1024x1024. - Size string `json:"size,omitempty"` + // The quality of the image that will be generated. + // "hd" creates images with finer details and greater consistency across the image. + // This param is only supported for dall-e-3. + Quality string `json:"quality,omitempty"` // The format in which the generated images are returned. Must be one of url or b64_json. ResponseFormat string `json:"response_format,omitempty"` + // The size of the generated images. + // Must be one of 256x256, 512x512, or 1024x1024 for dall-e-2. + // Must be one of 1024x1024, 1792x1024, or 1024x1792 for dall-e-3 models. + Size string `json:"size,omitempty"` + + // The style of the generated images. Must be one of vivid or natural. + // Vivid causes the model to lean towards generating hyper-real and dramatic images. + // Natural causes the model to produce more natural, less hyper-real looking images. + // This param is only supported for dall-e-3. + Style string `json:"style,omitempty"` + // A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. User string `json:"user,omitempty"` } @@ -111,6 +156,9 @@ type EditRequest struct { // any path information. ImageName string `json:"-"` + // A text description of the desired image(s). The maximum length is 1000 characters. + Prompt string `json:"prompt,omitempty"` + // An additional image whose fully transparent areas (e.g. where alpha is zero) // indicate where image should be edited. Must be a valid PNG file, less than 4MB, // and have the same dimensions as image. @@ -120,8 +168,8 @@ type EditRequest struct { // path information. MaskName string `json:"-"` - // A text description of the desired image(s). The maximum length is 1000 characters. - Prompt string `json:"prompt,omitempty"` + // The model to use for image generation. Only dall-e-2 is supported at this time. + Model string `json:"model,omitempty"` // The number of images to generate. Must be between 1 and 10. N *uint64 `json:"n,omitempty"` @@ -145,14 +193,15 @@ func MakeEditRequest(request *EditRequest, organizationID *string) (*Response, e buf := new(bytes.Buffer) writer := multipart.NewWriter(buf) + var err error + if len(request.Prompt) > 0 { - err := common.CreateFormField("prompt", request.Prompt, writer) + err = common.CreateFormField("prompt", request.Prompt, writer) if err != nil { return nil, err } } - var err error if request.N != nil { err = common.CreateFormField("n", request.N, writer) if err != nil { @@ -181,6 +230,13 @@ func MakeEditRequest(request *EditRequest, organizationID *string) (*Response, e } } + if len(request.Model) > 0 { + err = common.CreateFormField("model", request.Model, writer) + if err != nil { + return nil, err + } + } + if len(request.Image) > 0 { err = common.CreateFormFile("image", request.ImageName, request.Image, writer) if err != nil { @@ -240,6 +296,9 @@ type VariationRequest struct { // any path information. ImageName string `json:"-"` + // The model to use for image generation. Only dall-e-2 is supported at this time. + Model string `json:"model,omitempty"` + // The number of images to generate. Must be between 1 and 10. N *uint64 `json:"n,omitempty"` @@ -263,6 +322,7 @@ func MakeVariationRequest(request *VariationRequest, organizationID *string) (*R writer := multipart.NewWriter(buf) var err error + if request.N != nil { err = common.CreateFormField("n", request.N, writer) if err != nil { @@ -298,6 +358,13 @@ func MakeVariationRequest(request *VariationRequest, organizationID *string) (*R } } + if len(request.Model) > 0 { + err = common.CreateFormField("model", request.Model, writer) + if err != nil { + return nil, err + } + } + writer.Close() r, err := common.MakeRequestWithForm[Response](buf, VariationEndpoint, http.MethodPost, writer.FormDataContentType(), organizationID) if err != nil { diff --git a/images/images_test.go b/images/images_test.go index f775901..18fa86b 100644 --- a/images/images_test.go +++ b/images/images_test.go @@ -18,14 +18,15 @@ func init() { authentication.SetAPIKey(key) } -func create() (*images.Response, error) { +func create(model, size string) (*images.Response, error) { const prompt = "A cute baby sea otter" fmt.Printf("Creating from prompt: %s\n", prompt) resp, err := images.MakeCreationRequest(&images.CreationRequest{ Prompt: prompt, - Size: images.SmallImage, + Size: size, User: "https://github.com/TannerKvarfordt/gopenai", + Model: model, }, nil) if err != nil { return nil, err @@ -38,14 +39,15 @@ func create() (*images.Response, error) { return resp, nil } -func variation(imagename, image string) error { +func variation(model, imagename, image string) error { fmt.Printf("Generating a variation...") resp, err := images.MakeVariationRequest(&images.VariationRequest{ Image: image, ImageName: imagename, - Size: images.SmallImage, + Size: images.Dalle2SmallImage, User: "https://github.com/TannerKvarfordt/gopenai", + Model: model, }, nil) if err != nil { return err @@ -58,13 +60,20 @@ func variation(imagename, image string) error { return nil } -func TestImages(t *testing.T) { - resp, err := create() +func TestImagesDalle2(t *testing.T) { + resp, err := create(images.ModelDalle2, images.Dalle2SmallImage) if err != nil { t.Fatal(err) } - err = variation("Original", resp.Data[0].URL) + err = variation(images.ModelDalle2, "Original", resp.Data[0].URL) + if err != nil { + t.Fatal(err) + } +} + +func TestImagesDalle3(t *testing.T) { + _, err := create(images.ModelDalle3, images.Dalle3SquareImage) if err != nil { t.Fatal(err) } diff --git a/tools/build.sh b/tools/build.sh index f824585..0b4fffa 100755 --- a/tools/build.sh +++ b/tools/build.sh @@ -7,7 +7,12 @@ pushd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null ./examples-build.sh pushd .. >/dev/null +echo "Formatting code..." go fmt ./... +echo "Running tests..." go test ./... +echo "Building $(basename "$(pwd)")..." go build ./... +echo "Vetting..." go vet ./... +echo "Done."