Skip to content

Commit

Permalink
Add Support for Dall-e-3 (#9)
Browse files Browse the repository at this point in the history
Add support for Dall-e-3
  • Loading branch information
Kardbord authored Nov 7, 2023
1 parent 3816ebc commit 232ba6b
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 25 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ If you use this library, you must conform to Open AI's [Usage Policies](https://

## Other Language Bindings

For another great Go implementation, see [sashabaranov/go-gpt3](https://github.com/sashabaranov/go-gpt3).
For another great Go implementation, see [sashabaranov/go-openai](https://github.com/sashabaranov/go-openai).
For other languages, see [Open AI's Website](https://beta.openai.com/docs/libraries/libraries).

## Contributing
Expand Down
17 changes: 12 additions & 5 deletions examples/images/images-example.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@ func init() {
authentication.SetAPIKey(key)
}

func create() (*images.Response, error) {
func create(model, size string) (*images.Response, error) {
const prompt = "A cute baby sea otter"

fmt.Printf("Creating from prompt: %s\n", prompt)
fmt.Printf("Creating from model=\"%s\", prompt=\"%s\"\n", model, prompt)
resp, _, err := images.MakeModeratedCreationRequest(&images.CreationRequest{
Prompt: prompt,
Size: images.SmallImage,
Size: size,
User: "https://github.com/TannerKvarfordt/gopenai",
Model: model,
}, nil)
if err != nil {
return nil, err
Expand All @@ -39,7 +40,7 @@ func variation(imagename, image string) error {
resp, err := images.MakeVariationRequest(&images.VariationRequest{
Image: image,
ImageName: imagename,
Size: images.SmallImage,
Size: images.Dalle2SmallImage,
User: "https://github.com/TannerKvarfordt/gopenai",
}, nil)
if err != nil {
Expand All @@ -51,7 +52,7 @@ func variation(imagename, image string) error {
}

func main() {
resp, err := create()
resp, err := create(images.ModelDalle2, images.Dalle2SmallImage)
if err != nil {
fmt.Println(err)
return
Expand All @@ -62,4 +63,10 @@ func main() {
fmt.Println(err)
return
}

_, err = create(images.ModelDalle3, images.Dalle3SquareImage)
if err != nil {
fmt.Println(err)
return
}
}
91 changes: 79 additions & 12 deletions images/images.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,40 +25,85 @@ const (
)

const (
SmallImage string = "256x256"
MediumImage string = "512x512"
LargeImage string = "1024x1024"
Dalle2SmallImage = "256x256"
Dalle2MediumImage = "512x512"
Dalle2LargeImage = "1024x1024"

Dalle3SquareImage = "1024x1024"
Dalle3LandscapeImage = "1792x1024"
Dalle3PortraitImage = "1024x1792"

// Deprecated: Use Dalle2SmallImage instead.
SmallImage = Dalle2SmallImage
// Deprecated: Use Dalle2MediumImage instead.
MediumImage = Dalle2MediumImage
// Deprecated: Use Dalle2LargeImage instead.
LargeImage = Dalle2LargeImage
)

const (
ResponseFormatURL = "url"
ResponseFormatB64JSON = "b64_json"
)

const (
ModelDalle2 = "dall-e-2"
ModelDalle3 = "dall-e-3"
)

const (
QualityStandard = "standard"
QualityHD = "hd"
)

const (
StyleVivid = "vivid"
StyleNatural = "natural"
)

// Response structure for the image API endpoint.
type Response struct {
Created uint64 `json:"created"`
Data []struct {
URL string `json:"url"`
B64JSON string `json:"b64_json"`
URL string `json:"url"`
B64JSON string `json:"b64_json"`
RevisedPrompt string `json:"revised_prompt"`
}
Error *common.ResponseError `json:"error,omitempty"`
}

// Request structure for the image creation API endpoint.
type CreationRequest struct {
// A text description of the desired image(s). The maximum length is 1000 characters.
// A text description of the desired image(s).
// The maximum length is 1000 characters for dall-e-2 and 4000 characters for dall-e-3.
Prompt string `json:"prompt,omitempty"`

// The model to use for image generation.
Model string `json:"model,omitempty"`

// The number of images to generate. Must be between 1 and 10.
// For dall-e-3, only n=1 is supported.
N *uint64 `json:"n,omitempty"`

// The size of the generated images. Must be one of 256x256, 512x512, or 1024x1024.
Size string `json:"size,omitempty"`
// The quality of the image that will be generated.
// "hd" creates images with finer details and greater consistency across the image.
// This param is only supported for dall-e-3.
Quality string `json:"quality,omitempty"`

// The format in which the generated images are returned. Must be one of url or b64_json.
ResponseFormat string `json:"response_format,omitempty"`

// The size of the generated images.
// Must be one of 256x256, 512x512, or 1024x1024 for dall-e-2.
// Must be one of 1024x1024, 1792x1024, or 1024x1792 for dall-e-3 models.
Size string `json:"size,omitempty"`

// The style of the generated images. Must be one of vivid or natural.
// Vivid causes the model to lean towards generating hyper-real and dramatic images.
// Natural causes the model to produce more natural, less hyper-real looking images.
// This param is only supported for dall-e-3.
Style string `json:"style,omitempty"`

// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
User string `json:"user,omitempty"`
}
Expand Down Expand Up @@ -111,6 +156,9 @@ type EditRequest struct {
// any path information.
ImageName string `json:"-"`

// A text description of the desired image(s). The maximum length is 1000 characters.
Prompt string `json:"prompt,omitempty"`

// An additional image whose fully transparent areas (e.g. where alpha is zero)
// indicate where image should be edited. Must be a valid PNG file, less than 4MB,
// and have the same dimensions as image.
Expand All @@ -120,8 +168,8 @@ type EditRequest struct {
// path information.
MaskName string `json:"-"`

// A text description of the desired image(s). The maximum length is 1000 characters.
Prompt string `json:"prompt,omitempty"`
// The model to use for image generation. Only dall-e-2 is supported at this time.
Model string `json:"model,omitempty"`

// The number of images to generate. Must be between 1 and 10.
N *uint64 `json:"n,omitempty"`
Expand All @@ -145,14 +193,15 @@ func MakeEditRequest(request *EditRequest, organizationID *string) (*Response, e
buf := new(bytes.Buffer)
writer := multipart.NewWriter(buf)

var err error

if len(request.Prompt) > 0 {
err := common.CreateFormField("prompt", request.Prompt, writer)
err = common.CreateFormField("prompt", request.Prompt, writer)
if err != nil {
return nil, err
}
}

var err error
if request.N != nil {
err = common.CreateFormField("n", request.N, writer)
if err != nil {
Expand Down Expand Up @@ -181,6 +230,13 @@ func MakeEditRequest(request *EditRequest, organizationID *string) (*Response, e
}
}

if len(request.Model) > 0 {
err = common.CreateFormField("model", request.Model, writer)
if err != nil {
return nil, err
}
}

if len(request.Image) > 0 {
err = common.CreateFormFile("image", request.ImageName, request.Image, writer)
if err != nil {
Expand Down Expand Up @@ -240,6 +296,9 @@ type VariationRequest struct {
// any path information.
ImageName string `json:"-"`

// The model to use for image generation. Only dall-e-2 is supported at this time.
Model string `json:"model,omitempty"`

// The number of images to generate. Must be between 1 and 10.
N *uint64 `json:"n,omitempty"`

Expand All @@ -263,6 +322,7 @@ func MakeVariationRequest(request *VariationRequest, organizationID *string) (*R
writer := multipart.NewWriter(buf)

var err error

if request.N != nil {
err = common.CreateFormField("n", request.N, writer)
if err != nil {
Expand Down Expand Up @@ -298,6 +358,13 @@ func MakeVariationRequest(request *VariationRequest, organizationID *string) (*R
}
}

if len(request.Model) > 0 {
err = common.CreateFormField("model", request.Model, writer)
if err != nil {
return nil, err
}
}

writer.Close()
r, err := common.MakeRequestWithForm[Response](buf, VariationEndpoint, http.MethodPost, writer.FormDataContentType(), organizationID)
if err != nil {
Expand Down
23 changes: 16 additions & 7 deletions images/images_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@ func init() {
authentication.SetAPIKey(key)
}

func create() (*images.Response, error) {
func create(model, size string) (*images.Response, error) {
const prompt = "A cute baby sea otter"

fmt.Printf("Creating from prompt: %s\n", prompt)
resp, err := images.MakeCreationRequest(&images.CreationRequest{
Prompt: prompt,
Size: images.SmallImage,
Size: size,
User: "https://github.com/TannerKvarfordt/gopenai",
Model: model,
}, nil)
if err != nil {
return nil, err
Expand All @@ -38,14 +39,15 @@ func create() (*images.Response, error) {
return resp, nil
}

func variation(imagename, image string) error {
func variation(model, imagename, image string) error {

fmt.Printf("Generating a variation...")
resp, err := images.MakeVariationRequest(&images.VariationRequest{
Image: image,
ImageName: imagename,
Size: images.SmallImage,
Size: images.Dalle2SmallImage,
User: "https://github.com/TannerKvarfordt/gopenai",
Model: model,
}, nil)
if err != nil {
return err
Expand All @@ -58,13 +60,20 @@ func variation(imagename, image string) error {
return nil
}

func TestImages(t *testing.T) {
resp, err := create()
func TestImagesDalle2(t *testing.T) {
resp, err := create(images.ModelDalle2, images.Dalle2SmallImage)
if err != nil {
t.Fatal(err)
}

err = variation("Original", resp.Data[0].URL)
err = variation(images.ModelDalle2, "Original", resp.Data[0].URL)
if err != nil {
t.Fatal(err)
}
}

func TestImagesDalle3(t *testing.T) {
_, err := create(images.ModelDalle3, images.Dalle3SquareImage)
if err != nil {
t.Fatal(err)
}
Expand Down
5 changes: 5 additions & 0 deletions tools/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@ pushd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null
./examples-build.sh

pushd .. >/dev/null
echo "Formatting code..."
go fmt ./...
echo "Running tests..."
go test ./...
echo "Building $(basename "$(pwd)")..."
go build ./...
echo "Vetting..."
go vet ./...
echo "Done."

0 comments on commit 232ba6b

Please sign in to comment.