-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Demonstrate RMF and serialization support
- Loading branch information
Showing
6 changed files
with
311 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
#ifndef IMPFOO_MY_RESTRAINT2_H | ||
#define IMPFOO_MY_RESTRAINT2_H | ||
|
||
#include <IMP/foo/foo_config.h> | ||
#include <IMP/Restraint.h> | ||
#include <cereal/access.hpp> | ||
|
||
IMPFOO_BEGIN_NAMESPACE | ||
|
||
class IMPFOOEXPORT MyRestraint2 : public Restraint { | ||
ParticleIndex p_; | ||
double k_; | ||
|
||
public: | ||
MyRestraint2(Model *m, ParticleIndex p, double k); | ||
void do_add_score_and_derivatives(ScoreAccumulator sa) const override; | ||
ModelObjectsTemp do_get_inputs() const override; | ||
IMP_OBJECT_METHODS(MyRestraint2); | ||
|
||
// RMF output support | ||
RestraintInfo *get_static_info() const override; | ||
|
||
// Serialization support | ||
MyRestraint2() {} | ||
private: | ||
friend class cereal::access; | ||
template<class Archive> void serialize(Archive &ar) { | ||
ar(cereal::base_class<Restraint>(this), p_, k_); | ||
} | ||
IMP_OBJECT_SERIALIZE_DECL(MyRestraint2); | ||
|
||
}; | ||
|
||
IMPFOO_END_NAMESPACE | ||
|
||
#endif /* IMPFOO_MY_RESTRAINT2_H */ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
#include <IMP/foo/MyRestraint2.h> | ||
#include <IMP/core/XYZ.h> | ||
|
||
IMPFOO_BEGIN_NAMESPACE | ||
|
||
MyRestraint2::MyRestraint2(Model *m, ParticleIndex p, double k) | ||
: Restraint(m, "MyRestraint%1%"), p_(p), k_(k) {} | ||
|
||
void MyRestraint2::do_add_score_and_derivatives(ScoreAccumulator sa) const { | ||
core::XYZ d(get_model(), p_); | ||
double score = .5 * k_ * square(d.get_z()); | ||
if (sa.get_derivative_accumulator()) { | ||
double deriv = k_ * d.get_z(); | ||
d.add_to_derivative(2, deriv, *sa.get_derivative_accumulator()); | ||
} | ||
sa.add_score(score); | ||
} | ||
|
||
ModelObjectsTemp MyRestraint2::do_get_inputs() const { | ||
return ModelObjectsTemp(1, get_model()->get_particle(p_)); | ||
} | ||
|
||
// RMF output support | ||
RestraintInfo *MyRestraint2::get_static_info() const { | ||
IMP_NEW(RestraintInfo, ri, ()); | ||
ri->add_string("type", "IMP.foo.MyRestraint2"); | ||
ri->add_float("force constant", k_); | ||
return ri.release(); | ||
} | ||
|
||
// Serialization support | ||
IMP_OBJECT_SERIALIZE_IMPL(IMP::foo::MyRestraint2); | ||
|
||
IMPFOO_END_NAMESPACE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from __future__ import print_function, division | ||
import IMP | ||
import IMP.test | ||
import IMP.algebra | ||
import IMP.core | ||
import IMP.foo | ||
import pickle | ||
|
||
|
||
class Tests(IMP.test.TestCase): | ||
|
||
def test_my_restraint(self): | ||
"""Test scoring of MyRestraint2""" | ||
m = IMP.Model() | ||
p = m.add_particle("p") | ||
d = IMP.core.XYZ.setup_particle(m, p, IMP.algebra.Vector3D(1,2,3)) | ||
r = IMP.foo.MyRestraint2(m, p, 10.) | ||
self.assertAlmostEqual(r.evaluate(True), 45.0, delta=1e-4) | ||
self.assertLess(IMP.algebra.get_distance(d.get_derivatives(), | ||
IMP.algebra.Vector3D(0,0,30)), | ||
1e-4) | ||
self.assertEqual(len(r.get_inputs()), 1) | ||
|
||
def test_static_info(self): | ||
"""Test static info of MyRestraint2""" | ||
m = IMP.Model() | ||
p = m.add_particle("p") | ||
d = IMP.core.XYZ.setup_particle(m, p, IMP.algebra.Vector3D(1,2,3)) | ||
r = IMP.foo.MyRestraint2(m, p, 10.) | ||
info = r.get_static_info() | ||
self.assertEqual(info.get_number_of_string(), 1) | ||
self.assertEqual(info.get_string_key(0), "type") | ||
self.assertEqual(info.get_string_value(0), "IMP.foo.MyRestraint2") | ||
|
||
self.assertEqual(info.get_number_of_float(), 1) | ||
self.assertEqual(info.get_float_key(0), "force constant") | ||
self.assertAlmostEqual(info.get_float_value(0), 10.0, delta=0.001) | ||
|
||
def test_serialize(self): | ||
"""Test (un-)serialize of MyRestraint2""" | ||
m = IMP.Model() | ||
p = m.add_particle("p") | ||
d = IMP.core.XYZ.setup_particle(m, p, IMP.algebra.Vector3D(1,2,3)) | ||
r = IMP.foo.MyRestraint2(m, p, 10.) | ||
self.assertAlmostEqual(r.evaluate(False), 45.0, delta=1e-3) | ||
dump = pickle.dumps(r) | ||
newr = pickle.loads(dump) | ||
self.assertAlmostEqual(newr.evaluate(False), 45.0, delta=1e-3) | ||
|
||
def test_serialize_polymorphic(self): | ||
"""Test (un-)serialize of MyRestraint2 via polymorphic pointer""" | ||
m = IMP.Model() | ||
p = m.add_particle("p") | ||
d = IMP.core.XYZ.setup_particle(m, p, IMP.algebra.Vector3D(1,2,3)) | ||
r = IMP.foo.MyRestraint2(m, p, 10.) | ||
sf = IMP.core.RestraintsScoringFunction([r]) | ||
self.assertAlmostEqual(sf.evaluate(False), 45.0, delta=1e-3) | ||
dump = pickle.dumps(sf) | ||
newsf = pickle.loads(dump) | ||
self.assertAlmostEqual(newsf.evaluate(False), 45.0, delta=1e-3) | ||
|
||
|
||
if __name__ == '__main__': | ||
IMP.test.main() |