Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to begin training from initial models #53

Merged
merged 5 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions 2_proxima/0_run-serial-proxima.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,14 @@
group.add_argument('--initial-model', help='Path to initial model in message format. Code will generate a network with default settings if none provided')
group.add_argument('--initial-data', nargs='*', default=(), help='Path to data files (e.g., ASE .traj and .db) containing initial training data')
group.add_argument('--ensemble-size', type=int, default=2, help='Number of models to train on different data segments')
group.add_argument('--online-training', action='store_true', help='Whether to use the weights from the last training step as starting for the next')
group.add_argument('--training-mode', choices=['online', 'reset', 'finetune'],
help='How to begin each training step. Random weights (reset), weights from last training step (online), or weights from original model (finetune)'
)
group.add_argument('--training-epochs', type=int, default=32, help='Number of epochs per training event')
group.add_argument('--training-batch-size', type=int, default=32, help='Which device to use for training models')
group.add_argument('--training-max-size', type=int, default=None, help='Maximum training set size to use when updating models')
group.add_argument('--training-recency-bias', type=float, default=1., help='Factor by which to favor recent data when reducing training set size')
group.add_argument('--training-learning-rate', type=float, default=1e-3, help='Initial learning rate for the optimizer.')
group.add_argument('--training-device', default='cuda', help='Which device to use for training models')

group = parser.add_argument_group(title='Proxima', description="Settings for learning on the fly")
Expand Down Expand Up @@ -164,11 +167,13 @@
train_kwargs={
'num_epochs': args.training_epochs,
'batch_size': args.training_batch_size,
'reset_weights': not args.online_training,
'reset_weights': args.training_mode == 'reset',
'learning_rate': args.training_learning_rate,
'device': args.training_device}, # Configuration for the training routines
train_freq=args.retrain_freq,
train_max_size=args.training_max_size,
train_recency_bias=args.training_recency_bias,
train_from_original=args.training_mode == 'finetune',
target_ferr=args.target_error,
history_length=args.error_history,
min_target_fraction=args.min_target_frac,
Expand Down
35 changes: 26 additions & 9 deletions cascade/proxima/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ class SerialLearningCalculator(Calculator):
target_ferr: float
Target maximum difference between the forces predicted by the target
calculator and the learnable surrogate
train_from_original: bool
Whether to use the original models provided when creating the class as the starting
point for training rather than the models produced from the most-recent training.
The calculator will preserve the original models only if ``True``.
history_length: int
The number of previous observations of the error between target and surrogate
function to use when establishing a link between uncertainty metric
Expand All @@ -67,7 +71,7 @@ class SerialLearningCalculator(Calculator):
even if it need not be used based on the UQ metric.
n_blending_steps: int
How many timesteps to smoothly combine target and surrogate forces.
When the threshold is satisy we apply an increasing mixture of ML and
When the threshold is satisfied we apply an increasing mixture of ML and
target forces.
db_path: Path or str
Database in which to store the results of running the target calculator,
Expand All @@ -83,6 +87,7 @@ class SerialLearningCalculator(Calculator):
'train_freq': 1,
'train_max_size': None,
'train_recency_bias': 1.,
'train_from_original': False,
'target_ferr': 0.1, # TODO (wardlt): Make the error metric configurable
'min_target_fraction': 0.,
'n_blending_steps': 0,
Expand Down Expand Up @@ -115,6 +120,8 @@ class SerialLearningCalculator(Calculator):
"""Ranges from 0-1, describing mixture between surrogate and physics"""
model_version: int = 0
"""How many times the model has been retrained"""
models: Optional[list] = None
"""Ensemble of models from the latest training invocation. The same as ``parameters['models']`` if `train_from_original`"""

def set(self, **kwargs):
# TODO (wardlt): Fix ASE such that it does not try to do a numpy comparison on everything
Expand All @@ -132,10 +139,16 @@ def learner(self) -> BaseLearnableForcefield:
@staticmethod
def smoothing_function(x):
"""Smoothing used for blending surrogate with physics"""
return 0.5*((np.cos(np.pi*x)) + 1)
return 0.5 * ((np.cos(np.pi * x)) + 1)

def retrain_surrogate(self):
"""Retrain the surrogate models using the currently-available data"""
# Determine where the updated models will be stored
model_list = self.models = (
[None] * len(self.parameters['models']) # Create a new list
if self.parameters['train_from_original'] else
self.parameters['models'] # Edit it in place
)

