Skip to content
Draft
117 changes: 117 additions & 0 deletions handlers/authorize.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package handlers

import (
"errors"
"net/http"

"github.com/mtlynch/screenjournal/v2/screenjournal"
)

var errForbidden = errors.New("forbidden")

type dbService struct {
Store
request *http.Request
}

func (d dbService) isOwnerOrAdmin(owner screenjournal.Username) bool {
return mustGetUsernameFromContext(d.request.Context()).Equal(owner) ||
isAdmin(d.request.Context())
}

func (d dbService) readReview(id screenjournal.ReviewID) (screenjournal.Review, error) {
return d.ReadReview(id)
}

func (d dbService) readComment(id screenjournal.CommentID) (screenjournal.ReviewComment, error) {
return d.ReadComment(id)
}

func (d dbService) readReaction(id screenjournal.ReactionID) (screenjournal.ReviewReaction, error) {
return d.ReadReaction(id)
}

func (d dbService) updateReview(id screenjournal.ReviewID, updated reviewPutRequest) (screenjournal.Review, error) {
review, err := d.readReview(id)
if err != nil {
return screenjournal.Review{}, err
}
if !d.isOwnerOrAdmin(review.Owner) {
return screenjournal.Review{}, errForbidden
}

review.Rating = updated.Rating
review.Blurb = updated.Blurb
review.Watched = updated.Watched

if err := d.UpdateReview(review); err != nil {
return screenjournal.Review{}, err
}

return review, nil
}

func (d dbService) deleteReview(id screenjournal.ReviewID) error {
review, err := d.readReview(id)
if err != nil {
return err
}
if !d.isOwnerOrAdmin(review.Owner) {
return errForbidden
}

if err := d.DeleteReview(id); err != nil {
return err
}

return nil
}

func (d dbService) updateComment(id screenjournal.CommentID, commentText screenjournal.CommentText) (screenjournal.ReviewComment, error) {
rc, err := d.readComment(id)
if err != nil {
return screenjournal.ReviewComment{}, err
}
if !d.isOwnerOrAdmin(rc.Owner) {
return screenjournal.ReviewComment{}, errForbidden
}

rc.CommentText = commentText
if err := d.UpdateComment(rc); err != nil {
return screenjournal.ReviewComment{}, err
}

return rc, nil
}

func (d dbService) deleteComment(id screenjournal.CommentID) error {
rc, err := d.readComment(id)
if err != nil {
return err
}
if !d.isOwnerOrAdmin(rc.Owner) {
return errForbidden
}

if err := d.DeleteComment(id); err != nil {
return err
}

return nil
}

func (d dbService) deleteReaction(id screenjournal.ReactionID) error {
rr, err := d.readReaction(id)
if err != nil {
return err
}
if !d.isOwnerOrAdmin(rr.Owner) {
return errForbidden
}

if err := d.DeleteReaction(id); err != nil {
return err
}

return nil
}
31 changes: 7 additions & 24 deletions handlers/comments.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package handlers

