Skip to content

Commit

Permalink
feat: support binary data handling
Browse files Browse the repository at this point in the history
  • Loading branch information
kardolus committed Jan 4, 2025
1 parent 5c6bd63 commit 7229f1d
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 32 deletions.
110 changes: 78 additions & 32 deletions api/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ func (c *Client) ListModels() ([]string, error) {
// and the method will split it into messages, preserving punctuation and special
// characters.
func (c *Client) ProvideContext(context string) {
if len(c.Config.Binary) > 0 {
return
}

c.initHistory()
historyEntries := c.createHistoryEntriesFromString(context)
c.History = append(c.History, historyEntries...)
Expand Down Expand Up @@ -253,39 +257,20 @@ func (c *Client) createBody(stream bool) ([]byte, error) {
Stream: stream,
}

if c.Config.Image != "" {
var content api.ImageContent

if isValidURL(c.Config.Image) {
content = api.ImageContent{
Type: imageURLType,
ImageURL: struct {
URL string `json:"url"`
}{
URL: c.Config.Image,
},
}
} else {
mime, err := c.getMimeTypeFromFileContent(c.Config.Image)
if err != nil {
return nil, err
}

image, err := c.base64EncodeImage(c.Config.Image)
if err != nil {
return nil, err
}

content = api.ImageContent{
Type: imageURLType,
ImageURL: struct {
URL string `json:"url"`
}{
URL: fmt.Sprintf(imageContent, mime, image),
},
}
if len(c.Config.Binary) > 0 {
content, err := c.createImageContentFromBinary(c.Config.Binary)
if err != nil {
return nil, err
}
body.Messages = append(body.Messages, api.Message{
Role: UserRole,
Content: []api.ImageContent{content},
})
} else if c.Config.Image != "" {
content, err := c.createImageContentFromURLOrFile(c.Config.Image)
if err != nil {
return nil, err
}

body.Messages = append(body.Messages, api.Message{
Role: UserRole,
Content: []api.ImageContent{content},
Expand All @@ -295,6 +280,61 @@ func (c *Client) createBody(stream bool) ([]byte, error) {
return json.Marshal(body)
}

func (c *Client) createImageContentFromBinary(binary []byte) (api.ImageContent, error) {
mime, err := c.getMimeTypeFromBytes(binary)
if err != nil {
return api.ImageContent{}, err
}

encoded := base64.StdEncoding.EncodeToString(binary)
content := api.ImageContent{
Type: imageURLType,
ImageURL: struct {
URL string `json:"url"`
}{
URL: fmt.Sprintf(imageContent, mime, encoded),
},
}

return content, nil
}

func (c *Client) createImageContentFromURLOrFile(image string) (api.ImageContent, error) {
var content api.ImageContent

if isValidURL(image) {
content = api.ImageContent{
Type: imageURLType,
ImageURL: struct {
URL string `json:"url"`
}{
URL: image,
},
}
} else {
mime, err := c.getMimeTypeFromFileContent(image)
if err != nil {
return content, err
}

encodedImage, err := c.base64EncodeImage(image)
if err != nil {
return content, err
}

content = api.ImageContent{
Type: imageURLType,
ImageURL: struct {
URL string `json:"url"`
}{
URL: fmt.Sprintf(imageContent, mime, encodedImage),
},
}
}

return content, nil
}

func (c *Client) initHistory() {
if len(c.History) != 0 {
return
Expand Down Expand Up @@ -422,6 +462,12 @@ func (c *Client) createHistoryEntriesFromString(input string) []history.History
return result
}

func (c *Client) getMimeTypeFromBytes(data []byte) (string, error) {
mimeType := stdhttp.DetectContentType(data)

return mimeType, nil
}

func (c *Client) getMimeTypeFromFileContent(path string) (string, error) {
file, err := c.reader.Open(path)
if err != nil {
Expand Down
13 changes: 13 additions & 0 deletions api/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,19 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
Expect(contextMessage.Role).To(Equal(client.UserRole))
Expect(contextMessage.Content).To(Equal(context))
})
it("does not update history if Config.Binary is provided", func() {
subject := factory.buildClientWithoutConfig()

subject.Config.Binary = []byte("binary data")

mockHistoryStore.EXPECT().Read().Times(0) // No read should be called, early return happens
mockTimer.EXPECT().Now().Times(0) // No need to mock time since we should not enter the function body

initialHistoryLength := len(subject.History)
subject.ProvideContext("some context")

Expect(len(subject.History)).To(Equal(initialHistoryLength))
})
})
}

Expand Down
4 changes: 4 additions & 0 deletions cmd/chatgpt/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,10 @@ func run(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to read from pipe: %w", err)
}

if utils.IsBinary(pipeContent) {
c.Config.Binary = pipeContent
}

context := string(pipeContent)

if strings.Trim(context, "\n ") != "" {
Expand Down
34 changes: 34 additions & 0 deletions cmd/chatgpt/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"os"
"strings"
"time"
"unicode/utf8"
)

func ColorToAnsi(color string) (string, string) {
Expand Down Expand Up @@ -64,3 +65,36 @@ func FormatPrompt(str string, counter, usage int, now time.Time) string {

return str
}

func IsBinary(data []byte) bool {
if len(data) == 0 {
return false
}

// Only check up to 512KB to avoid memory issues with large files
const maxBytes = 512 * 1024
checkSize := len(data)
if checkSize > maxBytes {
checkSize = maxBytes
}

// Check if the sample is valid UTF-8
if !utf8.Valid(data[:checkSize]) {
return true
}

// Count suspicious bytes in the sample
binaryCount := 0
for _, b := range data[:checkSize] {
if b == 0 {
return true
}

if b < 32 && b != 9 && b != 10 && b != 13 {
binaryCount++
}
}

threshold := checkSize * 10 / 100
return binaryCount > threshold
}
47 changes: 47 additions & 0 deletions cmd/chatgpt/utils/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,51 @@ func testUtils(t *testing.T, when spec.G, it spec.S) {
Expect(utils.FormatPrompt(input, counter, usage, now)).To(Equal(expected))
})
})

when("IsBinary()", func() {
it("should return false for a regular string", func() {
Expect(utils.IsBinary([]byte("regular string"))).To(BeFalse())
})
it("should return false for a string containing emojis", func() {
Expect(utils.IsBinary([]byte("☮️✅❤️"))).To(BeFalse())
})
it("should return true for a binary string", func() {
Expect(utils.IsBinary([]byte{0xFF, 0xFE, 0xFD, 0xFC, 0xFB})).To(BeTrue())
})
it("should return false when the data is empty", func() {
Expect(utils.IsBinary([]byte{})).To(BeFalse())
})
it("should handle large text files correctly", func() {
// Create a large slice > 512KB with normal text
largeText := make([]byte, 1024*1024) // 1MB
for i := range largeText {
largeText[i] = 'a'
}

Expect(utils.IsBinary(largeText)).To(BeFalse())
})
it("should return true when data contains null bytes", func() {
Expect(utils.IsBinary([]byte{'h', 'e', 'l', 'l', 0x00, 'o'})).To(BeTrue())
})

it("should return true for invalid UTF-8 sequences", func() {
// Invalid UTF-8: 0xED 0xA0 0x80 is a surrogate pair which is invalid in UTF-8
Expect(utils.IsBinary([]byte{0xED, 0xA0, 0x80})).To(BeTrue())
})

it("should return false for valid UTF-8 special characters", func() {
// Testing with Chinese characters, Arabic, and other non-ASCII but valid UTF-8
Expect(utils.IsBinary([]byte("你好世界مرحبا"))).To(BeFalse())
})

it("should handle control characters correctly", func() {
// Test with allowed control characters (tab, newline, carriage return)
Expect(utils.IsBinary([]byte("Hello\tWorld\r\nTest"))).To(BeFalse())

// Test with other control characters that should trigger binary detection
data := []byte{0x01, 0x02, 0x03, 0x04}
Expect(utils.IsBinary(data)).To(BeTrue())
})

})
}
1 change: 1 addition & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ type Config struct {
ContextWindow int `yaml:"context_window"`
Role string `yaml:"role"`
Image string `yaml:"image"`
Binary []byte `yaml:"binary"`
Temperature float64 `yaml:"temperature"`
TopP float64 `yaml:"top_p"`
FrequencyPenalty float64 `yaml:"frequency_penalty"`
Expand Down
7 changes: 7 additions & 0 deletions scripts/all-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ if ! golangci-lint run; then
exit 1
fi

# Search for TODOs in the codebase, excluding vendor and scripts directories.
log "Searching for TODOs..."
if ag TODO --ignore-dir vendor --ignore scripts; then
log "Error: Found TODOs in the codebase. Please address them before proceeding."
exit 1
fi

# Run tests in parallel for faster execution
log "Running unit tests..."
cd "$( dirname "${BASH_SOURCE[0]}" )/.."
Expand Down

0 comments on commit 7229f1d

Please sign in to comment.