From b02048a36da85070c4e2b1cb6bea5a9f5e0e2814 Mon Sep 17 00:00:00 2001 From: David Gregory <42326109+hunoz@users.noreply.github.com> Date: Tue, 22 Oct 2024 21:42:27 -0500 Subject: [PATCH 1/3] Add function to get token for renewing certificates --- ca/provisioner.go | 10 +++ ca/provisioner_test.go | 198 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 208 insertions(+) diff --git a/ca/provisioner.go b/ca/provisioner.go index d5b23f38c..3fb264a58 100644 --- a/ca/provisioner.go +++ b/ca/provisioner.go @@ -85,6 +85,7 @@ func (p *Provisioner) SetFingerprint(sum string) { // Token generates a bootstrap token for a subject. func (p *Provisioner) Token(subject string, sans ...string) (string, error) { + p.endpoint.ResolveReference(&url.URL{Path: "/1.0/renew"}) if len(sans) == 0 { sans = []string{subject} } @@ -118,6 +119,15 @@ func (p *Provisioner) Token(subject string, sans ...string) (string, error) { return tok.SignedString(p.jwk.Algorithm, p.jwk.Key) } +func (p *Provisioner) RenewalToken(subject string, sans ...string) (string, error) { + oldAudience := p.audience + u := p.endpoint.ResolveReference(&url.URL{Path: "/1.0/renew"}).String() + p.audience = u + response, err := p.Token(subject, sans...) + p.audience = oldAudience + return response, err +} + // SSHToken generates a SSH token. func (p *Provisioner) SSHToken(certType, keyID string, principals []string) (string, error) { jwtID, err := randutil.Hex(64) diff --git a/ca/provisioner_test.go b/ca/provisioner_test.go index 5a754f084..1a28608a2 100644 --- a/ca/provisioner_test.go +++ b/ca/provisioner_test.go @@ -296,6 +296,204 @@ func TestProvisioner_IPv6Token(t *testing.T) { } } +func TestProvisioner_RenewalToken(t *testing.T) { + p := getTestProvisioner(t, "https://127.0.0.1:9000") + sha := "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7" + + type fields struct { + name string + kid string + fingerprint string + jwk *jose.JSONWebKey + tokenLifetime time.Duration + } + type args struct { + subject string + sans []string + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + {"ok", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"subject", nil}, false}, + {"ok-with-san", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"subject", []string{"foo.smallstep.com"}}, false}, + {"ok-with-sans", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"subject", []string{"foo.smallstep.com", "127.0.0.1"}}, false}, + {"fail-no-subject", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"", []string{"foo.smallstep.com"}}, true}, + {"fail-no-key", fields{p.name, p.kid, sha, &jose.JSONWebKey{}, p.tokenLifetime}, args{"subject", nil}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &Provisioner{ + name: tt.fields.name, + kid: tt.fields.kid, + audience: "https://127.0.0.1:9000/1.0/sign", + fingerprint: tt.fields.fingerprint, + jwk: tt.fields.jwk, + tokenLifetime: tt.fields.tokenLifetime, + } + got, err := p.RenewalToken(tt.args.subject, tt.args.sans...) + if (err != nil) != tt.wantErr { + t.Errorf("Provisioner.Token() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr == false { + jwt, err := jose.ParseSigned(got) + if err != nil { + t.Error(err) + return + } + var claims jose.Claims + if err := jwt.Claims(tt.fields.jwk.Public(), &claims); err != nil { + t.Error(err) + return + } + if err := claims.ValidateWithLeeway(jose.Expected{ + Audience: []string{"https://127.0.0.1:9000/1.0/renew"}, + Issuer: tt.fields.name, + Subject: tt.args.subject, + Time: time.Now().UTC(), + }, time.Minute); err != nil { + t.Error(err) + return + } + lifetime := claims.Expiry.Time().Sub(claims.NotBefore.Time()) + if lifetime != tt.fields.tokenLifetime { + t.Errorf("Claims token life time = %s, want %s", lifetime, tt.fields.tokenLifetime) + } + allClaims := make(map[string]interface{}) + if err := jwt.Claims(tt.fields.jwk.Public(), &allClaims); err != nil { + t.Error(err) + return + } + if v, ok := allClaims["sha"].(string); !ok || v != sha { + t.Errorf("Claim sha = %s, want %s", v, sha) + } + if len(tt.args.sans) == 0 { + if v, ok := allClaims["sans"].([]interface{}); !ok || !reflect.DeepEqual(v, []interface{}{tt.args.subject}) { + t.Errorf("Claim sans = %s, want %s", v, []interface{}{tt.args.subject}) + } + } else { + want := []interface{}{} + for _, s := range tt.args.sans { + want = append(want, s) + } + if v, ok := allClaims["sans"].([]interface{}); !ok || !reflect.DeepEqual(v, want) { + t.Errorf("Claim sans = %s, want %s", v, want) + } + } + if v, ok := allClaims["jti"].(string); !ok || v == "" { + t.Errorf("Claim jti = %s, want not blank", v) + } + if p.audience != "https://127.0.0.1:9000/1.0/sign" { + t.Errorf("Provisioner audience = %s, want %s", p.audience, "https://127.0.0.1:9000/1.0/sign") + } + } + }) + } +} + +func TestProvisioner_RenewalIPv6Token(t *testing.T) { + p := getTestProvisioner(t, "https://[::1]:9000") + sha := "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7" + + type fields struct { + name string + kid string + fingerprint string + jwk *jose.JSONWebKey + tokenLifetime time.Duration + } + type args struct { + subject string + sans []string + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + {"ok", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"subject", nil}, false}, + {"ok-with-san", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"subject", []string{"foo.smallstep.com"}}, false}, + {"ok-with-sans", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"subject", []string{"foo.smallstep.com", "127.0.0.1"}}, false}, + {"fail-no-subject", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"", []string{"foo.smallstep.com"}}, true}, + {"fail-no-key", fields{p.name, p.kid, sha, &jose.JSONWebKey{}, p.tokenLifetime}, args{"subject", nil}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &Provisioner{ + name: tt.fields.name, + kid: tt.fields.kid, + audience: "https://[::1]:9000/1.0/sign", + fingerprint: tt.fields.fingerprint, + jwk: tt.fields.jwk, + tokenLifetime: tt.fields.tokenLifetime, + } + got, err := p.RenewalToken(tt.args.subject, tt.args.sans...) + if (err != nil) != tt.wantErr { + t.Errorf("Provisioner.Token() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr == false { + jwt, err := jose.ParseSigned(got) + if err != nil { + t.Error(err) + return + } + var claims jose.Claims + if err := jwt.Claims(tt.fields.jwk.Public(), &claims); err != nil { + t.Error(err) + return + } + if err := claims.ValidateWithLeeway(jose.Expected{ + Audience: []string{"https://[::1]:9000/1.0/renew"}, + Issuer: tt.fields.name, + Subject: tt.args.subject, + Time: time.Now().UTC(), + }, time.Minute); err != nil { + t.Error(err) + return + } + lifetime := claims.Expiry.Time().Sub(claims.NotBefore.Time()) + if lifetime != tt.fields.tokenLifetime { + t.Errorf("Claims token life time = %s, want %s", lifetime, tt.fields.tokenLifetime) + } + allClaims := make(map[string]interface{}) + if err := jwt.Claims(tt.fields.jwk.Public(), &allClaims); err != nil { + t.Error(err) + return + } + if v, ok := allClaims["sha"].(string); !ok || v != sha { + t.Errorf("Claim sha = %s, want %s", v, sha) + } + if len(tt.args.sans) == 0 { + if v, ok := allClaims["sans"].([]interface{}); !ok || !reflect.DeepEqual(v, []interface{}{tt.args.subject}) { + t.Errorf("Claim sans = %s, want %s", v, []interface{}{tt.args.subject}) + } + } else { + want := []interface{}{} + for _, s := range tt.args.sans { + want = append(want, s) + } + if v, ok := allClaims["sans"].([]interface{}); !ok || !reflect.DeepEqual(v, want) { + t.Errorf("Claim sans = %s, want %s", v, want) + } + } + if v, ok := allClaims["jti"].(string); !ok || v == "" { + t.Errorf("Claim jti = %s, want not blank", v) + } + if p.audience != "https://[::1]:9000/1.0/sign" { + t.Errorf("Provisioner audience = %s, want %s", p.audience, "https://[::1]:9000/1.0/sign") + } + } + }) + } +} + func TestProvisioner_SSHToken(t *testing.T) { p := getTestProvisioner(t, "https://127.0.0.1:9000") sha := "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7" From 39746aef3e51d2036ea2a050736443b2d10de364 Mon Sep 17 00:00:00 2001 From: David Gregory <42326109+hunoz@users.noreply.github.com> Date: Tue, 22 Oct 2024 21:55:50 -0500 Subject: [PATCH 2/3] Removing code accidentally left in --- ca/provisioner.go | 1 - 1 file changed, 1 deletion(-) diff --git a/ca/provisioner.go b/ca/provisioner.go index 3fb264a58..7ffa28a6f 100644 --- a/ca/provisioner.go +++ b/ca/provisioner.go @@ -85,7 +85,6 @@ func (p *Provisioner) SetFingerprint(sum string) { // Token generates a bootstrap token for a subject. func (p *Provisioner) Token(subject string, sans ...string) (string, error) { - p.endpoint.ResolveReference(&url.URL{Path: "/1.0/renew"}) if len(sans) == 0 { sans = []string{subject} } From a393e32cb28618b4e09cb1b81351a6fa3c4e89fe Mon Sep 17 00:00:00 2001 From: David Gregory <42326109+hunoz@users.noreply.github.com> Date: Tue, 22 Oct 2024 22:35:30 -0500 Subject: [PATCH 3/3] Add documentation to new function --- ca/provisioner.go | 1 + 1 file changed, 1 insertion(+) diff --git a/ca/provisioner.go b/ca/provisioner.go index 7ffa28a6f..469e7d28c 100644 --- a/ca/provisioner.go +++ b/ca/provisioner.go @@ -118,6 +118,7 @@ func (p *Provisioner) Token(subject string, sans ...string) (string, error) { return tok.SignedString(p.jwk.Algorithm, p.jwk.Key) } +// RenewalToken generates a token for use with renewing certificates func (p *Provisioner) RenewalToken(subject string, sans ...string) (string, error) { oldAudience := p.audience u := p.endpoint.ResolveReference(&url.URL{Path: "/1.0/renew"}).String()