Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot import graph generated with Tensorflow 2.x #388

Open
ramon-garcia opened this issue Nov 22, 2022 · 6 comments
Open

Cannot import graph generated with Tensorflow 2.x #388

ramon-garcia opened this issue Nov 22, 2022 · 6 comments

Comments

@ramon-garcia
Copy link
Contributor

The examples, for instance, examples/regression.rs include code to load a graph from a .pb file

    let mut graph = Graph::new();
    let mut proto = Vec::new();
    File::open(filename)?.read_to_end(&mut proto)?;
    graph.import_graph_def(&proto, &ImportGraphDefOptions::new())?;

The code works with a .pb file generated with Tensorflow 1.x for Python.

import os
import tensorflow as tf

x = tf.placeholder(tf.float32, name='x')
y = tf.placeholder(tf.float32, name='y')

w = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='w')
b = tf.Variable(tf.zeros([1]), name='b')
y_hat = w * x + b

loss = tf.reduce_mean(tf.square(y_hat - y))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss, name='train')

init = tf.variables_initializer(tf.global_variables(), name='init')

definition = tf.Session().graph_def
directory = 'examples/regression'
tf.train.write_graph(definition, directory, 'model.pb', as_text=False)

But if we run this Python file under Tensorflow 2.x after making some small changes for compatibility

import os
import tensorflow as tf
import tensorflow.compat.v1 as tf1
tf1.disable_eager_execution()

x = tf1.placeholder(tf.float32, name='x')
y = tf1.placeholder(tf.float32, name='y')

w = tf.Variable(tf.random_uniform_initializer(minval=-1.0, maxval=1.0)(shape=[1]), name='w')
b = tf.Variable(tf.zeros([1]), name='b')
y_hat = w * x + b

loss = tf.reduce_mean(tf.square(y_hat - y))
optimizer = tf1.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss, name='train')

init = tf1.variables_initializer(tf1.global_variables(), name='init')

definition = tf1.Session().graph_def
directory = 'examples/regression'
tf.io.write_graph(definition, directory, 'model.pb', as_text=False)

But when trying to run the example, after initializing Tensorflow session, the following error appears
Error: {inner:0x55665fcf34b0, InvalidArgument: Requested tensor type does not match actual tensor type: Resource vs Float}

@ramon-garcia
Copy link
Contributor Author

ping

@dskkato
Copy link
Contributor

dskkato commented Jan 6, 2023

The example code you tried is not work in TF 2.x. Please refer other examples checked in #309 .

@ramon-garcia
Copy link
Contributor Author

@dskkato I am not sure if you are right. I tried writting the code with TF 2.x style (@tf.function) and the same error appeared.

Did you try importing the PB file and creating a session with the imported graph?

@dskkato
Copy link
Contributor

dskkato commented Jan 8, 2023

In that code, the node name to read the Variable are different between TF1 and TF2.

import tensorflow.compat.v1 as tf
tf.disable_eager_execution()

x = tf.placeholder(tf.float32, name='x')
y = tf.placeholder(tf.float32, name='y')

w = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='w')
b = tf.Variable(tf.zeros([1]), name='b')
y_hat = w * x + b

loss = tf.reduce_mean(tf.square(y_hat - y))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss, name='train')

init = tf.variables_initializer(tf.global_variables(), name='init')

definition = tf.Session().graph_def
directory = 'examples/regression'
tf.train.write_graph(definition, directory, 'model.pb', as_text=False)

# for debug
tf.train.write_graph(definition, directory, 'model.pbtxt', as_text=True)

As is in the model.pbtxt, it seems that the node names are "w/Read/ReadVariableOp" for "w" and "b/Read/ReadVariableOp" for "b", respectively:

diff --git a/examples/regression.rs b/examples/regression.rs
index 5393b7ac..66f472d4 100644
--- a/examples/regression.rs
+++ b/examples/regression.rs
@@ -55,8 +55,8 @@ fn main() -> Result<(), Box<dyn Error>> {
     let op_y = graph.operation_by_name_required("y")?;
     let op_init = graph.operation_by_name_required("init")?;
     let op_train = graph.operation_by_name_required("train")?;
-    let op_w = graph.operation_by_name_required("w")?;
-    let op_b = graph.operation_by_name_required("b")?;
+    let op_w = graph.operation_by_name_required("w/Read/ReadVariableOp")?;
+    let op_b = graph.operation_by_name_required("b/Read/ReadVariableOp")?;
 
     // Load the test data into the session.
     let mut init_step = SessionRunArgs::new();

@dskkato
Copy link
Contributor

dskkato commented Jan 8, 2023

Please note that I'm not sure that the canonical way to get the above node names.

@ramon-garcia
Copy link
Contributor Author

Thanks, this is exactly what I was looking for.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants