-
Notifications
You must be signed in to change notification settings - Fork 443
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[StableHLO][API] Add API to get StableHLO version from PortableArtifact
PiperOrigin-RevId: 702481032
- Loading branch information
1 parent
af11c20
commit 40398df
Showing
1 changed file
with
153 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,154 @@ | ||
diff --ruN a/stablehlo/stablehlo/dialect/Serialization.cpp b/stablehlo/stablehlo/dialect/Serialization.cpp | ||
--- stablehlo/stablehlo/dialect/Serialization.cpp | ||
+++ stablehlo/stablehlo/dialect/Serialization.cpp | ||
@@ -15,6 +15,8 @@ | ||
|
||
#include "stablehlo/dialect/Serialization.h" | ||
|
||
+#include "llvm/Support/MemoryBuffer.h" | ||
+#include "mlir/Bytecode/BytecodeReader.h" | ||
#include "mlir/Bytecode/BytecodeWriter.h" | ||
#include "mlir/IR/BuiltinOps.h" | ||
#include "mlir/IR/Diagnostics.h" | ||
@@ -28,6 +30,9 @@ | ||
#include "stablehlo/dialect/Version.h" | ||
#include "stablehlo/dialect/VhloOps.h" | ||
#include "stablehlo/transforms/Passes.h" | ||
+#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/Debug.h" | ||
+ | ||
+#define DEBUG_TYPE "compat-passes" | ||
|
||
namespace mlir { | ||
namespace stablehlo { | ||
@@ -89,5 +94,36 @@ | ||
return module; | ||
} | ||
|
||
+FailureOr<vhlo::Version> getPortableArtifactVersion(llvm::StringRef bytecode) { | ||
+ auto logFailure = [&](llvm::StringRef message) { | ||
+ LLVM_DEBUG(llvm::dbgs() << "Failed to get portable artifact version: " | ||
+ << message << "\n"); | ||
+ return failure(); | ||
+ }; | ||
+ // Must start with MLiRxStableHLO_vX.Y.Z, minimum length of 19. | ||
+ constexpr size_t minHeaderLength = 19; | ||
+ if (bytecode.size() < minHeaderLength) return logFailure("min header"); | ||
+ | ||
+ // Truncate to the end of the null-terminated producer string. | ||
+ size_t pos = bytecode.find('\0'); | ||
+ if (pos == llvm::StringRef::npos) return logFailure("no terminator"); | ||
+ bytecode = bytecode.substr(0, pos); | ||
+ | ||
+ // Check if the bytecode is valid, starts with MLiR magic number. | ||
+ if (!isBytecode( | ||
+ llvm::MemoryBuffer::getMemBuffer(bytecode)->getMemBufferRef())) | ||
+ return logFailure("not bytecode"); | ||
+ | ||
+ // Skip 4 bytes for the magic number. | ||
+ std::string stablehloHeader = "StableHLO_v"; | ||
+ size_t stablehloPos = bytecode.find(stablehloHeader); | ||
+ if (stablehloPos == llvm::StringRef::npos) | ||
+ return logFailure("not a StableHLO portable artifact"); | ||
+ | ||
+ // Skip the 11 bytes for StableHLO_v to get the StableHLO version to parse. | ||
+ StringRef version = bytecode.substr(stablehloPos + stablehloHeader.size()); | ||
+ return vhlo::Version::fromString(version); | ||
+} | ||
+ | ||
} // namespace stablehlo | ||
} // namespace mlir | ||
diff --ruN a/stablehlo/stablehlo/dialect/Serialization.h b/stablehlo/stablehlo/dialect/Serialization.h | ||
--- stablehlo/stablehlo/dialect/Serialization.h | ||
+++ stablehlo/stablehlo/dialect/Serialization.h | ||
@@ -19,6 +19,7 @@ | ||
#include "mlir/IR/BuiltinOps.h" | ||
#include "mlir/IR/MLIRContext.h" | ||
#include "mlir/Support/LogicalResult.h" | ||
+#include "stablehlo/dialect/Version.h" | ||
|
||
namespace mlir { | ||
namespace stablehlo { | ||
@@ -43,6 +44,17 @@ | ||
OwningOpRef<ModuleOp> deserializePortableArtifact(StringRef sourceStr, | ||
MLIRContext* context); | ||
|
||
+// Get portable artifact version from the producer string after the MLIR | ||
+// Bytecode magic number `MLïRStableHLO_vX.Y.Z` -> X.Y.Z | ||
+// Returns failure if input string is not a valid portable artifact produced by | ||
+// serializePortableArtifact APIs, which would cause the bytecode artifact to | ||
+// not have the proper producer string. | ||
+// | ||
+// This method should be safe, since any changes to the bytecode format would | ||
+// warrant a bytecode version bump, and MLIR bytecode gives the option to | ||
+// specify a forward compatible bytecode version to target. | ||
+FailureOr<vhlo::Version> getPortableArtifactVersion(llvm::StringRef bytecode); | ||
+ | ||
} // namespace stablehlo | ||
} // namespace mlir | ||
|
||
diff --ruN a/stablehlo/stablehlo/tests/vhlo/vhlo_emit_version_api.1_1_0.mlir b/stablehlo/stablehlo/tests/vhlo/vhlo_emit_version_api.1_1_0.mlir | ||
--- stablehlo/stablehlo/tests/vhlo/vhlo_emit_version_api.1_1_0.mlir | ||
+++ stablehlo/stablehlo/tests/vhlo/vhlo_emit_version_api.1_1_0.mlir | ||
@@ -0,0 +1,19 @@ | ||
+// RUN: stablehlo-translate --deserialize --print-stablehlo-version %s.bc | FileCheck %s --check-prefix=CHECK-VERSION | ||
+// RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize --print-stablehlo-version | FileCheck %s --check-prefix=CHECK-VERSION-LATEST | ||
+// RUN: stablehlo-translate --deserialize --print-stablehlo-version %s | FileCheck %s --check-prefix=CHECK-VERSION-NOT-BYTECODE | ||
+ | ||
+// This file tests the `getPortableArtifactVersion` Serialization API. | ||
+// Any breakages to this file likely indicate that the MLIR Bytecode Format | ||
+// has changed, or that the StableHLO producer string emit by | ||
+// `serializePortableArtifact` has changed. | ||
+// | ||
+// See the `getPortableArtifactVersion` doc comments for more details. | ||
+ | ||
+// CHECK-VERSION: // Reading portable artifact with StableHLO version: 1.1.0 | ||
+// CHECK-VERSION-NOT-BYTECODE: // Failed parsing StableHLO version from portable artifact | ||
+// CHECK-VERSION-LATEST: // Reading portable artifact with StableHLO version: {{.*}} | ||
+ | ||
+func.func @main(%arg0: tensor<f32>) -> tensor<f32> { | ||
+ %0 = stablehlo.add %arg0, %arg0 : tensor<f32> | ||
+ func.return %0 : tensor<f32> | ||
+} | ||
diff --ruN a/stablehlo/stablehlo/tools/StablehloTranslateMain.cpp b/stablehlo/stablehlo/tools/StablehloTranslateMain.cpp | ||
--- stablehlo/stablehlo/tools/StablehloTranslateMain.cpp | ||
+++ stablehlo/stablehlo/tools/StablehloTranslateMain.cpp | ||
@@ -53,6 +53,7 @@ | ||
#include "stablehlo/reference/Tensor.h" | ||
#include "stablehlo/reference/Value.h" | ||
#include "stablehlo/tests/CheckOps.h" | ||
+#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/raw_ostream.h" | ||
|
||
namespace mlir { | ||
|
||
@@ -63,6 +64,12 @@ | ||
|
||
llvm::cl::opt<bool> stripDebuginfoOption( | ||
"strip-debuginfo", llvm::cl::desc("Strip debug info from all operations"), | ||
+ llvm::cl::init(false)); | ||
+ | ||
+llvm::cl::opt<bool> printStablehloVersion( | ||
+ "print-stablehlo-version", | ||
+ llvm::cl::desc( | ||
+ "When deserializing a portable artifact, print the StableHLO version"), | ||
llvm::cl::init(false)); | ||
|
||
llvm::cl::opt<std::string> targetOption( | ||
@@ -306,6 +313,17 @@ | ||
TranslateToMLIRRegistration deserializeRegistration( | ||
"deserialize", "Deserialize a portable artifact into a StableHLO program", | ||
[](llvm::StringRef input, mlir::MLIRContext *context) { | ||
+ if (printStablehloVersion.getValue()) { | ||
+ auto version = stablehlo::getPortableArtifactVersion(input); | ||
+ if (failed(version)) { | ||
+ llvm::outs() | ||
+ << "// Failed parsing StableHLO version from portable artifact\n"; | ||
+ } else { | ||
+ llvm::outs() | ||
+ << "// Reading portable artifact with StableHLO version: " | ||
+ << *version << "\n"; | ||
+ } | ||
+ } | ||
return stablehlo::deserializePortableArtifact(input, context); | ||
}, | ||
[](DialectRegistry ®istry) { | ||
|