package main import ( "flag" "math" "net" "net/http" _ "net/http/pprof" "sync" "time" "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/benchmark" testpb "google.golang.org/grpc/benchmark/grpc_testing" "google.golang.org/grpc/benchmark/stats" "google.golang.org/grpc/grpclog" ) var ( server = flag.String("server", "", "The server address") maxConcurrentRPCs = flag.Int("max_concurrent_rpcs", 1, "The max number of concurrent RPCs") duration = flag.Int("duration", math.MaxInt32, "The duration in seconds to run the benchmark client") trace = flag.Bool("trace", true, "Whether tracing is on") rpcType = flag.Int("rpc_type", 0, `Configure different client rpc type. Valid options are: 0 : unary call; 1 : streaming call.`) ) func unaryCaller(client testpb.BenchmarkServiceClient) { benchmark.DoUnaryCall(client, 1, 1) } func streamCaller(stream testpb.BenchmarkService_StreamingCallClient) { benchmark.DoStreamingRoundTrip(stream, 1, 1) } func buildConnection() (s *stats.Stats, conn *grpc.ClientConn, tc testpb.BenchmarkServiceClient) { s = stats.NewStats(256) conn = benchmark.NewClientConn(*server) tc = testpb.NewBenchmarkServiceClient(conn) return s, conn, tc } func closeLoopUnary() { s, conn, tc := buildConnection() for i := 0; i < 100; i++ { unaryCaller(tc) } ch := make(chan int, *maxConcurrentRPCs*4) var ( mu sync.Mutex wg sync.WaitGroup ) wg.Add(*maxConcurrentRPCs) for i := 0; i < *maxConcurrentRPCs; i++ { go func() { for _ = range ch { start := time.Now() unaryCaller(tc) elapse := time.Since(start) mu.Lock() s.Add(elapse) mu.Unlock() } wg.Done() }() } // Stop the client when time is up. done := make(chan struct{}) go func() { <-time.After(time.Duration(*duration) * time.Second) close(done) }() ok := true for ok { select { case ch <- 0: case <-done: ok = false } } close(ch) wg.Wait() conn.Close() grpclog.Println(s.String()) } func closeLoopStream() { s, conn, tc := buildConnection() ch := make(chan int, *maxConcurrentRPCs*4) var ( mu sync.Mutex wg sync.WaitGroup ) wg.Add(*maxConcurrentRPCs) // Distribute RPCs over maxConcurrentCalls workers. for i := 0; i < *maxConcurrentRPCs; i++ { go func() { stream, err := tc.StreamingCall(context.Background()) if err != nil { grpclog.Fatalf("%v.StreamingCall(_) = _, %v", tc, err) } // Do some warm up. for i := 0; i < 100; i++ { streamCaller(stream) } for range ch { start := time.Now() streamCaller(stream) elapse := time.Since(start) mu.Lock() s.Add(elapse) mu.Unlock() } wg.Done() }() } // Stop the client when time is up. done := make(chan struct{}) go func() { <-time.After(time.Duration(*duration) * time.Second) close(done) }() ok := true for ok { select { case ch <- 0: case <-done: ok = false } } close(ch) wg.Wait() conn.Close() grpclog.Println(s.String()) } func main() { flag.Parse() grpc.EnableTracing = *trace go func() { lis, err := net.Listen("tcp", ":0") if err != nil { grpclog.Fatalf("Failed to listen: %v", err) } grpclog.Println("Client profiling address: ", lis.Addr().String()) if err := http.Serve(lis, nil); err != nil { grpclog.Fatalf("Failed to serve: %v", err) } }() switch *rpcType { case 0: closeLoopUnary() case 1: closeLoopStream() } }