aboutsummaryrefslogtreecommitdiffstats
path: root/storage/storage.go
diff options
context:
space:
mode:
Diffstat (limited to 'storage/storage.go')
-rw-r--r--storage/storage.go280
1 files changed, 280 insertions, 0 deletions
diff --git a/storage/storage.go b/storage/storage.go
new file mode 100644
index 0000000..3e2b816
--- /dev/null
+++ b/storage/storage.go
@@ -0,0 +1,280 @@
+package storage
+
+import (
+ "crypto/sha1"
+ "encoding/base64"
+ "encoding/hex"
+ "errors"
+ "io"
+ "io/ioutil"
+ "math/rand"
+ "mime"
+ "os"
+ "path"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "time"
+)
+
+const (
+ MaxIdTries = 64
+
+ DefaultIdCharset = "abcdefghijklmnopqrstuvwxyz"
+ DefaultIdLength = 6
+ DefaultMaxSize = 50 * 1024 * 1024
+)
+
+type Storage struct {
+ Folder string
+ IdCharset string
+ IdLength int
+ MaxSize int64
+ FilterMime []string
+ FilterExt []string
+ Whitelist bool
+}
+
+type ErrForbidden struct{ Type string }
+
+func (e ErrForbidden) Error() string { return "forbidden type: " + e.Type }
+
+type ErrTooLarge struct{ Size int64 }
+
+func (e ErrTooLarge) Error() string {
+ return "file exceeds maximum allowed size of " + strconv.FormatInt(e.Size, 10) + " bytes"
+}
+
+type ErrNotFound struct{ Name string }
+
+func (e ErrNotFound) Error() string { return "file " + e.Name + " not found" }
+
+func NewStorage(folder string) *Storage {
+ if err := os.MkdirAll(path.Join(folder, "temp"), 0755); err != nil {
+ panic(err)
+ }
+ if err := os.MkdirAll(path.Join(folder, "files"), 0755); err != nil {
+ panic(err)
+ }
+ if err := os.MkdirAll(path.Join(folder, "ids"), 0755); err != nil {
+ panic(err)
+ }
+
+ return &Storage{
+ Folder: folder,
+ IdCharset: DefaultIdCharset,
+ IdLength: DefaultIdLength,
+ MaxSize: DefaultMaxSize,
+ }
+}
+
+func (s *Storage) Get(id string) (file *os.File, hash string, size int64, modtime time.Time, err error) {
+ ext := path.Ext(id)
+ id = id[:len(id)-len(ext)]
+ for i := 0; i < len(id); i++ {
+ if !strings.ContainsRune(s.IdCharset, rune(id[i])) {
+ err = errors.New("invalid ID: " + id)
+ return
+ }
+ }
+ folder := s.idToFolder("ids", id)
+ files, err := ioutil.ReadDir(folder)
+ if err != nil {
+ err = ErrNotFound{id + ext}
+ return
+ }
+ if len(files) < 1 {
+ err = errors.New("internal storage error")
+ return
+ }
+ fn := files[0].Name()
+ fp := path.Join(folder, fn)
+ target, err := os.Readlink(fp)
+ if err != nil {
+ return
+ }
+ bhash, err := base64.RawURLEncoding.DecodeString(path.Base(path.Dir(target)))
+ if err != nil {
+ return
+ }
+ hash = hex.EncodeToString(bhash)
+ if path.Ext(fn) != ext {
+ err = ErrNotFound{id + ext}
+ return
+ }
+ stat, err := os.Lstat(fp)
+ if err != nil {
+ return
+ }
+ modtime = stat.ModTime()
+ size = stat.Size()
+ file, err = os.Open(fp)
+ return
+}
+
+var errFileExists = errors.New("file exists")
+
+func (s *Storage) New(r io.Reader, name string) (id, hash string, size int64, err error) {
+ temp, err := ioutil.TempFile(path.Join(s.Folder, "temp"), "file")
+ if err != nil {
+ return
+ }
+ defer func() {
+ if temp != nil {
+ temp.Close()
+ os.Remove(temp.Name())
+ }
+ }()
+
+ hash, size, err = s.readInput(temp, r)
+ if err != nil {
+ return
+ }
+ _, ext, err := s.getMimeExt(temp.Name(), name)
+ if err != nil {
+ return
+ }
+ id, err = s.storeFile(temp, hash, ext)
+ if err == nil {
+ temp = nil // prevent deletion
+ } else if err == errFileExists {
+ err = nil
+ }
+
+ return
+}
+
+func (s *Storage) randomId() string {
+ id := make([]byte, s.IdLength)
+ for i := 0; i < len(id); i++ {
+ id[i] = s.IdCharset[rand.Intn(len(s.IdCharset))]
+ }
+ return string(id)
+}
+
+func (s *Storage) idToFolder(subfolder, id string) string {
+ name := id
+ for len(name) < 4 {
+ name = "_" + name
+ }
+ return path.Join(s.Folder, subfolder, name[0:2], name[2:4], id)
+}
+
+func (s *Storage) readInput(w io.Writer, r io.Reader) (hash string, size int64, err error) {
+ h := sha1.New()
+ w = io.MultiWriter(h, w)
+ if s.MaxSize > 0 {
+ r = io.LimitReader(r, s.MaxSize+1)
+ }
+ size, err = io.Copy(w, r)
+ if err != nil {
+ return
+ }
+ if lr, ok := r.(*io.LimitedReader); ok && lr.N == 0 {
+ err = ErrTooLarge{s.MaxSize}
+ return
+ }
+ hash = base64.RawURLEncoding.EncodeToString(h.Sum(nil))
+ return
+}
+
+func (s *Storage) getMimeExt(fpath string, name string) (mimetype, ext string, err error) {
+ mimetype, err = GetMimeType(fpath)
+ if err != nil {
+ return
+ }
+
+ // choose file extension, prefer the user-provided one
+ ext = path.Ext(name)
+ exts, err := mime.ExtensionsByType(mimetype)
+ if err != nil {
+ return
+ }
+ if ext != "" && find(exts, ext) == "" {
+ ext = ""
+ if len(exts) > 0 {
+ ext = exts[0]
+ }
+ }
+
+ filtered, ok := s.findFilter(exts, mimetype)
+ if !ok && s.Whitelist { // whitelist: reject if not on filters
+ err = ErrForbidden{mimetype}
+ } else if ok && !s.Whitelist { // blacklist: reject if filtered
+ // only block application/octet-stream if explicitly requested
+ if mimetype != "application/octet-stream" || find(s.FilterMime, mimetype) != "" {
+ err = ErrForbidden{filtered}
+ }
+ }
+
+ return
+}
+
+func (s *Storage) findFilter(exts []string, mimetype string) (match string, ok bool) {
+ if m := find(s.FilterMime, mimetype); m == mimetype {
+ return m, true
+ }
+ for _, ext := range exts {
+ if e := find(s.FilterExt, ext); e == ext {
+ return e, true
+ }
+ }
+ return "", false
+}
+
+func (s *Storage) storeFile(file *os.File, hash, ext string) (id string, err error) {
+ hfolder := s.idToFolder("files", hash)
+ hpath := path.Join(hfolder, "file")
+ fexists := false
+
+ os.MkdirAll(path.Dir(hfolder), 0755)
+ err = os.Mkdir(hfolder, 0755)
+ if err != nil {
+ if _, err = os.Stat(hpath); err != nil {
+ err = errors.New("internal storage error")
+ }
+ fexists = true
+ } else {
+ err = os.Rename(file.Name(), hpath)
+ os.Chmod(hpath, 0644)
+ }
+ if err != nil {
+ return
+ }
+
+ fpath := ""
+ for i := 0; i < MaxIdTries; i++ {
+ id = s.randomId()
+ dir := s.idToFolder("ids", id)
+ os.MkdirAll(path.Dir(dir), 0755)
+ err = os.Mkdir(dir, 0755)
+ if err == nil {
+ fpath = path.Join(dir, "file"+ext)
+ id += ext
+ break
+ }
+ }
+ if fpath == "" {
+ err = errors.New("internal storage error")
+ return
+ }
+ rhpath, err := filepath.Rel(path.Dir(fpath), hpath)
+ if err != nil {
+ return
+ }
+ err = os.Symlink(rhpath, fpath)
+
+ if fexists && err == nil {
+ err = errFileExists
+ }
+ return
+}
+
+func find(ss []string, search string) string {
+ for _, s := range ss {
+ if s == search {
+ return s
+ }
+ }
+ return ""
+}