forked from quic-go/webtransport-go
-
Notifications
You must be signed in to change notification settings - Fork 0
/
server.go
217 lines (194 loc) · 5.5 KB
/
server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
package webtransport
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"sync"
"time"
"unicode/utf8"
"github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/http3"
"github.com/lucas-clemente/quic-go/quicvarint"
)
const (
webTransportDraftOfferHeaderKey = "Sec-Webtransport-Http3-Draft02"
webTransportDraftHeaderKey = "Sec-Webtransport-Http3-Draft"
webTransportDraftHeaderValue = "draft02"
)
const (
webTransportFrameType = 0x41
webTransportUniStreamType = 0x54
)
type Server struct {
H3 http3.Server
// StreamReorderingTime is the time an incoming WebTransport stream that cannot be associated
// with a session is buffered.
// This can happen if the CONNECT request (that creates a new session) is reordered, and arrives
// after the first WebTransport stream(s) for that session.
// Defaults to 5 seconds.
StreamReorderingTimeout time.Duration
// CheckOrigin is used to validate the request origin, thereby preventing cross-site request forgery.
// CheckOrigin returns true if the request Origin header is acceptable.
// If unset, a safe default is used: If the Origin header is set, it is checked that it
// matches the request's Host header.
CheckOrigin func(r *http.Request) bool
ctx context.Context // is closed when Close is called
ctxCancel context.CancelFunc
refCount sync.WaitGroup
initOnce sync.Once
initErr error
conns *sessionManager
}
func (s *Server) initialize() error {
s.initOnce.Do(func() {
s.initErr = s.init()
})
return s.initErr
}
func (s *Server) init() error {
s.ctx, s.ctxCancel = context.WithCancel(context.Background())
timeout := s.StreamReorderingTimeout
if timeout == 0 {
timeout = 5 * time.Second
}
s.conns = newSessionManager(timeout)
if s.CheckOrigin == nil {
s.CheckOrigin = checkSameOrigin
}
// configure the http3.Server
if s.H3.AdditionalSettings == nil {
s.H3.AdditionalSettings = make(map[uint64]uint64)
}
s.H3.AdditionalSettings[settingsEnableWebtransport] = 1
s.H3.EnableDatagrams = true
if s.H3.StreamHijacker != nil {
return errors.New("StreamHijacker already set")
}
s.H3.StreamHijacker = func(ft http3.FrameType, qconn quic.Connection, str quic.Stream, err error) (bool /* hijacked */, error) {
if isWebTransportError(err) {
return true, nil
}
if ft != webTransportFrameType {
return false, nil
}
id, err := quicvarint.Read(quicvarint.NewReader(str))
if err != nil {
if isWebTransportError(err) {
return true, nil
}
return false, err
}
s.conns.AddStream(qconn, str, sessionID(id))
return true, nil
}
s.H3.UniStreamHijacker = func(st http3.StreamType, qconn quic.Connection, str quic.ReceiveStream, err error) (hijacked bool) {
if st != webTransportUniStreamType && !isWebTransportError(err) {
return false
}
s.conns.AddUniStream(qconn, str)
return true
}
return nil
}
func (s *Server) Serve(conn net.PacketConn) error {
if err := s.initialize(); err != nil {
return err
}
return s.H3.Serve(conn)
}
func (s *Server) ListenAndServe() error {
if err := s.initialize(); err != nil {
return err
}
return s.H3.ListenAndServe()
}
func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
if err := s.initialize(); err != nil {
return err
}
return s.H3.ListenAndServeTLS(certFile, keyFile)
}
func (s *Server) Close() error {
// Make sure that ctxCancel is defined.
// This is expected to be uncommon.
// It only happens if the server is closed without Serve / ListenAndServe having been called.
s.initOnce.Do(func() {})
if s.ctxCancel != nil {
s.ctxCancel()
}
if s.conns != nil {
s.conns.Close()
}
err := s.H3.Close()
s.refCount.Wait()
return err
}
func (s *Server) Upgrade(w http.ResponseWriter, r *http.Request) (*Session, error) {
if r.Method != http.MethodConnect {
return nil, fmt.Errorf("expected CONNECT request, got %s", r.Method)
}
if r.Proto != protocolHeader {
return nil, fmt.Errorf("unexpected protocol: %s", r.Proto)
}
if v, ok := r.Header[webTransportDraftOfferHeaderKey]; !ok || len(v) != 1 || v[0] != "1" {
return nil, fmt.Errorf("missing or invalid %s header", webTransportDraftOfferHeaderKey)
}
if !s.CheckOrigin(r) {
return nil, errors.New("webtransport: request origin not allowed")
}
w.Header().Add(webTransportDraftHeaderKey, webTransportDraftHeaderValue)
w.WriteHeader(http.StatusOK)
w.(http.Flusher).Flush()
httpStreamer, ok := r.Body.(http3.HTTPStreamer)
if !ok { // should never happen, unless quic-go changed the API
return nil, errors.New("failed to take over HTTP stream")
}
str := httpStreamer.HTTPStream()
sID := sessionID(str.StreamID())
hijacker, ok := w.(http3.Hijacker)
if !ok { // should never happen, unless quic-go changed the API
return nil, errors.New("failed to hijack")
}
return s.conns.AddSession(
hijacker.StreamCreator(),
sID,
r.Body.(http3.HTTPStreamer).HTTPStream(),
), nil
}
// copied from https://github.com/gorilla/websocket
func checkSameOrigin(r *http.Request) bool {
origin := r.Header.Get("Origin")
if origin == "" {
return true
}
u, err := url.Parse(origin)
if err != nil {
return false
}
return equalASCIIFold(u.Host, r.Host)
}
// copied from https://github.com/gorilla/websocket
func equalASCIIFold(s, t string) bool {
for s != "" && t != "" {
sr, size := utf8.DecodeRuneInString(s)
s = s[size:]
tr, size := utf8.DecodeRuneInString(t)
t = t[size:]
if sr == tr {
continue
}
if 'A' <= sr && sr <= 'Z' {
sr = sr + 'a' - 'A'
}
if 'A' <= tr && tr <= 'Z' {
tr = tr + 'a' - 'A'
}
if sr != tr {
return false
}
}
return s == t
}