diff --git a/README.md b/README.md index 71035ef7..74003285 100644 --- a/README.md +++ b/README.md @@ -123,8 +123,6 @@ tootik has three kinds of posts: User A is allowed to send a message to user B only if B follows A. -**However, tootik does not implement access control.** Messages and posts are "unlisted" (in [Mastodon](https://github.com/mastodon/mastodon) jargon) and users who haven't received them can "discover" them. Every post is associated with an ID and /view will display a post given a hash of this ID: anyone who knows the hash associated with a message can see the message, even if the message was sent to another user and even if unauthenticated. - ### Post Visibility | Post type | To | CC | @@ -135,11 +133,11 @@ User A is allowed to send a message to user B only if B follows A. ### Reply Visibility -| Post type | To | CC | -|-------------|-------------|------------------------------------------------| -| Message | Post author | - | -| Post | Post author | Mentions and followers of reply author | -| Public post | Post author | Mentions, followers of reply author and Public | +| Post type | To | CC | +|-------------|-------------|-----------------------------------------------------------------| +| Message | Post author | - | +| Post | Post author | Post recipients, mentions and followers of reply author | +| Public post | Post author | Post recipients, mentions, followers of reply author and Public | ### Post Editing diff --git a/front/reply.go b/front/reply.go index 53529f6a..9a12c535 100644 --- a/front/reply.go +++ b/front/reply.go @@ -63,6 +63,14 @@ func reply(w text.Writer, r *request) { } else if !note.IsPublic() { to.Add(note.AttributedTo) cc.Add(r.User.Followers) + note.To.Range(func(id string, _ struct{}) bool { + cc.Add(id) + return true + }) + note.CC.Range(func(id string, _ struct{}) bool { + cc.Add(id) + return true + }) } else { r.Log.Error("Post audience is invalid", "post", note.ID) w.Error() diff --git a/front/view.go b/front/view.go index 322323f2..c1c389bc 100644 --- a/front/view.go +++ b/front/view.go @@ -43,7 +43,13 @@ func view(w text.Writer, r *request) { var noteString, authorString string var groupString sql.NullString - if err := r.QueryRow(`select notes.object, persons.actor, groups.actor from notes join persons on persons.id = notes.author left join (select id, actor from persons where actor->>'type' = 'Group') groups on groups.id = notes.groupid where notes.hash = ?`, hash).Scan(¬eString, &authorString, &groupString); err != nil && errors.Is(err, sql.ErrNoRows) { + + if r.User == nil { + err = r.QueryRow(`select notes.object, persons.actor, groups.actor from notes join persons on persons.id = notes.author left join (select id, actor from persons where actor->>'type' = 'Group') groups on groups.id = notes.groupid where notes.hash = ? and notes.public = 1`, hash).Scan(¬eString, &authorString, &groupString) + } else { + err = r.QueryRow(`select notes.object, persons.actor, groups.actor from notes join persons on persons.id = notes.author left join (select id, actor from persons where actor->>'type' = 'Group') groups on groups.id = notes.groupid where notes.hash = $1 and (notes.public = 1 or notes.author = $2 or $2 in (notes.cc0, notes.to0, notes.cc1, notes.to1, notes.cc2, notes.to2) or (notes.to2 is not null and exists (select 1 from json_each(notes.object->'to') where value = $2)) or (notes.cc2 is not null and exists (select 1 from json_each(notes.object->'cc') where value = $2)) or exists (select 1 from (select persons.actor->>'followers' as followers from persons join follows on follows.followed = persons.id where follows.accepted = 1 and follows.follower = $2) follows where follows.followers in (notes.cc0, notes.to0, notes.cc1, notes.to1, notes.cc2, notes.to2) or (notes.to2 is not null and exists (select 1 from json_each(notes.object->'to') where value = follows.followers)) or (notes.cc2 is not null and exists (select 1 from json_each(notes.object->'cc') where value = follows.followers))))`, hash, r.User.ID).Scan(¬eString, &authorString, &groupString) + } + if err != nil && errors.Is(err, sql.ErrNoRows) { r.Log.Info("Post was not found", "hash", hash) w.Status(40, "Post not found") return @@ -92,12 +98,9 @@ func view(w text.Writer, r *request) { } else { rows, err = r.Query( `select replies.object, persons.actor from - ( - select replies.object, replies.author, replies.inserted from notes join notes replies on replies.object->>'inReplyTo' = notes.id where notes.hash = $1 and (replies.public = 1 or replies.author = $2 or ($2 in (replies.cc0, replies.to0, replies.cc1, replies.to1, replies.cc2, replies.to2) or (replies.to2 is not null and exists (select 1 from json_each(replies.object->'to') where value = $2)) or (replies.cc2 is not null and exists (select 1 from json_each(replies.object->'cc') where value = $2)))) - union - select replies.object, replies.author, replies.inserted from notes join notes replies on replies.object->>'inReplyTo' = notes.id join (select persons.actor->>'followers' as followers from persons join follows on follows.followed = persons.id where follows.accepted = 1 and follows.followed = $2) follows on follows.followers in (replies.cc0, replies.to0, replies.cc1, replies.to1, replies.cc2, replies.to2) or (notes.to2 is not null and exists (select 1 from json_each(replies.object->'to') where value = follows.followers)) or (notes.cc2 is not null and exists (select 1 from json_each(replies.object->'cc') where value = follows.followers)) where notes.hash = $1 - ) replies + notes join notes replies on replies.object->>'inReplyTo' = notes.id left join persons on persons.id = replies.author + where notes.hash = $1 and (replies.public = 1 or replies.author = $2 or $2 in (replies.cc0, replies.to0, replies.cc1, replies.to1, replies.cc2, replies.to2) or (replies.to2 is not null and exists (select 1 from json_each(replies.object->'to') where value = $2)) or (replies.cc2 is not null and exists (select 1 from json_each(replies.object->'cc') where value = $2)) or exists (select 1 from persons join follows on follows.followed = persons.id where follows.accepted = 1 and follows.follower = $2 and persons.actor->>'followers' in (replies.cc0, replies.to0, replies.cc1, replies.to1, replies.cc2, replies.to2) or (notes.to2 is not null and exists (select 1 from json_each(replies.object->'to') where value = persons.actor->>'followers')) or (notes.cc2 is not null and exists (select 1 from json_each(replies.object->'cc') where value = persons.actor->>'followers')))) order by replies.inserted desc limit $3 offset $4`, hash, r.User.ID, diff --git a/test/hashtag_test.go b/test/hashtag_test.go index a13eb6e7..a48a3ab0 100644 --- a/test/hashtag_test.go +++ b/test/hashtag_test.go @@ -17,6 +17,8 @@ limitations under the License. package test import ( + "crypto/sha256" + "fmt" "github.com/stretchr/testify/assert" "testing" ) @@ -126,6 +128,9 @@ func TestHashtag_PostToFollowers(t *testing.T) { assert := assert.New(t) + follow := server.Handle(fmt.Sprintf("/users/follow/%x", sha256.Sum256([]byte(server.Alice.ID))), server.Bob) + assert.Equal(fmt.Sprintf("30 /users/outbox/%x\r\n", sha256.Sum256([]byte(server.Alice.ID))), follow) + whisper := server.Handle("/users/whisper?Hello%20%23world", server.Alice) assert.Regexp("^30 /users/view/[0-9a-f]{64}\r\n$", whisper) diff --git a/test/poll_test.go b/test/poll_test.go index 4c6c71e9..134146f4 100644 --- a/test/poll_test.go +++ b/test/poll_test.go @@ -743,6 +743,12 @@ func TestPoll_LocalVoteVisibilityFollowers(t *testing.T) { assert := assert.New(t) + follow := server.Handle(fmt.Sprintf("/users/follow/%x", sha256.Sum256([]byte(server.Alice.ID))), server.Bob) + assert.Equal(fmt.Sprintf("30 /users/outbox/%x\r\n", sha256.Sum256([]byte(server.Alice.ID))), follow) + + follow = server.Handle(fmt.Sprintf("/users/follow/%x", sha256.Sum256([]byte(server.Alice.ID))), server.Carol) + assert.Equal(fmt.Sprintf("30 /users/outbox/%x\r\n", sha256.Sum256([]byte(server.Alice.ID))), follow) + whisper := server.Handle("/users/whisper?%5bPOLL%20So%2c%20polls%20on%20Station%20are%20pretty%20cool%2c%20right%3f%5d%20Nope%20%7c%20Hell%20yeah%21%20%7c%20I%20couldn%27t%20care%20less", server.Alice) assert.Regexp("^30 /users/view/[0-9a-f]{64}\r\n$", whisper) @@ -785,12 +791,7 @@ func TestPoll_LocalVoteVisibilityFollowers(t *testing.T) { assert.Contains(view, "carol") view = server.Handle("/view/"+whisper[15:len(whisper)-2], nil) - assert.Contains(view, "So, polls on Station are pretty cool, right?") - assert.NotContains(view, "Vote") - assert.Contains(strings.Split(view, "\n"), "1 ████████ Hell yeah!") - assert.Contains(strings.Split(view, "\n"), "1 ████████ I couldn't care less") - assert.NotContains(view, "bob") - assert.NotContains(view, "carol") + assert.Equal("40 Post not found\r\n", view) } func TestPoll_LocalVoteVisibilityPublic(t *testing.T) { diff --git a/test/reply_test.go b/test/reply_test.go index 05b3427e..e204c232 100644 --- a/test/reply_test.go +++ b/test/reply_test.go @@ -207,8 +207,7 @@ func TestReply_ReplyToPublicPostByNotFollowedUser(t *testing.T) { assert.Regexp("30 /users/view/[0-9a-f]{64}", reply) view = server.Handle("/users/view/"+hash, server.Alice) - assert.Contains(view, "Hello world") - assert.Contains(view, "Welcome Bob") + assert.Equal("40 Post not found\r\n", view) users := server.Handle("/users/inbox/today", server.Alice) assert.NotContains(users, "Hello world") diff --git a/test/whisper_test.go b/test/whisper_test.go index ce92c8c9..218fecd5 100644 --- a/test/whisper_test.go +++ b/test/whisper_test.go @@ -29,10 +29,38 @@ func TestWhisper_HappyFlow(t *testing.T) { assert := assert.New(t) + follow := server.Handle(fmt.Sprintf("/users/follow/%x", sha256.Sum256([]byte(server.Alice.ID))), server.Bob) + assert.Equal(fmt.Sprintf("30 /users/outbox/%x\r\n", sha256.Sum256([]byte(server.Alice.ID))), follow) + + whisper := server.Handle("/users/whisper?Hello%20world", server.Alice) + assert.Regexp("^30 /users/view/[0-9a-f]{64}\r\n$", whisper) + + view := server.Handle(whisper[3:len(whisper)-2], server.Bob) + assert.Contains(view, "Hello world") + + outbox := server.Handle(fmt.Sprintf("/users/outbox/%x", sha256.Sum256([]byte(server.Alice.ID))), server.Bob) + assert.Contains(outbox, "Hello world") + + local := server.Handle("/local", server.Carol) + assert.NotContains(local, "Hello world") +} + +func TestWhisper_FollowAfterPost(t *testing.T) { + server := newTestServer() + defer server.Shutdown() + + assert := assert.New(t) + whisper := server.Handle("/users/whisper?Hello%20world", server.Alice) assert.Regexp("^30 /users/view/[0-9a-f]{64}\r\n$", whisper) view := server.Handle(whisper[3:len(whisper)-2], server.Bob) + assert.Equal("40 Post not found\r\n", view) + + follow := server.Handle(fmt.Sprintf("/users/follow/%x", sha256.Sum256([]byte(server.Alice.ID))), server.Bob) + assert.Equal(fmt.Sprintf("30 /users/outbox/%x\r\n", sha256.Sum256([]byte(server.Alice.ID))), follow) + + view = server.Handle(whisper[3:len(whisper)-2], server.Bob) assert.Contains(view, "Hello world") outbox := server.Handle(fmt.Sprintf("/users/outbox/%x", sha256.Sum256([]byte(server.Alice.ID))), server.Bob) @@ -48,6 +76,9 @@ func TestWhisper_Throttling(t *testing.T) { assert := assert.New(t) + follow := server.Handle(fmt.Sprintf("/users/follow/%x", sha256.Sum256([]byte(server.Alice.ID))), server.Bob) + assert.Equal(fmt.Sprintf("30 /users/outbox/%x\r\n", sha256.Sum256([]byte(server.Alice.ID))), follow) + whisper := server.Handle("/users/whisper?Hello%20world", server.Alice) assert.Regexp("^30 /users/view/[0-9a-f]{64}\r\n$", whisper)