diff --git a/services/httpoverrpc/client/client_test.go b/services/httpoverrpc/client/client_test.go index dfc5cef6..9f253610 100644 --- a/services/httpoverrpc/client/client_test.go +++ b/services/httpoverrpc/client/client_test.go @@ -19,6 +19,7 @@ package client import ( "context" "flag" + "fmt" "io" "log" "net" @@ -175,3 +176,186 @@ func TestGet(t *testing.T) { t.Errorf("got %q, want %q", got, want) } } + +func TestHTTPTransporter(t *testing.T) { + ctx := context.Background() + + // Set up web server + m := http.NewServeMux() + m.HandleFunc("/helloworld", func(httpResp http.ResponseWriter, httpReq *http.Request) { + _, _ = httpResp.Write([]byte("hello world")) + }) + l, err := net.Listen("tcp4", "localhost:0") + if err != nil { + t.Fatal(err) + } + go func() { _ = http.Serve(l, m) }() + + // Dial out to sansshell server set up in TestMain + conn, err := proxy.DialContext(ctx, "", []string{"bufnet"}, grpc.WithContextDialer(bufDialer), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { conn.Close() }) + + // setup http transporter + transporter := NewHTTPTransporter(conn) + + httpClient := http.Client{ + Transport: transporter, + } + + addr := l.Addr().String() + resp, err := httpClient.Get(fmt.Sprintf("http://%s/helloworld", addr)) + if err != nil { + t.Fatal(err) + } + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + want := "hello world" + if string(body) != want { + t.Errorf("got %q, want %q", body, want) + } +} + +func TestHTTPTransporterBody(t *testing.T) { + ctx := context.Background() + + // Set up web server + m := http.NewServeMux() + m.HandleFunc("/returnbody", func(httpResp http.ResponseWriter, httpReq *http.Request) { + body := []byte{} + if httpReq.Body != nil { + var err error + body, err = io.ReadAll(httpReq.Body) + if err != nil { + _, _ = httpResp.Write([]byte(err.Error())) + } + } + _, _ = httpResp.Write(body) + }) + l, err := net.Listen("tcp4", "localhost:0") + if err != nil { + t.Fatal(err) + } + go func() { _ = http.Serve(l, m) }() + + // Dial out to sansshell server set up in TestMain + conn, err := proxy.DialContext(ctx, "", []string{"bufnet"}, grpc.WithContextDialer(bufDialer), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { conn.Close() }) + + // setup http transporter + transporter := NewHTTPTransporter(conn) + + httpClient := http.Client{ + Transport: transporter, + } + + addr := l.Addr().String() + reqBody := "hello sansshell" + resp, err := httpClient.Post(fmt.Sprintf("http://%s/returnbody", addr), "", strings.NewReader(reqBody)) + if err != nil { + t.Fatal(err) + } + respBody, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + want := reqBody // should receive the sent request body + if string(respBody) != want { + t.Errorf("got %q, want %q", respBody, want) + } +} + +func TestHTTPTransporterMissingScheme(t *testing.T) { + ctx := context.Background() + + // Dial out to sansshell server set up in TestMain + conn, err := proxy.DialContext(ctx, "", []string{"bufnet"}, grpc.WithContextDialer(bufDialer), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { conn.Close() }) + + // setup http transporter + transporter := NewHTTPTransporter(conn) + + httpClient := http.Client{ + Transport: transporter, + } + + _, errGet := httpClient.Get("localhost:9090") + if !strings.Contains(errGet.Error(), errInvalidURLScheme.Error()) { + t.Fatal("must return error with descriptive message when there's no scheme in the request URL") + } +} + +func TestHTTPTransporterMissingHost(t *testing.T) { + ctx := context.Background() + + // Dial out to sansshell server set up in TestMain + conn, err := proxy.DialContext(ctx, "", []string{"bufnet"}, grpc.WithContextDialer(bufDialer), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { conn.Close() }) + + // setup http transporter + transporter := NewHTTPTransporter(conn) + + httpClient := http.Client{ + Transport: transporter, + } + + _, errGet := httpClient.Get("http://:9090") + if !strings.Contains(errGet.Error(), errInvalidURLMissingHost.Error()) { + t.Fatal("must return error with descriptive message when there's no hostname in the request URL") + } +} + +func TestGetPort(t *testing.T) { + req, err := http.NewRequest("GET", "http://localhost:9999", nil) + if err != nil { + t.Fatal(err) + } + result, err := getPort(req, "http") + if err != nil { + t.Fatal(err) + } + if result != 9999 { + t.Fatalf("got wrong port: %d. Expected: %d", result, 9999) + } +} + +func TestGetPortDefaultHTTP(t *testing.T) { + req, err := http.NewRequest("GET", "http://localhost", nil) + if err != nil { + t.Fatal(err) + } + result, err := getPort(req, "http") + if err != nil { + t.Fatal(err) + } + if result != defaultHTTPPort { + t.Fatalf("got wrong port: %d. Expected: %d", result, defaultHTTPPort) + } +} + +func TestGetPortDefaultHTTPS(t *testing.T) { + req, err := http.NewRequest("GET", "https://localhost", nil) + if err != nil { + t.Fatal(err) + } + result, err := getPort(req, "https") + if err != nil { + t.Fatal(err) + } + if result != defaultHTTPSPort { + t.Fatalf("got wrong port: %d. Expected: %d", result, defaultHTTPSPort) + } +} diff --git a/services/httpoverrpc/client/utils.go b/services/httpoverrpc/client/utils.go new file mode 100644 index 00000000..a1fef6ab --- /dev/null +++ b/services/httpoverrpc/client/utils.go @@ -0,0 +1,153 @@ +/* Copyright (c) 2023 Snowflake Inc. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*/ + +// Package client provides the client interface for 'httpoverrpc' +package client + +import ( + "bytes" + "fmt" + "io" + "net/http" + "strconv" + + "github.com/Snowflake-Labs/sansshell/proxy/proxy" + pb "github.com/Snowflake-Labs/sansshell/services/httpoverrpc" +) + +const ( + defaultHTTPPort = 80 + defaultHTTPSPort = 443 +) + +var ( + errInvalidURLScheme = fmt.Errorf("invalid URL scheme. Use either 'http' or 'https'") + errInvalidURLMissingHost = fmt.Errorf("no host in the request URL") +) + +type HTTPTransporter struct { + conn *proxy.Conn +} + +func NewHTTPTransporter(conn *proxy.Conn) *HTTPTransporter { + return &HTTPTransporter{ + conn, + } +} + +func httpHeaderToPbHeader(h *http.Header) []*pb.Header { + result := []*pb.Header{} + for k, v := range *h { + result = append(result, &pb.Header{ + Key: k, + Values: v, + }) + } + + return result +} + +func pbHeaderToHTTPHeader(header []*pb.Header) http.Header { + result := http.Header{} + for _, h := range header { + result[h.Key] = h.Values + } + + return result +} + +func pbReplytoHTTPResponse(rep *pb.HTTPReply) *http.Response { + reader := bytes.NewReader(rep.Body) + body := io.NopCloser(reader) + header := pbHeaderToHTTPHeader(rep.Headers) + result := &http.Response{ + Body: body, + StatusCode: int(rep.StatusCode), + Header: header, + } + + return result +} + +// getPort retrieves the port number from the request URL. +// If the URL doesn't contain a port number, it returns the +// default port associated with the HTTP protocol. +func getPort(req *http.Request, protocol string) (int32, error) { + var ret int32 + if req.URL.Port() != "" { + port, err := strconv.Atoi(req.URL.Port()) + if err != nil { + return 0, err + } + ret = int32(port) + } else { + // No port in URL, add default port + if protocol == "http" { + ret = defaultHTTPPort + } else { + ret = defaultHTTPSPort + } + } + + return ret, nil +} + +func (c *HTTPTransporter) RoundTrip(req *http.Request) (*http.Response, error) { + if req.URL.Scheme != "http" && req.URL.Scheme != "https" { + return nil, errInvalidURLScheme + } + + if req.URL.Hostname() == "" { + return nil, errInvalidURLMissingHost + } + + proxy := pb.NewHTTPOverRPCClientProxy(c.conn) + body := []byte{} + if req.Body != nil { + var err error + body, err = io.ReadAll(req.Body) + if err != nil { + return nil, err + } + } + reqPb := &pb.HostHTTPRequest{ + Request: &pb.HTTPRequest{ + RequestUri: req.URL.Path, + Method: req.Method, + Headers: httpHeaderToPbHeader(&req.Header), + Body: body, + }, + Protocol: req.URL.Scheme, + Hostname: req.URL.Hostname(), + } + + port, errPort := getPort(req, reqPb.Protocol) + if errPort != nil { + return nil, fmt.Errorf("error getting port: %v", errPort) + } + reqPb.Port = port + + respChan, err := proxy.HostOneMany(req.Context(), reqPb) + if err != nil { + return nil, err + } + resp := <-respChan + if resp.Error != nil { + return nil, fmt.Errorf("httpOverRPC failed: %v", resp.Error) + } + result := pbReplytoHTTPResponse(resp.Resp) + return result, nil +}