Skip to content

Commit

Permalink
fixes tmc#1033
Browse files Browse the repository at this point in the history
  • Loading branch information
lonelycode committed Sep 22, 2024
1 parent 6a05e1d commit d00050c
Showing 1 changed file with 98 additions and 10 deletions.
108 changes: 98 additions & 10 deletions llms/anthropic/internal/anthropicclient/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,44 @@ func (tc TextContent) GetType() string {
return tc.Type
}

type PartialJSONContent struct {
Type string `json:"type"`
PartialJSON string `json:"partial_json"`
}

func (tc PartialJSONContent) GetType() string {
return tc.Type
}

type ToolUseContent struct {
Type string `json:"type"`
ID string `json:"id"`
Name string `json:"name"`
Input map[string]interface{} `json:"input"`
Type string `json:"type"`
ID string `json:"id"`
Name string `json:"name"`
Input map[string]interface{} `json:"input"`
rawStreamInput string
}

func (tuc *ToolUseContent) AppendStreamChunk(chunk string) {
tuc.rawStreamInput += chunk
}

func (tuc *ToolUseContent) GetStreamInput() string {
return tuc.rawStreamInput
}

func (tuc *ToolUseContent) DecodeStream() error {
if tuc.rawStreamInput == "" {
return nil
}

m := make(map[string]interface{})
err := json.Unmarshal([]byte(tuc.rawStreamInput), &m)
if err != nil {
return err
}

tuc.Input = m
return nil
}

func (tuc ToolUseContent) GetType() string {
Expand Down Expand Up @@ -261,7 +294,20 @@ func processStreamEvent(ctx context.Context, event map[string]interface{}, paylo
case "content_block_delta":
return handleContentBlockDeltaEvent(ctx, event, response, payload)
case "content_block_stop":
// Nothing to do here
for _, content := range response.Content {
if content == nil {
continue
}
tuc, ok := content.(*ToolUseContent)
if !ok {
continue
}

err := tuc.DecodeStream()
if err != nil {
return response, fmt.Errorf("error decoding stream tool data: %w", err)
}
}
case "message_delta":
return handleMessageDeltaEvent(event, response)
case "message_stop":
Expand Down Expand Up @@ -307,15 +353,38 @@ func handleContentBlockStartEvent(event map[string]interface{}, response Message
index := int(indexValue)

var eventType string
if cb, ok := event["content_block"].(map[string]any); ok {
cb, ok := event["content_block"].(map[string]any)
if ok {
typ, _ := cb["type"].(string)
eventType = typ
}

if len(response.Content) <= index {
response.Content = append(response.Content, &TextContent{
Type: eventType,
})
switch eventType {
case "text":
response.Content = append(response.Content, &TextContent{
Type: eventType,
})
case "tool_use":
toolID, ok := cb["id"].(string)
if !ok {
return response, fmt.Errorf("missing tool id field in content block [start]")
}

toolName, ok := cb["name"].(string)
if !ok {
return response, fmt.Errorf("missing name field in content block [start]")
}

response.Content = append(response.Content, &ToolUseContent{
Type: eventType,
Input: make(map[string]interface{}),
ID: toolID,
Name: toolName,
})
default:
return response, fmt.Errorf("unknown content block type: %s", eventType)
}
}
return response, nil
}
Expand Down Expand Up @@ -351,7 +420,26 @@ func handleContentBlockDeltaEvent(ctx context.Context, event map[string]interfac
textContent.Text += text
}

if payload.StreamingFunc != nil {
streamOutput := true
if deltaType == "input_json_delta" {
streamOutput = false
partial, ok := delta["partial_json"].(string)
if !ok {
return response, fmt.Errorf("partial_json field missing")
}
if len(response.Content) <= index {
return response, ErrContentIndexOutOfRange
}
tuc, ok := response.Content[index].(*ToolUseContent)
if !ok {
asJson, _ := json.MarshalIndent(response, "", " ")
return response, fmt.Errorf("failed to cast index %v to ToolUseContent: \n%s", index, string(asJson))
}

tuc.AppendStreamChunk(partial)
}

if payload.StreamingFunc != nil && streamOutput {
text, ok := delta["text"].(string)
if !ok {
return response, ErrInvalidDeltaTextField
Expand Down

0 comments on commit d00050c

Please sign in to comment.