diff --git a/server.go b/server.go index 70fe23f55022..1da2a542acde 100644 --- a/server.go +++ b/server.go @@ -1598,6 +1598,7 @@ func (s *Server) processStreamingRPC(ctx context.Context, stream *transport.Serv s: stream, p: &parser{r: stream, bufferPool: s.opts.bufferPool}, codec: s.getCodec(stream.ContentSubtype()), + desc: sd, maxReceiveMessageSize: s.opts.maxReceiveMessageSize, maxSendMessageSize: s.opts.maxSendMessageSize, trInfo: trInfo, diff --git a/stream.go b/stream.go index ca6948926f93..626af2349904 100644 --- a/stream.go +++ b/stream.go @@ -1580,6 +1580,7 @@ type serverStream struct { s *transport.ServerStream p *parser codec baseCodec + desc *StreamDesc compressorV0 Compressor compressorV1 encoding.Compressor @@ -1588,6 +1589,8 @@ type serverStream struct { sendCompressorName string + recvFirstMsg bool // set after the first message is received + maxReceiveMessageSize int maxSendMessageSize int trInfo *traceInfo @@ -1774,6 +1777,10 @@ func (ss *serverStream) RecvMsg(m any) (err error) { binlog.Log(ss.ctx, chc) } } + // Received no request msg for non-client streaming rpcs. + if !ss.desc.ClientStreams && !ss.recvFirstMsg { + return status.Error(codes.Internal, "cardinality violation: received no request message from non-client-streaming RPC") + } return err } if err == io.ErrUnexpectedEOF { @@ -1781,6 +1788,7 @@ func (ss *serverStream) RecvMsg(m any) (err error) { } return toRPCErr(err) } + ss.recvFirstMsg = true if len(ss.statsHandler) != 0 { for _, sh := range ss.statsHandler { sh.HandleRPC(ss.s.Context(), &stats.InPayload{ @@ -1800,7 +1808,19 @@ func (ss *serverStream) RecvMsg(m any) (err error) { binlog.Log(ss.ctx, cm) } } - return nil + + if ss.desc.ClientStreams { + // Subsequent messages should be received by subsequent RecvMsg calls. + return nil + } + // Special handling for non-client-stream rpcs. + // This recv expects EOF or errors, so we don't collect inPayload. + if err := recv(ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, nil, ss.decompressorV1, true); err == io.EOF { + return nil + } else if err != nil { + return err + } + return status.Error(codes.Internal, "cardinality violation: received multiple request messages for non-client-streaming RPC") } // MethodFromServerStream returns the method string for the input stream. diff --git a/test/end2end_test.go b/test/end2end_test.go index 584c90ca3b15..ce1c5dbec70f 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -51,6 +51,7 @@ import ( "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/credentials/local" "google.golang.org/grpc/health" "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/binarylog" @@ -3740,6 +3741,245 @@ func (s) TestClientStreaming_ReturnErrorAfterSendAndClose(t *testing.T) { } } +// Tests the behavior for server-side streaming when client calls SendMsg twice. +// Second call to SendMsg should fail with Internal error and result in closing +// the connection with a RST_STREAM. +func (s) TestServerStreaming_ClientCallSendMsgTwice(t *testing.T) { + // To ensure initial call to server.recvMsg() made by the generated code is successfully + // completed. Otherwise, if the client attempts to send a second request message, that + // will trigger a RST_STREAM from the client due to the application violating the RPC's + // protocol. The RST_STREAM could cause the server’s first RecvMsg to fail and will prevent + // the method handler from being called. + recvDoneOnServer := make(chan struct{}) + // To ensure goroutine for test does not end before RPC handler performs error + // checking. + handlerDone := make(chan struct{}) + ss := stubserver.StubServer{ + StreamingOutputCallF: func(_ *testpb.StreamingOutputCallRequest, stream testgrpc.TestService_StreamingOutputCallServer) error { + close(recvDoneOnServer) + // Block until the stream’s context is done. Second call to client.SendMsg + // triggers a RST_STREAM which cancels the stream context on the server. + <-stream.Context().Done() + if err := stream.SendMsg(&testpb.StreamingOutputCallRequest{}); status.Code(err) != codes.Canceled { + t.Errorf("stream.SendMsg() = %v, want error %v", err, codes.Canceled) + } + close(handlerDone) + return nil + }, + } + if err := ss.Start(nil); err != nil { + t.Fatal("Error starting server:", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(local.NewCredentials())) + if err != nil { + t.Fatalf("grpc.NewClient(%q) failed unexpectedly: %v", ss.Address, err) + } + defer cc.Close() + + desc := &grpc.StreamDesc{ + StreamName: "StreamingOutputCall", + ServerStreams: true, + ClientStreams: false, + } + + stream, err := cc.NewStream(ctx, desc, "/grpc.testing.TestService/StreamingOutputCall") + if err != nil { + t.Fatalf("cc.NewStream() failed unexpectedly: %v", err) + } + + if err := stream.SendMsg(&testpb.Empty{}); err != nil { + t.Errorf("stream.SendMsg() = %v, want ", err) + } + + <-recvDoneOnServer + if err := stream.SendMsg(&testpb.Empty{}); status.Code(err) != codes.Internal { + t.Errorf("stream.SendMsg() = %v, want error %v", err, codes.Internal) + } + <-handlerDone +} + +// TODO(i/7286) : Add tests to check server-side behavior for Unary RPC. +// Tests the behavior for unary RPC when client calls SendMsg twice. Second call +// to SendMsg should fail with Internal error. +func (s) TestUnaryRPC_ClientCallSendMsgTwice(t *testing.T) { + ss := stubserver.StubServer{ + UnaryCallF: func(context.Context, *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return &testpb.SimpleResponse{}, nil + }, + } + if err := ss.Start(nil); err != nil { + t.Fatal("Error starting server:", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(local.NewCredentials())) + if err != nil { + t.Fatalf("grpc.NewClient(%q) failed unexpectedly: %v", ss.Address, err) + } + defer cc.Close() + + desc := &grpc.StreamDesc{ + StreamName: "UnaryCall", + ServerStreams: false, + ClientStreams: false, + } + + stream, err := cc.NewStream(ctx, desc, "/grpc.testing.TestService/UnaryCall") + if err != nil { + t.Fatalf("cc.NewStream() failed unexpectedly: %v", err) + } + + if err := stream.SendMsg(&testpb.Empty{}); err != nil { + t.Errorf("stream.SendMsg() = %v, want ", err) + } + + if err := stream.SendMsg(&testpb.Empty{}); status.Code(err) != codes.Internal { + t.Errorf("stream.SendMsg() = %v, want error %v", status.Code(err), codes.Internal) + } +} + +// Tests the behavior for server-side streaming RPC when client misbehaves as Bidi-streaming +// and sends multiple messages. +func (s) TestServerStreaming_ClientSendsMultipleMessages(t *testing.T) { + // The initial call to recvMsg made by the generated code, will return the error. + ss := stubserver.StubServer{} + if err := ss.Start(nil); err != nil { + t.Fatal("Error starting server:", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(local.NewCredentials())) + if err != nil { + t.Fatalf("grpc.NewClient(%q) failed unexpectedly: %v", ss.Address, err) + } + defer cc.Close() + + // Making the client bi-di to bypass the client side checks that stop a non-streaming client + // from sending multiple messages. + desc := &grpc.StreamDesc{ + StreamName: "StreamingOutputCall", + ServerStreams: true, + ClientStreams: true, + } + + stream, err := cc.NewStream(ctx, desc, "/grpc.testing.TestService/StreamingOutputCall") + if err != nil { + t.Fatalf("cc.NewStream() failed unexpectedly: %v", err) + } + + if err := stream.SendMsg(&testpb.Empty{}); err != nil { + t.Errorf("stream.SendMsg() = %v, want ", err) + } + + if err := stream.SendMsg(&testpb.Empty{}); err != nil { + t.Errorf("stream.SendMsg() = %v, want ", err) + } + + if err := stream.RecvMsg(&testpb.Empty{}); status.Code(err) != codes.Internal { + t.Errorf("stream.RecvMsg() = %v, want error %v", status.Code(err), codes.Internal) + } +} + +// Tests the behavior of server for server-side streaming RPC when client sends zero request messages. +func (s) TestServerStreaming_ServerRecvZeroRequests(t *testing.T) { + testCases := []struct { + name string + desc *grpc.StreamDesc + wantCode codes.Code + }{ + { + name: "BidiStreaming", + desc: &grpc.StreamDesc{ + StreamName: "StreamingOutputCall", + ServerStreams: true, + ClientStreams: true, + }, + wantCode: codes.Internal, + }, + { + name: "ClientStreaming", + desc: &grpc.StreamDesc{ + StreamName: "StreamingOutputCall", + ServerStreams: false, + ClientStreams: true, + }, + wantCode: codes.Internal, + }, + } + + for _, tc := range testCases { + // The initial call to recvMsg made by the generated code, will return the error. + ss := stubserver.StubServer{} + if err := ss.Start(nil); err != nil { + t.Fatal("Error starting server:", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(local.NewCredentials())) + if err != nil { + t.Fatalf("grpc.NewClient(%q) failed unexpectedly: %v", ss.Address, err) + } + defer cc.Close() + + stream, err := cc.NewStream(ctx, tc.desc, "/grpc.testing.TestService/StreamingOutputCall") + if err != nil { + t.Fatalf("cc.NewStream() failed unexpectedly: %v", err) + } + + if err := stream.CloseSend(); err != nil { + t.Errorf("stream.CloseSend() = %v, want ", err) + } + + if err := stream.RecvMsg(&testpb.Empty{}); status.Code(err) != tc.wantCode { + t.Errorf("stream.RecvMsg() = %v, want error %v", status.Code(err), tc.wantCode) + } + } +} + +// Tests the behavior of client for server-side streaming RPC when client sends zero request messages. +func (s) TestServerStreaming_ClientSendsZeroRequests(t *testing.T) { + t.Skip("blocked on i/7286") + // The initial call to recvMsg made by the generated code, will return the error. + ss := stubserver.StubServer{} + if err := ss.Start(nil); err != nil { + t.Fatal("Error starting server:", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(local.NewCredentials())) + if err != nil { + t.Fatalf("grpc.NewClient(%q) failed unexpectedly: %v", ss.Address, err) + } + defer cc.Close() + + desc := &grpc.StreamDesc{ + StreamName: "StreamingOutputCall", + ServerStreams: true, + ClientStreams: false, + } + + stream, err := cc.NewStream(ctx, desc, "/grpc.testing.TestService/StreamingOutputCall") + if err != nil { + t.Fatalf("cc.NewStream() failed unexpectedly: %v", err) + } + + if err := stream.CloseSend(); status.Code(err) != codes.Internal { + t.Errorf("stream.CloseSend() = %v, want error %v", status.Code(err), codes.Internal) + } +} + // Tests that a client receives a cardinality violation error for client-streaming // RPCs if the server call SendMsg multiple times. func (s) TestClientStreaming_ServerHandlerSendMsgAfterSendMsg(t *testing.T) {