Skip to content

Commit

Permalink
fixed error when too many model weights are sent to frontend
Browse files Browse the repository at this point in the history
  • Loading branch information
mishadr committed Oct 31, 2024
1 parent 1fdd771 commit 9615a8c
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 15 deletions.
20 changes: 13 additions & 7 deletions src/models_builder/gnn_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def get_hash(self):
gnn_name_hash = hash_data_sha256(json_object.encode('utf-8'))
return gnn_name_hash

def get_full_info(self):
def get_full_info(self, tensor_size_limit=None):
""" Get available info about model for frontend
"""
# FIXMe architecture and weights can be not accessible
Expand All @@ -118,7 +118,7 @@ def get_full_info(self):
except (AttributeError, NotImplementedError):
pass
try:
result["weights"] = self.get_weights()
result["weights"] = self.get_weights(tensor_size_limit=tensor_size_limit)
except (AttributeError, NotImplementedError):
pass
try:
Expand Down Expand Up @@ -165,9 +165,8 @@ def get_neurons(self):
neurons.append(n_neurons)
return neurons

def get_weights(self):
"""
Get model weights calling torch.nn.Module.state_dict() to draw them on the frontend.
def get_weights(self, tensor_size_limit=None):
""" Get model weights calling torch.nn.Module.state_dict() to draw them on the frontend.
"""
try:
state_dict = self.state_dict()
Expand All @@ -185,8 +184,15 @@ def get_weights(self):

k = sub_keys[-1]
if type(value) == UninitializedParameter:
continue
part[k] = value.numpy().tolist()
part[k] = '?'
else:
size = 1
for dim in value.shape:
size *= dim
if tensor_size_limit and size > tensor_size_limit: # Tensor is too big - return just its shape
part[k] = 'x'.join(str(d) for d in value.shape)
else:
part[k] = value.numpy().tolist()
return model_data


Expand Down
13 changes: 7 additions & 6 deletions web_interface/back_front/model_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from web_interface.back_front.block import Block, WrapperBlock
from web_interface.back_front.utils import WebInterfaceError, json_dumps, get_config_keys

TENSOR_SIZE_LIMIT = 1024 # Max size of weights tensor we sent to frontend


class ModelWBlock(WrapperBlock):
def __init__(self, name, blocks, *args, **kwargs):
Expand Down Expand Up @@ -58,7 +60,7 @@ def _submit(self):

self._object = self.model_manager
self._result = self._object.get_full_info()
self._result.update(self._object.gnn.get_full_info())
self._result.update(self._object.gnn.get_full_info(tensor_size_limit=TENSOR_SIZE_LIMIT))

def get_index(self):
""" Get all available models with respect to current dataset
Expand All @@ -71,9 +73,8 @@ def get_index(self):
keys_list, full_keys_list, dir_structure, _ = DataInfo.take_keys_etc_by_prefix(
prefix=("data_root", "data_prepared")
)
values_info = DataInfo.values_list_by_path_and_keys(path=path,
full_keys_list=full_keys_list,
dir_structure=dir_structure)
values_info = DataInfo.values_list_by_path_and_keys(
path=path, full_keys_list=full_keys_list, dir_structure=dir_structure)
ps = index.filter(dict(zip(keys_list, values_info)))
return [ps.to_json(), json_dumps(info)]

Expand Down Expand Up @@ -105,7 +106,7 @@ def _finalize(self):

def _submit(self):
self._object = FrameworkGNNConstructor(self.model_config)
self._result = self._object.get_full_info()
self._result = self._object.get_full_info(tensor_size_limit=TENSOR_SIZE_LIMIT)


class ModelCustomBlock(Block):
Expand Down Expand Up @@ -140,7 +141,7 @@ def _submit(self):
assert cm_path

self._object = UserCodeInfo.take_user_model_obj(cm_path, self.model_name["model"])
self._result = self._object.get_full_info()
self._result = self._object.get_full_info(tensor_size_limit=TENSOR_SIZE_LIMIT)

def get_index(self):
""" Get all available models with respect to current dataset
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ class PanelModelArchView extends PanelView {
// Init all SVG primitives
this.primitives = {}

// Function drawing an parameters data object
let marginText = 5
let marginBlocks = 140 // TODO make it = rightmost bound of all texts
// Function drawing an parameters data object
let draw = (kv, primitives, offsets, keysList=[], depth=0) => {
for (let [key, value] of Object.entries(kv)) {
let text
Expand Down Expand Up @@ -173,6 +173,15 @@ class PanelModelArchView extends PanelView {
primitives[key] = text
offsets[1] += 25
}
else if (typeof(value) === 'string') { // String, indicates tensor size is over limit
this.svgPanel.$svg.append(Svg.text(
`[Tensor ${value}]`,
marginBlocks + 12*depth, offsets[1],
'middle', '20px',
'normal', "#000000"
))
offsets[1] += this.size + 20
}
else if (value.constructor === Array) {
let arrayPrimitives = []
primitives[key] = arrayPrimitives
Expand Down Expand Up @@ -248,7 +257,7 @@ class PanelModelArchView extends PanelView {
}
}
}
else console.error("Unknown type")
else console.error("Model data contains unknown data type.")
}
}

Expand Down

0 comments on commit 9615a8c

Please sign in to comment.