diff --git a/credential.go b/credential.go index 8168b42..3d7f4e8 100644 --- a/credential.go +++ b/credential.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "net/url" "strings" ) @@ -25,12 +26,45 @@ type credential struct { Prefix string } +func newCredential(endpoint, region, accessKey, secretKey, prefix, acl string) credential { + parsedEndpoint, _ := url.Parse(endpoint) + return credential{ + Endpoint: parsedEndpoint.String(), + Region: region, + AccessKey: accessKey, + SecretKey: secretKey, + Prefix: prefix, + ACL: acl, + } +} + func (cred credential) validate() error { + parsedEndpoint, err := url.Parse(cred.Endpoint) + if err != nil { + return fmt.Errorf("%w: endpoint must be a URL and not empty", errBadRequest) + } else if parsedEndpoint.Host == "" { + return fmt.Errorf("%w: endpoint must have a valid host", errBadRequest) + } else if parsedEndpoint.User != nil { + return fmt.Errorf("%w: endpoint must not have user credentials", errBadRequest) + } else if parsedEndpoint.RawQuery != "" { + return fmt.Errorf("%w: endpoint must not have query parameters", errBadRequest) + } else if parsedEndpoint.RawFragment != "" { + return fmt.Errorf("%w: endpoint must not have fragment", errBadRequest) + } else if parsedEndpoint.Scheme != "http" && parsedEndpoint.Scheme != "https" { + return fmt.Errorf("%w: endpoint must be http(s)", errBadRequest) + } + + if cred.Region == "" { + return fmt.Errorf("%w: region must not be empty", errBadRequest) + } + if strings.HasSuffix(cred.Endpoint, "/") { return fmt.Errorf("%w: endpoint should not end with slash", errBadRequest) } + if strings.HasPrefix(cred.Prefix, "/") { return fmt.Errorf("%w: prefix should not start with slash", errBadRequest) } + return nil } diff --git a/handlers-s3.go b/handlers-s3.go index b9e1b0d..6d30fa5 100644 --- a/handlers-s3.go +++ b/handlers-s3.go @@ -64,7 +64,7 @@ func handleCreateMultipartUpload(w http.ResponseWriter, req *http.Request) { } if err := r.validate(); err != nil { - errorResponse(w, req, fmt.Errorf("%w: %s", errBadRequest, err)) + errorResponse(w, req, err) return } @@ -73,7 +73,7 @@ func handleCreateMultipartUpload(w http.ResponseWriter, req *http.Request) { result, err := initiateMultipartUpload(key, cred) if err != nil { - errorResponse(w, req, fmt.Errorf("%w: %s", errInternalServerError, err)) + errorResponse(w, req, err) return } @@ -109,7 +109,7 @@ func handleGetUploadedParts(w http.ResponseWriter, req *http.Request) { for { page, err := listParts(key, uploadID, cred, nextPartNumberMarker) if err != nil { - errorResponse(w, req, fmt.Errorf("%w: %s", errInternalServerError, err)) + errorResponse(w, req, err) return } @@ -208,13 +208,13 @@ func handleCompleteMultipartUpload(w http.ResponseWriter, req *http.Request) { } if err := r.validate(); err != nil { - errorResponse(w, req, fmt.Errorf("%w: %s", errBadRequest, err)) + errorResponse(w, req, err) return } result, err := completeMultipartUpload(key, uploadID, r.Parts, cred) if err != nil { - errorResponse(w, req, fmt.Errorf("%w: %s", errInternalServerError, err)) + errorResponse(w, req, err) return } @@ -242,7 +242,7 @@ func handleAbortMultipartUpload(w http.ResponseWriter, req *http.Request) { err = abortMultipartUpload(key, uploadID, cred) if err != nil { - errorResponse(w, req, fmt.Errorf("%w: %s", errInternalServerError, err)) + errorResponse(w, req, err) return } } diff --git a/handlers.go b/handlers.go index 2ef5e47..4bdd1c6 100644 --- a/handlers.go +++ b/handlers.go @@ -114,15 +114,16 @@ type createReq struct { } func handleCreateForm(w http.ResponseWriter, req *http.Request) { - cred := credential{ - Endpoint: req.PostFormValue("Endpoint"), - Region: req.PostFormValue("Region"), - AccessKey: req.PostFormValue("AccessKey"), - SecretKey: req.PostFormValue("SecretKey"), - Prefix: req.PostFormValue("Prefix"), - } + cred := newCredential( + req.PostFormValue("Endpoint"), + req.PostFormValue("Region"), + req.PostFormValue("AccessKey"), + req.PostFormValue("SecretKey"), + req.PostFormValue("Prefix"), + req.PostFormValue("ACL"), + ) if err := cred.validate(); err != nil { - errorResponse(w, req, fmt.Errorf("%w: %s", errBadRequest, err)) + errorResponse(w, req, err) return } diff --git a/helpers.go b/helpers.go index 93ee136..250a36f 100644 --- a/helpers.go +++ b/helpers.go @@ -8,6 +8,8 @@ import ( var errNotFound = errors.New("not found") var errBadRequest = errors.New("bad request") var errInternalServerError = errors.New("internal server error") +var errUnauthorized = errors.New("unauthorized") +var errForbidden = errors.New("forbidden") func errorResponseStatus(w http.ResponseWriter, req *http.Request, err error) { errorStatus := http.StatusInternalServerError @@ -18,6 +20,10 @@ func errorResponseStatus(w http.ResponseWriter, req *http.Request, err error) { errorStatus = http.StatusBadRequest } else if errors.Is(err, errInternalServerError) { errorStatus = http.StatusInternalServerError + } else if errors.Is(err, errUnauthorized) { + errorStatus = http.StatusUnauthorized + } else if errors.Is(err, errForbidden) { + errorStatus = http.StatusForbidden } w.WriteHeader(errorStatus) @@ -31,3 +37,19 @@ func errorResponse(w http.ResponseWriter, req *http.Request, err error) { errorResponseStatus(w, req, err) w.Write([]byte(err.Error())) } + +// responseToError converts a HTTP status code to an error +func responseToError(resp *http.Response) error { + if resp.StatusCode == http.StatusNotFound { + return errNotFound + } else if resp.StatusCode == http.StatusBadRequest { + return errBadRequest + } else if resp.StatusCode == http.StatusInternalServerError { + return errInternalServerError + } else if resp.StatusCode == http.StatusUnauthorized { + return errUnauthorized + } else if resp.StatusCode == http.StatusForbidden { + return errForbidden + } + return nil +} diff --git a/main.go b/main.go index 05a09e5..545ad35 100644 --- a/main.go +++ b/main.go @@ -26,8 +26,8 @@ func main() { router.Methods(http.MethodGet).Path("/readyz").HandlerFunc(readyz) router.Methods(http.MethodGet).PathPrefix("/assets").HandlerFunc(handleAssets) - router.Methods(http.MethodGet).Path("/create").HandlerFunc(handleCreate) - router.Methods(http.MethodPost).Path("/create").HandlerFunc(handleCreateForm) + router.Methods(http.MethodGet).Path("/").HandlerFunc(handleCreate) + router.Methods(http.MethodPost).Path("/").HandlerFunc(handleCreateForm) uploadRouter := router.PathPrefix("/{id}").Subrouter() uploadTemplateRouter := uploadRouter.Path("").Subrouter() diff --git a/s3.go b/s3.go index 36520f7..10f924d 100644 --- a/s3.go +++ b/s3.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io/ioutil" + "log" "net/http" "net/url" "strconv" @@ -52,6 +53,32 @@ func stripETag(t string) string { return strings.TrimSuffix(strings.TrimPrefix(t, "\""), "\"") } +type errEndpoint struct { + err error + status string + body []byte +} + +func (e errEndpoint) Unwrap() error { + return e.err +} + +func (e errEndpoint) Error() string { + body := bytes.ReplaceAll(e.body, []byte("\n"), []byte("")) + if e.err != nil { + return fmt.Sprintf("endpoint responded with %v: %s", e.err, body) + } + return fmt.Sprintf("endpoint responded with %s: %s", e.status, body) +} + +func endpointReturnedError(resp *http.Response) error { + if resp.StatusCode < 200 || resp.StatusCode > 299 { + body, _ := ioutil.ReadAll(resp.Body) + return errEndpoint{responseToError(resp), resp.Status, body} + } + return nil +} + /* initiateMultipartUpload */ type initiateMultipartUploadResult struct { @@ -71,6 +98,7 @@ func initiateMultipartUpload( params.Set("uploads", "") unsignedReq, err := http.NewRequestWithContext(ctx, http.MethodPost, cred.Endpoint+"/"+key+"?"+params.Encode(), nil) if err != nil { + log.Printf("failure creating request: %v", err) return initiateMultipartUploadResult{}, err } if cred.ACL != "" { @@ -80,12 +108,14 @@ func initiateMultipartUpload( signedReq := sign(unsignedReq, cred) resp, err := httpClientS3.Do(signedReq) if err != nil { + log.Printf("failure connecting to endpoint: %v", err) return initiateMultipartUploadResult{}, err } defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - body, _ := ioutil.ReadAll(resp.Body) - return initiateMultipartUploadResult{}, fmt.Errorf("endpoint request failed: %d: %s", resp.StatusCode, body) + err = endpointReturnedError(resp) + if err != nil { + log.Printf("endpoint responded negatively: %v", err) + return initiateMultipartUploadResult{}, err } result := initiateMultipartUploadResult{} @@ -136,18 +166,21 @@ func listParts( params.Set("uploadId", uploadID) unsignedReq, err := http.NewRequestWithContext(ctx, http.MethodGet, cred.Endpoint+"/"+key+"?"+params.Encode(), nil) if err != nil { + log.Printf("failure creating request: %v", err) return listPartsResult{}, err } signedReq := sign(unsignedReq, cred) resp, err := httpClientS3.Do(signedReq) if err != nil { + log.Printf("failure connecting to endpoint: %v", err) return listPartsResult{}, err } defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - body, _ := ioutil.ReadAll(resp.Body) - return listPartsResult{}, fmt.Errorf("endpoint request failed: %d: %s", resp.StatusCode, body) + err = endpointReturnedError(resp) + if err != nil { + log.Printf("endpoint responded negatively: %v", err) + return listPartsResult{}, err } result := listPartsResult{} @@ -213,18 +246,21 @@ func completeMultipartUpload( params.Set("uploadId", uploadID) unsignedReq, err := http.NewRequestWithContext(ctx, http.MethodPost, cred.Endpoint+"/"+key+"?"+params.Encode(), &body) if err != nil { + log.Printf("failure creating request: %v", err) return completeMultipartUploadResult{}, err } signedReq := sign(unsignedReq, cred) resp, err := httpClientS3.Do(signedReq) if err != nil { + log.Printf("failure connecting to endpoint: %v", err) return completeMultipartUploadResult{}, err } defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - body, _ := ioutil.ReadAll(resp.Body) - return completeMultipartUploadResult{}, fmt.Errorf("endpoint request failed: %d: %s", resp.StatusCode, body) + err = endpointReturnedError(resp) + if err != nil { + log.Printf("endpoint responded negatively: %v", err) + return completeMultipartUploadResult{}, err } result := completeMultipartUploadResult{} @@ -252,18 +288,21 @@ func abortMultipartUpload( params.Set("uploadId", uploadID) unsignedReq, err := http.NewRequestWithContext(ctx, http.MethodDelete, cred.Endpoint+"/"+key+"?"+params.Encode(), nil) if err != nil { + log.Printf("failure creating request: %v", err) return err } signedReq := sign(unsignedReq, cred) resp, err := httpClientS3.Do(signedReq) if err != nil { + log.Printf("failure connecting to endpoint: %v", err) return err } defer resp.Body.Close() - if resp.StatusCode != http.StatusNoContent { - body, _ := ioutil.ReadAll(resp.Body) - return fmt.Errorf("endpoint request failed: %d: %s", resp.StatusCode, body) + err = endpointReturnedError(resp) + if err != nil { + log.Printf("endpoint responded negatively: %v", err) + return err } return nil diff --git a/store.go b/store.go index ee16d36..ed70c25 100644 --- a/store.go +++ b/store.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "log" "strings" "time" @@ -62,7 +63,7 @@ func (s *redisStore) ping() error { return err } if pong != "PONG" { - return errInternalServerError + return fmt.Errorf("%w: pong request failed", errInternalServerError) } return nil } @@ -74,6 +75,7 @@ func (s *redisStore) put(key string, data []byte, expire time.Duration) error { exists := 0 err := s.client.Do(ctx, radix.Cmd(&exists, "EXISTS", "upl:"+key)) if err != nil { + log.Printf("put failed on existence check: %v", err) return err } @@ -84,6 +86,7 @@ func (s *redisStore) put(key string, data []byte, expire time.Duration) error { expireS := int64(expire / time.Second) err = s.client.Do(ctx, radix.FlatCmd(nil, "SETEX", "upl:"+key, expireS, data)) if err != nil { + log.Printf("put failed: %v", err) return err } return nil @@ -96,6 +99,7 @@ func (s *redisStore) get(key string) ([]byte, error) { var data []byte err := s.client.Do(ctx, radix.Cmd(&data, "GET", "upl:"+key)) if err != nil { + log.Printf("get failed: %v", err) return nil, err } diff --git a/web/create.tmpl b/web/create.tmpl index 270a1cb..c01e773 100644 --- a/web/create.tmpl +++ b/web/create.tmpl @@ -1,17 +1,35 @@ -{{template "head.tmpl" "Create dropbox"}} -