package controller import ( "database/sql" "errors" "fmt" "net/http" "strings" "time" "arimelody-web/model" "github.com/jmoiron/sqlx" ) const TOKEN_LEN = 64 func GetSessionFromRequest(db *sqlx.DB, r *http.Request) (*model.Session, error) { sessionCookie, err := r.Cookie(model.COOKIE_TOKEN) if err != nil && err != http.ErrNoCookie { return nil, errors.New(fmt.Sprintf("Failed to retrieve session cookie: %v", err)) } var session *model.Session if sessionCookie != nil { // fetch existing session session, err = GetSession(db, sessionCookie.Value) if err != nil && !strings.Contains(err.Error(), "no rows") { return nil, errors.New(fmt.Sprintf("Failed to retrieve session: %v", err)) } if session != nil { // TODO: consider running security checks here (i.e. user agent mismatches) } } return session, nil } func CreateSession(db *sqlx.DB, userAgent string) (*model.Session, error) { tokenString := GenerateAlnumString(TOKEN_LEN) session := model.Session{ Token: string(tokenString), UserAgent: userAgent, CreatedAt: time.Now(), ExpiresAt: time.Now().Add(time.Hour * 24), } _, err := db.Exec("INSERT INTO session " + "(token, user_agent, created_at, expires_at) VALUES " + "($1, $2, $3, $4)", session.Token, session.UserAgent, session.CreatedAt, session.ExpiresAt, ) if err != nil { return nil, err } return &session, nil } // func WriteSession(db *sqlx.DB, session *model.Session) error { // _, err := db.Exec( // "UPDATE session " + // "SET account=$2,message=$3,error=$4 " + // "WHERE token=$1", // session.Token, // session.Account.ID, // session.Message, // session.Error, // ) // return err // } func SetSessionAttemptAccount(db *sqlx.DB, session *model.Session, account *model.Account) error { var err error session.AttemptAccount = account if account == nil { _, err = db.Exec("UPDATE session SET attempt_account=NULL WHERE token=$1", session.Token) } else { _, err = db.Exec("UPDATE session SET attempt_account=$2 WHERE token=$1", session.Token, account.ID) } return err } func SetSessionAccount(db *sqlx.DB, session *model.Session, account *model.Account) error { var err error session.Account = account if account == nil { _, err = db.Exec("UPDATE session SET account=NULL WHERE token=$1", session.Token) } else { _, err = db.Exec("UPDATE session SET account=$2 WHERE token=$1", session.Token, account.ID) } return err } func SetSessionMessage(db *sqlx.DB, session *model.Session, message string) error { var err error if message == "" { if !session.Message.Valid { return nil } session.Message = sql.NullString{ } _, err = db.Exec("UPDATE session SET message=NULL WHERE token=$1", session.Token) } else { session.Message = sql.NullString{ String: message, Valid: true } _, err = db.Exec("UPDATE session SET message=$2 WHERE token=$1", session.Token, message) } return err } func SetSessionError(db *sqlx.DB, session *model.Session, message string) error { var err error if message == "" { if !session.Error.Valid { return nil } session.Error = sql.NullString{ } _, err = db.Exec("UPDATE session SET error=NULL WHERE token=$1", session.Token) } else { session.Error = sql.NullString{ String: message, Valid: true } _, err = db.Exec("UPDATE session SET error=$2 WHERE token=$1", session.Token, message) } return err } func GetSession(db *sqlx.DB, token string) (*model.Session, error) { type dbSession struct { model.Session AttemptAccountID sql.NullString `db:"attempt_account"` AccountID sql.NullString `db:"account"` } session := dbSession{} err := db.Get( &session, "SELECT * FROM session WHERE token=$1", token, ) if err != nil { return nil, err } if session.AccountID.Valid { session.Account, err = GetAccountByID(db, session.AccountID.String) if err != nil { return nil, err } } if session.AttemptAccountID.Valid { session.AttemptAccount, err = GetAccountByID(db, session.AttemptAccountID.String) if err != nil { return nil, err } } return &session.Session, err } // func GetAllSessionsForAccount(db *sqlx.DB, accountID string) ([]model.Session, error) { // sessions := []model.Session{} // err := db.Select(&sessions, "SELECT * FROM session WHERE account=$1 AND expires_at>current_timestamp", accountID) // return sessions, err // } func DeleteAllSessionsForAccount(db *sqlx.DB, accountID string) error { _, err := db.Exec("DELETE FROM session WHERE account=$1", accountID) return err } func DeleteSession(db *sqlx.DB, token string) error { _, err := db.Exec("DELETE FROM session WHERE token=$1", token) return err }