From d3536e6741351fb13a9f6a327637bc2a4619fea4 Mon Sep 17 00:00:00 2001 From: clsr Date: Tue, 15 Nov 2016 16:57:51 +0100 Subject: Move storage to its own package --- api.go | 11 ++- magic.go | 31 ------ main.go | 25 ++--- storage.go | 280 ----------------------------------------------------- storage/magic.go | 31 ++++++ storage/storage.go | 280 +++++++++++++++++++++++++++++++++++++++++++++++++++++ website.go | 4 +- 7 files changed, 332 insertions(+), 330 deletions(-) delete mode 100644 magic.go delete mode 100644 storage.go create mode 100644 storage/magic.go create mode 100644 storage/storage.go diff --git a/api.go b/api.go index 464d110..46563dd 100644 --- a/api.go +++ b/api.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "encoding/json" "fmt" + "git.clsr.net/gomf/storage" "io" "mime" "net/http" @@ -17,9 +18,9 @@ import ( ) func handleFile(w http.ResponseWriter, r *http.Request) { - f, hash, size, modtime, err := storage.Get(strings.TrimLeft(r.URL.Path, "/")) + f, hash, size, modtime, err := uploads.Get(strings.TrimLeft(r.URL.Path, "/")) if err != nil { - if _, ok := err.(ErrNotFound); ok { + if _, ok := err.(storage.ErrNotFound); ok { http.Error(w, err.Error(), http.StatusNotFound) } else { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -94,13 +95,13 @@ func handleUpload(w http.ResponseWriter, r *http.Request) { continue } - id, hash, size, err := storage.New(part, part.FileName()) + id, hash, size, err := uploads.New(part, part.FileName()) if err != nil { resp.ErrorCode = http.StatusInternalServerError resp.Description = err.Error() - if _, ok := err.(ErrTooLarge); ok { + if _, ok := err.(storage.ErrTooLarge); ok { resp.ErrorCode = http.StatusRequestEntityTooLarge - } else if _, ok := err.(ErrForbidden); ok { + } else if _, ok := err.(storage.ErrForbidden); ok { resp.ErrorCode = http.StatusForbidden } break diff --git a/magic.go b/magic.go deleted file mode 100644 index c93e15d..0000000 --- a/magic.go +++ /dev/null @@ -1,31 +0,0 @@ -package main - -// #cgo LDFLAGS: -lmagic -// #include -// #include -import "C" -import "errors" -import "unsafe" - -var magic C.magic_t - -func init() { - magic = C.magic_open(C.MAGIC_MIME_TYPE | C.MAGIC_SYMLINK | C.MAGIC_ERROR) - if magic == nil { - panic("unable to initialize libmagic") - } - if C.magic_load(magic, nil) != 0 { - C.magic_close(magic) - panic("unable to load libmagic database: " + C.GoString(C.magic_error(magic))) - } -} - -func GetMimeType(fname string) (string, error) { - cfname := C.CString(fname) - defer C.free(unsafe.Pointer(cfname)) - mime := C.magic_file(magic, cfname) - if mime == nil { - return "", errors.New(C.GoString(C.magic_error(magic))) - } - return C.GoString(mime), nil -} diff --git a/main.go b/main.go index d66416f..e8ff33b 100644 --- a/main.go +++ b/main.go @@ -3,13 +3,14 @@ package main import ( "flag" "fmt" + "git.clsr.net/gomf/storage" "math/rand" "net/http" "strings" "time" ) -var storage *Storage +var uploads *storage.Storage var ( uploadUrl string @@ -69,12 +70,12 @@ func main() { listenHttps := flag.String("https", "", "address to listen on for HTTPS") cert := flag.String("cert", "", "path to TLS certificate (for HTTPS)") key := flag.String("key", "", "path to TLS key (for HTTPS)") - maxSize := flag.Int64("max-size", DefaultMaxSize, "max filesize in bytes") + maxSize := flag.Int64("max-size", storage.DefaultMaxSize, "max filesize in bytes") filterMime := flag.String("filter-mime", "application/x-dosexec,application/x-msdos-program", "comma-separated list of filtered MIME types") filterExt := flag.String("filter-ext", "exe,dll,msi,scr,com,pif", "comma-separated list of filtered file extensions") whitelist := flag.Bool("whitelist", false, "use filter as a whitelist instead of blacklist") grill := flag.Bool("grill", false, "enable grills") - idLength := flag.Int("id-length", DefaultIdLength, "length of uploaded file IDs") + idLength := flag.Int("id-length", storage.DefaultIdLength, "length of uploaded file IDs") idCharset := flag.String("id-charset", "", "charset for uploaded file IDs (default lowercase letters a-z)") enableLog := flag.Bool("log", false, "enable logging") logIP := flag.Bool("log-ip", false, "log IP addresses") @@ -91,17 +92,17 @@ func main() { initWebsite() - storage = NewStorage("upload") - storage.FilterExt = strings.Split(*filterExt, ",") - for i := range storage.FilterExt { - storage.FilterExt[i] = "." + storage.FilterExt[i] + uploads = storage.NewStorage("upload") + uploads.FilterExt = strings.Split(*filterExt, ",") + for i := range uploads.FilterExt { + uploads.FilterExt[i] = "." + uploads.FilterExt[i] } - storage.FilterMime = strings.Split(*filterMime, ",") - storage.Whitelist = *whitelist - storage.IdLength = *idLength - storage.MaxSize = *maxSize + uploads.FilterMime = strings.Split(*filterMime, ",") + uploads.Whitelist = *whitelist + uploads.IdLength = *idLength + uploads.MaxSize = *maxSize if *idCharset != "" { - storage.IdCharset = *idCharset + uploads.IdCharset = *idCharset } if !*enableLog { diff --git a/storage.go b/storage.go deleted file mode 100644 index c0e8130..0000000 --- a/storage.go +++ /dev/null @@ -1,280 +0,0 @@ -package main - -import ( - "crypto/sha1" - "encoding/base64" - "encoding/hex" - "errors" - "io" - "io/ioutil" - "math/rand" - "mime" - "os" - "path" - "path/filepath" - "strconv" - "strings" - "time" -) - -const ( - MaxIdTries = 64 - - DefaultIdCharset = "abcdefghijklmnopqrstuvwxyz" - DefaultIdLength = 6 - DefaultMaxSize = 50 * 1024 * 1024 -) - -type Storage struct { - Folder string - IdCharset string - IdLength int - MaxSize int64 - FilterMime []string - FilterExt []string - Whitelist bool -} - -type ErrForbidden struct{ Type string } - -func (e ErrForbidden) Error() string { return "forbidden type: " + e.Type } - -type ErrTooLarge struct{ Size int64 } - -func (e ErrTooLarge) Error() string { - return "file exceeds maximum allowed size of " + strconv.FormatInt(e.Size, 10) + " bytes" -} - -type ErrNotFound struct{ Name string } - -func (e ErrNotFound) Error() string { return "file " + e.Name + " not found" } - -func NewStorage(folder string) *Storage { - if err := os.MkdirAll(path.Join(folder, "temp"), 0755); err != nil { - panic(err) - } - if err := os.MkdirAll(path.Join(folder, "files"), 0755); err != nil { - panic(err) - } - if err := os.MkdirAll(path.Join(folder, "ids"), 0755); err != nil { - panic(err) - } - - return &Storage{ - Folder: folder, - IdCharset: DefaultIdCharset, - IdLength: DefaultIdLength, - MaxSize: DefaultMaxSize, - } -} - -func (s *Storage) Get(id string) (file *os.File, hash string, size int64, modtime time.Time, err error) { - ext := path.Ext(id) - id = id[:len(id)-len(ext)] - for i := 0; i < len(id); i++ { - if !strings.ContainsRune(s.IdCharset, rune(id[i])) { - err = errors.New("invalid ID: " + id) - return - } - } - folder := s.idToFolder("ids", id) - files, err := ioutil.ReadDir(folder) - if err != nil { - err = ErrNotFound{id + ext} - return - } - if len(files) < 1 { - err = errors.New("internal storage error") - return - } - fn := files[0].Name() - fp := path.Join(folder, fn) - target, err := os.Readlink(fp) - if err != nil { - return - } - bhash, err := base64.RawURLEncoding.DecodeString(path.Base(path.Dir(target))) - if err != nil { - return - } - hash = hex.EncodeToString(bhash) - if path.Ext(fn) != ext { - err = ErrNotFound{id + ext} - return - } - stat, err := os.Lstat(fp) - if err != nil { - return - } - modtime = stat.ModTime() - size = stat.Size() - file, err = os.Open(fp) - return -} - -var errFileExists = errors.New("file exists") - -func (s *Storage) New(r io.Reader, name string) (id, hash string, size int64, err error) { - temp, err := ioutil.TempFile(path.Join(s.Folder, "temp"), "file") - if err != nil { - return - } - defer func() { - if temp != nil { - temp.Close() - os.Remove(temp.Name()) - } - }() - - hash, size, err = s.readInput(temp, r) - if err != nil { - return - } - _, ext, err := s.getMimeExt(temp.Name(), name) - if err != nil { - return - } - id, err = s.storeFile(temp, hash, ext) - if err == nil { - temp = nil // prevent deletion - } else if err == errFileExists { - err = nil - } - - return -} - -func (s *Storage) randomId() string { - id := make([]byte, s.IdLength) - for i := 0; i < len(id); i++ { - id[i] = s.IdCharset[rand.Intn(len(s.IdCharset))] - } - return string(id) -} - -func (s *Storage) idToFolder(subfolder, id string) string { - name := id - for len(name) < 4 { - name = "_" + name - } - return path.Join(s.Folder, subfolder, name[0:2], name[2:4], id) -} - -func (s *Storage) readInput(w io.Writer, r io.Reader) (hash string, size int64, err error) { - h := sha1.New() - w = io.MultiWriter(h, w) - if s.MaxSize > 0 { - r = io.LimitReader(r, s.MaxSize+1) - } - size, err = io.Copy(w, r) - if err != nil { - return - } - if lr, ok := r.(*io.LimitedReader); ok && lr.N == 0 { - err = ErrTooLarge{s.MaxSize} - return - } - hash = base64.RawURLEncoding.EncodeToString(h.Sum(nil)) - return -} - -func (s *Storage) getMimeExt(fpath string, name string) (mimetype, ext string, err error) { - mimetype, err = GetMimeType(fpath) - if err != nil { - return - } - - // choose file extension, prefer the user-provided one - ext = path.Ext(name) - exts, err := mime.ExtensionsByType(mimetype) - if err != nil { - return - } - if ext != "" && find(exts, ext) == "" { - ext = "" - if len(exts) > 0 { - ext = exts[0] - } - } - - filtered, ok := s.findFilter(exts, mimetype) - if !ok && s.Whitelist { // whitelist: reject if not on filters - err = ErrForbidden{mimetype} - } else if ok && !s.Whitelist { // blacklist: reject if filtered - // only block application/octet-stream if explicitly requested - if mimetype != "application/octet-stream" || find(s.FilterMime, mimetype) != "" { - err = ErrForbidden{filtered} - } - } - - return -} - -func (s *Storage) findFilter(exts []string, mimetype string) (match string, ok bool) { - if m := find(s.FilterMime, mimetype); m == mimetype { - return m, true - } - for _, ext := range exts { - if e := find(s.FilterExt, ext); e == ext { - return e, true - } - } - return "", false -} - -func (s *Storage) storeFile(file *os.File, hash, ext string) (id string, err error) { - hfolder := s.idToFolder("files", hash) - hpath := path.Join(hfolder, "file") - fexists := false - - os.MkdirAll(path.Dir(hfolder), 0755) - err = os.Mkdir(hfolder, 0755) - if err != nil { - if _, err = os.Stat(hpath); err != nil { - err = errors.New("internal storage error") - } - fexists = true - } else { - err = os.Rename(file.Name(), hpath) - os.Chmod(hpath, 0644) - } - if err != nil { - return - } - - fpath := "" - for i := 0; i < MaxIdTries; i++ { - id = s.randomId() - dir := s.idToFolder("ids", id) - os.MkdirAll(path.Dir(dir), 0755) - err = os.Mkdir(dir, 0755) - if err == nil { - fpath = path.Join(dir, "file"+ext) - id += ext - break - } - } - if fpath == "" { - err = errors.New("internal storage error") - return - } - rhpath, err := filepath.Rel(path.Dir(fpath), hpath) - if err != nil { - return - } - err = os.Symlink(rhpath, fpath) - - if fexists && err == nil { - err = errFileExists - } - return -} - -func find(ss []string, search string) string { - for _, s := range ss { - if s == search { - return s - } - } - return "" -} diff --git a/storage/magic.go b/storage/magic.go new file mode 100644 index 0000000..16f655d --- /dev/null +++ b/storage/magic.go @@ -0,0 +1,31 @@ +package storage + +// #cgo LDFLAGS: -lmagic +// #include +// #include +import "C" +import "errors" +import "unsafe" + +var magic C.magic_t + +func init() { + magic = C.magic_open(C.MAGIC_MIME_TYPE | C.MAGIC_SYMLINK | C.MAGIC_ERROR) + if magic == nil { + panic("unable to initialize libmagic") + } + if C.magic_load(magic, nil) != 0 { + C.magic_close(magic) + panic("unable to load libmagic database: " + C.GoString(C.magic_error(magic))) + } +} + +func GetMimeType(fname string) (string, error) { + cfname := C.CString(fname) + defer C.free(unsafe.Pointer(cfname)) + mime := C.magic_file(magic, cfname) + if mime == nil { + return "", errors.New(C.GoString(C.magic_error(magic))) + } + return C.GoString(mime), nil +} diff --git a/storage/storage.go b/storage/storage.go new file mode 100644 index 0000000..3e2b816 --- /dev/null +++ b/storage/storage.go @@ -0,0 +1,280 @@ +package storage + +import ( + "crypto/sha1" + "encoding/base64" + "encoding/hex" + "errors" + "io" + "io/ioutil" + "math/rand" + "mime" + "os" + "path" + "path/filepath" + "strconv" + "strings" + "time" +) + +const ( + MaxIdTries = 64 + + DefaultIdCharset = "abcdefghijklmnopqrstuvwxyz" + DefaultIdLength = 6 + DefaultMaxSize = 50 * 1024 * 1024 +) + +type Storage struct { + Folder string + IdCharset string + IdLength int + MaxSize int64 + FilterMime []string + FilterExt []string + Whitelist bool +} + +type ErrForbidden struct{ Type string } + +func (e ErrForbidden) Error() string { return "forbidden type: " + e.Type } + +type ErrTooLarge struct{ Size int64 } + +func (e ErrTooLarge) Error() string { + return "file exceeds maximum allowed size of " + strconv.FormatInt(e.Size, 10) + " bytes" +} + +type ErrNotFound struct{ Name string } + +func (e ErrNotFound) Error() string { return "file " + e.Name + " not found" } + +func NewStorage(folder string) *Storage { + if err := os.MkdirAll(path.Join(folder, "temp"), 0755); err != nil { + panic(err) + } + if err := os.MkdirAll(path.Join(folder, "files"), 0755); err != nil { + panic(err) + } + if err := os.MkdirAll(path.Join(folder, "ids"), 0755); err != nil { + panic(err) + } + + return &Storage{ + Folder: folder, + IdCharset: DefaultIdCharset, + IdLength: DefaultIdLength, + MaxSize: DefaultMaxSize, + } +} + +func (s *Storage) Get(id string) (file *os.File, hash string, size int64, modtime time.Time, err error) { + ext := path.Ext(id) + id = id[:len(id)-len(ext)] + for i := 0; i < len(id); i++ { + if !strings.ContainsRune(s.IdCharset, rune(id[i])) { + err = errors.New("invalid ID: " + id) + return + } + } + folder := s.idToFolder("ids", id) + files, err := ioutil.ReadDir(folder) + if err != nil { + err = ErrNotFound{id + ext} + return + } + if len(files) < 1 { + err = errors.New("internal storage error") + return + } + fn := files[0].Name() + fp := path.Join(folder, fn) + target, err := os.Readlink(fp) + if err != nil { + return + } + bhash, err := base64.RawURLEncoding.DecodeString(path.Base(path.Dir(target))) + if err != nil { + return + } + hash = hex.EncodeToString(bhash) + if path.Ext(fn) != ext { + err = ErrNotFound{id + ext} + return + } + stat, err := os.Lstat(fp) + if err != nil { + return + } + modtime = stat.ModTime() + size = stat.Size() + file, err = os.Open(fp) + return +} + +var errFileExists = errors.New("file exists") + +func (s *Storage) New(r io.Reader, name string) (id, hash string, size int64, err error) { + temp, err := ioutil.TempFile(path.Join(s.Folder, "temp"), "file") + if err != nil { + return + } + defer func() { + if temp != nil { + temp.Close() + os.Remove(temp.Name()) + } + }() + + hash, size, err = s.readInput(temp, r) + if err != nil { + return + } + _, ext, err := s.getMimeExt(temp.Name(), name) + if err != nil { + return + } + id, err = s.storeFile(temp, hash, ext) + if err == nil { + temp = nil // prevent deletion + } else if err == errFileExists { + err = nil + } + + return +} + +func (s *Storage) randomId() string { + id := make([]byte, s.IdLength) + for i := 0; i < len(id); i++ { + id[i] = s.IdCharset[rand.Intn(len(s.IdCharset))] + } + return string(id) +} + +func (s *Storage) idToFolder(subfolder, id string) string { + name := id + for len(name) < 4 { + name = "_" + name + } + return path.Join(s.Folder, subfolder, name[0:2], name[2:4], id) +} + +func (s *Storage) readInput(w io.Writer, r io.Reader) (hash string, size int64, err error) { + h := sha1.New() + w = io.MultiWriter(h, w) + if s.MaxSize > 0 { + r = io.LimitReader(r, s.MaxSize+1) + } + size, err = io.Copy(w, r) + if err != nil { + return + } + if lr, ok := r.(*io.LimitedReader); ok && lr.N == 0 { + err = ErrTooLarge{s.MaxSize} + return + } + hash = base64.RawURLEncoding.EncodeToString(h.Sum(nil)) + return +} + +func (s *Storage) getMimeExt(fpath string, name string) (mimetype, ext string, err error) { + mimetype, err = GetMimeType(fpath) + if err != nil { + return + } + + // choose file extension, prefer the user-provided one + ext = path.Ext(name) + exts, err := mime.ExtensionsByType(mimetype) + if err != nil { + return + } + if ext != "" && find(exts, ext) == "" { + ext = "" + if len(exts) > 0 { + ext = exts[0] + } + } + + filtered, ok := s.findFilter(exts, mimetype) + if !ok && s.Whitelist { // whitelist: reject if not on filters + err = ErrForbidden{mimetype} + } else if ok && !s.Whitelist { // blacklist: reject if filtered + // only block application/octet-stream if explicitly requested + if mimetype != "application/octet-stream" || find(s.FilterMime, mimetype) != "" { + err = ErrForbidden{filtered} + } + } + + return +} + +func (s *Storage) findFilter(exts []string, mimetype string) (match string, ok bool) { + if m := find(s.FilterMime, mimetype); m == mimetype { + return m, true + } + for _, ext := range exts { + if e := find(s.FilterExt, ext); e == ext { + return e, true + } + } + return "", false +} + +func (s *Storage) storeFile(file *os.File, hash, ext string) (id string, err error) { + hfolder := s.idToFolder("files", hash) + hpath := path.Join(hfolder, "file") + fexists := false + + os.MkdirAll(path.Dir(hfolder), 0755) + err = os.Mkdir(hfolder, 0755) + if err != nil { + if _, err = os.Stat(hpath); err != nil { + err = errors.New("internal storage error") + } + fexists = true + } else { + err = os.Rename(file.Name(), hpath) + os.Chmod(hpath, 0644) + } + if err != nil { + return + } + + fpath := "" + for i := 0; i < MaxIdTries; i++ { + id = s.randomId() + dir := s.idToFolder("ids", id) + os.MkdirAll(path.Dir(dir), 0755) + err = os.Mkdir(dir, 0755) + if err == nil { + fpath = path.Join(dir, "file"+ext) + id += ext + break + } + } + if fpath == "" { + err = errors.New("internal storage error") + return + } + rhpath, err := filepath.Rel(path.Dir(fpath), hpath) + if err != nil { + return + } + err = os.Symlink(rhpath, fpath) + + if fexists && err == nil { + err = errFileExists + } + return +} + +func find(ss []string, search string) string { + for _, s := range ss { + if s == search { + return s + } + } + return "" +} diff --git a/website.go b/website.go index 7cc32c1..cf10586 100644 --- a/website.go +++ b/website.go @@ -80,8 +80,8 @@ func newContext() pageContext { SiteName: siteName, Abuse: abuseMail, Contact: contactMail, - MaxSizeBytes: storage.MaxSize, - MaxSize: humanize(storage.MaxSize), + MaxSizeBytes: uploads.MaxSize, + MaxSize: humanize(uploads.MaxSize), Pages: pages, } } -- cgit