diff --git a/src/models_builder/gnn_constructor.py b/src/models_builder/gnn_constructor.py index 9c82197..edc2dc5 100644 --- a/src/models_builder/gnn_constructor.py +++ b/src/models_builder/gnn_constructor.py @@ -108,7 +108,7 @@ def get_hash(self): gnn_name_hash = hash_data_sha256(json_object.encode('utf-8')) return gnn_name_hash - def get_full_info(self): + def get_full_info(self, tensor_size_limit=None): """ Get available info about model for frontend """ # FIXMe architecture and weights can be not accessible @@ -118,7 +118,7 @@ def get_full_info(self): except (AttributeError, NotImplementedError): pass try: - result["weights"] = self.get_weights() + result["weights"] = self.get_weights(tensor_size_limit=tensor_size_limit) except (AttributeError, NotImplementedError): pass try: @@ -165,9 +165,8 @@ def get_neurons(self): neurons.append(n_neurons) return neurons - def get_weights(self): - """ - Get model weights calling torch.nn.Module.state_dict() to draw them on the frontend. + def get_weights(self, tensor_size_limit=None): + """ Get model weights calling torch.nn.Module.state_dict() to draw them on the frontend. """ try: state_dict = self.state_dict() @@ -185,8 +184,15 @@ def get_weights(self): k = sub_keys[-1] if type(value) == UninitializedParameter: - continue - part[k] = value.numpy().tolist() + part[k] = '?' + else: + size = 1 + for dim in value.shape: + size *= dim + if tensor_size_limit and size > tensor_size_limit: # Tensor is too big - return just its shape + part[k] = 'x'.join(str(d) for d in value.shape) + else: + part[k] = value.numpy().tolist() return model_data diff --git a/web_interface/back_front/model_blocks.py b/web_interface/back_front/model_blocks.py index 5d71215..b9c11b1 100644 --- a/web_interface/back_front/model_blocks.py +++ b/web_interface/back_front/model_blocks.py @@ -17,6 +17,8 @@ from web_interface.back_front.block import Block, WrapperBlock from web_interface.back_front.utils import WebInterfaceError, json_dumps, get_config_keys +TENSOR_SIZE_LIMIT = 1024 # Max size of weights tensor we sent to frontend + class ModelWBlock(WrapperBlock): def __init__(self, name, blocks, *args, **kwargs): @@ -58,7 +60,7 @@ def _submit(self): self._object = self.model_manager self._result = self._object.get_full_info() - self._result.update(self._object.gnn.get_full_info()) + self._result.update(self._object.gnn.get_full_info(tensor_size_limit=TENSOR_SIZE_LIMIT)) def get_index(self): """ Get all available models with respect to current dataset @@ -71,9 +73,8 @@ def get_index(self): keys_list, full_keys_list, dir_structure, _ = DataInfo.take_keys_etc_by_prefix( prefix=("data_root", "data_prepared") ) - values_info = DataInfo.values_list_by_path_and_keys(path=path, - full_keys_list=full_keys_list, - dir_structure=dir_structure) + values_info = DataInfo.values_list_by_path_and_keys( + path=path, full_keys_list=full_keys_list, dir_structure=dir_structure) ps = index.filter(dict(zip(keys_list, values_info))) return [ps.to_json(), json_dumps(info)] @@ -105,7 +106,7 @@ def _finalize(self): def _submit(self): self._object = FrameworkGNNConstructor(self.model_config) - self._result = self._object.get_full_info() + self._result = self._object.get_full_info(tensor_size_limit=TENSOR_SIZE_LIMIT) class ModelCustomBlock(Block): @@ -140,7 +141,7 @@ def _submit(self): assert cm_path self._object = UserCodeInfo.take_user_model_obj(cm_path, self.model_name["model"]) - self._result = self._object.get_full_info() + self._result = self._object.get_full_info(tensor_size_limit=TENSOR_SIZE_LIMIT) def get_index(self): """ Get all available models with respect to current dataset diff --git a/web_interface/static/js/presentation/right_panel/model/panelModelArchView.js b/web_interface/static/js/presentation/right_panel/model/panelModelArchView.js index 687eea2..ae50c50 100644 --- a/web_interface/static/js/presentation/right_panel/model/panelModelArchView.js +++ b/web_interface/static/js/presentation/right_panel/model/panelModelArchView.js @@ -101,9 +101,9 @@ class PanelModelArchView extends PanelView { // Init all SVG primitives this.primitives = {} - // Function drawing an parameters data object let marginText = 5 let marginBlocks = 140 // TODO make it = rightmost bound of all texts + // Function drawing an parameters data object let draw = (kv, primitives, offsets, keysList=[], depth=0) => { for (let [key, value] of Object.entries(kv)) { let text @@ -173,6 +173,15 @@ class PanelModelArchView extends PanelView { primitives[key] = text offsets[1] += 25 } + else if (typeof(value) === 'string') { // String, indicates tensor size is over limit + this.svgPanel.$svg.append(Svg.text( + `[Tensor ${value}]`, + marginBlocks + 12*depth, offsets[1], + 'middle', '20px', + 'normal', "#000000" + )) + offsets[1] += this.size + 20 + } else if (value.constructor === Array) { let arrayPrimitives = [] primitives[key] = arrayPrimitives @@ -248,7 +257,7 @@ class PanelModelArchView extends PanelView { } } } - else console.error("Unknown type") + else console.error("Model data contains unknown data type.") } }