diff --git a/admin/http.go b/admin/http.go index 83d2a18..0ca61a3 100644 --- a/admin/http.go +++ b/admin/http.go @@ -477,30 +477,13 @@ func staticHandler() http.Handler { func enforceSession(app *model.AppState, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - sessionCookie, err := r.Cookie(model.COOKIE_TOKEN) - if err != nil && err != http.ErrNoCookie { - fmt.Fprintf(os.Stderr, "WARN: Failed to retrieve session cookie: %v\n", err) + session, err := controller.GetSessionFromRequest(app.DB, r) + if err != nil { + fmt.Fprintf(os.Stderr, "WARN: Failed to retrieve session: %v\n", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } - var session *model.Session - - if sessionCookie != nil { - // fetch existing session - session, err = controller.GetSession(app.DB, sessionCookie.Value) - - if err != nil && !strings.Contains(err.Error(), "no rows") { - fmt.Fprintf(os.Stderr, "WARN: Failed to retrieve session: %v\n", err) - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) - return - } - - if session != nil { - // TODO: consider running security checks here (i.e. user agent mismatches) - } - } - if session == nil { // create a new session session, err = controller.CreateSession(app.DB, r.UserAgent()) diff --git a/api/release.go b/api/release.go index 4e7372f..b89cec8 100644 --- a/api/release.go +++ b/api/release.go @@ -19,7 +19,13 @@ func ServeRelease(app *model.AppState, release *model.Release) http.Handler { // only allow authorised users to view hidden releases privileged := false if !release.Visible { - session := r.Context().Value("session").(*model.Session) + session, err := controller.GetSessionFromRequest(app.DB, r) + if err != nil { + fmt.Fprintf(os.Stderr, "WARN: Failed to retrieve session: %v\n", err) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + if session != nil && session.Account != nil { // TODO: check privilege on release privileged = true diff --git a/controller/account.go b/controller/account.go index 0cf3364..9c7c1e1 100644 --- a/controller/account.go +++ b/controller/account.go @@ -2,7 +2,6 @@ package controller import ( "arimelody-web/model" - "net/http" "strings" "github.com/jmoiron/sqlx" @@ -77,19 +76,6 @@ func GetAccountBySession(db *sqlx.DB, sessionToken string) (*model.Account, erro return &account, nil } -func GetSessionFromRequest(db *sqlx.DB, r *http.Request) string { - tokenStr := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") - if len(tokenStr) > 0 { - return tokenStr - } - - cookie, err := r.Cookie(model.COOKIE_TOKEN) - if err != nil { - return "" - } - return cookie.Value -} - func CreateAccount(db *sqlx.DB, account *model.Account) error { err := db.Get( &account.ID, diff --git a/controller/session.go b/controller/session.go index 6e566f5..cf423fe 100644 --- a/controller/session.go +++ b/controller/session.go @@ -2,6 +2,10 @@ package controller import ( "database/sql" + "errors" + "fmt" + "net/http" + "strings" "time" "arimelody-web/model" @@ -11,6 +15,30 @@ import ( const TOKEN_LEN = 64 +func GetSessionFromRequest(db *sqlx.DB, r *http.Request) (*model.Session, error) { + sessionCookie, err := r.Cookie(model.COOKIE_TOKEN) + if err != nil && err != http.ErrNoCookie { + return nil, errors.New(fmt.Sprintf("Failed to retrieve session cookie: %v", err)) + } + + var session *model.Session + + if sessionCookie != nil { + // fetch existing session + session, err = GetSession(db, sessionCookie.Value) + + if err != nil && !strings.Contains(err.Error(), "no rows") { + return nil, errors.New(fmt.Sprintf("Failed to retrieve session: %v", err)) + } + + if session != nil { + // TODO: consider running security checks here (i.e. user agent mismatches) + } + } + + return session, nil +} + func CreateSession(db *sqlx.DB, userAgent string) (*model.Session, error) { tokenString := GenerateAlnumString(TOKEN_LEN) diff --git a/view/music.go b/view/music.go index 7b86270..89e428c 100644 --- a/view/music.go +++ b/view/music.go @@ -3,6 +3,7 @@ package view import ( "fmt" "net/http" + "os" "arimelody-web/controller" "arimelody-web/model" @@ -59,7 +60,13 @@ func ServeGateway(app *model.AppState, release *model.Release) http.Handler { // only allow authorised users to view hidden releases privileged := false if !release.Visible { - session := r.Context().Value("session").(*model.Session) + session, err := controller.GetSessionFromRequest(app.DB, r) + if err != nil { + fmt.Fprintf(os.Stderr, "WARN: Failed to retrieve session: %v\n", err) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + if session != nil && session.Account != nil { // TODO: check privilege on release privileged = true