forked from YiQiu1984/lightsocks
-
Notifications
You must be signed in to change notification settings - Fork 0
/
securetcp.go
127 lines (115 loc) · 2.89 KB
/
securetcp.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
package lightsocks
import (
"io"
"log"
"net"
)
const (
bufSize = 1024
)
// 加密传输的 TCP Socket
type SecureTCPConn struct {
io.ReadWriteCloser
Cipher *Cipher
}
// 从输入流里读取加密过的数据,解密后把原数据放到bs里
func (secureSocket *SecureTCPConn) DecodeRead(bs []byte) (n int, err error) {
n, err = secureSocket.Read(bs)
if err != nil {
return
}
secureSocket.Cipher.Decode(bs[:n])
return
}
// 把放在bs里的数据加密后立即全部写入输出流
func (secureSocket *SecureTCPConn) EncodeWrite(bs []byte) (int, error) {
secureSocket.Cipher.Encode(bs)
return secureSocket.Write(bs)
}
// 从src中源源不断的读取原数据加密后写入到dst,直到src中没有数据可以再读取
func (secureSocket *SecureTCPConn) EncodeCopy(dst io.ReadWriteCloser) error {
buf := make([]byte, bufSize)
for {
readCount, errRead := secureSocket.Read(buf)
if errRead != nil {
if errRead != io.EOF {
return errRead
} else {
return nil
}
}
if readCount > 0 {
writeCount, errWrite := (&SecureTCPConn{
ReadWriteCloser: dst,
Cipher: secureSocket.Cipher,
}).EncodeWrite(buf[0:readCount])
if errWrite != nil {
return errWrite
}
if readCount != writeCount {
return io.ErrShortWrite
}
}
}
}
// 从src中源源不断的读取加密后的数据解密后写入到dst,直到src中没有数据可以再读取
func (secureSocket *SecureTCPConn) DecodeCopy(dst io.Writer) error {
buf := make([]byte, bufSize)
for {
readCount, errRead := secureSocket.DecodeRead(buf)
if errRead != nil {
if errRead != io.EOF {
return errRead
} else {
return nil
}
}
if readCount > 0 {
writeCount, errWrite := dst.Write(buf[0:readCount])
if errWrite != nil {
return errWrite
}
if readCount != writeCount {
return io.ErrShortWrite
}
}
}
}
// see net.DialTCP
func DialEncryptedTCP(raddr *net.TCPAddr, cipher *Cipher) (*SecureTCPConn, error) {
remoteConn, err := net.DialTCP("tcp", nil, raddr)
if err != nil {
return nil, err
}
// Conn被关闭时直接清除所有数据 不管没有发送的数据
remoteConn.SetLinger(0)
return &SecureTCPConn{
ReadWriteCloser: remoteConn,
Cipher: cipher,
}, nil
}
// see net.ListenTCP
func ListenEncryptedTCP(laddr *net.TCPAddr, cipher *Cipher, handleConn func(localConn *SecureTCPConn), didListen func(listenAddr *net.TCPAddr)) error {
listener, err := net.ListenTCP("tcp", laddr)
if err != nil {
return err
}
defer listener.Close()
if didListen != nil {
// didListen 可能有阻塞操作
go didListen(listener.Addr().(*net.TCPAddr))
}
for {
localConn, err := listener.AcceptTCP()
if err != nil {
log.Println(err)
continue
}
// localConn被关闭时直接清除所有数据 不管没有发送的数据
localConn.SetLinger(0)
go handleConn(&SecureTCPConn{
ReadWriteCloser: localConn,
Cipher: cipher,
})
}
}