diff --git a/admin/http.go b/admin/http.go index 4d32aa9..b6a71ee 100644 --- a/admin/http.go +++ b/admin/http.go @@ -20,20 +20,20 @@ func Handler(app *model.AppState) http.Handler { mux := http.NewServeMux() mux.Handle("/login", loginHandler(app)) - mux.Handle("/logout", RequireAccount(app, logoutHandler(app))) + mux.Handle("/logout", requireAccount(app, logoutHandler(app))) mux.Handle("/register", registerAccountHandler(app)) - mux.Handle("/account", RequireAccount(app, accountIndexHandler(app))) - mux.Handle("/account/", RequireAccount(app, http.StripPrefix("/account", accountHandler(app)))) + mux.Handle("/account", requireAccount(app, accountIndexHandler(app))) + mux.Handle("/account/", requireAccount(app, http.StripPrefix("/account", accountHandler(app)))) - mux.Handle("/release/", RequireAccount(app, http.StripPrefix("/release", serveRelease(app)))) - mux.Handle("/artist/", RequireAccount(app, http.StripPrefix("/artist", serveArtist(app)))) - mux.Handle("/track/", RequireAccount(app, http.StripPrefix("/track", serveTrack(app)))) + mux.Handle("/release/", requireAccount(app, http.StripPrefix("/release", serveRelease(app)))) + mux.Handle("/artist/", requireAccount(app, http.StripPrefix("/artist", serveArtist(app)))) + mux.Handle("/track/", requireAccount(app, http.StripPrefix("/track", serveTrack(app)))) mux.Handle("/static/", http.StripPrefix("/static", staticHandler())) - mux.Handle("/", RequireAccount(app, AdminIndexHandler(app))) + mux.Handle("/", requireAccount(app, AdminIndexHandler(app))) // response wrapper to make sure a session cookie exists return enforceSession(app, mux) @@ -381,7 +381,7 @@ func logoutHandler(app *model.AppState) http.Handler { }) } -func RequireAccount(app *model.AppState, next http.Handler) http.HandlerFunc { +func requireAccount(app *model.AppState, next http.Handler) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { session := r.Context().Value("session").(*model.Session) if session.Account == nil { diff --git a/api/api.go b/api/api.go index 9489126..50b1c63 100644 --- a/api/api.go +++ b/api/api.go @@ -1,11 +1,13 @@ package api import ( + "context" + "errors" "fmt" "net/http" + "os" "strings" - "arimelody-web/admin" "arimelody-web/controller" "arimelody-web/model" ) @@ -36,10 +38,10 @@ func Handler(app *model.AppState) http.Handler { ServeArtist(app, artist).ServeHTTP(w, r) case http.MethodPut: // PUT /api/v1/artist/{id} (admin) - admin.RequireAccount(app, UpdateArtist(app, artist)).ServeHTTP(w, r) + requireAccount(app, UpdateArtist(app, artist)).ServeHTTP(w, r) case http.MethodDelete: // DELETE /api/v1/artist/{id} (admin) - admin.RequireAccount(app, DeleteArtist(app, artist)).ServeHTTP(w, r) + requireAccount(app, DeleteArtist(app, artist)).ServeHTTP(w, r) default: http.NotFound(w, r) } @@ -51,7 +53,7 @@ func Handler(app *model.AppState) http.Handler { ServeAllArtists(app).ServeHTTP(w, r) case http.MethodPost: // POST /api/v1/artist (admin) - admin.RequireAccount(app, CreateArtist(app)).ServeHTTP(w, r) + requireAccount(app, CreateArtist(app)).ServeHTTP(w, r) default: http.NotFound(w, r) } @@ -78,10 +80,10 @@ func Handler(app *model.AppState) http.Handler { ServeRelease(app, release).ServeHTTP(w, r) case http.MethodPut: // PUT /api/v1/music/{id} (admin) - admin.RequireAccount(app, UpdateRelease(app, release)).ServeHTTP(w, r) + requireAccount(app, UpdateRelease(app, release)).ServeHTTP(w, r) case http.MethodDelete: // DELETE /api/v1/music/{id} (admin) - admin.RequireAccount(app, DeleteRelease(app, release)).ServeHTTP(w, r) + requireAccount(app, DeleteRelease(app, release)).ServeHTTP(w, r) default: http.NotFound(w, r) } @@ -93,7 +95,7 @@ func Handler(app *model.AppState) http.Handler { ServeCatalog(app).ServeHTTP(w, r) case http.MethodPost: // POST /api/v1/music (admin) - admin.RequireAccount(app, CreateRelease(app)).ServeHTTP(w, r) + requireAccount(app, CreateRelease(app)).ServeHTTP(w, r) default: http.NotFound(w, r) } @@ -117,13 +119,13 @@ func Handler(app *model.AppState) http.Handler { switch r.Method { case http.MethodGet: // GET /api/v1/track/{id} (admin) - admin.RequireAccount(app, ServeTrack(app, track)).ServeHTTP(w, r) + requireAccount(app, ServeTrack(app, track)).ServeHTTP(w, r) case http.MethodPut: // PUT /api/v1/track/{id} (admin) - admin.RequireAccount(app, UpdateTrack(app, track)).ServeHTTP(w, r) + requireAccount(app, UpdateTrack(app, track)).ServeHTTP(w, r) case http.MethodDelete: // DELETE /api/v1/track/{id} (admin) - admin.RequireAccount(app, DeleteTrack(app, track)).ServeHTTP(w, r) + requireAccount(app, DeleteTrack(app, track)).ServeHTTP(w, r) default: http.NotFound(w, r) } @@ -132,10 +134,10 @@ func Handler(app *model.AppState) http.Handler { switch r.Method { case http.MethodGet: // GET /api/v1/track (admin) - admin.RequireAccount(app, ServeAllTracks(app)).ServeHTTP(w, r) + requireAccount(app, ServeAllTracks(app)).ServeHTTP(w, r) case http.MethodPost: // POST /api/v1/track (admin) - admin.RequireAccount(app, CreateTrack(app)).ServeHTTP(w, r) + requireAccount(app, CreateTrack(app)).ServeHTTP(w, r) default: http.NotFound(w, r) } @@ -143,3 +145,51 @@ func Handler(app *model.AppState) http.Handler { return mux } + +func requireAccount(app *model.AppState, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + session, err := getSession(app, r) + if err != nil { + fmt.Fprintf(os.Stderr, "WARN: Failed to get session: %v\n", err) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + if session.Account == nil { + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + ctx := context.WithValue(r.Context(), "session", session) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +func getSession(app *model.AppState, r *http.Request) (*model.Session, error) { + var token string + + // check cookies first + sessionCookie, err := r.Cookie(model.COOKIE_TOKEN) + if err != nil && err != http.ErrNoCookie { + return nil, errors.New(fmt.Sprintf("Failed to retrieve session cookie: %v\n", err)) + } + if sessionCookie != nil { + token = sessionCookie.Value + } else { + // check Authorization header + token = strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") + } + + if token == "" { return nil, nil } + + // fetch existing session + session, err := controller.GetSession(app.DB, token) + + if err != nil && !strings.Contains(err.Error(), "no rows") { + return nil, errors.New(fmt.Sprintf("Failed to retrieve session: %v\n", err)) + } + + if session != nil { + // TODO: consider running security checks here (i.e. user agent mismatches) + } + + return session, nil +}