diff --git a/pkg/ast/json.go b/pkg/ast/json.go index 89a42d4..a96d7aa 100644 --- a/pkg/ast/json.go +++ b/pkg/ast/json.go @@ -134,73 +134,44 @@ func (n *Node[Stmt]) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements the json.Unmarshaler interface func (n *Node[T]) UnmarshalJSON(data []byte) error { - var objMap map[string]json.RawMessage - if err := json.Unmarshal(data, &objMap); err != nil { - return err - } - - if data, ok := objMap["id"]; ok { - if err := json.Unmarshal(data, &n.ID); err != nil { - return err - } - } - if data, ok := objMap["filename"]; ok { - if err := json.Unmarshal(data, &n.Pos.Filename); err != nil { - return err - } - } - if data, ok := objMap["line"]; ok { - if err := json.Unmarshal(data, &n.Pos.Line); err != nil { - return err - } + var node struct { + ID AstIndex `json:"id,omitempty"` + NodeData json.RawMessage `json:"node,omitempty"` + Pos } - if data, ok := objMap["column"]; ok { - if err := json.Unmarshal(data, &n.Pos.Column); err != nil { - return err - } - } - if data, ok := objMap["end_line"]; ok { - if err := json.Unmarshal(data, &n.Pos.EndLine); err != nil { - return err - } - } - if data, ok := objMap["end_column"]; ok { - if err := json.Unmarshal(data, &n.Pos.EndColumn); err != nil { - return err - } - } - - nodeData, ok := objMap["node"] - if !ok { - return fmt.Errorf("missing 'node' field") + if err := json.Unmarshal(data, &node); err != nil { + return err } - if expr, err := UnmarshalExpr(nodeData); err == nil { + n.Pos = node.Pos + n.ID = node.ID + var ok bool + if expr, err := UnmarshalExpr(node.NodeData); err == nil { if n.Node, ok = expr.(T); ok { return nil } } - if stmt, err := UnmarshalStmt(nodeData); err == nil { + if stmt, err := UnmarshalStmt(node.NodeData); err == nil { if n.Node, ok = stmt.(T); ok { return nil } } - if memberOrIndex, err := UnmarshalMemberOrIndex(nodeData); err == nil { + if memberOrIndex, err := UnmarshalMemberOrIndex(node.NodeData); err == nil { if n.Node, ok = memberOrIndex.(T); ok { return nil } } - if numberLit, err := UnmarshalNumberLitValue(nodeData); err == nil { + if numberLit, err := UnmarshalNumberLitValue(node.NodeData); err == nil { if n.Node, ok = numberLit.(T); ok { return nil } } - if ty, err := UnmarshalType(nodeData); err == nil { + if ty, err := UnmarshalType(node.NodeData); err == nil { if n.Node, ok = ty.(T); ok { return nil } } else { otherNode := new(T) - if err := json.Unmarshal(nodeData, otherNode); err != nil { + if err := json.Unmarshal(node.NodeData, otherNode); err != nil { return err } n.Node = *otherNode