From 141d5675b3212443dcce87402439cf223c4da460 Mon Sep 17 00:00:00 2001 From: zuoxiupeng Date: Sat, 30 Dec 2017 14:10:51 +0800 Subject: [PATCH] =?UTF-8?q?=E9=A6=96=E6=AC=A1=E6=8F=90=E4=BA=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 3 + README.md | 10 +++- client.go | 89 +++++++++++++++++++++++++++++ conn.go | 81 ++++++++++++++++++++++++++ convert.go | 108 +++++++++++++++++++++++++++++++++++ default_packet.go | 100 ++++++++++++++++++++++++++++++++ handle_websocket.go | 39 +++++++++++++ interface.go | 30 ++++++++++ logger.go | 61 ++++++++++++++++++++ send_message.go | 53 +++++++++++++++++ socket.go | 122 +++++++++++++++++++++++++++++++++++++++ tcplibrary.go | 135 ++++++++++++++++++++++++++++++++++++++++++++ websocket.go | 113 ++++++++++++++++++++++++++++++++++++ 13 files changed, 942 insertions(+), 2 deletions(-) create mode 100644 client.go create mode 100644 conn.go create mode 100644 convert.go create mode 100644 default_packet.go create mode 100644 handle_websocket.go create mode 100644 interface.go create mode 100644 logger.go create mode 100644 send_message.go create mode 100644 socket.go create mode 100644 tcplibrary.go create mode 100644 websocket.go diff --git a/.gitignore b/.gitignore index a1338d6..3a19ac4 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,6 @@ # Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736 .glide/ + +examples/chat/client/client* +examples/chat/server/server* diff --git a/README.md b/README.md index ca545c6..9d8121d 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,8 @@ -# tcplibrary -golang 版的tcp通讯库,让tcp开发更简单 +# golang的tcp通讯库 + +使用此库你只需要创建一个server或client的结构体,实现接口,再写两句话就可以搭建起tcp通讯 + +## 备注 +自定义协议的小伙伴如果需要websocket通讯时,请参考default_packet.go中的`GetPayload()`函数,包内容问题 + +当需要获取连接列表时 可以调用`GetClients()`方法获取连接`*sync.Map` diff --git a/client.go b/client.go new file mode 100644 index 0000000..c95ad6f --- /dev/null +++ b/client.go @@ -0,0 +1,89 @@ +/* + * @Author: 时光弧线 + * @Date: 2017-12-30 11:54:57 + * @Last Modified by: 时光弧线 + * @Last Modified time: 2017-12-30 13:11:34 + */ +package tcplibrary + +import ( + "errors" + "net" +) + +/* tcp golang客户端 */ + +// TCPClient tcp客户端 +type TCPClient struct { + *TCPLibrary + conn *Conn // 连接对象 +} + +// NewTCPClient 创建一个tcp客户端 +func NewTCPClient(debug bool, socket Socket, packets ...Packet) (*TCPClient, error) { + if socket == nil { + return nil, errors.New("Socket参数不能是nil") + } + // 封包解包对象 + var packet Packet + if len(packets) == 0 { + packet = new(DefaultPacket) + } else { + packet = packets[0] + } + // 标记为客户端 + isServer = false + + return &TCPClient{ + TCPLibrary: &TCPLibrary{ + packet: packet, + socket: socket, + readDeadline: DefaultReadDeadline, + readBufferSize: DefaultBufferSize, + }, + }, nil +} + +// DialAndStart 连接到服务器,并开始读取信息 +func (c *TCPClient) DialAndStart(address string) error { + addr, err := net.ResolveTCPAddr("tcp", address) + if err != nil { + globalLogger.Errorf(err.Error()) + return err + } + conn, err := net.DialTCP("tcp", nil, addr) + if err != nil { + globalLogger.Errorf(err.Error()) + return err + } + // 判断是否设置读超时 + if c.readDeadline == 0 { + c.readDeadline = DefaultReadDeadline + } + // 赋值给当前连接对象 + c.conn = &Conn{ + Conn: conn, + connType: TCPSocketType, + packet: c.packet, + } + // 通知建立连接 + err = c.socket.OnConnect(c.conn) + if err != nil { + globalLogger.Errorf(err.Error()) + // 如果建立连接函数返回false,则关闭连接 + c.socket.OnClose(c.conn, err) // 通知关闭 + err = conn.Close() // 关闭连接 + if err != nil { + globalLogger.Errorf(err.Error()) + } + return err + } + // 开启一个协程处理数据接收 + go c.handleConn(c.conn) + return nil +} + +// GetConn 获取连接对象 +func (c *TCPClient) GetConn() *Conn { + return c.conn +} diff --git a/conn.go b/conn.go new file mode 100644 index 0000000..5efc88a --- /dev/null +++ b/conn.go @@ -0,0 +1,81 @@ +/* + * @Author: 时光弧线 + * @Date: 2017-12-30 11:55:02 + * @Last Modified by: 时光弧线 + * @Last Modified time: 2017-12-30 13:09:56 + */ +package tcplibrary + +import ( + "fmt" + "net" +) + +/* tcp 连接定义 */ + +// TCPType tcp连接类型 +type TCPType = int + +const ( + // TCPSocketType tcp连接 + TCPSocketType = iota + // WebSocketType WebSocket连接 + WebSocketType +) + +// Conn 自定义连接对象结构体,可以存储tcp或webSocket连接对象 +type Conn struct { + net.Conn + connType TCPType // 连接对象类型 + clientID string // 客户端id + packet Packet // 封闭解包对象 +} + +// SendMessage 发送消息,参数为自己报文结构体 +func (c *Conn) SendMessage(v interface{}) (int, error) { + // 判断是tcp还是websocket + if c.connType == TCPSocketType { // 二进制协议 + // 先封包,再发送数据 + data, err := c.packet.Marshal(v) + if err != nil { + globalLogger.Errorf(err.Error()) + return 0, err + } + return c.Write(data) + } else if c.connType == WebSocketType { // json方式 + data, err := c.packet.MarshalToJSON(v) + if err != nil { + globalLogger.Errorf(err.Error()) + return 0, err + } + c.Write(data) + } else { + globalLogger.Errorf("不支持的连接方式") + } + return 0, nil +} + +// GetClientID 获取当前连接id +func (c *Conn) GetClientID() string { + return c.clientID +} + +// GetConnType 获取连接类型 +func (c *Conn) GetConnType() TCPType { + return c.connType +} + +// CloseForClientID 根据clientID关闭连接 +func CloseForClientID(clientID string) error { + // log.Println(clientID) + connInterface, ok := clients.Load(clientID) + if ok == false { + return fmt.Errorf("踢人失败,没有这样的连接1:%s", clientID) + } + if conn, ok := connInterface.(*Conn); ok == true { + conn.Close() + } else { + return fmt.Errorf("踢人失败,没有这样的连接2:%s", clientID) + } + return nil +} diff --git a/convert.go b/convert.go new file mode 100644 index 0000000..17a16f7 --- /dev/null +++ b/convert.go @@ -0,0 +1,108 @@ +/* + * @Author: 时光弧线 + * @Date: 2017-12-30 11:55:06 + * @Last Modified by: 时光弧线 + * @Last Modified time: 2017-12-30 11:55:06 + */ +package tcplibrary + +import ( + "encoding/binary" + "math" +) + +func ByteToBool(i byte) bool { + if i == 1 { + return true + } + return false +} + +func BytesToUint16(data []byte) uint16 { + return binary.LittleEndian.Uint16(data) +} + +func BytesToUint32(data []byte) uint32 { + return binary.LittleEndian.Uint32(data) +} + +func BytesToUint64(data []byte) uint64 { + return binary.LittleEndian.Uint64(data) +} + +func BytesToInt16(data []byte) int16 { + return int16(BytesToUint16(data)) +} + +func BytesToInt32(data []byte) int32 { + return int32(BytesToUint16(data)) +} + +func BytesToInt64(data []byte) int64 { + return int64(BytesToUint64(data)) +} + +func BytesToInt(data []byte) int { + switch len(data) { + case 2: + return int(BytesToUint16(data)) + case 4: + return int(BytesToUint32(data)) + case 8: + return int(BytesToUint64(data)) + } + return 0 +} + +//IntToBytes 整形转换成byte数组 +func IntToBytes(data interface{}) []byte { + var buf []byte + switch data.(type) { + case int: + buf = make([]byte, 8) + target, _ := data.(int) + binary.LittleEndian.PutUint64(buf, uint64(target)) + case int16: + buf = make([]byte, 2) + target, _ := data.(int16) + binary.LittleEndian.PutUint16(buf, uint16(target)) + case int32: + buf = make([]byte, 4) + target, _ := data.(int32) + binary.LittleEndian.PutUint32(buf, uint32(target)) + case int64: + buf = make([]byte, 8) + target, _ := data.(int64) + binary.LittleEndian.PutUint64(buf, uint64(target)) + case uint: + buf = make([]byte, 8) + target, _ := data.(uint) + binary.LittleEndian.PutUint64(buf, uint64(target)) + case uint16: + buf = make([]byte, 2) + target, _ := data.(uint16) + binary.LittleEndian.PutUint16(buf, target) + case uint32: + buf = make([]byte, 4) + target, _ := data.(uint32) + binary.LittleEndian.PutUint32(buf, target) + case uint64: + buf = make([]byte, 8) + target, _ := data.(uint64) + binary.LittleEndian.PutUint64(buf, target) + } + return buf +} + +func Float64frombytes(bytes []byte) float64 { + bits := binary.LittleEndian.Uint64(bytes) + float := math.Float64frombits(bits) + return float +} + +func Float32bytes(float float32) []byte { + bits := math.Float32bits(float) + bytes := make([]byte, 4) + binary.LittleEndian.PutUint32(bytes, bits) + return bytes +} diff --git a/default_packet.go b/default_packet.go new file mode 100644 index 0000000..2414c77 --- /dev/null +++ b/default_packet.go @@ -0,0 +1,100 @@ +/* + * @Author: 时光弧线 + * @Date: 2017-12-30 11:55:15 + * @Last Modified by: 时光弧线 + * @Last Modified time: 2017-12-30 14:02:48 + */ +package tcplibrary + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" +) + +// DefaultPacket 协议包 +type DefaultPacket struct { + Length int32 `json:"Length"` // Payload 包长度,4字节 + Payload interface{} `json:"Payload"` // 报文内容,n字节 +} + +// GetPayload 获取包内容 +// 之所以有此函数,是为了兼容websocket,websocket使用此库时可直接使用json传输数据 +func (dp *DefaultPacket) GetPayload() []byte { + switch dp.Payload.(type) { + case string: + return []byte(dp.Payload.(string)) + case []byte: + return dp.Payload.([]byte) + default: + js, err := json.Marshal(dp.Payload) + if err == nil { + return js + } + globalLogger.Errorf("默认包结构获取错误:%v", err) + } + return make([]byte, 0) +} + +// Unmarshal 默认解包 +func (dp *DefaultPacket) Unmarshal(data []byte, c chan interface{}) (outData []byte, err error) { + // 捕获异常 + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("%T", r) + globalLogger.Fatalf("默认解包错误:%v", err) + } + }() + // 长度不足4个字节无法获取包长度 + if len(data) < 4 { + return data, err + } + // 获取包长度 + packetLength := BytesToInt32(data[0:4]) + 4 + // 判断是否达到一个包长,没有达到直接返回 + if len(data) < int(packetLength) { + return data, err + } + // 截取一个包的长度,解包 + packetData := data[:packetLength] + // 解析内容和长度 + packet := new(DefaultPacket) + packet.Length = int32(packetLength) + packet.Payload = packetData[4:] + // 写入管道数据,用于通知实际业务逻辑 + c <- packet + // 递归调用解包 + return dp.Unmarshal(data[packetLength:], c) +} + +// Marshal 默认封包 +func (dp *DefaultPacket) Marshal(v interface{}) ([]byte, error) { + packet, ok := v.(*DefaultPacket) + if ok == false { + return nil, errors.New("封包参数不是*DefaultPacket") + } + // 获取内容 + payload := packet.GetPayload() + packet.Length = int32(len(payload)) + + /* 创建Buffer对象,写入头和数据 */ + packetData := bytes.NewBuffer([]byte{}) + lengthByte := IntToBytes(packet.Length) // 长度转byte + packetData.Write(lengthByte) + packetData.Write(payload) + + // 返回编码后的字节数组 + return packetData.Bytes(), nil +} + +// MarshalToJSON 编码到json, 同时将Payload转为字符串 +func (dp *DefaultPacket) MarshalToJSON(v interface{}) ([]byte, error) { + packet, ok := v.(*DefaultPacket) + if ok == false { + return nil, errors.New("封包参数不是*DefaultPacket") + } + packet.Payload = string(packet.GetPayload()) + // 直接转json返回 + return json.Marshal(packet) +} diff --git a/handle_websocket.go b/handle_websocket.go new file mode 100644 index 0000000..0fc7939 --- /dev/null +++ b/handle_websocket.go @@ -0,0 +1,39 @@ +/* + * @Author: 时光弧线 + * @Date: 2017-12-30 11:55:19 + * @Last Modified by: 时光弧线 + * @Last Modified time: 2017-12-30 13:13:07 + */ +package tcplibrary + +import ( + "golang.org/x/net/websocket" +) + +/* websocket的连接处理,涉及到包内容解析,所以单独新建文件 */ + +// 用于websocket的连接处理函数 +func (ws *WebSocketServer) handleConn(conn *Conn) { + defer func() { + if r := recover(); r != nil { + globalLogger.Fatalf("%T", r) + } + }() + // 收到消息的管道 + messageChannel := make(chan interface{}, DefaultMessageChannelSize) + go ws.handleMessage(conn, messageChannel) + // 循环读取 websocket + for { + // 解析websocket传输的包 + defaultPacket := new(DefaultPacket) + err := websocket.JSON.Receive(conn.Conn.(*websocket.Conn), defaultPacket) + if err != nil { + globalLogger.Errorf(err.Error()) + // 关闭连接,并通知错误 + ws.closeConn(conn, err) + break + } + // 向管道写入数据 + messageChannel <- defaultPacket + } +} diff --git a/interface.go b/interface.go new file mode 100644 index 0000000..0ed9542 --- /dev/null +++ b/interface.go @@ -0,0 +1,30 @@ +/* + * @Author: 时光弧线 + * @Date: 2017-12-30 11:55:26 + * @Last Modified by: 时光弧线 + * @Last Modified time: 2017-12-30 13:22:12 + */ +package tcplibrary + +/* tcp库用到的接口定义 */ + +// Socket tcp通讯需要的一些回调函数 +type Socket interface { + OnConnect(*Conn) error // 连接建立时 + OnError(error) // 连接发生错误 + OnClose(*Conn, error) // 关闭连接时 + OnRecMessage(*Conn, interface{}) // 接收消息时 +} + +// ServerSocket 服务接口,实例化tcp server时传次参数 +type ServerSocket interface { + Socket + GetClientID() string // 获取session id生成规则 +} + +// Packet 封包和解包 +type Packet interface { + Unmarshal(data []byte, c chan interface{}) ([]byte, error) // 解包 + Marshal(v interface{}) ([]byte, error) // 封包 + MarshalToJSON(v interface{}) ([]byte, error) // 封包为json字符串形式 +} diff --git a/logger.go b/logger.go new file mode 100644 index 0000000..64112a3 --- /dev/null +++ b/logger.go @@ -0,0 +1,61 @@ +/* + * @Author: 时光弧线 + * @Date: 2017-12-30 11:55:33 + * @Last Modified by: 时光弧线 + * @Last Modified time: 2017-12-30 13:02:01 + */ +package tcplibrary + +import ( + "fmt" + "log" +) + +/* tcp 内部日志打印 */ + +// Logger 日志记录 +type Logger struct { + debug bool +} + +var globalLogger *Logger + +func init() { + globalLogger = new(Logger) + globalLogger.debug = true +} + +// SetDebug 设置是否为debug模式 +func (l *Logger) SetDebug(debug bool) { + l.debug = debug +} + +// Infof 打印错误 Info +func (l *Logger) Infof(format string, a ...interface{}) { + if l.debug == false { + return + } + format = fmt.Sprintf("INFO: %s\n", format) + log.Printf(format, a...) +} + +// Warnf 打印错误 Warn +func (l *Logger) Warnf(format string, a ...interface{}) { + if l.debug == false { + return + } + format = fmt.Sprintf("WARN: %s\n", format) + log.Printf(format, a...) +} + +// Errorf 打印错误 Error +func (l *Logger) Errorf(format string, a ...interface{}) { + format = fmt.Sprintf("ERROR: %s\n", format) + log.Printf(format, a...) +} + +// Fatalf 打印错误 Fatal +func (l *Logger) Fatalf(format string, a ...interface{}) { + format = fmt.Sprintf("FATAL: %s\n", format) + log.Printf(format, a...) +} diff --git a/send_message.go b/send_message.go new file mode 100644 index 0000000..a02ebf4 --- /dev/null +++ b/send_message.go @@ -0,0 +1,53 @@ +/* + * @Author: 时光弧线 + * @Date: 2017-12-30 11:55:38 + * @Last Modified by: 时光弧线 + * @Last Modified time: 2017-12-30 13:14:28 + */ +package tcplibrary + +import ( + "errors" +) + +/* 公共发送数据函数 */ + +// SendMessageToClients 发送数据给指定客户端 +// 返回值,第一个值为发送成功几个连接 +// 只有服务端可调用 +func SendMessageToClients(v interface{}, clientIDs ...string) (sendCount int, err error) { + if isServer == false { + return 0, errors.New("客户端不允许调用此函数") + } + for _, vv := range clientIDs { + if val, ok := clients.Load(vv); ok == true { + if conn, ok := val.(*Conn); ok == true { + _, err = conn.SendMessage(v) + if err == nil { + sendCount++ + } + } + } + } + return sendCount, err +} + +// SendMessageToAll 发送给所有客户端 +// 只有服务端可调用 +func SendMessageToAll(v interface{}) (int, error) { + if isServer == false { + return 0, errors.New("客户端不允许调用此函数") + } + sendCount := 0 + clients.Range(func(key, val interface{}) bool { + if conn, ok := val.(*Conn); ok == true { + _, err := conn.SendMessage(v) + if err != nil { + return true + } + sendCount++ + } + return true + }) + return sendCount, nil +} diff --git a/socket.go b/socket.go new file mode 100644 index 0000000..6ecaa68 --- /dev/null +++ b/socket.go @@ -0,0 +1,122 @@ +/* + * @Author: 时光弧线 + * @Date: 2017-12-30 11:55:41 + * @Last Modified by: 时光弧线 + * @Last Modified time: 2017-12-30 13:22:40 + */ +package tcplibrary + +import ( + "errors" + "fmt" + "net" + "time" +) + +/* tcp 连接 */ + +// TCPServer tcp服务端对象 +type TCPServer struct { + *TCPLibrary + listener *net.TCPListener // tcp监听 + isListener bool // 是否已监听 +} + +// NewTCPServer 创建一个server实例 +func NewTCPServer(debug bool, socket ServerSocket, packets ...Packet) (*TCPServer, error) { + if socket == nil { + return nil, errors.New("ServerSocket参数不能是nil") + } + // 封包解包对象 + var packet Packet + if len(packets) == 0 { + packet = new(DefaultPacket) + } else { + packet = packets[0] + } + // 标记为服务端 + isServer = true + + return &TCPServer{ + TCPLibrary: &TCPLibrary{ + packet: packet, + socket: socket, + readDeadline: DefaultReadDeadline, + readBufferSize: DefaultBufferSize, + }, + isListener: false, + }, nil +} + +// ListenAndServe 开始tcp监听 +func (tcp *TCPServer) ListenAndServe(address string) error { + if tcp.isListener == true { + return errors.New("已调用监听端口") + } + if address == "" { + return errors.New("监听地址不能为空") + } + // 开启tcp监听 + addr, err := net.ResolveTCPAddr("tcp", address) + if err != nil { + globalLogger.Errorf(err.Error()) + return err + } + listen, err := net.ListenTCP("tcp", addr) + if err != nil { + return err + } + // 判断是否设置读超时 + if tcp.readDeadline == 0 { + tcp.readDeadline = DefaultReadDeadline + } + // 监听对象赋值给当前对象,并将isListener设为true + tcp.listener = listen + tcp.isListener = true + // 打印开启tcp服务 + globalLogger.Infof("tcp socket start, net %s addr %s", listen.Addr().Network(), listen.Addr().String()) + // 开始接收客户端连接 + for { + tcpConn, err := tcp.listener.Accept() + if err != nil { + globalLogger.Errorf(err.Error()) + continue + } + // 创建一个Conn对象 + conn := &Conn{ + Conn: tcpConn, + connType: TCPSocketType, + packet: tcp.packet, + } + // 获取客户端id + serverSocket, ok := tcp.socket.(ServerSocket) + if ok == false { + // 如果建立连接函数返回false,则关闭连接 + tcp.socket.OnClose(conn, fmt.Errorf("%s", "转换为ServerSocket错误")) // 通知关闭 + err = conn.Close() // 关闭连接 + if err != nil { + globalLogger.Errorf(err.Error()) + } + break + } + clientID := serverSocket.GetClientID() + conn.clientID = clientID + // 通知连接创建后函数 + err = tcp.socket.OnConnect(conn) + if err != nil { + // 如果建立连接函数返回false,则关闭连接 + tcp.socket.OnClose(conn, err) // 通知关闭 + err = conn.Close() // 关闭连接 + if err != nil { + globalLogger.Errorf(err.Error()) + } + break + } + clients.Store(clientID, conn) + // 设置超时 + conn.SetReadDeadline(time.Now().Add(tcp.readDeadline)) + // 开启一个协程处理数据接收 + go tcp.handleConn(conn) + } + return nil +} diff --git a/tcplibrary.go b/tcplibrary.go new file mode 100644 index 0000000..d863f78 --- /dev/null +++ b/tcplibrary.go @@ -0,0 +1,135 @@ +/* + * @Author: 时光弧线 + * @Date: 2017-12-30 11:55:47 + * @Last Modified by: 时光弧线 + * @Last Modified time: 2017-12-30 13:48:57 + */ +package tcplibrary + +import ( + "errors" + "io" + "sync" + "time" +) + +/* 通讯库父类 */ + +// 定义的tcp读缓存区大小 +var ( + DefaultBufferSize = 1024 + DefaultMessageChannelSize = 32 + DefaultReadDeadline = 15 * time.Second +) + +var ( + // 保存所有 + clients *sync.Map + // 是否是服务端 + isServer bool +) + +func init() { + // 初始化客户端存储map + clients = new(sync.Map) + // 默认非服务端 + isServer = false +} + +// GetClients 获取客户端列表,在自己的业务中使用,使用时切记小心操作 +func GetClients() *sync.Map { + return clients +} + +// TCPLibrary tcp库父类 +type TCPLibrary struct { + socket Socket // socket 需要实现的几个方法 + packet Packet // 解包和封包 + readDeadline time.Duration // 读超时 + readBufferSize int // 读数据时的字节缓冲 +} + +// SetReadDeadline 设置参数 readDeadline +func (t *TCPLibrary) SetReadDeadline(duration time.Duration) { + t.readDeadline = duration +} + +// SetReadBufferSize 设置参数 readBufferSize +func (t *TCPLibrary) SetReadBufferSize(readBufferSize int) { + t.readBufferSize = readBufferSize +} + +// 收到消息时处理 +func (t *TCPLibrary) handleMessage(conn *Conn, message chan interface{}) { + for { + select { + case v := <-message: + if isServer == true { + // 设置超时 + conn.SetReadDeadline(time.Now().Add(t.readDeadline)) + } + // 调用消息回调 + go t.socket.OnRecMessage(conn, v) + } + } +} + +// DelClients 删除一个客户端对象 +func (t *TCPLibrary) delClients(keys ...interface{}) { + for _, v := range keys { + clients.Delete(v) + } +} + +// 处理从连接中读取数据 +func (t *TCPLibrary) handleConn(conn *Conn) { + defer func() { + if r := recover(); r != nil { + globalLogger.Fatalf("%T", r) + } + }() + // 收到消息的管道 + messageChannel := make(chan interface{}, DefaultMessageChannelSize) + go t.handleMessage(conn, messageChannel) + // 缓冲区大小 + bufferSize := t.readBufferSize + if bufferSize == 0 { + bufferSize = DefaultBufferSize + } + data := make([]byte, 0) + buf := make([]byte, bufferSize) + for { + n, err := conn.Read(buf) + if err != nil { + // 关闭连接,并通知错误 + t.closeConn(conn, err) + break + } + // 解包 + data, err = conn.packet.Unmarshal(append(data, buf[:n]...), messageChannel) + if err != nil { + globalLogger.Errorf("%s", err.Error()) + t.socket.OnError(err) + } + } +} + +// 关闭连接,并通知错误 +func (t *TCPLibrary) closeConn(conn *Conn, err error) { + // 判断错误是不是nil和io.EOF + if err != nil && err != io.EOF { + globalLogger.Errorf(err.Error()) + // 通知错误 + t.socket.OnError(err) + } else { + err = errors.New("") + } + // 关闭连接 + t.socket.OnClose(conn, err) // 通知关闭 + // 删除客户端连接 + t.delClients(conn.clientID) + err = conn.Close() + if err != nil { + globalLogger.Errorf("%s", err.Error()) + } +} diff --git a/websocket.go b/websocket.go new file mode 100644 index 0000000..e7a162c --- /dev/null +++ b/websocket.go @@ -0,0 +1,113 @@ +/* + * @Author: 时光弧线 + * @Date: 2017-12-30 11:55:50 + * @Last Modified by: 时光弧线 + * @Last Modified time: 2017-12-30 13:23:42 + */ +package tcplibrary + +import ( + "errors" + "fmt" + "net" + "net/http" + "time" + + "golang.org/x/net/websocket" +) + +/* websocket 服务端 */ + +// WebSocketServer websocket 服务端操作对象 +type WebSocketServer struct { + *TCPLibrary + listener *net.TCPListener // tcp监听 + isListener bool // 是否已监听 +} + +// NewWebSocketServer 创建一个websocket监听 +func NewWebSocketServer(debug bool, socket ServerSocket, packets ...Packet) (*WebSocketServer, error) { + if socket == nil { + return nil, errors.New("ServerSocket参数不能是nil") + } + // 封包解包对象 + var packet Packet + if len(packets) == 0 { + packet = new(DefaultPacket) + } else { + packet = packets[0] + } + // 标记为服务端 + isServer = true + + return &WebSocketServer{ + TCPLibrary: &TCPLibrary{ + packet: packet, + socket: socket, + readDeadline: DefaultReadDeadline, + readBufferSize: DefaultBufferSize, + }, + isListener: false, + }, nil +} + +// ListenAndServe 开始tcp监听 +// address 监听的地址和端口 +// route 监听的路由(url) +func (ws *WebSocketServer) ListenAndServe(address, route string) error { + if address == "" { + return errors.New("监听地址不能为空") + } + if route == "" { + route = "/" + } + // 判断是否设置读超时 + if ws.readDeadline == 0 { + ws.readDeadline = DefaultReadDeadline + } + http.Handle(route, websocket.Handler(ws.handleWebSocketConn)) + globalLogger.Infof("web socket start, net websocket addr %s", address) + err := http.ListenAndServe(address, nil) + return err +} + +// 处理WebSocket数据 +func (ws *WebSocketServer) handleWebSocketConn(wsConn *websocket.Conn) { + // 构建Conn对象 + conn := &Conn{ + Conn: wsConn, + connType: WebSocketType, + packet: ws.packet, + } + // 保存连接到客户端数组 + serverSocket, ok := ws.socket.(ServerSocket) + if ok == false { + // 如果建立连接函数返回false,则关闭连接 + ws.socket.OnClose(conn, fmt.Errorf("%s", "转换为ServerSocket错误")) // 通知关闭 + err := conn.Close() // 关闭连接 + if err != nil { + globalLogger.Errorf("%s", err.Error()) + } + return + } + // 补上客户端id和封包解包对象,并存入服务端客户端对象 + clientID := serverSocket.GetClientID() + conn.clientID = clientID + clients.Store(clientID, conn) + // 设置超时 + conn.SetReadDeadline(time.Now().Add(ws.readDeadline)) + // 调用OnConnect + // 通知连接创建后函数 + err := ws.socket.OnConnect(conn) + if err != nil { + // 如果建立连接函数返回false,则关闭连接 + ws.socket.OnClose(conn, err) // 通知关闭 + err = conn.Close() // 关闭连接 + if err != nil { + globalLogger.Errorf("%s", err.Error()) + } + return + } + // 调用websocket连接处理方法 + ws.handleConn(conn) +}