Skip to content

Commit

Permalink
Add SignWithContext method to authority and mocks
Browse files Browse the repository at this point in the history
  • Loading branch information
hslatman committed Sep 19, 2023
1 parent b2301ea commit 4e06bdb
Show file tree
Hide file tree
Showing 11 changed files with 69 additions and 31 deletions.
4 changes: 4 additions & 0 deletions acme/api/revoke_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,10 @@ func (m *mockCA) Sign(*x509.CertificateRequest, provisioner.SignOptions, ...prov
return nil, nil
}

func (m *mockCA) SignWithContext(context.Context, *x509.CertificateRequest, provisioner.SignOptions, ...provisioner.SignOption) ([]*x509.Certificate, error) {
return nil, nil
}

func (m *mockCA) AreSANsAllowed(ctx context.Context, sans []string) error {
if m.MockAreSANsallowed != nil {
return m.MockAreSANsallowed(ctx, sans)
Expand Down
1 change: 1 addition & 0 deletions acme/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ var clock Clock
// CertificateAuthority is the interface implemented by a CA authority.
type CertificateAuthority interface {
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
AreSANsAllowed(ctx context.Context, sans []string) error
IsRevoked(sn string) (bool, error)
Revoke(context.Context, *authority.RevokeOptions) error
Expand Down
10 changes: 10 additions & 0 deletions acme/order_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ func TestOrder_UpdateStatus(t *testing.T) {

type mockSignAuth struct {
sign func(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
signWithContext func(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
areSANsAllowed func(ctx context.Context, sans []string) error
loadProvisionerByName func(string) (provisioner.Interface, error)
ret1, ret2 interface{}
Expand All @@ -287,6 +288,15 @@ func (m *mockSignAuth) Sign(csr *x509.CertificateRequest, signOpts provisioner.S
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
}

func (m *mockSignAuth) SignWithContext(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
if m.signWithContext != nil {
return m.signWithContext(ctx, csr, signOpts, extraOpts...)
} else if m.err != nil {
return nil, m.err
}
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
}

func (m *mockSignAuth) AreSANsAllowed(ctx context.Context, sans []string) error {
if m.areSANsAllowed != nil {
return m.areSANsAllowed(ctx, sans)
Expand Down
1 change: 1 addition & 0 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ type Authority interface {
GetTLSOptions() *config.TLSOptions
Root(shasum string) (*x509.Certificate, error)
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
Renew(peer *x509.Certificate) ([]*x509.Certificate, error)
RenewContext(ctx context.Context, peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
Rekey(peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
Expand Down
8 changes: 8 additions & 0 deletions api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ type mockAuthority struct {
getTLSOptions func() *authority.TLSOptions
root func(shasum string) (*x509.Certificate, error)
sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
signWithContext func(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
renew func(cert *x509.Certificate) ([]*x509.Certificate, error)
rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
renewContext func(ctx context.Context, oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
Expand Down Expand Up @@ -261,6 +262,13 @@ func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.SignO
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
}

func (m *mockAuthority) SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
if m.signWithContext != nil {
return m.signWithContext(ctx, cr, opts, signOpts...)
}
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
}

func (m *mockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, error) {
if m.renew != nil {
return m.renew(cert)
Expand Down
19 changes: 9 additions & 10 deletions authority/provisioner/webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ type WebhookController struct {

// Enrich fetches data from remote servers and adds returned data to the
// templateData
func (wc *WebhookController) Enrich(req *webhook.RequestBody) error {
func (wc *WebhookController) Enrich(ctx context.Context, req *webhook.RequestBody) error {
if wc == nil {
return nil
}
Expand All @@ -56,11 +56,11 @@ func (wc *WebhookController) Enrich(req *webhook.RequestBody) error {
if !wc.isCertTypeOK(wh) {
continue
}
// TODO(hs): propagate context from above
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()

resp, err := wh.DoWithContext(ctx, wc.client, req, wc.TemplateData)
whCtx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel() //nolint:gocritic // every request canceled with its own timeout

resp, err := wh.DoWithContext(whCtx, wc.client, req, wc.TemplateData)
if err != nil {
return err
}
Expand All @@ -73,7 +73,7 @@ func (wc *WebhookController) Enrich(req *webhook.RequestBody) error {
}

// Authorize checks that all remote servers allow the request
func (wc *WebhookController) Authorize(req *webhook.RequestBody) error {
func (wc *WebhookController) Authorize(ctx context.Context, req *webhook.RequestBody) error {
if wc == nil {
return nil
}
Expand All @@ -93,11 +93,10 @@ func (wc *WebhookController) Authorize(req *webhook.RequestBody) error {
continue
}

// TODO(hs): propagate context from above
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
whCtx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel() //nolint:gocritic // every request canceled with its own timeout

resp, err := wh.DoWithContext(ctx, wc.client, req, wc.TemplateData)
resp, err := wh.DoWithContext(whCtx, wc.client, req, wc.TemplateData)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions authority/provisioner/webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ func TestWebhookController_Enrich(t *testing.T) {
wh.URL = ts.URL
}

err := test.ctl.Enrich(test.req)
err := test.ctl.Enrich(context.Background(), test.req)
if (err != nil) != test.expectErr {
t.Fatalf("Got err %v, want %v", err, test.expectErr)
}
Expand Down Expand Up @@ -352,7 +352,7 @@ func TestWebhookController_Authorize(t *testing.T) {
wh.URL = ts.URL
}

err := test.ctl.Authorize(test.req)
err := test.ctl.Authorize(context.Background(), test.req)
if (err != nil) != test.expectErr {
t.Fatalf("Got err %v, want %v", err, test.expectErr)
}
Expand Down
14 changes: 7 additions & 7 deletions authority/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func (a *Authority) GetSSHBastion(ctx context.Context, user, hostname string) (*
}

// SignSSH creates a signed SSH certificate with the given public key and options.
func (a *Authority) SignSSH(_ context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
var (
certOptions []sshutil.Option
mods []provisioner.SSHCertModifier
Expand Down Expand Up @@ -205,7 +205,7 @@ func (a *Authority) SignSSH(_ context.Context, key ssh.PublicKey, opts provision
}

// Call enriching webhooks
if err := callEnrichingWebhooksSSH(webhookCtl, cr); err != nil {
if err := callEnrichingWebhooksSSH(ctx, webhookCtl, cr); err != nil {
return nil, errs.ApplyOptions(
errs.ForbiddenErr(err, err.Error()),
errs.WithKeyVal("signOptions", signOpts),
Expand Down Expand Up @@ -277,7 +277,7 @@ func (a *Authority) SignSSH(_ context.Context, key ssh.PublicKey, opts provision
}

// Send certificate to webhooks for authorization
if err := callAuthorizingWebhooksSSH(webhookCtl, certificate, certTpl); err != nil {
if err := callAuthorizingWebhooksSSH(ctx, webhookCtl, certificate, certTpl); err != nil {
return nil, errs.ApplyOptions(
errs.ForbiddenErr(err, "authority.SignSSH: error signing certificate"),
)
Expand Down Expand Up @@ -653,7 +653,7 @@ func (a *Authority) getAddUserCommand(principal string) string {
return strings.ReplaceAll(cmd, "<principal>", principal)
}

func callEnrichingWebhooksSSH(webhookCtl webhookController, cr sshutil.CertificateRequest) error {
func callEnrichingWebhooksSSH(ctx context.Context, webhookCtl webhookController, cr sshutil.CertificateRequest) error {
if webhookCtl == nil {
return nil
}
Expand All @@ -663,10 +663,10 @@ func callEnrichingWebhooksSSH(webhookCtl webhookController, cr sshutil.Certifica
if err != nil {
return err
}
return webhookCtl.Enrich(whEnrichReq)
return webhookCtl.Enrich(ctx, whEnrichReq)
}

func callAuthorizingWebhooksSSH(webhookCtl webhookController, cert *sshutil.Certificate, certTpl *ssh.Certificate) error {
func callAuthorizingWebhooksSSH(ctx context.Context, webhookCtl webhookController, cert *sshutil.Certificate, certTpl *ssh.Certificate) error {
if webhookCtl == nil {
return nil
}
Expand All @@ -676,5 +676,5 @@ func callAuthorizingWebhooksSSH(webhookCtl webhookController, cert *sshutil.Cert
if err != nil {
return err
}
return webhookCtl.Authorize(whAuthBody)
return webhookCtl.Authorize(ctx, whAuthBody)
}
23 changes: 16 additions & 7 deletions authority/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,17 @@ func withDefaultASN1DN(def *config.ASN1DN) provisioner.CertificateModifierFunc {
}
}

// Sign creates a signed certificate from a certificate signing request.
// Sign creates a signed certificate from a certificate signing request. It
// creates a new context.Context, and calls into SignWithContext.
//
// Deprecated: Use authority.SignWithContext with an actual context.Context.
func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
return a.SignWithContext(context.Background(), csr, signOpts, extraOpts...)
}

// SignWithContext creates a signed certificate from a certificate signing request,
// taking the provided context.Context.
func (a *Authority) SignWithContext(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
var (
certOptions []x509util.Option
certValidators []provisioner.CertificateValidator
Expand Down Expand Up @@ -163,7 +172,7 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
}
}

if err := callEnrichingWebhooksX509(webhookCtl, attData, csr); err != nil {
if err := callEnrichingWebhooksX509(ctx, webhookCtl, attData, csr); err != nil {
return nil, errs.ApplyOptions(
errs.ForbiddenErr(err, err.Error()),
errs.WithKeyVal("csr", csr),
Expand Down Expand Up @@ -256,7 +265,7 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
}

// Send certificate to webhooks for authorization
if err := callAuthorizingWebhooksX509(webhookCtl, cert, leaf, attData); err != nil {
if err := callAuthorizingWebhooksX509(ctx, webhookCtl, cert, leaf, attData); err != nil {
return nil, errs.ApplyOptions(
errs.ForbiddenErr(err, "error creating certificate"),
opts...,
Expand Down Expand Up @@ -952,7 +961,7 @@ func templatingError(err error) error {
return errors.Wrap(cause, "error applying certificate template")
}

func callEnrichingWebhooksX509(webhookCtl webhookController, attData *provisioner.AttestationData, csr *x509.CertificateRequest) error {
func callEnrichingWebhooksX509(ctx context.Context, webhookCtl webhookController, attData *provisioner.AttestationData, csr *x509.CertificateRequest) error {
if webhookCtl == nil {
return nil
}
Expand All @@ -969,10 +978,10 @@ func callEnrichingWebhooksX509(webhookCtl webhookController, attData *provisione
if err != nil {
return err
}
return webhookCtl.Enrich(whEnrichReq)
return webhookCtl.Enrich(ctx, whEnrichReq)
}

func callAuthorizingWebhooksX509(webhookCtl webhookController, cert *x509util.Certificate, leaf *x509.Certificate, attData *provisioner.AttestationData) error {
func callAuthorizingWebhooksX509(ctx context.Context, webhookCtl webhookController, cert *x509util.Certificate, leaf *x509.Certificate, attData *provisioner.AttestationData) error {
if webhookCtl == nil {
return nil
}
Expand All @@ -989,5 +998,5 @@ func callAuthorizingWebhooksX509(webhookCtl webhookController, cert *x509util.Ce
if err != nil {
return err
}
return webhookCtl.Authorize(whAuthBody)
return webhookCtl.Authorize(ctx, whAuthBody)
}
10 changes: 7 additions & 3 deletions authority/webhook.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package authority

import "github.com/smallstep/certificates/webhook"
import (
"context"

"github.com/smallstep/certificates/webhook"
)

type webhookController interface {
Enrich(*webhook.RequestBody) error
Authorize(*webhook.RequestBody) error
Enrich(context.Context, *webhook.RequestBody) error
Authorize(context.Context, *webhook.RequestBody) error
}
6 changes: 4 additions & 2 deletions authority/webhook_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package authority

import (
"context"

"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/webhook"
)
Expand All @@ -14,14 +16,14 @@ type mockWebhookController struct {

var _ webhookController = &mockWebhookController{}

func (wc *mockWebhookController) Enrich(*webhook.RequestBody) error {
func (wc *mockWebhookController) Enrich(context.Context, *webhook.RequestBody) error {
for key, data := range wc.respData {
wc.templateData.SetWebhook(key, data)
}

return wc.enrichErr
}

func (wc *mockWebhookController) Authorize(*webhook.RequestBody) error {
func (wc *mockWebhookController) Authorize(context.Context, *webhook.RequestBody) error {
return wc.authorizeErr
}

0 comments on commit 4e06bdb

Please sign in to comment.