diff --git a/shai/src/main/protobuf/cmdExecutor.proto b/shai/src/main/protobuf/cmdExecutor.proto index 62c3f87e09..6c6532a3b5 100644 --- a/shai/src/main/protobuf/cmdExecutor.proto +++ b/shai/src/main/protobuf/cmdExecutor.proto @@ -34,6 +34,7 @@ enum Cmd { PARSE = 0; CHECK = 1; TYPECHECK = 3; + TLA = 4; // Parse input and reply with a TLA+ string } enum CmdErrorType { diff --git a/shai/src/main/scala/at/forsyte/apalache/shai/v1/CmdExecutorService.scala b/shai/src/main/scala/at/forsyte/apalache/shai/v1/CmdExecutorService.scala index 09333fd122..c89abd7235 100644 --- a/shai/src/main/scala/at/forsyte/apalache/shai/v1/CmdExecutorService.scala +++ b/shai/src/main/scala/at/forsyte/apalache/shai/v1/CmdExecutorService.scala @@ -15,6 +15,8 @@ import at.forsyte.apalache.shai.v1.cmdExecutor.{ import at.forsyte.apalache.tla.bmcmt.config.CheckerModule import at.forsyte.apalache.tla.passes.imp.ParserModule import at.forsyte.apalache.tla.passes.typecheck.TypeCheckerModule +import at.forsyte.apalache.tla.lir.TlaModule +import at.forsyte.apalache.io.lir.PrettyWriter /** * Provides the [[CmdExecutorService]] @@ -77,9 +79,9 @@ class CmdExecutorService(logger: Logger) extends ZioCmdExecutor.ZCmdExecutor[ZEn toolModule <- { import OptionGroup._ cmd match { - case Cmd.PARSE => WithIO(cfg).map(new ParserModule(_)).toCmdResult - case Cmd.CHECK => WithCheckerPreds(cfg).map(new CheckerModule(_)).toCmdResult - case Cmd.TYPECHECK => WithTypechecker(cfg).map(new TypeCheckerModule(_)).toCmdResult + case Cmd.PARSE | Cmd.TLA => WithIO(cfg).map(new ParserModule(_)).toCmdResult + case Cmd.CHECK => WithCheckerPreds(cfg).map(new CheckerModule(_)).toCmdResult + case Cmd.TYPECHECK => WithTypechecker(cfg).map(new TypeCheckerModule(_)).toCmdResult case Cmd.Unrecognized(_) => throw new IllegalArgumentException("programmer error: executeCmd applied before validateCmd") } @@ -90,7 +92,14 @@ class CmdExecutorService(logger: Logger) extends ZioCmdExecutor.ZCmdExecutor[ZEn catch { case err: Throwable => Left(throwableErr(err)) } - } yield TlaToUJson(tlaModule) + } yield cmd match { + case Cmd.TLA => tlaModuleToJsonString(tlaModule) + case _ => TlaToUJson(tlaModule) + } + } + + private def tlaModuleToJsonString(module: TlaModule): ujson.Value = { + ujson.Str(PrettyWriter.writeAsString(module)) } // Allows us to handle invalid protobuf messages on the ZIO level, before diff --git a/shai/src/test/scala/at/forsyte/apalache/shai/v1/TestCmdExecutorService.scala b/shai/src/test/scala/at/forsyte/apalache/shai/v1/TestCmdExecutorService.scala index 8cabb2c498..4574daa405 100644 --- a/shai/src/test/scala/at/forsyte/apalache/shai/v1/TestCmdExecutorService.scala +++ b/shai/src/test/scala/at/forsyte/apalache/shai/v1/TestCmdExecutorService.scala @@ -122,6 +122,22 @@ object TestCmdExecutorService extends DefaultRunnableSpec { assert(data("error_data").arr)(isNonEmpty) } }, + testM("can use TLA command to receive formatted TLA") { + val expectedPayload = + """|----------------------------------- MODULE M ----------------------------------- + | + |EXTENDS Integers, Sequences, FiniteSets, TLC, Apalache + | + |Foo == TRUE + | + |================================================================================ + |""".stripMargin + for { + s <- ZIO.service[CmdExecutorService] + resp <- s.run(runCmd(Cmd.TLA, trivialSpec)) + actualPayload = ujson.read(resp.result.success.get).str + } yield assert(actualPayload)(equalTo(expectedPayload)) + }, ) // Create the single shared service for use in our tests, allowing us to run // all tests as if they were against the same service this accurately diff --git a/tla-io/src/main/scala/at/forsyte/apalache/io/lir/PrettyWriter.scala b/tla-io/src/main/scala/at/forsyte/apalache/io/lir/PrettyWriter.scala index 527a070fdc..c90ce1064b 100644 --- a/tla-io/src/main/scala/at/forsyte/apalache/io/lir/PrettyWriter.scala +++ b/tla-io/src/main/scala/at/forsyte/apalache/io/lir/PrettyWriter.scala @@ -10,6 +10,7 @@ import org.bitbucket.inkytonik.kiama.output.PrettyPrinter import java.io.{File, FileWriter, PrintWriter} import scala.collection.immutable.{HashMap, HashSet} +import java.io.StringWriter /** *

A pretty printer to a file that formats a TLA+ expression to a given text width (normally, 80 characters). As @@ -691,6 +692,13 @@ object PrettyWriter { } } + def writeAsString(module: TlaModule, extendedModules: List[String] = TlaWriter.STANDARD_MODULES): String = { + val buf = new StringWriter() + val prettyWriter = new PrettyWriter(new PrintWriter(buf)) + prettyWriter.write(module, extendedModules) + buf.toString() + } + protected val unaryOps = HashMap( TlaBoolOper.not -> "~", TlaArithOper.uminus -> "-",