Skip to content

Commit

Permalink
Created security support.
Browse files Browse the repository at this point in the history
  • Loading branch information
sjohnsonaz committed Aug 8, 2022
1 parent b16a515 commit 9c7976c
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 21 deletions.
11 changes: 11 additions & 0 deletions entity/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ type Endpoint struct {
Header map[string]HeaderProperty
Body BodyProperty
Response Response
Security []Security
}

type ParamProperty struct {
Expand Down Expand Up @@ -52,3 +53,13 @@ type Response struct {
Default bool
DefaultCode int
}

type Security struct {
Type SecurityType
}

type SecurityType string

const SECURITY_TYPE_BASIC SecurityType = "basic"
const SECURITY_TYPE_BEARER SecurityType = "bearer"
const SECURITY_TYPE_COOKIE SecurityType = "cookie"
10 changes: 9 additions & 1 deletion example/openapi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ paths:
/token:
get:
operationId: getToken
security:
- basicAuth: []
parameters:
- name: authorization
in: header
Expand Down Expand Up @@ -293,4 +295,10 @@ components:
type: integer
format: int32
message:
type: string
type: string
securitySchemes:
basicAuth:
type: http
scheme: basic
security:
- basicAuth: []
25 changes: 24 additions & 1 deletion parser/header_property.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ import (
"github.com/getkin/kin-openapi/openapi3"
)

func GetHeader(operation *openapi3.Operation) map[string]entity.HeaderProperty {
func GetHeader(
operation *openapi3.Operation,
securitySchemes []entity.Security,
) map[string]entity.HeaderProperty {
header := make(map[string]entity.HeaderProperty)
parameters := array.Map(operation.Parameters, func(ref *openapi3.ParameterRef) *openapi3.Parameter {
return ref.Value
Expand All @@ -25,5 +28,25 @@ func GetHeader(operation *openapi3.Operation) map[string]entity.HeaderProperty {
}
}
})
array.ForEach(securitySchemes, func(securityScheme entity.Security) {
authorization := "authorization"
Authorization := "Authorization"
switch securityScheme.Type {
case entity.SECURITY_TYPE_BASIC:
header[Authorization] = entity.HeaderProperty{
Type: "string",
Name: Authorization,
Key: authorization,
}
case entity.SECURITY_TYPE_BEARER:
header[Authorization] = entity.HeaderProperty{
Type: "string",
Name: Authorization,
Key: authorization,
}
case entity.SECURITY_TYPE_COOKIE:
default:
}
})
return header
}
2 changes: 1 addition & 1 deletion parser/query_property.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func (p *SchemaParser) GetQuery(operation *openapi3.Operation) map[string]entity
array.ForEach(parameters, func(parameter *openapi3.Parameter) {
if parameter.Schema != nil && parameter.Schema.Value != nil {
name := GetPropertyName(parameter.Name)
schema := p.Add(name, parameter.Schema, false)
schema := p.add(name, parameter.Schema, false)
query[name] = entity.QueryProperty{
Schema: schema,
Name: parameter.Name,
Expand Down
2 changes: 1 addition & 1 deletion parser/response_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func (p *SchemaParser) GetResponses(operation *openapi3.Operation, s map[*openap
schema, ok := s[mediaType.Schema.Value]
if !ok {
// Schema has not be logged
schema = p.Add(GetPropertyName(operation.OperationID+code), mediaType.Schema, true)
schema = p.add(GetPropertyName(operation.OperationID+code), mediaType.Schema, true)
}
if schema == nil {
// TODO: Handle this error
Expand Down
120 changes: 105 additions & 15 deletions parser/schema_parser.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package parser

import (
"fmt"
"sort"
"strings"

Expand Down Expand Up @@ -43,15 +44,29 @@ func (p *SchemaParser) GetEndpoints() []*entity.Endpoint {

func (p *SchemaParser) Parse(doc *openapi3.T) {
for name, schemaRef := range doc.Components.Schemas {
p.Add(name, schemaRef, true)
p.add(name, schemaRef, true)
}
security := p.addSecurity(doc)
for key, path := range doc.Paths {
p.AddEndpoint(key, path)
p.addEndpoint(doc, key, path, security)
}
p.sortSchemas()
p.sortEndpoints()
}

func (p *SchemaParser) addSecurity(doc *openapi3.T) []*openapi3.SecurityScheme {
security := make([]*openapi3.SecurityScheme, 0)
for _, securityRequirement := range doc.Security {
for name := range securityRequirement {
securitySchema, ok := doc.Components.SecuritySchemes[name]
if ok && securitySchema.Value != nil {
security = append(security, securitySchema.Value)
}
}
}
return security
}

func (p *SchemaParser) sortSchemas() {
p.schemas = make([]*entity.Schema, 0)

Expand All @@ -77,37 +92,42 @@ func (p *SchemaParser) sortEndpoints() {
})
}

func (p *SchemaParser) AddEndpoint(key string, path *openapi3.PathItem) {
func (p *SchemaParser) addEndpoint(
doc *openapi3.T,
key string,
path *openapi3.PathItem,
security []*openapi3.SecurityScheme,
) {
if path.Connect != nil {
p.CreateEndpoint(key, entity.VERB_CONNECT, path.Connect)
p.createEndpoint(doc, key, entity.VERB_CONNECT, path.Connect, security)
}
if path.Delete != nil {
p.CreateEndpoint(key, entity.VERB_DELETE, path.Delete)
p.createEndpoint(doc, key, entity.VERB_DELETE, path.Delete, security)
}
if path.Get != nil {
p.CreateEndpoint(key, entity.VERB_GET, path.Get)
p.createEndpoint(doc, key, entity.VERB_GET, path.Get, security)
}
if path.Head != nil {
p.CreateEndpoint(key, entity.VERB_HEAD, path.Head)
p.createEndpoint(doc, key, entity.VERB_HEAD, path.Head, security)
}
if path.Options != nil {
p.CreateEndpoint(key, entity.VERB_OPTIONS, path.Options)
p.createEndpoint(doc, key, entity.VERB_OPTIONS, path.Options, security)
}
if path.Patch != nil {
p.CreateEndpoint(key, entity.VERB_PATCH, path.Patch)
p.createEndpoint(doc, key, entity.VERB_PATCH, path.Patch, security)
}
if path.Post != nil {
p.CreateEndpoint(key, entity.VERB_POST, path.Post)
p.createEndpoint(doc, key, entity.VERB_POST, path.Post, security)
}
if path.Put != nil {
p.CreateEndpoint(key, entity.VERB_PUT, path.Put)
p.createEndpoint(doc, key, entity.VERB_PUT, path.Put, security)
}
if path.Trace != nil {
p.CreateEndpoint(key, entity.VERB_TRACE, path.Trace)
p.createEndpoint(doc, key, entity.VERB_TRACE, path.Trace, security)
}
}

func (p *SchemaParser) Add(key string, schemaRef *openapi3.SchemaRef, display bool) *entity.Schema {
func (p *SchemaParser) add(key string, schemaRef *openapi3.SchemaRef, display bool) *entity.Schema {
schema := schemaRef.Value
ref := schemaRef.Ref
name := GetSchemaName(ref)
Expand All @@ -133,21 +153,91 @@ func (p *SchemaParser) Add(key string, schemaRef *openapi3.SchemaRef, display bo
return nil
}

func (p *SchemaParser) CreateEndpoint(key string, verb entity.Verb, operation *openapi3.Operation) *entity.Endpoint {
func (p *SchemaParser) createEndpoint(
doc *openapi3.T,
key string,
verb entity.Verb,
operation *openapi3.Operation,
security []*openapi3.SecurityScheme,
) *entity.Endpoint {
securitySchemes := p.createSecuritySchemes(doc, operation, security)
endpoint := &entity.Endpoint{
Verb: verb,
Name: GetEndpointName(verb, operation.OperationID, key),
Path: KeyToPath(key),
Params: GetParams(operation),
Query: p.GetQuery(operation),
Header: GetHeader(operation),
Header: GetHeader(operation, securitySchemes),
Body: p.GetBody(operation),
Response: p.GetResponses(operation, p.schemasMap),
Security: securitySchemes,
}
p.endpoints = append(p.endpoints, endpoint)
return endpoint
}

func (p *SchemaParser) createSecuritySchemes(
doc *openapi3.T,
operation *openapi3.Operation,
security []*openapi3.SecurityScheme,
) []entity.Security {
schemes := make(map[*openapi3.SecurityScheme]entity.Security)

for _, scheme := range security {
schemes[scheme] = p.createSecurity(scheme)
}

if operation.Security != nil {
for _, securityRequirement := range *operation.Security {
for name := range securityRequirement {
securitySchema, ok := doc.Components.SecuritySchemes[name]
if ok && securitySchema.Value != nil {
scheme := securitySchema.Value
schemes[scheme] = p.createSecurity(scheme)
}
}
}
}

result := make([]entity.Security, len(schemes))
index := 0
for _, scheme := range schemes {
result[index] = scheme
index++
}
fmt.Printf("%v", result)

return result
}

func (p *SchemaParser) createSecurity(scheme *openapi3.SecurityScheme) entity.Security {
switch strings.ToLower(scheme.Type) {
case "http":
switch strings.ToLower(scheme.Scheme) {
case "basic":
return entity.Security{
Type: entity.SECURITY_TYPE_BASIC,
}
case "bearer":
return entity.Security{
Type: entity.SECURITY_TYPE_BEARER,
}
default:
return entity.Security{
Type: entity.SECURITY_TYPE_BASIC,
}
}
case "oauth2":
return entity.Security{
Type: entity.SECURITY_TYPE_BEARER,
}
default:
return entity.Security{
Type: entity.SECURITY_TYPE_BASIC,
}
}
}

func GetEndpointName(verb entity.Verb, operationId string, key string) string {
if operationId != "" {
return GetPropertyName(operationId)
Expand Down
4 changes: 2 additions & 2 deletions parser/schema_to_struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func (p *SchemaParser) AddObject(
// name := schemaRef.Ref
// if name is empty, this is not a true ref
name := GetPropertyName(key)
schema := p.Add(key, schemaRef, false)
schema := p.add(key, schemaRef, false)
fields[index] = entity.NewField(
name,
schema,
Expand All @@ -146,7 +146,7 @@ func (p *SchemaParser) AddArray(
return item
}

items := p.Add("", schema.Items, false)
items := p.add("", schema.Items, false)

newItem := entity.NewArraySchema(ref, name, items, display)

Expand Down

0 comments on commit 9c7976c

Please sign in to comment.