-
Notifications
You must be signed in to change notification settings - Fork 256
/
bf-staged.scala
94 lines (77 loc) · 2.69 KB
/
bf-staged.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
//> using dep org.scala-lang::scala3-staging::3.5.2
import scala.quoted.*
import scala.collection.mutable.ArrayBuffer
class Tape:
private var tape: Array[Int] = Array(0)
private var pos: Int = 0
def get = tape(pos)
def inc(x: Int) = tape(pos) += x
def move(x: Int) =
pos += x
if pos >= tape.length then
tape = Array.copyOf(tape, tape.length * 2)
class Printer(val quiet: Boolean):
private var sum1, sum2: Int = 0
def write(n: Int) =
if quiet then
sum1 = (sum1 + n) % 255
sum2 = (sum2 + sum1) % 255
else print(n.toChar)
def checksum = (sum2 << 8) | sum1
class Program(text: String, p: Printer):
def parse(using Quotes)(iter : Iterator[Char], t : Expr[Tape], p : Expr[Printer]) : Expr[Unit] =
val code = ArrayBuffer.empty[Expr[Unit]]
while (iter.hasNext) do
iter.next() match
case '+' => code += '{ $t.inc(1) }
case '-' => code += '{ $t.inc(-1) }
case '>' => code += '{ $t.move(1) }
case '<' => code += '{ $t.move(-1) }
case '.' => code += '{ $p.write($t.get) }
case '[' => code += '{ def body() = ${ parse(iter, t, p) }
while $t.get > 0 do body()
}
case ']' => return Expr.block(code.toList, '{})
case _ =>
Expr.block(code.toList, '{})
given staging.Compiler = staging.Compiler.make(getClass.getClassLoader)
val runOn : (Tape, Printer) => Unit =
staging.run('{ ((t : Tape, p : Printer) => ${ parse(text.iterator, 't, 'p) }) })
def run = runOn(Tape(), p)
object BrainFuckStaged {
def notify(msg: String) = {
scala.util.Using((java.net.Socket("localhost", 9001)).getOutputStream()) {
_.write(msg.getBytes())
}
}
def verify = {
val text = """++++++++[>++++[>++>+++>+++>+<<<<-]>+>+>->>+[<]<-]>>.>
---.+++++++..+++.>>.<-.<.+++.------.--------.>>+.>++."""
val pLeft = Printer(true)
Program(text, pLeft).run
val left = pLeft.checksum
val pRight = Printer(true)
for (c <- "Hello World!\n") {
pRight.write(c)
}
val right = pRight.checksum
if (left != right) {
System.err.println(s"${left} != ${right}")
System.exit(1)
}
}
def main(args: Array[String]): Unit = {
val filename = args(0)
verify
val text = scala.util.Using(scala.io.Source.fromFile(filename)) { _.mkString }.get
val p = Printer(sys.env.get("QUIET").isDefined)
notify(s"Scala (Staged)\t${ProcessHandle.current().pid()}")
val s = System.nanoTime
Program(text, p).run
val elapsed = (System.nanoTime - s) / 1e9
notify("stop")
System.err.println(s"time: $elapsed s")
if p.quiet then
System.out.println(s"Output checksum: ${p.checksum}")
}
}