Skip to content

Commit

Permalink
Merge pull request #17 from tetratelabs/move-internal
Browse files Browse the repository at this point in the history
Refactor HTTP healthcheck
  • Loading branch information
chirauki authored Feb 20, 2024
2 parents 052d2f8 + 809bc9a commit e46aad7
Show file tree
Hide file tree
Showing 14 changed files with 241 additions and 115 deletions.
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (

"github.com/hashicorp/terraform-plugin-framework/providerserver"

"github.com/tetratelabs/terraform-provider-checkmate/internal/provider"
"github.com/tetratelabs/terraform-provider-checkmate/pkg/provider"
)

// Run "go generate" to format example terraform files and generate the docs for the registry/website
Expand Down
214 changes: 214 additions & 0 deletions pkg/healthcheck/http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
// Copyright 2024 Tetrate
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package healthcheck

import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"

"github.com/hashicorp/go-multierror"
"github.com/hashicorp/terraform-plugin-framework/diag"
"github.com/hashicorp/terraform-plugin-log/tflog"
"github.com/tetratelabs/terraform-provider-checkmate/pkg/helpers"
)

type HttpHealthArgs struct {
URL string
Method string
Timeout int64
RequestTimeout int64
Interval int64
StatusCode string
ConsecutiveSuccesses int64
Headers map[string]string
IgnoreFailure bool
Passed bool
RequestBody string
ResultBody string
CABundle string
InsecureTLS bool
}

func HealthCheck(ctx context.Context, data *HttpHealthArgs, diag *diag.Diagnostics) error {
var err error

data.Passed = false
endpoint, err := url.Parse(data.URL)
if err != nil {
diagAddError(diag, "Client Error", fmt.Sprintf("Unable to parse url %q, got error %s", data.URL, err))
return fmt.Errorf("parse url %q: %w", data.URL, err)
}

var checkCode func(int) (bool, error)
// check the pattern once
checkStatusCode(data.StatusCode, 0, diag)
if diag.HasError() {
return fmt.Errorf("bad status code pattern")
}
checkCode = func(c int) (bool, error) { return checkStatusCode(data.StatusCode, c, diag) }

// normalize headers
headers := make(map[string][]string)
if data.Headers != nil {
for k, v := range data.Headers {
headers[k] = []string{v}
}
}

window := helpers.RetryWindow{
Timeout: time.Duration(data.Timeout) * time.Millisecond,
Interval: time.Duration(data.Interval) * time.Millisecond,
ConsecutiveSuccesses: int(data.ConsecutiveSuccesses),
}
data.ResultBody = ""

if data.CABundle != "" && data.InsecureTLS {
diagAddError(diag, "Conflicting configuration", "You cannot specify both custom CA and insecure TLS. Please use only one of them.")
}
tlsConfig := &tls.Config{}
if data.CABundle != "" {
caCertPool := x509.NewCertPool()
if ok := caCertPool.AppendCertsFromPEM([]byte(data.CABundle)); !ok {
diagAddError(diag, "Building CA cert pool", err.Error())
multierror.Append(err, fmt.Errorf("build CA cert pool: %w", err))
}
tlsConfig.RootCAs = caCertPool
}
tlsConfig.InsecureSkipVerify = data.InsecureTLS

client := http.Client{
Transport: &http.Transport{
TLSClientConfig: tlsConfig,
},
Timeout: time.Duration(data.RequestTimeout) * time.Millisecond,
}

tflog.Debug(ctx, fmt.Sprintf("Starting HTTP health check. Overall timeout: %d ms, request timeout: %d ms", data.Timeout, data.RequestTimeout))
for h, v := range headers {
tflog.Debug(ctx, fmt.Sprintf("%s: %s", h, v))
}

result := window.Do(func(attempt int, successes int) bool {
if successes != 0 {
tflog.Trace(ctx, fmt.Sprintf("SUCCESS [%d/%d] http %s %s", successes, data.ConsecutiveSuccesses, data.Method, endpoint))
} else {
tflog.Trace(ctx, fmt.Sprintf("ATTEMPT #%d http %s %s", attempt, data.Method, endpoint))
}

httpResponse, err := client.Do(&http.Request{
URL: endpoint,
Method: data.Method,
Header: headers,
Body: io.NopCloser(strings.NewReader(data.RequestBody)),
})
if err != nil {
tflog.Warn(ctx, fmt.Sprintf("CONNECTION FAILURE %v", err))
return false
}

success, err := checkCode(httpResponse.StatusCode)
if err != nil {
diagAddError(diag, "check status code", err.Error())
multierror.Append(err, fmt.Errorf("check status code: %w", err))
}
if success {
tflog.Trace(ctx, fmt.Sprintf("SUCCESS CODE %d", httpResponse.StatusCode))
body, err := io.ReadAll(httpResponse.Body)
if err != nil {
tflog.Warn(ctx, fmt.Sprintf("ERROR READING BODY %v", err))
data.ResultBody = ""
} else {
tflog.Warn(ctx, fmt.Sprintf("READ %d BYTES", len(body)))
data.ResultBody = string(body)
}
} else {
tflog.Trace(ctx, fmt.Sprintf("FAILURE CODE %d", httpResponse.StatusCode))
}
return success
})

switch result {
case helpers.Success:
data.Passed = true
case helpers.TimeoutExceeded:
diagAddWarning(diag, "Timeout exceeded", fmt.Sprintf("Timeout of %d milliseconds exceeded", data.Timeout))
if !data.IgnoreFailure {
diagAddError(diag, "Check failed", "The check did not pass within the timeout and create_anyway_on_check_failure is false")
multierror.Append(err, fmt.Errorf("the check did not pass within the timeout and create_anyway_on_check_failure is false"))
}
}

return err
}

