Skip to content

Commit

Permalink
[StableHLO][API] Add API to get StableHLO version from PortableArtifact
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 702481032
  • Loading branch information
GleasonK authored and Google-ML-Automation committed Dec 4, 2024
1 parent af11c20 commit 40398df
Showing 1 changed file with 153 additions and 0 deletions.
153 changes: 153 additions & 0 deletions third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
@@ -1 +1,154 @@
diff --ruN a/stablehlo/stablehlo/dialect/Serialization.cpp b/stablehlo/stablehlo/dialect/Serialization.cpp
--- stablehlo/stablehlo/dialect/Serialization.cpp
+++ stablehlo/stablehlo/dialect/Serialization.cpp
@@ -15,6 +15,8 @@

#include "stablehlo/dialect/Serialization.h"

+#include "llvm/Support/MemoryBuffer.h"
+#include "mlir/Bytecode/BytecodeReader.h"
#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
@@ -28,6 +30,9 @@
#include "stablehlo/dialect/Version.h"
#include "stablehlo/dialect/VhloOps.h"
#include "stablehlo/transforms/Passes.h"
+#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "compat-passes"

namespace mlir {
namespace stablehlo {
@@ -89,5 +94,36 @@
return module;
}

+FailureOr<vhlo::Version> getPortableArtifactVersion(llvm::StringRef bytecode) {
+ auto logFailure = [&](llvm::StringRef message) {
+ LLVM_DEBUG(llvm::dbgs() << "Failed to get portable artifact version: "
+ << message << "\n");
+ return failure();
+ };
+ // Must start with MLiRxStableHLO_vX.Y.Z, minimum length of 19.
+ constexpr size_t minHeaderLength = 19;
+ if (bytecode.size() < minHeaderLength) return logFailure("min header");
+
+ // Truncate to the end of the null-terminated producer string.
+ size_t pos = bytecode.find('\0');
+ if (pos == llvm::StringRef::npos) return logFailure("no terminator");
+ bytecode = bytecode.substr(0, pos);
+
+ // Check if the bytecode is valid, starts with MLiR magic number.
+ if (!isBytecode(
+ llvm::MemoryBuffer::getMemBuffer(bytecode)->getMemBufferRef()))
+ return logFailure("not bytecode");
+
+ // Skip 4 bytes for the magic number.
+ std::string stablehloHeader = "StableHLO_v";
+ size_t stablehloPos = bytecode.find(stablehloHeader);
+ if (stablehloPos == llvm::StringRef::npos)
+ return logFailure("not a StableHLO portable artifact");
+
+ // Skip the 11 bytes for StableHLO_v to get the StableHLO version to parse.
+ StringRef version = bytecode.substr(stablehloPos + stablehloHeader.size());
+ return vhlo::Version::fromString(version);
+}
+
} // namespace stablehlo
} // namespace mlir
diff --ruN a/stablehlo/stablehlo/dialect/Serialization.h b/stablehlo/stablehlo/dialect/Serialization.h
--- stablehlo/stablehlo/dialect/Serialization.h
+++ stablehlo/stablehlo/dialect/Serialization.h
@@ -19,6 +19,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Support/LogicalResult.h"
+#include "stablehlo/dialect/Version.h"

namespace mlir {
namespace stablehlo {
@@ -43,6 +44,17 @@
OwningOpRef<ModuleOp> deserializePortableArtifact(StringRef sourceStr,
MLIRContext* context);

+// Get portable artifact version from the producer string after the MLIR
+// Bytecode magic number `MLïRStableHLO_vX.Y.Z` -> X.Y.Z
+// Returns failure if input string is not a valid portable artifact produced by
+// serializePortableArtifact APIs, which would cause the bytecode artifact to
+// not have the proper producer string.
+//
+// This method should be safe, since any changes to the bytecode format would
+// warrant a bytecode version bump, and MLIR bytecode gives the option to
+// specify a forward compatible bytecode version to target.
+FailureOr<vhlo::Version> getPortableArtifactVersion(llvm::StringRef bytecode);
+
} // namespace stablehlo
} // namespace mlir

diff --ruN a/stablehlo/stablehlo/tests/vhlo/vhlo_emit_version_api.1_1_0.mlir b/stablehlo/stablehlo/tests/vhlo/vhlo_emit_version_api.1_1_0.mlir
--- stablehlo/stablehlo/tests/vhlo/vhlo_emit_version_api.1_1_0.mlir
+++ stablehlo/stablehlo/tests/vhlo/vhlo_emit_version_api.1_1_0.mlir
@@ -0,0 +1,19 @@
+// RUN: stablehlo-translate --deserialize --print-stablehlo-version %s.bc | FileCheck %s --check-prefix=CHECK-VERSION
+// RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize --print-stablehlo-version | FileCheck %s --check-prefix=CHECK-VERSION-LATEST
+// RUN: stablehlo-translate --deserialize --print-stablehlo-version %s | FileCheck %s --check-prefix=CHECK-VERSION-NOT-BYTECODE
+
+// This file tests the `getPortableArtifactVersion` Serialization API.
+// Any breakages to this file likely indicate that the MLIR Bytecode Format
+// has changed, or that the StableHLO producer string emit by
+// `serializePortableArtifact` has changed.
+//
+// See the `getPortableArtifactVersion` doc comments for more details.
+
+// CHECK-VERSION: // Reading portable artifact with StableHLO version: 1.1.0
+// CHECK-VERSION-NOT-BYTECODE: // Failed parsing StableHLO version from portable artifact
+// CHECK-VERSION-LATEST: // Reading portable artifact with StableHLO version: {{.*}}
+
+func.func @main(%arg0: tensor<f32>) -> tensor<f32> {
+ %0 = stablehlo.add %arg0, %arg0 : tensor<f32>
+ func.return %0 : tensor<f32>
+}
diff --ruN a/stablehlo/stablehlo/tools/StablehloTranslateMain.cpp b/stablehlo/stablehlo/tools/StablehloTranslateMain.cpp
--- stablehlo/stablehlo/tools/StablehloTranslateMain.cpp
+++ stablehlo/stablehlo/tools/StablehloTranslateMain.cpp
@@ -53,6 +53,7 @@
#include "stablehlo/reference/Tensor.h"
#include "stablehlo/reference/Value.h"
#include "stablehlo/tests/CheckOps.h"
+#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/raw_ostream.h"

namespace mlir {

@@ -63,6 +64,12 @@

llvm::cl::opt<bool> stripDebuginfoOption(
"strip-debuginfo", llvm::cl::desc("Strip debug info from all operations"),
+ llvm::cl::init(false));
+
+llvm::cl::opt<bool> printStablehloVersion(
+ "print-stablehlo-version",
+ llvm::cl::desc(
+ "When deserializing a portable artifact, print the StableHLO version"),
llvm::cl::init(false));

llvm::cl::opt<std::string> targetOption(
@@ -306,6 +313,17 @@
TranslateToMLIRRegistration deserializeRegistration(
"deserialize", "Deserialize a portable artifact into a StableHLO program",
[](llvm::StringRef input, mlir::MLIRContext *context) {
+ if (printStablehloVersion.getValue()) {
+ auto version = stablehlo::getPortableArtifactVersion(input);
+ if (failed(version)) {
+ llvm::outs()
+ << "// Failed parsing StableHLO version from portable artifact\n";
+ } else {
+ llvm::outs()
+ << "// Reading portable artifact with StableHLO version: "
+ << *version << "\n";
+ }
+ }
return stablehlo::deserializePortableArtifact(input, context);
},
[](DialectRegistry &registry) {

0 comments on commit 40398df

Please sign in to comment.