Skip to content

Commit

Permalink
Merge pull request #34 from nopeless/main
Browse files Browse the repository at this point in the history
[Fix] Add missing max_token in request query
  • Loading branch information
kardolus authored Mar 20, 2024
2 parents 5895a9d + 5e2a53e commit c78bcb7
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 9 deletions.
6 changes: 4 additions & 2 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ import (
"encoding/json"
"errors"
"fmt"
"strings"
"unicode/utf8"

"github.com/kardolus/chatgpt-cli/config"
"github.com/kardolus/chatgpt-cli/configmanager"
"github.com/kardolus/chatgpt-cli/history"
"github.com/kardolus/chatgpt-cli/http"
"github.com/kardolus/chatgpt-cli/types"
"strings"
"unicode/utf8"
)

const (
Expand Down Expand Up @@ -163,6 +164,7 @@ func (c *Client) createBody(stream bool) ([]byte, error) {
body := types.CompletionsRequest{
Messages: c.History,
Model: c.Config.Model,
MaxTokens: c.Config.MaxTokens,
Temperature: c.Config.Temperature,
TopP: c.Config.TopP,
FrequencyPenalty: c.Config.FrequencyPenalty,
Expand Down
14 changes: 9 additions & 5 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
//go:generate mockgen -destination=configmocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/config ConfigStore

const (
defaultMaxTokens = 4096
defaultMaxTokens = 50
defaultURL = "https://default.openai.com"
defaultName = "default-name"
defaultModel = "gpt-3.5-turbo"
Expand Down Expand Up @@ -204,6 +204,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
topP = 200.2
frequencyPenalty = 300.3
presencePenalty = 400.4
maxTokens = 12345
)

messages = createMessages(nil, query)
Expand All @@ -214,9 +215,10 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
TopP: topP,
FrequencyPenalty: frequencyPenalty,
PresencePenalty: presencePenalty,
MaxTokens: maxTokens,
})

body, err = createBodyWithConfig(messages, false, model, temperature, topP, frequencyPenalty, presencePenalty)
body, err = createBodyWithConfig(messages, false, model, maxTokens, temperature, topP, frequencyPenalty, presencePenalty)
Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(subject, nil, body, false)
})
Expand Down Expand Up @@ -464,18 +466,20 @@ func createBody(messages []types.Message, stream bool) ([]byte, error) {
Temperature: defaultTemperature,
TopP: defaultTopP,
FrequencyPenalty: defaultFrequencyPenalty,
MaxTokens: defaultMaxTokens,
PresencePenalty: defaultPresencePenalty,
}

return json.Marshal(req)
}

func createBodyWithConfig(messages []types.Message, stream bool, model string, temperature float64, topP float64, frequencyPenalty float64, presencePenalty float64) ([]byte, error) {
func createBodyWithConfig(messages []types.Message, stream bool, model string, maxTokens int, temperature float64, topP float64, frequencyPenalty float64, presencePenalty float64) ([]byte, error) {
req := types.CompletionsRequest{
Model: model,
Messages: messages,
Stream: stream,
Temperature: temperature,
MaxTokens: maxTokens,
TopP: topP,
FrequencyPenalty: frequencyPenalty,
PresencePenalty: presencePenalty,
Expand Down Expand Up @@ -538,7 +542,7 @@ func (f *clientFactory) buildClientWithoutConfig() *client.Client {
c, err := client.New(mockCallerFactory, f.mockConfigStore, f.mockHistoryStore)
Expect(err).NotTo(HaveOccurred())

return c.WithCapacity(50)
return c
}

func (f *clientFactory) buildClientWithConfig(config types.Config) *client.Client {
Expand All @@ -548,7 +552,7 @@ func (f *clientFactory) buildClientWithConfig(config types.Config) *client.Clien
c, err := client.New(mockCallerFactory, f.mockConfigStore, f.mockHistoryStore)
Expect(err).NotTo(HaveOccurred())

return c.WithCapacity(50)
return c
}

func (f *clientFactory) withoutHistory() {
Expand Down
5 changes: 3 additions & 2 deletions integration/contract_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ func testContract(t *testing.T, when spec.G, it spec.S) {
Role: client.SystemRole,
Content: cfg.Role,
}},
Model: cfg.Model,
Stream: false,
MaxTokens: 1234,
Model: cfg.Model,
Stream: false,
}

bytes, err := json.Marshal(body)
Expand Down
1 change: 1 addition & 0 deletions types/completions.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ type CompletionsRequest struct {
Temperature float64 `json:"temperature"`
TopP float64 `json:"top_p"`
FrequencyPenalty float64 `json:"frequency_penalty"`
MaxTokens int `json:"max_tokens"`
PresencePenalty float64 `json:"presence_penalty"`
Messages []Message `json:"messages"`
Stream bool `json:"stream"`
Expand Down

0 comments on commit c78bcb7

Please sign in to comment.