# Load in the data from the db
db_path = self.parameters['db_path']
Expand All @@ -154,7 +167,6 @@ def retrain_surrogate(self):
return

# Train each model using a different, randomly-selected subset of the data
model_list = self.parameters['models'] # Edit it in place
self.train_logs = []
for i, model_msg in enumerate(self.parameters['models']):
# Assign splits such that the same entries do not switch between train/validation as test grows
Expand All @@ -166,7 +178,8 @@ def retrain_surrogate(self):
# Downselect training set if it is larger than the fixed maximum
train_max_size = self.parameters['train_max_size']
if train_max_size is not None and len(train_atoms) > train_max_size:
valid_size = train_max_size * len(valid_atoms) // len(train_atoms) # Decrease the validation size proportionally
# Decrease the validation size proportionally
valid_size = train_max_size * len(valid_atoms) // len(train_atoms)

train_weights = np.geomspace(1, self.parameters['train_recency_bias'], len(train_atoms))
train_ids = rng.choice(len(train_atoms), size=(train_max_size,), p=train_weights / train_weights.sum(), replace=False)
Expand All @@ -193,7 +206,7 @@ def calculate(
if self.surrogate_calc is None:
self.retrain_surrogate()
self.surrogate_calc = EnsembleCalculator(
calculators=[self.learner.make_calculator(m, self.parameters['device']) for m in self.parameters['models']]
calculators=[self.learner.make_calculator(m, self.parameters['device']) for m in self.models]
)
self.surrogate_calc.calculate(atoms, properties + ['forces'], system_changes) # Make sure forces are computed too

Expand Down Expand Up @@ -236,7 +249,7 @@ def calculate(
# handle differences in voigt vs (3,3) stress convention
if k == 'stress' and r_target.shape != r_surrogate.shape:
r_target, r_surrogate = map(to_voigt, [r_target, r_surrogate])
self.results[k] = self.lambda_target*r_target + (1-self.lambda_target)*r_surrogate
self.results[k] = self.lambda_target * r_target + (1 - self.lambda_target) * r_surrogate
else:
# the surrogate may have some extra results which we store
self.results[k] = results_surrogate[k]
Expand Down Expand Up @@ -305,15 +318,15 @@ def get_state(self) -> dict[str, Any]:
'threshold': self.threshold,
'alpha': self.alpha,
'blending_step': int(self.blending_step),
'error_history': list(self.error_history),
'error_history': list(self.error_history) if self.error_history is not None else [],
'new_points': self.new_points,
'train_logs': self.train_logs,
'total_invocations': self.total_invocations,
'target_invocations': self.target_invocations,
'model_version': self.model_version
}
if self.surrogate_calc is not None:
output['models'] = [self.learner.serialize_model(s) for s in self.parameters['models']]
output['models'] = [self.learner.serialize_model(s) for s in self.models]
return output

def set_state(self, state: dict[str, Any]):
Expand All @@ -337,7 +350,11 @@ def set_state(self, state: dict[str, Any]):

# Remake the surrogate calculator, if available
if 'models' in state:
self.parameters['models'] = state['models']
# Store in a different place depending on whether we are training from original or latest
if self.parameters['train_from_original']:
self.models = state['models']
else:
self.models = self.parameters['models'] = state['models'] # Both are the same
self.surrogate_calc = EnsembleCalculator(
calculators=[self.learner.make_calculator(m, self.parameters['device']) for m in state['models']]
)
Expand Down
17 changes: 17 additions & 0 deletions tests/test_proxima.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,20 @@ def test_blending(starting_frame, simple_model, target_calc, tmpdir):
assert not calc.used_surrogate
assert calc.blending_step == 0
assert calc.lambda_target == 1


def test_train_from_original(starting_frame, simple_proxima, initialized_db):
# Set the train_from_original flag
simple_proxima.parameters['train_from_original'] = True
assert simple_proxima.models is None # Should start unset

# Ensure the original models are unchanged
orig_models = simple_proxima.parameters['models'].copy()
simple_proxima.retrain_surrogate()
assert all(x is y for x, y in zip(simple_proxima.parameters['models'], orig_models))
assert all(x is not y for x, y in zip(simple_proxima.models, orig_models))

# Ensure that pickling does not alter the original models
state = pkl.loads(pkl.dumps(simple_proxima.get_state()))
simple_proxima.set_state(state)
assert all(x is y for x, y in zip(simple_proxima.parameters['models'], orig_models))
Loading