Skip to content

Commit

Permalink
add gateway support
Browse files Browse the repository at this point in the history
  • Loading branch information
gabe committed Apr 2, 2024
1 parent 937a919 commit 6dcea61
Show file tree
Hide file tree
Showing 15 changed files with 147 additions and 77 deletions.
2 changes: 1 addition & 1 deletion impl/concurrencytest/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func generateDIDPutRequest() (string, []byte, error) {
return "", nil, err
}

packet, err := did.DHT(doc.ID).ToDNSPacket(*doc, nil)
packet, err := did.DHT(doc.ID).ToDNSPacket(*doc, nil, nil)
if err != nil {
return "", nil, err
}
Expand Down
5 changes: 3 additions & 2 deletions impl/integrationtest/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ import (
"os/signal"
"time"

"github.com/sirupsen/logrus"

"github.com/TBD54566975/did-dht-method/internal/did"
"github.com/TBD54566975/did-dht-method/pkg/dht"
"github.com/sirupsen/logrus"
)

var (
Expand Down Expand Up @@ -97,7 +98,7 @@ func generateDIDPutRequest() (string, []byte, error) {
return "", nil, err
}

packet, err := did.DHT(doc.ID).ToDNSPacket(*doc, nil)
packet, err := did.DHT(doc.ID).ToDNSPacket(*doc, nil, nil)
if err != nil {
return "", nil, err
}
Expand Down
16 changes: 8 additions & 8 deletions impl/internal/did/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,31 +30,31 @@ func NewGatewayClient(gatewayURL string) (*GatewayClient, error) {
}, nil
}

