diff --git a/app/pkg/agent/agent_test.go b/app/pkg/agent/agent_test.go index 37101bc..628c0de 100644 --- a/app/pkg/agent/agent_test.go +++ b/app/pkg/agent/agent_test.go @@ -1,12 +1,15 @@ package agent import ( + "connectrpc.com/connect" "context" "crypto/tls" "fmt" + "github.com/jlewi/foyle/app/pkg/runme/ulid" "io" "net" "net/http" + "net/url" "os" "os/signal" "testing" @@ -161,6 +164,34 @@ func Test_Generate(t *testing.T) { } } +func Test_StreamingClient(t *testing.T) { + // Test the streaming client + // Useful when testing with a service version + if os.Getenv("GITHUB_ACTIONS") != "" { + t.Skipf("Test is skipped in GitHub actions") + } + + // Setup logs + c := zap.NewDevelopmentConfig() + newLog, err := c.Build() + if err != nil { + t.Fatalf("Error creating logger; %v", err) + } + zap.ReplaceGlobals(newLog) + log := zapr.NewLogger(newLog) + // This is code to help us test streaming with the connect protocol + + addr := "https://foyle.gateway.unified-0s.internal.api.openai.org/api" + //addr := "https://foyle.unified-0s.internal.api.openai.org/api" + //addr := "http://127.0.0.1:8877/api" + //addr := "http://127.0.0.1:9977/api" + + log.Info("Server started") + if err := runClient(addr); err != nil { + t.Fatalf("Error running client for addres %v; %v", addr, err) + } +} + func Test_Streaming(t *testing.T) { if os.Getenv("GITHUB_ACTIONS") != "" { t.Skipf("Test is skipped in GitHub actions") @@ -177,7 +208,7 @@ func Test_Streaming(t *testing.T) { // This is code to help us test streaming with the connect protocol a := &Agent{} - addr := "localhost:8088" + addr := "http://localhost:8088" go func() { if err := setupAndRunServer(addr, a); err != nil { log.Error(err, "Error running server") @@ -209,38 +240,99 @@ func waitForServer(addr string) error { } return errors.Errorf("Server didn't start in time") } -func runClient(addr string) { + +func runClient(baseURL string) error { log := zapr.NewLogger(zap.L()) - baseURL := fmt.Sprintf("http://%s", addr) - client := v1alpha1connect.NewAIServiceClient( - &http.Client{ - Transport: &http2.Transport{ - AllowHTTP: true, - DialTLSContext: func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) { - // Use the standard Dial function to create a plain TCP connection - return net.Dial(network, addr) + u, err := url.Parse(baseURL) + if err != nil { + log.Error(err, "Failed to parse URL") + panic(err) + } + + var client v1alpha1connect.AIServiceClient + + if u.Scheme == "https" { + // Configure the TLS settings + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, // Set to true only for testing; otherwise validate the server's certificate + } + + client = v1alpha1connect.NewAIServiceClient( + &http.Client{ + Transport: &http2.Transport{ + TLSClientConfig: tlsConfig, + DialTLSContext: func(ctx context.Context, network, addr string, config *tls.Config) (net.Conn, error) { + // Create a secure connection with TLS + return tls.Dial(network, addr, config) + }, }, }, - }, - baseURL, - ) + baseURL, + ) + } else { + client = v1alpha1connect.NewAIServiceClient( + &http.Client{ + Transport: &http2.Transport{ + AllowHTTP: true, + DialTLSContext: func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) { + // Use the standard Dial function to create a plain TCP connection + return net.Dial(network, u.Host) + }, + }, + }, + baseURL, + ) + } + + // Make sure we can send a status request + statusReq := &v1alpha1.StatusRequest{} + statusResp, err := client.Status(context.Background(), connect.NewRequest(statusReq)) + if err != nil { + log.Error(err, "Failed to send status request") + return errors.Wrapf(err, "Failed to send status request") + } + + log.Info("Status response", "response", statusResp) ctx := context.Background() stream := client.StreamGenerate(ctx) // Send requests requests := []string{"Hello", "How are you?", "Goodbye"} - for _, prompt := range requests { - - req := &v1alpha1.StreamGenerateRequest{ - Request: &v1alpha1.StreamGenerateRequest_Update{ - Update: &v1alpha1.UpdateContext{ - Cell: &parserv1.Cell{ - Kind: parserv1.CellKind_CELL_KIND_MARKUP, - Value: prompt, + + contextId := ulid.GenerateID() + for i, prompt := range requests { + + var req *v1alpha1.StreamGenerateRequest + if i == 0 { + req = &v1alpha1.StreamGenerateRequest{ + ContextId: contextId, + Request: &v1alpha1.StreamGenerateRequest_FullContext{ + FullContext: &v1alpha1.FullContext{ + Notebook: &parserv1.Notebook{ + Cells: []*parserv1.Cell{ + { + Kind: parserv1.CellKind_CELL_KIND_MARKUP, + Value: prompt, + }, + }, + }, + NotebookUri: "/path/to/notebook", }, }, - }, + } + } else { + req = &v1alpha1.StreamGenerateRequest{ + ContextId: contextId, + Request: &v1alpha1.StreamGenerateRequest_Update{ + Update: &v1alpha1.UpdateContext{ + Cell: &parserv1.Cell{ + Kind: parserv1.CellKind_CELL_KIND_MARKUP, + Value: prompt, + }, + }, + }, + } } err := stream.Send(req) @@ -264,9 +356,11 @@ func runClient(addr string) { } if err != nil { log.Error(err, "Failed to receive response") + return errors.Wrapf(err, "Failed to receive response") } log.Info("Received response", "response", response) } + return nil } func setupAndRunServer(addr string, a *Agent) error {