diff --git a/api/uploads.go b/api/uploads.go index f2fe297..6e348d0 100644 --- a/api/uploads.go +++ b/api/uploads.go @@ -1,6 +1,7 @@ package api import ( + "arimelody-web/global" "bufio" "encoding/base64" "errors" @@ -15,6 +16,7 @@ func HandleImageUpload(data *string, directory string, filename string) (string, header := split[0] imageData, err := base64.StdEncoding.DecodeString(split[1]) ext, _ := strings.CutPrefix(header, "data:image/") + directory = filepath.Join(global.DATA_DIR, directory) switch ext { case "png": diff --git a/global/data.go b/global/data.go index 4f56a68..f88d25f 100644 --- a/global/data.go +++ b/global/data.go @@ -3,6 +3,7 @@ package global import ( "fmt" "os" + "path/filepath" "strings" "github.com/jmoiron/sqlx" @@ -34,7 +35,7 @@ var Args = func() map[string]string { return args }() -var HTTP_DOMAIN = func() string { +var HTTP_DOMAIN = func() string { domain := os.Getenv("ARIMELODY_HTTP_DOMAIN") if domain == "" { return "https://arimelody.me" @@ -42,4 +43,23 @@ var HTTP_DOMAIN = func() string { return domain }() +var DATA_DIR = func() string { + dir, err := filepath.Abs(os.Getenv("ARIMELODY_DATA_DIR")) + if err != nil { + fmt.Printf("FATAL: Failed to get working directory: %s\n", err.Error()) + os.Exit(1) + } + if dir != "" { + os.MkdirAll(dir, os.ModePerm) + } else { + var err error + dir, err = os.Getwd() + if err != nil { + fmt.Printf("FATAL: Failed to get working directory: %s\n", err.Error()) + os.Exit(1) + } + } + return dir +}() + var DB *sqlx.DB diff --git a/main.go b/main.go index a783421..901dcbb 100644 --- a/main.go +++ b/main.go @@ -72,7 +72,7 @@ func createServeMux() *http.ServeMux { mux.Handle("/admin/", http.StripPrefix("/admin", admin.Handler())) mux.Handle("/api/", http.StripPrefix("/api", api.Handler())) mux.Handle("/music/", http.StripPrefix("/music", view.MusicHandler())) - mux.Handle("/uploads/", http.StripPrefix("/uploads", staticHandler("uploads"))) + mux.Handle("/uploads/", http.StripPrefix("/uploads", staticHandler(filepath.Join(global.DATA_DIR, "uploads")))) mux.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/" || r.URL.Path == "/index.html" { err := templates.Pages["index"].Execute(w, nil)