From 1639541304af01e59650f69cfc90f3abe0a2a0af Mon Sep 17 00:00:00 2001 From: Yasuhiro Matsumoto Date: Mon, 6 Nov 2023 23:07:07 +0900 Subject: [PATCH] separate msgState for reader/writer --- connection.go | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/connection.go b/connection.go index 5385f35..8ac057a 100644 --- a/connection.go +++ b/connection.go @@ -24,7 +24,8 @@ type Connection struct { reader *wsutil.Reader flateWriter *wsflate.Writer writer *wsutil.Writer - msgState *wsflate.MessageState + msgStateR *wsflate.MessageState + msgStateW *wsflate.MessageState } func NewConnection(ctx context.Context, url string, requestHeader http.Header) (*Connection, error) { @@ -51,9 +52,9 @@ func NewConnection(ctx context.Context, url string, requestHeader http.Header) ( // reader var flateReader *wsflate.Reader - var msgState wsflate.MessageState + var msgStateR wsflate.MessageState if enableCompression { - msgState.SetCompressed(true) + msgStateR.SetCompressed(true) flateReader = wsflate.NewReader(nil, func(r io.Reader) wsflate.Decompressor { return flate.NewReader(r) @@ -67,13 +68,16 @@ func NewConnection(ctx context.Context, url string, requestHeader http.Header) ( OnIntermediate: controlHandler, CheckUTF8: false, Extensions: []wsutil.RecvExtension{ - &msgState, + &msgStateR, }, } // writer var flateWriter *wsflate.Writer + var msgStateW wsflate.MessageState if enableCompression { + msgStateW.SetCompressed(true) + flateWriter = wsflate.NewWriter(nil, func(w io.Writer) wsflate.Compressor { fw, err := flate.NewWriter(w, 4) if err != nil { @@ -84,7 +88,7 @@ func NewConnection(ctx context.Context, url string, requestHeader http.Header) ( } writer := wsutil.NewWriter(conn, state, ws.OpText) - writer.SetExtensions(&msgState) + writer.SetExtensions(&msgStateW) return &Connection{ conn: conn, @@ -92,14 +96,15 @@ func NewConnection(ctx context.Context, url string, requestHeader http.Header) ( controlHandler: controlHandler, flateReader: flateReader, reader: reader, + msgStateR: &msgStateR, flateWriter: flateWriter, - msgState: &msgState, writer: writer, + msgStateW: &msgStateW, }, nil } func (c *Connection) WriteMessage(data []byte) error { - if c.msgState.IsCompressed() && c.enableCompression { + if c.msgStateW.IsCompressed() && c.enableCompression { c.flateWriter.Reset(c.writer) if _, err := io.Copy(c.flateWriter, bytes.NewReader(data)); err != nil { return fmt.Errorf("failed to write message: %w", err) @@ -149,7 +154,7 @@ func (c *Connection) ReadMessage(ctx context.Context, buf io.Writer) error { } } - if c.msgState.IsCompressed() && c.enableCompression { + if c.msgStateR.IsCompressed() && c.enableCompression { c.flateReader.Reset(c.reader) if _, err := io.Copy(buf, c.flateReader); err != nil { return fmt.Errorf("failed to read message: %w", err)