-
Notifications
You must be signed in to change notification settings - Fork 358
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add dataset generators in DP Auditorium
DP Auditorium: - Add dataset generator for classifications tasks - Add dataset generator for Pipeline DP tests - Remove duplicated `VizierGeneratorConfig` Change-Id: Id7e54f8c9c248192f7da035afa86b9ac2f1ecec0 GitOrigin-RevId: 8865b02e81bd6141c9511a38cd5145a598fb46ca
- Loading branch information
1 parent
dc6456b
commit bd6d07c
Showing
10 changed files
with
524 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
166 changes: 166 additions & 0 deletions
166
python/dp_auditorium/dp_auditorium/generators/classification_dataset_generator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
# Copyright 2024 Google LLC. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""JAX-based dataset generator for binary classification problems.""" | ||
|
||
import numpy as np | ||
from vizier.service import clients | ||
from vizier.service import pyvizier as vz | ||
|
||
from dp_auditorium import interfaces | ||
from dp_auditorium.configs import dataset_generator_config | ||
from dp_auditorium.generators import vizier_dataset_generator | ||
|
||
|
||
def _initialize_vizier_problem_and_params_lookup_table( | ||
num_float_params: int, num_int_params: int, features_min_value: float, | ||
features_max_value: float, num_classes: int | ||
) -> tuple[vz.ProblemStatement, dict[str, int]]: | ||
"""Returns Vizier problem and parameters lookup table. | ||
This function initializes a Vizier problem and adds named parameters to it. | ||
`VizierDatasetGenerator` retrieves Vizier parameters (names and values) and | ||
organizes them into a one-dimensional array called `vizier_params`, which will | ||
serve as the input for the function | ||
`get_neighboring_datasets_from_vizier_params`. This function creates mappings | ||
between parameter names and their corresponding indices in `vizier_params` | ||
to ensure consistency between `_extract_params_from_trial` and | ||
`get_neighboring_datasets_from_vizier_params`, and that features and labels | ||
are correctly utilized by classification mechanisms. | ||
Args: | ||
num_float_params: Number of Vizier parameters of type float. These will be | ||
used as features. | ||
num_int_params: Number of Vizier parameters of type int. These will be | ||
used as labels. | ||
features_min_value: Minimum value for feature parameters. | ||
features_max_value: Maximum value for feature parameters. | ||
num_classes: Number of class labels. | ||
Returns: | ||
A tuple where the first element is a Vizier problem and the second element a | ||
dictionary where keys are the problem's parameters names and values are the | ||
corresponding index in a flattened array `vizier_params`. | ||
""" | ||
idx_to_str = {idx: f'float{idx}' for idx in range(num_float_params)} | ||
str_to_idx = { | ||
name: int(name.split('float', 2)[1]) for name in idx_to_str.values() | ||
} | ||
idx_to_str_int = { | ||
idx: f'int{idx}' | ||
for idx in range(num_float_params, num_float_params + num_int_params) | ||
} | ||
str_to_idx_int = { | ||
name: int(name.split('int', 2)[1]) for name in idx_to_str_int.values() | ||
} | ||
idx_to_str.update(idx_to_str_int) | ||
str_to_idx.update(str_to_idx_int) | ||
|
||
# Define problem parameters | ||
problem = vz.ProblemStatement() | ||
|
||
for i in range(num_float_params): | ||
problem.search_space.root.add_float_param( | ||
idx_to_str[i], | ||
min_value=features_min_value, | ||
max_value=features_max_value, | ||
) | ||
for i in range(num_float_params, num_float_params+num_int_params): | ||
problem.search_space.root.add_int_param( | ||
idx_to_str[i], | ||
min_value=0, | ||
max_value=num_classes, | ||
) | ||
return problem, str_to_idx | ||
|
||
|
||
class ClassificationDatasetGenerator( | ||
vizier_dataset_generator.VizierDatasetGenerator | ||
): | ||
"""Classification dataset generators supported by Vizier. | ||
This class generate pairs of add/remove-neighboring datasets, where each | ||
record has the form `(features, label)`. `features` is a float array and label | ||
an integer indicating a class. Both features and labels are generated by | ||
Vizier using the search algorithm specified in the configuration. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
config: dataset_generator_config.ClassificationDatasetGeneratorConfig, | ||
): | ||
"""Initializes a dataset generator where records have features and labels. | ||
Args: | ||
config: A configuration proto for classification dataset generator. | ||
""" | ||
|
||
self._num_samples = config.num_samples | ||
self._sample_dim = config.sample_dim | ||
self._num_float_params = config.sample_dim * config.num_samples | ||
self._num_int_params = config.num_samples | ||
self._num_vizier_params = self._num_float_params + self._num_int_params | ||
|
||
# Get indices to parameter names mapping assigned in Vizier and its inverse. | ||
problem, str_to_idx = _initialize_vizier_problem_and_params_lookup_table( | ||
num_float_params=self._num_float_params, | ||
num_int_params=self._num_int_params, | ||
features_min_value=config.min_value, | ||
features_max_value=config.max_value, | ||
num_classes=config.num_classes, | ||
) | ||
|
||
self._str_to_idx = str_to_idx | ||
|
||
# Define metric. | ||
self._metric_name = config.metric_name | ||
problem.metric_information = [ | ||
vz.MetricInformation( | ||
name=config.metric_name, goal=vz.ObjectiveMetricGoal.MAXIMIZE | ||
) | ||
] | ||
|
||
study_config = vz.StudyConfig.from_problem(problem) | ||
study_config.algorithm = config.search_algorithm | ||
|
||
self._study_client = clients.Study.from_study_config( | ||
study_config, owner=config.study_owner, study_id=config.study_name | ||
) | ||
|
||
self._trial_loaded = False | ||
self._last_trial = None | ||
|
||
def get_neighboring_datasets_from_vizier_params( | ||
self, vizier_params: np.ndarray | ||
) -> interfaces.NeighboringDatasetsType: | ||
"""Transforms a one-dimensional numpy array to neighboring datasets. | ||
Generates neighboring datasets where `data = (features, labels)`. `features` | ||
is an array with shape (self._num_samples, self._sample_dim) and labels an | ||
array with shape (self._num_samples,). | ||
Args: | ||
vizier_params: array with parameters generated by Vizier. | ||
Returns: | ||
Pair `(data1, data2)` of neighboring datasets under the add/remove | ||
definition. `data2` will have one record less than `data1`. | ||
""" | ||
|
||
features = vizier_params[: self._num_float_params] | ||
features = features.reshape(self._num_samples, self._sample_dim) | ||
labels = vizier_params[self._num_float_params :] | ||
|
||
data1 = features, labels | ||
data2 = data1[0][:-1], data1[1][:-1] | ||
return data1, data2 |
83 changes: 83 additions & 0 deletions
83
python/dp_auditorium/dp_auditorium/generators/classification_dataset_generator_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# Copyright 2024 Google LLC. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Tests for classification_dataset_generator.""" | ||
from absl.testing import absltest | ||
import tensorflow as tf | ||
from vizier.service import clients | ||
|
||
from dp_auditorium.configs import dataset_generator_config | ||
from dp_auditorium.generators import classification_dataset_generator | ||
|
||
|
||
clients.environment_variables.servicer_use_sql_ram() | ||
|
||
|
||
class ClassificationDatasetGeneratorTest(tf.test.TestCase): | ||
|
||
def setUp(self): | ||
super().setUp() | ||
self.sample_dim = 5 | ||
self.num_samples = 4 | ||
self.min_value = 13.7 | ||
self.max_value = 15.9 | ||
self.num_classes = 3 | ||
|
||
self.generator_config = ( | ||
dataset_generator_config.ClassificationDatasetGeneratorConfig( | ||
sample_dim=self.sample_dim, | ||
num_samples=self.num_samples, | ||
num_classes=self.num_classes, | ||
min_value=self.min_value, | ||
max_value=self.max_value, | ||
study_name='stub_study', | ||
study_owner='stub_owner', | ||
metric_name='stub_metric', | ||
search_algorithm='RANDOM_SEARCH', | ||
) | ||
) | ||
self.generator = ( | ||
classification_dataset_generator.ClassificationDatasetGenerator( | ||
config=self.generator_config, | ||
) | ||
) | ||
|
||
def test_get_neighboring_datasets_from_vizier_params_produces_correct_pair( | ||
self, | ||
): | ||
"""Tests datasets have correct shapes and are adjacent.""" | ||
|
||
data1, data2 = self.generator(None) | ||
# Check output shape | ||
with self.subTest('data1-images-have-correct-shape'): | ||
self.assertEqual(data1[0].shape, (self.num_samples, self.sample_dim)) | ||
with self.subTest('data1-labels-have-correct-shape'): | ||
self.assertEqual(data1[1].shape, (self.num_samples,)) | ||
with self.subTest('data2-images-have-correct-shape'): | ||
self.assertEqual(data2[0].shape, (self.num_samples - 1, self.sample_dim)) | ||
with self.subTest('data2-labels-have-correct-shape'): | ||
self.assertEqual(data2[1].shape, (self.num_samples - 1,)) | ||
|
||
# Check output values range. | ||
with self.subTest('data1-labels-in-range'): | ||
self.assertAllInRange(data1[1], 0, self.num_classes) | ||
with self.subTest('data2-labels-in-range'): | ||
self.assertAllInRange(data2[1], 0, self.num_classes) | ||
with self.subTest('data1-features-in-range'): | ||
self.assertAllInRange(data1[0], self.min_value, self.max_value) | ||
with self.subTest('data2-features-in-range'): | ||
self.assertAllInRange(data2[0], self.min_value, self.max_value) | ||
|
||
|
||
if __name__ == '__main__': | ||
absltest.main() |
Oops, something went wrong.