Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[StableHLO][API] Add API to get StableHLO version from PortableArtifact #20105

Merged
merged 1 commit into from
Dec 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
@@ -1 +1,165 @@
diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel
--- stablehlo/BUILD.bazel
+++ stablehlo/BUILD.bazel
@@ -925,6 +925,7 @@
":stablehlo_serialization",
":version",
"@llvm-project//llvm:Support",
+ "@llvm-project//mlir:BytecodeReader",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
diff --ruN a/stablehlo/stablehlo/dialect/Serialization.cpp b/stablehlo/stablehlo/dialect/Serialization.cpp
--- stablehlo/stablehlo/dialect/Serialization.cpp
+++ stablehlo/stablehlo/dialect/Serialization.cpp
@@ -15,6 +15,9 @@

#include "stablehlo/dialect/Serialization.h"

+#include "llvm/Support/Debug.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "mlir/Bytecode/BytecodeReader.h"
#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
@@ -28,6 +31,8 @@
#include "stablehlo/dialect/Version.h"
#include "stablehlo/dialect/VhloOps.h"
#include "stablehlo/transforms/Passes.h"
+
+#define DEBUG_TYPE "compat-passes"

namespace mlir {
namespace stablehlo {
@@ -89,5 +94,36 @@
return module;
}

+FailureOr<vhlo::Version> getPortableArtifactVersion(llvm::StringRef bytecode) {
+ auto logFailure = [&](llvm::StringRef message) {
+ LLVM_DEBUG(llvm::dbgs() << "Failed to get portable artifact version: "
+ << message << "\n");
+ return failure();
+ };
+ // Must start with MLiRxStableHLO_vX.Y.Z, minimum length of 19.
+ constexpr size_t minHeaderLength = 19;
+ if (bytecode.size() < minHeaderLength) return logFailure("min header");
+
+ // Truncate to the end of the null-terminated producer string.
+ size_t pos = bytecode.find('\0');
+ if (pos == llvm::StringRef::npos) return logFailure("no terminator");
+ bytecode = bytecode.substr(0, pos);
+
+ // Check if the bytecode is valid, starts with MLiR magic number.
+ if (!isBytecode(
+ llvm::MemoryBuffer::getMemBuffer(bytecode)->getMemBufferRef()))
+ return logFailure("not bytecode");
+
+ // Skip 4 bytes for the magic number.
+ std::string stablehloHeader = "StableHLO_v";
+ size_t stablehloPos = bytecode.find(stablehloHeader);
+ if (stablehloPos == llvm::StringRef::npos)
+ return logFailure("not a StableHLO portable artifact");
+
+ // Skip the 11 bytes for StableHLO_v to get the StableHLO version to parse.
+ StringRef version = bytecode.substr(stablehloPos + stablehloHeader.size());
+ return vhlo::Version::fromString(version);
+}
+
} // namespace stablehlo
} // namespace mlir
diff --ruN a/stablehlo/stablehlo/dialect/Serialization.h b/stablehlo/stablehlo/dialect/Serialization.h
--- stablehlo/stablehlo/dialect/Serialization.h
+++ stablehlo/stablehlo/dialect/Serialization.h
@@ -19,6 +19,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Support/LogicalResult.h"
+#include "stablehlo/dialect/Version.h"

namespace mlir {
namespace stablehlo {
@@ -43,6 +44,17 @@
OwningOpRef<ModuleOp> deserializePortableArtifact(StringRef sourceStr,
MLIRContext* context);

+// Get portable artifact version from the producer string after the MLIR
+// Bytecode magic number `MLïRStableHLO_vX.Y.Z` -> X.Y.Z
+// Returns failure if input string is not a valid portable artifact produced by
+// serializePortableArtifact APIs, which would cause the bytecode artifact to
+// not have the proper producer string.
+//
+// This method should be safe, since any changes to the bytecode format would
+// warrant a bytecode version bump, and MLIR bytecode gives the option to
+// specify a forward compatible bytecode version to target.
+FailureOr<vhlo::Version> getPortableArtifactVersion(llvm::StringRef bytecode);
+
} // namespace stablehlo
} // namespace mlir

diff --ruN a/stablehlo/stablehlo/tests/vhlo/vhlo_emit_version_api.1_1_0.mlir b/stablehlo/stablehlo/tests/vhlo/vhlo_emit_version_api.1_1_0.mlir
--- stablehlo/stablehlo/tests/vhlo/vhlo_emit_version_api.1_1_0.mlir
+++ stablehlo/stablehlo/tests/vhlo/vhlo_emit_version_api.1_1_0.mlir
@@ -0,0 +1,19 @@
+// RUN: stablehlo-translate --deserialize --print-stablehlo-version %s.bc | FileCheck %s --check-prefix=CHECK-VERSION
+// RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize --print-stablehlo-version | FileCheck %s --check-prefix=CHECK-VERSION-LATEST
+// RUN: stablehlo-translate --deserialize --print-stablehlo-version %s | FileCheck %s --check-prefix=CHECK-VERSION-NOT-BYTECODE
+
+// This file tests the `getPortableArtifactVersion` Serialization API.
+// Any breakages to this file likely indicate that the MLIR Bytecode Format
+// has changed, or that the StableHLO producer string emit by
+// `serializePortableArtifact` has changed.
+//
+// See the `getPortableArtifactVersion` doc comments for more details.
+
+// CHECK-VERSION: // Reading portable artifact with StableHLO version: 1.1.0
+// CHECK-VERSION-NOT-BYTECODE: // Failed parsing StableHLO version from portable artifact
+// CHECK-VERSION-LATEST: // Reading portable artifact with StableHLO version: {{.*}}
+
+func.func @main(%arg0: tensor<f32>) -> tensor<f32> {
+ %0 = stablehlo.add %arg0, %arg0 : tensor<f32>
+ func.return %0 : tensor<f32>
+}
diff --ruN a/stablehlo/stablehlo/tools/StablehloTranslateMain.cpp b/stablehlo/stablehlo/tools/StablehloTranslateMain.cpp
--- stablehlo/stablehlo/tools/StablehloTranslateMain.cpp
+++ stablehlo/stablehlo/tools/StablehloTranslateMain.cpp
@@ -23,6 +23,7 @@
#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/LogicalResult.h"
+#include "llvm/Support/raw_ostream.h"
#include "mlir/AsmParser/AsmParser.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Quant/IR/Quant.h"
@@ -65,6 +66,12 @@
"strip-debuginfo", llvm::cl::desc("Strip debug info from all operations"),
llvm::cl::init(false));

+llvm::cl::opt<bool> printStablehloVersion(
+ "print-stablehlo-version",
+ llvm::cl::desc(
+ "When deserializing a portable artifact, print the StableHLO version"),
+ llvm::cl::init(false));
+
llvm::cl::opt<std::string> targetOption(
"target", llvm::cl::desc("Target version for serialization"),
llvm::cl::init(""));
@@ -306,6 +313,17 @@
TranslateToMLIRRegistration deserializeRegistration(
"deserialize", "Deserialize a portable artifact into a StableHLO program",
[](llvm::StringRef input, mlir::MLIRContext *context) {
+ if (printStablehloVersion.getValue()) {
+ auto version = stablehlo::getPortableArtifactVersion(input);
+ if (failed(version)) {
+ llvm::outs()
+ << "// Failed parsing StableHLO version from portable artifact\n";
+ } else {
+ llvm::outs()
+ << "// Reading portable artifact with StableHLO version: "
+ << *version << "\n";
+ }
+ }
return stablehlo::deserializePortableArtifact(input, context);
},
[](DialectRegistry &registry) {

Loading