Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OV JS]: Expose ov.saveModel() functionality #27148

Merged
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/bindings/js/node/include/addon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,19 @@ void init_class(Napi::Env env,
Prototype func,
Napi::FunctionReference& reference);

template <typename Callable>
void init_function(Napi::Env env,
Napi::Object exports,
std::string func_name,
Callable func) {
const auto& napi_func = Napi::Function::New(env, func, func_name);

exports.Set(func_name, napi_func);
}

Napi::Object init_module(Napi::Env env, Napi::Object exports);

/**
* @brief Saves model in a specified path.
*/
Napi::Value save_model_sync(const Napi::CallbackInfo& info);
6 changes: 6 additions & 0 deletions src/bindings/js/node/include/type_validation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ const char* get_attr_type<Napi::String>();
template <>
const char* get_attr_type<Napi::Object>();

template <>
const char* get_attr_type<Napi::Boolean>();

template <>
const char* get_attr_type<Napi::Buffer<uint8_t>>();

Expand All @@ -52,6 +55,9 @@ bool validate_value<Napi::String>(const Napi::Env& env, const Napi::Value& value
template <>
bool validate_value<Napi::Object>(const Napi::Env& env, const Napi::Value& value);

template <>
bool validate_value<Napi::Boolean>(const Napi::Env& env, const Napi::Value& value);

template <>
bool validate_value<Napi::Buffer<uint8_t>>(const Napi::Env& env, const Napi::Value& value);

Expand Down
16 changes: 16 additions & 0 deletions src/bindings/js/node/lib/addon.ts
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,22 @@ export interface NodeAddon {
resizeAlgorithm: typeof resizeAlgorithm;
PrePostProcessor: PrePostProcessorConstructor;
};

/**
* It saves a model into IR files (xml and bin).
* Floating point weights are compressed to FP16 by default.
* This method saves a model to IR applying all necessary transformations
* that usually applied in model conversion flow provided by mo tool.
* Particularly, floating point weights are compressed to FP16,
* debug information in model nodes are cleaned up, etc.
* @param model The model which will be
* converted to IR representation and saved.
* @param path The path for saving the model.
* @param compressToFp16 Whether to compress
* floating point weights to FP16. Default is set to `true`.
*/
saveModelSync(model: Model, path: string, compressToFp16?: boolean): void;

element: typeof element;
}

Expand Down
29 changes: 29 additions & 0 deletions src/bindings/js/node/src/addon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
#include "node/include/compiled_model.hpp"
#include "node/include/core_wrap.hpp"
#include "node/include/element_type.hpp"
#include "node/include/errors.hpp"
#include "node/include/helper.hpp"
#include "node/include/infer_request.hpp"
#include "node/include/model_wrap.hpp"
#include "node/include/node_output.hpp"
#include "node/include/partial_shape_wrap.hpp"
#include "node/include/preprocess/preprocess.hpp"
#include "node/include/tensor.hpp"
#include "node/include/type_validation.hpp"
#include "openvino/openvino.hpp"

