From c415e654c7c47cc6af120fbb471a47d66b166d7e Mon Sep 17 00:00:00 2001 From: Borengar Date: Sat, 16 Oct 2021 10:32:27 +0200 Subject: [PATCH] cache get_follow and get_block in stats get prediction (#64) * cache get_follow and get_block in stats get prediction * create class for cached responses * use buffered channel in stats get prediction queue * move follow and block cache out of stats get prediction * use generic handle func for cached requests --- UNSAFE_SPEEDUPS.md | 12 +++++ main.go | 5 +- proxy/proxy.go | 92 +++++++++++++++++++++++++---------- proxy/response_cache.go | 33 +++++++++++++ proxy/stats_get_prediction.go | 56 +++++++++++++++------ 5 files changed, 157 insertions(+), 41 deletions(-) create mode 100644 proxy/response_cache.go diff --git a/UNSAFE_SPEEDUPS.md b/UNSAFE_SPEEDUPS.md index b4e467f..7158838 100644 --- a/UNSAFE_SPEEDUPS.md +++ b/UNSAFE_SPEEDUPS.md @@ -97,3 +97,15 @@ The first batch of recommended replays will contain random replays when you open ### Speedup 2-3 seconds saved on the initial loading on title screen. + +## `-unsafe-cache-follow` ([@Borengar](https://github.com/Borengar)) + +(v1.8.0+) + +Totsugeki caches the `/api/catalog/get_follow` and `/api/catalog/get_block` calls on the first request. The same response is returned on subsequent requests. + +These requests return your follow/block list. Cache will be invalidated if you follow/unfollow/block/unblock other players. After that the next request will be cached again. + +### Speedup + +Up to 1 second every time when you look at your follow/block list, enter the tower, open replays, open the ranking list, etc. diff --git a/main.go b/main.go index 7c4b4d6..83e0934 100755 --- a/main.go +++ b/main.go @@ -287,6 +287,7 @@ func main() { var unsafeNoNews = flag.Bool("unsafe-no-news", false, "UNSAFE: Return an empty response for news.") var unsafePredictReplay = flag.Bool("unsafe-predict-replay", false, "UNSAFE: Asynchronously precache expected get_replay calls. Needs unsafe-predict-stats-get to work.") var unsafeCacheEnv = flag.Bool("unsafe-cache-env", false, "UNSAFE: Cache first get_env call and return cached version on subsequent calls.") + var unsafeCacheFollow = flag.Bool("unsafe-cache-follow", false, "UNSAFE: Cache first get_follow and get_block calls and return cached version on subsequent calls.") var ungaBunga = flag.Bool("unga-bunga", UngaBungaMode != "", "UNSAFE: Enable all unsafe speedups for maximum speed. Please read https://github.com/optix2000/totsugeki/blob/master/UNSAFE_SPEEDUPS.md") var iKnowWhatImDoing = flag.Bool("i-know-what-im-doing", false, "UNSAFE: Suppress any UNSAFE warnings. I hope you know what you're doing...") var ver = flag.Bool("version", false, "Print the version number and exit.") @@ -329,6 +330,7 @@ func main() { *unsafeNoNews = true *unsafePredictReplay = true *unsafeCacheEnv = true + *unsafeCacheFollow = true } // Drop process priority @@ -397,6 +399,7 @@ func main() { NoNews: *unsafeNoNews, PredictReplay: *unsafePredictReplay, CacheEnv: *unsafeCacheEnv, + CacheFollow: *unsafeCacheFollow, }) fmt.Println("Started Proxy Server on port 21611.") @@ -419,7 +422,7 @@ func main() { }() } - if !*iKnowWhatImDoing && (*unsafeAsyncStatsSet || *unsafePredictStatsGet || *unsafeCacheNews || *unsafeNoNews || *unsafeCacheEnv || *unsafePredictReplay) { + if !*iKnowWhatImDoing && (*unsafeAsyncStatsSet || *unsafePredictStatsGet || *unsafeCacheNews || *unsafeNoNews || *unsafeCacheEnv || *unsafePredictReplay || *unsafeCacheFollow) { fmt.Println("WARNING: Unsafe feature used. Make sure you understand the implications: https://github.com/optix2000/totsugeki/blob/master/UNSAFE_SPEEDUPS.md") } diff --git a/proxy/proxy.go b/proxy/proxy.go index d6cd728..341ea70 100755 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -21,12 +21,9 @@ type StriveAPIProxy struct { PatchedAPIURL string statsQueue chan<- *http.Request wg sync.WaitGroup - cachedNewsReq *http.Response - cachedNewsBody []byte prediction StatsGetPrediction CacheEnv bool - cachedEnvReq *http.Response - cachedEnvBody []byte + responseCache *ResponseCache } type StriveAPIProxyOptions struct { @@ -36,6 +33,7 @@ type StriveAPIProxyOptions struct { NoNews bool PredictReplay bool CacheEnv bool + CacheFollow bool } func (s *StriveAPIProxy) proxyRequest(r *http.Request) (*http.Response, error) { @@ -73,13 +71,31 @@ func (s *StriveAPIProxy) HandleCatchall(w http.ResponseWriter, r *http.Request) } } -// GGST uses the URL from this API after initial launch so we need to intercept this. -func (s *StriveAPIProxy) HandleGetEnv(w http.ResponseWriter, r *http.Request) { - if s.CacheEnv && s.cachedEnvReq != nil { - for name, values := range s.cachedEnvReq.Header { +// Invalidate cache if certain requests are used +func (s *StriveAPIProxy) CacheInvalidationHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + path := r.URL.Path + switch path { + case "/api/follow/follow_user", "/api/follow/unfollow_user": + s.responseCache.RemoveResponse("catalog/get_follow") + next.ServeHTTP(w, r) + case "/api/follow/block_user", "/api/follow/unblock_user": + s.responseCache.RemoveResponse("catalog/get_block") + next.ServeHTTP(w, r) + default: + next.ServeHTTP(w, r) + } + }) +} + +// Generic handler func for cached requests +func (s *StriveAPIProxy) HandleCachedRequest(request string, w http.ResponseWriter, r *http.Request) { + if s.responseCache.ResponseExists(request) { + resp, body := s.responseCache.GetResponse(request) + for name, values := range resp.Header { w.Header()[name] = values } - w.Write(s.cachedEnvBody) + w.Write(body) } else { resp, err := s.proxyRequest(r) if err != nil { @@ -93,22 +109,23 @@ func (s *StriveAPIProxy) HandleGetEnv(w http.ResponseWriter, r *http.Request) { w.Header()[name] = values } w.WriteHeader(resp.StatusCode) - buf, err := io.ReadAll(resp.Body) + reader := io.TeeReader(resp.Body, w) // For dumping API payloads + buf, err := io.ReadAll(reader) if err != nil { fmt.Println(err) } - buf = bytes.Replace(buf, []byte(s.GGStriveAPIURL), []byte(s.PatchedAPIURL), -1) - w.Write(buf) + s.responseCache.AddResponse(request, resp, buf) } } -// UNSAFE: Cache news on first request. On every other request return the cached value. -func (s *StriveAPIProxy) HandleGetNews(w http.ResponseWriter, r *http.Request) { - if s.cachedNewsReq != nil { - for name, values := range s.cachedNewsReq.Header { +// GGST uses the URL from this API after initial launch so we need to intercept this. +func (s *StriveAPIProxy) HandleGetEnv(w http.ResponseWriter, r *http.Request) { + if s.CacheEnv && s.responseCache.ResponseExists("sys/get_env") { + resp, body := s.responseCache.GetResponse("sys/get_env") + for name, values := range resp.Header { w.Header()[name] = values } - w.Write(s.cachedNewsBody) + w.Write(body) } else { resp, err := s.proxyRequest(r) if err != nil { @@ -122,16 +139,30 @@ func (s *StriveAPIProxy) HandleGetNews(w http.ResponseWriter, r *http.Request) { w.Header()[name] = values } w.WriteHeader(resp.StatusCode) - reader := io.TeeReader(resp.Body, w) // For dumping API payloads - buf, err := io.ReadAll(reader) + buf, err := io.ReadAll(resp.Body) if err != nil { fmt.Println(err) } - s.cachedNewsReq = resp - s.cachedNewsBody = buf + buf = bytes.Replace(buf, []byte(s.GGStriveAPIURL), []byte(s.PatchedAPIURL), -1) + w.Write(buf) } } +// UNSAFE: Cache news on first request. On every other request return the cached value. +func (s *StriveAPIProxy) HandleGetNews(w http.ResponseWriter, r *http.Request) { + s.HandleCachedRequest("sys/get_news", w, r) +} + +// UNSAFE: Cache get_follow on first request. On every other request return the cached value. +func (s *StriveAPIProxy) HandleGetFollow(w http.ResponseWriter, r *http.Request) { + s.HandleCachedRequest("catalog/get_follow", w, r) +} + +// UNSAFE: Cache get_block on first request. On every other request return the cached value. +func (s *StriveAPIProxy) HandleGetBlock(w http.ResponseWriter, r *http.Request) { + s.HandleCachedRequest("catalog/get_block", w, r) +} + func (s *StriveAPIProxy) Shutdown() { fmt.Println("Shutting down proxy...") @@ -165,14 +196,20 @@ func CreateStriveProxy(listen string, GGStriveAPIURL string, PatchedAPIURL strin GGStriveAPIURL: GGStriveAPIURL, PatchedAPIURL: PatchedAPIURL, CacheEnv: false, + responseCache: &ResponseCache{ + responses: make(map[string]*CachedResponse), + }, } statsSet := proxy.HandleCatchall statsGet := proxy.HandleCatchall getNews := proxy.HandleCatchall getReplay := proxy.HandleCatchall + getFollow := proxy.HandleCatchall + getBlock := proxy.HandleCatchall r := chi.NewRouter() r.Use(middleware.Logger) + r.Use(proxy.CacheInvalidationHandler) if options.AsyncStatsSet { statsSet = proxy.HandleStatsSet @@ -187,7 +224,7 @@ func CreateStriveProxy(listen string, GGStriveAPIURL string, PatchedAPIURL strin predictStatsClient := client predictStatsClient.Transport = &predictStatsTransport - proxy.prediction = CreateStatsGetPrediction(GGStriveAPIURL, &predictStatsClient) + proxy.prediction = CreateStatsGetPrediction(GGStriveAPIURL, &predictStatsClient, proxy.responseCache) r.Use(proxy.prediction.StatsGetStateHandler) statsGet = func(w http.ResponseWriter, r *http.Request) { if !proxy.prediction.HandleGetStats(w, r) { @@ -207,6 +244,10 @@ func CreateStriveProxy(listen string, GGStriveAPIURL string, PatchedAPIURL strin } else if options.CacheNews { getNews = proxy.HandleGetNews } + if options.CacheFollow { + getFollow = proxy.HandleGetFollow + getBlock = proxy.HandleGetBlock + } if options.CacheEnv { proxy.CacheEnv = true @@ -219,8 +260,7 @@ func CreateStriveProxy(listen string, GGStriveAPIURL string, PatchedAPIURL strin fmt.Println(err) } buf = bytes.Replace(buf, []byte(GGStriveAPIURL), []byte(PatchedAPIURL), -1) - proxy.cachedEnvReq = resp - proxy.cachedEnvBody = buf + proxy.responseCache.AddResponse("sys/get_env", resp, buf) resp.Body.Close() } @@ -230,8 +270,8 @@ func CreateStriveProxy(listen string, GGStriveAPIURL string, PatchedAPIURL strin r.HandleFunc("/statistics/set", statsSet) r.HandleFunc("/tus/write", statsSet) r.HandleFunc("/sys/get_news", getNews) - r.HandleFunc("/catalog/get_follow", statsGet) - r.HandleFunc("/catalog/get_block", statsGet) + r.HandleFunc("/catalog/get_follow", getFollow) + r.HandleFunc("/catalog/get_block", getBlock) r.HandleFunc("/catalog/get_replay", getReplay) r.HandleFunc("/lobby/get_vip_status", statsGet) r.HandleFunc("/item/get_item", statsGet) diff --git a/proxy/response_cache.go b/proxy/response_cache.go new file mode 100644 index 0000000..9d6c39a --- /dev/null +++ b/proxy/response_cache.go @@ -0,0 +1,33 @@ +package proxy + +import "net/http" + +type CachedResponse struct { + response *http.Response + body []byte +} + +type ResponseCache struct { + responses map[string]*CachedResponse +} + +func (c *ResponseCache) ResponseExists(request string) bool { + _, exists := c.responses[request] + return exists +} + +func (c *ResponseCache) GetResponse(request string) (http.Response, []byte) { + response, _ := c.responses[request] + return *response.response, response.body +} + +func (c *ResponseCache) AddResponse(request string, response *http.Response, body []byte) { + c.responses[request] = &CachedResponse{ + response: response, + body: body, + } +} + +func (c *ResponseCache) RemoveResponse(request string) { + delete(c.responses, request) +} diff --git a/proxy/stats_get_prediction.go b/proxy/stats_get_prediction.go index b7dfb4d..6e30137 100755 --- a/proxy/stats_get_prediction.go +++ b/proxy/stats_get_prediction.go @@ -16,10 +16,11 @@ import ( const StatsGetWorkers = 5 type StatsGetTask struct { - data string - path string - request string - response chan *http.Response + data string + path string + request string + response chan *http.Response + responseBody []byte } type StatsGetPrediction struct { @@ -29,6 +30,7 @@ type StatsGetPrediction struct { client *http.Client PredictReplay bool skipNext bool + responseCache *ResponseCache } type PredictionState int @@ -66,6 +68,20 @@ func (rw *CachingResponseWriter) Write(data []byte) (int, error) { return rw.w.Write(data) } +func (s *StatsGetPrediction) proxyRequest(r *http.Request) (*http.Response, error) { + apiURL, err := url.Parse(s.GGStriveAPIURL) + if err != nil { + fmt.Println(err) + return nil, err + } + apiURL.Path = r.URL.Path + + r.URL = apiURL + r.Host = "" + r.RequestURI = "" + return s.client.Do(r) +} + func (s *StatsGetPrediction) StatsGetStateHandler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { path := r.URL.Path @@ -87,7 +103,6 @@ func (s *StatsGetPrediction) StatsGetStateHandler(next http.Handler) http.Handle next.ServeHTTP(w, r) } }) - } // Proxy getstats @@ -114,16 +129,12 @@ func (s *StatsGetPrediction) HandleGetStats(w http.ResponseWriter, r *http.Reque delete(s.statsGetTasks, req) return false } - defer resp.Body.Close() // Copy headers for name, values := range resp.Header { w.Header()[name] = values } w.WriteHeader(resp.StatusCode) - _, err := io.Copy(w, resp.Body) - if err != nil { - fmt.Println(err) - } + w.Write(task.responseBody) delete(s.statsGetTasks, req) if len(s.statsGetTasks) == 0 { s.predictionState = ready @@ -173,7 +184,24 @@ func (s *StatsGetPrediction) ProcessStatsQueue(queue chan *StatsGetTask) { fmt.Println(err) item.response <- nil } else { - item.response <- res + buf, err := io.ReadAll(res.Body) + res.Body.Close() + if err != nil { + fmt.Println(err) + item.response <- nil + } else { + //add get_follow and get_block to the generic response cache instead of the prediction queue + if strings.HasSuffix(req.URL.Path, "catalog/get_follow") { + s.responseCache.AddResponse("catalog/get_follow", res, buf) + delete(s.statsGetTasks, item.request) + } else if strings.HasSuffix(req.URL.Path, "catalog/get_block") { + s.responseCache.AddResponse("catalog/get_block", res, buf) + delete(s.statsGetTasks, item.request) + } else { + item.responseBody = buf + item.response <- res + } + } } default: fmt.Println("Empty queue, shutting down") @@ -181,7 +209,6 @@ func (s *StatsGetPrediction) ProcessStatsQueue(queue chan *StatsGetTask) { } } - } func (s *StatsGetPrediction) AsyncGetStats(body []byte, reqType StatsGetType) { @@ -213,7 +240,7 @@ func (s *StatsGetPrediction) AsyncGetStats(body []byte, reqType StatsGetType) { id := bodyConst + task.data + "\x00" task.request = id - task.response = make(chan *http.Response) + task.response = make(chan *http.Response, 1) s.statsGetTasks[id] = &task queue <- &task @@ -226,7 +253,7 @@ func (s *StatsGetPrediction) AsyncGetStats(body []byte, reqType StatsGetType) { } } -func CreateStatsGetPrediction(GGStriveAPIURL string, client *http.Client) StatsGetPrediction { +func CreateStatsGetPrediction(GGStriveAPIURL string, client *http.Client, responseCache *ResponseCache) StatsGetPrediction { return StatsGetPrediction{ GGStriveAPIURL: GGStriveAPIURL, predictionState: ready, @@ -234,6 +261,7 @@ func CreateStatsGetPrediction(GGStriveAPIURL string, client *http.Client) StatsG client: client, PredictReplay: false, skipNext: false, + responseCache: responseCache, } }