From a9d53dc41b4b49132379f203de6cfdc6e66be8e6 Mon Sep 17 00:00:00 2001 From: Ronald Holshausen Date: Fri, 31 May 2024 11:57:58 +1000 Subject: [PATCH] feat: Update example JWT plugin --- plugins/jwt/inspect.lua | 373 +++++++++++++++++++++++++++++++++++ plugins/jwt/jwt.lua | 46 ++++- plugins/jwt/matching.lua | 87 ++++++++ plugins/jwt/pact-plugin.json | 6 +- plugins/jwt/plugin.lua | 78 +++++++- plugins/jwt/utils.lua | 48 +++++ 6 files changed, 625 insertions(+), 13 deletions(-) create mode 100644 plugins/jwt/inspect.lua create mode 100644 plugins/jwt/matching.lua create mode 100644 plugins/jwt/utils.lua diff --git a/plugins/jwt/inspect.lua b/plugins/jwt/inspect.lua new file mode 100644 index 00000000..e1dda1ab --- /dev/null +++ b/plugins/jwt/inspect.lua @@ -0,0 +1,373 @@ +-- Taken from https://github.com/kikito/inspect.lua + +local _tl_compat; if (tonumber((_VERSION or ''):match('[%d.]*$')) or 0) < 5.3 then local p, m = pcall(require, 'compat53.module'); if p then _tl_compat = m end end; local math = _tl_compat and _tl_compat.math or math; local string = _tl_compat and _tl_compat.string or string; local table = _tl_compat and _tl_compat.table or table +local inspect = {Options = {}, } + + + + + + + + + + + + + + + + + +inspect._VERSION = 'inspect.lua 3.1.0' +inspect._URL = 'http://github.com/kikito/inspect.lua' +inspect._DESCRIPTION = 'human-readable representations of tables' +inspect._LICENSE = [[ + MIT LICENSE + + Copyright (c) 2022 Enrique GarcĂ­a Cota + + Permission is hereby granted, free of charge, to any person obtaining a + copy of this software and associated documentation files (the + "Software"), to deal in the Software without restriction, including + without limitation the rights to use, copy, modify, merge, publish, + distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to + the following conditions: + + The above copyright notice and this permission notice shall be included + in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +]] +inspect.KEY = setmetatable({}, { __tostring = function() return 'inspect.KEY' end }) +inspect.METATABLE = setmetatable({}, { __tostring = function() return 'inspect.METATABLE' end }) + +local tostring = tostring +local rep = string.rep +local match = string.match +local char = string.char +local gsub = string.gsub +local fmt = string.format + +local _rawget +if rawget then + _rawget = rawget +else + _rawget = function(t, k) return t[k] end +end + +local function rawpairs(t) + return next, t, nil +end + + + +local function smartQuote(str) + if match(str, '"') and not match(str, "'") then + return "'" .. str .. "'" + end + return '"' .. gsub(str, '"', '\\"') .. '"' +end + + +local shortControlCharEscapes = { + ["\a"] = "\\a", ["\b"] = "\\b", ["\f"] = "\\f", ["\n"] = "\\n", + ["\r"] = "\\r", ["\t"] = "\\t", ["\v"] = "\\v", ["\127"] = "\\127", +} +local longControlCharEscapes = { ["\127"] = "\127" } +for i = 0, 31 do + local ch = char(i) + if not shortControlCharEscapes[ch] then + shortControlCharEscapes[ch] = "\\" .. i + longControlCharEscapes[ch] = fmt("\\%03d", i) + end +end + +local function escape(str) + return (gsub(gsub(gsub(str, "\\", "\\\\"), + "(%c)%f[0-9]", longControlCharEscapes), + "%c", shortControlCharEscapes)) +end + +local luaKeywords = { + ['and'] = true, + ['break'] = true, + ['do'] = true, + ['else'] = true, + ['elseif'] = true, + ['end'] = true, + ['false'] = true, + ['for'] = true, + ['function'] = true, + ['goto'] = true, + ['if'] = true, + ['in'] = true, + ['local'] = true, + ['nil'] = true, + ['not'] = true, + ['or'] = true, + ['repeat'] = true, + ['return'] = true, + ['then'] = true, + ['true'] = true, + ['until'] = true, + ['while'] = true, +} + +local function isIdentifier(str) + return type(str) == "string" and + not not str:match("^[_%a][_%a%d]*$") and + not luaKeywords[str] +end + +local flr = math.floor +local function isSequenceKey(k, sequenceLength) + return type(k) == "number" and + flr(k) == k and + 1 <= (k) and + k <= sequenceLength +end + +local defaultTypeOrders = { + ['number'] = 1, ['boolean'] = 2, ['string'] = 3, ['table'] = 4, + ['function'] = 5, ['userdata'] = 6, ['thread'] = 7, +} + +local function sortKeys(a, b) + local ta, tb = type(a), type(b) + + + if ta == tb and (ta == 'string' or ta == 'number') then + return (a) < (b) + end + + local dta = defaultTypeOrders[ta] or 100 + local dtb = defaultTypeOrders[tb] or 100 + + + return dta == dtb and ta < tb or dta < dtb +end + +local function getKeys(t) + + local seqLen = 1 + while _rawget(t, seqLen) ~= nil do + seqLen = seqLen + 1 + end + seqLen = seqLen - 1 + + local keys, keysLen = {}, 0 + for k in rawpairs(t) do + if not isSequenceKey(k, seqLen) then + keysLen = keysLen + 1 + keys[keysLen] = k + end + end + table.sort(keys, sortKeys) + return keys, keysLen, seqLen +end + +local function countCycles(x, cycles) + if type(x) == "table" then + if cycles[x] then + cycles[x] = cycles[x] + 1 + else + cycles[x] = 1 + for k, v in rawpairs(x) do + countCycles(k, cycles) + countCycles(v, cycles) + end + countCycles(getmetatable(x), cycles) + end + end +end + +local function makePath(path, a, b) + local newPath = {} + local len = #path + for i = 1, len do newPath[i] = path[i] end + + newPath[len + 1] = a + newPath[len + 2] = b + + return newPath +end + + +local function processRecursive(process, + item, + path, + visited) + if item == nil then return nil end + if visited[item] then return visited[item] end + + local processed = process(item, path) + if type(processed) == "table" then + local processedCopy = {} + visited[item] = processedCopy + local processedKey + + for k, v in rawpairs(processed) do + processedKey = processRecursive(process, k, makePath(path, k, inspect.KEY), visited) + if processedKey ~= nil then + processedCopy[processedKey] = processRecursive(process, v, makePath(path, processedKey), visited) + end + end + + local mt = processRecursive(process, getmetatable(processed), makePath(path, inspect.METATABLE), visited) + if type(mt) ~= 'table' then mt = nil end + setmetatable(processedCopy, mt) + processed = processedCopy + end + return processed +end + +local function puts(buf, str) + buf.n = buf.n + 1 + buf[buf.n] = str +end + + + +local Inspector = {} + + + + + + + + + + +local Inspector_mt = { __index = Inspector } + +local function tabify(inspector) + puts(inspector.buf, inspector.newline .. rep(inspector.indent, inspector.level)) +end + +function Inspector:getId(v) + local id = self.ids[v] + local ids = self.ids + if not id then + local tv = type(v) + id = (ids[tv] or 0) + 1 + ids[v], ids[tv] = id, id + end + return tostring(id) +end + +function Inspector:putValue(v) + local buf = self.buf + local tv = type(v) + if tv == 'string' then + puts(buf, smartQuote(escape(v))) + elseif tv == 'number' or tv == 'boolean' or tv == 'nil' or + tv == 'cdata' or tv == 'ctype' then + puts(buf, tostring(v)) + elseif tv == 'table' and not self.ids[v] then + local t = v + + if t == inspect.KEY or t == inspect.METATABLE then + puts(buf, tostring(t)) + elseif self.level >= self.depth then + puts(buf, '{...}') + else + if self.cycles[t] > 1 then puts(buf, fmt('<%d>', self:getId(t))) end + + local keys, keysLen, seqLen = getKeys(t) + + puts(buf, '{') + self.level = self.level + 1 + + for i = 1, seqLen + keysLen do + if i > 1 then puts(buf, ',') end + if i <= seqLen then + puts(buf, ' ') + self:putValue(t[i]) + else + local k = keys[i - seqLen] + tabify(self) + if isIdentifier(k) then + puts(buf, k) + else + puts(buf, "[") + self:putValue(k) + puts(buf, "]") + end + puts(buf, ' = ') + self:putValue(t[k]) + end + end + + local mt = getmetatable(t) + if type(mt) == 'table' then + if seqLen + keysLen > 0 then puts(buf, ',') end + tabify(self) + puts(buf, ' = ') + self:putValue(mt) + end + + self.level = self.level - 1 + + if keysLen > 0 or type(mt) == 'table' then + tabify(self) + elseif seqLen > 0 then + puts(buf, ' ') + end + + puts(buf, '}') + end + + else + puts(buf, fmt('<%s %d>', tv, self:getId(v))) + end +end + + + + +function inspect.inspect(root, options) + options = options or {} + + local depth = options.depth or (math.huge) + local newline = options.newline or '\n' + local indent = options.indent or ' ' + local process = options.process + + if process then + root = processRecursive(process, root, {}, {}) + end + + local cycles = {} + countCycles(root, cycles) + + local inspector = setmetatable({ + buf = { n = 0 }, + ids = {}, + cycles = cycles, + depth = depth, + level = 0, + newline = newline, + indent = indent, + }, Inspector_mt) + + inspector:putValue(root) + + return table.concat(inspector.buf) +end + +setmetatable(inspect, { + __call = function(_, root, options) + return inspect.inspect(root, options) + end, +}) + +return inspect diff --git a/plugins/jwt/jwt.lua b/plugins/jwt/jwt.lua index d5f535a9..db5d6a6b 100644 --- a/plugins/jwt/jwt.lua +++ b/plugins/jwt/jwt.lua @@ -1,10 +1,13 @@ local jwt = {} -local random_utils = require("random_utils") +local utils = require "utils" +local base64 = require "base64" +local json = require "json" +local inspect = require "inspect" function jwt.build_header(config) local header = {} - header["typ"] = config["token-type"] or "jwt" + header["typ"] = config["token-type"] or "JWT" header["alg"] = config["algorithm"] or "RS512" if config["key-id"] then header["kid"] = config["key-id"] @@ -15,13 +18,13 @@ end function jwt.build_payload(config) local claims = { - jti = random_utils.random_hex(16), + jti = utils.random_hex(16), iat = os.time() } - claims["sub"] = config["subject"] or "sub_" .. random_utils.random_str(4) - claims["iss"] = config["issuer"] or "iss_" .. random_utils.random_str(4) - claims["aud"] = config["audience"] or "aud_" .. random_utils.random_str(4) + claims["sub"] = config["subject"] or "sub_" .. utils.random_str(4) + claims["iss"] = config["issuer"] or "iss_" .. utils.random_str(4) + claims["aud"] = config["audience"] or "aud_" .. utils.random_str(4) -- exp: now + expiryInMinutes * 60, // Current time + STS_TOKEN_EXPIRY_MINUTES minutes claims["exp"] = os.time() + 5 * 60 @@ -54,4 +57,35 @@ function jwt.sign_token(config, header, private_key, base_token) return signature end +function jwt.decode_token(contents) + local encoded_string = utils.utf8_from(contents) + logger("Encoded token = " .. encoded_string) + local t = {} + for str in string.gmatch(encoded_string, "([^\\.]+)") do + table.insert(t, str) + end + local header = utils.utf8_from(b64_decode_no_pad(t[1])) + logger("Token header = " .. inspect(header)) + local payload = utils.utf8_from(b64_decode_no_pad(t[2])) + logger("Token payload = " .. inspect(payload)) + local signature = t[3] + logger("Token signature = " .. signature) + + return { header = json.decode(header), payload = json.decode(payload), signature = signature, encoded = encoded_string }, nil +end + +function jwt.validate_signature(token, algorithm, key) + local parts = {} + for str in string.gmatch(token, "([^\\.]+)") do + table.insert(parts, str) + end + + if algorithm ~= "RS512" then + logger("Signature algorithm is set to " .. algorithm) + return false, "Only the RS512 alogirthim is supported at the moment" + end + + return rsa_validate(parts, algorithm, key) +end + return jwt diff --git a/plugins/jwt/matching.lua b/plugins/jwt/matching.lua new file mode 100644 index 00000000..06d6f0eb --- /dev/null +++ b/plugins/jwt/matching.lua @@ -0,0 +1,87 @@ +local jwt = require "jwt" +local inspect = require "inspect" + +local matching = {} + +function matching.validate_token(token, algorithm, key) + local result = {} + + local signature_valid = jwt.validate_signature(token.encoded, algorithm, key) + if not signature_valid then + table.insert(result, "Actual token signature is not valid") + end + + local expiration_time = token.payload["exp"] + if expiration_time < os.time() then + table.insert(result, "Actual token has expired") + end + + local not_before_time = token.payload["nbf"] + if not_before_time and not_before_time > os.time() then + table.insert(result, "Actual token is not to be used yet") + end + + return result +end + +function matching.match_headers(expected_header, actual_header) + logger("matching JWT headers") + logger("expected headers: " .. inspect(expected_header)) + logger("actual headers: " .. inspect(actual_header)) + return match_map(expected_header, actual_header, Set({"typ", "alg"}), + Set({"alg", "jku", "jwk", "kid", "x5u", "x5c", "x5t", "x5t#S256", "typ", "cty", "crit"}), Set({"jku"})) +end + +function matching.match_claims(expected_claims, actual_claims) + logger("matching JWT claims") + logger("expected claims: " .. inspect(expected_claims)) + logger("actual claims: " .. inspect(actual_claims)) + return match_map(expected_claims, actual_claims, Set({"iss", "sub", "aud", "exp"}), {}, Set({"exp", "nbf", "iat", "jti"})) +end + +function match_map(expected, actual, compulsary_keys, allowed_keys, keys_to_ignore) + local result = {} + + for k, v in pairs(expected) do + if not keys_to_ignore[k] then + if actual[k] ~= v then + result[k] = { + expected = v, + actual = actual[k], + mismatch = "Expected value " .. inspect(v) .. " but got value " .. inspect(actual[k]), + path = k + } + end + end + end + + local allowed_keys_empty = next(allowed_keys) == nil + for k, v in pairs(actual) do + if not allowed_keys_empty and not allowed_keys[k] then + result[k] = { + actual = v, + mismatch = k .. " is not allowed as a key", + path = k + } + end + end + + for k, v in pairs(compulsary_keys) do + if not actual[k] then + result[k] = { + mismatch = k .. " is a compulsary key, but was missing", + path = k + } + end + end + + return result +end + +function Set(list) + local set = {} + for _, l in ipairs(list) do set[l] = true end + return set +end + +return matching diff --git a/plugins/jwt/pact-plugin.json b/plugins/jwt/pact-plugin.json index ae57fc6a..b368127d 100644 --- a/plugins/jwt/pact-plugin.json +++ b/plugins/jwt/pact-plugin.json @@ -5,7 +5,5 @@ "version": "0.0.0", "executableType": "lua", "entryPoint": "plugin.lua", - "pluginConfig": { - } -} - + "pluginConfig": {} +} \ No newline at end of file diff --git a/plugins/jwt/plugin.lua b/plugins/jwt/plugin.lua index 05695119..4d5bb82e 100644 --- a/plugins/jwt/plugin.lua +++ b/plugins/jwt/plugin.lua @@ -2,6 +2,8 @@ local jwt = require "jwt" local json = require "json" +local inspect = require "inspect" +local matching = require "matching" -- Init function is called after the plugin script is loaded. It needs to return the plugin catalog -- entries to be added to the global catalog @@ -11,7 +13,7 @@ function init(implementation, version) -- Add some entropy to the random number generator math.randomseed(os.time()) - local params = { ["content-types"] = "application/jwt" } + local params = { ["content-types"] = "application/jwt;application/jwt+json" } local catalogue_entries = {} catalogue_entries[0] = { entryType="CONTENT_MATCHER", providerType="PLUGIN", key="jwt", values=params } catalogue_entries[1] = { entryType="CONTENT_GENERATOR", providerType="PLUGIN", key="jwt", values=params } @@ -56,7 +58,11 @@ function configure_interaction(content_type, config) } local contents = {} - --[[ pub part_name: String, + --[[ + /// Description of what part this interaction belongs to (in the case of there being more than + /// one, for instance, request/response messages) + pub part_name: String, + /// Body/Contents of the interaction pub body: OptionalBody, @@ -82,10 +88,76 @@ function configure_interaction(content_type, config) pub interaction_markup_type: String --]] contents[0] = { - body = signed_token, + body = { + contents = signed_token, + content_type = "application/jwt+json", + content_type_hint = "TEXT" + }, plugin_config = plugin_config } -- (Vec, Option) return { contents = contents, plugin_config = plugin_config } end + +-- This function does the actual matching +function match_contents(match_request) + --[[ + /// The expected contents from the Pact interaction + pub expected_contents: OptionalBody, + /// The actual contents that was received + pub actual_contents: OptionalBody, + /// Where there are keys or attributes in the data, indicates whether unexpected values are allowed + pub allow_unexpected_keys: bool, + /// Matching rules that apply + pub matching_rules: HashMap, + /// Plugin configuration form the Pact + pub plugin_configuration: Option + --]] + logger("Got a match request: " .. inspect(match_request)) + + local public_key = match_request.plugin_configuration.interaction_configuration["public-key"] + local algorithm = match_request.plugin_configuration.interaction_configuration["algorithm"] + + local expected_jwt, error = jwt.decode_token(match_request.expected_contents.contents) + if error then + return { error = error } + end + logger("Expected JWT: " .. inspect(expected_jwt)) + + local actual_jwt, actual_error = jwt.decode_token(match_request.actual_contents.contents) + if actual_error then + return { error = actual_error } + end + logger("Actual JWT: " .. inspect(actual_jwt)) + + --[[ + /// An error occurred trying to compare the contents + Error(String), + /// The content type was incorrect + TypeMismatch(String, String), + /// There were mismatched results + Mismatches(HashMap>), + /// All OK + OK + --]] + + local mismatches = {} + + local token_issues = matching.validate_token(actual_jwt, algorithm, public_key) + mismatches["$"] = token_issues + + local header_mismatches = matching.match_headers(expected_jwt.header, actual_jwt.header) + for k, v in pairs(header_mismatches) do + mismatches["header:" .. k] = v + end + + local claim_mismatches = matching.match_claims(expected_jwt.payload, actual_jwt.payload) + for k, v in pairs(claim_mismatches) do + mismatches["claims:" .. k] = v + end + + local result = { mismatches = mismatches } + logger("returning match result " .. inspect(result)) + return result +end diff --git a/plugins/jwt/utils.lua b/plugins/jwt/utils.lua new file mode 100644 index 00000000..6e7dade7 --- /dev/null +++ b/plugins/jwt/utils.lua @@ -0,0 +1,48 @@ +local utils = {} + +function utils.random_str(count) + local result = {} + + for _i, ch in random_str_iter(count) do + table.insert(result, ch) + end + + return table.concat(result) +end + +function random_str_iter(count) + return random_str_gen, count, 0 +end + +utils.S = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + +function random_str_gen(count, index) + if index < count then + index = index + 1 + local i = math.random(string.len(utils.S)) + return index, string.sub(utils.S, i, i + 1) + end +end + +function utils.random_hex(count) + local result = {} + + local i = 0 + while i < count do + table.insert(result, string.format("%X", math.random(0, 16))) + i = i + 1 + end + + return table.concat(result) +end + +function utils.utf8_from(t) + local bytearr = {} + for _, v in ipairs(t) do + local utf8byte = v < 0 and (0xff + v + 1) or v + table.insert(bytearr, string.char(utf8byte)) + end + return table.concat(bytearr) +end + +return utils