diff --git a/sdks/go/pkg/beam/options/resource/hint.go b/sdks/go/pkg/beam/options/resource/hint.go index 983014ccf8c8..d823f4feafa9 100644 --- a/sdks/go/pkg/beam/options/resource/hint.go +++ b/sdks/go/pkg/beam/options/resource/hint.go @@ -221,7 +221,7 @@ func (h CPUCountHint) Payload() []byte { return []byte(strconv.FormatUint(h.value, 10)) } -// MergeWith an outer CPUCountHints by keeping the maximum of the two cpu counts. +// MergeWithOuter by keeping the maximum of the two cpu counts. func (h CPUCountHint) MergeWithOuter(outer Hint) Hint { // Intentional runtime panic from type assertion to catch hint merge errors. if outer.(CPUCountHint).value > h.value { diff --git a/sdks/go/pkg/beam/options/resource/hint_test.go b/sdks/go/pkg/beam/options/resource/hint_test.go index cf24b47b6c91..7c2a1df79294 100644 --- a/sdks/go/pkg/beam/options/resource/hint_test.go +++ b/sdks/go/pkg/beam/options/resource/hint_test.go @@ -111,6 +111,38 @@ func TestParseMinRAMHint_panic(t *testing.T) { ParseMinRAM("a bad byte string") } +func TestCPUCountHint_MergeWith(t *testing.T) { + low := CPUCountHint{value: 2} + high := CPUCountHint{value: 128} + + if got, want := low.MergeWithOuter(high), high; got != want { + t.Errorf("%v.MergeWith(%v) = %v, want %v", low, high, got, want) + } + if got, want := high.MergeWithOuter(low), high; got != want { + t.Errorf("%v.MergeWith(%v) = %v, want %v", high, low, got, want) + } +} + +func TestCPUCountHint_Payload(t *testing.T) { + tests := []struct { + value uint64 + payload string + }{ + {0, "0"}, + {2, "2"}, + {11, "11"}, + {2003, "2003"}, + {1.2e7, "12000000"}, + } + + for _, test := range tests { + h := CPUCountHint{value: test.value} + if got, want := h.Payload(), []byte(test.payload); !bytes.Equal(got, want) { + t.Errorf("%v.Payload() = %v, want %v", h, got, want) + } + } +} + // We copy the URN from the proto for use as a constant rather than perform a direct look up // each time, or increase initialization time. However we do need to validate that they are // correct, and match the standard hint urns, so that's done here. @@ -130,7 +162,11 @@ func TestStandardHintUrns(t *testing.T) { }, { h: MinRAMBytes(2e9), urn: getStandardURN(pipepb.StandardResourceHints_MIN_RAM_BYTES), + }, { + h: CPUCount(4), + urn: getStandardURN(pipepb.StandardResourceHints_CPU_COUNT), }} + for _, test := range tests { if got, want := test.h.URN(), test.urn; got != want { t.Errorf("Checked urn for %T, got %q, want %q", test.h, got, want) @@ -154,12 +190,12 @@ func (h customHint) MergeWithOuter(outer Hint) Hint { } func TestHints_Equal(t *testing.T) { - hs := NewHints(MinRAMBytes(2e9), Accelerator("type:pants;count1;install-pajamas")) + hs := NewHints(MinRAMBytes(2e9), Accelerator("type:pants;count1;install-pajamas"), CPUCount(4)) if got, want := hs.Equal(hs), true; got != want { t.Errorf("Self equal test: hs.Equal(hs) = %v, want %v", got, want) } - eq := NewHints(MinRAMBytes(2e9), Accelerator("type:pants;count1;install-pajamas")) + eq := NewHints(MinRAMBytes(2e9), Accelerator("type:pants;count1;install-pajamas"), CPUCount(4)) if got, want := hs.Equal(eq), true; got != want { t.Errorf("identical equal test: hs.Equal(eq) = %v, want %v", got, want) } @@ -223,12 +259,13 @@ func TestHints_MergeWithOuter(t *testing.T) { func TestHints_Payloads(t *testing.T) { { - hs := NewHints(MinRAMBytes(2e9), Accelerator("type:jeans;count1;")) + hs := NewHints(MinRAMBytes(2e9), Accelerator("type:jeans;count1;"), CPUCount(4)) got := hs.Payloads() want := map[string][]byte{ "beam:resources:min_ram_bytes:v1": []byte("2000000000"), "beam:resources:accelerator:v1": []byte("type:jeans;count1;"), + "beam:resources:cpu_count:v1": []byte("4"), } if !reflect.DeepEqual(got, want) { t.Errorf("hs.Payloads() = %v, want %v", got, want) @@ -248,7 +285,7 @@ func TestHints_Payloads(t *testing.T) { func TestHints_NilHints(t *testing.T) { var hs1, hs2 Hints - hs := NewHints(MinRAMBytes(2e9), Accelerator("type:pants;count1;install-pajamas")) + hs := NewHints(MinRAMBytes(2e9), Accelerator("type:pants;count1;install-pajamas"), CPUCount(4)) if got, want := hs1.Equal(hs2), true; got != want { t.Errorf("nils equal test: (nil).Equal(nil) = %v, want %v", got, want) diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/resourcehints/ResourceHintsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/resourcehints/ResourceHintsTest.java index 3cc522176374..4e2aeab9b15e 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/resourcehints/ResourceHintsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/resourcehints/ResourceHintsTest.java @@ -92,10 +92,11 @@ public void testFromOptions() { .withHint("beam:resources:bar", new ResourceHints.StringHint("foo"))); options = PipelineOptionsFactory.fromArgs( - "--resourceHints=min_ram=1KB", "--resourceHints=accelerator=foo") + "--resourceHints=min_ram=1KB", "--resourceHints=accelerator=foo", + "--resourceHints=cpu_count=4") .as(ResourceHintsOptions.class); assertEquals( ResourceHints.fromOptions(options), - ResourceHints.create().withMinRam(1000).withAccelerator("foo")); + ResourceHints.create().withMinRam(1000).withAccelerator("foo").withCpuCount(4)); } } diff --git a/sdks/python/apache_beam/transforms/resources.py b/sdks/python/apache_beam/transforms/resources.py index 73592f56202f..7c4160df8edd 100644 --- a/sdks/python/apache_beam/transforms/resources.py +++ b/sdks/python/apache_beam/transforms/resources.py @@ -179,16 +179,17 @@ def get_merged_value( class CpuCountHint(ResourceHint): - """Describes desired hardware accelerators in execution environment.""" + """Describes number of CPUs available in transform's execution environment.""" urn = resource_hints.CPU_COUNT.urn @classmethod def get_merged_value( - cls, outer_value, inner_value): # type: (int, int) -> int + cls, outer_value, inner_value): # type: (bytes, bytes) -> bytes return ResourceHint._use_max(outer_value, inner_value) ResourceHint.register_resource_hint('cpu_count', CpuCountHint) +# Alias for interoperability with SDKs preferring camelCase. ResourceHint.register_resource_hint('cpuCount', CpuCountHint) diff --git a/sdks/python/apache_beam/transforms/resources_test.py b/sdks/python/apache_beam/transforms/resources_test.py index 939391b7adcb..939bdcd62651 100644 --- a/sdks/python/apache_beam/transforms/resources_test.py +++ b/sdks/python/apache_beam/transforms/resources_test.py @@ -46,6 +46,11 @@ class ResourcesTest(unittest.TestCase): val='gpu', urn='beam:resources:accelerator:v1', bytestr=b'gpu'), + param( + name='cpu_count', + val='4', + urn='beam:resources:cpu_count:v1', + bytestr=b'4'), ]) def test_known_resource_hints(self, name, val, urn, bytestr): t = PTransform() @@ -56,6 +61,7 @@ def test_known_resource_hints(self, name, val, urn, bytestr): @parameterized.expand([ param(name='min_ram', val='3,500G'), param(name='accelerator', val=1), + param(name='cpu_count', val=1), param(name='unknown_hint', val=1) ]) def test_resource_hint_parsing_fails_early(self, name, val):