Skip to content

Commit

Permalink
feat: add option to generate command to only generate allowed messa…
Browse files Browse the repository at this point in the history
…ge ids
  • Loading branch information
aschiffmann committed Nov 4, 2024
1 parent 7798b54 commit 95a09c7
Show file tree
Hide file tree
Showing 4 changed files with 359 additions and 287 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func main() {
It is possible to generate Go code from a `.dbc` file.

```
$ go run go.einride.tech/can/cmd/cantool generate <dbc file root folder> <output folder>
$ go run go.einride.tech/can/cmd/cantool generate <dbc file root folder> <output folder> [<allowed-message-ids>...]
```

In order to generate Go code that makes sense, we currently perform some
Expand Down
9 changes: 6 additions & 3 deletions cmd/cantool/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ func generateCommand(app *kingpin.Application) {
Arg("output-dir", "output directory").
Required().
String()
allowedMessageIds := command.
Arg("allowed-message-ids", "optional filter of message-ids to compile").
Uint32List()
command.Action(func(_ *kingpin.ParseContext) error {
return filepath.Walk(*inputDir, func(p string, i os.FileInfo, err error) error {
if err != nil {
Expand All @@ -66,7 +69,7 @@ func generateCommand(app *kingpin.Application) {
}
outputFile := relPath + ".go"
outputPath := filepath.Join(*outputDir, outputFile)
return genGo(p, outputPath)
return genGo(p, outputPath, *allowedMessageIds)
})
})
}
Expand Down Expand Up @@ -143,15 +146,15 @@ func analyzers() []*analysis.Analyzer {
}
}

func genGo(inputFile, outputFile string) error {
func genGo(inputFile, outputFile string, allowedMessageIds []uint32) error {
if err := os.MkdirAll(filepath.Dir(outputFile), 0o755); err != nil {
return err
}
input, err := os.ReadFile(inputFile)
if err != nil {
return err
}
result, err := generate.Compile(inputFile, input)
result, err := generate.Compile(inputFile, input, generate.WithAllowedMessageIds(allowedMessageIds))
if err != nil {
return err
}
Expand Down
56 changes: 47 additions & 9 deletions internal/generate/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ type CompileResult struct {
Warnings []error
}

func Compile(sourceFile string, data []byte) (result *CompileResult, err error) {
type CompileOption func(*compiler)

func Compile(sourceFile string, data []byte, options ...CompileOption) (result *CompileResult, err error) {
p := dbc.NewParser(sourceFile, data)
if err := p.Parse(); err != nil {
return nil, fmt.Errorf("failed to parse DBC source file: %w", err)
Expand All @@ -24,12 +26,30 @@ func Compile(sourceFile string, data []byte) (result *CompileResult, err error)
db: &descriptor.Database{SourceFile: sourceFile},
defs: defs,
}

for _, opt := range options {
opt(c)
}

c.collectDescriptors()
c.addMetadata()
c.sortDescriptors()
return &CompileResult{Database: c.db, Warnings: c.warnings}, nil
}

func WithAllowedMessageIds(ids []uint32) CompileOption {
return func(c *compiler) {
if ids == nil {
return
}

c.onlyCompileIds = make([]dbc.MessageID, 0, len(ids))
for _, id := range ids {
c.onlyCompileIds = append(c.onlyCompileIds, dbc.MessageID(id))
}
}
}

type compileError struct {
def dbc.Def
reason string
Expand All @@ -40,9 +60,10 @@ func (e *compileError) Error() string {
}

type compiler struct {
db *descriptor.Database
defs []dbc.Def
warnings []error
onlyCompileIds []dbc.MessageID
db *descriptor.Database
defs []dbc.Def
warnings []error
}

func (c *compiler) addWarning(warning error) {
Expand All @@ -55,7 +76,7 @@ func (c *compiler) collectDescriptors() {
case *dbc.VersionDef:
c.db.Version = def.Version
case *dbc.MessageDef:
if def.MessageID == dbc.IndependentSignalsMessageID {
if c.skipCompile(def.MessageID) {
continue // don't compile
}
message := &descriptor.Message{
Expand Down Expand Up @@ -101,7 +122,9 @@ func (c *compiler) addMetadata() {
case *dbc.SignalValueTypeDef:
signal, ok := c.db.Signal(def.MessageID.ToCAN(), string(def.SignalName))
if !ok {
c.addWarning(&compileError{def: def, reason: "no declared signal"})
if !c.skipCompile(def.MessageID) {
c.addWarning(&compileError{def: def, reason: "no declared signal"})
}
continue
}
switch def.SignalValueType {
Expand All @@ -121,7 +144,7 @@ func (c *compiler) addMetadata() {
case *dbc.CommentDef:
switch def.ObjectType {
case dbc.ObjectTypeMessage:
if def.MessageID == dbc.IndependentSignalsMessageID {
if c.skipCompile(def.MessageID) {
continue // don't compile
}
message, ok := c.db.Message(def.MessageID.ToCAN())
Expand All @@ -131,7 +154,7 @@ func (c *compiler) addMetadata() {
}
message.Description = def.Comment
case dbc.ObjectTypeSignal:
if def.MessageID == dbc.IndependentSignalsMessageID {
if c.skipCompile(def.MessageID) {
continue // don't compile
}
signal, ok := c.db.Signal(def.MessageID.ToCAN(), string(def.SignalName))
Expand All @@ -149,7 +172,7 @@ func (c *compiler) addMetadata() {
node.Description = def.Comment
}
case *dbc.ValueDescriptionsDef:
if def.MessageID == dbc.IndependentSignalsMessageID {
if c.skipCompile(def.MessageID) {
continue // don't compile
}
if def.ObjectType != dbc.ObjectTypeSignal {
Expand All @@ -167,6 +190,9 @@ func (c *compiler) addMetadata() {
})
}
case *dbc.AttributeValueForObjectDef:
if c.skipCompile(def.MessageID) {
continue // don't compile
}
switch def.ObjectType {
case dbc.ObjectTypeMessage:
msg, ok := c.db.Message(def.MessageID.ToCAN())
Expand Down Expand Up @@ -225,3 +251,15 @@ func (c *compiler) sortDescriptors() {
}
}
}

func (c *compiler) skipCompile(id dbc.MessageID) bool {
if id == dbc.IndependentSignalsMessageID {
return true
}
for _, allowedMessageId := range c.onlyCompileIds {
if allowedMessageId == id {
return false
}
}
return len(c.onlyCompileIds) > 0
}
Loading

0 comments on commit 95a09c7

Please sign in to comment.