diff --git a/api/api.go b/api/api.go index 50b1c63..4edd07b 100644 --- a/api/api.go +++ b/api/api.go @@ -38,10 +38,10 @@ func Handler(app *model.AppState) http.Handler { ServeArtist(app, artist).ServeHTTP(w, r) case http.MethodPut: // PUT /api/v1/artist/{id} (admin) - requireAccount(app, UpdateArtist(app, artist)).ServeHTTP(w, r) + requireAccount(UpdateArtist(app, artist)).ServeHTTP(w, r) case http.MethodDelete: // DELETE /api/v1/artist/{id} (admin) - requireAccount(app, DeleteArtist(app, artist)).ServeHTTP(w, r) + requireAccount(DeleteArtist(app, artist)).ServeHTTP(w, r) default: http.NotFound(w, r) } @@ -53,7 +53,7 @@ func Handler(app *model.AppState) http.Handler { ServeAllArtists(app).ServeHTTP(w, r) case http.MethodPost: // POST /api/v1/artist (admin) - requireAccount(app, CreateArtist(app)).ServeHTTP(w, r) + requireAccount(CreateArtist(app)).ServeHTTP(w, r) default: http.NotFound(w, r) } @@ -80,10 +80,10 @@ func Handler(app *model.AppState) http.Handler { ServeRelease(app, release).ServeHTTP(w, r) case http.MethodPut: // PUT /api/v1/music/{id} (admin) - requireAccount(app, UpdateRelease(app, release)).ServeHTTP(w, r) + requireAccount(UpdateRelease(app, release)).ServeHTTP(w, r) case http.MethodDelete: // DELETE /api/v1/music/{id} (admin) - requireAccount(app, DeleteRelease(app, release)).ServeHTTP(w, r) + requireAccount(DeleteRelease(app, release)).ServeHTTP(w, r) default: http.NotFound(w, r) } @@ -95,7 +95,7 @@ func Handler(app *model.AppState) http.Handler { ServeCatalog(app).ServeHTTP(w, r) case http.MethodPost: // POST /api/v1/music (admin) - requireAccount(app, CreateRelease(app)).ServeHTTP(w, r) + requireAccount(CreateRelease(app)).ServeHTTP(w, r) default: http.NotFound(w, r) } @@ -119,13 +119,13 @@ func Handler(app *model.AppState) http.Handler { switch r.Method { case http.MethodGet: // GET /api/v1/track/{id} (admin) - requireAccount(app, ServeTrack(app, track)).ServeHTTP(w, r) + requireAccount(ServeTrack(app, track)).ServeHTTP(w, r) case http.MethodPut: // PUT /api/v1/track/{id} (admin) - requireAccount(app, UpdateTrack(app, track)).ServeHTTP(w, r) + requireAccount(UpdateTrack(app, track)).ServeHTTP(w, r) case http.MethodDelete: // DELETE /api/v1/track/{id} (admin) - requireAccount(app, DeleteTrack(app, track)).ServeHTTP(w, r) + requireAccount(DeleteTrack(app, track)).ServeHTTP(w, r) default: http.NotFound(w, r) } @@ -134,19 +134,15 @@ func Handler(app *model.AppState) http.Handler { switch r.Method { case http.MethodGet: // GET /api/v1/track (admin) - requireAccount(app, ServeAllTracks(app)).ServeHTTP(w, r) + requireAccount(ServeAllTracks(app)).ServeHTTP(w, r) case http.MethodPost: // POST /api/v1/track (admin) - requireAccount(app, CreateTrack(app)).ServeHTTP(w, r) + requireAccount(CreateTrack(app)).ServeHTTP(w, r) default: http.NotFound(w, r) } })) - return mux -} - -func requireAccount(app *model.AppState, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { session, err := getSession(app, r) if err != nil { @@ -154,7 +150,15 @@ func requireAccount(app *model.AppState, next http.Handler) http.Handler { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } - if session.Account == nil { + ctx := context.WithValue(r.Context(), "session", session) + mux.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +func requireAccount(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + session := r.Context().Value("session").(*model.Session) + if session == nil || session.Account == nil { http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return }