From 466cc088fd86261c1de4dcba7167953374dbe8b8 Mon Sep 17 00:00:00 2001 From: Mike Date: Mon, 16 Dec 2024 13:49:31 -0700 Subject: [PATCH] refactor: add validation to apf --- pkg/apf/processor.go | 34 +++- pkg/apf/vaildator.go | 72 +++++++ pkg/apf/validator_test.go | 397 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 495 insertions(+), 8 deletions(-) create mode 100644 pkg/apf/vaildator.go create mode 100644 pkg/apf/validator_test.go diff --git a/pkg/apf/processor.go b/pkg/apf/processor.go index 6008c697..c8d666cd 100644 --- a/pkg/apf/processor.go +++ b/pkg/apf/processor.go @@ -23,7 +23,9 @@ func Process(data []byte, session *Session) bytes.Buffer { case APF_GLOBAL_REQUEST: // 80 log.Debug("received APF_GLOBAL_REQUEST") - dataToSend = ProcessGlobalRequest(data) + if ValidateGlobalRequest(data) { + dataToSend = ProcessGlobalRequest(data) + } case APF_CHANNEL_OPEN: // (90) Sent by Intel AMT when a channel needs to be open from Intel AMT. This is not common, but WSMAN events are a good example of channel coming from AMT. log.Debug("received APF_CHANNEL_OPEN") case APF_DISCONNECT: // (1) Intel AMT wants to completely disconnect. Not sure when this happens. @@ -31,29 +33,45 @@ func Process(data []byte, session *Session) bytes.Buffer { case APF_SERVICE_REQUEST: // (5) log.Debug("received APF_SERVICE_REQUEST") - dataToSend = ProcessServiceRequest(data) + if ValidateServiceRequest(data) { + dataToSend = ProcessServiceRequest(data) + } case APF_CHANNEL_OPEN_CONFIRMATION: // (91) Intel AMT confirmation to an APF_CHANNEL_OPEN request. log.Debug("received APF_CHANNEL_OPEN_CONFIRMATION") - ProcessChannelOpenConfirmation(data, session) + if ValidateChannelOpenConfirmation(data) { + ProcessChannelOpenConfirmation(data, session) + } case APF_CHANNEL_OPEN_FAILURE: // (92) Intel AMT rejected our connection attempt. log.Debug("received APF_CHANNEL_OPEN_FAILURE") - ProcessChannelOpenFailure(data, session) + if ValidateChannelOpenFailure(data) { + ProcessChannelOpenFailure(data, session) + } case APF_CHANNEL_CLOSE: // (97) Intel AMT is closing this channel, we need to disconnect the LMS TCP connection log.Debug("received APF_CHANNEL_CLOSE") - ProcessChannelClose(data, session) + if ValidateChannelClose(data) { + ProcessChannelClose(data, session) + } case APF_CHANNEL_DATA: // (94) Intel AMT is sending data that we must relay into an LMS TCP connection. - ProcessChannelData(data, session) + log.Debug("received APF_CHANNEL_DATA") + + if ValidateChannelData(data) { + ProcessChannelData(data, session) + } case APF_CHANNEL_WINDOW_ADJUST: // 93 log.Debug("received APF_CHANNEL_WINDOW_ADJUST") - ProcessChannelWindowAdjust(data, session) + if ValidateChannelWindowAdjust(data) { + ProcessChannelWindowAdjust(data, session) + } case APF_PROTOCOLVERSION: // 192 log.Debug("received APF PROTOCOL VERSION") - dataToSend = ProcessProtocolVersion(data) + if ValidateProtocolVersion(data) { + dataToSend = ProcessProtocolVersion(data) + } case APF_USERAUTH_REQUEST: // 50 default: } diff --git a/pkg/apf/vaildator.go b/pkg/apf/vaildator.go new file mode 100644 index 00000000..bcc4ff24 --- /dev/null +++ b/pkg/apf/vaildator.go @@ -0,0 +1,72 @@ +package apf + +import ( + "encoding/binary" +) + +// ValidateProtocolVersion checks if the data length is at least 93 bytes for APF_PROTOCOLVERSION. +func ValidateProtocolVersion(data []byte) bool { + return len(data) >= 93 +} + +// ValidateServiceRequest checks if the data length is sufficient for APF_SERVICE_REQUEST. +func ValidateServiceRequest(data []byte) bool { + if len(data) < 5 { + return false + } + + serviceLen := int(binary.BigEndian.Uint32(data[1:5])) + + return len(data) >= 5+serviceLen +} + +// ValidateGlobalRequest checks if the data length is sufficient for APF_GLOBAL_REQUEST. +func ValidateGlobalRequest(data []byte) bool { + if len(data) < 5 { + return false + } + + globalReqLen := int(binary.BigEndian.Uint32(data[1:5])) + if len(data) < 5+globalReqLen+1 { + return false + } + + serviceName := string(data[5 : 5+globalReqLen]) + + if serviceName == APF_GLOBAL_REQUEST_STR_TCP_FORWARD_REQUEST || serviceName == APF_GLOBAL_REQUEST_STR_TCP_FORWARD_CANCEL_REQUEST { + if len(data) < 6+globalReqLen+4 { + return false + } + + addrLen := int(binary.BigEndian.Uint32(data[6+globalReqLen : 10+globalReqLen])) + + return len(data) >= 14+globalReqLen+addrLen + } + + return false +} + +// ValidateChannelOpenConfirmation checks if the data length is at least 17 bytes for APF_CHANNEL_OPEN_CONFIRMATION. +func ValidateChannelOpenConfirmation(data []byte) bool { + return len(data) >= 17 +} + +// ValidateChannelOpenFailure checks if the data length is at least 17 bytes for APF_CHANNEL_OPEN_FAILURE. +func ValidateChannelOpenFailure(data []byte) bool { + return len(data) >= 17 +} + +// ValidateChannelClose checks if the data length is at least 5 bytes for APF_CHANNEL_CLOSE. +func ValidateChannelClose(data []byte) bool { + return len(data) >= 5 +} + +// ValidateChannelData checks if the data length is sufficient for APF_CHANNEL_DATA. +func ValidateChannelData(data []byte) bool { + return len(data) >= 9 && len(data) >= 9+int(binary.BigEndian.Uint32(data[5:9])) +} + +// ValidateChannelWindowAdjust checks if the data length is at least 9 bytes for APF_CHANNEL_WINDOW_ADJUST. +func ValidateChannelWindowAdjust(data []byte) bool { + return len(data) >= 9 +} diff --git a/pkg/apf/validator_test.go b/pkg/apf/validator_test.go new file mode 100644 index 00000000..89806f19 --- /dev/null +++ b/pkg/apf/validator_test.go @@ -0,0 +1,397 @@ +package apf + +import ( + "encoding/binary" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestValidateProtocolVersion(t *testing.T) { + testCases := []struct { + name string + len int + want bool + }{ + { + name: "length < 93", + len: 92, + want: false, + }, + { + name: "length = 93", + len: 93, + want: true, + }, + { + name: "length > 93", + len: 100, + want: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + data := make([]byte, tc.len) + got := ValidateProtocolVersion(data) + assert.Equal(t, tc.want, got) + }) + } +} + +func TestValidateServiceRequest(t *testing.T) { + testCases := []struct { + name string + dataLen int + serviceLen uint32 + want bool + }{ + { + name: "length < 5", + dataLen: 4, + want: false, + }, + { + name: "exactly 5, serviceLen=0", + dataLen: 5, + serviceLen: 0, + want: true, + }, + { + name: "exactly 5, serviceLen=1 (not enough)", + dataLen: 5, + serviceLen: 1, + want: false, + }, + { + name: "length=6, serviceLen=1 (enough)", + dataLen: 6, + serviceLen: 1, + want: true, + }, + { + name: "length=10, serviceLen=5 (enough)", + dataLen: 10, + serviceLen: 5, + want: true, + }, + { + name: "length=9, serviceLen=5 (not enough)", + dataLen: 9, + serviceLen: 5, + want: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + data := make([]byte, tc.dataLen) + if tc.dataLen >= 5 { + binary.BigEndian.PutUint32(data[1:5], tc.serviceLen) + } + + got := ValidateServiceRequest(data) + + assert.Equal(t, tc.want, got) + }) + } +} + +func TestValidateGlobalRequest(t *testing.T) { + testCases := []struct { + name string + serviceName string + addrLen int + totalLen int + dataSetup func([]byte, int, string, int) + want bool + }{ + { + name: "length < 5", + serviceName: "", + totalLen: 4, // Not enough length to even read the globalReqLen + dataSetup: func(data []byte, totalLen int, service string, addrLen int) { + // No writes, as we know this will fail due to length anyway. + }, + want: false, + }, + { + name: "serviceName empty", + serviceName: "", + totalLen: 5, // Enough to write [1:5] + dataSetup: func(data []byte, totalLen int, service string, addrLen int) { + binary.BigEndian.PutUint32(data[1:5], 0) + }, + want: false, + }, + { + name: "non-matching service name", + serviceName: "unknown", + // globalReqLen = len("unknown") = 7 + // Minimum length for reading globalReqLen and serviceName: 5 + 7 = 12 + totalLen: 12, + dataSetup: func(data []byte, totalLen int, service string, addrLen int) { + globalReqLen := len(service) + binary.BigEndian.PutUint32(data[1:5], uint32(globalReqLen)) + copy(data[5:5+globalReqLen], service) + }, + want: false, + }, + { + name: "tcpip-forward insufficient for addrLen", + serviceName: APF_GLOBAL_REQUEST_STR_TCP_FORWARD_REQUEST, + // globalReqLen = 13 + // minimum to copy service name: 5 + 13 = 18 + // less than required for addrLen parsing (23) + totalLen: 18, + dataSetup: func(data []byte, totalLen int, service string, addrLen int) { + globalReqLen := len(service) + binary.BigEndian.PutUint32(data[1:5], uint32(globalReqLen)) + copy(data[5:5+globalReqLen], service) + }, + want: false, + }, + { + name: "tcpip-forward with enough length", + serviceName: APF_GLOBAL_REQUEST_STR_TCP_FORWARD_REQUEST, + addrLen: 4, + // For success: + // globalReqLen = 13 + // Need at least: 14 + 13 + 4 = 31 + totalLen: 31, + dataSetup: func(data []byte, totalLen int, service string, addrLen int) { + globalReqLen := len(service) + binary.BigEndian.PutUint32(data[1:5], uint32(globalReqLen)) + copy(data[5:5+globalReqLen], service) + binary.BigEndian.PutUint32(data[6+globalReqLen:10+globalReqLen], uint32(addrLen)) + for i := 10 + globalReqLen; i < 10+globalReqLen+addrLen; i++ { + data[i] = 0x01 + } + }, + want: true, + }, + { + name: "cancel-tcpip-forward with enough length", + serviceName: APF_GLOBAL_REQUEST_STR_TCP_FORWARD_CANCEL_REQUEST, + addrLen: 4, + // globalReqLen = len("cancel-tcpip-forward")=22 + // needed: 14+22+4=40 + totalLen: 40, + dataSetup: func(data []byte, totalLen int, service string, addrLen int) { + globalReqLen := len(service) + binary.BigEndian.PutUint32(data[1:5], uint32(globalReqLen)) + copy(data[5:5+globalReqLen], service) + binary.BigEndian.PutUint32(data[6+globalReqLen:10+globalReqLen], uint32(addrLen)) + for i := 10 + globalReqLen; i < 10+globalReqLen+addrLen; i++ { + data[i] = 0x01 + } + }, + want: true, + }, + { + name: "tcpip-forward insufficient for large addrLen", + serviceName: APF_GLOBAL_REQUEST_STR_TCP_FORWARD_REQUEST, + addrLen: 10, + // For success: 14+11+10=35 required + // Give 34 to fail + totalLen: 34, + dataSetup: func(data []byte, totalLen int, service string, addrLen int) { + globalReqLen := len(service) + binary.BigEndian.PutUint32(data[1:5], uint32(globalReqLen)) + copy(data[5:5+globalReqLen], service) + binary.BigEndian.PutUint32(data[6+globalReqLen:10+globalReqLen], uint32(addrLen)) + }, + want: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + data := make([]byte, tc.totalLen) + tc.dataSetup(data, tc.totalLen, tc.serviceName, tc.addrLen) + got := ValidateGlobalRequest(data) + assert.Equal(t, tc.want, got) + }) + } +} + +func TestValidateChannelOpenConfirmation(t *testing.T) { + testCases := []struct { + name string + len int + want bool + }{ + { + name: "length < 17", + len: 16, + want: false, + }, + { + name: "length = 17", + len: 17, + want: true, + }, + { + name: "length > 17", + len: 20, + want: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + data := make([]byte, tc.len) + got := ValidateChannelOpenConfirmation(data) + assert.Equal(t, tc.want, got) + }) + } +} + +func TestValidateChannelOpenFailure(t *testing.T) { + testCases := []struct { + name string + len int + want bool + }{ + { + name: "length < 17", + len: 16, + want: false, + }, + { + name: "length = 17", + len: 17, + want: true, + }, + { + name: "length > 17", + len: 20, + want: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + data := make([]byte, tc.len) + got := ValidateChannelOpenFailure(data) + assert.Equal(t, tc.want, got) + }) + } +} + +func TestValidateChannelClose(t *testing.T) { + testCases := []struct { + name string + len int + want bool + }{ + { + name: "length < 5", + len: 4, + want: false, + }, + { + name: "length = 5", + len: 5, + want: true, + }, + { + name: "length > 5", + len: 10, + want: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + data := make([]byte, tc.len) + got := ValidateChannelClose(data) + assert.Equal(t, tc.want, got) + }) + } +} + +func TestValidateChannelData(t *testing.T) { + testCases := []struct { + name string + dataLen int + dataField uint32 + want bool + }{ + { + name: "length < 9", + dataLen: 8, + dataField: 0, + want: false, + }, + { + name: "length = 9, dataField=0", + dataLen: 9, + dataField: 0, + want: true, + }, + { + name: "length=10, dataField=1", + dataLen: 10, + dataField: 1, + want: true, + }, + { + name: "length=13, dataField=5 (needs 14)", + dataLen: 13, + dataField: 5, + want: false, + }, + { + name: "length=14, dataField=5", + dataLen: 14, + dataField: 5, + want: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + data := make([]byte, tc.dataLen) + if tc.dataLen >= 9 { + binary.BigEndian.PutUint32(data[5:9], tc.dataField) + } + + got := ValidateChannelData(data) + + assert.Equal(t, tc.want, got) + }) + } +} + +func TestValidateChannelWindowAdjust(t *testing.T) { + testCases := []struct { + name string + len int + want bool + }{ + { + name: "length < 9", + len: 8, + want: false, + }, + { + name: "length = 9", + len: 9, + want: true, + }, + { + name: "length > 9", + len: 10, + want: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + data := make([]byte, tc.len) + got := ValidateChannelWindowAdjust(data) + assert.Equal(t, tc.want, got) + }) + } +}