Skip to content

Commit

Permalink
Merge branch 'tensorflow:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
Corallus-Caninus authored Apr 20, 2022
2 parents 01e744d + c526e7e commit cee8013
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 13 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ rustversion = "1.0.5"
[dev-dependencies]
random = "0.12.2"
serial_test = "0.5.1"
image = "0.23.14"

[features]
tensorflow_gpu = ["tensorflow-sys/tensorflow_gpu"]
Expand Down Expand Up @@ -67,4 +66,5 @@ name = "regression_checkpoint"
name = "xor"

[[example]]
name = "mobilenetv3"
name = "mobilenetv3"
required-features = ["eager"]
28 changes: 17 additions & 11 deletions examples/mobilenetv3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ use tensorflow::Status;
use tensorflow::Tensor;
use tensorflow::DEFAULT_SERVING_SIGNATURE_DEF_KEY;

use image::io::Reader as ImageReader;
use image::GenericImageView;
use tensorflow::eager::{self, raw_ops, ToTensorHandle};

fn main() -> Result<(), Box<dyn Error>> {
let export_dir = "examples/mobilenetv3";
Expand All @@ -30,16 +29,23 @@ fn main() -> Result<(), Box<dyn Error>> {
));
}

// Create input variables for our addition
let mut x = Tensor::new(&[1, 224, 224, 3]);
let img = ImageReader::open("examples/mobilenetv3/sample.png")?.decode()?;
for (i, (_, _, pixel)) in img.pixels().enumerate() {
x[3 * i] = pixel.0[0] as f32;
x[3 * i + 1] = pixel.0[1] as f32;
x[3 * i + 2] = pixel.0[2] as f32;
}
// Create an eager execution context
let opts = eager::ContextOptions::new();
let ctx = eager::Context::new(opts)?;

// Load an input image.
let fname = "examples/mobilenetv3/sample.png".to_handle(&ctx)?;
let buf = raw_ops::read_file(&ctx, &fname)?;
let img = raw_ops::decode_image(&ctx, &buf)?;
let cast2float = raw_ops::Cast::new().DstT(tensorflow::DataType::Float);
let img = cast2float.call(&ctx, &img)?;
let batch = raw_ops::expand_dims(&ctx, &img, &0)?; // add batch dim
let readonly_x = batch.resolve()?;

// The current eager API implementation requires unsafe block to feed the tensor into a graph.
let x: Tensor<f32> = unsafe { readonly_x.into_tensor() };

// Load the saved model exported by zenn_savedmodel.py.
// Load the model.
let mut graph = Graph::new();
let bundle =
SavedModelBundle::load(&SessionOptions::new(), &["serve"], &mut graph, export_dir)?;
Expand Down

0 comments on commit cee8013

Please sign in to comment.