diff --git a/jflux/modules/conditioner.py b/jflux/modules/conditioner.py index 74f1af4..846bd88 100644 --- a/jflux/modules/conditioner.py +++ b/jflux/modules/conditioner.py @@ -2,9 +2,9 @@ from flax import nnx from transformers import ( CLIPTokenizer, + T5Tokenizer, FlaxCLIPTextModel, FlaxT5EncoderModel, - T5Tokenizer, ) diff --git a/tests/modules/test_conditioner.py b/tests/modules/test_conditioner.py new file mode 100644 index 0000000..1d8c381 --- /dev/null +++ b/tests/modules/test_conditioner.py @@ -0,0 +1,10 @@ +import numpy as np + +from flux.modules.conditioner import HFEmbedder as TorchHFEmbedder +from jflux.modules.conditioner import HFEmbedder as JaxHFEmbedder + + +class HFEmbedderTestCase(np.testing.TestCase): + def test_hf_embed(self): + # initialize layers + TorchHFEmbedder() \ No newline at end of file