// GetDIDDocument gets a DID document and its types from a did:dht Gateway
func (c *GatewayClient) GetDIDDocument(id string) (*did.Document, []TypeIndex, error) {
// GetDIDDocument gets a DID document, its types, and authoritative gateways, from a did:dht Gateway
func (c *GatewayClient) GetDIDDocument(id string) (*did.Document, []TypeIndex, []AuthoritativeGateway, error) {
d := DHT(id)
if !d.IsValid() {
return nil, nil, errors.New("invalid did")
return nil, nil, nil, errors.New("invalid did")
}
suffix, err := d.Suffix()
if err != nil {
return nil, nil, errors.Wrap(err, "failed to get suffix")
return nil, nil, nil, errors.Wrap(err, "failed to get suffix")
}
resp, err := http.Get(c.gatewayURL + "/" + suffix)
if err != nil {
return nil, nil, errors.Wrap(err, "failed to get did document")
return nil, nil, nil, errors.Wrap(err, "failed to get did document")
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, nil, errors.Errorf("failed to get did document, status code: %d", resp.StatusCode)
return nil, nil, nil, errors.Errorf("failed to get did document, status code: %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, nil, errors.Wrap(err, "failed to read response body")
return nil, nil, nil, errors.Wrap(err, "failed to read response body")
}
msg := new(dns.Msg)
if err = msg.Unpack(body[72:]); err != nil {
return nil, nil, errors.Wrap(err, "failed to unpack records")
return nil, nil, nil, errors.Wrap(err, "failed to unpack records")
}
return d.FromDNSPacket(msg)
}
Expand Down
36 changes: 20 additions & 16 deletions impl/internal/did/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func TestClient(t *testing.T) {
require.NoError(t, err)
require.NotEmpty(t, doc)

packet, err := DHT(doc.ID).ToDNSPacket(*doc, nil)
packet, err := DHT(doc.ID).ToDNSPacket(*doc, nil, nil)
assert.NoError(t, err)
assert.NotEmpty(t, packet)

Expand All @@ -34,7 +34,7 @@ func TestClient(t *testing.T) {
err = client.PutDocument(doc.ID, *bep44Put)
assert.NoError(t, err)

gotDID, _, err := client.GetDIDDocument(doc.ID)
gotDID, _, _, err := client.GetDIDDocument(doc.ID)
assert.NoError(t, err)
assert.EqualValues(t, doc, gotDID)

Expand All @@ -51,31 +51,35 @@ func TestClientInvalidGateway(t *testing.T) {
func TestInvalidDIDDocument(t *testing.T) {
client, err := NewGatewayClient("https://diddht.tbddev.test")
require.NoError(t, err)
require.NotNil(t, client)
require.NotEmpty(t, client)

did, ty, err := client.GetDIDDocument("this is not a valid did")
did, types, gateways, err := client.GetDIDDocument("this is not a valid did")
assert.Error(t, err)
assert.Nil(t, ty)
assert.Nil(t, did)
assert.Empty(t, did)
assert.Empty(t, types)
assert.Empty(t, gateways)

did, ty, err = client.GetDIDDocument("did:dht:example")
did, types, gateways, err = client.GetDIDDocument("did:dht:example")
assert.EqualError(t, err, "invalid did")
assert.Nil(t, ty)
assert.Nil(t, did)
assert.Empty(t, did)
assert.Empty(t, types)
assert.Empty(t, gateways)

did, ty, err = client.GetDIDDocument("did:dht:i9xkp8ddcbcg8jwq54ox699wuzxyifsqx4jru45zodqu453ksz6y")
did, types, gateways, err = client.GetDIDDocument("did:dht:i9xkp8ddcbcg8jwq54ox699wuzxyifsqx4jru45zodqu453ksz6y")
assert.Error(t, err) // this should error because the gateway URL is invalid
assert.Nil(t, ty)
assert.Nil(t, did)
assert.Empty(t, did)
assert.Empty(t, types)
assert.Empty(t, gateways)

client, err = NewGatewayClient("https://tbd.website")
require.NoError(t, err)
require.NotNil(t, client)
require.NotEmpty(t, client)

did, ty, err = client.GetDIDDocument("did:dht:i9xkp8ddcbcg8jwq54ox699wuzxyifsqx4jru45zodqu453ksz6y")
did, types, gateways, err = client.GetDIDDocument("did:dht:i9xkp8ddcbcg8jwq54ox699wuzxyifsqx4jru45zodqu453ksz6y")
assert.Error(t, err) // this should error because the gateway URL will return a non-200
assert.Nil(t, ty)
assert.Nil(t, did)
assert.Empty(t, did)
assert.Empty(t, types)
assert.Empty(t, gateways)

err = client.PutDocument("did:dht:example", bep44.Put{})
assert.Error(t, err)
Expand Down
69 changes: 52 additions & 17 deletions impl/internal/did/did.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ import (
"github.com/TBD54566975/ssi-sdk/did"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/miekg/dns"
"github.com/pkg/errors"
"github.com/tv42/zbase32"
)

type (
DHT string
TypeIndex int
DHT string
TypeIndex int
AuthoritativeGateway string
)

const (
Expand Down Expand Up @@ -237,11 +239,16 @@ func GetDIDDHTIdentifier(pubKey []byte) string {
}

// ToDNSPacket converts a DID DHT Document to a DNS packet with an optional list of types to include
func (d DHT) ToDNSPacket(doc did.Document, types []TypeIndex) (*dns.Msg, error) {
func (d DHT) ToDNSPacket(doc did.Document, types []TypeIndex, gateways []AuthoritativeGateway) (*dns.Msg, error) {
var records []dns.RR
var rootRecord []string
keyLookup := make(map[string]string)

suffix, err := d.Suffix()
if err != nil {
return nil, errors.Wrap(err, "failed to get suffix while decoding DNS packet")
}

// first append the version to the root record
rootRecord = append(rootRecord, fmt.Sprintf("v=%d", Version))

Expand Down Expand Up @@ -285,6 +292,20 @@ func (d DHT) ToDNSPacket(doc did.Document, types []TypeIndex) (*dns.Msg, error)
records = append(records, &akaAnswer)
}

// add all gateways
for _, gateway := range gateways {
gatewayAnswer := dns.TXT{
Hdr: dns.RR_Header{
Name: fmt.Sprintf("_did.%s.", suffix),
Rrtype: dns.TypeNS,
Class: dns.ClassINET,
Ttl: 7200,
},
Txt: []string{string(gateway)},
}
records = append(records, &gatewayAnswer)
}

// build all key records
var vmIDs []string
for i, vm := range doc.VerificationMethod {
Expand Down Expand Up @@ -420,7 +441,7 @@ func (d DHT) ToDNSPacket(doc did.Document, types []TypeIndex) (*dns.Msg, error)
// add the root record
rootAnswer := dns.TXT{
Hdr: dns.RR_Header{
Name: "_did.",
Name: fmt.Sprintf("_did.%s.", suffix),
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
Ttl: 7200,
Expand Down Expand Up @@ -483,11 +504,20 @@ func parseServiceData(serviceEndpoint any) string {
}

// FromDNSPacket converts a DNS packet to a DID DHT Document
func (d DHT) FromDNSPacket(msg *dns.Msg) (*did.Document, []TypeIndex, error) {
// Returns the DID Document, a list of types, a list of authoritative gateways, and an error
func (d DHT) FromDNSPacket(msg *dns.Msg) (*did.Document, []TypeIndex, []AuthoritativeGateway, error) {
doc := did.Document{
ID: d.String(),
}

suffix, err := d.Suffix()
if err != nil {
return nil, nil, nil, errors.Wrap(err, "failed to get suffix while decoding DNS packet")
}

// track the authoritative gateways
var gateways []AuthoritativeGateway
// track the types
var types []TypeIndex
keyLookup := make(map[string]string)
for _, rr := range msg.Answer {
Expand Down Expand Up @@ -520,23 +550,23 @@ func (d DHT) FromDNSPacket(msg *dns.Msg) (*did.Document, []TypeIndex, error) {
// Convert keyBase64URL back to PublicKeyJWK
pubKeyBytes, err := base64.RawURLEncoding.DecodeString(keyBase64URL)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
// as per the spec's guidance DNS representations use compressed keys, so we must unmarshall them as such
pubKey, err := crypto.BytesToPubKey(pubKeyBytes, keyType, crypto.ECDSAUnmarshalCompressed)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
pubKeyJWK, err := jwx.PublicKeyToPublicKeyJWK(&vmID, pubKey)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}

// set the algorithm if it's not the default for the key type
if alg == "" {
defaultAlg := defaultAlgForJWK(*pubKeyJWK)
if defaultAlg == "" {
return nil, nil, fmt.Errorf("unable to provide default alg for unsupported key type: %s", keyType)
return nil, nil, nil, fmt.Errorf("unable to provide default alg for unsupported key type: %s", keyType)
}
pubKeyJWK.ALG = defaultAlg
} else {
Expand All @@ -545,7 +575,7 @@ func (d DHT) FromDNSPacket(msg *dns.Msg) (*did.Document, []TypeIndex, error) {

// make sure the controller of the identity key matches the DID
if vmID == "0" && controller != d.String() {
return nil, nil, fmt.Errorf("controller of identity key must be the DID itself, instead it is: %s", controller)
return nil, nil, nil, fmt.Errorf("controller of identity key must be the DID itself, instead it is: %s", controller)
}

// if the verification method ID is not set, set it to the thumbprint
Expand All @@ -554,7 +584,7 @@ func (d DHT) FromDNSPacket(msg *dns.Msg) (*did.Document, []TypeIndex, error) {
}

if vmID != "0" && pubKeyJWK.KID != vmID {
return nil, nil, fmt.Errorf("verification method JWK KID must be set to its thumbprint")
return nil, nil, nil, fmt.Errorf("verification method JWK KID must be set to its thumbprint")
}

vm := did.VerificationMethod{
Expand Down Expand Up @@ -601,17 +631,22 @@ func (d DHT) FromDNSPacket(msg *dns.Msg) (*did.Document, []TypeIndex, error) {

} else if record.Hdr.Name == "_typ._did." {
if record.Txt[0] == "" || len(record.Txt) != 1 {
return nil, nil, fmt.Errorf("invalid types record")
return nil, nil, nil, fmt.Errorf("invalid types record")
}
typesStr := strings.Split(strings.TrimPrefix(record.Txt[0], "id="), ",")
for _, t := range typesStr {
tInt, err := strconv.Atoi(t)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
types = append(types, TypeIndex(tInt))
}
} else if record.Hdr.Name == "_did." {
} else if record.Hdr.Name == fmt.Sprintf("_did.%s.", suffix) && record.Hdr.Rrtype == dns.TypeNS {
if record.Txt[0] == "" || len(record.Txt) != 1 {
return nil, nil, nil, fmt.Errorf("invalid gateways record: %s", record.String())
}
gateways = append(gateways, AuthoritativeGateway(record.Txt[0]))
} else if record.Hdr.Name == fmt.Sprintf("_did.%s.", suffix) && record.Hdr.Rrtype == dns.TypeTXT {
rootData := strings.Join(record.Txt, ";")
rootItems := strings.Split(rootData, ";")

Expand All @@ -628,7 +663,7 @@ func (d DHT) FromDNSPacket(msg *dns.Msg) (*did.Document, []TypeIndex, error) {
switch key {
case "v":
if len(valueItems) != 1 || valueItems[0] != strconv.Itoa(Version) {
return nil, nil, fmt.Errorf("invalid version: %s", values)
return nil, nil, nil, fmt.Errorf("invalid version: %s", values)
}
seenVersion = true
case "auth":
Expand All @@ -654,13 +689,13 @@ func (d DHT) FromDNSPacket(msg *dns.Msg) (*did.Document, []TypeIndex, error) {
}
}
if !seenVersion {
return nil, nil, fmt.Errorf("root record missing version identifier")
return nil, nil, nil, fmt.Errorf("root record missing version identifier")
}
}
}
}

return &doc, types, nil
return &doc, types, gateways, nil
}

func parseTxtData(data string) map[string]string {
Expand Down
Loading

0 comments on commit 6dcea61

Please sign in to comment.