func checkStatusCode(pattern string, code int, diag *diag.Diagnostics) (bool, error) {
ranges := strings.Split(pattern, ",")
for _, r := range ranges {
bounds := strings.Split(r, "-")
if len(bounds) == 2 {
left, err := strconv.Atoi(bounds[0])
if err != nil {
diagAddError(diag, "Bad status code pattern", fmt.Sprintf("Can't convert %s to integer. %s", bounds[0], err))
return false, fmt.Errorf("convert %q to integer: %w", bounds[0], err)
}
right, err := strconv.Atoi(bounds[1])
if err != nil {
diagAddError(diag, "Bad status code pattern", fmt.Sprintf("Can't convert %s to integer. %s", bounds[1], err))
return false, fmt.Errorf("convert %q to integer: %w", bounds[0], err)
}
if left > right {
diagAddError(diag, "Bad status code pattern", fmt.Sprintf("Left bound %d is greater than right bound %d", left, right))
return false, fmt.Errorf("left bound %d is greater than right bound %d", left, right)
}
if left <= code && right >= code {
return true, nil
}
} else if len(bounds) == 1 {
val, err := strconv.Atoi(bounds[0])
if err != nil {
diagAddError(diag, "Bad status code pattern", fmt.Sprintf("Can't convert %s to integer. %s", bounds[0], err))
return false, fmt.Errorf("convert %q to integer: %w", bounds[0], err)
}
if val == code {
return true, nil
}
} else {
diagAddError(diag, "Bad status code pattern", "Too many dashes in range pattern")
return false, errors.New("too many dashes in range pattern")
}
}
return false, errors.New("status code does not match pattern")
}

func diagAddError(diag *diag.Diagnostics, summary string, details string) {
if diag != nil {
diag.AddError(summary, details)
}
}

func diagAddWarning(diag *diag.Diagnostics, summary string, details string) {
if diag != nil {
diag.AddWarning(summary, details)
}
}
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion pkg/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ package pkg
import (
framework "github.com/hashicorp/terraform-plugin-framework/provider"

"github.com/tetratelabs/terraform-provider-checkmate/internal/provider"
"github.com/tetratelabs/terraform-provider-checkmate/pkg/provider"
)

var version string = "1.6.0"
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,9 @@ package provider

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"

"github.com/google/uuid"
"github.com/hashicorp/terraform-plugin-framework/diag"
Expand All @@ -34,10 +28,9 @@ import (
"github.com/hashicorp/terraform-plugin-framework/resource/schema/planmodifier"
"github.com/hashicorp/terraform-plugin-framework/resource/schema/stringplanmodifier"
"github.com/hashicorp/terraform-plugin-framework/types"
"github.com/hashicorp/terraform-plugin-log/tflog"

"github.com/tetratelabs/terraform-provider-checkmate/internal/helpers"
"github.com/tetratelabs/terraform-provider-checkmate/internal/modifiers"
"github.com/tetratelabs/terraform-provider-checkmate/pkg/healthcheck"
"github.com/tetratelabs/terraform-provider-checkmate/pkg/modifiers"
)

// Ensure provider defined types fully satisfy framework interfaces
Expand Down Expand Up @@ -182,113 +175,32 @@ func (r *HttpHealthResource) Create(ctx context.Context, req resource.CreateRequ
}

