diff --git a/cassandra_test.go b/cassandra_test.go index 02eed613e..9f44ce784 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -2179,7 +2179,7 @@ func TestNegativeStream(t *testing.T) { conn := getRandomConn(t, session) const stream = -50 - writer := frameWriterFunc(func(f *framer, streamID int) error { + writer := frameWriterFunc(func(f *framer, streamID int) (outFrameInfo, error) { f.writeHeader(0, opOptions, stream) return f.finish() }) diff --git a/conn.go b/conn.go index 1a7eb787d..7c2ce1cf2 100644 --- a/conn.go +++ b/conn.go @@ -1099,13 +1099,19 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram framer.trace() } + ofi, err := req.buildFrame(framer, stream) + // The error is handled after we call the stream observer. + if call.streamObserverContext != nil { call.streamObserverContext.StreamStarted(ObservedStream{ - Host: c.host, + Host: c.host, + FramePayloadUncompressedSize: ofi.uncompressedSize, + FramePayloadCompressedSize: ofi.compressedSize, + QueryValuesSize: ofi.queryValuesSize, + QueryCount: ofi.queryCount, }) } - err := req.buildFrame(framer, stream) if err != nil { // closeWithError will block waiting for this stream to either receive a response // or for us to timeout. @@ -1217,6 +1223,20 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram type ObservedStream struct { // Host of the connection used to send the stream. Host *HostInfo + // FramePayloadUncompressedSize is the uncompressed size of the frame payload (without frame header). + // This field is only available in StreamStarted. + FramePayloadUncompressedSize int + // FramePayloadCompressedSize is the compressed size of the frame payload (without frame header). + // This field is only available in StreamStarted. + FramePayloadCompressedSize int + // QueryValuesSize is the total uncompressed size of query values in the frame (without other query options). + // For a batch, it is the sum for all queries in the batch. + // For frames that contain no query values QueryValuesSize is zero. + // This field is only available in StreamStarted. + QueryValuesSize int + // QueryCount is 1 for EXECUTE/QUERY and size of the batch for BATCH frames. + // This field is only available in StreamStarted. + QueryCount int } // StreamObserver is notified about request/response pairs. diff --git a/conn_test.go b/conn_test.go index 69d775664..f1a0bc19f 100644 --- a/conn_test.go +++ b/conn_test.go @@ -679,7 +679,7 @@ func TestStream0(t *testing.T) { f.writeHeader(0, opResult, 0) f.writeInt(resultKindVoid) f.buf[0] |= 0x80 - if err := f.finish(); err != nil { + if _, err := f.finish(); err != nil { t.Fatal(err) } if err := f.writeTo(&buf); err != nil { @@ -1285,7 +1285,7 @@ func (srv *TestServer) process(conn net.Conn, reqFrame *framer, exts map[string] respFrame.buf[0] = srv.protocol | 0x80 - if err := respFrame.finish(); err != nil { + if _, err := respFrame.finish(); err != nil { srv.errorLocked(err) } diff --git a/frame.go b/frame.go index 22d722205..218c97ce8 100644 --- a/frame.go +++ b/frame.go @@ -764,11 +764,26 @@ func (f *framer) setLength(length int) { f.buf[p+3] = byte(length) } -func (f *framer) finish() error { +type outFrameInfo struct { + // compressedSize of the frame payload (without header). + compressedSize int + // uncompressedSize of the frame payload (without header). + uncompressedSize int + // queryValuesSize is sum of sizes of query values. + queryValuesSize int + // queryCount is number of queries executed by the query/execute/batch frame. + queryCount int +} + +func (f *framer) finish() (outFrameInfo, error) { if len(f.buf) > maxFrameSize { // huge app frame, lets remove it so it doesn't bloat the heap f.buf = make([]byte, defaultBufSize) - return ErrFrameTooBig + return outFrameInfo{}, ErrFrameTooBig + } + + info := outFrameInfo{ + uncompressedSize: len(f.buf) - f.headSize, } if f.buf[1]&flagCompress == flagCompress { @@ -779,15 +794,16 @@ func (f *framer) finish() error { // TODO: only compress frames which are big enough compressed, err := f.compres.Encode(f.buf[f.headSize:]) if err != nil { - return err + return info, err } f.buf = append(f.buf[:f.headSize], compressed...) } length := len(f.buf) - f.headSize + info.compressedSize = length f.setLength(length) - return nil + return info, nil } func (f *framer) writeTo(w io.Writer) error { @@ -833,7 +849,7 @@ func (w writeStartupFrame) String() string { return fmt.Sprintf("[startup opts=%+v]", w.opts) } -func (w *writeStartupFrame) buildFrame(f *framer, streamID int) error { +func (w *writeStartupFrame) buildFrame(f *framer, streamID int) (outFrameInfo, error) { f.writeHeader(f.flags&^flagCompress, opStartup, streamID) f.writeStringMap(w.opts) @@ -846,7 +862,7 @@ type writePrepareFrame struct { customPayload map[string][]byte } -func (w *writePrepareFrame) buildFrame(f *framer, streamID int) error { +func (w *writePrepareFrame) buildFrame(f *framer, streamID int) (outFrameInfo, error) { if len(w.customPayload) > 0 { f.payload() } @@ -1436,11 +1452,11 @@ func (a *writeAuthResponseFrame) String() string { return fmt.Sprintf("[auth_response data=%q]", a.data) } -func (a *writeAuthResponseFrame) buildFrame(framer *framer, streamID int) error { +func (a *writeAuthResponseFrame) buildFrame(framer *framer, streamID int) (outFrameInfo, error) { return framer.writeAuthResponseFrame(streamID, a.data) } -func (f *framer) writeAuthResponseFrame(streamID int, data []byte) error { +func (f *framer) writeAuthResponseFrame(streamID int, data []byte) (outFrameInfo, error) { f.writeHeader(f.flags, opAuthResponse, streamID) f.writeBytes(data) return f.finish() @@ -1474,11 +1490,13 @@ func (q queryParams) String() string { q.consistency, q.skipMeta, q.pageSize, q.pagingState, q.serialConsistency, q.defaultTimestamp, q.values, q.keyspace) } -func (f *framer) writeQueryParams(opts *queryParams) { +// writeQueryParams writes the queryParameters to the buffer. +// It returns the total size of the values. +func (f *framer) writeQueryParams(opts *queryParams) int { f.writeConsistency(opts.consistency) if f.proto == protoVersion1 { - return + return 0 } var flags byte @@ -1526,6 +1544,7 @@ func (f *framer) writeQueryParams(opts *queryParams) { f.writeByte(flags) } + startIdx := len(f.buf) if n := len(opts.values); n > 0 { f.writeShort(uint16(n)) @@ -1540,6 +1559,7 @@ func (f *framer) writeQueryParams(opts *queryParams) { } } } + valuesSize := len(f.buf) - startIdx if opts.pageSize > 0 { f.writeInt(int32(opts.pageSize)) @@ -1567,6 +1587,8 @@ func (f *framer) writeQueryParams(opts *queryParams) { if opts.keyspace != "" { f.writeString(opts.keyspace) } + + return valuesSize } type writeQueryFrame struct { @@ -1581,29 +1603,32 @@ func (w *writeQueryFrame) String() string { return fmt.Sprintf("[query statement=%q params=%v]", w.statement, w.params) } -func (w *writeQueryFrame) buildFrame(framer *framer, streamID int) error { +func (w *writeQueryFrame) buildFrame(framer *framer, streamID int) (outFrameInfo, error) { return framer.writeQueryFrame(streamID, w.statement, &w.params, w.customPayload) } -func (f *framer) writeQueryFrame(streamID int, statement string, params *queryParams, customPayload map[string][]byte) error { +func (f *framer) writeQueryFrame(streamID int, statement string, params *queryParams, customPayload map[string][]byte) (outFrameInfo, error) { if len(customPayload) > 0 { f.payload() } f.writeHeader(f.flags, opQuery, streamID) f.writeCustomPayload(&customPayload) f.writeLongString(statement) - f.writeQueryParams(params) + valuesSize := f.writeQueryParams(params) - return f.finish() + ofi, err := f.finish() + ofi.queryValuesSize = valuesSize + ofi.queryCount = 1 + return ofi, err } type frameBuilder interface { - buildFrame(framer *framer, streamID int) error + buildFrame(framer *framer, streamID int) (outFrameInfo, error) } -type frameWriterFunc func(framer *framer, streamID int) error +type frameWriterFunc func(framer *framer, streamID int) (outFrameInfo, error) -func (f frameWriterFunc) buildFrame(framer *framer, streamID int) error { +func (f frameWriterFunc) buildFrame(framer *framer, streamID int) (outFrameInfo, error) { return f(framer, streamID) } @@ -1619,20 +1644,22 @@ func (e *writeExecuteFrame) String() string { return fmt.Sprintf("[execute id=% X params=%v]", e.preparedID, &e.params) } -func (e *writeExecuteFrame) buildFrame(fr *framer, streamID int) error { +func (e *writeExecuteFrame) buildFrame(fr *framer, streamID int) (outFrameInfo, error) { return fr.writeExecuteFrame(streamID, e.preparedID, &e.params, &e.customPayload) } -func (f *framer) writeExecuteFrame(streamID int, preparedID []byte, params *queryParams, customPayload *map[string][]byte) error { +func (f *framer) writeExecuteFrame(streamID int, preparedID []byte, params *queryParams, customPayload *map[string][]byte) (outFrameInfo, error) { if len(*customPayload) > 0 { f.payload() } f.writeHeader(f.flags, opExecute, streamID) f.writeCustomPayload(customPayload) f.writeShortBytes(preparedID) + var valuesSize int if f.proto > protoVersion1 { - f.writeQueryParams(params) + valuesSize = f.writeQueryParams(params) } else { + startIdx := len(f.buf) n := len(params.values) f.writeShort(uint16(n)) for i := 0; i < n; i++ { @@ -1642,10 +1669,14 @@ func (f *framer) writeExecuteFrame(streamID int, preparedID []byte, params *quer f.writeBytes(params.values[i].value) } } + valuesSize = len(f.buf) - startIdx f.writeConsistency(params.consistency) } - return f.finish() + ofi, err := f.finish() + ofi.queryValuesSize = valuesSize + ofi.queryCount = 1 + return ofi, err } // TODO: can we replace BatchStatemt with batchStatement? As they prety much @@ -1670,11 +1701,11 @@ type writeBatchFrame struct { customPayload map[string][]byte } -func (w *writeBatchFrame) buildFrame(framer *framer, streamID int) error { +func (w *writeBatchFrame) buildFrame(framer *framer, streamID int) (outFrameInfo, error) { return framer.writeBatchFrame(streamID, w, w.customPayload) } -func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame, customPayload map[string][]byte) error { +func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame, customPayload map[string][]byte) (outFrameInfo, error) { if len(customPayload) > 0 { f.payload() } @@ -1687,6 +1718,8 @@ func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame, customPayload var flags byte + var queryParamsSize int + for i := 0; i < n; i++ { b := &w.statements[i] if len(b.preparedID) == 0 { @@ -1697,6 +1730,8 @@ func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame, customPayload f.writeShortBytes(b.preparedID) } + startIdx := len(f.buf) + f.writeShort(uint16(len(b.values))) for j := range b.values { col := b.values[j] @@ -1704,7 +1739,7 @@ func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame, customPayload // TODO: move this check into the caller and set a flag on writeBatchFrame // to indicate using named values if f.proto <= protoVersion5 { - return fmt.Errorf("gocql: named query values are not supported in batches, please see https://issues.apache.org/jira/browse/CASSANDRA-10246") + return outFrameInfo{}, fmt.Errorf("gocql: named query values are not supported in batches, please see https://issues.apache.org/jira/browse/CASSANDRA-10246") } flags |= flagWithNameValues f.writeString(col.name) @@ -1715,6 +1750,8 @@ func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame, customPayload f.writeBytes(col.value) } } + + queryParamsSize += len(f.buf) - startIdx } f.writeConsistency(w.consistency) @@ -1748,16 +1785,19 @@ func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame, customPayload } } - return f.finish() + ofi, err := f.finish() + ofi.queryValuesSize = queryParamsSize + ofi.queryCount = n + return ofi, err } type writeOptionsFrame struct{} -func (w *writeOptionsFrame) buildFrame(framer *framer, streamID int) error { +func (w *writeOptionsFrame) buildFrame(framer *framer, streamID int) (outFrameInfo, error) { return framer.writeOptionsFrame(streamID, w) } -func (f *framer) writeOptionsFrame(stream int, _ *writeOptionsFrame) error { +func (f *framer) writeOptionsFrame(stream int, _ *writeOptionsFrame) (outFrameInfo, error) { f.writeHeader(f.flags&^flagCompress, opOptions, stream) return f.finish() } @@ -1766,11 +1806,11 @@ type writeRegisterFrame struct { events []string } -func (w *writeRegisterFrame) buildFrame(framer *framer, streamID int) error { +func (w *writeRegisterFrame) buildFrame(framer *framer, streamID int) (outFrameInfo, error) { return framer.writeRegisterFrame(streamID, w) } -func (f *framer) writeRegisterFrame(streamID int, w *writeRegisterFrame) error { +func (f *framer) writeRegisterFrame(streamID int, w *writeRegisterFrame) (outFrameInfo, error) { f.writeHeader(f.flags, opRegister, streamID) f.writeStringList(w.events) diff --git a/frame_test.go b/frame_test.go index 6b8eb228e..9fa379ada 100644 --- a/frame_test.go +++ b/frame_test.go @@ -66,7 +66,7 @@ func TestFrameWriteTooLong(t *testing.T) { framer.writeHeader(0, opStartup, 1) framer.writeBytes(make([]byte, maxFrameSize+1)) - err := framer.finish() + _, err := framer.finish() if err != ErrFrameTooBig { t.Fatalf("expected to get %v got %v", ErrFrameTooBig, err) } @@ -103,3 +103,146 @@ func TestFrameReadTooLong(t *testing.T) { t.Fatalf("expected to get header %v got %v", opReady, head.op) } } + +func TestOutFrameInfo(t *testing.T) { + tests := map[string]struct { + frame frameBuilder + expectedInfo outFrameInfo + }{ + "query": { + frame: &writeQueryFrame{ + statement: "SELECT * FROM mytable WHERE id=? AND x=?", + params: queryParams{ + consistency: One, + skipMeta: false, + values: []queryValues{ + { + value: []byte{'H', 'e', 'l', 'l', 'o', 'W', 'o', 'r', 'l', 'd'}, + }, + { + value: []byte{'H', 'e', 'l', 'l', 'o', 'W', 'o', 'r', 'l', 'd'}, + }, + }, + pageSize: 5000, + pagingState: nil, + serialConsistency: 0, + defaultTimestamp: false, + defaultTimestampValue: 0, + keyspace: "", + }, + customPayload: nil, + }, + expectedInfo: outFrameInfo{ + uncompressedSize: 81, + compressedSize: 72, + queryValuesSize: 30, + queryCount: 1, + }, + }, + "execute": { + frame: &writeExecuteFrame{ + preparedID: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5}, + params: queryParams{ + values: []queryValues{ + { + value: []byte{'H', 'e', 'l', 'l', 'o', 'W', 'o', 'r', 'l', 'd'}, + }, + { + value: []byte{'H', 'e', 'l', 'l', 'o', 'W', 'o', 'r', 'l', 'd'}, + }, + }, + }, + customPayload: nil, + }, + expectedInfo: outFrameInfo{ + compressedSize: 50, + uncompressedSize: 51, + queryValuesSize: 30, + queryCount: 1, + }, + }, + "batch": { + frame: &writeBatchFrame{ + typ: UnloggedBatch, + statements: []batchStatment{ + { + preparedID: nil, + statement: "SELECT * FROM mytable WHERE id=? AND x=?", + values: []queryValues{ + { + value: []byte{'H', 'e', 'l', 'l', 'o', 'W', 'o', 'r', 'l', 'd'}, + }, + { + value: []byte{'H', 'e', 'l', 'l', 'o', 'W', 'o', 'r', 'l', 'd'}, + }, + }, + }, + { + preparedID: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5}, + statement: "", + values: []queryValues{ + { + value: []byte{'H', 'e', 'l', 'l', 'o', 'W', 'o', 'r', 'l', 'd'}, + }, + { + value: []byte{'H', 'e', 'l', 'l', 'o', 'W', 'o', 'r', 'l', 'd'}, + }, + }, + }, + }, + consistency: One, + serialConsistency: 0, + defaultTimestamp: false, + defaultTimestampValue: 0, + customPayload: nil, + }, + expectedInfo: outFrameInfo{ + compressedSize: 96, + uncompressedSize: 130, + queryValuesSize: 60, + queryCount: 2, + }, + }, + "options": { + frame: &writeOptionsFrame{}, + expectedInfo: outFrameInfo{ + compressedSize: 0, + uncompressedSize: 0, + queryValuesSize: 0, + queryCount: 0, + }, + }, + "register": { + frame: &writeRegisterFrame{ + events: []string{"event1", "event2"}, + }, + expectedInfo: outFrameInfo{ + compressedSize: 20, + uncompressedSize: 18, + queryValuesSize: 0, + queryCount: 0, + }, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + fr := newFramer(SnappyCompressor{}, 4) + ofi, err := test.frame.buildFrame(fr, 42) + if err != nil { + t.Fatal(err) + } + if ofi.queryCount != test.expectedInfo.queryCount { + t.Errorf("expected queryCount %d, but got %d", test.expectedInfo.queryCount, ofi.queryCount) + } + if ofi.queryValuesSize != test.expectedInfo.queryValuesSize { + t.Errorf("expected queryValuesSize %d, but got %d", test.expectedInfo.queryValuesSize, ofi.queryValuesSize) + } + if ofi.uncompressedSize != test.expectedInfo.uncompressedSize { + t.Errorf("expected uncompressedSize %d, but got %d", test.expectedInfo.uncompressedSize, ofi.uncompressedSize) + } + if ofi.compressedSize != test.expectedInfo.compressedSize { + t.Errorf("expected compressedSize %d, but got %d", test.expectedInfo.compressedSize, ofi.compressedSize) + } + }) + } +}