import (
"errors"
"fmt"
"html/template"
"log"
Expand Down Expand Up @@ -197,23 +198,14 @@ func (s Server) commentsPut() http.HandlerFunc {
return
}

rc, err := s.getDB(r).ReadComment(req.CommentID)
rc, err := s.getDB(r).updateComment(req.CommentID, req.CommentText)
if err == store.ErrCommentNotFound {
http.Error(w, "Comment not found", http.StatusNotFound)
return
} else if err != nil {
log.Printf("failed to read comment: %v", err)
http.Error(w, fmt.Sprintf("Failed to read comment: %v", err), http.StatusInternalServerError)
return
}

if !mustGetUsernameFromContext(r.Context()).Equal(rc.Owner) {
} else if errors.Is(err, errForbidden) {
http.Error(w, "Can't edit another user's comment", http.StatusForbidden)
return
}

rc.CommentText = req.CommentText
if err := s.getDB(r).UpdateComment(rc); err != nil {
} else if err != nil {
log.Printf("failed to update comment: %v", err)
http.Error(w, fmt.Sprintf("Failed to update comment: %v", err), http.StatusInternalServerError)
return
Expand Down Expand Up @@ -241,22 +233,13 @@ func (s Server) commentsDelete() http.HandlerFunc {
return
}

rc, err := s.getDB(r).ReadComment(cid)
if err == store.ErrCommentNotFound {
if err := s.getDB(r).deleteComment(cid); err == store.ErrCommentNotFound {
http.Error(w, "Comment not found", http.StatusNotFound)
return
} else if err != nil {
log.Printf("failed to read comment: %v", err)
http.Error(w, fmt.Sprintf("Failed to read comment: %v", err), http.StatusInternalServerError)
return
}

if !mustGetUsernameFromContext(r.Context()).Equal(rc.Owner) {
} else if errors.Is(err, errForbidden) {
http.Error(w, "Can't delete another user's comment", http.StatusForbidden)
return
}

if err := s.getDB(r).DeleteComment(cid); err != nil {
} else if err != nil {
log.Printf("failed to delete comment id=%v: %v", cid, err)
http.Error(w, "Failed to delete comment: %v", http.StatusInternalServerError)
return
Expand Down
60 changes: 60 additions & 0 deletions handlers/comments_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,40 @@ func TestCommentsPut(t *testing.T) {
},
status: http.StatusForbidden,
},
{
description: "allows an admin to update another user's comment",
route: "/api/comments/1",
payload: "comment=Admin%20updated%20this%20comment",
sessionToken: "adm123",
sessions: []mockSessionEntry{
makeCommentsTestData().sessions.userA,
makeCommentsTestData().sessions.userB,
{
token: "adm123",
session: sessions.Session{
Username: screenjournal.Username("admin"),
IsAdmin: true,
},
},
},
comments: []screenjournal.ReviewComment{
{
ID: screenjournal.CommentID(1),
Owner: makeCommentsTestData().sessions.userA.session.Username,
CommentText: screenjournal.CommentText("Good insights!"),
Review: makeCommentsTestData().reviews.userBTheWaterBoy,
},
},
status: http.StatusOK,
expectedComments: []screenjournal.ReviewComment{
{
ID: screenjournal.CommentID(1),
Owner: makeCommentsTestData().sessions.userA.session.Username,
CommentText: screenjournal.CommentText("Admin updated this comment"),
Review: makeCommentsTestData().reviews.userBTheWaterBoy,
},
},
},
{
description: "prevents an unauthenticated user from updating any comment",
route: "/api/comments/1",
Expand Down Expand Up @@ -573,6 +607,32 @@ func TestCommentsDelete(t *testing.T) {
},
status: http.StatusForbidden,
},
{
description: "allows an admin to delete another user's comment",
route: "/api/comments/1",
sessionToken: "adm123",
sessions: []mockSessionEntry{
makeCommentsTestData().sessions.userA,
makeCommentsTestData().sessions.userB,
{
token: "adm123",
session: sessions.Session{
Username: screenjournal.Username("admin"),
IsAdmin: true,
},
},
},
comments: []screenjournal.ReviewComment{
{
ID: screenjournal.CommentID(1),
Owner: makeCommentsTestData().sessions.userA.session.Username,
CommentText: screenjournal.CommentText("Good insights!"),
Review: makeCommentsTestData().reviews.userBTheWaterBoy,
},
},
status: http.StatusNoContent,
expectedComments: []screenjournal.ReviewComment{},
},
{
description: "prevents an unauthenticated user from deleting any comment",
route: "/api/comments/1",
Expand Down
11 changes: 11 additions & 0 deletions handlers/db.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package handlers

import "net/http"

func (s Server) getDB(r *http.Request) dbService {
return s.db.dbForRequest(r)
}

func (s Server) getAuthenticator(r *http.Request) Authenticator {
return s.db.authenticatorForRequest(r, s.authenticator)
}
27 changes: 22 additions & 5 deletions handlers/db_dev.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,16 @@ var sharedDBSettings = dbSettings{
tokenToDB: map[dbToken]Store{},
}

type sessionDBProvider struct {
store Store
}

func newDBProvider(store Store) dbProvider {
return sessionDBProvider{
store: store,
}
}

func (dbs *dbSettings) IsSessionIsolationEnabled() bool {
dbs.lock.RLock()
dbs.lock.RUnlock()
Expand All @@ -220,9 +230,9 @@ func (dbs *dbSettings) SaveDB(token dbToken, db Store) {
dbs.tokenToDB[token] = db
}

func (s Server) getDB(r *http.Request) Store {
func (p sessionDBProvider) rawDB(r *http.Request) Store {
if !sharedDBSettings.IsSessionIsolationEnabled() {
return s.store
return p.store
}
c, err := r.Cookie(dbTokenCookieName)
if err != nil {
Expand All @@ -231,11 +241,18 @@ func (s Server) getDB(r *http.Request) Store {
return sharedDBSettings.GetDB(dbToken(c.Value))
}

func (s Server) getAuthenticator(r *http.Request) Authenticator {
func (p sessionDBProvider) dbForRequest(r *http.Request) dbService {
return dbService{
Store: p.rawDB(r),
request: r,
}
}

func (p sessionDBProvider) authenticatorForRequest(r *http.Request, fallback Authenticator) Authenticator {
if !sharedDBSettings.IsSessionIsolationEnabled() {
return s.authenticator
return fallback
}
return auth.New(s.getDB(r))
return auth.New(p.rawDB(r))
}

func dbPerSessionPost() http.HandlerFunc {
Expand Down
21 changes: 17 additions & 4 deletions handlers/db_prod.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,23 @@ func (s *Server) addDevRoutes() {
// no-op
}

func (s Server) getDB(*http.Request) Store {
return s.store
type staticDBProvider struct {
store Store
}

func (s Server) getAuthenticator(_ *http.Request) Authenticator {
return s.authenticator
func newDBProvider(store Store) dbProvider {
return staticDBProvider{
store: store,
}
}

func (p staticDBProvider) dbForRequest(r *http.Request) dbService {
return dbService{
Store: p.store,
request: r,
}
}

func (p staticDBProvider) authenticatorForRequest(_ *http.Request, fallback Authenticator) Authenticator {
return fallback
}
17 changes: 4 additions & 13 deletions handlers/reactions.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package handlers

import (
"errors"
"fmt"
"html/template"
"log"
Expand Down Expand Up @@ -119,23 +120,13 @@ func (s Server) reactionsDelete() http.HandlerFunc {
return
}

rr, err := s.getDB(r).ReadReaction(rid)
if err == store.ErrReactionNotFound {
if err := s.getDB(r).deleteReaction(rid); err == store.ErrReactionNotFound {
http.Error(w, "Reaction not found", http.StatusNotFound)
return
} else if err != nil {
log.Printf("failed to read reaction: %v", err)
http.Error(w, fmt.Sprintf("Failed to read reaction: %v", err), http.StatusInternalServerError)
return
}

loggedInUsername := mustGetUsernameFromContext(r.Context())
if !loggedInUsername.Equal(rr.Owner) && !isAdmin(r.Context()) {
} else if errors.Is(err, errForbidden) {
http.Error(w, "Can't delete another user's reaction", http.StatusForbidden)
return
}

if err := s.getDB(r).DeleteReaction(rid); err != nil {
} else if err != nil {
log.Printf("failed to delete reaction id=%v: %v", rid, err)
http.Error(w, "Failed to delete reaction", http.StatusInternalServerError)
return
Expand Down
Loading