diff --git a/spanner/google/cloud/spanner/streamed.py b/spanner/google/cloud/spanner/streamed.py index f44d0278a22aa..c7d950d766d7e 100644 --- a/spanner/google/cloud/spanner/streamed.py +++ b/spanner/google/cloud/spanner/streamed.py @@ -16,6 +16,7 @@ from google.protobuf.struct_pb2 import ListValue from google.protobuf.struct_pb2 import Value +from google.cloud import exceptions from google.cloud.proto.spanner.v1 import type_pb2 import six @@ -169,6 +170,48 @@ def __iter__(self): while iter_rows: yield iter_rows.pop(0) + def one(self): + """Return exactly one result, or raise an exception. + + :raises: :exc:`NotFound`: If there are no results. + :raises: :exc:`ValueError`: If there are multiple results. + :raises: :exc:`RuntimeError`: If consumption has already occurred, + in whole or in part. + """ + answer = self.one_or_none() + if answer is None: + raise exceptions.NotFound('No rows matched the given query.') + return answer + + def one_or_none(self): + """Return exactly one result, or None if there are no results. + + :raises: :exc:`ValueError`: If there are multiple results. + :raises: :exc:`RuntimeError`: If consumption has already occurred, + in whole or in part. + """ + # Sanity check: Has consumption of this query already started? + # If it has, then this is an exception. + if self._metadata is not None: + raise RuntimeError('Can not call `.one` or `.one_or_none` after ' + 'stream consumption has already started.') + + # Consume the first result of the stream. + # If there is no first result, then return None. + iterator = iter(self) + try: + answer = next(iterator) + except StopIteration: + return None + + # Attempt to consume more. This should no-op; if we get additional + # rows, then this is an error case. + try: + next(iterator) + raise ValueError('Expected one result; got more.') + except StopIteration: + return answer + class Unmergeable(ValueError): """Unable to merge two values. diff --git a/spanner/tests/unit/test_streamed.py b/spanner/tests/unit/test_streamed.py index 2e31f4dfad2cf..0e0bcb7aff6b3 100644 --- a/spanner/tests/unit/test_streamed.py +++ b/spanner/tests/unit/test_streamed.py @@ -53,7 +53,7 @@ def test_fields_unset(self): iterator = _MockCancellableIterator() streamed = self._make_one(iterator) with self.assertRaises(AttributeError): - _ = streamed.fields + streamed.fields @staticmethod def _make_scalar_field(name, type_): @@ -243,13 +243,24 @@ def test__merge_chunk_string_w_bytes(self): self._make_scalar_field('image', 'BYTES'), ] streamed._metadata = self._make_result_set_metadata(FIELDS) - streamed._pending_chunk = self._make_value(u'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAAAAAA6fptVAAAACXBIWXMAAAsTAAALEwEAmpwYAAAA\n') - chunk = self._make_value(u'B3RJTUUH4QQGFwsBTL3HMwAAABJpVFh0Q29tbWVudAAAAAAAU0FNUExFMG3E+AAAAApJREFUCNdj\nYAAAAAIAAeIhvDMAAAAASUVORK5CYII=\n') + streamed._pending_chunk = self._make_value( + u'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAAAAAA' + u'6fptVAAAACXBIWXMAAAsTAAALEwEAmpwYAAAA\n', + ) + chunk = self._make_value( + u'B3RJTUUH4QQGFwsBTL3HMwAAABJpVFh0Q29tbWVudAAAAAAAU0FNUExF' + u'MG3E+AAAAApJREFUCNdj\nYAAAAAIAAeIhvDMAAAAASUVORK5CYII=\n', + ) merged = streamed._merge_chunk(chunk) - self.assertEqual(merged.string_value, u'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAAAAAA6fptVAAAACXBIWXMAAAsTAAALEwEAmpwYAAAA\nB3RJTUUH4QQGFwsBTL3HMwAAABJpVFh0Q29tbWVudAAAAAAAU0FNUExFMG3E+AAAAApJREFUCNdj\nYAAAAAIAAeIhvDMAAAAASUVORK5CYII=\n') - self.assertIsNone(streamed._pending_chunk) + self.assertEqual( + merged.string_value, + u'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAAAAAA6fptVAAAACXBIWXMAAAsTAAAL' + u'EwEAmpwYAAAA\nB3RJTUUH4QQGFwsBTL3HMwAAABJpVFh0Q29tbWVudAAAAAAAU0' + u'FNUExFMG3E+AAAAApJREFUCNdj\nYAAAAAIAAeIhvDMAAAAASUVORK5CYII=\n', + ) + self.assertIsNone(streamed._pending_chunk) def test__merge_chunk_array_of_bool(self): iterator = _MockCancellableIterator() @@ -591,6 +602,48 @@ def test_merge_values_partial_and_filled_plus(self): self.assertEqual(streamed.rows, [VALUES[0:3], VALUES[3:6]]) self.assertEqual(streamed._current_row, VALUES[6:]) + def test_one_or_none_no_value(self): + streamed = self._make_one(_MockCancellableIterator()) + with mock.patch.object(streamed, 'consume_next') as consume_next: + consume_next.side_effect = StopIteration + self.assertIsNone(streamed.one_or_none()) + + def test_one_or_none_single_value(self): + streamed = self._make_one(_MockCancellableIterator()) + streamed._rows = ['foo'] + with mock.patch.object(streamed, 'consume_next') as consume_next: + consume_next.side_effect = StopIteration + self.assertEqual(streamed.one_or_none(), 'foo') + + def test_one_or_none_multiple_values(self): + streamed = self._make_one(_MockCancellableIterator()) + streamed._rows = ['foo', 'bar'] + with self.assertRaises(ValueError): + streamed.one_or_none() + + def test_one_or_none_consumed_stream(self): + streamed = self._make_one(_MockCancellableIterator()) + streamed._metadata = object() + with self.assertRaises(RuntimeError): + streamed.one_or_none() + + def test_one_single_value(self): + streamed = self._make_one(_MockCancellableIterator()) + streamed._rows = ['foo'] + with mock.patch.object(streamed, 'consume_next') as consume_next: + consume_next.side_effect = StopIteration + self.assertEqual(streamed.one(), 'foo') + + def test_one_no_value(self): + from google.cloud import exceptions + + iterator = _MockCancellableIterator(['foo']) + streamed = self._make_one(iterator) + with mock.patch.object(streamed, 'consume_next') as consume_next: + consume_next.side_effect = StopIteration + with self.assertRaises(exceptions.NotFound): + streamed.one() + def test_consume_next_empty(self): iterator = _MockCancellableIterator() streamed = self._make_one(iterator)