diff --git a/client/main.go b/client/main.go index 55839b39..bf25fde6 100644 --- a/client/main.go +++ b/client/main.go @@ -162,27 +162,32 @@ func logFinalReport(good, bad, bytes int64, latencies *hdrhistogram.Histogram, j } func sendNonStreamingRequests(client pb.ResponderClient, - requests int64, lengthDistribution distribution.Distribution, + shutdownChannel <-chan struct{}, + lengthDistribution distribution.Distribution, latencyDistribution distribution.Distribution, r *rand.Rand, received chan *MeasuredResponse) error { - for j := int64(0); j < requests; j++ { - start := time.Now() - resp, err := client.Get(context.Background(), - &pb.ResponseSpec{ - Length: int32(lengthDistribution.Get(r.Int31() % 1000)), - Latency: latencyDistribution.Get(r.Int31() % 1000)}) - - bytes := int64(len([]byte(resp.Body))) - latency := time.Since(start) - promLatencyHistogram.Observe(float64(latency)) - received <- &MeasuredResponse{0, latency, bytes, err} - - if err != nil { - return err + for { + select { + case <-shutdownChannel: + return nil + default: + + start := time.Now() + resp, err := client.Get(context.Background(), + &pb.ResponseSpec{ + Length: int32(lengthDistribution.Get(r.Int31() % 1000)), + Latency: latencyDistribution.Get(r.Int31() % 1000)}) + + bytes := int64(len([]byte(resp.Body))) + latency := time.Since(start) + promLatencyHistogram.Observe(float64(latency)) + received <- &MeasuredResponse{0, latency, bytes, err} + + if err != nil { + return err + } } } - - return nil } // parseStreamingRatio takes a string formatted as integer:integer, @@ -209,13 +214,14 @@ func parseStreamingRatio(streamingRatio string) (int64, int64) { } func sendStreamingRequests(client pb.ResponderClient, - requests int64, lengthDistribution distribution.Distribution, + shutdownChannel <-chan struct{}, lengthDistribution distribution.Distribution, latencyDistribution distribution.Distribution, streamingRatio string, r *rand.Rand, received chan *MeasuredResponse) error { stream, err := client.StreamingGet(context.Background()) if err != nil { log.Fatalf("%v.StreamingGet(_) = _, %v", client, err) } + defer stream.CloseSend() latencyMap := latencyDistribution.ToMap() lengthMap := lengthDistribution.ToMap() @@ -250,35 +256,38 @@ func sendStreamingRequests(client pb.ResponderClient, }() requestRatioM, requestRatioN := parseStreamingRatio(streamingRatio) - var numRequests = int64(0) - for j := int64(0); j < requests; j++ { - if (j % requestRatioM) == 0 { - numRequests = requestRatioN - } + numRequests := int64(0) + currentRequest := int64(0) + for { + select { + case <-shutdownChannel: + return nil + default: + if (currentRequest % requestRatioM) == 0 { + numRequests = requestRatioN + } - err := stream.Send(&pb.StreamingResponseSpec{ - Count: int32(numRequests), - LatencyPercentiles: latencyMap, - LengthPercentiles: lengthMap, - }) + err := stream.Send(&pb.StreamingResponseSpec{ + Count: int32(numRequests), + LatencyPercentiles: latencyMap, + LengthPercentiles: lengthMap, + }) - if err != nil { - log.Fatalf("Failed to Send ResponseSpec: %v", err) - } + if err != nil { + log.Fatalf("Failed to Send ResponseSpec: %v", err) + } - numRequests = 0 + numRequests = 0 + currentRequest++ + } } - - stream.CloseSend() - <-waitc - return nil } func main() { var ( address = flag.String("address", "localhost:11111", "hostname:port of strest-grpc service or intermediary") concurrency = flag.Int("concurrency", 1, "client concurrency level") - requests = flag.Int64("requests", 10000, "number of requests per connection") + totalRequests = flag.Int64("totalRequests", 0, "total number of requests to send. default: infinite") interval = flag.Duration("interval", 10*time.Second, "reporting interval") latencyPercentileFlag = flag.String("latencyPercentiles", "50=10,100=100", "response latency percentile distribution.") lengthPercentileFlag = flag.String("lengthPercentiles", "50=100,100=1000", "response body length percentile distribution.") @@ -324,8 +333,9 @@ func main() { log.Fatalf("unable to create length distribution: %v", err) } - cleanup := make(chan os.Signal) - signal.Notify(cleanup, syscall.SIGINT) + cleanup := make(chan struct{}, 2) + interrupt := make(chan os.Signal, 2) + signal.Notify(interrupt, syscall.SIGINT) var bytes, totalBytes, count, totalCount, good, totalGood, bad, totalBad, max, min int64 min = math.MaxInt64 @@ -350,8 +360,12 @@ func main() { var wg sync.WaitGroup wg.Add(*concurrency) + shutdownChannels := make([]chan struct{}, *concurrency) for i := int(0); i < *concurrency; i++ { + shutdownChannel := make(chan struct{}, 2) + shutdownChannels = append(shutdownChannels, shutdownChannel) + go func() { r := rand.New(rand.NewSource(time.Now().UnixNano())) // Set up a connection to the server. @@ -363,33 +377,40 @@ func main() { client := pb.NewResponderClient(conn) if !*streaming { - err := sendNonStreamingRequests(client, - *requests, lengthDistribution, latencyDistribution, r, received) + err := sendNonStreamingRequests(client, shutdownChannel, + lengthDistribution, latencyDistribution, r, received) if err != nil { log.Fatalf("could not send a request: %v", err) } } else { - err := sendStreamingRequests(client, - *requests, lengthDistribution, latencyDistribution, *streamingRatio, r, received) + err := sendStreamingRequests(client, shutdownChannel, + lengthDistribution, latencyDistribution, *streamingRatio, r, received) if err != nil { log.Fatalf("could not send a request: %v", err) } } - wg.Done() }() } go func() { + wg.Add(1) for { select { + case <-interrupt: + cleanup <- struct{}{} case <-cleanup: - // FIX: how can we close the client - // connection properly here? + for _, c := range shutdownChannels { + if c != nil { + c <- struct{}{} + } + } + if !*disableFinalReport { logFinalReport(totalGood, totalBad, totalBytes, globalLatencyHist, globalJitterHist) } - os.Exit(0) + wg.Done() + return case resp := <-received: count++ @@ -431,12 +452,17 @@ func main() { latencyHist.Reset() jitterHist.Reset() timeout = time.After(*interval) + if totalCount > 0 && totalCount > *totalRequests { + cleanup <- struct{}{} + } } } }() wg.Wait() + if !*disableFinalReport { logFinalReport(totalGood, totalBad, totalBytes, globalLatencyHist, globalJitterHist) } + os.Exit(0) }