diff --git a/CHANGELOG.md b/CHANGELOG.md index 7d50b0bf..ab54358d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,30 @@ # Changelog +## v0.6.0 - 1 Feb 2015 + +There are some major changes to the driver with this release that are not related to the RethinkDB v1.16 release. Please have a read through them: +- Improvements to result decoding by caching reflection calls. +- Finished implementing the `Marshaler`/`Unmarshaler` interfaces +- Connection pool overhauled. There were a couple of issues with connections in the previous releases so this release replaces the `fatih/pool` package with a connection pool based on the `database/sql` connection pool. +- Another change is the removal of the prefetching mechanism as the connection+cursor logic was becoming quite complex and causing bugs, hopefully this will be added back in the near future but for now I am focusing my efforts on ensuring the driver is as stable as possible #130 #137 +- Due to the above change the API for connecting has changed slightly (The API is now closer to the `database/sql` API. `ConnectOpts` changes: + - `MaxActive` renamed to `MaxOpen` + - `IdleTimeout` renamed to `Timeout` +- `Cursor`s are now only closed automatically when calling either `All` or `One` +- `Exec` now takes `ExecOpts` instead of `RunOpts`. The only difference is that `Exec` has the `NoReply` field + +With that out the way here are the v1.16 changes: + +- Added `Range` which generates all numbers from a given range +- Added an optional squash argument to the changes command, which lets the server combine multiple changes to the same document (defaults to true) +- Added new admin functions (`Config`, `Rebalance`, `Reconfigure`, `Status`, `Wait`) +- Added support for `SUCCESS_ATOM_FEED` +- Added `MinIndex` + `MaxInde`x functions +- Added `ToJSON` function +- Updated `WriteResponse` type + +Since this release has a lot of changes and although I have tested these changes sometimes things fall through the gaps. If you discover any bugs please let me know and I will try to fix them as soon as possible. + ## Hotfix - 14 Dec 2014 - Fixed empty slices being returned as `[]T(nil)` not `[]T{}` #138 diff --git a/README.md b/README.md index 4e39ecec..9f6ad4d8 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,17 @@ -GoRethink - RethinkDB Driver for Go [![wercker status](https://app.wercker.com/status/e315e764041af8e80f0c68280d4b4de2/s/master "wercker status")](https://app.wercker.com/project/bykey/e315e764041af8e80f0c68280d4b4de2) [![GoDoc](https://godoc.org/github.com/dancannon/gorethink?status.png)](https://godoc.org/github.com/dancannon/gorethink) -===================== +# GoRethink - RethinkDB Driver for Go -[Go](http://golang.org/) driver for [RethinkDB](http://www.rethinkdb.com/) made by [Daniel Cannon](http://github.com/dancannon) and based off of Christopher Hesse's [RethinkGo](https://github.com/christopherhesse/rethinkgo) driver. +[![GitHub tag](https://img.shields.io/github/tag/dancannon/gorethink.svg?style=flat)]() +[![GoDoc](https://godoc.org/github.com/dancannon/gorethink?status.png)](https://godoc.org/github.com/dancannon/gorethink) +[![wercker status](https://app.wercker.com/status/e315e764041af8e80f0c68280d4b4de2/s/master "wercker status")](https://app.wercker.com/project/bykey/e315e764041af8e80f0c68280d4b4de2) +[Go](http://golang.org/) driver for [RethinkDB](http://www.rethinkdb.com/) -Current version: v0.5.0 (RethinkDB v1.15.1) -**Version 0.3 introduced some API changes, for more information check the [change log](CHANGELOG.md)** +Current version: v0.6.0 (RethinkDB v1.16.0) + +**Version 0.6 introduced some small API changes and some significant internal changes, for more information check the [change log](CHANGELOG.md) and please be aware the driver is not yet stable** + +[![Gitter](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/dancannon/gorethink?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) ## Installation @@ -41,7 +46,9 @@ See the [documentation](http://godoc.org/github.com/dancannon/gorethink#Connect) ### Connection Pool -The driver uses a connection pool at all times, however by default there is only a single connection available. In order to turn this into a proper connection pool, we need to pass the `maxIdle`, `maxActive` and/or `idleTimeout` parameters to Connect(): +The driver uses a connection pool at all times, by default it creates and frees connections automatically. It's safe for concurrent use by multiple goroutines. + +To configure the connection pool `MaxIdle`, `MaxOpen` and `IdleTimeout` can be specified during connection. If you wish to change the value of `MaxIdle` or `MaxOpen` during runtime then the functions `SetMaxIdleConns` and `SetMaxOpenConns` can be used. ```go import ( @@ -54,16 +61,18 @@ session, err := r.Connect(r.ConnectOpts{ Address: "localhost:28015", Database: "test", MaxIdle: 10, - IdleTimeout: time.Second * 10, + MaxOpen: 10, }) - if err != nil { log.Fatalln(err.Error()) } + +session.SetMaxOpenConns(5) ``` A pre-configured [Pool](http://godoc.org/github.com/dancannon/gorethink#Pool) instance can also be passed to Connect(). + ## Query Functions This library is based on the official drivers so the code on the [API](http://www.rethinkdb.com/api/) page should require very few changes to work. @@ -112,7 +121,7 @@ Different result types are returned depending on what function is used to execut - `Run` returns a cursor which can be used to view all rows returned. - `RunWrite` returns a WriteResponse and should be used for queries such as Insert,Update,etc... -- `Exec` sends a query to the server with the noreply flag set and returns immediately +- `Exec` sends a query to the server and closes the connection immediately after reading the response from the database. If you do not wish to wait for the response then you can set the `NoReply` flag. Example: diff --git a/connection.go b/connection.go index bc97c24b..991599a6 100644 --- a/connection.go +++ b/connection.go @@ -7,287 +7,308 @@ import ( "fmt" "io" "net" - "sync" + "sync/atomic" "time" - "gopkg.in/fatih/pool.v2" - p "github.com/dancannon/gorethink/ql2" ) type Response struct { Token int64 Type p.Response_ResponseType `json:"t"` - Responses []interface{} `json:"r"` + Responses []json.RawMessage `json:"r"` Backtrace []interface{} `json:"b"` Profile interface{} `json:"p"` } -type Conn interface { - SendQuery(s *Session, q *p.Query, t Term, opts map[string]interface{}, async bool) (*Cursor, error) - ReadResponse(s *Session, token int64) (*Response, error) - Close() error -} - -// connection is a connection to a rethinkdb database +// Connection is a connection to a rethinkdb database. Connection is not thread +// safe and should only be accessed be a single goroutine type Connection struct { - // embed the net.Conn type, so that we can effectively define new methods on - // it (interfaces do not allow that) - net.Conn - s *Session - - sync.Mutex + conn net.Conn + opts *ConnectOpts + token int64 + cursors map[int64]*Cursor + bad bool } // Dial closes the previous connection and attempts to connect again. -func Dial(s *Session) pool.Factory { - return func() (net.Conn, error) { - conn, err := net.Dial("tcp", s.address) - if err != nil { - return nil, RqlConnectionError{err.Error()} - } - - // Send the protocol version to the server as a 4-byte little-endian-encoded integer - if err := binary.Write(conn, binary.LittleEndian, p.VersionDummy_V0_3); err != nil { - return nil, RqlConnectionError{err.Error()} - } +func NewConnection(opts *ConnectOpts) (*Connection, error) { + conn, err := net.Dial("tcp", opts.Address) + if err != nil { + return nil, RqlConnectionError{err.Error()} + } - // Send the length of the auth key to the server as a 4-byte little-endian-encoded integer - if err := binary.Write(conn, binary.LittleEndian, uint32(len(s.authkey))); err != nil { - return nil, RqlConnectionError{err.Error()} - } + // Send the protocol version to the server as a 4-byte little-endian-encoded integer + if err := binary.Write(conn, binary.LittleEndian, p.VersionDummy_V0_3); err != nil { + return nil, RqlConnectionError{err.Error()} + } - // Send the auth key as an ASCII string - // If there is no auth key, skip this step - if s.authkey != "" { - if _, err := io.WriteString(conn, s.authkey); err != nil { - return nil, RqlConnectionError{err.Error()} - } - } + // Send the length of the auth key to the server as a 4-byte little-endian-encoded integer + if err := binary.Write(conn, binary.LittleEndian, uint32(len(opts.AuthKey))); err != nil { + return nil, RqlConnectionError{err.Error()} + } - // Send the protocol type as a 4-byte little-endian-encoded integer - if err := binary.Write(conn, binary.LittleEndian, p.VersionDummy_JSON); err != nil { + // Send the auth key as an ASCII string + // If there is no auth key, skip this step + if opts.AuthKey != "" { + if _, err := io.WriteString(conn, opts.AuthKey); err != nil { return nil, RqlConnectionError{err.Error()} } + } - // read server response to authorization key (terminated by NUL) - reader := bufio.NewReader(conn) - line, err := reader.ReadBytes('\x00') - if err != nil { - if err == io.EOF { - return nil, fmt.Errorf("Unexpected EOF: %s", string(line)) - } - return nil, RqlDriverError{err.Error()} - } - // convert to string and remove trailing NUL byte - response := string(line[:len(line)-1]) - if response != "SUCCESS" { - // we failed authorization or something else terrible happened - return nil, RqlDriverError{fmt.Sprintf("Server dropped connection with message: \"%s\"", response)} + // Send the protocol type as a 4-byte little-endian-encoded integer + if err := binary.Write(conn, binary.LittleEndian, p.VersionDummy_JSON); err != nil { + return nil, RqlConnectionError{err.Error()} + } + + // read server response to authorization key (terminated by NUL) + reader := bufio.NewReader(conn) + line, err := reader.ReadBytes('\x00') + if err != nil { + if err == io.EOF { + return nil, fmt.Errorf("Unexpected EOF: %s", string(line)) } + return nil, RqlConnectionError{err.Error()} + } + // convert to string and remove trailing NUL byte + response := string(line[:len(line)-1]) + if response != "SUCCESS" { + // we failed authorization or something else terrible happened + return nil, RqlDriverError{fmt.Sprintf("Server dropped connection with message: \"%s\"", response)} + } - return conn, nil + c := &Connection{ + opts: opts, + conn: conn, + cursors: make(map[int64]*Cursor), } -} -func TestOnBorrow(c *Connection, t time.Time) error { - c.SetReadDeadline(t) + c.conn.SetDeadline(time.Time{}) - data := make([]byte, 1) - if _, err := c.Read(data); err != nil { - e, ok := err.(net.Error) - if err != nil && !(ok && e.Timeout()) { - return err - } + return c, nil +} + +// Close closes the underlying net.Conn +func (c *Connection) Close() error { + if c.conn != nil { + c.conn.Close() + c.conn = nil } - c.SetReadDeadline(time.Time{}) + c.cursors = nil + c.opts = nil + return nil } -func (c *Connection) ReadResponse(s *Session, token int64) (*Response, error) { - for { - // Read the 8-byte token of the query the response corresponds to. - var responseToken int64 - if err := binary.Read(c, binary.LittleEndian, &responseToken); err != nil { - return nil, RqlConnectionError{err.Error()} - } +func (c *Connection) Query(q Query) (*Response, *Cursor, error) { + if c == nil { + return nil, nil, nil + } + if c.conn == nil { + c.bad = true + return nil, nil, nil + } - // Read the length of the JSON-encoded response as a 4-byte - // little-endian-encoded integer. - var messageLength uint32 - if err := binary.Read(c, binary.LittleEndian, &messageLength); err != nil { - return nil, RqlConnectionError{err.Error()} + // Add token if query is a START/NOREPLY_WAIT + if q.Type == p.Query_START || q.Type == p.Query_NOREPLY_WAIT { + q.Token = c.nextToken() + if c.opts.Database != "" { + q.Opts["db"] = Db(c.opts.Database).build() } + } - // Read the JSON encoding of the Response itself. - b := make([]byte, messageLength) - if _, err := io.ReadFull(c, b); err != nil { - return nil, RqlDriverError{err.Error()} - } + err := c.sendQuery(q) + if err != nil { + return nil, nil, err + } + + if noreply, ok := q.Opts["noreply"]; ok && noreply.(bool) { + return nil, nil, nil + } - // Decode the response - var response = new(Response) - response.Token = responseToken - err := json.Unmarshal(b, response) + var response *Response + for { + response, err = c.readResponse() if err != nil { - return nil, RqlDriverError{err.Error()} + return nil, nil, err } - if responseToken == token { - return response, nil - } else if cursor, ok := s.checkCache(token); ok { - // Handle batch response - s.handleBatchResponse(cursor, response) - } else { - return nil, RqlDriverError{"Unexpected response received"} + if response.Token == q.Token { + // If this was the requested response process and return + return c.processResponse(q, response) + } else if _, ok := c.cursors[response.Token]; ok { + // If the token is in the cursor cache then process the response + c.processResponse(q, response) } } } -func (c *Connection) SendQuery(s *Session, q Query, opts map[string]interface{}, async bool) (*Cursor, error) { - var err error - +func (c *Connection) sendQuery(q Query) error { // Build query b, err := json.Marshal(q.build()) if err != nil { - return nil, RqlDriverError{"Error building query"} + return RqlDriverError{"Error building query"} } // Set timeout - if s.timeout == 0 { - c.SetDeadline(time.Time{}) + if c.opts.Timeout == 0 { + c.conn.SetDeadline(time.Time{}) } else { - c.SetDeadline(time.Now().Add(s.timeout)) + c.conn.SetDeadline(time.Now().Add(c.opts.Timeout)) } // Send a unique 8-byte token - if err = binary.Write(c, binary.LittleEndian, q.Token); err != nil { - return nil, RqlConnectionError{err.Error()} + if err = binary.Write(c.conn, binary.LittleEndian, q.Token); err != nil { + c.bad = true + return RqlConnectionError{err.Error()} } // Send the length of the JSON-encoded query as a 4-byte // little-endian-encoded integer. - if err = binary.Write(c, binary.LittleEndian, uint32(len(b))); err != nil { - return nil, RqlConnectionError{err.Error()} + if err = binary.Write(c.conn, binary.LittleEndian, uint32(len(b))); err != nil { + c.bad = true + return RqlConnectionError{err.Error()} } // Send the JSON encoding of the query itself. - if err = binary.Write(c, binary.BigEndian, b); err != nil { + if err = binary.Write(c.conn, binary.BigEndian, b); err != nil { + c.bad = true + return RqlConnectionError{err.Error()} + } + + return nil +} + +// getToken generates the next query token, used to number requests and match +// responses with requests. +func (c *Connection) nextToken() int64 { + return atomic.AddInt64(&c.token, 1) +} + +func (c *Connection) readResponse() (*Response, error) { + // Read the 8-byte token of the query the response corresponds to. + var responseToken int64 + if err := binary.Read(c.conn, binary.LittleEndian, &responseToken); err != nil { + c.bad = true return nil, RqlConnectionError{err.Error()} } - // Return immediately if the noreply option was set - if noreply, ok := opts["noreply"]; (ok && noreply.(bool)) || async { - return nil, nil + // Read the length of the JSON-encoded response as a 4-byte + // little-endian-encoded integer. + var messageLength uint32 + if err := binary.Read(c.conn, binary.LittleEndian, &messageLength); err != nil { + c.bad = true + return nil, RqlConnectionError{err.Error()} } - // Get response - response, err := c.ReadResponse(s, q.Token) - if err != nil { - return nil, err + // Read the JSON encoding of the Response itself. + b := make([]byte, messageLength) + if _, err := io.ReadFull(c.conn, b); err != nil { + c.bad = true + return nil, RqlConnectionError{err.Error()} } - err = checkErrorResponse(response, q.Term) - if err != nil { - return nil, err + // Decode the response + var response = new(Response) + if err := json.Unmarshal(b, response); err != nil { + c.bad = true + return nil, RqlDriverError{err.Error()} } + response.Token = responseToken - // De-construct datum and return a cursor + return response, nil +} + +func (c *Connection) processResponse(q Query, response *Response) (*Response, *Cursor, error) { switch response.Type { - case p.Response_SUCCESS_PARTIAL, p.Response_SUCCESS_SEQUENCE, p.Response_SUCCESS_FEED: - cursor := &Cursor{ - session: s, - conn: c, - query: q, - term: *q.Term, - opts: opts, - profile: response.Profile, - } + case p.Response_CLIENT_ERROR: + return c.processErrorResponse(q, response, RqlClientError{rqlResponseError{response, q.Term}}) + case p.Response_COMPILE_ERROR: + return c.processErrorResponse(q, response, RqlCompileError{rqlResponseError{response, q.Term}}) + case p.Response_RUNTIME_ERROR: + return c.processErrorResponse(q, response, RqlRuntimeError{rqlResponseError{response, q.Term}}) + case p.Response_SUCCESS_ATOM: + return c.processAtomResponse(q, response) + case p.Response_SUCCESS_FEED, p.Response_SUCCESS_ATOM_FEED: + return c.processFeedResponse(q, response) + case p.Response_SUCCESS_PARTIAL: + return c.processPartialResponse(q, response) + case p.Response_SUCCESS_SEQUENCE: + return c.processSequenceResponse(q, response) + case p.Response_WAIT_COMPLETE: + return c.processWaitResponse(q, response) + default: + return nil, nil, RqlDriverError{"Unexpected response type"} + } +} - s.setCache(q.Token, cursor) +func (c *Connection) processErrorResponse(q Query, response *Response, err error) (*Response, *Cursor, error) { + cursor := c.cursors[response.Token] - cursor.extend(response) + delete(c.cursors, response.Token) - return cursor, nil - case p.Response_SUCCESS_ATOM: - var value []interface{} - var err error - - if len(response.Responses) < 1 { - value = []interface{}{} - } else { - var v interface{} - - v, err = recursivelyConvertPseudotype(response.Responses[0], opts) - if err != nil { - return nil, err - } - if err != nil { - return nil, RqlDriverError{err.Error()} - } - - if sv, ok := v.([]interface{}); ok { - value = sv - } else if v == nil { - value = []interface{}{nil} - } else { - value = []interface{}{v} - } - } + return response, cursor, err +} - cursor := &Cursor{ - session: s, - conn: c, - query: q, - term: *q.Term, - opts: opts, - profile: response.Profile, - buffer: value, - finished: true, - } +func (c *Connection) processAtomResponse(q Query, response *Response) (*Response, *Cursor, error) { + // Create cursor + cursor := newCursor(c, response.Token, q.Term, q.Opts) + cursor.profile = response.Profile - return cursor, nil - case p.Response_WAIT_COMPLETE: - return nil, nil - default: - return nil, RqlDriverError{fmt.Sprintf("Unexpected response type received: %s", response.Type)} - } + cursor.extend(response) + + return response, cursor, nil } -func (c *Connection) Close() error { - err := c.NoreplyWait() - if err != nil { - return err +func (c *Connection) processFeedResponse(q Query, response *Response) (*Response, *Cursor, error) { + var cursor *Cursor + if _, ok := c.cursors[response.Token]; !ok { + // Create a new cursor if needed + cursor = newCursor(c, response.Token, q.Term, q.Opts) + cursor.profile = response.Profile + c.cursors[response.Token] = cursor + } else { + cursor = c.cursors[response.Token] } - return c.Conn.Close() + cursor.extend(response) + + return response, cursor, nil } -// noreplyWaitQuery sends the NOREPLY_WAIT query to the server. -func (c *Connection) NoreplyWait() error { - q := Query{ - Type: p.Query_NOREPLY_WAIT, - Token: c.s.nextToken(), - } +func (c *Connection) processPartialResponse(q Query, response *Response) (*Response, *Cursor, error) { + cursor, ok := c.cursors[response.Token] + if !ok { + // Create a new cursor if needed + cursor = newCursor(c, response.Token, q.Term, q.Opts) + cursor.profile = response.Profile - _, err := c.SendQuery(c.s, q, map[string]interface{}{}, false) - if err != nil { - return err + c.cursors[response.Token] = cursor } - return nil + cursor.extend(response) + + return response, cursor, nil } -func checkErrorResponse(response *Response, t *Term) error { - switch response.Type { - case p.Response_CLIENT_ERROR: - return RqlClientError{rqlResponseError{response, t}} - case p.Response_COMPILE_ERROR: - return RqlCompileError{rqlResponseError{response, t}} - case p.Response_RUNTIME_ERROR: - return RqlRuntimeError{rqlResponseError{response, t}} +func (c *Connection) processSequenceResponse(q Query, response *Response) (*Response, *Cursor, error) { + cursor, ok := c.cursors[response.Token] + if !ok { + // Create a new cursor if needed + cursor = newCursor(c, response.Token, q.Term, q.Opts) + cursor.profile = response.Profile } - return nil + delete(c.cursors, response.Token) + + cursor.extend(response) + + return response, cursor, nil +} + +func (c *Connection) processWaitResponse(q Query, response *Response) (*Response, *Cursor, error) { + delete(c.cursors, response.Token) + + return response, nil, nil } diff --git a/cursor.go b/cursor.go index 30f2c722..7a7ffe9c 100644 --- a/cursor.go +++ b/cursor.go @@ -1,75 +1,106 @@ package gorethink import ( + "encoding/json" "errors" "reflect" - "sync" "github.com/dancannon/gorethink/encoding" p "github.com/dancannon/gorethink/ql2" ) -// Cursors are used to represent data returned from the database. +var ( + errCursorClosed = errors.New("connection closed, cannot read cursor") +) + +func newCursor(conn *Connection, token int64, term *Term, opts map[string]interface{}) *Cursor { + cursor := &Cursor{ + conn: conn, + token: token, + term: term, + opts: opts, + } + + return cursor +} + +// Cursor is the result of a query. Its cursor starts before the first row +// of the result set. A Cursor is not thread safe and should only be accessed +// by a single goroutine at any given time. Use Next to advance through the +// rows: // -// The code for this struct is based off of mgo's Iter and the official -// python driver's cursor. +// cursor, err := query.Run(session) +// ... +// defer cursor.Close() +// +// var response interface{} +// for cursor.Next(&response) { +// ... +// } +// err = cursor.Err() // get any error encountered during iteration +// ... type Cursor struct { - mu sync.Mutex - session *Session - conn *Connection - query Query - term Term - opts map[string]interface{} - - err error - outstandingRequests int - closed bool - finished bool - responses []*Response - profile interface{} - buffer []interface{} + pc *poolConn + releaseConn func(error) + + conn *Connection + token int64 + query Query + term *Term + opts map[string]interface{} + + lastErr error + fetching bool + closed bool + finished bool + isAtom bool + buffer queue + responses queue + profile interface{} } // Profile returns the information returned from the query profiler. func (c *Cursor) Profile() interface{} { - c.mu.Lock() - defer c.mu.Unlock() - return c.profile } // Err returns nil if no errors happened during iteration, or the actual // error otherwise. func (c *Cursor) Err() error { - c.mu.Lock() - defer c.mu.Unlock() - - return c.err + return c.lastErr } // Close closes the cursor, preventing further enumeration. If the end is // encountered, the cursor is closed automatically. Close is idempotent. func (c *Cursor) Close() error { - c.mu.Lock() + var err error - if !c.closed && !c.finished { - c.mu.Unlock() - err := c.session.stopQuery(c) - c.mu.Lock() + if c.closed { + return nil + } - if err != nil && (c.err == nil || c.err == ErrEmptyResult) { - c.err = err - } - c.closed = true + conn := c.conn + if conn == nil { + return nil + } + if conn.conn == nil { + return nil } - err := c.conn.Close() - if err != nil { - return err + // Stop any unfinished queries + if !c.closed && !c.finished { + q := Query{ + Type: p.Query_STOP, + Token: c.token, + } + + _, _, err = conn.Query(q) } - err = c.err - c.mu.Unlock() + c.releaseConn(err) + + c.closed = true + c.conn = nil return err } @@ -83,91 +114,81 @@ func (c *Cursor) Close() error { // and false at the end of the result set or if an error happened. // When Next returns false, the Err method should be called to verify if // there was an error during iteration. -func (c *Cursor) Next(result interface{}) bool { - c.mu.Lock() +func (c *Cursor) Next(dest interface{}) bool { + if c.closed { + return false + } + + hasMore, err := c.loadNext(dest) + if c.handleError(err) != nil { + c.Close() + return false + } + + return hasMore +} - // Load more data if needed - for c.err == nil { +func (c *Cursor) loadNext(dest interface{}) (bool, error) { + for c.lastErr == nil { // Check if response is closed/finished - if len(c.buffer) == 0 && len(c.responses) == 0 && c.closed { - c.err = errors.New("connection closed, cannot read cursor") - c.mu.Unlock() - return false - } - if len(c.buffer) == 0 && len(c.responses) == 0 && c.finished { - c.mu.Unlock() - return false + if c.buffer.Len() == 0 && c.responses.Len() == 0 && c.closed { + + return false, errCursorClosed } - // Start precomputing next batch - if len(c.responses) == 1 && !c.finished { - c.mu.Unlock() - if err := c.session.asyncContinueQuery(c); err != nil { - c.err = err - return false + if c.buffer.Len() == 0 && c.responses.Len() == 0 && !c.finished { + + err := c.fetchMore() + if err != nil { + return false, err } - c.mu.Lock() } - // If the buffer is empty fetch more results - if len(c.buffer) == 0 { - if len(c.responses) == 0 && !c.finished { - c.mu.Unlock() - if err := c.session.continueQuery(c); err != nil { - c.err = err - return false + if c.buffer.Len() == 0 && c.responses.Len() == 0 && c.finished { + + return false, nil + } + + if c.buffer.Len() == 0 && c.responses.Len() > 0 { + if response, ok := c.responses.Pop().(json.RawMessage); ok { + + var value interface{} + err := json.Unmarshal(response, &value) + if err != nil { + return false, err } - c.mu.Lock() - } - // Load the new response into the buffer - if len(c.responses) > 0 { - var err error - c.buffer = c.responses[0].Responses + value, err = recursivelyConvertPseudotype(value, c.opts) if err != nil { - c.err = err - c.mu.Unlock() - return false + return false, err } - c.responses = c.responses[1:] - } - } - // If the buffer is no longer empty then move on otherwise - // try again - if len(c.buffer) > 0 { - break + // If response is an ATOM then try and convert to an array + if data, ok := value.([]interface{}); ok && c.isAtom { + for _, v := range data { + c.buffer.Push(v) + } + } else if value == nil { + c.buffer.Push(nil) + } else { + c.buffer.Push(value) + } + } } - } - - if c.err != nil { - c.mu.Unlock() - return false - } - var data interface{} - data, c.buffer = c.buffer[0], c.buffer[1:] + if c.buffer.Len() > 0 { + data := c.buffer.Pop() - data, err := recursivelyConvertPseudotype(data, c.opts) - if err != nil { - c.err = err - c.mu.Unlock() - return false - } + err := encoding.Decode(dest, data) + if err != nil { + return false, err + } - c.mu.Unlock() - err = encoding.Decode(result, data) - if err != nil { - c.mu.Lock() - if c.err == nil { - c.err = err + return true, nil } - c.mu.Unlock() - - return false } - return true + return false, c.lastErr } // All retrieves all documents from the result set into the provided slice @@ -200,7 +221,17 @@ func (c *Cursor) All(result interface{}) error { i++ } resultv.Elem().Set(slicev.Slice(0, i)) - return c.Close() + + if err := c.Err(); err != nil { + c.Close() + return err + } + + if err := c.Close(); err != nil { + return err + } + + return nil } // One retrieves a single document from the result set into the provided @@ -210,52 +241,165 @@ func (c *Cursor) One(result interface{}) error { return ErrEmptyResult } - var err error - ok := c.Next(result) - if !ok { - err = c.Err() - if err == nil { - err = ErrEmptyResult + hasResult := c.Next(result) + + if err := c.Err(); err != nil { + c.Close() + return err + } + + if err := c.Close(); err != nil { + return err + } + + if !hasResult { + return ErrEmptyResult + } + + return nil +} + +// IsNil tests if the current row is nil. +func (c *Cursor) IsNil() bool { + if c.buffer.Len() > 0 { + bufferedItem := c.buffer.Peek() + if bufferedItem == nil { + return true + } + + if bufferedItem == nil { + return true + } + + return false + } + + if c.responses.Len() > 0 { + response := c.responses.Peek() + if response == nil { + return true + } + + if response, ok := response.(json.RawMessage); ok { + if string(response) == "null" { + return true + } } + + return false } - if e := c.Close(); e != nil { - err = e + return true +} + +// fetchMore fetches more rows from the database. +// +// If wait is true then it will wait for the database to reply otherwise it +// will return after sending the continue query. +func (c *Cursor) fetchMore() error { + var err error + if !c.fetching { + c.fetching = true + + if c.closed { + return errCursorClosed + } + + q := Query{ + Type: p.Query_CONTINUE, + Token: c.token, + } + + _, _, err = c.conn.Query(q) + c.handleError(err) } return err } -// Tests if the current row is nil. -func (c *Cursor) IsNil() bool { - c.mu.Lock() - defer c.mu.Unlock() +// handleError sets the value of lastErr to err if lastErr is not yet set. +func (c *Cursor) handleError(err error) error { + return c.handleErrorLocked(err) +} + +// handleError sets the value of lastErr to err if lastErr is not yet set. +func (c *Cursor) handleErrorLocked(err error) error { + if c.lastErr == nil { + c.lastErr = err + } - return (len(c.responses) == 0 && len(c.buffer) == 0) || (len(c.buffer) == 1 && c.buffer[0] == nil) + return c.lastErr } +// extend adds the result of a continue query to the cursor. func (c *Cursor) extend(response *Response) { - c.mu.Lock() - c.finished = response.Type != p.Response_SUCCESS_PARTIAL && - response.Type != p.Response_SUCCESS_FEED - c.responses = append(c.responses, response) - - // Prefetch results if needed - if len(c.responses) == 1 && !c.finished { - if err := c.session.asyncContinueQuery(c); err != nil { - c.err = err - return - } + for _, response := range response.Responses { + c.responses.Push(response) } - // Load the new response into the buffer - var err error - c.buffer = c.responses[0].Responses - if err != nil { - c.err = err + c.finished = response.Type != p.Response_SUCCESS_PARTIAL && + response.Type != p.Response_SUCCESS_FEED && + response.Type != p.Response_SUCCESS_ATOM_FEED + c.fetching = false + c.isAtom = response.Type == p.Response_SUCCESS_ATOM +} - return +// Queue structure used for storing responses + +type queue struct { + elems []interface{} + nelems, popi, pushi int +} + +func (q *queue) Len() int { + return q.nelems +} +func (q *queue) Push(elem interface{}) { + if q.nelems == len(q.elems) { + q.expand() + } + q.elems[q.pushi] = elem + q.nelems++ + q.pushi = (q.pushi + 1) % len(q.elems) +} +func (q *queue) Pop() (elem interface{}) { + if q.nelems == 0 { + return nil + } + elem = q.elems[q.popi] + q.elems[q.popi] = nil // Help GC. + q.nelems-- + q.popi = (q.popi + 1) % len(q.elems) + return elem +} +func (q *queue) Peek() (elem interface{}) { + if q.nelems == 0 { + return nil + } + return q.elems[q.popi] +} +func (q *queue) expand() { + curcap := len(q.elems) + var newcap int + if curcap == 0 { + newcap = 8 + } else if curcap < 1024 { + newcap = curcap * 2 + } else { + newcap = curcap + (curcap / 4) + } + elems := make([]interface{}, newcap) + if q.popi == 0 { + copy(elems, q.elems) + q.pushi = curcap + } else { + newpopi := newcap - (curcap - q.popi) + copy(elems, q.elems[:q.popi]) + copy(elems[newpopi:], q.elems[q.popi:]) + q.popi = newpopi + } + for i := range q.elems { + q.elems[i] = nil // Help GC. } - c.responses = c.responses[1:] - c.mu.Unlock() + q.elems = elems } diff --git a/cursor_test.go b/cursor_test.go index c7374e14..efc2891f 100644 --- a/cursor_test.go +++ b/cursor_test.go @@ -188,6 +188,10 @@ func (s *RethinkSuite) TestEmptyResults(c *test.C) { c.Assert(err, test.Equals, ErrEmptyResult) c.Assert(res.IsNil(), test.Equals, true) + res, err = Expr(nil).Run(sess) + c.Assert(err, test.IsNil) + c.Assert(res.IsNil(), test.Equals, true) + res, err = Db("test").Table("test").Get("missing value").Run(sess) c.Assert(err, test.IsNil) c.Assert(res.IsNil(), test.Equals, true) diff --git a/doc.go b/doc.go index 5255510a..7882ea23 100644 --- a/doc.go +++ b/doc.go @@ -1,6 +1,6 @@ // Go driver for RethinkDB // -// Current version: v0.5.0 (RethinkDB v1.15.1) +// Current version: v0.6.0 (RethinkDB v1.16.0) // For more in depth information on how to use RethinkDB check out the API docs // at http://rethinkdb.com/api package gorethink diff --git a/encoding/cache.go b/encoding/cache.go index 0c64bf3a..feb28f2a 100644 --- a/encoding/cache.go +++ b/encoding/cache.go @@ -12,6 +12,7 @@ import ( type field struct { name string nameBytes []byte // []byte(name) + equalFold func(s, t []byte) bool tag bool index []int @@ -22,6 +23,7 @@ type field struct { func fillField(f field) field { f.nameBytes = []byte(f.name) + f.equalFold = foldFunc(f.nameBytes) return f } diff --git a/encoding/decoder.go b/encoding/decoder.go index 16d09895..bd10d710 100644 --- a/encoding/decoder.go +++ b/encoding/decoder.go @@ -1,21 +1,16 @@ -// // This code is based on encoding/json and gorilla/schema - package encoding import ( - - // "errors" "errors" "reflect" "runtime" - - // "runtime" - "strconv" - "strings" + "sync" ) var byteSliceType = reflect.TypeOf([]byte(nil)) +type decoderFunc func(dv reflect.Value, sv reflect.Value) + // Decode decodes map[string]interface{} into a struct. The first parameter // must be a pointer. func Decode(dst interface{}, src interface{}) (err error) { @@ -34,404 +29,83 @@ func Decode(dst interface{}, src interface{}) (err error) { dv := reflect.ValueOf(dst) sv := reflect.ValueOf(src) - if dv.Kind() != reflect.Ptr || dv.IsNil() { - return &InvalidDecodeError{reflect.TypeOf(dst)} - } - - decode(dv, sv) - - return nil -} - -// decode decodes the source value into the destination value -func decode(dv, sv reflect.Value) { - if dv.IsValid() { - val := indirect(dv, false) - val.Set(reflect.Zero(val.Type())) - } - - if dv.IsValid() && sv.IsValid() { - // Ensure that the source value has the correct type of parsing - if sv.Kind() == reflect.Interface { - sv = reflect.ValueOf(sv.Interface()) - } - - switch sv.Kind() { - default: - decodeLiteral(dv, sv) - case reflect.Slice, reflect.Array: - decodeArray(dv, sv) - case reflect.Map: - decodeObject(dv, sv) - case reflect.Struct: - dv = indirect(dv, false) - dv.Set(sv) + if dv.Kind() != reflect.Ptr { + return &DecodeTypeError{ + DestType: dv.Type(), + SrcType: sv.Type(), + Reason: "must be a pointer", } } -} - -// decodeLiteral decodes the source value into the destination value. This function -// is used to decode literal values. -func decodeLiteral(dv reflect.Value, sv reflect.Value) { - dv = indirect(dv, true) - - // Special case for if sv is nil: - switch sv.Kind() { - case reflect.Invalid: - dv.Set(reflect.Zero(dv.Type())) - return - } - - // Attempt to convert the value from the source type to the destination type - switch value := sv.Interface().(type) { - case nil: - switch dv.Kind() { - case reflect.Interface, reflect.Ptr, reflect.Map, reflect.Slice: - dv.Set(reflect.Zero(dv.Type())) - } - case bool: - switch dv.Kind() { - default: - panic(&DecodeTypeError{"bool", dv.Type()}) - return - case reflect.Bool: - dv.SetBool(value) - case reflect.String: - dv.SetString(strconv.FormatBool(value)) - case reflect.Interface: - if dv.NumMethod() == 0 { - dv.Set(reflect.ValueOf(value)) - } else { - panic(&DecodeTypeError{"bool", dv.Type()}) - return - } - } - - case string: - switch dv.Kind() { - default: - panic(&DecodeTypeError{"string", dv.Type()}) - return - case reflect.String: - dv.SetString(value) - case reflect.Bool: - b, err := strconv.ParseBool(value) - if err != nil { - panic(&DecodeTypeError{"string", dv.Type()}) - return - } - dv.SetBool(b) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - n, err := strconv.ParseInt(value, 10, 64) - if err != nil || dv.OverflowInt(n) { - panic(&DecodeTypeError{"string", dv.Type()}) - return - } - dv.SetInt(n) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - n, err := strconv.ParseUint(value, 10, 64) - if err != nil || dv.OverflowUint(n) { - panic(&DecodeTypeError{"string", dv.Type()}) - return - } - dv.SetUint(n) - case reflect.Float32, reflect.Float64: - n, err := strconv.ParseFloat(value, 64) - if err != nil || dv.OverflowFloat(n) { - panic(&DecodeTypeError{"string", dv.Type()}) - return - } - dv.SetFloat(n) - case reflect.Interface: - if dv.NumMethod() == 0 { - dv.Set(reflect.ValueOf(string(value))) - } else { - panic(&DecodeTypeError{"string", dv.Type()}) - return - } - } - - case int, int8, int16, int32, int64: - switch dv.Kind() { - default: - panic(&DecodeTypeError{"int", dv.Type()}) - return - case reflect.Interface: - if dv.NumMethod() != 0 { - panic(&DecodeTypeError{"int", dv.Type()}) - return - } - dv.Set(reflect.ValueOf(value)) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - dv.SetInt(int64(reflect.ValueOf(value).Int())) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - dv.SetUint(uint64(reflect.ValueOf(value).Int())) - case reflect.Float32, reflect.Float64: - dv.SetFloat(float64(reflect.ValueOf(value).Int())) - case reflect.String: - dv.SetString(strconv.FormatInt(int64(reflect.ValueOf(value).Int()), 10)) + dv = dv.Elem() + if !dv.CanAddr() { + return &DecodeTypeError{ + DestType: dv.Type(), + SrcType: sv.Type(), + Reason: "must be addressable", } - case uint, uint8, uint16, uint32, uint64: - switch dv.Kind() { - default: - panic(&DecodeTypeError{"uint", dv.Type()}) - return - case reflect.Interface: - if dv.NumMethod() != 0 { - panic(&DecodeTypeError{"uint", dv.Type()}) - return - } - dv.Set(reflect.ValueOf(value)) - - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - dv.SetInt(int64(reflect.ValueOf(value).Uint())) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - dv.SetUint(uint64(reflect.ValueOf(value).Uint())) - case reflect.Float32, reflect.Float64: - dv.SetFloat(float64(reflect.ValueOf(value).Uint())) - case reflect.String: - dv.SetString(strconv.FormatUint(uint64(reflect.ValueOf(value).Uint()), 10)) - } - case float32, float64: - switch dv.Kind() { - default: - panic(&DecodeTypeError{"float", dv.Type()}) - return - case reflect.Interface: - if dv.NumMethod() != 0 { - panic(&DecodeTypeError{"float", dv.Type()}) - return - } - dv.Set(reflect.ValueOf(value)) - - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - dv.SetInt(int64(reflect.ValueOf(value).Float())) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - dv.SetUint(uint64(reflect.ValueOf(value).Float())) - case reflect.Float32, reflect.Float64: - dv.SetFloat(float64(reflect.ValueOf(value).Float())) - case reflect.String: - dv.SetString(strconv.FormatFloat(float64(reflect.ValueOf(value).Float()), 'g', -1, 64)) - } - default: - panic(&DecodeTypeError{sv.Type().String(), dv.Type()}) - return } - return + decode(dv, sv) + return nil } -// decodeArray decodes the source value into the destination value. This function -// is used when the source value is a slice or array. -func decodeArray(dv reflect.Value, sv reflect.Value) { - dv = indirect(dv, false) - dt := dv.Type() - - // Ensure that the dest is also a slice or array - switch dt.Kind() { - case reflect.Interface: - if dv.NumMethod() == 0 { - // Decoding into nil interface? Switch to non-reflect code. - dv.Set(reflect.ValueOf(decodeArrayInterface(sv))) - - return - } - // Otherwise it's invalid. - fallthrough - default: - panic(&DecodeTypeError{"array", dv.Type()}) - return - case reflect.Array: - case reflect.Slice: - if sv.Type() == byteSliceType { - dv.SetBytes(sv.Bytes()) - return - } - - newdv := reflect.MakeSlice(dv.Type(), dv.Len(), dv.Cap()) - dv.Set(newdv) - - break - } - - // Iterate through the slice/array and decode each element before adding it - // to the dest slice/array - i := 0 - for i < sv.Len() { - if dv.Kind() == reflect.Slice { - // Get element of array, growing if necessary. - if i >= dv.Cap() { - newcap := dv.Cap() + dv.Cap()/2 - if newcap < 4 { - newcap = 4 - } - newdv := reflect.MakeSlice(dv.Type(), dv.Len(), newcap) - reflect.Copy(newdv, dv) - dv.Set(newdv) - } - if i >= dv.Len() { - dv.SetLen(i + 1) - } - } - - if i < dv.Len() { - // Decode into element. - decode(dv.Index(i), sv.Index(i)) - } else { - // Ran out of fixed array: skip. - decode(reflect.Value{}, sv.Index(i)) - } - - i++ - } - - // Ensure that the destination is the correct size - if i < dv.Len() { - if dv.Kind() == reflect.Array { - // Array. Zero the rest. - z := reflect.Zero(dv.Type().Elem()) - for ; i < dv.Len(); i++ { - dv.Index(i).Set(z) - } - } else { - dv.SetLen(i) - } - } +// decode decodes the source value into the destination value +func decode(dv, sv reflect.Value) { + valueDecoder(dv, sv)(dv, sv) } -// decodeObject decodes the source value into the destination value. This function -// is used when the source value is a map or struct. -func decodeObject(dv reflect.Value, sv reflect.Value) (err error) { - dv = indirect(dv, false) - dt := dv.Type() - - // Decoding into nil interface? Switch to non-reflect code. - if dv.Kind() == reflect.Interface && dv.NumMethod() == 0 { - dv.Set(reflect.ValueOf(decodeObjectInterface(sv))) - return nil - } - - // Check type of target: struct or map[string]T - switch dv.Kind() { - case reflect.Map: - // map must have string kind - if dt.Key().Kind() != reflect.String { - panic(&DecodeTypeError{"object", dv.Type()}) - break - } - if dv.IsNil() { - dv.Set(reflect.MakeMap(dt)) - } - case reflect.Struct: - default: - panic(&DecodeTypeError{"object", dv.Type()}) - return - } - - var mapElem reflect.Value - - for _, key := range sv.MapKeys() { - var subdv reflect.Value - var subsv reflect.Value = sv.MapIndex(key) - - skey := key.Interface().(string) - - if dv.Kind() == reflect.Map { - elemType := dv.Type().Elem() - if !mapElem.IsValid() { - mapElem = reflect.New(elemType).Elem() - } else { - mapElem.Set(reflect.Zero(elemType)) - } - subdv = mapElem - } else { - var f *field - fields := cachedTypeFields(dv.Type()) - for i := range fields { - ff := &fields[i] - if ff.name == skey { - f = ff - break - } - if f == nil && strings.EqualFold(ff.name, skey) { - f = ff - } - } - if f != nil { - subdv = dv - for _, i := range f.index { - if subdv.Kind() == reflect.Ptr { - if subdv.IsNil() { - subdv.Set(reflect.New(subdv.Type().Elem())) - } - subdv = subdv.Elem() - } - subdv = subdv.Field(i) - } - } - } - - decode(subdv, subsv) - - if dv.Kind() == reflect.Map { - kv := reflect.ValueOf(skey) - dv.SetMapIndex(kv, subdv) - } - } - - return nil +type decoderCacheKey struct { + dt, st reflect.Type } -// The following methods are simplified versions of those above designed to use -// less reflection - -// decodeInterface decodes the source value into interface{} -func decodeInterface(sv reflect.Value) interface{} { - // Ensure that the source value has the correct type of parsing - if sv.Kind() == reflect.Interface { - sv = reflect.ValueOf(sv.Interface()) - } - - switch sv.Kind() { - case reflect.Slice, reflect.Array: - return decodeArrayInterface(sv) - case reflect.Map: - return decodeObjectInterface(sv) - default: - return decodeLiteralInterface(sv) - } +var decoderCache struct { + sync.RWMutex + m map[decoderCacheKey]decoderFunc } -// decodeArrayInterface decodes the source value into []interface{} -func decodeArrayInterface(sv reflect.Value) interface{} { - if sv.Type() == byteSliceType { - return sv.Bytes() +func valueDecoder(dv, sv reflect.Value) decoderFunc { + if !sv.IsValid() { + return invalidValueDecoder } - arr := []interface{}{} - for i := 0; i < sv.Len(); i++ { - arr = append(arr, decodeInterface(sv.Index(i))) + if dv.IsValid() { + val := indirect(dv, false) + val.Set(reflect.Zero(val.Type())) } - return arr -} -// decodeObjectInterface decodes the source value into map[string]interface{} -func decodeObjectInterface(sv reflect.Value) interface{} { - m := map[string]interface{}{} - for _, key := range sv.MapKeys() { - m[key.Interface().(string)] = decodeInterface(sv.MapIndex(key)) - } - return m + return typeDecoder(dv.Type(), sv.Type()) } -// decodeLiteralInterface returns the interface of the source value -func decodeLiteralInterface(sv reflect.Value) interface{} { - if !sv.IsValid() { - return nil - } - - return sv.Interface() +func typeDecoder(dt, st reflect.Type) decoderFunc { + decoderCache.RLock() + f := decoderCache.m[decoderCacheKey{dt, st}] + decoderCache.RUnlock() + if f != nil { + return f + } + + // To deal with recursive types, populate the map with an + // indirect func before we build it. This type waits on the + // real func (f) to be ready and then calls it. This indirect + // func is only used for recursive types. + decoderCache.Lock() + var wg sync.WaitGroup + wg.Add(1) + decoderCache.m[decoderCacheKey{dt, st}] = func(dv, sv reflect.Value) { + wg.Wait() + f(dv, sv) + } + decoderCache.Unlock() + + // Compute fields without lock. + // Might duplicate effort but won't hold other computations back. + f = newTypeDecoder(dt, st) + wg.Done() + decoderCache.Lock() + decoderCache.m[decoderCacheKey{dt, st}] = f + decoderCache.Unlock() + return f } // indirect walks down v allocating pointers as needed, diff --git a/encoding/decoder_test.go b/encoding/decoder_test.go index 9bcefada..90f065f5 100644 --- a/encoding/decoder_test.go +++ b/encoding/decoder_test.go @@ -1,6 +1,9 @@ package encoding import ( + "bytes" + "encoding/json" + "fmt" "image" "reflect" "testing" @@ -146,7 +149,7 @@ var decodeTests = []decodeTest{ {in: string("2"), ptr: new(interface{}), out: string("2")}, {in: "a\u1234", ptr: new(string), out: "a\u1234"}, {in: []interface{}{}, ptr: new([]string), out: []string{}}, - {in: map[string]interface{}{"X": []interface{}{1, 2, 3}, "Y": 4}, ptr: new(T), out: T{}, err: &DecodeTypeError{"array", reflect.TypeOf("")}}, + {in: map[string]interface{}{"X": []interface{}{1, 2, 3}, "Y": 4}, ptr: new(T), out: T{}, err: &DecodeTypeError{reflect.TypeOf([0]interface{}{}), reflect.TypeOf(""), ""}}, {in: map[string]interface{}{"x": 1}, ptr: new(tx), out: tx{}}, {in: map[string]interface{}{"F1": float64(1), "F2": 2, "F3": 3}, ptr: new(V), out: V{F1: float64(1), F2: int32(2), F3: string("3")}}, {in: map[string]interface{}{"F1": string("1"), "F2": 2, "F3": 3}, ptr: new(V), out: V{F1: string("1"), F2: int32(2), F3: string("3")}}, @@ -208,7 +211,6 @@ var decodeTests = []decodeTest{ Level1b: 9, Level1c: 10, Level1d: 11, - Level1e: 12, }, Loop: Loop{ Loop1: 13, @@ -227,7 +229,6 @@ var decodeTests = []decodeTest{ ptr: new(Ambig), out: Ambig{First: 1}, }, - { in: map[string]interface{}{"X": 1, "Y": 2}, ptr: new(S5), @@ -250,15 +251,14 @@ func TestDecode(t *testing.T) { v := reflect.New(reflect.TypeOf(tt.ptr).Elem()) err := Decode(v.Interface(), tt.in) - if tt.err != nil { - if !reflect.DeepEqual(err, tt.err) { - t.Errorf("#%d: got error %v want %v", i, err, tt.err) - } - + if !jsonEqual(err, tt.err) { + t.Errorf("#%d: got error %v want %v", i, err, tt.err) continue } - if !reflect.DeepEqual(v.Elem().Interface(), tt.out) { + if !jsonEqual(v.Elem().Interface(), tt.out) { + fmt.Printf("%#v\n", v.Elem().Interface()) + fmt.Printf("%#v\n", tt.out) t.Errorf("#%d: mismatch\nhave: %+v\nwant: %+v", i, v.Elem().Interface(), tt.out) continue } @@ -276,8 +276,7 @@ func TestDecode(t *testing.T) { t.Errorf("#%d: error re-decodeing: %v", i, err) continue } - - if !reflect.DeepEqual(v.Elem().Interface(), vv.Elem().Interface()) { + if !jsonEqual(v.Elem().Interface(), vv.Elem().Interface()) { t.Errorf("#%d: mismatch\nhave: %#+v\nwant: %#+v", i, v.Elem().Interface(), vv.Elem().Interface()) continue } @@ -303,34 +302,12 @@ func TestStringKind(t *testing.T) { t.Errorf("Unexpected error decoding: %v", err) } - if !reflect.DeepEqual(m1, m2) { + if !jsonEqual(m1, m2) { t.Error("Items should be equal after encoding and then decoding") } } -var decodeTypeErrorTests = []struct { - dest interface{} - src interface{} -}{ - {new(string), map[interface{}]interface{}{"user": "name"}}, - {new(error), map[interface{}]interface{}{}}, - {new(error), []interface{}{}}, - {new(error), ""}, - {new(error), 123}, - {new(error), true}, -} - -func TestDecodeTypeError(t *testing.T) { - for _, item := range decodeTypeErrorTests { - err := Decode(item.dest, item.src) - if _, ok := err.(*DecodeTypeError); !ok { - t.Errorf("expected type error for Decode(%q, type %T): got %T", - item.src, item.dest, err) - } - } -} - // Test handling of unexported fields that should be ignored. type unexportedFields struct { Name string @@ -358,7 +335,7 @@ func TestDecodeUnexported(t *testing.T) { if err != nil { t.Errorf("got error %v, expected nil", err) } - if !reflect.DeepEqual(out, want) { + if !jsonEqual(out, want) { t.Errorf("got %q, want %q", out, want) } } @@ -369,3 +346,16 @@ type Foo struct { type Bar struct { Baz int `gorethink:"baz"` } + +func jsonEqual(a, b interface{}) bool { + ba, err := json.Marshal(a) + if err != nil { + panic(err) + } + bb, err := json.Marshal(b) + if err != nil { + panic(err) + } + + return bytes.Compare(ba, bb) == 0 +} diff --git a/encoding/decoder_types.go b/encoding/decoder_types.go new file mode 100644 index 00000000..42dfde4b --- /dev/null +++ b/encoding/decoder_types.go @@ -0,0 +1,518 @@ +package encoding + +import ( + "bytes" + "fmt" + "reflect" + "strconv" +) + +// newTypeDecoder constructs an decoderFunc for a type. +func newTypeDecoder(dt, st reflect.Type) decoderFunc { + if dt.Implements(unmarshalerType) { + return unmarshalerDecoder + } + + if st.Kind() == reflect.Interface { + return interfaceAsTypeDecoder + } + + switch dt.Kind() { + case reflect.Bool: + switch st.Kind() { + case reflect.Bool: + return boolAsBoolDecoder + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return intAsBoolDecoder + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return uintAsBoolDecoder + case reflect.Float32, reflect.Float64: + return floatAsBoolDecoder + case reflect.String: + return stringAsBoolDecoder + default: + return decodeTypeError + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + switch st.Kind() { + case reflect.Bool: + return boolAsIntDecoder + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return intAsIntDecoder + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return uintAsIntDecoder + case reflect.Float32, reflect.Float64: + return floatAsIntDecoder + case reflect.String: + return stringAsIntDecoder + default: + return decodeTypeError + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + switch st.Kind() { + case reflect.Bool: + return boolAsUintDecoder + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return intAsUintDecoder + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return uintAsUintDecoder + case reflect.Float32, reflect.Float64: + return floatAsUintDecoder + case reflect.String: + return stringAsUintDecoder + default: + return decodeTypeError + } + case reflect.Float32, reflect.Float64: + switch st.Kind() { + case reflect.Bool: + return boolAsFloatDecoder + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return intAsFloatDecoder + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return uintAsFloatDecoder + case reflect.Float32, reflect.Float64: + return floatAsFloatDecoder + case reflect.String: + return stringAsFloatDecoder + default: + return decodeTypeError + } + case reflect.String: + switch st.Kind() { + case reflect.Bool: + return boolAsStringDecoder + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return intAsStringDecoder + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return uintAsStringDecoder + case reflect.Float32, reflect.Float64: + return floatAsStringDecoder + case reflect.String: + return stringAsStringDecoder + default: + return decodeTypeError + } + case reflect.Interface: + if !st.AssignableTo(dt) { + return decodeTypeError + } + + return interfaceDecoder + case reflect.Ptr: + return newPtrDecoder(dt, st) + case reflect.Map: + if st.AssignableTo(dt) { + return interfaceDecoder + } + + switch st.Kind() { + case reflect.Map: + return newMapAsMapDecoder(dt, st) + default: + return decodeTypeError + } + case reflect.Struct: + if st.AssignableTo(dt) { + return interfaceDecoder + } + + switch st.Kind() { + case reflect.Map: + if kind := st.Key().Kind(); kind != reflect.String && kind != reflect.Interface { + return newDecodeTypeError(fmt.Errorf("map needs string keys")) + } + + return newMapAsStructDecoder(dt, st) + default: + return decodeTypeError + } + case reflect.Slice: + if st.AssignableTo(dt) { + return interfaceDecoder + } + + switch st.Kind() { + case reflect.Array, reflect.Slice: + return newSliceDecoder(dt, st) + default: + return decodeTypeError + } + case reflect.Array: + if st.AssignableTo(dt) { + return interfaceDecoder + } + + switch st.Kind() { + case reflect.Array, reflect.Slice: + return newArrayDecoder(dt, st) + default: + return decodeTypeError + } + default: + return unsupportedTypeDecoder + } +} + +func invalidValueDecoder(dv, sv reflect.Value) { + dv.Set(reflect.Zero(dv.Type())) +} + +func unsupportedTypeDecoder(dv, sv reflect.Value) { + panic(&UnsupportedTypeError{dv.Type()}) +} + +func decodeTypeError(dv, sv reflect.Value) { + panic(&DecodeTypeError{ + DestType: dv.Type(), + SrcType: sv.Type(), + }) +} + +func newDecodeTypeError(err error) decoderFunc { + return func(dv, sv reflect.Value) { + panic(&DecodeTypeError{ + DestType: dv.Type(), + SrcType: sv.Type(), + Reason: err.Error(), + }) + } +} + +func interfaceDecoder(dv, sv reflect.Value) { + dv.Set(sv) +} + +func interfaceAsTypeDecoder(dv, sv reflect.Value) { + decode(dv, sv.Elem()) +} + +type ptrDecoder struct { + elemDec decoderFunc +} + +func (d *ptrDecoder) decode(dv, sv reflect.Value) { + v := reflect.New(dv.Type().Elem()) + d.elemDec(v, sv) + dv.Set(v) +} + +func newPtrDecoder(dt, st reflect.Type) decoderFunc { + dec := &ptrDecoder{typeDecoder(dt.Elem(), st)} + + return dec.decode +} + +func unmarshalerDecoder(dv, sv reflect.Value) { + if dv.Kind() != reflect.Ptr || dv.IsNil() { + panic(&InvalidUnmarshalError{sv.Type()}) + } + + u := dv.Interface().(Unmarshaler) + err := u.UnmarshalRQL(sv.Interface()) + if err != nil { + panic(&DecodeTypeError{dv.Type(), sv.Type(), err.Error()}) + } +} + +// Boolean decoders + +func boolAsBoolDecoder(dv, sv reflect.Value) { + dv.SetBool(sv.Bool()) +} +func boolAsIntDecoder(dv, sv reflect.Value) { + if sv.Bool() { + dv.SetInt(1) + } else { + dv.SetInt(0) + } +} +func boolAsUintDecoder(dv, sv reflect.Value) { + if sv.Bool() { + dv.SetUint(1) + } else { + dv.SetUint(0) + } +} +func boolAsFloatDecoder(dv, sv reflect.Value) { + if sv.Bool() { + dv.SetFloat(1) + } else { + dv.SetFloat(0) + } +} +func boolAsStringDecoder(dv, sv reflect.Value) { + if sv.Bool() { + dv.SetString("1") + } else { + dv.SetString("0") + } +} + +// Int decoders + +func intAsBoolDecoder(dv, sv reflect.Value) { + dv.SetBool(sv.Int() != 0) +} +func intAsIntDecoder(dv, sv reflect.Value) { + dv.SetInt(sv.Int()) +} +func intAsUintDecoder(dv, sv reflect.Value) { + dv.SetUint(uint64(sv.Int())) +} +func intAsFloatDecoder(dv, sv reflect.Value) { + dv.SetFloat(float64(sv.Int())) +} +func intAsStringDecoder(dv, sv reflect.Value) { + dv.SetString(strconv.FormatInt(sv.Int(), 10)) +} + +// Uint decoders + +func uintAsBoolDecoder(dv, sv reflect.Value) { + dv.SetBool(sv.Uint() != 0) +} +func uintAsIntDecoder(dv, sv reflect.Value) { + dv.SetInt(int64(sv.Uint())) +} +func uintAsUintDecoder(dv, sv reflect.Value) { + dv.SetUint(sv.Uint()) +} +func uintAsFloatDecoder(dv, sv reflect.Value) { + dv.SetFloat(float64(sv.Uint())) +} +func uintAsStringDecoder(dv, sv reflect.Value) { + dv.SetString(strconv.FormatUint(sv.Uint(), 10)) +} + +// Float decoders + +func floatAsBoolDecoder(dv, sv reflect.Value) { + dv.SetBool(sv.Float() != 0) +} +func floatAsIntDecoder(dv, sv reflect.Value) { + dv.SetInt(int64(sv.Float())) +} +func floatAsUintDecoder(dv, sv reflect.Value) { + dv.SetUint(uint64(sv.Float())) +} +func floatAsFloatDecoder(dv, sv reflect.Value) { + dv.SetFloat(float64(sv.Float())) +} +func floatAsStringDecoder(dv, sv reflect.Value) { + dv.SetString(strconv.FormatFloat(sv.Float(), 'f', -1, 64)) +} + +// String decoders + +func stringAsBoolDecoder(dv, sv reflect.Value) { + b, err := strconv.ParseBool(sv.String()) + if err == nil { + dv.SetBool(b) + } else if sv.String() == "" { + dv.SetBool(false) + } else { + panic(&DecodeTypeError{dv.Type(), sv.Type(), err.Error()}) + } +} +func stringAsIntDecoder(dv, sv reflect.Value) { + i, err := strconv.ParseInt(sv.String(), 0, dv.Type().Bits()) + if err == nil { + dv.SetInt(i) + } else { + panic(&DecodeTypeError{dv.Type(), sv.Type(), err.Error()}) + } +} +func stringAsUintDecoder(dv, sv reflect.Value) { + i, err := strconv.ParseUint(sv.String(), 0, dv.Type().Bits()) + if err == nil { + dv.SetUint(i) + } else { + panic(&DecodeTypeError{dv.Type(), sv.Type(), err.Error()}) + } +} +func stringAsFloatDecoder(dv, sv reflect.Value) { + f, err := strconv.ParseFloat(sv.String(), dv.Type().Bits()) + if err == nil { + dv.SetFloat(f) + } else { + panic(&DecodeTypeError{dv.Type(), sv.Type(), err.Error()}) + } +} +func stringAsStringDecoder(dv, sv reflect.Value) { + dv.SetString(sv.String()) +} + +// Slice/Array decoder + +type sliceDecoder struct { + arrayDec decoderFunc +} + +func (d *sliceDecoder) decode(dv, sv reflect.Value) { + if dv.Kind() == reflect.Slice { + dv.Set(reflect.MakeSlice(dv.Type(), dv.Len(), dv.Cap())) + } + + if !sv.IsNil() { + d.arrayDec(dv, sv) + } +} + +func newSliceDecoder(dt, st reflect.Type) decoderFunc { + // Byte slices get special treatment; arrays don't. + // if t.Elem().Kind() == reflect.Uint8 { + // return decodeByteSlice + // } + dec := &sliceDecoder{newArrayDecoder(dt, st)} + return dec.decode +} + +type arrayDecoder struct { + elemDec decoderFunc +} + +func (d *arrayDecoder) decode(dv, sv reflect.Value) { + // Iterate through the slice/array and decode each element before adding it + // to the dest slice/array + i := 0 + for i < sv.Len() { + if dv.Kind() == reflect.Slice { + // Get element of array, growing if necessary. + if i >= dv.Cap() { + newcap := dv.Cap() + dv.Cap()/2 + if newcap < 4 { + newcap = 4 + } + newdv := reflect.MakeSlice(dv.Type(), dv.Len(), newcap) + reflect.Copy(newdv, dv) + dv.Set(newdv) + } + if i >= dv.Len() { + dv.SetLen(i + 1) + } + } + + if i < dv.Len() { + // Decode into element. + d.elemDec(dv.Index(i), sv.Index(i)) + } + + i++ + } + + // Ensure that the destination is the correct size + if i < dv.Len() { + if dv.Kind() == reflect.Array { + // Array. Zero the rest. + z := reflect.Zero(dv.Type().Elem()) + for ; i < dv.Len(); i++ { + dv.Index(i).Set(z) + } + } else { + dv.SetLen(i) + } + } +} + +func newArrayDecoder(dt, st reflect.Type) decoderFunc { + dec := &arrayDecoder{typeDecoder(dt.Elem(), st.Elem())} + return dec.decode +} + +// Map decoder + +type mapAsMapDecoder struct { + keyDec, elemDec decoderFunc +} + +func (d *mapAsMapDecoder) decode(dv, sv reflect.Value) { + dt := dv.Type() + dv.Set(reflect.MakeMap(reflect.MapOf(dt.Key(), dt.Elem()))) + + var mapKey reflect.Value + var mapElem reflect.Value + + keyType := dv.Type().Key() + elemType := dv.Type().Elem() + + for _, sElemKey := range sv.MapKeys() { + var dElemKey reflect.Value + var dElemVal reflect.Value + + if !mapKey.IsValid() { + mapKey = reflect.New(keyType).Elem() + } else { + mapKey.Set(reflect.Zero(keyType)) + } + dElemKey = mapKey + + if !mapElem.IsValid() { + mapElem = reflect.New(elemType).Elem() + } else { + mapElem.Set(reflect.Zero(elemType)) + } + dElemVal = mapElem + + d.keyDec(dElemKey, sElemKey) + d.elemDec(dElemVal, sv.MapIndex(sElemKey)) + + dv.SetMapIndex(dElemKey, dElemVal) + } +} + +func newMapAsMapDecoder(dt, st reflect.Type) decoderFunc { + d := &mapAsMapDecoder{typeDecoder(dt.Key(), st.Key()), typeDecoder(dt.Elem(), st.Elem())} + return d.decode +} + +type mapAsStructDecoder struct { + fields []field + fieldDecs []decoderFunc +} + +func (d *mapAsStructDecoder) decode(dv, sv reflect.Value) { + for _, kv := range sv.MapKeys() { + var f *field + var fieldDec decoderFunc + key := []byte(kv.String()) + for i := range d.fields { + ff := &d.fields[i] + ffd := d.fieldDecs[i] + if bytes.Equal(ff.nameBytes, key) { + f = ff + fieldDec = ffd + break + } + if f == nil && ff.equalFold(ff.nameBytes, key) { + f = ff + fieldDec = ffd + } + } + + if f == nil { + continue + } + + dElemVal := fieldByIndex(dv, f.index) + sElemVal := sv.MapIndex(kv) + + if !sElemVal.IsValid() || !dElemVal.CanSet() { + continue + } + + fieldDec(dElemVal, sElemVal) + } +} + +func newMapAsStructDecoder(dt, st reflect.Type) decoderFunc { + fields := cachedTypeFields(dt) + se := &mapAsStructDecoder{ + fields: fields, + fieldDecs: make([]decoderFunc, len(fields)), + } + for i, f := range fields { + se.fieldDecs[i] = typeDecoder(typeByIndex(dt, f.index), st.Elem()) + } + return se.decode +} diff --git a/encoding/encoder.go b/encoding/encoder.go index 2e3b4ad7..3b0d3508 100644 --- a/encoding/encoder.go +++ b/encoding/encoder.go @@ -83,7 +83,7 @@ func typeEncoder(t reflect.Type) encoderFunc { // IgnoreType causes the encoder to ignore a type when encoding func IgnoreType(t reflect.Type) { - encoderCache.RLock() + encoderCache.Lock() encoderCache.m[t] = doNothingEncoder - encoderCache.RUnlock() + encoderCache.Unlock() } diff --git a/encoding/encoder_types.go b/encoding/encoder_types.go index e11ffde1..77ef4af7 100644 --- a/encoding/encoder_types.go +++ b/encoding/encoder_types.go @@ -1,22 +1,9 @@ package encoding import ( - "encoding" "encoding/base64" - "math" "reflect" - "strconv" "time" - - "github.com/dancannon/gorethink/types" -) - -var ( - marshalerType = reflect.TypeOf(new(Marshaler)).Elem() - textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem() - - timeType = reflect.TypeOf(new(time.Time)).Elem() - geometryType = reflect.TypeOf(new(types.Geometry)).Elem() ) // newTypeEncoder constructs an encoderFunc for a type. @@ -30,12 +17,11 @@ func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc { return newCondAddrEncoder(addrMarshalerEncoder, newTypeEncoder(t, false)) } } + // Check for psuedo-types first switch t { case timeType: return timePseudoTypeEncoder - case geometryType: - return geometryPseudoTypeEncoder } switch t.Kind() { @@ -45,10 +31,8 @@ func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc { return intEncoder case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: return uintEncoder - case reflect.Float32: - return float32Encoder - case reflect.Float64: - return float64Encoder + case reflect.Float32, reflect.Float64: + return floatEncoder case reflect.String: return stringEncoder case reflect.Interface: @@ -103,33 +87,6 @@ func addrMarshalerEncoder(v reflect.Value) interface{} { return ev } -func textMarshalerEncoder(v reflect.Value) interface{} { - if v.Kind() == reflect.Ptr && v.IsNil() { - return "" - } - m := v.Interface().(encoding.TextMarshaler) - b, err := m.MarshalText() - if err != nil { - panic(&MarshalerError{v.Type(), err}) - } - - return b -} - -func addrTextMarshalerEncoder(v reflect.Value) interface{} { - va := v.Addr() - if va.IsNil() { - return "" - } - m := va.Interface().(encoding.TextMarshaler) - b, err := m.MarshalText() - if err != nil { - panic(&MarshalerError{v.Type(), err}) - } - - return b -} - func boolEncoder(v reflect.Value) interface{} { if v.Bool() { return true @@ -146,21 +103,10 @@ func uintEncoder(v reflect.Value) interface{} { return v.Uint() } -type floatEncoder int // number of bits - -func (bits floatEncoder) encode(v reflect.Value) interface{} { - f := v.Float() - if math.IsInf(f, 0) || math.IsNaN(f) { - panic(&UnsupportedValueError{v, strconv.FormatFloat(f, 'g', -1, int(bits))}) - } - return f +func floatEncoder(v reflect.Value) interface{} { + return v.Float() } -var ( - float32Encoder = (floatEncoder(32)).encode - float64Encoder = (floatEncoder(64)).encode -) - func stringEncoder(v reflect.Value) interface{} { return v.String() } @@ -323,27 +269,6 @@ func timePseudoTypeEncoder(v reflect.Value) interface{} { } } -// Encode a time.Time value to the TIME RQL type -func geometryPseudoTypeEncoder(v reflect.Value) interface{} { - g := v.Interface().(types.Geometry) - - var coords interface{} - switch g.Type { - case "Point": - coords = g.Point.Marshal() - case "LineString": - coords = g.Line.Marshal() - case "Polygon": - coords = g.Lines.Marshal() - } - - return map[string]interface{}{ - "$reql_type$": "GEOMETRY", - "type": g.Type, - "coordinates": coords, - } -} - // Encode a byte slice to the BINARY RQL type func encodeByteSlice(v reflect.Value) interface{} { var b []byte diff --git a/encoding/encoding.go b/encoding/encoding.go index c61bf690..caa8fdef 100644 --- a/encoding/encoding.go +++ b/encoding/encoding.go @@ -1,6 +1,18 @@ package encoding -import "reflect" +import ( + "reflect" + "time" +) + +var ( + // type constants + stringType = reflect.TypeOf("") + timeType = reflect.TypeOf(new(time.Time)).Elem() + + marshalerType = reflect.TypeOf(new(Marshaler)).Elem() + unmarshalerType = reflect.TypeOf(new(Unmarshaler)).Elem() +) // Marshaler is the interface implemented by objects that // can marshal themselves into a valid RQL psuedo-type. @@ -16,4 +28,5 @@ type Unmarshaler interface { func init() { encoderCache.m = make(map[reflect.Type]encoderFunc) + decoderCache.m = make(map[decoderCacheKey]decoderFunc) } diff --git a/encoding/errors.go b/encoding/errors.go index 7fd78f4f..8b9ac2c5 100644 --- a/encoding/errors.go +++ b/encoding/errors.go @@ -3,34 +3,47 @@ package encoding import ( "fmt" "reflect" - "strconv" "strings" ) -// An InvalidEncodeError describes an invalid argument passed to Encode. -// (The argument to Encode must be a non-nil pointer.) -type InvalidEncodeError struct { +type MarshalerError struct { + Type reflect.Type + Err error +} + +func (e *MarshalerError) Error() string { + return "gorethink: error calling MarshalRQL for type " + e.Type.String() + ": " + e.Err.Error() +} + +type InvalidUnmarshalError struct { Type reflect.Type } -func (e *InvalidEncodeError) Error() string { +func (e *InvalidUnmarshalError) Error() string { if e.Type == nil { - return "gorethink: Encode(nil)" + return "gorethink: UnmarshalRQL(nil)" } if e.Type.Kind() != reflect.Ptr { - return "gorethink: Encode(non-pointer " + e.Type.String() + ")" + return "gorethink: UnmarshalRQL(non-pointer " + e.Type.String() + ")" } - return "gorethink: Encode(nil " + e.Type.String() + ")" + return "gorethink: UnmarshalRQL(nil " + e.Type.String() + ")" } -type MarshalerError struct { - Type reflect.Type - Err error +// An InvalidTypeError describes a value that was +// not appropriate for a value of a specific Go type. +type DecodeTypeError struct { + DestType, SrcType reflect.Type + Reason string } -func (e *MarshalerError) Error() string { - return "gorethink: error calling MarshalRQL for type " + e.Type.String() + ": " + e.Err.Error() +func (e *DecodeTypeError) Error() string { + if e.Reason != "" { + return "gorethink: could not decode type " + e.SrcType.String() + " into Go value of type " + e.DestType.String() + ": " + e.Reason + } else { + return "gorethink: could not decode type " + e.SrcType.String() + " into Go value of type " + e.DestType.String() + + } } // An UnsupportedTypeError is returned by Marshal when attempting @@ -43,6 +56,16 @@ func (e *UnsupportedTypeError) Error() string { return "gorethink: unsupported type: " + e.Type.String() } +// An UnsupportedTypeError is returned by Marshal when attempting +// to encode an unexpected value type. +type UnexpectedTypeError struct { + DestType, SrcType reflect.Type +} + +func (e *UnexpectedTypeError) Error() string { + return "gorethink: expected type: " + e.DestType.String() + ", got " + e.SrcType.String() +} + type UnsupportedValueError struct { Value reflect.Value Str string @@ -52,47 +75,6 @@ func (e *UnsupportedValueError) Error() string { return "gorethink: unsupported value: " + e.Str } -// An DecodeTypeError describes a value that was -// not appropriate for a value of a specific Go type. -type DecodeTypeError struct { - Value string // description of value - "bool", "array", "number -5" - Type reflect.Type // type of Go value it could not be assigned to -} - -func (e *DecodeTypeError) Error() string { - return "gorethink: cannot decode " + e.Value + " into Go value of type " + e.Type.String() -} - -// An DecodeFieldError describes a object key that -// led to an unexported (and therefore unwritable) struct field. -// (No longer used; kept for compatibility.) -type DecodeFieldError struct { - Key string - Type reflect.Type - Field reflect.StructField -} - -func (e *DecodeFieldError) Error() string { - return "gorethink: cannot decode object key " + strconv.Quote(e.Key) + " into unexported field " + e.Field.Name + " of type " + e.Type.String() -} - -// An InvalidDecodeError describes an invalid argument passed to Decode. -// (The argument to Decode must be a non-nil pointer.) -type InvalidDecodeError struct { - Type reflect.Type -} - -func (e *InvalidDecodeError) Error() string { - if e.Type == nil { - return "gorethink: Decode(nil)" - } - - if e.Type.Kind() != reflect.Ptr { - return "gorethink: Decode(non-pointer " + e.Type.String() + ")" - } - return "gorethink: Decode(nil " + e.Type.String() + ")" -} - // Error implements the error interface and can represents multiple // errors that occur in the course of a single decode. type Error struct { diff --git a/encoding/fold.go b/encoding/fold.go new file mode 100644 index 00000000..21c9e68e --- /dev/null +++ b/encoding/fold.go @@ -0,0 +1,139 @@ +package encoding + +import ( + "bytes" + "unicode/utf8" +) + +const ( + caseMask = ^byte(0x20) // Mask to ignore case in ASCII. + kelvin = '\u212a' + smallLongEss = '\u017f' +) + +// foldFunc returns one of four different case folding equivalence +// functions, from most general (and slow) to fastest: +// +// 1) bytes.EqualFold, if the key s contains any non-ASCII UTF-8 +// 2) equalFoldRight, if s contains special folding ASCII ('k', 'K', 's', 'S') +// 3) asciiEqualFold, no special, but includes non-letters (including _) +// 4) simpleLetterEqualFold, no specials, no non-letters. +// +// The letters S and K are special because they map to 3 runes, not just 2: +// * S maps to s and to U+017F 'ſ' Latin small letter long s +// * k maps to K and to U+212A 'K' Kelvin sign +// See http://play.golang.org/p/tTxjOc0OGo +// +// The returned function is specialized for matching against s and +// should only be given s. It's not curried for performance reasons. +func foldFunc(s []byte) func(s, t []byte) bool { + nonLetter := false + special := false // special letter + for _, b := range s { + if b >= utf8.RuneSelf { + return bytes.EqualFold + } + upper := b & caseMask + if upper < 'A' || upper > 'Z' { + nonLetter = true + } else if upper == 'K' || upper == 'S' { + // See above for why these letters are special. + special = true + } + } + if special { + return equalFoldRight + } + if nonLetter { + return asciiEqualFold + } + return simpleLetterEqualFold +} + +// equalFoldRight is a specialization of bytes.EqualFold when s is +// known to be all ASCII (including punctuation), but contains an 's', +// 'S', 'k', or 'K', requiring a Unicode fold on the bytes in t. +// See comments on foldFunc. +func equalFoldRight(s, t []byte) bool { + for _, sb := range s { + if len(t) == 0 { + return false + } + tb := t[0] + if tb < utf8.RuneSelf { + if sb != tb { + sbUpper := sb & caseMask + if 'A' <= sbUpper && sbUpper <= 'Z' { + if sbUpper != tb&caseMask { + return false + } + } else { + return false + } + } + t = t[1:] + continue + } + // sb is ASCII and t is not. t must be either kelvin + // sign or long s; sb must be s, S, k, or K. + tr, size := utf8.DecodeRune(t) + switch sb { + case 's', 'S': + if tr != smallLongEss { + return false + } + case 'k', 'K': + if tr != kelvin { + return false + } + default: + return false + } + t = t[size:] + + } + if len(t) > 0 { + return false + } + return true +} + +// asciiEqualFold is a specialization of bytes.EqualFold for use when +// s is all ASCII (but may contain non-letters) and contains no +// special-folding letters. +// See comments on foldFunc. +func asciiEqualFold(s, t []byte) bool { + if len(s) != len(t) { + return false + } + for i, sb := range s { + tb := t[i] + if sb == tb { + continue + } + if ('a' <= sb && sb <= 'z') || ('A' <= sb && sb <= 'Z') { + if sb&caseMask != tb&caseMask { + return false + } + } else { + return false + } + } + return true +} + +// simpleLetterEqualFold is a specialization of bytes.EqualFold for +// use when s is all ASCII letters (no underscores, etc) and also +// doesn't contain 'k', 'K', 's', or 'S'. +// See comments on foldFunc. +func simpleLetterEqualFold(s, t []byte) bool { + if len(s) != len(t) { + return false + } + for i, b := range s { + if b&caseMask != t[i]&caseMask { + return false + } + } + return true +} diff --git a/encoding/utils.go b/encoding/utils.go index 43c736f8..efaaedcf 100644 --- a/encoding/utils.go +++ b/encoding/utils.go @@ -2,8 +2,8 @@ package encoding import "reflect" -func getKind(val reflect.Value) reflect.Kind { - kind := val.Kind() +func getTypeKind(t reflect.Type) reflect.Kind { + kind := t.Kind() switch { case kind >= reflect.Int && kind <= reflect.Int64: @@ -39,12 +39,13 @@ func fieldByIndex(v reflect.Value, index []int) reflect.Value { for _, i := range index { if v.Kind() == reflect.Ptr { if v.IsNil() { - return reflect.Value{} + v.Set(reflect.New(v.Type().Elem())) } v = v.Elem() } v = v.Field(i) } + return v } diff --git a/errors.go b/errors.go index c6dab8c4..4ad38c15 100644 --- a/errors.go +++ b/errors.go @@ -2,12 +2,18 @@ package gorethink import ( "bytes" + "encoding/json" "errors" "fmt" p "github.com/dancannon/gorethink/ql2" ) +var ( + ErrNoConnections = errors.New("gorethink: no connections were made when creating the session") + ErrConnectionClosed = errors.New("gorethink: the connection is closed") +) + func printCarrots(t Term, frames []*p.Frame) string { var frame *p.Frame if len(frames) > 1 { @@ -57,7 +63,17 @@ type rqlResponseError struct { } func (e rqlResponseError) Error() string { - return fmt.Sprintf("gorethink: %s in: \n%s", e.response.Responses[0], e.term.String()) + var err = "An error occurred" + if e.response != nil { + json.Unmarshal(e.response.Responses[0], &err) + } + + if e.term == nil { + return fmt.Sprintf("gorethink: %s", err) + } + + return fmt.Sprintf("gorethink: %s in: \n%s", err, e.term.String()) + } func (e rqlResponseError) String() string { diff --git a/example_query_table_test.go b/example_query_table_test.go index 2208323e..e1304711 100644 --- a/example_query_table_test.go +++ b/example_query_table_test.go @@ -24,7 +24,7 @@ func Example_TableCreate() { log.Fatalf("Error creating table: %s", err) } - fmt.Printf("%d table created", response.Created) + fmt.Printf("%d table created", response.TablesCreated) // Output: // 1 table created diff --git a/gorethink_test.go b/gorethink_test.go index f0acb5fb..bb4df83f 100644 --- a/gorethink_test.go +++ b/gorethink_test.go @@ -45,10 +45,8 @@ var _ = test.Suite(&RethinkSuite{}) func (s *RethinkSuite) SetUpSuite(c *test.C) { var err error sess, err = Connect(ConnectOpts{ - Address: url, - MaxIdle: 3, - MaxActive: 3, - AuthKey: authKey, + Address: url, + AuthKey: authKey, }) c.Assert(err, test.IsNil) } @@ -230,7 +228,7 @@ func (s *RethinkSuite) BenchmarkNoReplyExpr(c *test.C) { for i := 0; i < c.N; i++ { // Test query query := Expr(true) - err := query.Exec(sess, RunOpts{NoReply: true}) + err := query.Exec(sess, ExecOpts{NoReply: true}) c.Assert(err, test.IsNil) } } diff --git a/pool.go b/pool.go new file mode 100644 index 00000000..15bb116c --- /dev/null +++ b/pool.go @@ -0,0 +1,532 @@ +package gorethink + +import ( + "errors" + "fmt" + "runtime" + "sync" +) + +const defaultMaxIdleConns = 1 + +// maxBadConnRetries is the number of maximum retries if the driver returns +// driver.ErrBadConn to signal a broken connection. +const maxBadConnRetries = 10 + +var ( + connectionRequestQueueSize = 1000000 + + errPoolClosed = errors.New("gorethink: pool is closed") + errConnClosed = errors.New("gorethink: conn is closed") + errConnBusy = errors.New("gorethink: conn is busy") + errConnInactive = errors.New("gorethink: conn was never active") +) + +// depSet is a finalCloser's outstanding dependencies +type depSet map[interface{}]bool // set of true bools +// The finalCloser interface is used by (*Pool).addDep and related +// dependency reference counting. +type finalCloser interface { + // finalClose is called when the reference count of an object + // goes to zero. (*Pool).mu is not held while calling it. + finalClose() error +} + +type Pool struct { + opts *ConnectOpts + + mu sync.Mutex // protects following fields + err error // the last error that occurred + freeConn []*poolConn + connRequests []chan connRequest + numOpen int + pendingOpens int + // Used to signal the need for new connections + // a goroutine running connectionOpener() reads on this chan and + // maybeOpenNewConnections sends on the chan (one send per needed connection) + // It is closed during p.Close(). The close tells the connectionOpener + // goroutine to exit. + openerCh chan struct{} + closed bool + dep map[finalCloser]depSet + lastPut map[*poolConn]string // stacktrace of last conn's put; debug only + maxIdle int // zero means defaultMaxIdleConns; negative means 0 + maxOpen int // <= 0 means unlimited +} + +func NewPool(opts *ConnectOpts) (*Pool, error) { + p := &Pool{ + opts: opts, + + openerCh: make(chan struct{}, connectionRequestQueueSize), + lastPut: make(map[*poolConn]string), + maxIdle: opts.MaxIdle, + } + go p.connectionOpener() + return p, nil +} + +// Ping verifies a connection to the database is still alive, +// establishing a connection if necessary. +func (p *Pool) Ping() error { + pc, err := p.conn() + if err != nil { + return err + } + p.putConn(pc, nil) + return nil +} + +// Close closes the database, releasing any open resources. +// +// It is rare to Close a Pool, as the Pool handle is meant to be +// long-lived and shared between many goroutines. +func (p *Pool) Close() error { + p.mu.Lock() + if p.closed { // Make Pool.Close idempotent + p.mu.Unlock() + return nil + } + close(p.openerCh) + var err error + fns := make([]func() error, 0, len(p.freeConn)) + for _, pc := range p.freeConn { + fns = append(fns, pc.closePoolLocked()) + } + p.freeConn = nil + p.closed = true + for _, req := range p.connRequests { + close(req) + } + p.mu.Unlock() + for _, fn := range fns { + err1 := fn() + if err1 != nil { + err = err1 + } + } + return err +} + +func (p *Pool) maxIdleConnsLocked() int { + n := p.maxIdle + switch { + case n == 0: + return defaultMaxIdleConns + case n < 0: + return 0 + default: + return n + } +} + +// SetMaxIdleConns sets the maximum number of connections in the idle +// connection pool. +// +// If MaxOpenConns is greater than 0 but less than the new MaxIdleConns +// then the new MaxIdleConns will be reduced to match the MaxOpenConns limit +// +// If n <= 0, no idle connections are retained. +func (p *Pool) SetMaxIdleConns(n int) { + p.mu.Lock() + if n > 0 { + p.maxIdle = n + } else { + // No idle connections. + p.maxIdle = -1 + } + // Make sure maxIdle doesn't exceed maxOpen + if p.maxOpen > 0 && p.maxIdleConnsLocked() > p.maxOpen { + p.maxIdle = p.maxOpen + } + var closing []*poolConn + idleCount := len(p.freeConn) + maxIdle := p.maxIdleConnsLocked() + if idleCount > maxIdle { + closing = p.freeConn[maxIdle:] + p.freeConn = p.freeConn[:maxIdle] + } + p.mu.Unlock() + for _, c := range closing { + c.Close() + } +} + +// SetMaxOpenConns sets the maximum number of open connections to the database. +// +// If MaxIdleConns is greater than 0 and the new MaxOpenConns is less than +// MaxIdleConns, then MaxIdleConns will be reduced to match the new +// MaxOpenConns limit +// +// If n <= 0, then there is no limit on the number of open connections. +// The default is 0 (unlimited). +func (p *Pool) SetMaxOpenConns(n int) { + p.mu.Lock() + p.maxOpen = n + if n < 0 { + p.maxOpen = 0 + } + syncMaxIdle := p.maxOpen > 0 && p.maxIdleConnsLocked() > p.maxOpen + p.mu.Unlock() + if syncMaxIdle { + p.SetMaxIdleConns(n) + } +} + +// Assumes p.mu is locked. +// If there are connRequests and the connection limit hasn't been reached, +// then tell the connectionOpener to open new connections. +func (p *Pool) maybeOpenNewConnections() { + numRequests := len(p.connRequests) - p.pendingOpens + if p.maxOpen > 0 { + numCanOpen := p.maxOpen - (p.numOpen + p.pendingOpens) + if numRequests > numCanOpen { + numRequests = numCanOpen + } + } + for numRequests > 0 { + p.pendingOpens++ + numRequests-- + p.openerCh <- struct{}{} + } +} + +// Runs in a separate goroutine, opens new connections when requested. +func (p *Pool) connectionOpener() { + for _ = range p.openerCh { + p.openNewConnection() + } +} + +// Open one new connection +func (p *Pool) openNewConnection() { + ci, err := NewConnection(p.opts) + p.mu.Lock() + defer p.mu.Unlock() + if p.closed { + if err == nil { + ci.Close() + } + return + } + p.pendingOpens-- + if err != nil { + p.putConnPoolLocked(nil, err) + return + } + pc := &poolConn{ + p: p, + ci: ci, + } + if p.putConnPoolLocked(pc, err) { + p.addDepLocked(pc, pc) + p.numOpen++ + } else { + ci.Close() + } +} + +// connRequest represents one request for a new connection +// When there are no idle connections available, Pool.conn will create +// a new connRequest and put it on the p.connRequests list. +type connRequest struct { + conn *poolConn + err error +} + +// conn returns a newly-opened or cached *poolConn +func (p *Pool) conn() (*poolConn, error) { + p.mu.Lock() + if p.closed { + p.mu.Unlock() + return nil, errPoolClosed + } + // If p.maxOpen > 0 and the number of open connections is over the limit + // and there are no free connection, make a request and wait. + if p.maxOpen > 0 && p.numOpen >= p.maxOpen && len(p.freeConn) == 0 { + // Make the connRequest channel. It's buffered so that the + // connectionOpener doesn't block while waiting for the req to be read. + req := make(chan connRequest, 1) + p.connRequests = append(p.connRequests, req) + p.mu.Unlock() + ret := <-req + return ret.conn, ret.err + } + if c := len(p.freeConn); c > 0 { + conn := p.freeConn[0] + copy(p.freeConn, p.freeConn[1:]) + p.freeConn = p.freeConn[:c-1] + conn.inUse = true + p.mu.Unlock() + return conn, nil + } + p.numOpen++ // optimistically + p.mu.Unlock() + ci, err := NewConnection(p.opts) + if err != nil { + p.mu.Lock() + p.numOpen-- // correct for earlier optimism + p.mu.Unlock() + return nil, err + } + p.mu.Lock() + pc := &poolConn{ + p: p, + ci: ci, + } + p.addDepLocked(pc, pc) + pc.inUse = true + p.mu.Unlock() + return pc, nil +} + +// connIfFree returns (wanted, nil) if wanted is still a valid conn and +// isn't in use. +// +// The error is errConnClosed if the connection if the requested connection +// is invalid because it's been closed. +// +// The error is errConnBusy if the connection is in use. +func (p *Pool) connIfFree(wanted *poolConn) (*poolConn, error) { + p.mu.Lock() + defer p.mu.Unlock() + if wanted.pmuClosed { + return nil, errConnClosed + } + if wanted.inUse { + return nil, errConnBusy + } + idx := -1 + for ii, v := range p.freeConn { + if v == wanted { + idx = ii + break + } + } + if idx >= 0 { + p.freeConn = append(p.freeConn[:idx], p.freeConn[idx+1:]...) + wanted.inUse = true + return wanted, nil + } + + panic("connIfFree call requested a non-closed, non-busy, non-free conn") +} + +// noteUnusedCursor notes that si is no longer used and should +// be closed whenever possible (when c is next not in use), unless c is +// already closed. +func (p *Pool) noteUnusedCursor(c *poolConn, ci *Cursor) { + p.mu.Lock() + defer p.mu.Unlock() + if c.inUse { + c.onPut = append(c.onPut, func() { + ci.Close() + }) + } else { + c.Lock() + defer c.Unlock() + if !c.finalClosed { + ci.Close() + } + } +} + +// debugGetPut determines whether getConn & putConn calls' stack traces +// are returned for more verbose crashes. +const debugGetPut = false + +// putConn adds a connection to the free pool. +// err is optionally the last error that occurred on this connection. +func (p *Pool) putConn(pc *poolConn, err error) { + p.mu.Lock() + if !pc.inUse { + if debugGetPut { + fmt.Printf("putConn(%v) DUPLICATE was: %s\n\nPREVIOUS was: %s", pc, stack(), p.lastPut[pc]) + } + panic("gorethink: connection returned that was never out") + } + if debugGetPut { + p.lastPut[pc] = stack() + } + pc.inUse = false + for _, fn := range pc.onPut { + fn() + } + pc.onPut = nil + if err != nil && pc.ci.bad { + // Don't reuse bad connections. + // Since the conn is considered bad and is being discarded, treat it + // as closed. Don't decrement the open count here, finalClose will + // take care of that. + p.maybeOpenNewConnections() + p.mu.Unlock() + pc.Close() + return + } + added := p.putConnPoolLocked(pc, nil) + p.mu.Unlock() + if !added { + pc.Close() + } +} + +// Satisfy a connRequest or put the poolConn in the idle pool and return true +// or return false. +// putConnPoolLocked will satisfy a connRequest if there is one, or it will +// return the *poolConn to the freeConn list if err == nil and the idle +// connection limit will not be exceeded. +// If err != nil, the value of pc is ignored. +// If err == nil, then pc must not equal nil. +// If a connRequest was fulfilled or the *poolConn was placed in the +// freeConn list, then true is returned, otherwise false is returned. +func (p *Pool) putConnPoolLocked(pc *poolConn, err error) bool { + if c := len(p.connRequests); c > 0 { + req := p.connRequests[0] + // This copy is O(n) but in practice faster than a linked list. + // TODO: consider compacting it down less often and + // moving the base instead? + copy(p.connRequests, p.connRequests[1:]) + p.connRequests = p.connRequests[:c-1] + if err == nil { + pc.inUse = true + } + req <- connRequest{ + conn: pc, + err: err, + } + return true + } else if err == nil && !p.closed && p.maxIdleConnsLocked() > len(p.freeConn) { + p.freeConn = append(p.freeConn, pc) + return true + } + return false +} + +// addDep notes that x now depends on dep, and x's finalClose won't be +// called until all of x's dependencies are removed with removeDep. +func (p *Pool) addDep(x finalCloser, dep interface{}) { + //println(fmt.Sprintf("addDep(%T %p, %T %p)", x, x, dep, dep)) + p.mu.Lock() + defer p.mu.Unlock() + p.addDepLocked(x, dep) +} + +func (p *Pool) addDepLocked(x finalCloser, dep interface{}) { + if p.dep == nil { + p.dep = make(map[finalCloser]depSet) + } + xdep := p.dep[x] + if xdep == nil { + xdep = make(depSet) + p.dep[x] = xdep + } + xdep[dep] = true +} + +// removeDep notes that x no longer depends on dep. +// If x still has dependencies, nil is returned. +// If x no longer has any dependencies, its finalClose method will be +// called and its error value will be returned. +func (p *Pool) removeDep(x finalCloser, dep interface{}) error { + p.mu.Lock() + fn := p.removeDepLocked(x, dep) + p.mu.Unlock() + return fn() +} + +func (p *Pool) removeDepLocked(x finalCloser, dep interface{}) func() error { + //println(fmt.Sprintf("removeDep(%T %p, %T %p)", x, x, dep, dep)) + xdep, ok := p.dep[x] + if !ok { + panic(fmt.Sprintf("unpaired removeDep: no deps for %T", x)) + } + l0 := len(xdep) + delete(xdep, dep) + switch len(xdep) { + case l0: + // Nothing removed. Shouldn't happen. + panic(fmt.Sprintf("unpaired removeDep: no %T dep on %T", dep, x)) + case 0: + // No more dependencies. + delete(p.dep, x) + return x.finalClose + default: + // Dependencies remain. + return func() error { return nil } + } +} + +// Query execution functions + +// Exec executes a query without waiting for any response. +func (p *Pool) Exec(q Query) error { + var err error + for i := 0; i < maxBadConnRetries; i++ { + err = p.exec(q) + if err != ErrBadConn { + break + } + } + return err +} +func (p *Pool) exec(q Query) (err error) { + pc, err := p.conn() + if err != nil { + return err + } + defer func() { + p.putConn(pc, err) + }() + + pc.Lock() + _, _, err = pc.ci.Query(q) + pc.Unlock() + + if err != nil { + return err + } + return nil +} + +// Query executes a query and waits for the response +func (p *Pool) Query(q Query) (*Cursor, error) { + var cursor *Cursor + var err error + for i := 0; i < maxBadConnRetries; i++ { + cursor, err = p.query(q) + if err != ErrBadConn { + break + } + } + return cursor, err +} +func (p *Pool) query(query Query) (*Cursor, error) { + ci, err := p.conn() + if err != nil { + return nil, err + } + return p.queryConn(ci, ci.releaseConn, query) +} + +// queryConn executes a query on the given connection. +// The connection gets released by the releaseConn function. +func (p *Pool) queryConn(pc *poolConn, releaseConn func(error), q Query) (*Cursor, error) { + pc.Lock() + _, cursor, err := pc.ci.Query(q) + pc.Unlock() + if err != nil { + releaseConn(err) + return nil, err + } + + cursor.releaseConn = releaseConn + + return cursor, nil +} + +// Helper functions + +func stack() string { + var buf [2 << 10]byte + return string(buf[:runtime.Stack(buf[:], false)]) +} diff --git a/pool_conn.go b/pool_conn.go new file mode 100644 index 00000000..300b7354 --- /dev/null +++ b/pool_conn.go @@ -0,0 +1,75 @@ +package gorethink + +import ( + "errors" + "sync" +) + +// ErrBadConn should be returned by a connection operation to signal to the +// pool that a driver.Conn is in a bad state (such as the server +// having earlier closed the connection) and the pool should retry on a +// new connection. +// +// To prevent duplicate operations, ErrBadConn should NOT be returned +// if there's a possibility that the database server might have +// performed the operation. Even if the server sends back an error, +// you shouldn't return ErrBadConn. +var ErrBadConn = errors.New("gorethink: bad connection") + +type poolConn struct { + p *Pool + + sync.Mutex // guards following + ci *Connection + closed bool + finalClosed bool // ci.Close has been called + + // guarded by p.mu + inUse bool + onPut []func() // code (with p.mu held) run when conn is next returned + pmuClosed bool // same as closed, but guarded by p.mu, for connIfFree +} + +func (pc *poolConn) releaseConn(err error) { + pc.p.putConn(pc, err) +} + +// the pc.p's Mutex is held. +func (pc *poolConn) closePoolLocked() func() error { + pc.Lock() + defer pc.Unlock() + if pc.closed { + return func() error { return errors.New("gorethink: duplicate driverConn close") } + } + pc.closed = true + return pc.p.removeDepLocked(pc, pc) +} + +func (pc *poolConn) Close() error { + pc.Lock() + if pc.closed { + pc.Unlock() + return errors.New("gorethink: duplicate driverConn close") + } + pc.closed = true + pc.Unlock() // not defer; removeDep finalClose calls may need to lock + // And now updates that require holding pc.mu.Lock. + pc.p.mu.Lock() + pc.pmuClosed = true + fn := pc.p.removeDepLocked(pc, pc) + pc.p.mu.Unlock() + return fn() +} + +func (pc *poolConn) finalClose() error { + pc.Lock() + err := pc.ci.Close() + pc.ci = nil + pc.finalClosed = true + pc.Unlock() + pc.p.mu.Lock() + pc.p.numOpen-- + pc.p.maybeOpenNewConnections() + pc.p.mu.Unlock() + return err +} diff --git a/ql2/ql2.pb.go b/ql2/ql2.pb.go index 2f72bec2..707354bd 100644 --- a/ql2/ql2.pb.go +++ b/ql2/ql2.pb.go @@ -2,33 +2,17 @@ // source: ql2.proto // DO NOT EDIT! -/* -Package ql2 is a generated protocol buffer package. - -It is generated from these files: - ql2.proto - -It has these top-level messages: - VersionDummy - Query - Frame - Backtrace - Response - Datum - Term -*/ package ql2 import proto "code.google.com/p/goprotobuf/proto" +import json "encoding/json" import math "math" -// Reference imports to suppress errors if they are not otherwise used. +// Reference proto, json, and math imports to suppress error if they are not otherwise used. var _ = proto.Marshal +var _ = &json.SyntaxError{} var _ = math.Inf -// non-conforming protobuf libraries -// This enum contains the magic numbers for your version. See **THE HIGH-LEVEL -// VIEW** for what to do with it. type VersionDummy_Version int32 const ( @@ -56,6 +40,9 @@ func (x VersionDummy_Version) Enum() *VersionDummy_Version { func (x VersionDummy_Version) String() string { return proto.EnumName(VersionDummy_Version_name, int32(x)) } +func (x VersionDummy_Version) MarshalJSON() ([]byte, error) { + return json.Marshal(x.String()) +} func (x *VersionDummy_Version) UnmarshalJSON(data []byte) error { value, err := proto.UnmarshalJSONEnum(VersionDummy_Version_value, data, "VersionDummy_Version") if err != nil { @@ -65,7 +52,6 @@ func (x *VersionDummy_Version) UnmarshalJSON(data []byte) error { return nil } -// The protocol to use after the handshake, specified in V0_3 type VersionDummy_Protocol int32 const ( @@ -90,6 +76,9 @@ func (x VersionDummy_Protocol) Enum() *VersionDummy_Protocol { func (x VersionDummy_Protocol) String() string { return proto.EnumName(VersionDummy_Protocol_name, int32(x)) } +func (x VersionDummy_Protocol) MarshalJSON() ([]byte, error) { + return json.Marshal(x.String()) +} func (x *VersionDummy_Protocol) UnmarshalJSON(data []byte) error { value, err := proto.UnmarshalJSONEnum(VersionDummy_Protocol_value, data, "VersionDummy_Protocol") if err != nil { @@ -102,9 +91,8 @@ func (x *VersionDummy_Protocol) UnmarshalJSON(data []byte) error { type Query_QueryType int32 const ( - Query_START Query_QueryType = 1 - Query_CONTINUE Query_QueryType = 2 - // (see [Response]). + Query_START Query_QueryType = 1 + Query_CONTINUE Query_QueryType = 2 Query_STOP Query_QueryType = 3 Query_NOREPLY_WAIT Query_QueryType = 4 ) @@ -130,6 +118,9 @@ func (x Query_QueryType) Enum() *Query_QueryType { func (x Query_QueryType) String() string { return proto.EnumName(Query_QueryType_name, int32(x)) } +func (x Query_QueryType) MarshalJSON() ([]byte, error) { + return json.Marshal(x.String()) +} func (x *Query_QueryType) UnmarshalJSON(data []byte) error { value, err := proto.UnmarshalJSONEnum(Query_QueryType_value, data, "Query_QueryType") if err != nil { @@ -163,6 +154,9 @@ func (x Frame_FrameType) Enum() *Frame_FrameType { func (x Frame_FrameType) String() string { return proto.EnumName(Frame_FrameType_name, int32(x)) } +func (x Frame_FrameType) MarshalJSON() ([]byte, error) { + return json.Marshal(x.String()) +} func (x *Frame_FrameType) UnmarshalJSON(data []byte) error { value, err := proto.UnmarshalJSONEnum(Frame_FrameType_value, data, "Frame_FrameType") if err != nil { @@ -175,24 +169,15 @@ func (x *Frame_FrameType) UnmarshalJSON(data []byte) error { type Response_ResponseType int32 const ( - // These response types indicate success. - Response_SUCCESS_ATOM Response_ResponseType = 1 - Response_SUCCESS_SEQUENCE Response_ResponseType = 2 - Response_SUCCESS_PARTIAL Response_ResponseType = 3 - // datatypes. If you send a [CONTINUE] query with - // the same token as this response, you will get - // more of the sequence. Keep sending [CONTINUE] - // queries until you get back [SUCCESS_SEQUENCE]. - Response_SUCCESS_FEED Response_ResponseType = 5 - Response_WAIT_COMPLETE Response_ResponseType = 4 - // These response types indicate failure. - Response_CLIENT_ERROR Response_ResponseType = 16 - // client sends a malformed protobuf, or tries to - // send [CONTINUE] for an unknown token. - Response_COMPILE_ERROR Response_ResponseType = 17 - // checking. For example, if you pass too many - // arguments to a function. - Response_RUNTIME_ERROR Response_ResponseType = 18 + Response_SUCCESS_ATOM Response_ResponseType = 1 + Response_SUCCESS_SEQUENCE Response_ResponseType = 2 + Response_SUCCESS_PARTIAL Response_ResponseType = 3 + Response_SUCCESS_FEED Response_ResponseType = 5 + Response_WAIT_COMPLETE Response_ResponseType = 4 + Response_SUCCESS_ATOM_FEED Response_ResponseType = 6 + Response_CLIENT_ERROR Response_ResponseType = 16 + Response_COMPILE_ERROR Response_ResponseType = 17 + Response_RUNTIME_ERROR Response_ResponseType = 18 ) var Response_ResponseType_name = map[int32]string{ @@ -201,19 +186,21 @@ var Response_ResponseType_name = map[int32]string{ 3: "SUCCESS_PARTIAL", 5: "SUCCESS_FEED", 4: "WAIT_COMPLETE", + 6: "SUCCESS_ATOM_FEED", 16: "CLIENT_ERROR", 17: "COMPILE_ERROR", 18: "RUNTIME_ERROR", } var Response_ResponseType_value = map[string]int32{ - "SUCCESS_ATOM": 1, - "SUCCESS_SEQUENCE": 2, - "SUCCESS_PARTIAL": 3, - "SUCCESS_FEED": 5, - "WAIT_COMPLETE": 4, - "CLIENT_ERROR": 16, - "COMPILE_ERROR": 17, - "RUNTIME_ERROR": 18, + "SUCCESS_ATOM": 1, + "SUCCESS_SEQUENCE": 2, + "SUCCESS_PARTIAL": 3, + "SUCCESS_FEED": 5, + "WAIT_COMPLETE": 4, + "SUCCESS_ATOM_FEED": 6, + "CLIENT_ERROR": 16, + "COMPILE_ERROR": 17, + "RUNTIME_ERROR": 18, } func (x Response_ResponseType) Enum() *Response_ResponseType { @@ -224,6 +211,9 @@ func (x Response_ResponseType) Enum() *Response_ResponseType { func (x Response_ResponseType) String() string { return proto.EnumName(Response_ResponseType_name, int32(x)) } +func (x Response_ResponseType) MarshalJSON() ([]byte, error) { + return json.Marshal(x.String()) +} func (x *Response_ResponseType) UnmarshalJSON(data []byte) error { value, err := proto.UnmarshalJSONEnum(Response_ResponseType_value, data, "Response_ResponseType") if err != nil { @@ -242,10 +232,7 @@ const ( Datum_R_STR Datum_DatumType = 4 Datum_R_ARRAY Datum_DatumType = 5 Datum_R_OBJECT Datum_DatumType = 6 - // This [DatumType] will only be used if [accepts_r_json] is - // set to [true] in [Query]. [r_str] will be filled with a - // JSON encoding of the [Datum]. - Datum_R_JSON Datum_DatumType = 7 + Datum_R_JSON Datum_DatumType = 7 ) var Datum_DatumType_name = map[int32]string{ @@ -275,6 +262,9 @@ func (x Datum_DatumType) Enum() *Datum_DatumType { func (x Datum_DatumType) String() string { return proto.EnumName(Datum_DatumType_name, int32(x)) } +func (x Datum_DatumType) MarshalJSON() ([]byte, error) { + return json.Marshal(x.String()) +} func (x *Datum_DatumType) UnmarshalJSON(data []byte) error { value, err := proto.UnmarshalJSONEnum(Datum_DatumType_value, data, "Datum_DatumType") if err != nil { @@ -287,68 +277,34 @@ func (x *Datum_DatumType) UnmarshalJSON(data []byte) error { type Term_TermType int32 const ( - // A RQL datum, stored in `datum` below. - Term_DATUM Term_TermType = 1 - Term_MAKE_ARRAY Term_TermType = 2 - // Evaluate the terms in [optargs] and make an object - Term_MAKE_OBJ Term_TermType = 3 - // Takes an integer representing a variable and returns the value stored - // in that variable. It's the responsibility of the client to translate - // from their local representation of a variable to a unique _non-negative_ - // integer for that variable. (We do it this way instead of letting - // clients provide variable names as strings to discourage - // variable-capturing client libraries, and because it's more efficient - // on the wire.) - Term_VAR Term_TermType = 10 - // Takes some javascript code and executes it. - Term_JAVASCRIPT Term_TermType = 11 - // STRING {timeout: !NUMBER} -> Function(*) - Term_UUID Term_TermType = 169 - // Takes an HTTP URL and gets it. If the get succeeds and - // returns valid JSON, it is converted into a DATUM - Term_HTTP Term_TermType = 153 - // Takes a string and throws an error with that message. - // Inside of a `default` block, you can omit the first - // argument to rethrow whatever error you catch (this is most - // useful as an argument to the `default` filter optarg). - Term_ERROR Term_TermType = 12 - // Takes nothing and returns a reference to the implicit variable. - Term_IMPLICIT_VAR Term_TermType = 13 - // * Data Operators - // Returns a reference to a database. - Term_DB Term_TermType = 14 - // Returns a reference to a table. - Term_TABLE Term_TermType = 15 - // Gets a single element from a table by its primary or a secondary key. - Term_GET Term_TermType = 16 - // Table, STRING -> NULL | Table, NUMBER -> NULL | - Term_GET_ALL Term_TermType = 78 - // Simple DATUM Ops - Term_EQ Term_TermType = 17 - Term_NE Term_TermType = 18 - Term_LT Term_TermType = 19 - Term_LE Term_TermType = 20 - Term_GT Term_TermType = 21 - Term_GE Term_TermType = 22 - Term_NOT Term_TermType = 23 - // ADD can either add two numbers or concatenate two arrays. - Term_ADD Term_TermType = 24 - Term_SUB Term_TermType = 25 - Term_MUL Term_TermType = 26 - Term_DIV Term_TermType = 27 - Term_MOD Term_TermType = 28 - // DATUM Array Ops - // Append a single element to the end of an array (like `snoc`). - Term_APPEND Term_TermType = 29 - // Prepend a single element to the end of an array (like `cons`). - Term_PREPEND Term_TermType = 80 - // Remove the elements of one array from another array. - Term_DIFFERENCE Term_TermType = 95 - // DATUM Set Ops - // Set ops work on arrays. They don't use actual sets and thus have - // performance characteristics you would expect from arrays rather than - // from sets. All set operations have the post condition that they - // array they return contains no duplicate values. + Term_DATUM Term_TermType = 1 + Term_MAKE_ARRAY Term_TermType = 2 + Term_MAKE_OBJ Term_TermType = 3 + Term_VAR Term_TermType = 10 + Term_JAVASCRIPT Term_TermType = 11 + Term_UUID Term_TermType = 169 + Term_HTTP Term_TermType = 153 + Term_ERROR Term_TermType = 12 + Term_IMPLICIT_VAR Term_TermType = 13 + Term_DB Term_TermType = 14 + Term_TABLE Term_TermType = 15 + Term_GET Term_TermType = 16 + Term_GET_ALL Term_TermType = 78 + Term_EQ Term_TermType = 17 + Term_NE Term_TermType = 18 + Term_LT Term_TermType = 19 + Term_LE Term_TermType = 20 + Term_GT Term_TermType = 21 + Term_GE Term_TermType = 22 + Term_NOT Term_TermType = 23 + Term_ADD Term_TermType = 24 + Term_SUB Term_TermType = 25 + Term_MUL Term_TermType = 26 + Term_DIV Term_TermType = 27 + Term_MOD Term_TermType = 28 + Term_APPEND Term_TermType = 29 + Term_PREPEND Term_TermType = 80 + Term_DIFFERENCE Term_TermType = 95 Term_SET_INSERT Term_TermType = 88 Term_SET_INTERSECTION Term_TermType = 89 Term_SET_UNION Term_TermType = 90 @@ -358,294 +314,124 @@ const ( Term_LIMIT Term_TermType = 71 Term_INDEXES_OF Term_TermType = 87 Term_CONTAINS Term_TermType = 93 - // Stream/Object Ops - // Get a particular field from an object, or map that over a - // sequence. - Term_GET_FIELD Term_TermType = 31 - // | Sequence, STRING -> Sequence - // Return an array containing the keys of the object. - Term_KEYS Term_TermType = 94 - // Creates an object - Term_OBJECT Term_TermType = 143 - // Check whether an object contains all the specified fields, - // or filters a sequence so that all objects inside of it - // contain all the specified fields. - Term_HAS_FIELDS Term_TermType = 32 - // x.with_fields(...) <=> x.has_fields(...).pluck(...) - Term_WITH_FIELDS Term_TermType = 96 - // Get a subset of an object by selecting some attributes to preserve, - // or map that over a sequence. (Both pick and pluck, polymorphic.) - Term_PLUCK Term_TermType = 33 - // Get a subset of an object by selecting some attributes to discard, or - // map that over a sequence. (Both unpick and without, polymorphic.) - Term_WITHOUT Term_TermType = 34 - // Merge objects (right-preferential) - Term_MERGE Term_TermType = 35 - // Sequence Ops - // Get all elements of a sequence between two values. - // Half-open by default, but the openness of either side can be - // changed by passing 'closed' or 'open for `right_bound` or - // `left_bound`. - Term_BETWEEN Term_TermType = 36 - Term_REDUCE Term_TermType = 37 - Term_MAP Term_TermType = 38 - // Filter a sequence with either a function or a shortcut - // object (see API docs for details). The body of FILTER is - // wrapped in an implicit `.default(false)`, and you can - // change the default value by specifying the `default` - // optarg. If you make the default `r.error`, all errors - // caught by `default` will be rethrown as if the `default` - // did not exist. - Term_FILTER Term_TermType = 39 - // Sequence, OBJECT, {default:DATUM} -> Sequence - // Map a function over a sequence and then concatenate the results together. - Term_CONCATMAP Term_TermType = 40 - // Order a sequence based on one or more attributes. - Term_ORDERBY Term_TermType = 41 - // Get all distinct elements of a sequence (like `uniq`). - Term_DISTINCT Term_TermType = 42 - // Count the number of elements in a sequence, or only the elements that match - // a given filter. - Term_COUNT Term_TermType = 43 - Term_IS_EMPTY Term_TermType = 86 - // Take the union of multiple sequences (preserves duplicate elements! (use distinct)). - Term_UNION Term_TermType = 44 - // Get the Nth element of a sequence. - Term_NTH Term_TermType = 45 - // do NTH or GET_FIELD depending on target object - Term_BRACKET Term_TermType = 170 - Term_INNER_JOIN Term_TermType = 48 - Term_OUTER_JOIN Term_TermType = 49 - // An inner-join that does an equality comparison on two attributes. - Term_EQ_JOIN Term_TermType = 50 - Term_ZIP Term_TermType = 72 - // Array Ops - // Insert an element in to an array at a given index. - Term_INSERT_AT Term_TermType = 82 - // Remove an element at a given index from an array. - Term_DELETE_AT Term_TermType = 83 - // ARRAY, NUMBER, NUMBER -> ARRAY - // Change the element at a given index of an array. - Term_CHANGE_AT Term_TermType = 84 - // Splice one array in to another array. - Term_SPLICE_AT Term_TermType = 85 - // * Type Ops - // Coerces a datum to a named type (e.g. "bool"). - // If you previously used `stream_to_array`, you should use this instead - // with the type "array". - Term_COERCE_TO Term_TermType = 51 - // Returns the named type of a datum (e.g. TYPEOF(true) = "BOOL") - Term_TYPEOF Term_TermType = 52 - // * Write Ops (the OBJECTs contain data about number of errors etc.) - // Updates all the rows in a selection. Calls its Function with the row - // to be updated, and then merges the result of that call. - Term_UPDATE Term_TermType = 53 - // SingleSelection, Function(1), {non_atomic:BOOL, durability:STRING, return_changes:BOOL} -> OBJECT | - // StreamSelection, OBJECT, {non_atomic:BOOL, durability:STRING, return_changes:BOOL} -> OBJECT | - // SingleSelection, OBJECT, {non_atomic:BOOL, durability:STRING, return_changes:BOOL} -> OBJECT - // Deletes all the rows in a selection. - Term_DELETE Term_TermType = 54 - // Replaces all the rows in a selection. Calls its Function with the row - // to be replaced, and then discards it and stores the result of that - // call. - Term_REPLACE Term_TermType = 55 - // Inserts into a table. If `conflict` is replace, overwrites - // entries with the same primary key. If `conflict` is - // update, does an update on the entry. If `conflict` is - // error, or is omitted, conflicts will trigger an error. - Term_INSERT Term_TermType = 56 - // * Administrative OPs - // Creates a database with a particular name. - Term_DB_CREATE Term_TermType = 57 - // Drops a database with a particular name. - Term_DB_DROP Term_TermType = 58 - // Lists all the databases by name. (Takes no arguments) - Term_DB_LIST Term_TermType = 59 - // Creates a table with a particular name in a particular - // database. (You may omit the first argument to use the - // default database.) - Term_TABLE_CREATE Term_TermType = 60 - // STRING, {datacenter:STRING, primary_key:STRING, durability:STRING} -> OBJECT - // Drops a table with a particular name from a particular - // database. (You may omit the first argument to use the - // default database.) - Term_TABLE_DROP Term_TermType = 61 - // STRING -> OBJECT - // Lists all the tables in a particular database. (You may - // omit the first argument to use the default database.) - Term_TABLE_LIST Term_TermType = 62 - // -> ARRAY - // Ensures that previously issued soft-durability writes are complete and - // written to disk. - Term_SYNC Term_TermType = 138 - // * Secondary indexes OPs - // Creates a new secondary index with a particular name and definition. - Term_INDEX_CREATE Term_TermType = 75 - // Drops a secondary index with a particular name from the specified table. - Term_INDEX_DROP Term_TermType = 76 - // Lists all secondary indexes on a particular table. - Term_INDEX_LIST Term_TermType = 77 - // Gets information about whether or not a set of indexes are ready to - // be accessed. Returns a list of objects that look like this: - // {index:STRING, ready:BOOL[, blocks_processed:NUMBER, blocks_total:NUMBER]} - Term_INDEX_STATUS Term_TermType = 139 - // Blocks until a set of indexes are ready to be accessed. Returns the - // same values INDEX_STATUS. - Term_INDEX_WAIT Term_TermType = 140 - // Renames the given index to a new name - Term_INDEX_RENAME Term_TermType = 156 - // * Control Operators - // Calls a function on data - Term_FUNCALL Term_TermType = 64 - // Executes its first argument, and returns its second argument if it - // got [true] or its third argument if it got [false] (like an `if` - // statement). - Term_BRANCH Term_TermType = 65 - // Returns true if any of its arguments returns true (short-circuits). - // (Like `or` in most languages.) - Term_ANY Term_TermType = 66 - // Returns true if all of its arguments return true (short-circuits). - // (Like `and` in most languages.) - Term_ALL Term_TermType = 67 - // Calls its Function with each entry in the sequence - // and executes the array of terms that Function returns. - Term_FOREACH Term_TermType = 68 - // An anonymous function. Takes an array of numbers representing - // variables (see [VAR] above), and a [Term] to execute with those in - // scope. Returns a function that may be passed an array of arguments, - // then executes the Term with those bound to the variable names. The - // user will never construct this directly. We use it internally for - // things like `map` which take a function. The "arity" of a [Function] is - // the number of arguments it takes. - // For example, here's what `_X_.map{|x| x+2}` turns into: - // Term { - // type = MAP; - // args = [_X_, - // Term { - // type = Function; - // args = [Term { - // type = DATUM; - // datum = Datum { - // type = R_ARRAY; - // r_array = [Datum { type = R_NUM; r_num = 1; }]; - // }; - // }, - // Term { - // type = ADD; - // args = [Term { - // type = VAR; - // args = [Term { - // type = DATUM; - // datum = Datum { type = R_NUM; - // r_num = 1}; - // }]; - // }, - // Term { - // type = DATUM; - // datum = Datum { type = R_NUM; r_num = 2; }; - // }]; - // }]; - // }]; - Term_FUNC Term_TermType = 69 - // Indicates to ORDER_BY that this attribute is to be sorted in ascending order. - Term_ASC Term_TermType = 73 - // Indicates to ORDER_BY that this attribute is to be sorted in descending order. - Term_DESC Term_TermType = 74 - // Gets info about anything. INFO is most commonly called on tables. - Term_INFO Term_TermType = 79 - // `a.match(b)` returns a match object if the string `a` - // matches the regular expression `b`. - Term_MATCH Term_TermType = 97 - // Change the case of a string. - Term_UPCASE Term_TermType = 141 - Term_DOWNCASE Term_TermType = 142 - // Select a number of elements from sequence with uniform distribution. - Term_SAMPLE Term_TermType = 81 - // Evaluates its first argument. If that argument returns - // NULL or throws an error related to the absence of an - // expected value (for instance, accessing a non-existent - // field or adding NULL to an integer), DEFAULT will either - // return its second argument or execute it if it's a - // function. If the second argument is a function, it will be - // passed either the text of the error or NULL as its - // argument. - Term_DEFAULT Term_TermType = 92 - // Parses its first argument as a json string and returns it as a - // datum. - Term_JSON Term_TermType = 98 - // Parses its first arguments as an ISO 8601 time and returns it as a - // datum. - Term_ISO8601 Term_TermType = 99 - // Prints a time as an ISO 8601 time. - Term_TO_ISO8601 Term_TermType = 100 - // Returns a time given seconds since epoch in UTC. - Term_EPOCH_TIME Term_TermType = 101 - // Returns seconds since epoch in UTC given a time. - Term_TO_EPOCH_TIME Term_TermType = 102 - // The time the query was received by the server. - Term_NOW Term_TermType = 103 - // Puts a time into an ISO 8601 timezone. - Term_IN_TIMEZONE Term_TermType = 104 - // a.during(b, c) returns whether a is in the range [b, c) - Term_DURING Term_TermType = 105 - // Retrieves the date portion of a time. - Term_DATE Term_TermType = 106 - // x.time_of_day == x.date - x - Term_TIME_OF_DAY Term_TermType = 126 - // Returns the timezone of a time. - Term_TIMEZONE Term_TermType = 127 - // These access the various components of a time. - Term_YEAR Term_TermType = 128 - Term_MONTH Term_TermType = 129 - Term_DAY Term_TermType = 130 - Term_DAY_OF_WEEK Term_TermType = 131 - Term_DAY_OF_YEAR Term_TermType = 132 - Term_HOURS Term_TermType = 133 - Term_MINUTES Term_TermType = 134 - Term_SECONDS Term_TermType = 135 - // Construct a time from a date and optional timezone or a - // date+time and optional timezone. - Term_TIME Term_TermType = 136 - // Constants for ISO 8601 days of the week. - Term_MONDAY Term_TermType = 107 - Term_TUESDAY Term_TermType = 108 - Term_WEDNESDAY Term_TermType = 109 - Term_THURSDAY Term_TermType = 110 - Term_FRIDAY Term_TermType = 111 - Term_SATURDAY Term_TermType = 112 - Term_SUNDAY Term_TermType = 113 - // Constants for ISO 8601 months. - Term_JANUARY Term_TermType = 114 - Term_FEBRUARY Term_TermType = 115 - Term_MARCH Term_TermType = 116 - Term_APRIL Term_TermType = 117 - Term_MAY Term_TermType = 118 - Term_JUNE Term_TermType = 119 - Term_JULY Term_TermType = 120 - Term_AUGUST Term_TermType = 121 - Term_SEPTEMBER Term_TermType = 122 - Term_OCTOBER Term_TermType = 123 - Term_NOVEMBER Term_TermType = 124 - Term_DECEMBER Term_TermType = 125 - // Indicates to MERGE to replace the other object rather than merge it. - Term_LITERAL Term_TermType = 137 - // SEQUENCE, STRING -> GROUPED_SEQUENCE | SEQUENCE, FUNCTION -> GROUPED_SEQUENCE - Term_GROUP Term_TermType = 144 - Term_SUM Term_TermType = 145 - Term_AVG Term_TermType = 146 - Term_MIN Term_TermType = 147 - Term_MAX Term_TermType = 148 - // `str.split()` splits on whitespace - // `str.split(" ")` splits on spaces only - // `str.split(" ", 5)` splits on spaces with at most 5 results - // `str.split(nil, 5)` splits on whitespace with at most 5 results - Term_SPLIT Term_TermType = 149 - Term_UNGROUP Term_TermType = 150 - // Takes a range of numbers and returns a random number within the range - Term_RANDOM Term_TermType = 151 - Term_CHANGES Term_TermType = 152 - Term_ARGS Term_TermType = 154 - // BINARY is client-only at the moment, it is not supported on the server + Term_GET_FIELD Term_TermType = 31 + Term_KEYS Term_TermType = 94 + Term_OBJECT Term_TermType = 143 + Term_HAS_FIELDS Term_TermType = 32 + Term_WITH_FIELDS Term_TermType = 96 + Term_PLUCK Term_TermType = 33 + Term_WITHOUT Term_TermType = 34 + Term_MERGE Term_TermType = 35 + Term_BETWEEN Term_TermType = 36 + Term_REDUCE Term_TermType = 37 + Term_MAP Term_TermType = 38 + Term_FILTER Term_TermType = 39 + Term_CONCAT_MAP Term_TermType = 40 + Term_ORDER_BY Term_TermType = 41 + Term_DISTINCT Term_TermType = 42 + Term_COUNT Term_TermType = 43 + Term_IS_EMPTY Term_TermType = 86 + Term_UNION Term_TermType = 44 + Term_NTH Term_TermType = 45 + Term_BRACKET Term_TermType = 170 + Term_INNER_JOIN Term_TermType = 48 + Term_OUTER_JOIN Term_TermType = 49 + Term_EQ_JOIN Term_TermType = 50 + Term_ZIP Term_TermType = 72 + Term_RANGE Term_TermType = 173 + Term_INSERT_AT Term_TermType = 82 + Term_DELETE_AT Term_TermType = 83 + Term_CHANGE_AT Term_TermType = 84 + Term_SPLICE_AT Term_TermType = 85 + Term_COERCE_TO Term_TermType = 51 + Term_TYPE_OF Term_TermType = 52 + Term_UPDATE Term_TermType = 53 + Term_DELETE Term_TermType = 54 + Term_REPLACE Term_TermType = 55 + Term_INSERT Term_TermType = 56 + Term_DB_CREATE Term_TermType = 57 + Term_DB_DROP Term_TermType = 58 + Term_DB_LIST Term_TermType = 59 + Term_TABLE_CREATE Term_TermType = 60 + Term_TABLE_DROP Term_TermType = 61 + Term_TABLE_LIST Term_TermType = 62 + Term_CONFIG Term_TermType = 174 + Term_STATUS Term_TermType = 175 + Term_WAIT Term_TermType = 177 + Term_RECONFIGURE Term_TermType = 176 + Term_REBALANCE Term_TermType = 179 + Term_SYNC Term_TermType = 138 + Term_INDEX_CREATE Term_TermType = 75 + Term_INDEX_DROP Term_TermType = 76 + Term_INDEX_LIST Term_TermType = 77 + Term_INDEX_STATUS Term_TermType = 139 + Term_INDEX_WAIT Term_TermType = 140 + Term_INDEX_RENAME Term_TermType = 156 + Term_FUNCALL Term_TermType = 64 + Term_BRANCH Term_TermType = 65 + Term_ANY Term_TermType = 66 + Term_ALL Term_TermType = 67 + Term_FOR_EACH Term_TermType = 68 + Term_FUNC Term_TermType = 69 + Term_ASC Term_TermType = 73 + Term_DESC Term_TermType = 74 + Term_INFO Term_TermType = 79 + Term_MATCH Term_TermType = 97 + Term_UPCASE Term_TermType = 141 + Term_DOWNCASE Term_TermType = 142 + Term_SAMPLE Term_TermType = 81 + Term_DEFAULT Term_TermType = 92 + Term_JSON Term_TermType = 98 + Term_TO_JSON_STRING Term_TermType = 172 + Term_ISO8601 Term_TermType = 99 + Term_TO_ISO8601 Term_TermType = 100 + Term_EPOCH_TIME Term_TermType = 101 + Term_TO_EPOCH_TIME Term_TermType = 102 + Term_NOW Term_TermType = 103 + Term_IN_TIMEZONE Term_TermType = 104 + Term_DURING Term_TermType = 105 + Term_DATE Term_TermType = 106 + Term_TIME_OF_DAY Term_TermType = 126 + Term_TIMEZONE Term_TermType = 127 + Term_YEAR Term_TermType = 128 + Term_MONTH Term_TermType = 129 + Term_DAY Term_TermType = 130 + Term_DAY_OF_WEEK Term_TermType = 131 + Term_DAY_OF_YEAR Term_TermType = 132 + Term_HOURS Term_TermType = 133 + Term_MINUTES Term_TermType = 134 + Term_SECONDS Term_TermType = 135 + Term_TIME Term_TermType = 136 + Term_MONDAY Term_TermType = 107 + Term_TUESDAY Term_TermType = 108 + Term_WEDNESDAY Term_TermType = 109 + Term_THURSDAY Term_TermType = 110 + Term_FRIDAY Term_TermType = 111 + Term_SATURDAY Term_TermType = 112 + Term_SUNDAY Term_TermType = 113 + Term_JANUARY Term_TermType = 114 + Term_FEBRUARY Term_TermType = 115 + Term_MARCH Term_TermType = 116 + Term_APRIL Term_TermType = 117 + Term_MAY Term_TermType = 118 + Term_JUNE Term_TermType = 119 + Term_JULY Term_TermType = 120 + Term_AUGUST Term_TermType = 121 + Term_SEPTEMBER Term_TermType = 122 + Term_OCTOBER Term_TermType = 123 + Term_NOVEMBER Term_TermType = 124 + Term_DECEMBER Term_TermType = 125 + Term_LITERAL Term_TermType = 137 + Term_GROUP Term_TermType = 144 + Term_SUM Term_TermType = 145 + Term_AVG Term_TermType = 146 + Term_MIN Term_TermType = 147 + Term_MAX Term_TermType = 148 + Term_SPLIT Term_TermType = 149 + Term_UNGROUP Term_TermType = 150 + Term_RANDOM Term_TermType = 151 + Term_CHANGES Term_TermType = 152 + Term_ARGS Term_TermType = 154 Term_BINARY Term_TermType = 155 Term_GEOJSON Term_TermType = 157 Term_TO_GEOJSON Term_TermType = 158 @@ -712,8 +498,8 @@ var Term_TermType_name = map[int32]string{ 37: "REDUCE", 38: "MAP", 39: "FILTER", - 40: "CONCATMAP", - 41: "ORDERBY", + 40: "CONCAT_MAP", + 41: "ORDER_BY", 42: "DISTINCT", 43: "COUNT", 86: "IS_EMPTY", @@ -724,12 +510,13 @@ var Term_TermType_name = map[int32]string{ 49: "OUTER_JOIN", 50: "EQ_JOIN", 72: "ZIP", + 173: "RANGE", 82: "INSERT_AT", 83: "DELETE_AT", 84: "CHANGE_AT", 85: "SPLICE_AT", 51: "COERCE_TO", - 52: "TYPEOF", + 52: "TYPE_OF", 53: "UPDATE", 54: "DELETE", 55: "REPLACE", @@ -740,6 +527,11 @@ var Term_TermType_name = map[int32]string{ 60: "TABLE_CREATE", 61: "TABLE_DROP", 62: "TABLE_LIST", + 174: "CONFIG", + 175: "STATUS", + 177: "WAIT", + 176: "RECONFIGURE", + 179: "REBALANCE", 138: "SYNC", 75: "INDEX_CREATE", 76: "INDEX_DROP", @@ -751,7 +543,7 @@ var Term_TermType_name = map[int32]string{ 65: "BRANCH", 66: "ANY", 67: "ALL", - 68: "FOREACH", + 68: "FOR_EACH", 69: "FUNC", 73: "ASC", 74: "DESC", @@ -762,6 +554,7 @@ var Term_TermType_name = map[int32]string{ 81: "SAMPLE", 92: "DEFAULT", 98: "JSON", + 172: "TO_JSON_STRING", 99: "ISO8601", 100: "TO_ISO8601", 101: "EPOCH_TIME", @@ -876,8 +669,8 @@ var Term_TermType_value = map[string]int32{ "REDUCE": 37, "MAP": 38, "FILTER": 39, - "CONCATMAP": 40, - "ORDERBY": 41, + "CONCAT_MAP": 40, + "ORDER_BY": 41, "DISTINCT": 42, "COUNT": 43, "IS_EMPTY": 86, @@ -888,12 +681,13 @@ var Term_TermType_value = map[string]int32{ "OUTER_JOIN": 49, "EQ_JOIN": 50, "ZIP": 72, + "RANGE": 173, "INSERT_AT": 82, "DELETE_AT": 83, "CHANGE_AT": 84, "SPLICE_AT": 85, "COERCE_TO": 51, - "TYPEOF": 52, + "TYPE_OF": 52, "UPDATE": 53, "DELETE": 54, "REPLACE": 55, @@ -904,6 +698,11 @@ var Term_TermType_value = map[string]int32{ "TABLE_CREATE": 60, "TABLE_DROP": 61, "TABLE_LIST": 62, + "CONFIG": 174, + "STATUS": 175, + "WAIT": 177, + "RECONFIGURE": 176, + "REBALANCE": 179, "SYNC": 138, "INDEX_CREATE": 75, "INDEX_DROP": 76, @@ -915,7 +714,7 @@ var Term_TermType_value = map[string]int32{ "BRANCH": 65, "ANY": 66, "ALL": 67, - "FOREACH": 68, + "FOR_EACH": 68, "FUNC": 69, "ASC": 73, "DESC": 74, @@ -926,6 +725,7 @@ var Term_TermType_value = map[string]int32{ "SAMPLE": 81, "DEFAULT": 92, "JSON": 98, + "TO_JSON_STRING": 172, "ISO8601": 99, "TO_ISO8601": 100, "EPOCH_TIME": 101, @@ -999,6 +799,9 @@ func (x Term_TermType) Enum() *Term_TermType { func (x Term_TermType) String() string { return proto.EnumName(Term_TermType_name, int32(x)) } +func (x Term_TermType) MarshalJSON() ([]byte, error) { + return json.Marshal(x.String()) +} func (x *Term_TermType) UnmarshalJSON(data []byte) error { value, err := proto.UnmarshalJSONEnum(Term_TermType_value, data, "Term_TermType") if err != nil { @@ -1016,25 +819,11 @@ func (m *VersionDummy) Reset() { *m = VersionDummy{} } func (m *VersionDummy) String() string { return proto.CompactTextString(m) } func (*VersionDummy) ProtoMessage() {} -// You send one of: -// * A [START] query with a [Term] to evaluate and a unique-per-connection token. -// * A [CONTINUE] query with the same token as a [START] query that returned -// [SUCCESS_PARTIAL] in its [Response]. -// * A [STOP] query with the same token as a [START] query that you want to stop. -// * A [NOREPLY_WAIT] query with a unique per-connection token. The server answers -// with a [WAIT_COMPLETE] [Response]. type Query struct { - Type *Query_QueryType `protobuf:"varint,1,opt,name=type,enum=Query_QueryType" json:"type,omitempty"` - // A [Term] is how we represent the operations we want a query to perform. - Query *Term `protobuf:"bytes,2,opt,name=query" json:"query,omitempty"` - Token *int64 `protobuf:"varint,3,opt,name=token" json:"token,omitempty"` - // This flag is ignored on the server. `noreply` should be added - // to `global_optargs` instead (the key "noreply" should map to - // either true or false). - OBSOLETENoreply *bool `protobuf:"varint,4,opt,name=OBSOLETE_noreply,def=0" json:"OBSOLETE_noreply,omitempty"` - // If this is set to [true], then [Datum] values will sometimes be - // of [DatumType] [R_JSON] (see below). This can provide enormous - // speedups in languages with poor protobuf libraries. + Type *Query_QueryType `protobuf:"varint,1,opt,name=type,enum=Query_QueryType" json:"type,omitempty"` + Query *Term `protobuf:"bytes,2,opt,name=query" json:"query,omitempty"` + Token *int64 `protobuf:"varint,3,opt,name=token" json:"token,omitempty"` + OBSOLETENoreply *bool `protobuf:"varint,4,opt,name=OBSOLETE_noreply,def=0" json:"OBSOLETE_noreply,omitempty"` AcceptsRJson *bool `protobuf:"varint,5,opt,name=accepts_r_json,def=0" json:"accepts_r_json,omitempty"` GlobalOptargs []*Query_AssocPair `protobuf:"bytes,6,rep,name=global_optargs" json:"global_optargs,omitempty"` XXX_unrecognized []byte `json:"-"` @@ -1051,7 +840,7 @@ func (m *Query) GetType() Query_QueryType { if m != nil && m.Type != nil { return *m.Type } - return Query_START + return 0 } func (m *Query) GetQuery() *Term { @@ -1113,7 +902,6 @@ func (m *Query_AssocPair) GetVal() *Term { return nil } -// A backtrace frame (see `backtrace` in Response below) type Frame struct { Type *Frame_FrameType `protobuf:"varint,1,opt,name=type,enum=Frame_FrameType" json:"type,omitempty"` Pos *int64 `protobuf:"varint,2,opt,name=pos" json:"pos,omitempty"` @@ -1129,7 +917,7 @@ func (m *Frame) GetType() Frame_FrameType { if m != nil && m.Type != nil { return *m.Type } - return Frame_POS + return 0 } func (m *Frame) GetPos() int64 { @@ -1162,33 +950,13 @@ func (m *Backtrace) GetFrames() []*Frame { return nil } -// You get back a response with the same [token] as your query. type Response struct { - Type *Response_ResponseType `protobuf:"varint,1,opt,name=type,enum=Response_ResponseType" json:"type,omitempty"` - Token *int64 `protobuf:"varint,2,opt,name=token" json:"token,omitempty"` - // [response] contains 1 RQL datum if [type] is [SUCCESS_ATOM], or many RQL - // data if [type] is [SUCCESS_SEQUENCE] or [SUCCESS_PARTIAL]. It contains 1 - // error message (of type [R_STR]) in all other cases. - Response []*Datum `protobuf:"bytes,3,rep,name=response" json:"response,omitempty"` - // If [type] is [CLIENT_ERROR], [TYPE_ERROR], or [RUNTIME_ERROR], then a - // backtrace will be provided. The backtrace says where in the query the - // error occured. Ideally this information will be presented to the user as - // a pretty-printed version of their query with the erroneous section - // underlined. A backtrace is a series of 0 or more [Frame]s, each of which - // specifies either the index of a positional argument or the name of an - // optional argument. (Those words will make more sense if you look at the - // [Term] message below.) - Backtrace *Backtrace `protobuf:"bytes,4,opt,name=backtrace" json:"backtrace,omitempty"` - // If the [global_optargs] in the [Query] that this [Response] is a - // response to contains a key "profile" which maps to a static value of - // true then [profile] will contain a [Datum] which provides profiling - // information about the execution of the query. This field should be - // returned to the user along with the result that would normally be - // returned (a datum or a cursor). In official drivers this is accomplished - // by putting them inside of an object with "value" mapping to the return - // value and "profile" mapping to the profile object. - Profile *Datum `protobuf:"bytes,5,opt,name=profile" json:"profile,omitempty"` - XXX_unrecognized []byte `json:"-"` + Type *Response_ResponseType `protobuf:"varint,1,opt,name=type,enum=Response_ResponseType" json:"type,omitempty"` + Token *int64 `protobuf:"varint,2,opt,name=token" json:"token,omitempty"` + Response []*Datum `protobuf:"bytes,3,rep,name=response" json:"response,omitempty"` + Backtrace *Backtrace `protobuf:"bytes,4,opt,name=backtrace" json:"backtrace,omitempty"` + Profile *Datum `protobuf:"bytes,5,opt,name=profile" json:"profile,omitempty"` + XXX_unrecognized []byte `json:"-"` } func (m *Response) Reset() { *m = Response{} } @@ -1199,7 +967,7 @@ func (m *Response) GetType() Response_ResponseType { if m != nil && m.Type != nil { return *m.Type } - return Response_SUCCESS_ATOM + return 0 } func (m *Response) GetToken() int64 { @@ -1230,9 +998,6 @@ func (m *Response) GetProfile() *Datum { return nil } -// A [Datum] is a chunk of data that can be serialized to disk or returned to -// the user in a Response. Currently we only support JSON types, but we may -// support other types in the future (e.g., a date type or an integer type). type Datum struct { Type *Datum_DatumType `protobuf:"varint,1,opt,name=type,enum=Datum_DatumType" json:"type,omitempty"` RBool *bool `protobuf:"varint,2,opt,name=r_bool" json:"r_bool,omitempty"` @@ -1266,7 +1031,7 @@ func (m *Datum) GetType() Datum_DatumType { if m != nil && m.Type != nil { return *m.Type } - return Datum_R_NULL + return 0 } func (m *Datum) GetRBool() bool { @@ -1328,52 +1093,8 @@ func (m *Datum_AssocPair) GetVal() *Datum { return nil } -// A [Term] is either a piece of data (see **Datum** above), or an operator and -// its operands. If you have a [Datum], it's stored in the member [datum]. If -// you have an operator, its positional arguments are stored in [args] and its -// optional arguments are stored in [optargs]. -// -// A note about type signatures: -// We use the following notation to denote types: -// arg1_type, arg2_type, argrest_type... -> result_type -// So, for example, if we have a function `avg` that takes any number of -// arguments and averages them, we might write: -// NUMBER... -> NUMBER -// Or if we had a function that took one number modulo another: -// NUMBER, NUMBER -> NUMBER -// Or a function that takes a table and a primary key of any Datum type, then -// retrieves the entry with that primary key: -// Table, DATUM -> OBJECT -// Some arguments must be provided as literal values (and not the results of sub -// terms). These are marked with a `!`. -// Optional arguments are specified within curly braces as argname `:` value -// type (e.x `{use_outdated:BOOL}`) -// Many RQL operations are polymorphic. For these, alterantive type signatures -// are separated by `|`. -// -// The RQL type hierarchy is as follows: -// Top -// DATUM -// NULL -// BOOL -// NUMBER -// STRING -// OBJECT -// SingleSelection -// ARRAY -// Sequence -// ARRAY -// Stream -// StreamSelection -// Table -// Database -// Function -// Ordering - used only by ORDER_BY -// Pathspec -- an object, string, or array that specifies a path -// Error type Term struct { - Type *Term_TermType `protobuf:"varint,1,opt,name=type,enum=Term_TermType" json:"type,omitempty"` - // This is only used when type is DATUM. + Type *Term_TermType `protobuf:"varint,1,opt,name=type,enum=Term_TermType" json:"type,omitempty"` Datum *Datum `protobuf:"bytes,2,opt,name=datum" json:"datum,omitempty"` Args []*Term `protobuf:"bytes,3,rep,name=args" json:"args,omitempty"` Optargs []*Term_AssocPair `protobuf:"bytes,4,rep,name=optargs" json:"optargs,omitempty"` @@ -1403,7 +1124,7 @@ func (m *Term) GetType() Term_TermType { if m != nil && m.Type != nil { return *m.Type } - return Term_DATUM + return 0 } func (m *Term) GetDatum() *Datum { diff --git a/ql2/ql2.proto b/ql2/ql2.proto index 94eff39a..113ea7d1 100644 --- a/ql2/ql2.proto +++ b/ql2/ql2.proto @@ -111,15 +111,16 @@ message Backtrace { message Response { enum ResponseType { // These response types indicate success. - SUCCESS_ATOM = 1; // Query returned a single RQL datatype. - SUCCESS_SEQUENCE = 2; // Query returned a sequence of RQL datatypes. - SUCCESS_PARTIAL = 3; // Query returned a partial sequence of RQL - // datatypes. If you send a [CONTINUE] query with - // the same token as this response, you will get - // more of the sequence. Keep sending [CONTINUE] - // queries until you get back [SUCCESS_SEQUENCE]. - SUCCESS_FEED = 5; // Like [SUCCESS_PARTIAL] but for feeds. - WAIT_COMPLETE = 4; // A [NOREPLY_WAIT] query completed. + SUCCESS_ATOM = 1; // Query returned a single RQL datatype. + SUCCESS_SEQUENCE = 2; // Query returned a sequence of RQL datatypes. + SUCCESS_PARTIAL = 3; // Query returned a partial sequence of RQL + // datatypes. If you send a [CONTINUE] query with + // the same token as this response, you will get + // more of the sequence. Keep sending [CONTINUE] + // queries until you get back [SUCCESS_SEQUENCE]. + SUCCESS_FEED = 5; // Like [SUCCESS_PARTIAL] but for feeds. + WAIT_COMPLETE = 4; // A [NOREPLY_WAIT] query completed. + SUCCESS_ATOM_FEED = 6; // Like [SUCCESS_FEED] but a singleton. // These response types indicate failure. CLIENT_ERROR = 16; // Means the client is buggy. An example is if the @@ -288,7 +289,8 @@ message Term { // Returns a reference to a database. DB = 14; // STRING -> Database // Returns a reference to a table. - TABLE = 15; // Database, STRING, {use_outdated:BOOL} -> Table | STRING, {use_outdated:BOOL} -> Table + TABLE = 15; // Database, STRING, {use_outdated:BOOL, identifier_format:STRING} -> Table + // STRING, {use_outdated:BOOL, identifier_format:STRING} -> Table // Gets a single element from a table by its primary or a secondary key. GET = 16; // Table, STRING -> SingleSelection | Table, NUMBER -> SingleSelection | // Table, STRING -> NULL | Table, NUMBER -> NULL | @@ -365,6 +367,8 @@ message Term { BETWEEN = 36; // StreamSelection, DATUM, DATUM, {index:!STRING, right_bound:STRING, left_bound:STRING} -> StreamSelection REDUCE = 37; // Sequence, Function(2) -> DATUM MAP = 38; // Sequence, Function(1) -> Sequence + // The arity of the function should be + // Sequence..., Function(sizeof...(Sequence)) -> Sequence // Filter a sequence with either a function or a shortcut // object (see API docs for details). The body of FILTER is @@ -376,9 +380,9 @@ message Term { FILTER = 39; // Sequence, Function(1), {default:DATUM} -> Sequence | // Sequence, OBJECT, {default:DATUM} -> Sequence // Map a function over a sequence and then concatenate the results together. - CONCATMAP = 40; // Sequence, Function(1) -> Sequence + CONCAT_MAP = 40; // Sequence, Function(1) -> Sequence // Order a sequence based on one or more attributes. - ORDERBY = 41; // Sequence, (!STRING | Ordering)... -> Sequence + ORDER_BY = 41; // Sequence, (!STRING | Ordering)... -> Sequence // Get all distinct elements of a sequence (like `uniq`). DISTINCT = 42; // Sequence -> Sequence // Count the number of elements in a sequence, or only the elements that match @@ -399,6 +403,9 @@ message Term { // An inner-join that does an equality comparison on two attributes. EQ_JOIN = 50; // Sequence, !STRING, Sequence, {index:!STRING} -> Sequence ZIP = 72; // Sequence -> Sequence + RANGE = 173; // -> Sequence [0, +inf) + // NUMBER -> Sequence [0, a) + // NUMBER, NUMBER -> Sequence [a, b) // Array Ops // Insert an element in to an array at a given index. @@ -416,8 +423,8 @@ message Term { // If you previously used `stream_to_array`, you should use this instead // with the type "array". COERCE_TO = 51; // Top, STRING -> Top - // Returns the named type of a datum (e.g. TYPEOF(true) = "BOOL") - TYPEOF = 52; // Top -> STRING + // Returns the named type of a datum (e.g. TYPE_OF(true) = "BOOL") + TYPE_OF = 52; // Top -> STRING // * Write Ops (the OBJECTs contain data about number of errors etc.) // Updates all the rows in a selection. Calls its Function with the row @@ -440,28 +447,53 @@ message Term { // * Administrative OPs // Creates a database with a particular name. - DB_CREATE = 57; // STRING -> OBJECT + DB_CREATE = 57; // STRING -> OBJECT // Drops a database with a particular name. - DB_DROP = 58; // STRING -> OBJECT + DB_DROP = 58; // STRING -> OBJECT // Lists all the databases by name. (Takes no arguments) - DB_LIST = 59; // -> ARRAY + DB_LIST = 59; // -> ARRAY // Creates a table with a particular name in a particular // database. (You may omit the first argument to use the // default database.) - TABLE_CREATE = 60; // Database, STRING, {datacenter:STRING, primary_key:STRING, durability:STRING} -> OBJECT - // STRING, {datacenter:STRING, primary_key:STRING, durability:STRING} -> OBJECT + TABLE_CREATE = 60; // Database, STRING, {primary_key:STRING, shards:NUMBER, replicas:NUMBER, primary_replica_tag:STRING} -> OBJECT + // Database, STRING, {primary_key:STRING, shards:NUMBER, replicas:OBJECT, primary_replica_tag:STRING} -> OBJECT + // STRING, {primary_key:STRING, shards:NUMBER, replicas:NUMBER, primary_replica_tag:STRING} -> OBJECT + // STRING, {primary_key:STRING, shards:NUMBER, replicas:OBJECT, primary_replica_tag:STRING} -> OBJECT // Drops a table with a particular name from a particular // database. (You may omit the first argument to use the // default database.) - TABLE_DROP = 61; // Database, STRING -> OBJECT - // STRING -> OBJECT + TABLE_DROP = 61; // Database, STRING -> OBJECT + // STRING -> OBJECT // Lists all the tables in a particular database. (You may // omit the first argument to use the default database.) - TABLE_LIST = 62; // Database -> ARRAY - // -> ARRAY + TABLE_LIST = 62; // Database -> ARRAY + // -> ARRAY + // Returns the row in the `rethinkdb.table_config` or `rethinkdb.db_config` table + // that corresponds to the given database or table. + CONFIG = 174; // Database -> SingleSelection + // Table -> SingleSelection + // Returns the row in the `rethinkdb.table_status` table that corresponds to the + // given table. + STATUS = 175; // Table -> SingleSelection + // Called on a table, waits for that table to be ready for read/write operations. + // Called on a database, waits for all of the tables in the database to be ready. + // Returns the corresponding row or rows from the `rethinkdb.table_status` table. + WAIT = 177; // Table -> OBJECT + // Database -> OBJECT + // Generates a new config for the given table, or all tables in the given database + // The `shards` and `replicas` arguments are required + RECONFIGURE = 176; // Database, {shards:NUMBER, replicas:NUMBER[, primary_replica_tag:STRING, dry_run:BOOLEAN]} -> OBJECT + // Database, {shards:NUMBER, replicas:OBJECT[, primary_replica_tag:STRING, dry_run:BOOLEAN]} -> OBJECT + // Table, {shards:NUMBER, replicas:NUMBER[, primary_replica_tag:STRING, dry_run:BOOLEAN]} -> OBJECT + // Table, {shards:NUMBER, replicas:OBJECT[, primary_replica_tag:STRING, dry_run:BOOLEAN]} -> OBJECT + // Balances the table's shards but leaves everything else the same. Can also be + // applied to an entire database at once. + REBALANCE = 179; // Table -> OBJECT + // Database -> OBJECT + // Ensures that previously issued soft-durability writes are complete and // written to disk. - SYNC = 138; // Table -> OBJECT + SYNC = 138; // Table -> OBJECT // * Secondary indexes OPs // Creates a new secondary index with a particular name and definition. @@ -495,7 +527,7 @@ message Term { ALL = 67; // BOOL... -> BOOL // Calls its Function with each entry in the sequence // and executes the array of terms that Function returns. - FOREACH = 68; // Sequence, Function(1) -> OBJECT + FOR_EACH = 68; // Sequence, Function(1) -> OBJECT //////////////////////////////////////////////////////////////////////////////// ////////// Special Terms @@ -571,6 +603,11 @@ message Term { // Parses its first argument as a json string and returns it as a // datum. JSON = 98; // STRING -> DATUM + // Returns the datum as a JSON string. + // N.B.: we would really prefer this be named TO_JSON and that exists as + // an alias in Python and JavaScript drivers; however it conflicts with the + // standard `to_json` method defined by Ruby's standard json library. + TO_JSON_STRING = 172; // DATUM -> STRING // Parses its first arguments as an ISO 8601 time and returns it as a // datum. @@ -608,9 +645,7 @@ message Term { // Construct a time from a date and optional timezone or a // date+time and optional timezone. - TIME = 136; // NUMBER, NUMBER, NUMBER -> PSEUDOTYPE(TIME) | - // NUMBER, NUMBER, NUMBER, STRING -> PSEUDOTYPE(TIME) | - // NUMBER, NUMBER, NUMBER, NUMBER, NUMBER, NUMBER -> PSEUDOTYPE(TIME) | + TIME = 136; // NUMBER, NUMBER, NUMBER, STRING -> PSEUDOTYPE(TIME) | // NUMBER, NUMBER, NUMBER, NUMBER, NUMBER, NUMBER, STRING -> PSEUDOTYPE(TIME) | // Constants for ISO 8601 days of the week. diff --git a/query.go b/query.go index 319a53b6..927871d1 100644 --- a/query.go +++ b/query.go @@ -8,9 +8,26 @@ import ( p "github.com/dancannon/gorethink/ql2" ) -type OptArgs interface { - toMap() map[string]interface{} +type Query struct { + Type p.Query_QueryType + Token int64 + Term *Term + Opts map[string]interface{} +} + +func (q *Query) build() []interface{} { + res := []interface{}{int(q.Type)} + if q.Term != nil { + res = append(res, q.Term.build()) + } + + if len(q.Opts) > 0 { + res = append(res, q.Opts) + } + + return res } + type termsList []Term type termsObj map[string]Term type Term struct { @@ -54,7 +71,16 @@ func (t Term) build() interface{} { optArgs[k] = v.build() } - return []interface{}{t.termType, args, optArgs} + ret := []interface{}{int(t.termType)} + + if len(args) > 0 { + ret = append(ret, args) + } + if len(optArgs) > 0 { + ret = append(ret, optArgs) + } + + return ret } // String returns a string representation of the query tree @@ -98,18 +124,28 @@ func (t Term) String() string { return fmt.Sprintf("%s.%s(%s)", t.args[0].String(), t.name, strings.Join(allArgsToStringSlice(t.args[1:], t.optArgs), ", ")) } +type OptArgs interface { + toMap() map[string]interface{} +} + type WriteResponse struct { - Errors int - Created int - Inserted int - Updated int - Unchanged int - Replaced int - Renamed int - Skipped int - Deleted int - GeneratedKeys []string `gorethink:"generated_keys"` - FirstError string `gorethink:"first_error"` // populated if Errors > 0 + Errors int `gorethink:"errors"` + Inserted int `gorethink:"inserted"` + Updated int `gorethink:"updadte"` + Unchanged int `gorethink:"unchanged"` + Replaced int `gorethink:"replaced"` + Renamed int `gorethink:"renamed"` + Skipped int `gorethink:"skipped"` + Deleted int `gorethink:"deleted"` + Created int `gorethink:"created"` + DBsCreated int `gorethink:"dbs_created"` + TablesCreated int `gorethink:"tables_created"` + Dropped int `gorethink:"dropped"` + DBsDropped int `gorethink:"dbs_dropped"` + TablesDropped int `gorethink:"tables_dropped"` + GeneratedKeys []string `gorethink:"generated_keys"` + FirstError string `gorethink:"first_error"` // populated if Errors > 0 + ConfigChanges []WriteChanges `gorethink:"config_changes"` Changes []WriteChanges } @@ -122,16 +158,12 @@ type RunOpts struct { Db interface{} `gorethink:"db,omitempty"` Profile interface{} `gorethink:"profile,omitempty"` UseOutdated interface{} `gorethink:"use_outdated,omitempty"` - NoReply interface{} `gorethink:"noreply,omitempty"` ArrayLimit interface{} `gorethink:"array_limit,omitempty"` TimeFormat interface{} `gorethink:"time_format,omitempty"` GroupFormat interface{} `gorethink:"group_format,omitempty"` BinaryFormat interface{} `gorethink:"binary_format,omitempty"` GeometryFormat interface{} `gorethink:"geometry_format,omitempty"` - BatchConf BatchOpts `gorethink:"batch_conf,omitempty"` -} -type BatchOpts struct { MinBatchRows interface{} `gorethink:"min_batch_rows,omitempty"` MaxBatchRows interface{} `gorethink:"max_batch_rows,omitempty"` MaxBatchBytes interface{} `gorethink:"max_batch_bytes,omitempty"` @@ -159,16 +191,17 @@ func (t Term) Run(s *Session, optArgs ...RunOpts) (*Cursor, error) { if len(optArgs) >= 1 { opts = optArgs[0].toMap() } - return s.startQuery(t, opts) + + q := newStartQuery(s, t, opts) + + return s.pool.Query(q) } // RunWrite runs a query using the given connection but unlike Run automatically // scans the result into a variable of type WriteResponse. This function should be used // if you are running a write query (such as Insert, Update, TableCreate, etc...) // -// res, err := r.Db("database").Table("table").Insert(doc).RunWrite(sess, r.RunOpts{ -// NoReply: true, -// }) +// res, err := r.Db("database").Table("table").Insert(doc).RunWrite(sess) func (t Term) RunWrite(s *Session, optArgs ...RunOpts) (WriteResponse, error) { var response WriteResponse res, err := t.Run(s, optArgs...) @@ -178,20 +211,64 @@ func (t Term) RunWrite(s *Session, optArgs ...RunOpts) (WriteResponse, error) { return response, err } -// Exec runs the query but does not return the result. -func (t Term) Exec(s *Session, optArgs ...RunOpts) error { - res, err := t.Run(s, optArgs...) - if err != nil { - return err - } - if res == nil { - return nil +// ExecOpts inherits its options from RunOpts, the only difference is the +// addition of the NoReply field. +// +// When NoReply is true it causes the driver not to wait to receive the result +// and return immediately. +type ExecOpts struct { + Db interface{} `gorethink:"db,omitempty"` + Profile interface{} `gorethink:"profile,omitempty"` + UseOutdated interface{} `gorethink:"use_outdated,omitempty"` + ArrayLimit interface{} `gorethink:"array_limit,omitempty"` + TimeFormat interface{} `gorethink:"time_format,omitempty"` + GroupFormat interface{} `gorethink:"group_format,omitempty"` + BinaryFormat interface{} `gorethink:"binary_format,omitempty"` + GeometryFormat interface{} `gorethink:"geometry_format,omitempty"` + + MinBatchRows interface{} `gorethink:"min_batch_rows,omitempty"` + MaxBatchRows interface{} `gorethink:"max_batch_rows,omitempty"` + MaxBatchBytes interface{} `gorethink:"max_batch_bytes,omitempty"` + MaxBatchSeconds interface{} `gorethink:"max_batch_seconds,omitempty"` + FirstBatchScaledownFactor interface{} `gorethink:"first_batch_scaledown_factor,omitempty"` + + NoReply interface{} `gorethink:"noreply,omitempty"` +} + +func (o *ExecOpts) toMap() map[string]interface{} { + return optArgsToMap(o) +} + +// Exec runs the query but does not return the result. Exec will still wait for +// the response to be received unless the NoReply field is true. +// +// res, err := r.Db("database").Table("table").Insert(doc).Exec(sess, r.ExecOpts{ +// NoReply: true, +// }) +func (t Term) Exec(s *Session, optArgs ...ExecOpts) error { + opts := map[string]interface{}{} + if len(optArgs) >= 1 { + opts = optArgs[0].toMap() } - err = res.Close() - if err != nil { - return err + q := newStartQuery(s, t, opts) + + return s.pool.Exec(q) +} + +func newStartQuery(s *Session, t Term, opts map[string]interface{}) Query { + queryOpts := map[string]interface{}{} + for k, v := range opts { + queryOpts[k] = Expr(v).build() + } + if s.opts.Database != "" { + queryOpts["db"] = Db(s.opts.Database).build() } - return nil + // Construct query + return Query{ + Type: p.Query_START, + Term: &t, + Opts: queryOpts, + } } diff --git a/query_admin.go b/query_admin.go new file mode 100644 index 00000000..26b9a2a1 --- /dev/null +++ b/query_admin.go @@ -0,0 +1,52 @@ +package gorethink + +import ( + p "github.com/dancannon/gorethink/ql2" +) + +// Config can be used to read and/or update the configurations for individual +// tables or databases. +func (t Term) Config() Term { + return constructMethodTerm(t, "Config", p.Term_CONFIG, []interface{}{}, map[string]interface{}{}) +} + +// Rebalance rebalances the shards of a table. When called on a database, all +// the tables in that database will be rebalanced. +func (t Term) Rebalance() Term { + return constructMethodTerm(t, "Rebalance", p.Term_REBALANCE, []interface{}{}, map[string]interface{}{}) +} + +type ReconfigureOpts struct { + Shards interface{} `gorethink:"shards,omitempty"` + Replicas interface{} `gorethink:"replicas,omitempty"` + PrimaryTag interface{} `gorethink:"primary_replicas_tag,omitempty"` + DryRun interface{} `gorethink:"dry_run,omitempty"` +} + +func (o *ReconfigureOpts) toMap() map[string]interface{} { + return optArgsToMap(o) +} + +// Reconfigure a table's sharding and replication. +func (t Term) Reconfigure(opts ReconfigureOpts) Term { + return constructMethodTerm(t, "Reconfigure", p.Term_RECONFIGURE, []interface{}{}, opts.toMap()) +} + +// Status return the status of a table +func (t Term) Status() Term { + return constructMethodTerm(t, "Status", p.Term_STATUS, []interface{}{}, map[string]interface{}{}) +} + +// Wait for a table or all the tables in a database to be ready. A table may be +// temporarily unavailable after creation, rebalancing or reconfiguring. The +// wait command blocks until the given table (or database) is fully up to date. +func Wait() Term { + return constructRootTerm("Wait", p.Term_WAIT, []interface{}{}, map[string]interface{}{}) +} + +// Wait for a table or all the tables in a database to be ready. A table may be +// temporarily unavailable after creation, rebalancing or reconfiguring. The +// wait command blocks until the given table (or database) is fully up to date. +func (t Term) Wait() Term { + return constructMethodTerm(t, "Wait", p.Term_WAIT, []interface{}{}, map[string]interface{}{}) +} diff --git a/query_admin_test.go b/query_admin_test.go new file mode 100644 index 00000000..a74d6fe0 --- /dev/null +++ b/query_admin_test.go @@ -0,0 +1,91 @@ +package gorethink + +import ( + test "gopkg.in/check.v1" +) + +func (s *RethinkSuite) TestAdminDbConfig(c *test.C) { + Db("test").TableDrop("test").Exec(sess) + Db("test").TableCreate("test").Exec(sess) + + // Test index rename + query := Db("test").Table("test").Config() + + res, err := query.Run(sess) + c.Assert(err, test.IsNil) + + var response map[string]interface{} + err = res.One(&response) + c.Assert(err, test.IsNil) + + c.Assert(response["name"], test.Equals, "test") +} + +func (s *RethinkSuite) TestAdminTableConfig(c *test.C) { + Db("test").TableDrop("test").Exec(sess) + Db("test").TableCreate("test").Exec(sess) + + // Test index rename + query := Db("test").Config() + + res, err := query.Run(sess) + c.Assert(err, test.IsNil) + + var response map[string]interface{} + err = res.One(&response) + c.Assert(err, test.IsNil) + + c.Assert(response["name"], test.Equals, "test") +} + +func (s *RethinkSuite) TestAdminTableStatus(c *test.C) { + Db("test").TableDrop("test").Exec(sess) + Db("test").TableCreate("test").Exec(sess) + + // Test index rename + query := Db("test").Table("test").Status() + + res, err := query.Run(sess) + c.Assert(err, test.IsNil) + + var response map[string]interface{} + err = res.One(&response) + c.Assert(err, test.IsNil) + + c.Assert(response["name"], test.Equals, "test") + c.Assert(response["status"], test.NotNil) +} + +func (s *RethinkSuite) TestAdminWait(c *test.C) { + Db("test").TableDrop("test").Exec(sess) + Db("test").TableCreate("test").Exec(sess) + + // Test index rename + query := Wait() + + res, err := query.Run(sess) + c.Assert(err, test.IsNil) + + var response map[string]interface{} + err = res.One(&response) + c.Assert(err, test.IsNil) + + c.Assert(response["ready"].(float64) > 0, test.Equals, true) +} + +func (s *RethinkSuite) TestAdminStatus(c *test.C) { + Db("test").TableDrop("test").Exec(sess) + Db("test").TableCreate("test").Exec(sess) + + // Test index rename + query := Db("test").Table("test").Wait() + + res, err := query.Run(sess) + c.Assert(err, test.IsNil) + + var response map[string]interface{} + err = res.One(&response) + c.Assert(err, test.IsNil) + + c.Assert(response["ready"], test.Equals, float64(1)) +} diff --git a/query_aggregation.go b/query_aggregation.go index 10ec737e..e2cc7c8e 100644 --- a/query_aggregation.go +++ b/query_aggregation.go @@ -94,6 +94,17 @@ func (t Term) Min(args ...interface{}) Term { return constructMethodTerm(t, "Min", p.Term_MIN, funcWrapArgs(args), map[string]interface{}{}) } +// Finds the minimum of a sequence. If called with a field name, finds the element +// of that sequence with the smallest value in that field. If called with a function, +// calls that function on every element of the sequence and returns the element +// which produced the smallest value, ignoring any elements where the function +// returns null or produces a non-existence error. +func (t Term) MinIndex(index interface{}, args ...interface{}) Term { + return constructMethodTerm(t, "Min", p.Term_MIN, funcWrapArgs(args), map[string]interface{}{ + "index": index, + }) +} + // Finds the maximum of a sequence. If called with a field name, finds the element // of that sequence with the largest value in that field. If called with a function, // calls that function on every element of the sequence and returns the element @@ -102,3 +113,14 @@ func (t Term) Min(args ...interface{}) Term { func (t Term) Max(args ...interface{}) Term { return constructMethodTerm(t, "Max", p.Term_MAX, funcWrapArgs(args), map[string]interface{}{}) } + +// Finds the maximum of a sequence. If called with a field name, finds the element +// of that sequence with the largest value in that field. If called with a function, +// calls that function on every element of the sequence and returns the element +// which produced the largest value, ignoring any elements where the function +// returns null or produces a non-existence error. +func (t Term) MaxIndex(index interface{}, args ...interface{}) Term { + return constructMethodTerm(t, "Max", p.Term_MAX, funcWrapArgs(args), map[string]interface{}{ + "index": index, + }) +} diff --git a/query_aggregation_test.go b/query_aggregation_test.go index 0395d417..23ce5dcd 100644 --- a/query_aggregation_test.go +++ b/query_aggregation_test.go @@ -210,6 +210,48 @@ func (s *RethinkSuite) TestAggregationGroupMax(c *test.C) { }) } +func (s *RethinkSuite) TestAggregationMin(c *test.C) { + // Ensure table + database exist + DbCreate("test").Exec(sess) + Db("test").TableCreate("Table2").Exec(sess) + Db("test").Table("Table2").IndexCreate("num").Exec(sess) + + // Insert rows + Db("test").Table("Table2").Insert(objList).Exec(sess) + + // Test query + var response interface{} + query := Db("test").Table("Table2").MinIndex("num") + res, err := query.Run(sess) + c.Assert(err, test.IsNil) + + err = res.One(&response) + + c.Assert(err, test.IsNil) + c.Assert(response, JsonEquals, map[string]interface{}{"id": 1, "g1": 1, "g2": 1, "num": 0}) +} + +func (s *RethinkSuite) TestAggregationMaxIndex(c *test.C) { + // Ensure table + database exist + DbCreate("test").Exec(sess) + Db("test").TableCreate("Table2").Exec(sess) + Db("test").Table("Table2").IndexCreate("num").Exec(sess) + + // Insert rows + Db("test").Table("Table2").Insert(objList).Exec(sess) + + // Test query + var response interface{} + query := Db("test").Table("Table2").MaxIndex("num") + res, err := query.Run(sess) + c.Assert(err, test.IsNil) + + err = res.One(&response) + + c.Assert(err, test.IsNil) + c.Assert(response, JsonEquals, map[string]interface{}{"id": 5, "g1": 2, "g2": 3, "num": 100}) +} + func (s *RethinkSuite) TestAggregationMultipleGroupSum(c *test.C) { var response []interface{} query := Expr(objList).Group("g1", "g2").Sum("num") diff --git a/query_control.go b/query_control.go index bcf6854c..56b2796d 100644 --- a/query_control.go +++ b/query_control.go @@ -36,6 +36,20 @@ func expr(val interface{}, depth int) Term { switch val := val.(type) { case Term: return val + case []interface{}: + vals := make([]Term, len(val)) + for i, v := range val { + vals[i] = expr(v, depth) + } + + return makeArray(vals) + case map[string]interface{}: + vals := make(map[string]Term, len(val)) + for k, v := range val { + vals[k] = expr(v, depth) + } + + return makeObject(vals) default: // Use reflection to check for other types valType := reflect.TypeOf(val) @@ -70,15 +84,15 @@ func expr(val interface{}, depth int) Term { return expr(data, depth-1) } else { - vals := []Term{} + vals := make([]Term, valValue.Len()) for i := 0; i < valValue.Len(); i++ { - vals = append(vals, expr(valValue.Index(i).Interface(), depth)) + vals[i] = expr(valValue.Index(i).Interface(), depth) } return makeArray(vals) } case reflect.Map: - vals := map[string]Term{} + vals := make(map[string]Term, len(valValue.MapKeys())) for _, k := range valValue.MapKeys() { vals[k.String()] = expr(valValue.MapIndex(k).Interface(), depth) } @@ -208,7 +222,13 @@ func Branch(args ...interface{}) Term { // Loop over a sequence, evaluating the given write query for each element. func (t Term) ForEach(args ...interface{}) Term { - return constructMethodTerm(t, "Foreach", p.Term_FOREACH, funcWrapArgs(args), map[string]interface{}{}) + return constructMethodTerm(t, "Foreach", p.Term_FOR_EACH, funcWrapArgs(args), map[string]interface{}{}) +} + +// Range generates a stream of sequential integers in a specified range. It +// accepts 0, 1, or 2 arguments, all of which should be numbers. +func Range(args ...interface{}) Term { + return constructRootTerm("Range", p.Term_RANGE, args, map[string]interface{}{}) } // Handle non-existence errors. Tries to evaluate and return its first argument. @@ -230,7 +250,12 @@ func (t Term) CoerceTo(args ...interface{}) Term { // Gets the type of a value. func (t Term) TypeOf(args ...interface{}) Term { - return constructMethodTerm(t, "TypeOf", p.Term_TYPEOF, args, map[string]interface{}{}) + return constructMethodTerm(t, "TypeOf", p.Term_TYPE_OF, args, map[string]interface{}{}) +} + +// Gets the type of a value. +func (t Term) ToJSON() Term { + return constructMethodTerm(t, "ToJSON", p.Term_TO_JSON_STRING, []interface{}{}, map[string]interface{}{}) } // Get information about a RQL value. diff --git a/query_control_test.go b/query_control_test.go index 93515935..1dfc0bdc 100644 --- a/query_control_test.go +++ b/query_control_test.go @@ -8,7 +8,7 @@ import ( test "gopkg.in/check.v1" ) -func (s *RethinkSuite) TestControlExecNil(c *test.C) { +func (s *RethinkSuite) TestControlExprNil(c *test.C) { var response interface{} query := Expr(nil) res, err := query.Run(sess) @@ -20,7 +20,7 @@ func (s *RethinkSuite) TestControlExecNil(c *test.C) { c.Assert(response, test.Equals, nil) } -func (s *RethinkSuite) TestControlExecSimple(c *test.C) { +func (s *RethinkSuite) TestControlExprSimple(c *test.C) { var response int query := Expr(1) res, err := query.Run(sess) @@ -32,7 +32,7 @@ func (s *RethinkSuite) TestControlExecSimple(c *test.C) { c.Assert(response, test.Equals, 1) } -func (s *RethinkSuite) TestControlExecList(c *test.C) { +func (s *RethinkSuite) TestControlExprList(c *test.C) { var response []interface{} query := Expr(narr) res, err := query.Run(sess) @@ -48,7 +48,7 @@ func (s *RethinkSuite) TestControlExecList(c *test.C) { }) } -func (s *RethinkSuite) TestControlExecObj(c *test.C) { +func (s *RethinkSuite) TestControlExprObj(c *test.C) { var response map[string]interface{} query := Expr(nobj) res, err := query.Run(sess) @@ -129,7 +129,7 @@ func (s *RethinkSuite) TestControlStringTypeAlias(c *test.C) { c.Assert(response, JsonEquals, TStr("Hello")) } -func (s *RethinkSuite) TestControlExecTypes(c *test.C) { +func (s *RethinkSuite) TestControlExprTypes(c *test.C) { var response []interface{} query := Expr([]interface{}{int64(1), uint64(1), float64(1.0), int32(1), uint32(1), float32(1), "1", true, false}) res, err := query.Run(sess) @@ -190,6 +190,7 @@ func (s *RethinkSuite) TestControlError(c *test.C) { c.Assert(err, test.NotNil) c.Assert(err, test.FitsTypeOf, RqlRuntimeError{}) + c.Assert(err.Error(), test.Equals, "gorethink: An error occurred in: \nr.Error(\"An error occurred\")") } @@ -397,3 +398,51 @@ func (s *RethinkSuite) TestControlTypeOf(c *test.C) { c.Assert(err, test.IsNil) c.Assert(response, test.Equals, "NUMBER") } + +func (s *RethinkSuite) TestControlRangeNoArgs(c *test.C) { + var response []int + query := Range().Limit(100) + res, err := query.Run(sess) + c.Assert(err, test.IsNil) + + err = res.All(&response) + + c.Assert(err, test.IsNil) + c.Assert(len(response), test.Equals, 100) +} + +func (s *RethinkSuite) TestControlRangeSingleArgs(c *test.C) { + var response []int + query := Range(4) + res, err := query.Run(sess) + c.Assert(err, test.IsNil) + + err = res.All(&response) + + c.Assert(err, test.IsNil) + c.Assert(response, test.DeepEquals, []int{0, 1, 2, 3}) +} + +func (s *RethinkSuite) TestControlRangeTwoArgs(c *test.C) { + var response []int + query := Range(4, 6) + res, err := query.Run(sess) + c.Assert(err, test.IsNil) + + err = res.All(&response) + + c.Assert(err, test.IsNil) + c.Assert(response, test.DeepEquals, []int{4, 5}) +} + +func (s *RethinkSuite) TestControlToJSON(c *test.C) { + var response string + query := Expr([]int{4, 5}).ToJSON() + res, err := query.Run(sess) + c.Assert(err, test.IsNil) + + err = res.One(&response) + + c.Assert(err, test.IsNil) + c.Assert(response, test.Equals, "[4,5]") +} diff --git a/query_db_test.go b/query_db_test.go index efca06c5..764222ba 100644 --- a/query_db_test.go +++ b/query_db_test.go @@ -5,21 +5,15 @@ import ( ) func (s *RethinkSuite) TestDbCreate(c *test.C) { - var response interface{} - // Delete the test2 database if it already exists DbDrop("test").Exec(sess) // Test database creation query := DbCreate("test") - res, err := query.Run(sess) - c.Assert(err, test.IsNil) - - err = res.One(&response) - + response, err := query.RunWrite(sess) c.Assert(err, test.IsNil) - c.Assert(response, JsonEquals, map[string]interface{}{"created": 1}) + c.Assert(response.DBsCreated, JsonEquals, 1) } func (s *RethinkSuite) TestDbList(c *test.C) { @@ -48,21 +42,15 @@ func (s *RethinkSuite) TestDbList(c *test.C) { } func (s *RethinkSuite) TestDbDelete(c *test.C) { - var response interface{} - // Delete the test2 database if it already exists DbCreate("test").Exec(sess) // Test database creation query := DbDrop("test") - res, err := query.Run(sess) - c.Assert(err, test.IsNil) - - err = res.One(&response) - + response, err := query.RunWrite(sess) c.Assert(err, test.IsNil) - c.Assert(response, JsonEquals, map[string]interface{}{"dropped": 1}) + c.Assert(response.DBsDropped, JsonEquals, 1) // Ensure that there is still a test DB after the test has finished DbCreate("test").Exec(sess) diff --git a/query_select_test.go b/query_select_test.go index a64064cb..b26d5814 100644 --- a/query_select_test.go +++ b/query_select_test.go @@ -2,6 +2,9 @@ package gorethink import ( "fmt" + "math/rand" + "testing" + "time" test "gopkg.in/check.v1" ) @@ -269,14 +272,14 @@ func (s *RethinkSuite) TestSelectFilterFunc(c *test.C) { }) } -func (s *RethinkSuite) TestSelectMany(c *test.C) { +func (s *RethinkSuite) TestSelectManyRows(c *test.C) { // Ensure table + database exist DbCreate("test").RunWrite(sess) Db("test").TableCreate("TestMany").RunWrite(sess) Db("test").Table("TestMany").Delete().RunWrite(sess) // Insert rows - for i := 0; i < 1; i++ { + for i := 0; i < 100; i++ { data := []interface{}{} for j := 0; j < 100; j++ { @@ -291,9 +294,7 @@ func (s *RethinkSuite) TestSelectMany(c *test.C) { // Test query res, err := Db("test").Table("TestMany").Run(sess, RunOpts{ - BatchConf: BatchOpts{ - MaxBatchRows: 1, - }, + MaxBatchRows: 1, }) c.Assert(err, test.IsNil) @@ -304,27 +305,129 @@ func (s *RethinkSuite) TestSelectMany(c *test.C) { } c.Assert(res.Err(), test.IsNil) - c.Assert(n, test.Equals, 100) + c.Assert(n, test.Equals, 10000) } -func (s *RethinkSuite) TestConcurrentSelectMany(c *test.C) { +func (s *RethinkSuite) TestConcurrentSelectManyWorkers(c *test.C) { + if testing.Short() { + c.Skip("Skipping long test") + } + + rand.Seed(time.Now().UnixNano()) + sess, _ := Connect(ConnectOpts{ + Address: url, + AuthKey: authKey, + + MaxOpen: 200, + MaxIdle: 200, + }) + // Ensure table + database exist DbCreate("test").RunWrite(sess) - Db("test").TableCreate("TestMany").RunWrite(sess) - Db("test").Table("TestMany").Delete().RunWrite(sess) + Db("test").TableDrop("TestConcurrent").RunWrite(sess) + Db("test").TableCreate("TestConcurrent").RunWrite(sess) + Db("test").TableDrop("TestConcurrent2").RunWrite(sess) + Db("test").TableCreate("TestConcurrent2").RunWrite(sess) // Insert rows - for i := 0; i < 1; i++ { - data := []interface{}{} + for j := 0; j < 200; j++ { + Db("test").Table("TestConcurrent").Insert(map[string]interface{}{ + "id": j, + "i": j, + }).Run(sess) + Db("test").Table("TestConcurrent2").Insert(map[string]interface{}{ + "j": j, + "k": j * 2, + }).Run(sess) + } - for j := 0; j < 100; j++ { - data = append(data, map[string]interface{}{ - "i": i, - "j": j, - }) + // Test queries concurrently + numQueries := 1000 + numWorkers := 10 + queryChan := make(chan int) + doneChan := make(chan error) + + // Start workers + for i := 0; i < numWorkers; i++ { + go func() { + for _ = range queryChan { + res, err := Db("test").Table("TestConcurrent2").EqJoin("j", Db("test").Table("TestConcurrent")).Zip().Run(sess) + if err != nil { + doneChan <- err + return + } + + var response []map[string]interface{} + err = res.All(&response) + if err != nil { + doneChan <- err + return + } + if err := res.Close(); err != nil { + doneChan <- err + return + } + + if len(response) != 200 { + doneChan <- fmt.Errorf("expected response length 200, received %d", len(response)) + return + } + + res, err = Db("test").Table("TestConcurrent").Get(response[rand.Intn(len(response))]["id"]).Run(sess) + if err != nil { + doneChan <- err + return + } + + err = res.All(&response) + if err != nil { + doneChan <- err + return + } + if err := res.Close(); err != nil { + doneChan <- err + return + } + + if len(response) != 1 { + doneChan <- fmt.Errorf("expected response length 1, received %d", len(response)) + return + } + + doneChan <- nil + } + }() + } + + go func() { + for i := 0; i < numQueries; i++ { + queryChan <- i } + }() - Db("test").Table("TestMany").Insert(data).Run(sess) + for i := 0; i < numQueries; i++ { + ret := <-doneChan + if ret != nil { + c.Fatalf("non-nil error returned (%s)", ret) + } + } +} + +func (s *RethinkSuite) TestConcurrentSelectManyRows(c *test.C) { + if testing.Short() { + c.Skip("Skipping long test") + } + + // Ensure table + database exist + DbCreate("test").RunWrite(sess) + Db("test").TableCreate("TestMany").RunWrite(sess) + Db("test").Table("TestMany").Delete().RunWrite(sess) + + // Insert rows + for i := 0; i < 100; i++ { + Db("test").Table("TestMany").Insert(map[string]interface{}{ + "i": i, + }).Run(sess) } // Test queries concurrently @@ -333,23 +436,22 @@ func (s *RethinkSuite) TestConcurrentSelectMany(c *test.C) { for i := 0; i < attempts; i++ { go func(i int, c chan error) { - res, err := Db("test").Table("TestMany").Run(sess, RunOpts{ - BatchConf: BatchOpts{ - MaxBatchRows: 1, - }, - }) + res, err := Db("test").Table("TestMany").Run(sess) if err != nil { c <- err + return } var response []map[string]interface{} err = res.All(&response) if err != nil { c <- err + return } if len(response) != 100 { c <- fmt.Errorf("expected response length 100, received %d", len(response)) + return } c <- nil diff --git a/query_table.go b/query_table.go index 615b2abb..6df9f975 100644 --- a/query_table.go +++ b/query_table.go @@ -122,10 +122,22 @@ func (t Term) IndexWait(args ...interface{}) Term { return constructMethodTerm(t, "IndexWait", p.Term_INDEX_WAIT, args, map[string]interface{}{}) } +type ChangesOpts struct { + Squash interface{} `gorethink:"squash,omitempty"` +} + +func (o *ChangesOpts) toMap() map[string]interface{} { + return optArgsToMap(o) +} + // Takes a table and returns an infinite stream of objects representing changes to that table. // Whenever an insert, delete, update or replace is performed on the table, an object of the form // {old_val:..., new_val:...} will be added to the stream. For an insert, old_val will be // null, and for a delete, new_val will be null. -func (t Term) Changes() Term { - return constructMethodTerm(t, "Changes", p.Term_CHANGES, []interface{}{}, map[string]interface{}{}) +func (t Term) Changes(optArgs ...ChangesOpts) Term { + opts := map[string]interface{}{} + if len(optArgs) >= 1 { + opts = optArgs[0].toMap() + } + return constructMethodTerm(t, "Changes", p.Term_CHANGES, []interface{}{}, opts) } diff --git a/query_table_test.go b/query_table_test.go index 46fb675c..d72ad725 100644 --- a/query_table_test.go +++ b/query_table_test.go @@ -7,25 +7,17 @@ import ( ) func (s *RethinkSuite) TestTableCreate(c *test.C) { - var response interface{} - Db("test").TableDrop("test").Exec(sess) // Test database creation query := Db("test").TableCreate("test") - res, err := query.Run(sess) - c.Assert(err, test.IsNil) - - err = res.One(&response) - + response, err := query.RunWrite(sess) c.Assert(err, test.IsNil) - c.Assert(response, JsonEquals, map[string]interface{}{"created": 1}) + c.Assert(response.TablesCreated, JsonEquals, 1) } func (s *RethinkSuite) TestTableCreatePrimaryKey(c *test.C) { - var response interface{} - Db("test").TableDrop("testOpts").Exec(sess) // Test database creation @@ -33,18 +25,12 @@ func (s *RethinkSuite) TestTableCreatePrimaryKey(c *test.C) { PrimaryKey: "it", }) - res, err := query.Run(sess) - c.Assert(err, test.IsNil) - - err = res.One(&response) - + response, err := query.RunWrite(sess) c.Assert(err, test.IsNil) - c.Assert(response, JsonEquals, map[string]interface{}{"created": 1}) + c.Assert(response.TablesCreated, JsonEquals, 1) } func (s *RethinkSuite) TestTableCreateSoftDurability(c *test.C) { - var response interface{} - Db("test").TableDrop("testOpts").Exec(sess) // Test database creation @@ -52,18 +38,12 @@ func (s *RethinkSuite) TestTableCreateSoftDurability(c *test.C) { Durability: "soft", }) - res, err := query.Run(sess) - c.Assert(err, test.IsNil) - - err = res.One(&response) - + response, err := query.RunWrite(sess) c.Assert(err, test.IsNil) - c.Assert(response, JsonEquals, map[string]interface{}{"created": 1}) + c.Assert(response.TablesCreated, JsonEquals, 1) } func (s *RethinkSuite) TestTableCreateSoftMultipleOpts(c *test.C) { - var response interface{} - Db("test").TableDrop("testOpts").Exec(sess) // Test database creation @@ -72,13 +52,9 @@ func (s *RethinkSuite) TestTableCreateSoftMultipleOpts(c *test.C) { Durability: "soft", }) - res, err := query.Run(sess) + response, err := query.RunWrite(sess) c.Assert(err, test.IsNil) - - err = res.One(&response) - - c.Assert(err, test.IsNil) - c.Assert(response, JsonEquals, map[string]interface{}{"created": 1}) + c.Assert(response.TablesCreated, JsonEquals, 1) Db("test").TableDrop("test").Exec(sess) } @@ -108,25 +84,17 @@ func (s *RethinkSuite) TestTableList(c *test.C) { } func (s *RethinkSuite) TestTableDelete(c *test.C) { - var response interface{} - Db("test").TableCreate("test").Exec(sess) // Test database creation query := Db("test").TableDrop("test") - res, err := query.Run(sess) + response, err := query.RunWrite(sess) c.Assert(err, test.IsNil) - - err = res.One(&response) - - c.Assert(err, test.IsNil) - c.Assert(response, JsonEquals, map[string]interface{}{"dropped": 1}) + c.Assert(response.TablesDropped, JsonEquals, 1) } func (s *RethinkSuite) TestTableIndexCreate(c *test.C) { - var response interface{} - Db("test").TableCreate("test").Exec(sess) Db("test").Table("test").IndexDrop("test").Exec(sess) @@ -135,13 +103,9 @@ func (s *RethinkSuite) TestTableIndexCreate(c *test.C) { Multi: true, }) - res, err := query.Run(sess) - c.Assert(err, test.IsNil) - - err = res.One(&response) - + response, err := query.RunWrite(sess) c.Assert(err, test.IsNil) - c.Assert(response, JsonEquals, map[string]interface{}{"created": 1}) + c.Assert(response.Created, JsonEquals, 1) } func (s *RethinkSuite) TestTableCompoundIndexCreate(c *test.C) { @@ -181,21 +145,15 @@ func (s *RethinkSuite) TestTableIndexList(c *test.C) { } func (s *RethinkSuite) TestTableIndexDelete(c *test.C) { - var response interface{} - Db("test").TableCreate("test").Exec(sess) Db("test").Table("test").IndexCreate("test").Exec(sess) // Test database creation query := Db("test").Table("test").IndexDrop("test") - res, err := query.Run(sess) + response, err := query.RunWrite(sess) c.Assert(err, test.IsNil) - - err = res.One(&response) - - c.Assert(err, test.IsNil) - c.Assert(response, JsonEquals, map[string]interface{}{"dropped": 1}) + c.Assert(response.Dropped, JsonEquals, 1) } func (s *RethinkSuite) TestTableIndexRename(c *test.C) { @@ -206,10 +164,9 @@ func (s *RethinkSuite) TestTableIndexRename(c *test.C) { // Test index rename query := Db("test").Table("test").IndexRename("test", "test2") - res, err := query.RunWrite(sess) + response, err := query.RunWrite(sess) c.Assert(err, test.IsNil) - - c.Assert(res.Renamed, JsonEquals, 1) + c.Assert(response.Renamed, JsonEquals, 1) } func (s *RethinkSuite) TestTableChanges(c *test.C) { @@ -230,7 +187,6 @@ func (s *RethinkSuite) TestTableChanges(c *test.C) { go func() { var response interface{} for n < 10 && res.Next(&response) { - // log.Println(response) n++ } diff --git a/query_test.go b/query_test.go index 938b74da..fb73bf59 100644 --- a/query_test.go +++ b/query_test.go @@ -14,6 +14,11 @@ func (s *RethinkSuite) TestQueryRun(c *test.C) { c.Assert(response, test.Equals, "Test") } +func (s *RethinkSuite) TestQueryExec(c *test.C) { + err := Expr("Test").Exec(sess) + c.Assert(err, test.IsNil) +} + func (s *RethinkSuite) TestQueryProfile(c *test.C) { var response string diff --git a/query_transformation.go b/query_transformation.go index d244b9ab..517f724d 100644 --- a/query_transformation.go +++ b/query_transformation.go @@ -3,6 +3,15 @@ package gorethink import p "github.com/dancannon/gorethink/ql2" // Transform each element of the sequence by applying the given mapping function. +func Map(args ...interface{}) Term { + if len(args) > 0 { + args = append(args[:len(args)-1], funcWrapArgs(args[len(args)-1:])...) + } + + return constructRootTerm("Map", p.Term_MAP, funcWrapArgs(args), map[string]interface{}{}) +} + +// Transfor >m each element of the sequence by applying the given mapping function. func (t Term) Map(args ...interface{}) Term { return constructMethodTerm(t, "Map", p.Term_MAP, funcWrapArgs(args), map[string]interface{}{}) } @@ -18,7 +27,7 @@ func (t Term) WithFields(args ...interface{}) Term { // Flattens a sequence of arrays returned by the mapping function into a single // sequence. func (t Term) ConcatMap(args ...interface{}) Term { - return constructMethodTerm(t, "ConcatMap", p.Term_CONCATMAP, funcWrapArgs(args), map[string]interface{}{}) + return constructMethodTerm(t, "ConcatMap", p.Term_CONCAT_MAP, funcWrapArgs(args), map[string]interface{}{}) } type OrderByOpts struct { @@ -57,7 +66,7 @@ func (t Term) OrderBy(args ...interface{}) Term { } } - return constructMethodTerm(t, "OrderBy", p.Term_ORDERBY, args, opts) + return constructMethodTerm(t, "OrderBy", p.Term_ORDER_BY, args, opts) } func Desc(args ...interface{}) Term { diff --git a/query_transformation_test.go b/query_transformation_test.go index 8ae24bcc..e430836c 100644 --- a/query_transformation_test.go +++ b/query_transformation_test.go @@ -70,6 +70,48 @@ func (s *RethinkSuite) TestTransformationConcatMap(c *test.C) { c.Assert(response, JsonEquals, []interface{}{0, 5, 10, 0, 100, 15, 0, 50, 25}) } +func (s *RethinkSuite) TestTransformationVariadicMap(c *test.C) { + query := Range(5).Map(Range(5), func(a, b Term) interface{} { + return []interface{}{a, b} + }) + + var response []interface{} + res, err := query.Run(sess) + c.Assert(err, test.IsNil) + + err = res.All(&response) + + c.Assert(err, test.IsNil) + c.Assert(response, JsonEquals, [][]int{ + {0, 0}, + {1, 1}, + {2, 2}, + {3, 3}, + {4, 4}, + }) +} + +func (s *RethinkSuite) TestTransformationVariadicRootMap(c *test.C) { + query := Map(Range(5), Range(5), func(a, b Term) interface{} { + return []interface{}{a, b} + }) + + var response []interface{} + res, err := query.Run(sess) + c.Assert(err, test.IsNil) + + err = res.All(&response) + + c.Assert(err, test.IsNil) + c.Assert(response, JsonEquals, [][]int{ + {0, 0}, + {1, 1}, + {2, 2}, + {3, 3}, + {4, 4}, + }) +} + func (s *RethinkSuite) TestTransformationOrderByDesc(c *test.C) { query := Expr(noDupNumObjList).OrderBy(Desc("num")) diff --git a/session.go b/session.go index d27d27ea..86d97115 100644 --- a/session.go +++ b/session.go @@ -2,106 +2,29 @@ package gorethink import ( "sync" - "sync/atomic" "time" - "gopkg.in/fatih/pool.v2" - p "github.com/dancannon/gorethink/ql2" ) -type Query struct { - Type p.Query_QueryType - Token int64 - Term *Term - GlobalOpts map[string]interface{} -} - -func (q *Query) build() []interface{} { - res := []interface{}{q.Type} - if q.Term != nil { - res = append(res, q.Term.build()) - } - - if len(q.GlobalOpts) > 0 { - res = append(res, q.GlobalOpts) - } - - return res -} - type Session struct { - token int64 - address string - database string - timeout time.Duration - authkey string - timeFormat string - - // Pool configuration options - initialCap int - maxCap int - idleTimeout time.Duration + opts ConnectOpts + pool *Pool // Response cache, used for batched responses sync.Mutex - cache map[int64]*Cursor - closed bool - - pool pool.Pool -} - -func newSession(args map[string]interface{}) *Session { - s := &Session{ - cache: map[int64]*Cursor{}, - } - - if token, ok := args["token"]; ok { - s.token = token.(int64) - } - if address, ok := args["address"]; ok { - s.address = address.(string) - } - if database, ok := args["database"]; ok { - s.database = database.(string) - } - if timeout, ok := args["timeout"]; ok { - s.timeout = timeout.(time.Duration) - } - if authkey, ok := args["authkey"]; ok { - s.authkey = authkey.(string) - } - - // Pool configuration options - if initialCap, ok := args["initialCap"]; ok { - s.initialCap = initialCap.(int) - } else { - s.initialCap = 5 - } - if maxCap, ok := args["maxCap"]; ok { - s.maxCap = maxCap.(int) - } else { - s.maxCap = 30 - } - if idleTimeout, ok := args["idleTimeout"]; ok { - s.idleTimeout = idleTimeout.(time.Duration) - } else { - s.idleTimeout = 10 * time.Second - } - - return s + token int64 } type ConnectOpts struct { - Token int64 `gorethink:"token,omitempty"` - Address string `gorethink:"address,omitempty"` - Database string `gorethink:"database,omitempty"` - Timeout time.Duration `gorethink:"timeout,omitempty"` - AuthKey string `gorethink:"authkey,omitempty"` - MaxIdle int `gorethink:"max_idle,omitempty"` - MaxActive int `gorethink:"max_active,omitempty"` - IdleTimeout time.Duration `gorethink:"idle_timeout,omitempty"` + Address string `gorethink:"address,omitempty"` + Database string `gorethink:"database,omitempty"` + AuthKey string `gorethink:"authkey,omitempty"` + Timeout time.Duration `gorethink:"timeout,omitempty"` + + MaxIdle int `gorethink:"max_idle,omitempty"` + MaxOpen int `gorethink:"max_open,omitempty"` } func (o *ConnectOpts) toMap() map[string]interface{} { @@ -110,11 +33,11 @@ func (o *ConnectOpts) toMap() map[string]interface{} { // Connect creates a new database session. // -// Supported arguments include token, address, database, timeout, authkey, -// and timeFormat. Pool options include maxIdle, maxActive and idleTimeout. +// Supported arguments include Address, Database, Timeout, Authkey. Pool +// options include MaxIdle, MaxOpen. // -// By default maxIdle and maxActive are set to 1: passing values greater -// than the default (e.g. maxIdle: "10", maxActive: "20") will provide a +// By default maxIdle and maxOpen are set to 1: passing values greater +// than the default (e.g. MaxIdle: "10", MaxOpen: "20") will provide a // pool of re-usable connections. // // Basic connection example: @@ -125,11 +48,17 @@ func (o *ConnectOpts) toMap() map[string]interface{} { // Database: "test", // AuthKey: "14daak1cad13dj", // }) -func Connect(args ConnectOpts) (*Session, error) { - s := newSession(args.toMap()) +func Connect(opts ConnectOpts) (*Session, error) { + // Connect + s := &Session{ + opts: opts, + } err := s.Reconnect() + if err != nil { + return nil, err + } - return s, err + return s, nil } type CloseOpts struct { @@ -142,25 +71,26 @@ func (o *CloseOpts) toMap() map[string]interface{} { // Reconnect closes and re-opens a session. func (s *Session) Reconnect(optArgs ...CloseOpts) error { - if err := s.Close(optArgs...); err != nil { + var err error + + if err = s.Close(optArgs...); err != nil { return err } - s.closed = false - if s.pool == nil { - cp, err := pool.NewChannelPool(s.initialCap, s.maxCap, Dial(s)) - s.pool = cp - if err != nil { - return err - } + s.pool, err = NewPool(&s.opts) + if err != nil { + return err + } - s.pool = cp + // Ping connection to check it is valid + err = s.pool.Ping() + if err != nil { + return err } - // Check the connection - _, err := s.getConn() + s.closed = false - return err + return nil } // Close closes the session @@ -178,200 +108,33 @@ func (s *Session) Close(optArgs ...CloseOpts) error { if s.pool != nil { s.pool.Close() } + s.pool = nil s.closed = true return nil } -// noreplyWait ensures that previous queries with the noreply flag have been -// processed by the server. Note that this guarantee only applies to queries -// run on the given connection -func (s *Session) NoReplyWait() { - s.noreplyWaitQuery() -} - -// Use changes the default database used -func (s *Session) Use(database string) { - s.database = database -} - -// SetTimeout causes any future queries that are run on this session to timeout -// after the given duration, returning a timeout error. Set to zero to disable. -func (s *Session) SetTimeout(timeout time.Duration) { - s.timeout = timeout -} - -// getToken generates the next query token, used to number requests and match -// responses with requests. -func (s *Session) nextToken() int64 { - return atomic.AddInt64(&s.token, 1) -} - -// startQuery creates a query from the term given and sends it to the server. -// The result from the server is returned as a cursor -func (s *Session) startQuery(t Term, opts map[string]interface{}) (*Cursor, error) { - token := s.nextToken() - - // Build global options - globalOpts := map[string]interface{}{} - for k, v := range opts { - globalOpts[k] = Expr(v).build() - } - - // If no DB option was set default to the value set in the connection - if _, ok := opts["db"]; !ok { - globalOpts["db"] = Db(s.database).build() - } - - // Construct query - q := Query{ - Type: p.Query_START, - Token: token, - Term: &t, - GlobalOpts: globalOpts, - } - - // Get a connection from the pool, do not close yet as it - // might be needed later if a partial response is returned - conn, err := s.getConn() - if err != nil { - return nil, err - } - - return conn.SendQuery(s, q, opts, false) -} - -func (s *Session) handleBatchResponse(cursor *Cursor, response *Response) { - cursor.extend(response) - - s.Lock() - cursor.outstandingRequests-- - - if response.Type != p.Response_SUCCESS_PARTIAL && - response.Type != p.Response_SUCCESS_FEED && - cursor.outstandingRequests == 0 { - delete(s.cache, response.Token) - } - s.Unlock() -} - -// continueQuery continues a previously run query. -// This is needed if a response is batched. -func (s *Session) continueQuery(cursor *Cursor) error { - err := s.asyncContinueQuery(cursor) - if err != nil { - return err - } - - response, err := cursor.conn.ReadResponse(s, cursor.query.Token) - if err != nil { - return err - } - - s.handleBatchResponse(cursor, response) - - return nil -} - -// asyncContinueQuery asynchronously continues a previously run query. -// This is needed if a response is batched. -func (s *Session) asyncContinueQuery(cursor *Cursor) error { - s.Lock() - if cursor.outstandingRequests != 0 { - - s.Unlock() - return nil - } - cursor.outstandingRequests = 1 - s.Unlock() - - q := Query{ - Type: p.Query_CONTINUE, - Token: cursor.query.Token, - } - - _, err := cursor.conn.SendQuery(s, q, cursor.opts, true) - if err != nil { - return err - } - - return nil +// SetMaxIdleConns sets the maximum number of connections in the idle +// connection pool. +func (s *Session) SetMaxIdleConns(n int) { + s.pool.SetMaxIdleConns(n) } -// stopQuery sends closes a query by sending Query_STOP to the server. -func (s *Session) stopQuery(cursor *Cursor) error { - cursor.mu.Lock() - cursor.outstandingRequests++ - cursor.mu.Unlock() - - q := Query{ - Type: p.Query_STOP, - Token: cursor.query.Token, - Term: &cursor.term, - } - - _, err := cursor.conn.SendQuery(s, q, cursor.opts, false) - if err != nil { - return err - } - - response, err := cursor.conn.ReadResponse(s, cursor.query.Token) - if err != nil { - return err - } - - s.handleBatchResponse(cursor, response) - - return nil -} - -// noreplyWaitQuery sends the NOREPLY_WAIT query to the server. -func (s *Session) noreplyWaitQuery() error { - conn, err := s.getConn() - if err != nil { - return err - } - - q := Query{ - Type: p.Query_NOREPLY_WAIT, - Token: s.nextToken(), - } - cur, err := conn.SendQuery(s, q, map[string]interface{}{}, false) - if err != nil { - return err - } - err = cur.Close() - if err != nil { - return err - } - - return nil +// SetMaxOpenConns sets the maximum number of open connections to the database. +func (s *Session) SetMaxOpenConns(n int) { + s.pool.SetMaxOpenConns(n) } -func (s *Session) getConn() (*Connection, error) { - if s.pool == nil { - return nil, pool.ErrClosed - } - - c, err := s.pool.Get() - if err != nil { - return nil, err - } - - return &Connection{Conn: c, s: s}, nil -} - -func (s *Session) checkCache(token int64) (*Cursor, bool) { - s.Lock() - defer s.Unlock() - - cursor, ok := s.cache[token] - return cursor, ok +// NoReplyWait ensures that previous queries with the noreply flag have been +// processed by the server. Note that this guarantee only applies to queries +// run on the given connection +func (s *Session) NoReplyWait() error { + return s.pool.Exec(Query{ + Type: p.Query_NOREPLY_WAIT, + }) } -func (s *Session) setCache(token int64, cursor *Cursor) { - s.Lock() - defer s.Unlock() - - s.cache[token] = cursor +// Use changes the default database used +func (s *Session) Use(database string) { + s.opts.Database = database } diff --git a/session_test.go b/session_test.go index 4ee9dcfa..1b090412 100644 --- a/session_test.go +++ b/session_test.go @@ -8,10 +8,8 @@ import ( func (s *RethinkSuite) TestSessionConnect(c *test.C) { session, err := Connect(ConnectOpts{ - Address: url, - AuthKey: os.Getenv("RETHINKDB_AUTHKEY"), - MaxIdle: 3, - MaxActive: 3, + Address: url, + AuthKey: os.Getenv("RETHINKDB_AUTHKEY"), }) c.Assert(err, test.IsNil) @@ -24,12 +22,49 @@ func (s *RethinkSuite) TestSessionConnect(c *test.C) { c.Assert(response, test.Equals, "Hello World") } +func (s *RethinkSuite) TestSessionReconnect(c *test.C) { + session, err := Connect(ConnectOpts{ + Address: url, + AuthKey: os.Getenv("RETHINKDB_AUTHKEY"), + }) + c.Assert(err, test.IsNil) + + row, err := Expr("Hello World").Run(session) + c.Assert(err, test.IsNil) + + var response string + err = row.One(&response) + c.Assert(err, test.IsNil) + c.Assert(response, test.Equals, "Hello World") + + err = session.Reconnect() + c.Assert(err, test.IsNil) + + row, err = Expr("Hello World 2").Run(session) + c.Assert(err, test.IsNil) + + err = row.One(&response) + c.Assert(err, test.IsNil) + c.Assert(response, test.Equals, "Hello World 2") +} + func (s *RethinkSuite) TestSessionConnectError(c *test.C) { var err error _, err = Connect(ConnectOpts{ - Address: "nonexistanturl", - MaxIdle: 3, - MaxActive: 3, + Address: "nonexistanturl", }) c.Assert(err, test.NotNil) } + +func (s *RethinkSuite) TestSessionConnectDatabase(c *test.C) { + session, err := Connect(ConnectOpts{ + Address: url, + AuthKey: os.Getenv("RETHINKDB_AUTHKEY"), + Database: "test2", + }) + c.Assert(err, test.IsNil) + + _, err = Table("test2").Run(session) + c.Assert(err, test.NotNil) + c.Assert(err.Error(), test.Equals, "gorethink: Database `test2` does not exist. in: \nr.Table(\"test2\")") +} diff --git a/types/geometry.go b/types/geometry.go index a63c7f1e..93c220d7 100644 --- a/types/geometry.go +++ b/types/geometry.go @@ -9,6 +9,56 @@ type Geometry struct { Lines Lines } +func (g Geometry) MarshalRQL() (interface{}, error) { + switch g.Type { + case "Point": + return g.Point.MarshalRQL() + case "LineString": + return g.Line.MarshalRQL() + case "Polygon": + return g.Lines.MarshalRQL() + default: + return nil, fmt.Errorf("pseudo-type GEOMETRY object field 'type' %s is not valid", g.Type) + } +} + +func (g *Geometry) UnmarshalRQL(data interface{}) error { + m, ok := data.(map[string]interface{}) + if !ok { + return fmt.Errorf("pseudo-type GEOMETRY object is not valid") + } + + typ, ok := m["type"] + if !ok { + return fmt.Errorf("pseudo-type GEOMETRY object is not valid, expects 'type' field") + } + coords, ok := m["coordinates"] + if !ok { + return fmt.Errorf("pseudo-type GEOMETRY object is not valid, expects 'coordinates' field") + } + + var err error + switch typ { + case "Point": + g.Type = "Point" + g.Point, err = UnmarshalPoint(coords) + case "LineString": + g.Type = "LineString" + g.Line, err = UnmarshalLineString(coords) + case "Polygon": + g.Type = "Polygon" + g.Lines, err = UnmarshalPolygon(coords) + default: + return fmt.Errorf("pseudo-type GEOMETRY object has invalid type") + } + + if err != nil { + return err + } + + return nil +} + type Point struct { Lon float64 Lat float64 @@ -16,41 +66,111 @@ type Point struct { type Line []Point type Lines []Line -func (p Point) Marshal() interface{} { +func (p Point) Coords() interface{} { return []interface{}{p.Lon, p.Lat} } -func (l Line) Marshal() interface{} { +func (p Point) MarshalRQL() (interface{}, error) { + return map[string]interface{}{ + "$reql_type$": "GEOMETRY", + "coordinates": p.Coords(), + "type": "Point", + }, nil +} + +func (p *Point) UnmarshalRQL(data interface{}) error { + g := &Geometry{} + err := g.UnmarshalRQL(data) + if err != nil { + return err + } + if g.Type != "Point" { + return fmt.Errorf("pseudo-type GEOMETRY object has type %s, expected type %s", g.Type, "Point") + } + + p.Lat = g.Point.Lat + p.Lon = g.Point.Lon + + return nil +} + +func (l Line) Coords() interface{} { coords := make([]interface{}, len(l)) for i, point := range l { - coords[i] = point.Marshal() + coords[i] = point.Coords() } return coords } -func (l Lines) Marshal() interface{} { +func (l Line) MarshalRQL() (interface{}, error) { + return map[string]interface{}{ + "$reql_type$": "GEOMETRY", + "coordinates": l.Coords(), + "type": "LineString", + }, nil +} + +func (l *Line) UnmarshalRQL(data interface{}) error { + g := &Geometry{} + err := g.UnmarshalRQL(data) + if err != nil { + return err + } + if g.Type != "LineString" { + return fmt.Errorf("pseudo-type GEOMETRY object has type %s, expected type %s", g.Type, "LineString") + } + + *l = g.Line + + return nil +} + +func (l Lines) Coords() interface{} { coords := make([]interface{}, len(l)) for i, line := range l { - coords[i] = line.Marshal() + coords[i] = line.Coords() } return coords } +func (l Lines) MarshalRQL() (interface{}, error) { + return map[string]interface{}{ + "$reql_type$": "GEOMETRY", + "coordinates": l.Coords(), + "type": "Polygon", + }, nil +} + +func (l *Lines) UnmarshalRQL(data interface{}) error { + g := &Geometry{} + err := g.UnmarshalRQL(data) + if err != nil { + return err + } + if g.Type != "Polygon" { + return fmt.Errorf("pseudo-type GEOMETRY object has type %s, expected type %s", g.Type, "Polygon") + } + + *l = g.Lines + + return nil +} + func UnmarshalPoint(v interface{}) (Point, error) { coords, ok := v.([]interface{}) if !ok { - return Point{}, fmt.Errorf("pseudo-type GEOMETRY object %v field \"coordinates\" is not valid", v) + return Point{}, fmt.Errorf("pseudo-type GEOMETRY object field 'coordinates' is not valid") } if len(coords) != 2 { - return Point{}, fmt.Errorf("pseudo-type GEOMETRY object %v field \"coordinates\" is not valid", v) + return Point{}, fmt.Errorf("pseudo-type GEOMETRY object field 'coordinates' is not valid") } lon, ok := coords[0].(float64) if !ok { - return Point{}, fmt.Errorf("pseudo-type GEOMETRY object %v field \"coordinates\" is not valid", v) + return Point{}, fmt.Errorf("pseudo-type GEOMETRY object field 'coordinates' is not valid") } lat, ok := coords[1].(float64) if !ok { - return Point{}, fmt.Errorf("pseudo-type GEOMETRY object %v field \"coordinates\" is not valid", v) + return Point{}, fmt.Errorf("pseudo-type GEOMETRY object field 'coordinates' is not valid") } return Point{ @@ -62,7 +182,7 @@ func UnmarshalPoint(v interface{}) (Point, error) { func UnmarshalLineString(v interface{}) (Line, error) { points, ok := v.([]interface{}) if !ok { - return Line{}, fmt.Errorf("pseudo-type GEOMETRY object %v field \"coordinates\" is not valid", v) + return Line{}, fmt.Errorf("pseudo-type GEOMETRY object field 'coordinates' is not valid") } var err error @@ -79,7 +199,7 @@ func UnmarshalLineString(v interface{}) (Line, error) { func UnmarshalPolygon(v interface{}) (Lines, error) { lines, ok := v.([]interface{}) if !ok { - return Lines{}, fmt.Errorf("pseudo-type GEOMETRY object %v field \"coordinates\" is not valid", v) + return Lines{}, fmt.Errorf("pseudo-type GEOMETRY object field 'coordinates' is not valid") } var err error diff --git a/utils.go b/utils.go index a4cf3829..52367d2f 100644 --- a/utils.go +++ b/utils.go @@ -53,16 +53,10 @@ func makeArray(args termsList) Term { // makeObject takes a map of terms and produces a single MAKE_OBJECT term func makeObject(args termsObj) Term { - // First all evaluate all fields in the map - temp := make(termsObj) - for k, v := range args { - temp[k] = Expr(v) - } - return Term{ name: "{...}", termType: p.Term_MAKE_OBJ, - optArgs: temp, + optArgs: args, } } @@ -161,7 +155,7 @@ func convertTermList(l []interface{}) termsList { // Convert a map into a map of terms func convertTermObj(o map[string]interface{}) termsObj { - terms := termsObj{} + terms := make(termsObj, len(o)) for k, v := range o { terms[k] = Expr(v) } diff --git a/wercker.yml b/wercker.yml index 313f6d6c..f03209cb 100644 --- a/wercker.yml +++ b/wercker.yml @@ -1,36 +1,37 @@ box: wercker/golang # Services services: - - mies/rethinkdb@0.3.0 + - dancannon/rethinkdb@0.4.0 # Build definition build: # The steps that will be executed on build steps: - # Sets the go workspace and places you package - # at the right place in the workspace tree - - setup-go-workspace + # Sets the go workspace and places you package + # at the right place in the workspace tree + - pjvds/setup-go-workspace - # Gets the dependencies - - script: - name: get dependencies - code: | - cd $WERCKER_SOURCE_DIR - go version - go get ./... && go get -u gopkg.in/check.v1 && go test -i - - # Build the project - - script: - name: build - code: | - go build ./... - - # Test the project - - script: - name: test - code: | - go test -test.v=true ./... - # - script: - # name: test auth keys - # code: | - # sh -c "echo 'set auth test_key' | rethinkdb admin --join $$HOST$$:29015" - # RETHINKDB_AUTHKEY=test_key go test -test.run="Test" -test.v=true -gocheck.f="TestConnectAuthKey" + - script: + name: Populate cache + code: |- + if test -d "$WERCKER_CACHE_DIR/go-pkg-cache"; then rsync -avzv --exclude "$WERCKER_SOURCE_DIR" "$WERCKER_CACHE_DIR/go-pkg-cache/" "$GOPATH/" ; fi + # Gets the dependencies + - script: + name: get dependencies + code: | + cd $WERCKER_SOURCE_DIR + go version + go get ./... + # Build the project + - script: + name: build + code: | + go build ./... + # Test the project + - script: + name: Test + code: |- + go get -u gopkg.in/check.v1 + - script: + name: Store cache + code: |- + rsync -avzv --exclude "$WERCKER_SOURCE_DIR" "$GOPATH/" "$WERCKER_CACHE_DIR/go-pkg-cache/"