func (r *HttpHealthResource) HealthCheck(ctx context.Context, data *HttpHealthResourceModel, diag *diag.Diagnostics) {
data.Passed = types.BoolValue(false)
endpoint, err := url.Parse(data.URL.ValueString())
if err != nil {
diag.AddError("Client Error", fmt.Sprintf("Unable to parse url %q, got error %s", data.URL.ValueString(), err))
return
}

var checkCode func(int) bool
// check the pattern once
checkStatusCode(data.StatusCode.ValueString(), 0, diag)
if diag.HasError() {
return
}
checkCode = func(c int) bool { return checkStatusCode(data.StatusCode.ValueString(), c, diag) }

// normalize headers
headers := make(map[string][]string)
var tmp map[string]string
if !data.Headers.IsNull() {
tmp := make(map[string]string)
diag.Append(data.Headers.ElementsAs(ctx, &tmp, false)...)
if diag.HasError() {
return
}

for k, v := range tmp {
headers[k] = []string{v}
}
}

window := helpers.RetryWindow{
Timeout: time.Duration(data.Timeout.ValueInt64()) * time.Millisecond,
Interval: time.Duration(data.Interval.ValueInt64()) * time.Millisecond,
ConsecutiveSuccesses: int(data.ConsecutiveSuccesses.ValueInt64()),
}
data.ResultBody = types.StringValue("")

if !data.CABundle.IsNull() && data.InsecureTLS.ValueBool() {
diag.AddError("Conflicting configuration", "You cannot specify both custom CA and insecure TLS. Please use only one of them.")
}
tlsConfig := &tls.Config{}
if !data.CABundle.IsNull() {
caCertPool := x509.NewCertPool()
if ok := caCertPool.AppendCertsFromPEM([]byte(data.CABundle.ValueString())); !ok {
diag.AddError("Building CA cert pool", err.Error())
}
tlsConfig.RootCAs = caCertPool
args := healthcheck.HttpHealthArgs{
URL: data.URL.ValueString(),
Method: data.Method.ValueString(),
Timeout: data.Timeout.ValueInt64(),
RequestTimeout: data.RequestTimeout.ValueInt64(),
Interval: data.Interval.ValueInt64(),
StatusCode: data.StatusCode.ValueString(),
ConsecutiveSuccesses: data.ConsecutiveSuccesses.ValueInt64(),
Headers: tmp,
IgnoreFailure: data.IgnoreFailure.ValueBool(),
RequestBody: data.RequestBody.ValueString(),
CABundle: data.CABundle.ValueString(),
InsecureTLS: data.InsecureTLS.ValueBool(),
}
tlsConfig.InsecureSkipVerify = data.InsecureTLS.ValueBool()

client := http.Client{
Transport: &http.Transport{
TLSClientConfig: tlsConfig,
},
Timeout: time.Duration(data.RequestTimeout.ValueInt64()) * time.Millisecond,
}

tflog.Debug(ctx, fmt.Sprintf("Starting HTTP health check. Overall timeout: %d ms, request timeout: %d ms", data.Timeout.ValueInt64(), data.RequestTimeout.ValueInt64()))
for h, v := range headers {
tflog.Debug(ctx, fmt.Sprintf("%s: %s", h, v))
}

result := window.Do(func(attempt int, successes int) bool {
if successes != 0 {
tflog.Trace(ctx, fmt.Sprintf("SUCCESS [%d/%d] http %s %s", successes, data.ConsecutiveSuccesses.ValueInt64(), data.Method.ValueString(), endpoint))
} else {
tflog.Trace(ctx, fmt.Sprintf("ATTEMPT #%d http %s %s", attempt, data.Method.ValueString(), endpoint))
}

httpResponse, err := client.Do(&http.Request{
URL: endpoint,
Method: data.Method.ValueString(),
Header: headers,
Body: io.NopCloser(strings.NewReader(data.RequestBody.ValueString())),
})
if err != nil {
tflog.Warn(ctx, fmt.Sprintf("CONNECTION FAILURE %v", err))
return false
}

success := checkCode(httpResponse.StatusCode)
if success {
tflog.Trace(ctx, fmt.Sprintf("SUCCESS CODE %d", httpResponse.StatusCode))
body, err := io.ReadAll(httpResponse.Body)
if err != nil {
tflog.Warn(ctx, fmt.Sprintf("ERROR READING BODY %v", err))
data.ResultBody = types.StringValue("")
} else {
tflog.Warn(ctx, fmt.Sprintf("READ %d BYTES", len(body)))
data.ResultBody = types.StringValue(string(body))
}
} else {
tflog.Trace(ctx, fmt.Sprintf("FAILURE CODE %d", httpResponse.StatusCode))
}
return success
})

switch result {
case helpers.Success:
data.Passed = types.BoolValue(true)
case helpers.TimeoutExceeded:
diag.AddWarning("Timeout exceeded", fmt.Sprintf("Timeout of %d milliseconds exceeded", data.Timeout.ValueInt64()))
if !data.IgnoreFailure.ValueBool() {
diag.AddError("Check failed", "The check did not pass within the timeout and create_anyway_on_check_failure is false")
return
}
err := healthcheck.HealthCheck(ctx, &args, diag)
if err != nil {
diag.AddError("Health Check Error", fmt.Sprintf("Error during health check: %s", err))
}

data.Passed = types.BoolValue(args.Passed)
data.ResultBody = types.StringValue(args.ResultBody)
}

func (r *HttpHealthResource) Read(ctx context.Context, req resource.ReadRequest, resp *resource.ReadResponse) {
Expand Down
File renamed without changes.
Loading

0 comments on commit e46aad7

Please sign in to comment.