Skip to content

Commit

Permalink
Add a http.Roundtripper based on HTTPOverRPC (#302)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-elinardi authored Sep 1, 2023
1 parent 524bd41 commit 09d6ddb
Show file tree
Hide file tree
Showing 2 changed files with 337 additions and 0 deletions.
184 changes: 184 additions & 0 deletions services/httpoverrpc/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package client
import (
"context"
"flag"
"fmt"
"io"
"log"
"net"
Expand Down Expand Up @@ -175,3 +176,186 @@ func TestGet(t *testing.T) {
t.Errorf("got %q, want %q", got, want)
}
}

func TestHTTPTransporter(t *testing.T) {
ctx := context.Background()

// Set up web server
m := http.NewServeMux()
m.HandleFunc("/helloworld", func(httpResp http.ResponseWriter, httpReq *http.Request) {
_, _ = httpResp.Write([]byte("hello world"))
})
l, err := net.Listen("tcp4", "localhost:0")
if err != nil {
t.Fatal(err)
}
go func() { _ = http.Serve(l, m) }()

// Dial out to sansshell server set up in TestMain
conn, err := proxy.DialContext(ctx, "", []string{"bufnet"}, grpc.WithContextDialer(bufDialer), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { conn.Close() })

// setup http transporter
transporter := NewHTTPTransporter(conn)

httpClient := http.Client{
Transport: transporter,
}

addr := l.Addr().String()
resp, err := httpClient.Get(fmt.Sprintf("http://%s/helloworld", addr))
if err != nil {
t.Fatal(err)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
want := "hello world"
if string(body) != want {
t.Errorf("got %q, want %q", body, want)
}
}

func TestHTTPTransporterBody(t *testing.T) {
ctx := context.Background()

// Set up web server
m := http.NewServeMux()
m.HandleFunc("/returnbody", func(httpResp http.ResponseWriter, httpReq *http.Request) {
body := []byte{}
if httpReq.Body != nil {
var err error
body, err = io.ReadAll(httpReq.Body)
if err != nil {
_, _ = httpResp.Write([]byte(err.Error()))
}
}
_, _ = httpResp.Write(body)
})
l, err := net.Listen("tcp4", "localhost:0")
if err != nil {
t.Fatal(err)
}
go func() { _ = http.Serve(l, m) }()

// Dial out to sansshell server set up in TestMain
conn, err := proxy.DialContext(ctx, "", []string{"bufnet"}, grpc.WithContextDialer(bufDialer), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { conn.Close() })

// setup http transporter
transporter := NewHTTPTransporter(conn)

httpClient := http.Client{
Transport: transporter,
}

addr := l.Addr().String()
reqBody := "hello sansshell"
resp, err := httpClient.Post(fmt.Sprintf("http://%s/returnbody", addr), "", strings.NewReader(reqBody))
if err != nil {
t.Fatal(err)
}
respBody, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
want := reqBody // should receive the sent request body
if string(respBody) != want {
t.Errorf("got %q, want %q", respBody, want)
}
}

func TestHTTPTransporterMissingScheme(t *testing.T) {
ctx := context.Background()

// Dial out to sansshell server set up in TestMain
conn, err := proxy.DialContext(ctx, "", []string{"bufnet"}, grpc.WithContextDialer(bufDialer), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { conn.Close() })

// setup http transporter
transporter := NewHTTPTransporter(conn)

httpClient := http.Client{
Transport: transporter,
}

_, errGet := httpClient.Get("localhost:9090")
if !strings.Contains(errGet.Error(), errInvalidURLScheme.Error()) {
t.Fatal("must return error with descriptive message when there's no scheme in the request URL")
}
}

func TestHTTPTransporterMissingHost(t *testing.T) {
ctx := context.Background()

// Dial out to sansshell server set up in TestMain
conn, err := proxy.DialContext(ctx, "", []string{"bufnet"}, grpc.WithContextDialer(bufDialer), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { conn.Close() })

// setup http transporter
transporter := NewHTTPTransporter(conn)

httpClient := http.Client{
Transport: transporter,
}

_, errGet := httpClient.Get("http://:9090")
if !strings.Contains(errGet.Error(), errInvalidURLMissingHost.Error()) {
t.Fatal("must return error with descriptive message when there's no hostname in the request URL")
}
}

func TestGetPort(t *testing.T) {
req, err := http.NewRequest("GET", "http://localhost:9999", nil)
if err != nil {
t.Fatal(err)
}
result, err := getPort(req, "http")
if err != nil {
t.Fatal(err)
}
if result != 9999 {
t.Fatalf("got wrong port: %d. Expected: %d", result, 9999)
}
}

func TestGetPortDefaultHTTP(t *testing.T) {
req, err := http.NewRequest("GET", "http://localhost", nil)
if err != nil {
t.Fatal(err)
}
result, err := getPort(req, "http")
if err != nil {
t.Fatal(err)
}
if result != defaultHTTPPort {
t.Fatalf("got wrong port: %d. Expected: %d", result, defaultHTTPPort)
}
}

func TestGetPortDefaultHTTPS(t *testing.T) {
req, err := http.NewRequest("GET", "https://localhost", nil)
if err != nil {
t.Fatal(err)
}
result, err := getPort(req, "https")
if err != nil {
t.Fatal(err)
}
if result != defaultHTTPSPort {
t.Fatalf("got wrong port: %d. Expected: %d", result, defaultHTTPSPort)
}
}
153 changes: 153 additions & 0 deletions services/httpoverrpc/client/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
/* Copyright (c) 2023 Snowflake Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
*/

// Package client provides the client interface for 'httpoverrpc'
package client

import (
"bytes"
"fmt"
"io"
"net/http"
"strconv"

"github.com/Snowflake-Labs/sansshell/proxy/proxy"
pb "github.com/Snowflake-Labs/sansshell/services/httpoverrpc"
)

const (
defaultHTTPPort = 80
defaultHTTPSPort = 443
)

var (
errInvalidURLScheme = fmt.Errorf("invalid URL scheme. Use either 'http' or 'https'")
errInvalidURLMissingHost = fmt.Errorf("no host in the request URL")
)

type HTTPTransporter struct {
conn *proxy.Conn
}

func NewHTTPTransporter(conn *proxy.Conn) *HTTPTransporter {
return &HTTPTransporter{
conn,
}
}

func httpHeaderToPbHeader(h *http.Header) []*pb.Header {
result := []*pb.Header{}
for k, v := range *h {
result = append(result, &pb.Header{
Key: k,
Values: v,
})
}

return result
}

func pbHeaderToHTTPHeader(header []*pb.Header) http.Header {
result := http.Header{}
for _, h := range header {
result[h.Key] = h.Values
}

return result
}

func pbReplytoHTTPResponse(rep *pb.HTTPReply) *http.Response {
reader := bytes.NewReader(rep.Body)
body := io.NopCloser(reader)
header := pbHeaderToHTTPHeader(rep.Headers)
result := &http.Response{
Body: body,
StatusCode: int(rep.StatusCode),
Header: header,
}

return result
}

// getPort retrieves the port number from the request URL.
// If the URL doesn't contain a port number, it returns the
// default port associated with the HTTP protocol.
func getPort(req *http.Request, protocol string) (int32, error) {
var ret int32
if req.URL.Port() != "" {
port, err := strconv.Atoi(req.URL.Port())
if err != nil {
return 0, err
}
ret = int32(port)
} else {
// No port in URL, add default port
if protocol == "http" {
ret = defaultHTTPPort
} else {
ret = defaultHTTPSPort
}
}

return ret, nil
}

func (c *HTTPTransporter) RoundTrip(req *http.Request) (*http.Response, error) {
if req.URL.Scheme != "http" && req.URL.Scheme != "https" {
return nil, errInvalidURLScheme
}

if req.URL.Hostname() == "" {
return nil, errInvalidURLMissingHost
}

proxy := pb.NewHTTPOverRPCClientProxy(c.conn)
body := []byte{}
if req.Body != nil {
var err error
body, err = io.ReadAll(req.Body)
if err != nil {
return nil, err
}
}
reqPb := &pb.HostHTTPRequest{
Request: &pb.HTTPRequest{
RequestUri: req.URL.Path,
Method: req.Method,
Headers: httpHeaderToPbHeader(&req.Header),
Body: body,
},
Protocol: req.URL.Scheme,
Hostname: req.URL.Hostname(),
}

port, errPort := getPort(req, reqPb.Protocol)
if errPort != nil {
return nil, fmt.Errorf("error getting port: %v", errPort)
}
reqPb.Port = port

respChan, err := proxy.HostOneMany(req.Context(), reqPb)
if err != nil {
return nil, err
}
resp := <-respChan
if resp.Error != nil {
return nil, fmt.Errorf("httpOverRPC failed: %v", resp.Error)
}
result := pbReplytoHTTPResponse(resp.Resp)
return result, nil
}

0 comments on commit 09d6ddb

Please sign in to comment.