package api import ( "context" "errors" "fmt" "net/http" "os" "strings" "arimelody-web/controller" "arimelody-web/model" ) func Handler(app *model.AppState) http.Handler { mux := http.NewServeMux() // TODO: generate API keys on the frontend // ARTIST ENDPOINTS mux.Handle("/v1/artist/", http.StripPrefix("/v1/artist", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var artistID = strings.Split(r.URL.Path[1:], "/")[0] artist, err := controller.GetArtist(app.DB, artistID) if err != nil { if strings.Contains(err.Error(), "no rows") { http.NotFound(w, r) return } fmt.Printf("WARN: Error while retrieving artist %s: %s\n", artistID, err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } switch r.Method { case http.MethodGet: // GET /api/v1/artist/{id} ServeArtist(app, artist).ServeHTTP(w, r) case http.MethodPut: // PUT /api/v1/artist/{id} (admin) requireAccount(UpdateArtist(app, artist)).ServeHTTP(w, r) case http.MethodDelete: // DELETE /api/v1/artist/{id} (admin) requireAccount(DeleteArtist(app, artist)).ServeHTTP(w, r) default: http.NotFound(w, r) } }))) mux.Handle("/v1/artist", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: // GET /api/v1/artist ServeAllArtists(app).ServeHTTP(w, r) case http.MethodPost: // POST /api/v1/artist (admin) requireAccount(CreateArtist(app)).ServeHTTP(w, r) default: http.NotFound(w, r) } })) // RELEASE ENDPOINTS mux.Handle("/v1/music/", http.StripPrefix("/v1/music", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var releaseID = strings.Split(r.URL.Path[1:], "/")[0] release, err := controller.GetRelease(app.DB, releaseID, true) if err != nil { if strings.Contains(err.Error(), "no rows") { http.NotFound(w, r) return } fmt.Printf("WARN: Error while retrieving release %s: %s\n", releaseID, err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } switch r.Method { case http.MethodGet: // GET /api/v1/music/{id} ServeRelease(app, release).ServeHTTP(w, r) case http.MethodPut: // PUT /api/v1/music/{id} (admin) requireAccount(UpdateRelease(app, release)).ServeHTTP(w, r) case http.MethodDelete: // DELETE /api/v1/music/{id} (admin) requireAccount(DeleteRelease(app, release)).ServeHTTP(w, r) default: http.NotFound(w, r) } }))) mux.Handle("/v1/music", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: // GET /api/v1/music ServeCatalog(app).ServeHTTP(w, r) case http.MethodPost: // POST /api/v1/music (admin) requireAccount(CreateRelease(app)).ServeHTTP(w, r) default: http.NotFound(w, r) } })) // TRACK ENDPOINTS mux.Handle("/v1/track/", http.StripPrefix("/v1/track", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var trackID = strings.Split(r.URL.Path[1:], "/")[0] track, err := controller.GetTrack(app.DB, trackID) if err != nil { if strings.Contains(err.Error(), "no rows") { http.NotFound(w, r) return } fmt.Printf("WARN: Error while retrieving track %s: %s\n", trackID, err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } switch r.Method { case http.MethodGet: // GET /api/v1/track/{id} (admin) requireAccount(ServeTrack(app, track)).ServeHTTP(w, r) case http.MethodPut: // PUT /api/v1/track/{id} (admin) requireAccount(UpdateTrack(app, track)).ServeHTTP(w, r) case http.MethodDelete: // DELETE /api/v1/track/{id} (admin) requireAccount(DeleteTrack(app, track)).ServeHTTP(w, r) default: http.NotFound(w, r) } }))) mux.Handle("/v1/track", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: // GET /api/v1/track (admin) requireAccount(ServeAllTracks(app)).ServeHTTP(w, r) case http.MethodPost: // POST /api/v1/track (admin) requireAccount(CreateTrack(app)).ServeHTTP(w, r) default: http.NotFound(w, r) } })) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { session, err := getSession(app, r) if err != nil { fmt.Fprintf(os.Stderr, "WARN: Failed to get session: %v\n", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } ctx := context.WithValue(r.Context(), "session", session) mux.ServeHTTP(w, r.WithContext(ctx)) }) } func requireAccount(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { session := r.Context().Value("session").(*model.Session) if session == nil || session.Account == nil { http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } ctx := context.WithValue(r.Context(), "session", session) next.ServeHTTP(w, r.WithContext(ctx)) }) } func getSession(app *model.AppState, r *http.Request) (*model.Session, error) { var token string // check cookies first sessionCookie, err := r.Cookie(model.COOKIE_TOKEN) if err != nil && err != http.ErrNoCookie { return nil, errors.New(fmt.Sprintf("Failed to retrieve session cookie: %v\n", err)) } if sessionCookie != nil { token = sessionCookie.Value } else { // check Authorization header token = strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") } if token == "" { return nil, nil } // fetch existing session session, err := controller.GetSession(app.DB, token) if err != nil && !strings.Contains(err.Error(), "no rows") { return nil, errors.New(fmt.Sprintf("Failed to retrieve session: %v\n", err)) } if session != nil { // TODO: consider running security checks here (i.e. user agent mismatches) } return session, nil }