Skip to content

Commit

Permalink
Adds tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
kerrydc committed Oct 6, 2023
1 parent 0a9b237 commit f2b17a1
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 9 deletions.
2 changes: 1 addition & 1 deletion sdks/go/pkg/beam/options/resource/hint.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ func (h CPUCountHint) Payload() []byte {
return []byte(strconv.FormatUint(h.value, 10))
}

// MergeWith an outer CPUCountHints by keeping the maximum of the two cpu counts.
// MergeWithOuter by keeping the maximum of the two cpu counts.
func (h CPUCountHint) MergeWithOuter(outer Hint) Hint {
// Intentional runtime panic from type assertion to catch hint merge errors.
if outer.(CPUCountHint).value > h.value {
Expand Down
45 changes: 41 additions & 4 deletions sdks/go/pkg/beam/options/resource/hint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,38 @@ func TestParseMinRAMHint_panic(t *testing.T) {
ParseMinRAM("a bad byte string")
}

func TestCPUCountHint_MergeWith(t *testing.T) {
low := CPUCountHint{value: 2}
high := CPUCountHint{value: 128}

if got, want := low.MergeWithOuter(high), high; got != want {
t.Errorf("%v.MergeWith(%v) = %v, want %v", low, high, got, want)
}
if got, want := high.MergeWithOuter(low), high; got != want {
t.Errorf("%v.MergeWith(%v) = %v, want %v", high, low, got, want)
}
}

func TestCPUCountHint_Payload(t *testing.T) {
tests := []struct {
value uint64
payload string
}{
{0, "0"},
{2, "2"},
{11, "11"},
{2003, "2003"},
{1.2e7, "12000000"},
}

for _, test := range tests {
h := CPUCountHint{value: test.value}
if got, want := h.Payload(), []byte(test.payload); !bytes.Equal(got, want) {
t.Errorf("%v.Payload() = %v, want %v", h, got, want)
}
}
}

// We copy the URN from the proto for use as a constant rather than perform a direct look up
// each time, or increase initialization time. However we do need to validate that they are
// correct, and match the standard hint urns, so that's done here.
Expand All @@ -130,7 +162,11 @@ func TestStandardHintUrns(t *testing.T) {
}, {
h: MinRAMBytes(2e9),
urn: getStandardURN(pipepb.StandardResourceHints_MIN_RAM_BYTES),
}, {
h: CPUCount(4),
urn: getStandardURN(pipepb.StandardResourceHints_CPU_COUNT),
}}

for _, test := range tests {
if got, want := test.h.URN(), test.urn; got != want {
t.Errorf("Checked urn for %T, got %q, want %q", test.h, got, want)
Expand All @@ -154,12 +190,12 @@ func (h customHint) MergeWithOuter(outer Hint) Hint {
}

func TestHints_Equal(t *testing.T) {
hs := NewHints(MinRAMBytes(2e9), Accelerator("type:pants;count1;install-pajamas"))
hs := NewHints(MinRAMBytes(2e9), Accelerator("type:pants;count1;install-pajamas"), CPUCount(4))

if got, want := hs.Equal(hs), true; got != want {
t.Errorf("Self equal test: hs.Equal(hs) = %v, want %v", got, want)
}
eq := NewHints(MinRAMBytes(2e9), Accelerator("type:pants;count1;install-pajamas"))
eq := NewHints(MinRAMBytes(2e9), Accelerator("type:pants;count1;install-pajamas"), CPUCount(4))
if got, want := hs.Equal(eq), true; got != want {
t.Errorf("identical equal test: hs.Equal(eq) = %v, want %v", got, want)
}
Expand Down Expand Up @@ -223,12 +259,13 @@ func TestHints_MergeWithOuter(t *testing.T) {

func TestHints_Payloads(t *testing.T) {
{
hs := NewHints(MinRAMBytes(2e9), Accelerator("type:jeans;count1;"))
hs := NewHints(MinRAMBytes(2e9), Accelerator("type:jeans;count1;"), CPUCount(4))

got := hs.Payloads()
want := map[string][]byte{
"beam:resources:min_ram_bytes:v1": []byte("2000000000"),
"beam:resources:accelerator:v1": []byte("type:jeans;count1;"),
"beam:resources:cpu_count:v1": []byte("4"),
}
if !reflect.DeepEqual(got, want) {
t.Errorf("hs.Payloads() = %v, want %v", got, want)
Expand All @@ -248,7 +285,7 @@ func TestHints_Payloads(t *testing.T) {
func TestHints_NilHints(t *testing.T) {
var hs1, hs2 Hints

hs := NewHints(MinRAMBytes(2e9), Accelerator("type:pants;count1;install-pajamas"))
hs := NewHints(MinRAMBytes(2e9), Accelerator("type:pants;count1;install-pajamas"), CPUCount(4))

if got, want := hs1.Equal(hs2), true; got != want {
t.Errorf("nils equal test: (nil).Equal(nil) = %v, want %v", got, want)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,11 @@ public void testFromOptions() {
.withHint("beam:resources:bar", new ResourceHints.StringHint("foo")));
options =
PipelineOptionsFactory.fromArgs(
"--resourceHints=min_ram=1KB", "--resourceHints=accelerator=foo")
"--resourceHints=min_ram=1KB", "--resourceHints=accelerator=foo",
"--resourceHints=cpu_count=4")
.as(ResourceHintsOptions.class);
assertEquals(
ResourceHints.fromOptions(options),
ResourceHints.create().withMinRam(1000).withAccelerator("foo"));
ResourceHints.create().withMinRam(1000).withAccelerator("foo").withCpuCount(4));
}
}
5 changes: 3 additions & 2 deletions sdks/python/apache_beam/transforms/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,16 +179,17 @@ def get_merged_value(


class CpuCountHint(ResourceHint):
"""Describes desired hardware accelerators in execution environment."""
"""Describes number of CPUs available in transform's execution environment."""
urn = resource_hints.CPU_COUNT.urn

@classmethod
def get_merged_value(
cls, outer_value, inner_value): # type: (int, int) -> int
cls, outer_value, inner_value): # type: (bytes, bytes) -> bytes
return ResourceHint._use_max(outer_value, inner_value)


ResourceHint.register_resource_hint('cpu_count', CpuCountHint)
# Alias for interoperability with SDKs preferring camelCase.
ResourceHint.register_resource_hint('cpuCount', CpuCountHint)


Expand Down
6 changes: 6 additions & 0 deletions sdks/python/apache_beam/transforms/resources_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ class ResourcesTest(unittest.TestCase):
val='gpu',
urn='beam:resources:accelerator:v1',
bytestr=b'gpu'),
param(
name='cpu_count',
val='4',
urn='beam:resources:cpu_count:v1',
bytestr=b'4'),
])
def test_known_resource_hints(self, name, val, urn, bytestr):
t = PTransform()
Expand All @@ -56,6 +61,7 @@ def test_known_resource_hints(self, name, val, urn, bytestr):
@parameterized.expand([
param(name='min_ram', val='3,500G'),
param(name='accelerator', val=1),
param(name='cpu_count', val=1),
param(name='unknown_hint', val=1)
])
def test_resource_hint_parsing_fails_early(self, name, val):
Expand Down

0 comments on commit f2b17a1

Please sign in to comment.