diff --git a/connection.go b/connection.go index f6b1d4f..f27d9e2 100644 --- a/connection.go +++ b/connection.go @@ -4,11 +4,13 @@ import ( "encoding/base64" "encoding/json" "errors" + "log" "net" "net/http" "net/url" "os" "strings" + "sync" "time" "github.com/gorilla/websocket" @@ -16,9 +18,15 @@ import ( // Clients include the necessary info to connect to the server and the underlying socket type Client struct { - Remote *url.URL - Ws *websocket.Conn - Auth []OptAuth + Remote *url.URL + Ws *websocket.Conn + Auth []OptAuth + Host string + sendCh chan *Request + requests map[string]*Request + quit chan bool + closeCh chan bool + lock sync.Mutex } func NewClient(urlStr string, options ...OptAuth) (*Client, error) { @@ -26,20 +34,133 @@ func NewClient(urlStr string, options ...OptAuth) (*Client, error) { if err != nil { return nil, err } - dialer := websocket.Dialer{} - ws, _, err := dialer.Dial(urlStr, http.Header{}) - if err != nil { - return nil, err - } - return &Client{Remote: r, Ws: ws, Auth: options}, nil + client := Client{Remote: r, Auth: options, Host: urlStr} + client.quit = make(chan bool, 1) + client.requests = make(map[string]*Request) + client.sendCh = make(chan *Request, 10) + + go client.loop() + return &client, nil } // Client executes the provided request func (c *Client) ExecQuery(query string) ([]byte, error) { req := Query(query) - return c.Exec(req) + //return c.Exec(req) + responseCh, err := c.queueRequest(req) + if err != nil { + return nil, err + } + + response := <-responseCh + if response.Err != nil { + return nil, response.Err + } + + return response.Result.Data, nil + +} +func (c *Client) queueRequest(req *Request) (<-chan *Response, error) { + requestMessage, err := GraphSONSerializer(req) + if err != nil { + return nil, err + } + req.Msg = requestMessage + req.responseCh = make(chan *Response, 1) + req.inBatchMode = false + req.dataItems = make([]json.RawMessage, 0) + select { + case <-c.closeCh: + return nil, ErrConnectionClosed + default: + } + c.sendCh <- req + return req.responseCh, nil } +func (c *Client) loop() { + for { + if err := c.createConnection(); err != nil { + return + } + c.closeCh = make(chan bool) + var wg sync.WaitGroup + + wg.Add(1) + go func() { + err := c.sendLoop() + if err != nil { + log.Println(err) + } + c.closeConnection() + wg.Done() + }() + + wg.Add(1) + go func() { + err := c.recvLoop() + log.Println(err) + if err == nil { + panic("recvloop not get nil err") + } + close(c.closeCh) + wg.Done() + }() + + wg.Wait() + + select { + case <-c.quit: + c.flushRequest() + return + default: + } + + // server is not close,should flush request + c.flushRequest() + } +} +func (c *Client) flushRequest() { + c.lock.Lock() + defer c.lock.Unlock() + + for requestId, request := range c.requests { + response := Response{Err: ErrClosing} + request.responseCh <- &response + delete(c.requests, requestId) + } + + c.requests = make(map[string]*Request) +} +func (c *Client) sendLoop() error { + for { + select { + case request := <-c.sendCh: + err := c.Ws.WriteMessage(websocket.BinaryMessage, request.Msg) + // if fail, direct return + // send responseCh error + if err != nil { + response := Response{Err: ErrConnectionClosed} + responseCh := request.responseCh + responseCh <- &response + return err + } + // if success, put request in requests map + if request.Op != "authentication" { + c.lock.Lock() + c.requests[request.RequestId] = request + c.lock.Unlock() + } + + case <-c.closeCh: + return nil + case <-c.quit: + return nil + } + } +} + +/** func (c *Client) Exec(req *Request) ([]byte, error) { requestMessage, err := GraphSONSerializer(req) if err != nil { @@ -53,14 +174,31 @@ func (c *Client) Exec(req *Request) ([]byte, error) { } return c.ReadResponse() } +**/ -func (c *Client) ReadResponse() (data []byte, err error) { - // Data buffer - var message []byte - var dataItems []json.RawMessage - inBatchMode := false +func (c *Client) createConnection() error { + dialer := websocket.Dialer{} + ws, _, err := dialer.Dial(c.Host, http.Header{}) + if err != nil { + return err + } + c.Ws = ws + return nil +} +func (c *Client) closeConnection() { + if c.Ws != nil { + c.Ws.Close() + } +} +func (c *Client) Close() { + close(c.quit) +} +func (c *Client) recvLoop() (err error) { // Receive data for { + // Data buffer + var message []byte + if _, message, err = c.Ws.ReadMessage(); err != nil { return } @@ -71,28 +209,61 @@ func (c *Client) ReadResponse() (data []byte, err error) { var items []json.RawMessage switch res.Status.Code { case StatusNoContent: - return + res.Result.Data = make([]byte, 0) + c.lock.Lock() + if request, ok := c.requests[res.RequestId]; ok { + delete(c.requests, res.RequestId) + request.responseCh <- res + } + c.lock.Unlock() case StatusAuthenticate: - return c.Authenticate(res.RequestId) + if err = c.Authenticate(res.RequestId); err != nil { + return + } case StatusPartialContent: - inBatchMode = true if err = json.Unmarshal(res.Result.Data, &items); err != nil { + c.lock.Lock() + if request, ok := c.requests[res.RequestId]; ok { + delete(c.requests, res.RequestId) + res.Err = err + request.responseCh <- res + } + c.lock.Unlock() return } - dataItems = append(dataItems, items...) + + c.lock.Lock() + if request, ok := c.requests[res.RequestId]; ok { + request.inBatchMode = true + request.dataItems = append(request.dataItems, items...) + } + c.lock.Unlock() case StatusSuccess: - if inBatchMode { + c.lock.Lock() + request, ok := c.requests[res.RequestId] + // not find request + if !ok { + c.lock.Unlock() + continue + } + delete(c.requests, res.RequestId) + c.lock.Unlock() + + if request.inBatchMode { if err = json.Unmarshal(res.Result.Data, &items); err != nil { + res.Err = err + request.responseCh <- res return } - dataItems = append(dataItems, items...) - data, err = json.Marshal(dataItems) + request.dataItems = append(request.dataItems, items...) + + res.Result.Data, _ = json.Marshal(request.dataItems) + request.responseCh <- res } else { - data = res.Result.Data + request.responseCh <- res } - return default: if msg, exists := ErrorMsg[res.Status.Code]; exists { @@ -103,7 +274,6 @@ func (c *Client) ReadResponse() (data []byte, err error) { return } } - return } // AuthInfo includes all info related with SASL authentication with the Gremlin server @@ -156,10 +326,10 @@ func OptAuthUserPass(user, pass string) OptAuth { } // Authenticates the connection -func (c *Client) Authenticate(requestId string) ([]byte, error) { +func (c *Client) Authenticate(requestId string) error { auth, err := NewAuthInfo(c.Auth...) if err != nil { - return nil, err + return err } var sasl []byte sasl = append(sasl, 0) @@ -173,8 +343,19 @@ func (c *Client) Authenticate(requestId string) ([]byte, error) { Processor: "trasversal", Op: "authentication", Args: args, + // responseCh: make(chan *Response, nil), + } + return c.queueAuthRequest(authReq) +} +func (c *Client) queueAuthRequest(req *Request) error { + requestMessage, err := GraphSONSerializer(req) + if err != nil { + return err } - return c.Exec(authReq) + req.Msg = requestMessage + c.sendCh <- req + + return nil } var servers []*url.URL diff --git a/request.go b/request.go index 3e4c9fe..67134f8 100644 --- a/request.go +++ b/request.go @@ -3,14 +3,19 @@ package gremlin import ( "encoding/json" _ "fmt" - "github.com/satori/go.uuid" + + uuid "github.com/satori/go.uuid" ) type Request struct { - RequestId string `json:"requestId"` - Op string `json:"op"` - Processor string `json:"processor"` - Args *RequestArgs `json:"args"` + RequestId string `json:"requestId"` + Op string `json:"op"` + Processor string `json:"processor"` + Args *RequestArgs `json:"args"` + Msg []byte + responseCh chan *Response + inBatchMode bool + dataItems []json.RawMessage } type RequestArgs struct { @@ -28,7 +33,7 @@ type RequestArgs struct { // Formats the requests in the appropriate way type FormattedReq struct { Op string `json:"op"` - RequestId interface{} `json:"requestId"` + RequestId string `json:"requestId"` Args *RequestArgs `json:"args"` Processor string `json:"processor"` } @@ -48,8 +53,8 @@ func GraphSONSerializer(req *Request) ([]byte, error) { } func NewFormattedReq(req *Request) FormattedReq { - rId := map[string]string{"@type": "g:UUID", "@value": req.RequestId} - sr := FormattedReq{RequestId: rId, Processor: req.Processor, Op: req.Op, Args: req.Args} + // rId := map[string]string{"@type": "g:UUID", "@value": req.RequestId} + sr := FormattedReq{RequestId: req.RequestId, Processor: req.Processor, Op: req.Op, Args: req.Args} return sr } diff --git a/response-codes.go b/response-codes.go index 8f01c76..cfb8f2c 100644 --- a/response-codes.go +++ b/response-codes.go @@ -1,5 +1,7 @@ package gremlin +import "errors" + const ( StatusSuccess = 200 StatusNoContent = 204 @@ -24,3 +26,7 @@ var ErrorMsg = map[int]string{ StatusServerTimeout: "Server Timeout", StatusServerSerializationError: "Server Serialization Error", } +var ( + ErrConnectionClosed = errors.New("gremlin connection closed") + ErrClosing = errors.New("gremlin is closeing") +) diff --git a/response.go b/response.go index c980bf9..faccc63 100644 --- a/response.go +++ b/response.go @@ -9,6 +9,7 @@ type Response struct { RequestId string `json:"requestId"` Status *ResponseStatus `json:"status"` Result *ResponseResult `json:"result"` + Err error } type ResponseStatus struct {