diff --git a/main.go b/main.go index 1671f4f..5e8a7af 100644 --- a/main.go +++ b/main.go @@ -133,18 +133,27 @@ func proxy(w http.ResponseWriter, r *http.Request) { } } + setTimeHeader(w, "Last-Modified", attrs.Updated) + setStrHeader(w, "Content-Type", attrs.ContentType) + setStrHeader(w, "Content-Language", attrs.ContentLanguage) + setStrHeader(w, "Cache-Control", attrs.CacheControl) + setStrHeader(w, "Content-Disposition", attrs.ContentDisposition) + gzipAcceptable := strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") - objr, err := client.Bucket(attrs.Bucket).Object(attrs.Name).ReadCompressed(gzipAcceptable).NewReader(r.Context()) + + var objr *storage.Reader + if rangeHeader := r.Header.Get("Range"); rangeHeader != "" { + objH := client.Bucket(attrs.Bucket).Object(attrs.Name).ReadCompressed(gzipAcceptable) + serveRange(w, r, attrs.Updated, attrs.Size, objH) + return + } else { + objr, err = client.Bucket(attrs.Bucket).Object(attrs.Name).ReadCompressed(gzipAcceptable).NewReader(r.Context()) + } if err != nil { handleError(w, err) return } - setTimeHeader(w, "Last-Modified", attrs.Updated) - setStrHeader(w, "Content-Type", attrs.ContentType) - setStrHeader(w, "Content-Language", attrs.ContentLanguage) - setStrHeader(w, "Cache-Control", attrs.CacheControl) setStrHeader(w, "Content-Encoding", objr.Attrs.ContentEncoding) - setStrHeader(w, "Content-Disposition", attrs.ContentDisposition) setIntHeader(w, "Content-Length", objr.Attrs.Size) io.Copy(w, objr) } diff --git a/range.go b/range.go new file mode 100644 index 0000000..823e364 --- /dev/null +++ b/range.go @@ -0,0 +1,510 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP file system request handler + +package main + +import ( + "cloud.google.com/go/storage" + "context" + "errors" + "fmt" + "io" + "mime/multipart" + "net/http" + "net/textproto" + "strconv" + "strings" + "time" +) + +// errSeeker is returned by ServeContent's sizeFunc when the content +// doesn't seek properly. The underlying Seeker's error text isn't +// included in the sizeFunc reply so it's not sent over HTTP to end +// users. +var errSeeker = errors.New("seeker can't seek") + +// errNoOverlap is returned by serveContent's parseRange if first-byte-pos of +// all of the byte-range-spec values is greater than the content size. +var errNoOverlap = errors.New("invalid range: failed to overlap") + +// if name is empty, filename is unknown. (used for mime type, before sniffing) +// if modtime.IsZero(), modtime is unknown. +// content must be seeked to the beginning of the file. +// The sizeFunc is called at most once. Its error, if any, is sent in the HTTP response. +func serveRange(w http.ResponseWriter, r *http.Request, modtime time.Time, size int64, objH *storage.ObjectHandle) { + setLastModified(w, modtime) + done, rangeReq := checkPreconditions(w, r, modtime) + if done { + return + } + + code := http.StatusOK + + if size < 0 { + // Should never happen but just to be sure + http.Error(w, "negative content size", http.StatusInternalServerError) + return + } + + ctype := w.Header().Get("Content-Type") + // handle Content-Range header. + sendSize := size + var sendContent io.Reader + ranges, err := parseRange(rangeReq, size) + switch err { + case nil: + case errNoOverlap: + if size == 0 { + // Some clients add a Range header to all requests to + // limit the size of the response. If the file is empty, + // ignore the range header and respond with a 200 rather + // than a 416. + ranges = nil + break + } + w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", size)) + fallthrough + default: + http.Error(w, err.Error(), http.StatusRequestedRangeNotSatisfiable) + return + } + + if sumRangesSize(ranges) > size { + // The total number of bytes in all the ranges + // is larger than the size of the file by + // itself, so this is probably an attack, or a + // dumb client. Ignore the range request. + ranges = nil + } + switch { + case len(ranges) == 1: + // RFC 7233, Section 4.1: + // "If a single part is being transferred, the server + // generating the 206 response MUST generate a + // Content-Range header field, describing what range + // of the selected representation is enclosed, and a + // payload consisting of the range. + // ... + // A server MUST NOT generate a multipart response to + // a request for a single range, since a client that + // does not request multiple parts might not support + // multipart responses." + ra := ranges[0] + if sendContent, err = objH.NewRangeReader(context.Background(), ra.start, ra.length); err != nil { + http.Error(w, err.Error(), http.StatusRequestedRangeNotSatisfiable) + return + } + sendSize = ra.length + code = http.StatusPartialContent + w.Header().Set("Content-Range", ra.contentRange(size)) + case len(ranges) > 1: + sendSize = rangesMIMESize(ranges, ctype, size) + code = http.StatusPartialContent + + pr, pw := io.Pipe() + mw := multipart.NewWriter(pw) + w.Header().Set("Content-Type", "multipart/byteranges; boundary="+mw.Boundary()) + sendContent = pr + defer pr.Close() // cause writing goroutine to fail and exit if CopyN doesn't finish. + go func() { + for _, ra := range ranges { + part, err := mw.CreatePart(ra.mimeHeader(ctype, size)) + if err != nil { + pw.CloseWithError(err) + return + } + var content *storage.Reader + if content, err = objH.NewRangeReader(context.Background(), ra.start, ra.length); err != nil { + pw.CloseWithError(err) + return + } + if _, err := io.CopyN(part, content, ra.length); err != nil { + pw.CloseWithError(err) + return + } + } + mw.Close() + pw.Close() + }() + } + + w.Header().Set("Accept-Ranges", "bytes") + if w.Header().Get("Content-Encoding") == "" { + w.Header().Set("Content-Length", strconv.FormatInt(sendSize, 10)) + } + + w.WriteHeader(code) + + if r.Method != "HEAD" { + io.CopyN(w, sendContent, sendSize) + } +} + +// scanETag determines if a syntactically valid ETag is present at s. If so, +// the ETag and remaining text after consuming ETag is returned. Otherwise, +// it returns "", "". +func scanETag(s string) (etag string, remain string) { + s = textproto.TrimString(s) + start := 0 + if strings.HasPrefix(s, "W/") { + start = 2 + } + if len(s[start:]) < 2 || s[start] != '"' { + return "", "" + } + // ETag is either W/"text" or "text". + // See RFC 7232 2.3. + for i := start + 1; i < len(s); i++ { + c := s[i] + switch { + // Character values allowed in ETags. + case c == 0x21 || c >= 0x23 && c <= 0x7E || c >= 0x80: + case c == '"': + return s[:i+1], s[i+1:] + default: + return "", "" + } + } + return "", "" +} + +// etagStrongMatch reports whether a and b match using strong ETag comparison. +// Assumes a and b are valid ETags. +func etagStrongMatch(a, b string) bool { + return a == b && a != "" && a[0] == '"' +} + +// etagWeakMatch reports whether a and b match using weak ETag comparison. +// Assumes a and b are valid ETags. +func etagWeakMatch(a, b string) bool { + return strings.TrimPrefix(a, "W/") == strings.TrimPrefix(b, "W/") +} + +// condResult is the result of an HTTP request precondition check. +// See https://tools.ietf.org/html/rfc7232 section 3. +type condResult int + +const ( + condNone condResult = iota + condTrue + condFalse +) + +func checkIfMatch(w http.ResponseWriter, r *http.Request) condResult { + im := r.Header.Get("If-Match") + if im == "" { + return condNone + } + for { + im = textproto.TrimString(im) + if len(im) == 0 { + break + } + if im[0] == ',' { + im = im[1:] + continue + } + if im[0] == '*' { + return condTrue + } + etag, remain := scanETag(im) + if etag == "" { + break + } + if etagStrongMatch(etag, w.Header().Get("Etag")) { + return condTrue + } + im = remain + } + + return condFalse +} + +func checkIfUnmodifiedSince(r *http.Request, modtime time.Time) condResult { + ius := r.Header.Get("If-Unmodified-Since") + if ius == "" || isZeroTime(modtime) { + return condNone + } + t, err := http.ParseTime(ius) + if err != nil { + return condNone + } + + // The Last-Modified header truncates sub-second precision so + // the modtime needs to be truncated too. + modtime = modtime.Truncate(time.Second) + if ret := modtime.Compare(t); ret <= 0 { + return condTrue + } + return condFalse +} + +func checkIfNoneMatch(w http.ResponseWriter, r *http.Request) condResult { + inm := r.Header.Get("If-None-Match") + if inm == "" { + return condNone + } + buf := inm + for { + buf = textproto.TrimString(buf) + if len(buf) == 0 { + break + } + if buf[0] == ',' { + buf = buf[1:] + continue + } + if buf[0] == '*' { + return condFalse + } + etag, remain := scanETag(buf) + if etag == "" { + break + } + if etagWeakMatch(etag, w.Header().Get("Etag")) { + return condFalse + } + buf = remain + } + return condTrue +} + +func checkIfModifiedSince(r *http.Request, modtime time.Time) condResult { + if r.Method != "GET" && r.Method != "HEAD" { + return condNone + } + ims := r.Header.Get("If-Modified-Since") + if ims == "" || isZeroTime(modtime) { + return condNone + } + t, err := http.ParseTime(ims) + if err != nil { + return condNone + } + // The Last-Modified header truncates sub-second precision so + // the modtime needs to be truncated too. + modtime = modtime.Truncate(time.Second) + if ret := modtime.Compare(t); ret <= 0 { + return condFalse + } + return condTrue +} + +func checkIfRange(w http.ResponseWriter, r *http.Request, modtime time.Time) condResult { + if r.Method != "GET" && r.Method != "HEAD" { + return condNone + } + ir := r.Header.Get("If-Range") + if ir == "" { + return condNone + } + etag, _ := scanETag(ir) + if etag != "" { + if etagStrongMatch(etag, w.Header().Get("Etag")) { + return condTrue + } else { + return condFalse + } + } + // The If-Range value is typically the ETag value, but it may also be + // the modtime date. See golang.org/issue/8367. + if modtime.IsZero() { + return condFalse + } + t, err := http.ParseTime(ir) + if err != nil { + return condFalse + } + if t.Unix() == modtime.Unix() { + return condTrue + } + return condFalse +} + +var unixEpochTime = time.Unix(0, 0) + +// isZeroTime reports whether t is obviously unspecified (either zero or Unix()=0). +func isZeroTime(t time.Time) bool { + return t.IsZero() || t.Equal(unixEpochTime) +} + +func setLastModified(w http.ResponseWriter, modtime time.Time) { + if !isZeroTime(modtime) { + w.Header().Set("Last-Modified", modtime.UTC().Format(http.TimeFormat)) + } +} + +func writeNotModified(w http.ResponseWriter) { + // RFC 7232 section 4.1: + // a sender SHOULD NOT generate representation metadata other than the + // above listed fields unless said metadata exists for the purpose of + // guiding cache updates (e.g., Last-Modified might be useful if the + // response does not have an ETag field). + h := w.Header() + delete(h, "Content-Type") + delete(h, "Content-Length") + delete(h, "Content-Encoding") + if h.Get("Etag") != "" { + delete(h, "Last-Modified") + } + w.WriteHeader(http.StatusNotModified) +} + +// checkPreconditions evaluates request preconditions and reports whether a precondition +// resulted in sending StatusNotModified or StatusPreconditionFailed. +func checkPreconditions(w http.ResponseWriter, r *http.Request, modtime time.Time) (done bool, rangeHeader string) { + // This function carefully follows RFC 7232 section 6. + ch := checkIfMatch(w, r) + if ch == condNone { + ch = checkIfUnmodifiedSince(r, modtime) + } + if ch == condFalse { + w.WriteHeader(http.StatusPreconditionFailed) + return true, "" + } + switch checkIfNoneMatch(w, r) { + case condFalse: + if r.Method == "GET" || r.Method == "HEAD" { + writeNotModified(w) + return true, "" + } else { + w.WriteHeader(http.StatusPreconditionFailed) + return true, "" + } + case condNone: + if checkIfModifiedSince(r, modtime) == condFalse { + writeNotModified(w) + return true, "" + } + } + + rangeHeader = r.Header.Get("Range") + if rangeHeader != "" && checkIfRange(w, r, modtime) == condFalse { + rangeHeader = "" + } + return false, rangeHeader +} + +// httpRange specifies the byte range to be sent to the client. +type httpRange struct { + start, length int64 +} + +func (r httpRange) contentRange(size int64) string { + return fmt.Sprintf("bytes %d-%d/%d", r.start, r.start+r.length-1, size) +} + +func (r httpRange) mimeHeader(contentType string, size int64) textproto.MIMEHeader { + return textproto.MIMEHeader{ + "Content-Range": {r.contentRange(size)}, + "Content-Type": {contentType}, + } +} + +// parseRange parses a Range header string as per RFC 7233. +// errNoOverlap is returned if none of the ranges overlap. +func parseRange(s string, size int64) ([]httpRange, error) { + if s == "" { + return nil, nil // header not present + } + const b = "bytes=" + if !strings.HasPrefix(s, b) { + return nil, errors.New("invalid range") + } + var ranges []httpRange + noOverlap := false + for _, ra := range strings.Split(s[len(b):], ",") { + ra = textproto.TrimString(ra) + if ra == "" { + continue + } + start, end, ok := strings.Cut(ra, "-") + if !ok { + return nil, errors.New("invalid range") + } + start, end = textproto.TrimString(start), textproto.TrimString(end) + var r httpRange + if start == "" { + // If no start is specified, end specifies the + // range start relative to the end of the file, + // and we are dealing with + // which has to be a non-negative integer as per + // RFC 7233 Section 2.1 "Byte-Ranges". + if end == "" || end[0] == '-' { + return nil, errors.New("invalid range") + } + i, err := strconv.ParseInt(end, 10, 64) + if i < 0 || err != nil { + return nil, errors.New("invalid range") + } + if i > size { + i = size + } + r.start = size - i + r.length = size - r.start + } else { + i, err := strconv.ParseInt(start, 10, 64) + if err != nil || i < 0 { + return nil, errors.New("invalid range") + } + if i >= size { + // If the range begins after the size of the content, + // then it does not overlap. + noOverlap = true + continue + } + r.start = i + if end == "" { + // If no end is specified, range extends to end of the file. + r.length = size - r.start + } else { + i, err := strconv.ParseInt(end, 10, 64) + if err != nil || r.start > i { + return nil, errors.New("invalid range") + } + if i >= size { + i = size - 1 + } + r.length = i - r.start + 1 + } + } + ranges = append(ranges, r) + } + if noOverlap && len(ranges) == 0 { + // The specified ranges did not overlap with the content. + return nil, errNoOverlap + } + return ranges, nil +} + +// countingWriter counts how many bytes have been written to it. +type countingWriter int64 + +func (w *countingWriter) Write(p []byte) (n int, err error) { + *w += countingWriter(len(p)) + return len(p), nil +} + +// rangesMIMESize returns the number of bytes it takes to encode the +// provided ranges as a multipart response. +func rangesMIMESize(ranges []httpRange, contentType string, contentSize int64) (encSize int64) { + var w countingWriter + mw := multipart.NewWriter(&w) + for _, ra := range ranges { + mw.CreatePart(ra.mimeHeader(contentType, contentSize)) + encSize += ra.length + } + mw.Close() + encSize += int64(w) + return +} + +func sumRangesSize(ranges []httpRange) (size int64) { + for _, ra := range ranges { + size += ra.length + } + return +}