-
Notifications
You must be signed in to change notification settings - Fork 422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cannot import graph generated with Tensorflow 2.x #388
Comments
ping |
The example code you tried is not work in TF 2.x. Please refer other examples checked in #309 . |
@dskkato I am not sure if you are right. I tried writting the code with TF 2.x style ( Did you try importing the PB file and creating a session with the imported graph? |
In that code, the node name to read the import tensorflow.compat.v1 as tf
tf.disable_eager_execution()
x = tf.placeholder(tf.float32, name='x')
y = tf.placeholder(tf.float32, name='y')
w = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='w')
b = tf.Variable(tf.zeros([1]), name='b')
y_hat = w * x + b
loss = tf.reduce_mean(tf.square(y_hat - y))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss, name='train')
init = tf.variables_initializer(tf.global_variables(), name='init')
definition = tf.Session().graph_def
directory = 'examples/regression'
tf.train.write_graph(definition, directory, 'model.pb', as_text=False)
# for debug
tf.train.write_graph(definition, directory, 'model.pbtxt', as_text=True) As is in the diff --git a/examples/regression.rs b/examples/regression.rs
index 5393b7ac..66f472d4 100644
--- a/examples/regression.rs
+++ b/examples/regression.rs
@@ -55,8 +55,8 @@ fn main() -> Result<(), Box<dyn Error>> {
let op_y = graph.operation_by_name_required("y")?;
let op_init = graph.operation_by_name_required("init")?;
let op_train = graph.operation_by_name_required("train")?;
- let op_w = graph.operation_by_name_required("w")?;
- let op_b = graph.operation_by_name_required("b")?;
+ let op_w = graph.operation_by_name_required("w/Read/ReadVariableOp")?;
+ let op_b = graph.operation_by_name_required("b/Read/ReadVariableOp")?;
// Load the test data into the session.
let mut init_step = SessionRunArgs::new(); |
Please note that I'm not sure that the canonical way to get the above node names. |
Thanks, this is exactly what I was looking for. |
The examples, for instance, examples/regression.rs include code to load a graph from a .pb file
The code works with a .pb file generated with Tensorflow 1.x for Python.
But if we run this Python file under Tensorflow 2.x after making some small changes for compatibility
But when trying to run the example, after initializing Tensorflow session, the following error appears
Error: {inner:0x55665fcf34b0, InvalidArgument: Requested tensor type does not match actual tensor type: Resource vs Float}
The text was updated successfully, but these errors were encountered: