-
Notifications
You must be signed in to change notification settings - Fork 0
/
staticFileServerHandler.go
200 lines (170 loc) · 5.69 KB
/
staticFileServerHandler.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
package spaserve
import (
"errors"
"io/fs"
"log/slog"
"net/http"
"os"
"path"
"strings"
"github.com/psanford/memfs"
)
type StaticFilesHandler struct {
opts staticFilesHandlerOpts
fileServer http.Handler
mfilesys *memfs.FS
logger *servespaLogger
muxErrHandler func(int, http.ResponseWriter, *http.Request)
}
type staticFilesHandlerOpts struct {
ns string
basePath string
logger *slog.Logger
muxErrHandler func(int) http.Handler
webEnv any
}
type staticFilesHandlerFunc func(staticFilesHandlerOpts) staticFilesHandlerOpts
var defaultStaticFilesHandlerOpts = staticFilesHandlerOpts{
ns: "APP_ENV",
basePath: "/",
logger: nil,
muxErrHandler: nil,
webEnv: nil,
}
// WithLogger sets the logger for the static file server. Defaults to slog.Logger.
func WithLogger(logger *slog.Logger) staticFilesHandlerFunc {
return func(c staticFilesHandlerOpts) staticFilesHandlerOpts {
c.logger = logger
return c
}
}
// WithBasePath sets the base path for the web server which will be trimmed from the request path before looking up files.
func WithBasePath(basePath string) staticFilesHandlerFunc {
if basePath == "" {
basePath = defaultStaticFilesHandlerOpts.basePath
}
// ensure leading slash for trimming later
if basePath[0] != '/' {
basePath = "/" + basePath
}
// ensure trailing slash for trimming later
if basePath[len(basePath)-1] != '/' {
basePath = basePath + "/"
}
return func(c staticFilesHandlerOpts) staticFilesHandlerOpts {
c.basePath = basePath
return c
}
}
// WithMuxErrorHandler sets custom error handlers for the static file server.
//
// handler: a function that returns an http.Handler for the given status code
func WithMuxErrorHandler(handler func(int) http.Handler) staticFilesHandlerFunc {
return func(c staticFilesHandlerOpts) staticFilesHandlerOpts {
c.muxErrHandler = handler
return c
}
}
// WithInjectWebEnv injects the web environment into the static file server.
//
// env: the web environment to inject, use json struct tags to drive the marshalling
// namespace: the namespace to use for the web environment, defaults to "APP_ENV"
func WithInjectWebEnv(env any, namespace string) staticFilesHandlerFunc {
if namespace == "" {
namespace = defaultStaticFilesHandlerOpts.ns
}
return func(c staticFilesHandlerOpts) staticFilesHandlerOpts {
c.webEnv = env
c.ns = namespace
return c
}
}
// StaticFilesHandler creates a static file server handler that serves files from the given fs.FS.
// It serves index.html for the root path and 404 for actual static file requests that don't exist.
// - ctx: the context
// - filesys: the file system to serve files from - this will be copied to a memfs
// - fn: optional functions to configure the handler (e.g. WithLogger, WithBasePath, WithMuxErrorHandler, WithInjectWebEnv)
func NewStaticFilesHandler(filesys fs.FS, fn ...staticFilesHandlerFunc) (http.Handler, error) {
// process options
opts := defaultStaticFilesHandlerOpts
for _, f := range fn {
opts = f(opts)
}
var (
mfilesys *memfs.FS
err error
)
// inject web env if provided
if opts.webEnv != nil {
mfilesys, err = InjectWebEnv(filesys, opts.webEnv, opts.ns)
} else {
mfilesys, err = CopyFileSys(filesys, nil)
}
if err != nil {
return nil, err
}
// create file server
fileServer := http.FileServer(http.FS(mfilesys))
logger := newLogger(opts.logger)
return &StaticFilesHandler{
opts: opts,
mfilesys: mfilesys,
fileServer: fileServer,
logger: logger,
muxErrHandler: newMuxErrorHandler(opts.muxErrHandler),
}, nil
}
func (h *StaticFilesHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// clean path for security and consistency
cleanedPath := path.Clean(r.URL.Path)
cleanedPath = strings.TrimPrefix(cleanedPath, h.opts.basePath)
cleanedPath = strings.TrimPrefix(cleanedPath, "/")
cleanedPath = strings.TrimSuffix(cleanedPath, "/")
h.logger.logContext(ctx, slog.LevelDebug, "request", slog.Attr{Key: "cleanedPath", Value: slog.StringValue(cleanedPath)})
// reconstitute the path
r.URL.Path = "/" + cleanedPath
// use root path for index.html
if r.URL.Path == "index.html" {
r.URL.Path = "/"
}
// handle non-root paths
if r.URL.Path != "/" {
// open file
file, err := h.mfilesys.Open(cleanedPath)
isErr := err != nil
isErrNotExist := errors.Is(err, os.ErrNotExist)
isFile := path.Ext(cleanedPath) != ""
if file != nil {
file.Close()
}
// return 500 for other errors
if isErr && !isErrNotExist {
h.logger.logContext(ctx, slog.LevelError, "could not open file", slog.Attr{Key: "cleanedPath", Value: slog.StringValue(cleanedPath)})
h.muxErrHandler(http.StatusInternalServerError, w, r)
return
}
// return 404 for actual static file requests that don't exist
if isErrNotExist && isFile {
h.logger.logContext(ctx, slog.LevelDebug, "not found, static file", slog.Attr{Key: "cleanedPath", Value: slog.StringValue(cleanedPath)})
h.muxErrHandler(http.StatusNotFound, w, r)
return
}
// serve index.html and let SPA handle undefined routes
if isErrNotExist {
h.logger.logContext(ctx, slog.LevelDebug, "not found, serve index", slog.Attr{Key: "cleanedPath", Value: slog.StringValue(cleanedPath)})
r.URL.Path = "/"
}
}
h.fileServer.ServeHTTP(w, r)
}
// newMuxErrorHandler creates a new error handler function with the given muxErrHandler.
func newMuxErrorHandler(muxErrHandler func(int) http.Handler) func(int, http.ResponseWriter, *http.Request) {
return func(statusCode int, w http.ResponseWriter, r *http.Request) {
if muxErrHandler != nil {
muxErrHandler(statusCode).ServeHTTP(w, r)
return
}
http.Error(w, http.StatusText(statusCode), statusCode)
}
}