diff --git a/pkg/workceptor/workceptor_test.go b/pkg/workceptor/workceptor_test.go index b4ada1290..9a5a4a007 100644 --- a/pkg/workceptor/workceptor_test.go +++ b/pkg/workceptor/workceptor_test.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "errors" "fmt" + "os" "testing" "github.com/ansible/receptor/pkg/logger" @@ -302,6 +303,7 @@ func TestAllocateRemoteUnit(t *testing.T) { testCases := []struct { name string + workUnitID string tlsClient string ttl string signWork bool @@ -310,37 +312,52 @@ func TestAllocateRemoteUnit(t *testing.T) { expectedCalls func() }{ { - name: "get client tls config error", - tlsClient: "something", - errorMsg: "terminated", + name: "get client tls config error", + workUnitID: "", + tlsClient: "something", + errorMsg: "terminated", expectedCalls: func() { mockNetceptor.EXPECT().GetClientTLSConfig(gomock.Any(), gomock.Any(), gomock.Any()).Return(&tls.Config{}, errors.New("terminated")) }, }, { - name: "sending secrets over non tls connection error", - tlsClient: "", - params: map[string]string{"secret_": "secret"}, - errorMsg: "cannot send secrets over a non-TLS connection", + name: "sending secrets over non tls connection error", + workUnitID: "", + tlsClient: "", + params: map[string]string{"secret_": "secret"}, + errorMsg: "cannot send secrets over a non-TLS connection", expectedCalls: func() { // For testing purposes }, }, { - name: "invalid duration error", - tlsClient: "", - ttl: "ttl", - errorMsg: "time: invalid duration \"ttl\"", + name: "invalid duration error", + workUnitID: "", + tlsClient: "", + ttl: "ttl", + errorMsg: "time: invalid duration \"ttl\"", expectedCalls: func() { // For testing purposes }, }, { - name: "normal case", - tlsClient: "", - ttl: "1.5h", - errorMsg: "", - signWork: true, + name: "normal case", + workUnitID: "", + tlsClient: "", + ttl: "1.5h", + errorMsg: "", + signWork: true, + expectedCalls: func() { + // For testing purposes + }, + }, + { + name: "pass workUnitID", + workUnitID: "testID12345678", + tlsClient: "", + ttl: "1.5h", + errorMsg: "", + signWork: true, expectedCalls: func() { // For testing purposes }, @@ -350,8 +367,7 @@ func TestAllocateRemoteUnit(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { tc.expectedCalls() - _, err := w.AllocateRemoteUnit("", "", "", tc.tlsClient, tc.ttl, tc.signWork, tc.params) - + wu, err := w.AllocateRemoteUnit("", "", tc.workUnitID, tc.tlsClient, tc.ttl, tc.signWork, tc.params) if tc.errorMsg != "" && tc.errorMsg != err.Error() && err != nil { t.Errorf("expected: %s, received: %s", tc.errorMsg, err) } @@ -359,6 +375,20 @@ func TestAllocateRemoteUnit(t *testing.T) { if tc.errorMsg == "" && err != nil { t.Error(err) } + if tc.workUnitID != "" { + wuID := wu.ID() + if tc.workUnitID != wuID { + t.Errorf("expected workUnitID to equal %s but got %s", tc.workUnitID, wuID) + } + } + }) + t.Cleanup(func() { + if tc.workUnitID != "" { + err := os.RemoveAll(fmt.Sprintf("/tmp/test/%s", tc.workUnitID)) + if err != nil { + t.Errorf("removal of test directory /tmp/test/%s failed", tc.workUnitID) + } + } }) } } diff --git a/tests/functional/mesh/work_test.go b/tests/functional/mesh/work_test.go index 50b8fb0bc..736f5769e 100644 --- a/tests/functional/mesh/work_test.go +++ b/tests/functional/mesh/work_test.go @@ -45,6 +45,10 @@ func TestWorkSubmitWithTLSClient(t *testing.T) { if err != nil { t.Fatal(err, m.GetDataDir()) } + err = controllers["node2"].AssertWorkResults(unitID, expectedResults) + if err != nil { + t.Fatal(err, m.GetDataDir()) + } }) } }