Skip to content

Commit

Permalink
Merge pull request #359 from Consensys/356-change-permutation-to-defp…
Browse files Browse the repository at this point in the history
…ermutation

feat: update syntax for `permutation` to `defpermutation`
  • Loading branch information
DavePearce authored Oct 23, 2024
2 parents b6032f0 + 8922452 commit aa02631
Show file tree
Hide file tree
Showing 17 changed files with 106 additions and 57 deletions.
4 changes: 2 additions & 2 deletions pkg/hir/lower.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ func (p *Schema) LowerToMir() *mir.Schema {
mirSchema.AddDataColumn(col.Context(), col.Name(), col.Type())
}
// Lower assignments (nothing to do here)
for _, asn := range p.assignments {
mirSchema.AddAssignment(asn)
for _, a := range p.assignments {
mirSchema.AddAssignment(a)
}
// Lower constraints
for _, c := range p.constraints {
Expand Down
126 changes: 86 additions & 40 deletions pkg/hir/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ func (p *hirParser) parseDeclaration(s sexp.SExp) error {
return p.parseConstraintDeclaration(e.Elements)
} else if e.Len() == 3 && e.MatchSymbols(2, "assert") {
return p.parseAssertionDeclaration(e.Elements)
} else if e.Len() == 3 && e.MatchSymbols(1, "permute") {
return p.parseSortedPermutationDeclaration(e)
} else if e.Len() == 3 && e.MatchSymbols(1, "defpermutation") {
return p.parsePermutationDeclaration(e)
} else if e.Len() == 4 && e.MatchSymbols(1, "deflookup") {
return p.parseLookupDeclaration(e)
} else if e.Len() == 3 && e.MatchSymbols(1, "definterleaved") {
Expand Down Expand Up @@ -182,13 +182,18 @@ func (p *hirParser) parseColumnDeclaration(e sexp.SExp) error {
}

// Parse a sorted permutation declaration
func (p *hirParser) parseSortedPermutationDeclaration(l *sexp.List) error {
func (p *hirParser) parsePermutationDeclaration(l *sexp.List) error {
// Target columns are (sorted) permutations of source columns.
sexpTargets := l.Elements[1].AsList()
// Source columns.
sexpSources := l.Elements[2].AsList()
// Sanity check
if sexpTargets == nil {
return p.translator.SyntaxError(l.Elements[1], "malformed target columns")
} else if sexpSources == nil {
return p.translator.SyntaxError(l.Elements[2], "malformed source columns")
}
// Convert into appropriate form.
targets := make([]sc.Column, sexpTargets.Len())
sources := make([]uint, sexpSources.Len())
signs := make([]bool, sexpSources.Len())
//
Expand All @@ -199,40 +204,9 @@ func (p *hirParser) parseSortedPermutationDeclaration(l *sexp.List) error {
ctx := trace.VoidContext()
//
for i := 0; i < sexpSources.Len(); i++ {
source := sexpSources.Get(i).AsSymbol()
target := sexpTargets.Get(i).AsSymbol()
// Sanity check syntax as expected
if source == nil {
return p.translator.SyntaxError(sexpSources.Get(i), "malformed column")
} else if target == nil {
return p.translator.SyntaxError(sexpTargets.Get(i), "malformed column")
}
// Determine source column sign (i.e. sort direction)
sortName := source.Value
if strings.HasPrefix(sortName, "+") {
signs[i] = true
} else if strings.HasPrefix(sortName, "-") {
if i == 0 {
return p.translator.SyntaxError(source, "sorted permutation requires ascending first column")
}

signs[i] = false
} else {
return p.translator.SyntaxError(source, "malformed sort direction")
}

sourceName := sortName[1:]
targetName := target.Value
// Determine index for source column
sourceIndex, ok := p.env.LookupColumn(p.module, sourceName)
if !ok {
// Column doesn't exist!
return p.translator.SyntaxError(sexpSources.Get(i), fmt.Sprintf("unknown column %s", sourceName))
}
// Sanity check that target column *doesn't* exist.
if p.env.HasColumn(p.module, targetName) {
// No, it doesn't.
return p.translator.SyntaxError(sexpTargets.Get(i), fmt.Sprintf("duplicate column %s", targetName))
sourceIndex, sourceSign, err := p.parsePermutationSource(sexpSources.Get(i))
if err != nil {
return err
}
// Check source context
sourceCol := p.env.schema.Columns().Nth(sourceIndex)
Expand All @@ -244,16 +218,88 @@ func (p *hirParser) parseSortedPermutationDeclaration(l *sexp.List) error {
return p.translator.SyntaxError(sexpSources.Get(i), "empty evaluation context")
}
// Copy over column name
signs[i] = sourceSign
sources[i] = sourceIndex
// FIXME: determine source column type
targets[i] = sc.NewColumn(ctx, targetName, &sc.FieldType{})
}
// Parse targets
targets := make([]sc.Column, sexpTargets.Len())
// Parse targets
for i := 0; i < sexpTargets.Len(); i++ {
targetName, err := p.parsePermutationTarget(sexpTargets.Get(i))
//
if err != nil {
return err
}
// Lookup corresponding source
source := p.env.schema.Columns().Nth(sources[i])
// Done
targets[i] = sc.NewColumn(ctx, targetName, source.Type())
}
//
p.env.AddAssignment(assignment.NewSortedPermutation(ctx, targets, signs, sources))
//
return nil
}

func (p *hirParser) parsePermutationSource(source sexp.SExp) (uint, bool, error) {
var (
name string
sign bool
err error
)

if source.AsList() != nil {
l := source.AsList()
// Check whether sort direction provided
if l.Len() != 2 || l.Get(0).AsSymbol() == nil || l.Get(1).AsSymbol() == nil {
return 0, false, p.translator.SyntaxError(source, "malformed column")
}
// Parser sorting direction
if sign, err = p.parseSortDirection(l.Get(0).AsSymbol()); err != nil {
return 0, false, err
}
// Extract column name
name = l.Get(1).AsSymbol().Value
} else {
name = source.AsSymbol().Value
sign = true // default
}
// Determine index for source column
index, ok := p.env.LookupColumn(p.module, name)
if !ok {
// Column doesn't exist!
return 0, false, p.translator.SyntaxError(source, "unknown column")
}
// Done
return index, sign, nil
}

func (p *hirParser) parsePermutationTarget(target sexp.SExp) (string, error) {
if target.AsSymbol() == nil {
return "", p.translator.SyntaxError(target, "malformed target column")
}
//
targetName := target.AsSymbol().Value
// Sanity check that target column *doesn't* exist.
if p.env.HasColumn(p.module, targetName) {
// No, it doesn't.
return "", p.translator.SyntaxError(target, "duplicate column")
}
// Done
return targetName, nil
}

func (p *hirParser) parseSortDirection(l *sexp.Symbol) (bool, error) {
switch l.Value {
case "+", "↓":
return true, nil
case "-", "↑":
return false, nil
}
// Unknown sort
return false, p.translator.SyntaxError(l, "malformed sort direction")
}

// Parse a lookup declaration
func (p *hirParser) parseLookupDeclaration(l *sexp.List) error {
handle := l.Elements[1].AsSymbol().Value
Expand Down
2 changes: 1 addition & 1 deletion pkg/schema/assignment/sorted_permutation.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ func (p *SortedPermutation) Lisp(schema sc.Schema) sexp.SExp {
}

return sexp.NewList([]sexp.SExp{
sexp.NewSymbol("sort"),
sexp.NewSymbol("defpermutation"),
targets,
sources,
})
Expand Down
2 changes: 1 addition & 1 deletion pkg/schema/constraint/permutation.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (p *PermutationConstraint) Lisp(schema sc.Schema) sexp.SExp {
}

return sexp.NewList([]sexp.SExp{
sexp.NewSymbol("permutation"),
sexp.NewSymbol("defpermutation"),
targets,
sources,
})
Expand Down
2 changes: 1 addition & 1 deletion testdata/memory.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
;; Value being Read/Written
(defcolumns (VAL :u8))
;; Permutation
(permute (ADDR' PC' RW' VAL') (+ADDR +PC +RW +VAL))
(defpermutation (ADDR' PC' RW' VAL') ((+ ADDR) (+ PC) (+ RW) (+ VAL)))
;; PC[0]=0
(defconstraint heartbeat_1 (:domain {0}) PC)
;; PC[k]=0 || PC[k]=PC[k-1]+1
Expand Down
2 changes: 1 addition & 1 deletion testdata/module_06.lisp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
(defcolumns X)
(module m1)
(defcolumns ST (X :u16))
(permute (Y) (+X))
(defpermutation (Y) ((+ X)))
;; Ensure sorted column increments by 1
(defconstraint increment () (* ST (- (shift Y 1) (+ 1 Y))))
2 changes: 1 addition & 1 deletion testdata/module_07.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
;;
(module m1)
(defcolumns (X :u8) (Y :u8))
(permute (A B) (+X +Y))
(defpermutation (A B) ((+ X) (+ Y)))
(defconstraint diag_ab () (- (shift A 1) B))
2 changes: 1 addition & 1 deletion testdata/mxp.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
(mxp:DEPLOYS :u1)
(mxp:MXP_TYPE_2 :u1))

(permute (mxp:CN_perm mxp:STAMP_perm mxp:C_MEM_perm mxp:C_MEM_NEW_perm mxp:WORDS_perm mxp:WORDS_NEW_perm) (+mxp:CN +mxp:STAMP +mxp:C_MEM +mxp:C_MEM_NEW +mxp:WORDS +mxp:WORDS_NEW))
(defpermutation (mxp:CN_perm mxp:STAMP_perm mxp:C_MEM_perm mxp:C_MEM_NEW_perm mxp:WORDS_perm mxp:WORDS_NEW_perm) ((+ mxp:CN) (+ mxp:STAMP) (+ mxp:C_MEM) (+ mxp:C_MEM_NEW) (+ mxp:WORDS) (+ mxp:WORDS_NEW)))

(defconstraint mxp:counter-constancy () (begin (ifnot mxp:CT (- mxp:INST (shift mxp:INST -1))) (ifnot mxp:CT (- mxp:OFFSET_1_LO (shift mxp:OFFSET_1_LO -1))) (ifnot mxp:CT (- mxp:OFFSET_1_HI (shift mxp:OFFSET_1_HI -1))) (ifnot mxp:CT (- mxp:OFFSET_2_LO (shift mxp:OFFSET_2_LO -1))) (ifnot mxp:CT (- mxp:OFFSET_2_HI (shift mxp:OFFSET_2_HI -1))) (ifnot mxp:CT (- mxp:SIZE_1_LO (shift mxp:SIZE_1_LO -1))) (ifnot mxp:CT (- mxp:SIZE_1_HI (shift mxp:SIZE_1_HI -1))) (ifnot mxp:CT (- mxp:SIZE_2_LO (shift mxp:SIZE_2_LO -1))) (ifnot mxp:CT (- mxp:SIZE_2_HI (shift mxp:SIZE_2_HI -1))) (ifnot mxp:CT (- mxp:WORDS (shift mxp:WORDS -1))) (ifnot mxp:CT (- mxp:WORDS_NEW (shift mxp:WORDS_NEW -1))) (ifnot mxp:CT (- mxp:C_MEM (shift mxp:C_MEM -1))) (ifnot mxp:CT (- mxp:C_MEM_NEW (shift mxp:C_MEM_NEW -1))) (ifnot mxp:CT (- mxp:COMP (shift mxp:COMP -1))) (ifnot mxp:CT (- mxp:MXPX (shift mxp:MXPX -1))) (ifnot mxp:CT (- mxp:EXPANDS (shift mxp:EXPANDS -1))) (ifnot mxp:CT (- mxp:QUAD_COST (shift mxp:QUAD_COST -1))) (ifnot mxp:CT (- mxp:LIN_COST (shift mxp:LIN_COST -1))) (ifnot mxp:CT (- mxp:GAS_MXP (shift mxp:GAS_MXP -1)))))

Expand Down
5 changes: 4 additions & 1 deletion testdata/permute_01.lisp
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
(defcolumns (X :u16))
(permute (Y) (+X))
(defpermutation (Y) ((↓ X)))
(defpermutation (Z) ((+ X)))
;; Y == Z
(defconstraint eq () (- Y Z))
2 changes: 1 addition & 1 deletion testdata/permute_02.lisp
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
(defcolumns (X :u16))
(permute (Y) (+X))
(defpermutation (Y) ((+ X)))
(defconstraint first-row (:domain {0}) Y)
2 changes: 1 addition & 1 deletion testdata/permute_03.lisp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
(defcolumns ST (X :u16))
(permute (Y) (+X))
(defpermutation (Y) ((↓ X)))
;; Ensure sorted column increments by 1
(defconstraint increment () (* ST (- (shift Y 1) (+ 1 Y))))
2 changes: 1 addition & 1 deletion testdata/permute_04.lisp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
(defcolumns
(ST :u16)
(X :u16))
(permute (ST' Y) (+ST -X))
(defpermutation (ST' Y) ((↓ ST) (↑ X)))
(defconstraint first-row (:domain {-1}) (- Y 5))
2 changes: 1 addition & 1 deletion testdata/permute_05.lisp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
(defcolumns
(X :u8)
(Y :u8))
(permute (A B) (+X +Y))
(defpermutation (A B) ((+ X) (+ Y)))
(defconstraint diag_ab () (- (shift A 1) B))
2 changes: 1 addition & 1 deletion testdata/permute_06.lisp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
(defcolumns
(X :u16)
(Y :u16))
(permute (A B) (+X +Y))
(defpermutation (A B) ((+ X) (+ Y)))
(defconstraint diag_ab () (- (shift A 1) B))
2 changes: 1 addition & 1 deletion testdata/permute_07.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
(ST :u16)
(X :u16)
(Y :u16))
(permute (ST' A B) (+ST -X +Y))
(defpermutation (ST' A B) ((+ ST) (- X) (+ Y)))
(defconstraint diag_ab () (* ST' (- (shift A 1) B)))
2 changes: 1 addition & 1 deletion testdata/permute_08.lisp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
(defcolumns
(X :u16)
(Y :u16))
(permute (A B) (+X -Y))
(defpermutation (A B) ((+ X) (- Y)))
(defconstraint diag_ab () (* A (- (shift A 1) B)))
2 changes: 1 addition & 1 deletion testdata/permute_09.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
(ST :u16)
(X :u16)
(Y :u16))
(permute (ST' A B) (+ST -X -Y))
(defpermutation (ST' A B) ((+ ST) (- X) (- Y)))
(defconstraint diag_ab () (* ST' (- (shift A 1) B)))

0 comments on commit aa02631

Please sign in to comment.