Skip to content

Commit

Permalink
Fix agent_test.go
Browse files Browse the repository at this point in the history
* Make it useful for testing streaming connections.
  • Loading branch information
jlewi committed Nov 22, 2024
1 parent ddf0b24 commit f562f35
Showing 1 changed file with 117 additions and 23 deletions.
140 changes: 117 additions & 23 deletions app/pkg/agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@ import (
"io"
"net"
"net/http"
"net/url"
"os"
"os/signal"
"testing"
"time"

"connectrpc.com/connect"
"github.com/jlewi/foyle/app/pkg/runme/ulid"

"github.com/google/go-cmp/cmp/cmpopts"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -161,6 +165,31 @@ func Test_Generate(t *testing.T) {
}
}

func Test_StreamingClient(t *testing.T) {
// Test the streaming client
// Useful when testing with a service version
if os.Getenv("GITHUB_ACTIONS") != "" {
t.Skipf("Test is skipped in GitHub actions")
}

// Setup logs
c := zap.NewDevelopmentConfig()
newLog, err := c.Build()
if err != nil {
t.Fatalf("Error creating logger; %v", err)
}
zap.ReplaceGlobals(newLog)
log := zapr.NewLogger(newLog)
// This is code to help us test streaming with the connect protocol
addr := "http://127.0.0.1:8877/api"
//addr := "http://127.0.0.1:9977/api"

log.Info("Server started")
if err := runClient(addr); err != nil {
t.Fatalf("Error running client for addres %v; %v", addr, err)
}
}

func Test_Streaming(t *testing.T) {
if os.Getenv("GITHUB_ACTIONS") != "" {
t.Skipf("Test is skipped in GitHub actions")
Expand All @@ -177,7 +206,7 @@ func Test_Streaming(t *testing.T) {
// This is code to help us test streaming with the connect protocol
a := &Agent{}

addr := "localhost:8088"
addr := "http://localhost:8088"
go func() {
if err := setupAndRunServer(addr, a); err != nil {
log.Error(err, "Error running server")
Expand All @@ -190,7 +219,9 @@ func Test_Streaming(t *testing.T) {
t.Fatalf("Error waiting for server; %v", err)
}
log.Info("Server started")
runClient(addr)
if err := runClient(addr); err != nil {
t.Fatalf("Error running client for addres %v; %v", addr, err)
}
}

func waitForServer(addr string) error {
Expand All @@ -209,38 +240,99 @@ func waitForServer(addr string) error {
}
return errors.Errorf("Server didn't start in time")
}
func runClient(addr string) {

func runClient(baseURL string) error {
log := zapr.NewLogger(zap.L())
baseURL := fmt.Sprintf("http://%s", addr)
client := v1alpha1connect.NewAIServiceClient(
&http.Client{
Transport: &http2.Transport{
AllowHTTP: true,
DialTLSContext: func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) {
// Use the standard Dial function to create a plain TCP connection
return net.Dial(network, addr)
u, err := url.Parse(baseURL)
if err != nil {
log.Error(err, "Failed to parse URL")
panic(err)
}

var client v1alpha1connect.AIServiceClient

if u.Scheme == "https" {
// Configure the TLS settings
tlsConfig := &tls.Config{
InsecureSkipVerify: true, // Set to true only for testing; otherwise validate the server's certificate
}

client = v1alpha1connect.NewAIServiceClient(
&http.Client{
Transport: &http2.Transport{
TLSClientConfig: tlsConfig,
DialTLSContext: func(ctx context.Context, network, addr string, config *tls.Config) (net.Conn, error) {
// Create a secure connection with TLS
return tls.Dial(network, addr, config)
},
},
},
},
baseURL,
)
baseURL,
)
} else {
client = v1alpha1connect.NewAIServiceClient(
&http.Client{
Transport: &http2.Transport{
AllowHTTP: true,
DialTLSContext: func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) {
// Use the standard Dial function to create a plain TCP connection
return net.Dial(network, u.Host)
},
},
},
baseURL,
)
}

// Make sure we can send a status request
statusReq := &v1alpha1.StatusRequest{}
statusResp, err := client.Status(context.Background(), connect.NewRequest(statusReq))
if err != nil {
log.Error(err, "Failed to send status request")
return errors.Wrapf(err, "Failed to send status request")
}

log.Info("Status response", "response", statusResp)

ctx := context.Background()
stream := client.StreamGenerate(ctx)

// Send requests
requests := []string{"Hello", "How are you?", "Goodbye"}
for _, prompt := range requests {

req := &v1alpha1.StreamGenerateRequest{
Request: &v1alpha1.StreamGenerateRequest_Update{
Update: &v1alpha1.UpdateContext{
Cell: &parserv1.Cell{
Kind: parserv1.CellKind_CELL_KIND_MARKUP,
Value: prompt,

contextId := ulid.GenerateID()
for i, prompt := range requests {

var req *v1alpha1.StreamGenerateRequest
if i == 0 {
req = &v1alpha1.StreamGenerateRequest{
ContextId: contextId,
Request: &v1alpha1.StreamGenerateRequest_FullContext{
FullContext: &v1alpha1.FullContext{
Notebook: &parserv1.Notebook{
Cells: []*parserv1.Cell{
{
Kind: parserv1.CellKind_CELL_KIND_MARKUP,
Value: prompt,
},
},
},
NotebookUri: "/path/to/notebook",
},
},
},
}
} else {
req = &v1alpha1.StreamGenerateRequest{
ContextId: contextId,
Request: &v1alpha1.StreamGenerateRequest_Update{
Update: &v1alpha1.UpdateContext{
Cell: &parserv1.Cell{
Kind: parserv1.CellKind_CELL_KIND_MARKUP,
Value: prompt,
},
},
},
}
}
err := stream.Send(req)

Expand All @@ -264,9 +356,11 @@ func runClient(addr string) {
}
if err != nil {
log.Error(err, "Failed to receive response")
return errors.Wrapf(err, "Failed to receive response")
}
log.Info("Received response", "response", response)
}
return nil
}

func setupAndRunServer(addr string, a *Agent) error {
Expand Down

0 comments on commit f562f35

Please sign in to comment.