diff --git a/internal/msgs/action_msgs/srv/CancelGoal_Request.gen.go b/internal/msgs/action_msgs/srv/CancelGoal_Request.gen.go index 0521c39..8a557eb 100644 --- a/internal/msgs/action_msgs/srv/CancelGoal_Request.gen.go +++ b/internal/msgs/action_msgs/srv/CancelGoal_Request.gen.go @@ -60,6 +60,13 @@ func (t *CancelGoal_Request) SetDefaults() { func (t *CancelGoal_Request) GetTypeSupport() types.MessageTypeSupport { return CancelGoal_RequestTypeSupport } +func (t *CancelGoal_Request) GetGoalID() *types.GoalID { + return (*types.GoalID)(&t.GoalInfo.GoalId.Uuid) +} + +func (t *CancelGoal_Request) SetGoalID(id *types.GoalID) { + t.GoalInfo.GoalId.Uuid = *id +} // CancelGoal_RequestPublisher wraps rclgo.Publisher to provide type safe helper // functions diff --git a/internal/msgs/example_interfaces/action/Fibonacci.gen.go b/internal/msgs/example_interfaces/action/Fibonacci.gen.go index 002d1fd..1c581be 100644 --- a/internal/msgs/example_interfaces/action/Fibonacci.gen.go +++ b/internal/msgs/example_interfaces/action/Fibonacci.gen.go @@ -186,20 +186,21 @@ func NewFibonacciClient(node *rclgo.Node, name string, opts *rclgo.ActionClientO return &FibonacciClient{client}, nil } -func (c *FibonacciClient) WatchGoal(ctx context.Context, goal *Fibonacci_Goal, onFeedback FibonacciFeedbackHandler) (*Fibonacci_GetResult_Response, error) { +func (c *FibonacciClient) WatchGoal(ctx context.Context, goal *Fibonacci_Goal, onFeedback FibonacciFeedbackHandler) (*Fibonacci_GetResult_Response, *types.GoalID, error) { var resp types.Message + var goalID *types.GoalID var err error if onFeedback == nil { - resp, err = c.ActionClient.WatchGoal(ctx, goal, nil) + resp, goalID, err = c.ActionClient.WatchGoal(ctx, goal, nil) } else { - resp, err = c.ActionClient.WatchGoal(ctx, goal, func(ctx context.Context, msg types.Message) { + resp, goalID, err = c.ActionClient.WatchGoal(ctx, goal, func(ctx context.Context, msg types.Message) { onFeedback(ctx, msg.(*Fibonacci_FeedbackMessage)) }) } if r, ok := resp.(*Fibonacci_GetResult_Response); ok { - return r, err + return r, goalID, err } - return nil, err + return nil, goalID, err } func (c *FibonacciClient) SendGoal(ctx context.Context, goal *Fibonacci_Goal) (*Fibonacci_SendGoal_Response, *types.GoalID, error) { diff --git a/internal/msgs/test_msgs/action/Fibonacci.gen.go b/internal/msgs/test_msgs/action/Fibonacci.gen.go index c44e0a3..6346aa3 100644 --- a/internal/msgs/test_msgs/action/Fibonacci.gen.go +++ b/internal/msgs/test_msgs/action/Fibonacci.gen.go @@ -186,20 +186,21 @@ func NewFibonacciClient(node *rclgo.Node, name string, opts *rclgo.ActionClientO return &FibonacciClient{client}, nil } -func (c *FibonacciClient) WatchGoal(ctx context.Context, goal *Fibonacci_Goal, onFeedback FibonacciFeedbackHandler) (*Fibonacci_GetResult_Response, error) { +func (c *FibonacciClient) WatchGoal(ctx context.Context, goal *Fibonacci_Goal, onFeedback FibonacciFeedbackHandler) (*Fibonacci_GetResult_Response, *types.GoalID, error) { var resp types.Message + var goalID *types.GoalID var err error if onFeedback == nil { - resp, err = c.ActionClient.WatchGoal(ctx, goal, nil) + resp, goalID, err = c.ActionClient.WatchGoal(ctx, goal, nil) } else { - resp, err = c.ActionClient.WatchGoal(ctx, goal, func(ctx context.Context, msg types.Message) { + resp, goalID, err = c.ActionClient.WatchGoal(ctx, goal, func(ctx context.Context, msg types.Message) { onFeedback(ctx, msg.(*Fibonacci_FeedbackMessage)) }) } if r, ok := resp.(*Fibonacci_GetResult_Response); ok { - return r, err + return r, goalID, err } - return nil, err + return nil, goalID, err } func (c *FibonacciClient) SendGoal(ctx context.Context, goal *Fibonacci_Goal) (*Fibonacci_SendGoal_Response, *types.GoalID, error) { diff --git a/internal/msgs/test_msgs/action/NestedMessage.gen.go b/internal/msgs/test_msgs/action/NestedMessage.gen.go index c19896a..1c1beb1 100644 --- a/internal/msgs/test_msgs/action/NestedMessage.gen.go +++ b/internal/msgs/test_msgs/action/NestedMessage.gen.go @@ -186,20 +186,21 @@ func NewNestedMessageClient(node *rclgo.Node, name string, opts *rclgo.ActionCli return &NestedMessageClient{client}, nil } -func (c *NestedMessageClient) WatchGoal(ctx context.Context, goal *NestedMessage_Goal, onFeedback NestedMessageFeedbackHandler) (*NestedMessage_GetResult_Response, error) { +func (c *NestedMessageClient) WatchGoal(ctx context.Context, goal *NestedMessage_Goal, onFeedback NestedMessageFeedbackHandler) (*NestedMessage_GetResult_Response, *types.GoalID, error) { var resp types.Message + var goalID *types.GoalID var err error if onFeedback == nil { - resp, err = c.ActionClient.WatchGoal(ctx, goal, nil) + resp, goalID, err = c.ActionClient.WatchGoal(ctx, goal, nil) } else { - resp, err = c.ActionClient.WatchGoal(ctx, goal, func(ctx context.Context, msg types.Message) { + resp, goalID, err = c.ActionClient.WatchGoal(ctx, goal, func(ctx context.Context, msg types.Message) { onFeedback(ctx, msg.(*NestedMessage_FeedbackMessage)) }) } if r, ok := resp.(*NestedMessage_GetResult_Response); ok { - return r, err + return r, goalID, err } - return nil, err + return nil, goalID, err } func (c *NestedMessageClient) SendGoal(ctx context.Context, goal *NestedMessage_Goal) (*NestedMessage_SendGoal_Response, *types.GoalID, error) { diff --git a/pkg/gogen/templates.go b/pkg/gogen/templates.go index 62683ea..4782507 100644 --- a/pkg/gogen/templates.go +++ b/pkg/gogen/templates.go @@ -180,7 +180,15 @@ func (t *{{$Md.Name}}) GetGoalAccepted() bool { } {{- end -}} -{{- if matchMsg $Md "action_msgs_srv" "CancelGoal_Response" }} +{{ if matchMsg $Md "action_msgs_srv" "CancelGoal_Request" }} +func (t *{{$Md.Name}}) GetGoalID() *types.GoalID { + return (*types.GoalID)(&t.GoalInfo.GoalId.Uuid) +} + +func (t *{{$Md.Name}}) SetGoalID(id *types.GoalID) { + t.GoalInfo.GoalId.Uuid = *id +} +{{- else if matchMsg $Md "action_msgs_srv" "CancelGoal_Response" }} func (t *{{$Md.Name}}) CallForEach(f func(interface{})) { for i := range t.GoalsCanceling { f((*types.GoalID)(&t.GoalsCanceling[i].GoalId.Uuid)) @@ -636,20 +644,21 @@ func New{{.Action.Name}}Client(node *rclgo.Node, name string, opts *rclgo.Action return &{{.Action.Name}}Client{client}, nil } -func (c *{{.Action.Name}}Client) WatchGoal(ctx context.Context, goal *{{.Action.Name}}_Goal, onFeedback {{.Action.Name}}FeedbackHandler) (*{{.Action.Name}}_GetResult_Response, error) { +func (c *{{.Action.Name}}Client) WatchGoal(ctx context.Context, goal *{{.Action.Name}}_Goal, onFeedback {{.Action.Name}}FeedbackHandler) (*{{.Action.Name}}_GetResult_Response, *types.GoalID, error) { var resp types.Message + var goalID *types.GoalID var err error if onFeedback == nil { - resp, err = c.ActionClient.WatchGoal(ctx, goal, nil) + resp, goalID, err = c.ActionClient.WatchGoal(ctx, goal, nil) } else { - resp, err = c.ActionClient.WatchGoal(ctx, goal, func(ctx context.Context, msg types.Message) { + resp, goalID, err = c.ActionClient.WatchGoal(ctx, goal, func(ctx context.Context, msg types.Message) { onFeedback(ctx, msg.(*{{.Action.Name}}_FeedbackMessage)) }) } if r, ok := resp.(*{{.Action.Name}}_GetResult_Response); ok { - return r, err + return r, goalID, err } - return nil, err + return nil, goalID, err } func (c *{{.Action.Name}}Client) SendGoal(ctx context.Context, goal *{{.Action.Name}}_Goal) (*{{.Action.Name}}_SendGoal_Response, *types.GoalID, error) { diff --git a/pkg/rclgo/action.go b/pkg/rclgo/action.go index d4cb063..23d6aa9 100644 --- a/pkg/rclgo/action.go +++ b/pkg/rclgo/action.go @@ -913,25 +913,39 @@ func (c *ActionClient) Node() *Node { // // The type support of the message passed to onFeedback is // types.ActionTypeSupport.FeedbackMessage(). -func (c *ActionClient) WatchGoal(ctx context.Context, goal types.Message, onFeedback FeedbackHandler) (types.Message, error) { +func (c *ActionClient) WatchGoal(ctx context.Context, goal types.Message, onFeedback FeedbackHandler) (result types.Message, goalID *types.GoalID, retErr error) { req, err := c.newSendGoalRequest(goal) if err != nil { - return nil, err + return nil, nil, err } if onFeedback != nil { - ctx, cancel := context.WithCancel(ctx) + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(ctx) defer cancel() unsub := c.subscribe(ctx, &c.feedbackSubs, req.GetGoalID(), onFeedback) defer unsub() } resp, err := c.SendGoalRequest(ctx, req) if err != nil { - return nil, err + return nil, req.GetGoalID(), err } if !resp.(goalResponseMessage).GetGoalAccepted() { - return nil, errors.New("goal was rejected") - } - return c.GetResult(ctx, req.GetGoalID()) + return nil, req.GetGoalID(), errors.New("goal was rejected") + } + defer func() { + if ctx.Err() != nil { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + cancelReq := c.typeSupport.CancelGoal().Request().New() + cancelReq.(goalIDMessage).SetGoalID(req.GetGoalID()) + _, err := c.CancelGoal(ctx, cancelReq) //nolint:contextcheck + if err != nil { + retErr = errors.Join(err, retErr) + } + } + }() + result, err = c.GetResult(ctx, req.GetGoalID()) + return result, req.GetGoalID(), err } // SendGoal sends a new goal to the server and returns the status message of the diff --git a/pkg/rclgo/action_test.go b/pkg/rclgo/action_test.go index 0fe2fb5..f4a763d 100644 --- a/pkg/rclgo/action_test.go +++ b/pkg/rclgo/action_test.go @@ -10,6 +10,7 @@ import ( "github.com/bradleyjkemp/cupaloy/v2" . "github.com/smartystreets/goconvey/convey" //nolint:revive + "github.com/stretchr/testify/require" action_msgs_msg "github.com/tiiuae/rclgo/internal/msgs/action_msgs/msg" action_msgs_srv "github.com/tiiuae/rclgo/internal/msgs/action_msgs/srv" test_msgs_action "github.com/tiiuae/rclgo/internal/msgs/test_msgs/action" @@ -154,7 +155,7 @@ func TestActionExecution(t *testing.T) { goal.Order = 10 feedbacks := make(fibonacciFeedbacks, 0) var feedbacksMu sync.Mutex - result, err := client.WatchGoal(ctx, goal, func(c context.Context, m types.Message) { + result, _, err := client.WatchGoal(ctx, goal, func(c context.Context, m types.Message) { fb := m.(*test_msgs_action.Fibonacci_FeedbackMessage) feedbacksMu.Lock() feedbacks = append(feedbacks, &fb.Feedback) @@ -179,7 +180,7 @@ func TestActionExecution(t *testing.T) { close(action.continueChan) goal := test_msgs_action.NewFibonacci_Goal() goal.Order = -1 - resp, err := client.WatchGoal(ctx, goal, func(c context.Context, m types.Message) { + resp, _, err := client.WatchGoal(ctx, goal, func(c context.Context, m types.Message) { panic("no feedback should be sent") }) So(err, ShouldNotBeNil) @@ -292,6 +293,65 @@ func TestActionCanceling(t *testing.T) { }) } +func TestWatchGoalCanceling(t *testing.T) { + _, cancelingAction := newWaitAction() + var ( + clientCtx, clientCancel = context.WithCancel(context.Background()) + serverCtx, serverCancel = context.WithCancel(context.Background()) + + rclctx *rclgo.Context + err error + ) + defer func() { + clientCancel() + serverCancel() + if rclctx != nil { + rclctx.Close() + } + }() + rclctx, err = newDefaultRCLContext() + require.NoError(t, err) + serverNode, err := rclctx.NewNode("server", "actions_test") + require.NoError(t, err) + _, err = serverNode.NewActionServer("canceling", cancelingAction, actionServerOpts) + require.NoError(t, err) + + clientNode, err := rclctx.NewNode("client", "actions_test") + require.NoError(t, err) + client, err := clientNode.NewActionClient("canceling", test_msgs_action.FibonacciTypeSupport, actionClientOpts) + require.NoError(t, err) + + spinErrC := make(chan error, 1) + go func() { spinErrC <- rclctx.Spin(serverCtx) }() + + goalErrC := make(chan error, 1) + var goalID *types.GoalID + go func() { + var goalErr error + _, goalID, goalErr = client.WatchGoal(clientCtx, test_msgs_action.NewFibonacci_Goal(), nil) + goalErrC <- goalErr + }() + + time.Sleep(time.Second) + clientCancel() + + err = waitChan(t, time.Second, goalErrC, "Waiting for goal watching to stop") + require.ErrorIs(t, err, context.Canceled) + + resultCtx, resultCancel := context.WithTimeout(context.Background(), time.Second) + defer resultCancel() + resp, err := client.GetResult(resultCtx, goalID) + require.NoError(t, err) + require.NotNil(t, resp) + result := resp.(*test_msgs_action.Fibonacci_GetResult_Response) + require.Equal(t, rclgo.GoalCanceled, rclgo.GoalStatus(result.Status)) + + serverCancel() + err = waitChan(t, time.Second, spinErrC, "Waiting for spinning to stop") + require.Error(t, err) + require.NoError(t, rclctx.Close()) +} + type goalStatus struct { ID types.GoalID Status rclgo.GoalStatus diff --git a/pkg/rclgo/example_action_client_test.go b/pkg/rclgo/example_action_client_test.go index defa41e..b01a8c2 100644 --- a/pkg/rclgo/example_action_client_test.go +++ b/pkg/rclgo/example_action_client_test.go @@ -31,7 +31,7 @@ func ExampleActionClient() { ctx := context.Background() goal := example_interfaces_action.NewFibonacci_Goal() goal.Order = 10 - result, err := client.WatchGoal(ctx, goal, func(ctx context.Context, feedback types.Message) { + result, _, err := client.WatchGoal(ctx, goal, func(ctx context.Context, feedback types.Message) { fmt.Println("Got feedback:", feedback) }) if err != nil { @@ -61,7 +61,7 @@ func ExampleActionClient_type_safe_wrapper() { ctx := context.Background() goal := example_interfaces_action.NewFibonacci_Goal() goal.Order = 10 - result, err := client.WatchGoal(ctx, goal, func(ctx context.Context, feedback *example_interfaces_action.Fibonacci_FeedbackMessage) { + result, _, err := client.WatchGoal(ctx, goal, func(ctx context.Context, feedback *example_interfaces_action.Fibonacci_FeedbackMessage) { fmt.Println("Got feedback:", feedback) }) if err != nil { diff --git a/pkg/rclgo/pubsub_test.go b/pkg/rclgo/pubsub_test.go index 4eec573..2669f0e 100644 --- a/pkg/rclgo/pubsub_test.go +++ b/pkg/rclgo/pubsub_test.go @@ -472,6 +472,18 @@ func timeOut(timeoutMs int, f func(), testDescription string) { } } +func waitChan[T any](t *testing.T, timeout time.Duration, ch <-chan T, testDescription string) (recv T) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + select { + case <-ctx.Done(): + t.Fatalf("%s: timeout", testDescription) + case recv = <-ch: + } + return recv +} + func publishString(pub *rclgo.Publisher, s string) { msg := std_msgs.NewString() msg.Data = s