diff --git a/plugin/plugin.go b/plugin/plugin.go index 790a1263..e58e8f27 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -921,6 +921,12 @@ func (b *ORMBuilder) parseBasicFields(msg *protogen.Message, g *protogen.Generat } else { continue } + } else { + goType, pointer := fieldGoType(g, field) + if pointer { + goType = "*" + goType + } + fieldType = goType } switch fieldType { @@ -967,6 +973,49 @@ func (b *ORMBuilder) parseBasicFields(msg *protogen.Message, g *protogen.Generat } } +func fieldGoType(g *protogen.GeneratedFile, field *protogen.Field) (goType string, pointer bool) { + if field.Desc.IsWeak() { + return "struct{}", false + } + + pointer = field.Desc.HasPresence() + switch field.Desc.Kind() { + case protoreflect.BoolKind: + goType = "bool" + case protoreflect.EnumKind: + goType = g.QualifiedGoIdent(field.Enum.GoIdent) + case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: + goType = "int32" + case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: + goType = "uint32" + case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: + goType = "int64" + case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: + goType = "uint64" + case protoreflect.FloatKind: + goType = "float32" + case protoreflect.DoubleKind: + goType = "float64" + case protoreflect.StringKind: + goType = "string" + case protoreflect.BytesKind: + goType = "[]byte" + pointer = false // rely on nullability of slices for presence + case protoreflect.MessageKind, protoreflect.GroupKind: + goType = "*" + g.QualifiedGoIdent(field.Message.GoIdent) + pointer = false // pointer captured as part of the type + } + switch { + case field.Desc.IsList(): + return "[]" + goType, false + case field.Desc.IsMap(): + keyType, _ := fieldGoType(g, field.Message.Fields[0]) + valType, _ := fieldGoType(g, field.Message.Fields[1]) + return fmt.Sprintf("map[%v]%v", keyType, valType), false + } + return goType, pointer +} + func (b *ORMBuilder) addIncludedField(ormable *OrmableType, field *gorm.ExtraField, g *protogen.GeneratedFile) { fieldName := camelCase(field.GetName()) isPtr := strings.HasPrefix(field.GetType(), "*")