Skip to content

Commit

Permalink
feat: add show-history flag
Browse files Browse the repository at this point in the history
  • Loading branch information
kardolus committed Oct 16, 2024
1 parent 782eff6 commit a91e32c
Show file tree
Hide file tree
Showing 8 changed files with 461 additions and 51 deletions.
15 changes: 15 additions & 0 deletions client/historymocks_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 29 additions & 2 deletions cmd/chatgpt/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ var (
GitVersion string
queryMode bool
clearHistory bool
showHistory bool
showVersion bool
newThread bool
showConfig bool
Expand Down Expand Up @@ -171,6 +172,30 @@ func run(cmd *cobra.Command, args []string) error {
return nil
}

if showHistory { // TODO integration test
var targetThread string
if len(args) > 0 {
targetThread = args[0]
} else {
targetThread = cfg.Thread
}

store, err := history.New()
if err != nil {
return err
}

h := history.NewHistory(store)

output, err := h.Print(targetThread)
if err != nil {
return err
}

fmt.Println(output)
return nil
}

if showConfig {
allSettings := viper.AllSettings()

Expand Down Expand Up @@ -531,8 +556,9 @@ func setCustomHelp(rootCmd *cobra.Command) {
printFlagWithPadding("-v, --version", "Display the version information")
printFlagWithPadding("-l, --list-models", "List available models")
printFlagWithPadding("--list-threads", "List available threads")
printFlagWithPadding("--clear-history", "Clear the history of the current thread")
printFlagWithPadding("--delete-thread", "Delete the specified thread")
printFlagWithPadding("--clear-history", "Clear the history of the current thread")
printFlagWithPadding("--show-history [thread]", "Show the human-readable conversation history")
printFlagWithPadding("--set-completions", "Generate autocompletion script for your current shell")
fmt.Println()

Expand Down Expand Up @@ -575,6 +601,7 @@ func setupFlags(rootCmd *cobra.Command) {
rootCmd.PersistentFlags().StringVarP(&promptFile, "prompt", "p", "", "Provide a prompt file")
rootCmd.PersistentFlags().BoolVarP(&listThreads, "list-threads", "", false, "List available threads")
rootCmd.PersistentFlags().StringVar(&threadName, "delete-thread", "", "Delete the specified thread")
rootCmd.PersistentFlags().BoolVar(&showHistory, "show-history", false, "Show the human-readable conversation history")
rootCmd.PersistentFlags().StringVar(&shell, "set-completions", "", "Generate autocompletion script for your current shell")
}

Expand Down Expand Up @@ -608,7 +635,7 @@ func isNonConfigSetter(name string) bool {

func isGeneralFlag(name string) bool {
switch name {
case "query", "interactive", "config", "version", "new-thread", "list-models", "list-threads", "clear-history", "delete-thread", "prompt", "set-completions", "help":
case "query", "interactive", "config", "version", "new-thread", "list-models", "list-threads", "clear-history", "delete-thread", "show-history", "prompt", "set-completions", "help":
return true
default:
return false
Expand Down
88 changes: 44 additions & 44 deletions configmanager/configmanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,28 +178,28 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
})

it("should prioritize environment variables over default config", func() {
os.Setenv(envPrefix+"API_KEY", "env-api-key")
os.Setenv(envPrefix+"MODEL", "env-model")
os.Setenv(envPrefix+"MAX_TOKENS", "15")
os.Setenv(envPrefix+"CONTEXT_WINDOW", "25")
os.Setenv(envPrefix+"URL", "env-url")
os.Setenv(envPrefix+"COMPLETIONS_PATH", "env-completions-path")
os.Setenv(envPrefix+"MODELS_PATH", "env-models-path")
os.Setenv(envPrefix+"AUTH_HEADER", "env-auth-header")
os.Setenv(envPrefix+"AUTH_TOKEN_PREFIX", "env-auth-token-prefix")
os.Setenv(envPrefix+"OMIT_HISTORY", "true")
os.Setenv(envPrefix+"AUTO_CREATE_NEW_THREAD", "true")
os.Setenv(envPrefix+"TRACK_TOKEN_USAGE", "true")
os.Setenv(envPrefix+"DEBUG", "true")
os.Setenv(envPrefix+"SKIP_TLS_VERIFY", "true")
os.Setenv(envPrefix+"ROLE", "env-role")
os.Setenv(envPrefix+"THREAD", "env-thread")
os.Setenv(envPrefix+"TEMPERATURE", "2.2")
os.Setenv(envPrefix+"TOP_P", "3.3")
os.Setenv(envPrefix+"FREQUENCY_PENALTY", "4.4")
os.Setenv(envPrefix+"PRESENCE_PENALTY", "5.5")
os.Setenv(envPrefix+"COMMAND_PROMPT", "env-command-prompt")
os.Setenv(envPrefix+"OUTPUT_PROMPT", "env-output-prompt")
Expect(os.Setenv(envPrefix+"API_KEY", "env-api-key")).To(Succeed())
Expect(os.Setenv(envPrefix+"MODEL", "env-model")).To(Succeed())
Expect(os.Setenv(envPrefix+"MAX_TOKENS", "15")).To(Succeed())
Expect(os.Setenv(envPrefix+"CONTEXT_WINDOW", "25")).To(Succeed())
Expect(os.Setenv(envPrefix+"URL", "env-url")).To(Succeed())
Expect(os.Setenv(envPrefix+"COMPLETIONS_PATH", "env-completions-path")).To(Succeed())
Expect(os.Setenv(envPrefix+"MODELS_PATH", "env-models-path")).To(Succeed())
Expect(os.Setenv(envPrefix+"AUTH_HEADER", "env-auth-header")).To(Succeed())
Expect(os.Setenv(envPrefix+"AUTH_TOKEN_PREFIX", "env-auth-token-prefix")).To(Succeed())
Expect(os.Setenv(envPrefix+"OMIT_HISTORY", "true")).To(Succeed())
Expect(os.Setenv(envPrefix+"AUTO_CREATE_NEW_THREAD", "true")).To(Succeed())
Expect(os.Setenv(envPrefix+"TRACK_TOKEN_USAGE", "true")).To(Succeed())
Expect(os.Setenv(envPrefix+"DEBUG", "true")).To(Succeed())
Expect(os.Setenv(envPrefix+"SKIP_TLS_VERIFY", "true")).To(Succeed())
Expect(os.Setenv(envPrefix+"ROLE", "env-role")).To(Succeed())
Expect(os.Setenv(envPrefix+"THREAD", "env-thread")).To(Succeed())
Expect(os.Setenv(envPrefix+"TEMPERATURE", "2.2")).To(Succeed())
Expect(os.Setenv(envPrefix+"TOP_P", "3.3")).To(Succeed())
Expect(os.Setenv(envPrefix+"FREQUENCY_PENALTY", "4.4")).To(Succeed())
Expect(os.Setenv(envPrefix+"PRESENCE_PENALTY", "5.5")).To(Succeed())
Expect(os.Setenv(envPrefix+"COMMAND_PROMPT", "env-command-prompt")).To(Succeed())
Expect(os.Setenv(envPrefix+"OUTPUT_PROMPT", "env-output-prompt")).To(Succeed())

mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
mockConfigStore.EXPECT().Read().Return(types.Config{}, errors.New("config error")).Times(1)
Expand Down Expand Up @@ -231,28 +231,28 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
})

it("should prioritize environment variables over user-provided config", func() {
os.Setenv(envPrefix+"API_KEY", "env-api-key")
os.Setenv(envPrefix+"MODEL", "env-model")
os.Setenv(envPrefix+"MAX_TOKENS", "15")
os.Setenv(envPrefix+"CONTEXT_WINDOW", "25")
os.Setenv(envPrefix+"URL", "env-url")
os.Setenv(envPrefix+"COMPLETIONS_PATH", "env-completions-path")
os.Setenv(envPrefix+"MODELS_PATH", "env-models-path")
os.Setenv(envPrefix+"AUTH_HEADER", "env-auth-header")
os.Setenv(envPrefix+"AUTH_TOKEN_PREFIX", "env-auth-token-prefix")
os.Setenv(envPrefix+"OMIT_HISTORY", "true")
os.Setenv(envPrefix+"AUTO_CREATE_NEW_THREAD", "true")
os.Setenv(envPrefix+"TRACK_TOKEN_USAGE", "true")
os.Setenv(envPrefix+"SKIP_TLS_VERIFY", "true")
os.Setenv(envPrefix+"DEBUG", "false")
os.Setenv(envPrefix+"ROLE", "env-role")
os.Setenv(envPrefix+"THREAD", "env-thread")
os.Setenv(envPrefix+"TEMPERATURE", "2.2")
os.Setenv(envPrefix+"TOP_P", "3.3")
os.Setenv(envPrefix+"FREQUENCY_PENALTY", "4.4")
os.Setenv(envPrefix+"PRESENCE_PENALTY", "5.5")
os.Setenv(envPrefix+"COMMAND_PROMPT", "env-command-prompt")
os.Setenv(envPrefix+"OUTPUT_PROMPT", "env-output-prompt")
Expect(os.Setenv(envPrefix+"API_KEY", "env-api-key")).To(Succeed())
Expect(os.Setenv(envPrefix+"MODEL", "env-model")).To(Succeed())
Expect(os.Setenv(envPrefix+"MAX_TOKENS", "15")).To(Succeed())
Expect(os.Setenv(envPrefix+"CONTEXT_WINDOW", "25")).To(Succeed())
Expect(os.Setenv(envPrefix+"URL", "env-url")).To(Succeed())
Expect(os.Setenv(envPrefix+"COMPLETIONS_PATH", "env-completions-path")).To(Succeed())
Expect(os.Setenv(envPrefix+"MODELS_PATH", "env-models-path")).To(Succeed())
Expect(os.Setenv(envPrefix+"AUTH_HEADER", "env-auth-header")).To(Succeed())
Expect(os.Setenv(envPrefix+"AUTH_TOKEN_PREFIX", "env-auth-token-prefix")).To(Succeed())
Expect(os.Setenv(envPrefix+"OMIT_HISTORY", "true")).To(Succeed())
Expect(os.Setenv(envPrefix+"AUTO_CREATE_NEW_THREAD", "true")).To(Succeed())
Expect(os.Setenv(envPrefix+"TRACK_TOKEN_USAGE", "true")).To(Succeed())
Expect(os.Setenv(envPrefix+"SKIP_TLS_VERIFY", "true")).To(Succeed())
Expect(os.Setenv(envPrefix+"DEBUG", "false")).To(Succeed())
Expect(os.Setenv(envPrefix+"ROLE", "env-role")).To(Succeed())
Expect(os.Setenv(envPrefix+"THREAD", "env-thread")).To(Succeed())
Expect(os.Setenv(envPrefix+"TEMPERATURE", "2.2")).To(Succeed())
Expect(os.Setenv(envPrefix+"TOP_P", "3.3")).To(Succeed())
Expect(os.Setenv(envPrefix+"FREQUENCY_PENALTY", "4.4")).To(Succeed())
Expect(os.Setenv(envPrefix+"PRESENCE_PENALTY", "5.5")).To(Succeed())
Expect(os.Setenv(envPrefix+"COMMAND_PROMPT", "env-command-prompt")).To(Succeed())
Expect(os.Setenv(envPrefix+"OUTPUT_PROMPT", "env-output-prompt")).To(Succeed())

userConfig := types.Config{
APIKey: "user-api-key",
Expand Down
82 changes: 82 additions & 0 deletions history/history.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package history

import (
"fmt"
"github.com/kardolus/chatgpt-cli/types"
"strings"
)

const (
assistantRole = "assistant"
systemRole = "system"
userRole = "user"
)

type History struct {
store HistoryStore
}

func NewHistory(store HistoryStore) *History {
return &History{store: store}
}

func (h *History) Print(thread string) (string, error) {
var result string

messages, err := h.store.ReadThread(thread)
if err != nil {
return "", err
}

var (
lastRole string
concatenatedMessage string
)

for _, message := range messages {
if message.Role == userRole && lastRole == userRole {
concatenatedMessage += message.Content
} else {
if lastRole == userRole && concatenatedMessage != "" {
result += formatMessage(types.Message{Role: userRole, Content: concatenatedMessage})
concatenatedMessage = ""
}

if message.Role == userRole {
concatenatedMessage = message.Content
} else {
result += formatMessage(message)
}
}

lastRole = message.Role
}

// Handle the case where the last message is a user message and was concatenated
if lastRole == userRole && concatenatedMessage != "" {
result += formatMessage(types.Message{Role: userRole, Content: concatenatedMessage})
}

return result, nil
}

func formatMessage(msg types.Message) string {
var (
emoji string
prefix string
)

switch msg.Role {
case systemRole:
emoji = "💻"
prefix = "\n"
case userRole:
emoji = "👤"
prefix = "---\n"
case assistantRole:
emoji = "🤖"
prefix = "\n"
}

return fmt.Sprintf("%s**%s** %s:\n%s\n", prefix, strings.ToUpper(msg.Role), emoji, msg.Content)
}
92 changes: 92 additions & 0 deletions history/history_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package history_test

import (
"errors"
"github.com/golang/mock/gomock"
"github.com/kardolus/chatgpt-cli/history"
"github.com/kardolus/chatgpt-cli/types"
. "github.com/onsi/gomega"
"github.com/sclevine/spec"
"github.com/sclevine/spec/report"
"testing"
)

//go:generate mockgen -destination=historymocks_test.go -package=history_test github.com/kardolus/chatgpt-cli/history HistoryStore

var (
mockCtrl *gomock.Controller
mockHistoryStore *MockHistoryStore
subject *history.History
)

func TestUnitHistory(t *testing.T) {
spec.Run(t, "Testing the History", testHistory, spec.Report(report.Terminal{}))
}

func testHistory(t *testing.T, when spec.G, it spec.S) {
it.Before(func() {
RegisterTestingT(t)
mockCtrl = gomock.NewController(t)
mockHistoryStore = NewMockHistoryStore(mockCtrl)
subject = history.NewHistory(mockHistoryStore)
})

it.After(func() {
mockCtrl.Finish()
})

when("Print()", func() {
const threadName = "threadName"

it("throws an error when there is a problem talking to the store", func() {
mockHistoryStore.EXPECT().ReadThread(threadName).Return(nil, errors.New("nope")).Times(1)

_, err := subject.Print(threadName)
Expect(err).To(HaveOccurred())
})

it("concatenates multiple user messages", func() {
messages := []types.Message{
{Role: "user", Content: "first message"},
{Role: "user", Content: " second message"},
{Role: "assistant", Content: "response"},
}

mockHistoryStore.EXPECT().ReadThread(threadName).Return(messages, nil).Times(1)

result, err := subject.Print(threadName)
Expect(err).NotTo(HaveOccurred())
Expect(result).To(ContainSubstring("**USER** 👤:\nfirst message second message\n"))
Expect(result).To(ContainSubstring("**ASSISTANT** 🤖:\nresponse\n"))
})

it("prints all roles correctly", func() {
messages := []types.Message{
{Role: "system", Content: "system message"},
{Role: "user", Content: "user message"},
{Role: "assistant", Content: "assistant message"},
}

mockHistoryStore.EXPECT().ReadThread(threadName).Return(messages, nil).Times(1)

result, err := subject.Print(threadName)
Expect(err).NotTo(HaveOccurred())
Expect(result).To(ContainSubstring("**SYSTEM** 💻:\nsystem message\n"))
Expect(result).To(ContainSubstring("\n---\n**USER** 👤:\nuser message\n"))
Expect(result).To(ContainSubstring("**ASSISTANT** 🤖:\nassistant message\n"))
})

it("handles the final user message concatenation", func() {
messages := []types.Message{
{Role: "user", Content: "first message"},
{Role: "user", Content: " second message"},
}

mockHistoryStore.EXPECT().ReadThread(threadName).Return(messages, nil).Times(1)

result, err := subject.Print(threadName)
Expect(err).NotTo(HaveOccurred())
Expect(result).To(ContainSubstring("**USER** 👤:\nfirst message second message\n"))
})
})
}
Loading

0 comments on commit a91e32c

Please sign in to comment.