void init_class(Napi::Env env,
Expand All @@ -27,6 +30,30 @@ void init_class(Napi::Env env,
exports.Set(class_name, prototype);
}

Napi::Value save_model_sync(const Napi::CallbackInfo& info) {
std::vector<std::string> allowed_signatures;
try {
if (ov::js::validate<ModelWrap, Napi::String>(info, allowed_signatures)) {
const auto& model = info[0].ToObject();
const auto m = Napi::ObjectWrap<ModelWrap>::Unwrap(model);
const auto path = js_to_cpp<std::string>(info, 1);
ov::save_model(m->get_model(), path);
} else if (ov::js::validate<ModelWrap, Napi::String, Napi::Boolean>(info, allowed_signatures)) {
const auto& model = info[0].ToObject();
const auto m = Napi::ObjectWrap<ModelWrap>::Unwrap(model);
const auto path = js_to_cpp<std::string>(info, 1);
const auto compress_to_fp16 = info[2].ToBoolean();
ov::save_model(m->get_model(), path, compress_to_fp16);
} else {
OPENVINO_THROW("'saveModelSync'", ov::js::get_parameters_error_msg(info, allowed_signatures));
}
} catch (const std::exception& e) {
reportError(info.Env(), e.what());
}

return info.Env().Undefined();
}

/** @brief Initialize native add-on */
Napi::Object init_module(Napi::Env env, Napi::Object exports) {
auto addon_data = new AddonData();
Expand All @@ -41,6 +68,8 @@ Napi::Object init_module(Napi::Env env, Napi::Object exports) {
init_class(env, exports, "ConstOutput", &Output<const ov::Node>::get_class, addon_data->const_output);
init_class(env, exports, "PartialShape", &PartialShapeWrap::get_class, addon_data->partial_shape);

init_function(env, exports, "saveModelSync", save_model_sync);

preprocess::init(env, exports);
element::init(env, exports);

Expand Down
10 changes: 10 additions & 0 deletions src/bindings/js/node/src/type_validation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ const char* get_attr_type<Napi::Object>() {
return NapiArg::get_type_name(napi_object);
}

template <>
const char* get_attr_type<Napi::Boolean>() {
return NapiArg::get_type_name(napi_boolean);
}

template <>
const char* get_attr_type<Napi::Buffer<uint8_t>>() {
return BindingTypename::BUFFER;
Expand Down Expand Up @@ -115,6 +120,11 @@ bool validate_value<Napi::Object>(const Napi::Env& env, const Napi::Value& value
return napi_object == value.Type();
}

template <>
bool validate_value<Napi::Boolean>(const Napi::Env& env, const Napi::Value& value) {
return napi_boolean == value.Type();
}

template <>
bool validate_value<Napi::Buffer<uint8_t>>(const Napi::Env& env, const Napi::Value& value) {
return value.IsBuffer();
Expand Down
59 changes: 58 additions & 1 deletion src/bindings/js/node/tests/unit/basic.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,14 @@
const { addon: ov } = require('../..');
const assert = require('assert');
const { describe, it, before, beforeEach } = require('node:test');
const { testModels, getModelPath, isModelAvailable } = require('./utils.js');
const {
testModels,
compareModels,
getModelPath,
isModelAvailable,
hub-bla marked this conversation as resolved.
Show resolved Hide resolved
} = require('./utils.js');
const epsilon = 0.5;
const outDir = 'tests/unit/out/';
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does after unit tests run git has clear state?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I believe I should add tests/unit/out to .gitignore, similar to how it was done for tests/unit/test_models path.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use mkdtemp(). The prefix can be simple. You may also see this usage of os.tempdir()

Copy link
Contributor

@almilosz almilosz Oct 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding files to .gitignore won't remove files in the CI's check. So you should make sure they are deleted


describe('ov basic tests.', () => {
let testXml = null;
Expand All @@ -33,6 +39,57 @@ describe('ov basic tests.', () => {
assert.ok(devices.includes('CPU'));
});

describe('ov.saveModelSync()', () => {
it('saveModelSync(model, path, compressToFp16=true)', () => {
const xmlPath = `${outDir}${model.getName()}_fp16.xml`;
assert.doesNotThrow(() => ov.saveModelSync(model, xmlPath, true));

const savedModel = core.readModelSync(xmlPath);
assert.doesNotThrow(() => compareModels(model, savedModel));
});

it('saveModelSync(model, path, compressToFp16)', () => {
const xmlPath = `${outDir}${model.getName()}_fp32.xml`;
assert.doesNotThrow(() => ov.saveModelSync(model, xmlPath));

const savedModel = core.readModelSync(xmlPath);
assert.doesNotThrow(() => compareModels(model, savedModel));
});
it('saveModelSync(model, path, compressToFp16=false)', () => {
const xmlPath = `${outDir}${model.getName()}_fp32.xml`;
assert.doesNotThrow(() => ov.saveModelSync(model, xmlPath, false));

const savedModel = core.readModelSync(xmlPath);
assert.doesNotThrow(() => compareModels(model, savedModel));
});

it('saveModelSync(model) throws', () => {
const expectedMsg = (
'\'saveModelSync\' method called with incorrect parameters.\n' +
'Provided signature: (object) \n' +
'Allowed signatures:\n' +
'- (Model, string)\n' +
'- (Model, string, boolean)\n'
).replace(/[()]/g, '\\$&');

assert.throws(
() => ov.saveModelSync(model),
new RegExp(expectedMsg));
});

it('saveModelSync(model, path) throws with incorrect path', () => {
const expectedMsg = (
'Path for xml file doesn\'t ' +
'contains file name with \'xml\' extension'
).replace(/[()]/g, '\\$&');

const noXmlPath = `${outDir}${model.getName()}_fp32`;
assert.throws(
() => ov.saveModelSync(model, noXmlPath),
new RegExp(expectedMsg));
});
});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Propose to add more test that check invalid parameters passing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely, it's just rough draft for now.


describe('Core.getVersions()', () => {
it('getVersions(validDeviceName: string)', () => {
const deviceVersion = core.getVersions('CPU');
Expand Down
26 changes: 26 additions & 0 deletions src/bindings/js/node/tests/unit/utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,38 @@ const testModels = {
};

module.exports = {
compareModels,
hub-bla marked this conversation as resolved.
Show resolved Hide resolved
getModelPath,
downloadTestModel,
isModelAvailable,
testModels,
};

function compareModels(model1, model2) {
const differences = [];
if (model1.getFriendlyName() !== model2.getFriendlyName()) {
differences.push('Friendly names of models are not equal ' +
`model_one: ${model1.getFriendlyName()},` +
`model_two: ${model2.getFriendlyName()}`);
}

if (model1.inputs.length !== model2.inputs.length) {
differences.push('Number of models\' inputs are not equal ' +
`model_one: ${model1.inputs.length}, ` +
`model_two: ${model2.inputs.length}`);
}

if (model1.outputs.length !== model2.outputs.length) {
differences.push('Number of models\' outputs are not equal ' +
`model_one: ${model1.outputs.length}, ` +
`model_two: ${model2.outputs.length}`);
}

if (differences.length) {
throw new Error(differences.join('\n'));
}
}

hub-bla marked this conversation as resolved.
Show resolved Hide resolved
function getModelPath(isFP16 = false) {
const modelName = `test_model_fp${isFP16 ? 16 : 32}`;

Expand Down
Loading