From cf729859b9e7f061a1cb635ca420d8f511ba47cc Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Wed, 13 Nov 2024 06:40:45 +0000 Subject: [PATCH] add test --- test/pytest/test_transpose_concat.py | 34 ++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/test/pytest/test_transpose_concat.py b/test/pytest/test_transpose_concat.py index 884d5859d5..642455f2c8 100644 --- a/test/pytest/test_transpose_concat.py +++ b/test/pytest/test_transpose_concat.py @@ -54,3 +54,37 @@ def test_accuracy(data, keras_model, hls_model): y_hls4ml = hls_model.predict(X).reshape(y_keras.shape) # "accuracy" of hls4ml predictions vs keras np.testing.assert_allclose(y_keras, y_hls4ml, rtol=0, atol=1e-04, verbose=True) + + +@pytest.fixture(scope='module') +def keras_model_highdim(): + inp = Input(shape=(2, 3, 4, 5, 6), name='input_1') + out = Permute((3, 5, 4, 1, 2))(inp) + model = Model(inputs=inp, outputs=out) + return model + + +@pytest.fixture(scope='module') +def data_highdim(): + X = np.random.randint(-128, 127, (100, 2, 3, 4, 5, 6)) / 128 + X = X.astype(np.float32) + return X + + +@pytest.mark.parametrize('io_type', ['io_stream', 'io_parallel']) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis']) +def test_highdim_permute(data_highdim, keras_model_highdim, io_type, backend): + X = data_highdim + model = keras_model_highdim + + model_hls = hls4ml.converters.convert_from_keras_model( + model, + io_type=io_type, + backend=backend, + output_dir=str(test_root_path / f'hls4mlprj_highdim_transpose_{backend}_{io_type}'), + ) + model_hls.compile() + y_keras = model.predict(X) + y_hls4ml = model_hls.predict(X).reshape(y_keras.shape) # type: ignore + + assert np.all(y_keras == y_hls4ml)