diff options
-rw-r--r-- | client.go | 39 | ||||
-rw-r--r-- | cnp.go | 16 | ||||
-rw-r--r-- | common.go | 141 | ||||
-rw-r--r-- | common_test.go | 76 | ||||
-rw-r--r-- | error.go | 205 | ||||
-rw-r--r-- | header.go | 365 | ||||
-rw-r--r-- | header_test.go | 200 | ||||
-rw-r--r-- | message.go | 176 | ||||
-rw-r--r-- | message_test.go | 204 | ||||
-rw-r--r-- | request.go | 260 | ||||
-rw-r--r-- | request_test.go | 341 | ||||
-rw-r--r-- | response.go | 249 | ||||
-rw-r--r-- | response_test.go | 317 | ||||
-rw-r--r-- | server.go | 257 |
14 files changed, 2846 insertions, 0 deletions
diff --git a/client.go b/client.go new file mode 100644 index 0000000..b208637 --- /dev/null +++ b/client.go @@ -0,0 +1,39 @@ +package cnp + +import ( + "net" + "strconv" + "strings" +) + +// TODO: make more modular and extensible like net/http + +// Send sends a CNP request to a server and returns the response. +// +// The TCP connection is made using net.Dial. +func Send(r *Request) (*Response, error) { + host := r.Host() + if strings.LastIndexByte(host, ':') <= strings.LastIndexByte(host, ']') { // missing/default port + host = net.JoinHostPort(host, strconv.Itoa(DefaultPort)) + } + conn, err := net.Dial("tcp", host) + if err != nil { + return nil, err + } + if err = r.Write(conn); err != nil { + return nil, err + } + if err = r.Close(); err != nil { + return nil, err + } + return ParseResponse(conn) +} + +// Get sends a body-less request to a given URL. +func Get(url string) (*Response, error) { + req, err := NewRequestURL(url, nil) + if err != nil { + return nil, err + } + return Send(req) +} @@ -0,0 +1,16 @@ +// Package cnp provides CNP client and server implementations. +package cnp // import "contnet.org/lib/cnp-go" + +const ( + // DefaultPort represents the default port for the cnp:// schema. + DefaultPort = 25454 + + // MaxHeaderLength is the maximum byte size of the header. + MaxHeaderLength = 1 * 1024 * 1024 + + // VersionMajor is the major CNP version (X in cnp/X.Y). + VersionMajor = 0 + + // VersionMinor is the major CNP version (Y in cnp/X.Y). + VersionMinor = 3 +) diff --git a/common.go b/common.go new file mode 100644 index 0000000..e6173b0 --- /dev/null +++ b/common.go @@ -0,0 +1,141 @@ +package cnp + +import ( + "bufio" + "bytes" + "io" + "mime" + "strconv" + "strings" + "time" +) + +func readLimitedLine(br *bufio.Reader, length int) ([]byte, error) { + var buf bytes.Buffer + + for { + data, err := br.ReadSlice('\n') + if len(data) > 0 { + buf.Write(data) + } + if buf.Len() > length { + return nil, ErrorTooLarge{"header exceeds maximum permitted size"} + } + if err == nil || err == io.EOF { + return buf.Bytes(), nil + } + if err != bufio.ErrBufferFull { + return nil, ErrorSyntax{"invalid header: missing line feed"} + } + } +} + +func getInt(m *Message, param string) (int64, error) { + p := m.Param(param) + if p == "" { + return 0, nil + } + n, err := strconv.ParseUint(p, 10, 63) + if err != nil || (n != 0 && p[0] == '0') || (n == 0 && len(p) != 1) { + return int64(n), ErrorInvalid{"invalid parameter: " + param + " is not a valid integer"} + } + return int64(n), nil +} + +func setInt(m *Message, param string, n int64) { + if n < 0 { + n = 0 + } + m.SetParam(param, strconv.FormatInt(n, 10)) +} + +func getFilename(m *Message, param string) (string, error) { + name := m.Param(param) + if strings.ContainsAny(name, "/\x00") { + return name, ErrorInvalid{"invalid parameter: " + param + " contains invalid characters"} + } + return name, nil +} + +func setFilename(m *Message, param, name string) error { + if strings.ContainsAny(name, "/\x00") { + return ErrorInvalid{"invalid parameter: " + param + " contains invalid characters"} + } + m.SetParam(param, name) + return nil +} + +func getType(m *Message, param string) (typ string, err error) { + t := m.Param(param) + if t != "" { + var params map[string]string + typ, params, err = mime.ParseMediaType(t) + if err != nil || !validMimeType(t) || len(params) > 0 { // may not contain params + err = ErrorInvalid{"invalid parameter: " + param + " is not a valid mime type"} + } + } + if typ == "" { + typ = "application/octet-stream" + } + return +} + +func setType(m *Message, param, typ string) error { + if typ == "" { + m.SetParam(param, "") + return nil + } + + /*ss := strings.Split(typ, "/") + if len(ss) != 2 || len(ss[0]) == 0 || strings.ContainsAny(typ, "\x00 \n\t\r;,") { + return ErrorInvalid{…} + }*/ + + t := mime.FormatMediaType(typ, nil) + if t == "" || !validMimeType(typ) { + return ErrorInvalid{"invalid parameter: " + param + " is not a valid mime type"} + } + m.SetParam(param, t) + + return nil +} + +func validMimeType(typ string) bool { + ss := strings.Split(typ, "/") + if len(ss) != 2 || ss[0] == "" || ss[1] == "" { + return false + } + for _, r := range typ { + /*switch r { + case '(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '[', ']', '?', '=': + // tspecials except / + return false + }*/ // handled by mime.ParseMediaType + if r <= ' ' || r >= '\x7f' { + // control codes, whitespace, null + return false + } + } + return true +} + +func getTime(m *Message, param string) (time.Time, error) { + t := m.Param(param) + var z time.Time + if t == "" { + return z, nil + } + ts, err := time.Parse(time.RFC3339, t) + if err != nil || !strings.HasSuffix(t, "Z") { + return z, ErrorInvalid{"invalid parameter: " + param + " is not a valid RFC3339 timestamp"} + } + return ts, nil +} + +func setTime(m *Message, param string, t time.Time) { + if t.IsZero() { + m.SetParam(param, "") + } else { + m.SetParam(param, t.UTC().Format(time.RFC3339)) + } +} diff --git a/common_test.go b/common_test.go new file mode 100644 index 0000000..8e8c673 --- /dev/null +++ b/common_test.go @@ -0,0 +1,76 @@ +package cnp + +import ( + "bytes" + "io" + "io/ioutil" + "reflect" + "strings" +) + +func headerEqual(a, b Header) bool { + return a.Intent == b.Intent && a.VersionMajor == b.VersionMajor && a.VersionMinor == b.VersionMinor && paramEqual(a.Parameters, b.Parameters) +} + +func paramEqual(a, b Parameters) bool { + if len(a) != len(b) { + return false + } + for k := range a { + if a[k] != b[k] { + return false + } + } + return true +} + +func errorEqual(a, b error) bool { + if a == nil && b == nil { + return true + } + return reflect.TypeOf(a) == reflect.TypeOf(b) +} + +func msgEqual(a, b *Message) bool { + return headerEqual(a.Header, b.Header) && bodyEqual(a.Body, b.Body) +} + +func bodyEqual(a, b io.Reader) bool { + if a == nil && b == nil { + return true + } + if a == nil { + a = strings.NewReader("") + } + if b == nil { + b = strings.NewReader("") + } + ba, err := ioutil.ReadAll(a) + if err != nil { + panic(err) + } + if s, ok := a.(io.Seeker); ok { + _, _ = s.Seek(0, io.SeekStart) + } + bb, err := ioutil.ReadAll(b) + if err != nil { + panic(err) + } + if s, ok := b.(io.Seeker); ok { + _, _ = s.Seek(0, io.SeekStart) + } + return bytes.Equal(ba, bb) +} + +type testStringReader struct { + s string +} + +func (r *testStringReader) Read(b []byte) (n int, err error) { + if r.s == "" { + return 0, io.EOF + } + n = copy(b, r.s) + r.s = r.s[n:] + return +} diff --git a/error.go b/error.go new file mode 100644 index 0000000..b6270b9 --- /dev/null +++ b/error.go @@ -0,0 +1,205 @@ +package cnp + +// Error represents an error as used in a CNP error response. +type Error interface { + // CNPError returns the value of the "error" parameter in a CNP error + // response. + CNPError() string + + Error() string +} + +const ( + // ReasonSyntax represents the "syntax" reason parameter value. + ReasonSyntax = "syntax" + // ReasonVersion represents the "version" reason parameter value. + ReasonVersion = "version" + // ReasonInvalid represents the "invalid" reason parameter value. + ReasonInvalid = "invalid" + // ReasonNotSupported represents the "not_supported" reason parameter value. + ReasonNotSupported = "not_supported" + // ReasonTooLarge represents the "too_large" reason parameter value. + ReasonTooLarge = "too_large" + // ReasonNotFound represents the "not_found" reason parameter value. + ReasonNotFound = "not_found" + // ReasonDenied represents the "denied" reason parameter value. + ReasonDenied = "denied" + // ReasonRejected represents the "rejected" reason parameter value. + ReasonRejected = "rejected" + // ReasonServerError represents the "server_error" reason parameter value. + ReasonServerError = "server_error" +) + +// NewError returns a new Error based on a reason parameter value. +// +// If the reason is blank, nil is returned. +// If the reason is unknown, ErrorServerError is returned. +func NewError(reason string) Error { + switch reason { + case ReasonSyntax: + return ErrorSyntax{"Invalid CNP Message Syntax"} + case ReasonVersion: + return ErrorVersion{"Unsupported CNP Protocol Version"} + case ReasonInvalid: + return ErrorInvalid{"Invalid CNP Message"} + case ReasonNotSupported: + return ErrorNotSupported{"CNP Feature Not Supported"} + case ReasonTooLarge: + return ErrorTooLarge{"CNP Message Too Large"} + case ReasonNotFound: + return ErrorNotFound{"Not Found"} + case ReasonDenied: + return ErrorDenied{"Denied"} + case ReasonRejected: + return ErrorRejected{"Rejected"} + case "": + return nil + default: + fallthrough + case ReasonServerError: + return ErrorServerError{"Internal Server Error"} + } +} + +// ErrorSyntax represents the CNP "syntax" error reason. +type ErrorSyntax struct { + Reason string +} + +// CNPError on ErrorSyntax returns the error parameter value "syntax". +func (e ErrorSyntax) CNPError() string { + return ReasonSyntax +} + +func (e ErrorSyntax) Error() string { + return "CNP syntax error: " + e.Reason +} + +// ErrorVersion represents the CNP "version" error reason. +type ErrorVersion struct { + Reason string +} + +// CNPError on ErrorVersion returns the error parameter value "version". +func (e ErrorVersion) CNPError() string { + return ReasonVersion +} + +func (e ErrorVersion) Error() string { + return "Unsupported CNP version: " + e.Reason +} + +// ErrorInvalid represents the CNP "invalid" error reason. +type ErrorInvalid struct { + Reason string +} + +// CNPError on ErrorInvalid returns the error parameter value "invalid". +func (e ErrorInvalid) CNPError() string { + return ReasonInvalid +} + +func (e ErrorInvalid) Error() string { + return "Invalid CNP message: " + e.Reason +} + +// ErrorNotSupported represents the CNP "not_supported" error reason. +type ErrorNotSupported struct { + Reason string +} + +// CNPError on ErrorNotSupported returns the error parameter value +// "not_supported". +func (e ErrorNotSupported) CNPError() string { + return ReasonNotSupported +} + +func (e ErrorNotSupported) Error() string { + return "Requested CNP feature is not supported: " + e.Reason +} + +// ErrorTooLarge represents the CNP "too_large" error reason. +type ErrorTooLarge struct { + Reason string +} + +// CNPError on ErrorTooLarge returns the error parameter value "too_large". +func (e ErrorTooLarge) CNPError() string { + return ReasonTooLarge +} + +func (e ErrorTooLarge) Error() string { + return "CNP message is too large: " + e.Reason +} + +// ErrorNotFound represents the CNP "not_found" error reason. +type ErrorNotFound struct { + Reason string +} + +// CNPError on ErrorNotFound returns the error parameter value "not_found". +func (e ErrorNotFound) CNPError() string { + return ReasonNotFound +} + +func (e ErrorNotFound) Error() string { + return "Requested path was not found: " + e.Reason +} + +// ErrorDenied represents the CNP "denied" error reason. +type ErrorDenied struct { + Reason string +} + +// CNPError on ErrorDenied returns the error parameter value "denied". +func (e ErrorDenied) CNPError() string { + return ReasonDenied +} + +func (e ErrorDenied) Error() string { + return "Server denied access: " + e.Reason +} + +// ErrorRejected represents the CNP "rejected" error reason. +type ErrorRejected struct { + Reason string +} + +// CNPError on ErrorRejected returns the error parameter value "rejected". +func (e ErrorRejected) CNPError() string { + return ReasonRejected +} + +func (e ErrorRejected) Error() string { + return "Server rejected request: " + e.Reason +} + +// ErrorServerError represents the CNP "server_error" error reason. +type ErrorServerError struct { + Reason string +} + +// CNPError on ErrorServerError returns the error parameter value +// "server_error". +func (e ErrorServerError) CNPError() string { + return ReasonServerError +} + +func (e ErrorServerError) Error() string { + return "Internal server error: " + e.Reason +} + +// ErrorURL is a non-CNPError that represents an invalid CNP URL. +type ErrorURL struct { + // Err represents the error reason. + Err error + // URL is the URL that triggered the error. + URL string +} + +func (e ErrorURL) Error() string { + if e.Err == nil { + return "CNP URL error" + } + return e.Err.Error() +} diff --git a/header.go b/header.go new file mode 100644 index 0000000..e185dca --- /dev/null +++ b/header.go @@ -0,0 +1,365 @@ +package cnp + +import ( + "bufio" + "bytes" + "io" + "sort" + "strconv" +) + +const ( + // IntentOK represents the "ok" response intent. + IntentOK = "ok" + // IntentNotModified represents the "not_modified" response intent. + IntentNotModified = "not_modified" + // IntentError represents the "error" response intent. + IntentError = "error" + // IntentRedirect represents the "redirect" response intent. + IntentRedirect = "redirect" +) + +// Header represents a CNP message header +type Header struct { + // VersionMajor is the major CNP version given in the header. + VersionMajor int + // VersionMinor is the minor CNP version given in the header. + VersionMinor int + // Intent is the intent string of the message. + Intent string + // Parameters is a decoded map of the message parameters. + Parameters Parameters +} + +// NewHeader creates a new CNP header from an intent and optional parameter +// map. +func NewHeader(intent string, params Parameters) Header { + if params == nil { + params = make(Parameters) + } + return Header{ + VersionMajor: VersionMajor, + VersionMinor: VersionMinor, + Intent: intent, + Parameters: params, + } +} + +// ParseHeader parses a CNP header from a bytestring. +// The line parameter must be a single line that ends with a line feed. +func ParseHeader(line []byte) (h Header, err error) { + if bytes.IndexByte(line, '\x00') >= 0 { + err = ErrorSyntax{"invalid header: NUL byte"} + return + } + + if !bytes.HasPrefix(line, []byte("cnp/")) { + err = ErrorSyntax{"invalid header: no version string"} + return + } + line = line[len("cnp/"):] + + var o int + + h.VersionMajor, o, err = parseHeaderNumber(line) + if err != nil { + return + } + line = line[o:] + + if len(line) == 0 || line[0] != '.' { + err = ErrorSyntax{"invalid header: invalid version number tuple"} + return + } + line = line[1:] + h.VersionMinor, o, err = parseHeaderNumber(line) + if err != nil { + return + } + line = line[o:] + + if len(line) == 0 || line[0] != ' ' { + err = ErrorSyntax{"invalid header: expected a space"} + return + } + line = line[1:] + + h.Intent, o, err = parseHeaderIntent(line) + if err != nil { + return + } + line = line[o:] + + h.Parameters, o, err = parseHeaderParameters(line) + if err != nil { + return + } + line = line[o:] + + if len(line) != 1 || line[0] != '\n' { + err = ErrorSyntax{"invalid header: expected line feed and end"} + return + } + + return +} + +func parseHeaderNumber(line []byte) (n int, length int, err error) { + i := 0 + for ; i < len(line); i++ { + b := line[i] + if '0' <= b && b <= '9' { + } else { + break + } + } + n, err = strconv.Atoi(string(line[:i])) + if err != nil || (i > 1 && line[0] == '0') { + err = ErrorSyntax{"invalid header: invalid version number format"} + return + } + length = i + return +} + +func parseHeaderIntent(line []byte) (intent string, length int, err error) { + p := bytes.IndexAny(line, " \n=\x00") + if p < 0 { + p = len(line) - 1 + } + if p == 0 { + err = ErrorSyntax{"invalid header: empty intent"} + return + } + intent, err = Unescape(line[:p]) + length = p + return +} + +func parseHeaderParameters(line []byte) (params Parameters, length int, err error) { + params = make(Parameters) + + for len(line) > 1 { + if line[0] == '\n' { + break + } else if line[0] != ' ' { + err = ErrorSyntax{"invalid header: expected a space"} + return + } + line = line[1:] + length++ + + var k, v string + var o int + k, v, o, err = parseHeaderParameter(line) + if err != nil { + return + } + if _, ok := params[k]; ok { + err = ErrorSyntax{"invalid header: duplicate parameter"} + return + } + params[k] = v + line = line[o:] + length += o + } + + return +} + +func parseHeaderParameter(line []byte) (key, value string, length int, err error) { + var o int + key, o, err = parseHeaderParamString(line) + if err != nil { + return + } + line = line[o:] + length += o + + if len(line) == 0 || line[0] != '=' { + err = ErrorSyntax{"invalid header: expected an equals sign"} + return + } + line = line[1:] + length++ + + value, o, err = parseHeaderParamString(line) + if err != nil { + return + } + length += o + + return +} + +func parseHeaderParamString(line []byte) (s string, length int, err error) { + for ; length < len(line); length++ { + b := line[length] + if b == '\n' || b == ' ' || b == '=' { + break + } + } + s, err = Unescape(line[:length]) + return +} + +// Version returns the message's CNP version as a "cnp/X.Y" string. +func (h Header) Version() string { + return "cnp/" + strconv.Itoa(h.VersionMajor) + "." + strconv.Itoa(h.VersionMinor) +} + +// Write writes the CNP message header line in the wire format. +// The written line ends with a line feed. +func (h Header) Write(w io.Writer) (err error) { + bw := bufio.NewWriter(w) + _, err = bw.WriteString(h.Version()) + if err != nil { + return + } + err = bw.WriteByte(' ') + if err != nil { + return + } + _, err = bw.Write(Escape(h.Intent)) + if err != nil { + return + } + err = h.Parameters.Write(bw) + if err != nil { + return + } + err = bw.WriteByte('\n') + if err != nil { + return + } + return bw.Flush() +} + +func (h Header) String() string { + var buf bytes.Buffer + _ = h.Write(&buf) + return buf.String() +} + +// Parameters represents CNP message parameter key=value pairs. +type Parameters map[string]string + +// Write writes the parameters encoded for inclusion in the wire format. +// Includes a leading space. +func (p Parameters) Write(w io.Writer) (err error) { + bw := bufio.NewWriter(w) + keys := []string{} + for k := range p { + keys = append(keys, k) + } + sort.Strings(keys) + for _, k := range keys { + err = bw.WriteByte(' ') + if err != nil { + return + } + _, err = bw.Write(Escape(k)) + if err != nil { + return + } + err = bw.WriteByte('=') + if err != nil { + return + } + _, err = bw.Write(Escape(p[k])) + if err != nil { + return + } + } + return +} + +// Escape CNP-escapes the bytestring s. +func Escape(s string) []byte { + el := escapeLength(s) + if el == len(s) { + return []byte(s) + } + bs := make([]byte, el) + bi := 0 + for i := 0; i < len(s); i++ { + switch s[i] { + case '\x00': + bs[bi] = '\\' + bs[bi+1] = '0' + bi += 2 + + case '\n': + bs[bi] = '\\' + bs[bi+1] = 'n' + bi += 2 + + case ' ': + bs[bi] = '\\' + bs[bi+1] = '_' + bi += 2 + + case '=': + bs[bi] = '\\' + bs[bi+1] = '-' + bi += 2 + + case '\\': + bs[bi] = '\\' + bs[bi+1] = '\\' + bi += 2 + + default: + bs[bi] = s[i] + bi++ + } + } + return bs +} + +func escapeLength(s string) (l int) { + for i := 0; i < len(s); i++ { + switch s[i] { + case '\x00', '\n', ' ', '=', '\\': + l += 2 + default: + l++ + } + } + return +} + +// Unescape unescapes the bs from wire format into a bytestring. +func Unescape(bs []byte) (string, error) { + buf := make([]byte, len(bs)) + bi := 0 + + for i := 0; i < len(bs); i++ { + switch bs[i] { + case '\\': + i++ + if i >= len(bs) { + return string(buf[:bi]), ErrorSyntax{"invalid escape sequence: unexpected end of string"} + } + switch bs[i] { + case '0': + buf[bi] = '\x00' + case 'n': + buf[bi] = '\n' + case '_': + buf[bi] = ' ' + case '-': + buf[bi] = '=' + case '\\': + buf[bi] = '\\' + default: + return string(buf[:bi]), ErrorSyntax{"invalid escape sequence: undefined sequence"} + } + default: + buf[bi] = bs[i] + } + bi++ + } + + return string(buf[:bi]), nil +} diff --git a/header_test.go b/header_test.go new file mode 100644 index 0000000..b035078 --- /dev/null +++ b/header_test.go @@ -0,0 +1,200 @@ +package cnp + +import "testing" + +var escapes = map[string]string{ + "": ``, + "ContNet": `ContNet`, + " ": `\_`, + "=": `\-`, + "\n": `\n`, + "\x00": `\0`, + "\\": `\\`, + "a\nb c=d\x00e\\f": `a\nb\_c\-d\0e\\f`, + "\n\n\n": `\n\n\n`, + "===": `\-\-\-`, + " ": `\_\_\_`, + "\x00\x00\x00": `\0\0\0`, + "\\\\\\": `\\\\\\`, + " =\n\x00\\": `\_\-\n\0\\`, + "\b5Ὂg̀9! ℃ᾭG": "\b5Ὂg̀9!\\_℃ᾭG", + "\xff\x00\xee\xaa\xee": "\xff\\0\xee\xaa\xee", + "\x00\x10\x20\x30\x40": "\\0\x10\\_\x30\x40", + "\x10\x50\x90\xe0": "\x10\x50\x90\xe0", + "Hello, 世界": `Hello,\_世界`, + "\xed\x9f\xbf": "\xed\x9f\xbf", + "\xee\x80\x80": "\xee\x80\x80", + "\xef\xbf\xbd": "\xef\xbf\xbd", + "\x80\x80\x80\x80": "\x80\x80\x80\x80", +} + +func TestEscape(t *testing.T) { + for k, v := range escapes { + e := string(Escape(k)) + if e != v { + t.Errorf("Escape(%q) -> %q, expected %q", k, e, v) + } + } + for i := 0; i <= 255; i++ { + switch i { + case '\x00', ' ', '=', '\n', '\\': + continue + default: + s := string([]byte{byte(i)}) + b := Escape(s) + if s != string(b) { + t.Errorf("Escape(%q) -> %q, expected %q", s, b, s) + } + } + } +} + +func TestUnEscape(t *testing.T) { + for k, v := range escapes { + u, err := Unescape([]byte(v)) + if err != nil { + t.Errorf("Unescape(%q): error: %s", v, err) + } else if u != k { + t.Errorf("Unescape(%q) -> %q, expected %q", v, u, k) + } + } + for i := 0; i <= 255; i++ { + switch i { + case '\x00', ' ', '=', '\n', '\\': + continue + default: + b := []byte{byte(i)} + s, err := Unescape(b) + if err != nil { + t.Errorf("Unescape(%q): error: %s", b, err) + } else if string(b) != s { + t.Errorf("Escape(%q) -> %q, expected %q", b, s, b) + } + } + } +} + +var headers = map[string]struct { + h Header + e error +}{ + // invalid version + "cnp/0.3ok\n": {e: ErrorSyntax{}}, + "cwp/0.3 ok\n": {e: ErrorSyntax{}}, + "cnp/0.03 ok\n": {e: ErrorSyntax{}}, + "cnp/00.3 ok\n": {e: ErrorSyntax{}}, + "cnp/0..3 ok\n": {e: ErrorSyntax{}}, + "cnp/.3 ok\n": {e: ErrorSyntax{}}, + "cnp/0. ok\n": {e: ErrorSyntax{}}, + "cnp/. ok\n": {e: ErrorSyntax{}}, + "cnp/0,3 ok\n": {e: ErrorSyntax{}}, + "/0.3 ok\n": {e: ErrorSyntax{}}, + "0.3 ok\n": {e: ErrorSyntax{}}, + "cnp/ ok\n": {e: ErrorSyntax{}}, + "cnp ok\n": {e: ErrorSyntax{}}, + "cnp.0.3 ok\n": {e: ErrorSyntax{}}, + "cnp/03 ok\n": {e: ErrorSyntax{}}, + "cnp/3 ok\n": {e: ErrorSyntax{}}, + "cnp/0 ok\n": {e: ErrorSyntax{}}, + "cnp/0 3 ok\n": {e: ErrorSyntax{}}, + "cnp/0/3 ok\n": {e: ErrorSyntax{}}, + + // missing/invalid intent + "cnp/0.3\n": {e: ErrorSyntax{}}, + "cnp/0.3 \n": {e: ErrorSyntax{}}, + "cnp/0.3 o\x00k\n": {e: ErrorSyntax{}}, + "cnp/0.3 foo=bar\n": {e: ErrorSyntax{}}, + + // missing/invalid line end + "cnp/0.3 ok \n": {e: ErrorSyntax{}}, + "cnp/0.3 ok\n\n": {e: ErrorSyntax{}}, + "cnp/0.3 ok": {e: ErrorSyntax{}}, + "cnp/0.3 ok ": {e: ErrorSyntax{}}, + "cnp/0.3 ok foo=bar \n": {e: ErrorSyntax{}}, + "cnp/0.3 ok foo=bar": {e: ErrorSyntax{}}, + "cnp/0.3 ok = =\n": {e: ErrorSyntax{}}, + + // spaces + "cnp/0.3 ok foo=bar\n": {e: ErrorSyntax{}}, + "cnp/0.3 ok\n": {e: ErrorSyntax{}}, + "cnp/0.3\tok\n": {e: ErrorSyntax{}}, + + // invalid params + "cnp/0.3 ok foo==bar\n": {e: ErrorSyntax{}}, + "cnp/0.3 ok foo=bar=baz\n": {e: ErrorSyntax{}}, + "cnp/0.3 ok foo=bar baz=quux \n": {e: ErrorSyntax{}}, + "cnp/0.3 ok foo\\-bar\n": {e: ErrorSyntax{}}, + + // invalid escape sequences + "cnp/0.3 o\\k\n": {e: ErrorSyntax{}}, + "cnp/0.3 ok qwe=\\\n": {e: ErrorSyntax{}}, + "cnp/0.3 ok fo\\o=bar\n": {e: ErrorSyntax{}}, + + // valid + "cnp/0.0 ok\n": {h: Header{0, 0, "ok", nil}}, + "cnp/0.3 ok\n": {h: Header{0, 3, "ok", nil}}, + "cnp/1.0 ok\n": {h: Header{1, 0, "ok", nil}}, + "cnp/123456.987654 ok\n": {h: Header{123456, 987654, "ok", nil}}, + "cnp/0.1 ok\n": {h: Header{0, 1, "ok", nil}}, + "cnp/0.3 ok\r\n": {h: Header{0, 3, "ok\r", nil}}, + "cnp/0.3 foo\\nbar\n": {h: Header{0, 3, "foo\nbar", nil}}, + "cnp/0.3 \\-\\_\\n\\0\\\\\n": {h: Header{0, 3, "= \n\x00\\", nil}}, + + // valid with params + "cnp/0.3 ok type=text/plain\n": {h: Header{0, 3, "ok", Parameters{"type": "text/plain"}}}, + "cnp/0.3 ok baz=quux foo=bar qwe=asd\n": {h: Header{0, 3, "ok", Parameters{"foo": "bar", "baz": "quux", "qwe": "asd"}}}, + "cnp/0.3 ok = \\_=\\0 \\-=\\n\r\n": {h: Header{0, 3, "ok", Parameters{"": "", "=": "\n\r", " ": "\x00"}}}, +} + +func TestHeaderParse(t *testing.T) { + for raw, tst := range headers { + hdr, err := ParseHeader([]byte(raw)) + if err != nil || tst.e != nil { + if !errorEqual(err, tst.e) { + t.Errorf("ParseHeader(%q): expected error %+v, got %+v (%v)", raw, tst.e, err, err) + } + } else if !headerEqual(tst.h, hdr) { + t.Errorf("ParseHeader(%q): expected %+v, got %+v", raw, tst.h, hdr) + } + } +} + +func TestHeaderCompose(t *testing.T) { + for raw, tst := range headers { + if tst.e != nil { + continue + } + // Parameters.Write currently sorts parameter keys + /*if len(tst.h.Parameters) > 1 { + continue // can't depend on parameter order + }*/ + str := tst.h.String() + if raw != str { + t.Errorf("%+v.String(): expected %q, got %q", tst.h, raw, str) + } + } +} + +func TestNewHeader(t *testing.T) { + raw := "cnp/0.3 ok baz=quux foo=bar qwe=asd\n" + hdr := Header{0, 3, "ok", Parameters{"foo": "bar", "baz": "quux", "qwe": "asd"}} + h := NewHeader("ok", Parameters{"foo": "bar", "baz": "quux", "qwe": "asd"}) + if !headerEqual(hdr, h) { + t.Errorf("%+v: expected %+v", h, hdr) + } + s := h.String() + if raw != s { + t.Errorf("%q: expected %q", h, hdr) + } + + raw = "cnp/0.3 ok\n" + hdr = Header{0, 3, "ok", nil} + h = NewHeader("ok", nil) + if !headerEqual(hdr, h) { + t.Errorf("%+v: expected %+v", h, hdr) + } + s = h.String() + if raw != s { + t.Errorf("%q: expected %q", h, hdr) + } +} diff --git a/message.go b/message.go new file mode 100644 index 0000000..eb3d2b9 --- /dev/null +++ b/message.go @@ -0,0 +1,176 @@ +package cnp + +import ( + "bufio" + "bytes" + "io" + "io/ioutil" + "strings" +) + +// Message represents a CNP message. +type Message struct { + Header Header + Body io.Reader + closer io.Closer +} + +// NewMessage creates a new CNP message. +// +// This method also calls Message.TryComputeLength(). +func NewMessage(intent string, body io.Reader) *Message { + msg := &Message{ + Header: NewHeader(intent, nil), + Body: body, + } + + msg.TryComputeLength() + + return msg +} + +// ParseMessage parses a CNP message. +// +// The message's Body field is set to a bufio.Reader wrapping r. If r is a +// io.Closer, it is also stored separately for usage with Message.Close(). +func ParseMessage(r io.Reader) (*Message, error) { + br := bufio.NewReader(r) + + line, err := readLimitedLine(br, MaxHeaderLength) + if err != nil { + return nil, err + } + + h, err := ParseHeader(line) + if err != nil { + return nil, err + } + + msg := &Message{ + Header: h, + Body: br, + } + + if rc, ok := r.(io.Closer); ok { + msg.closer = rc + } + + return msg, nil +} + +// Close attempts to close the message body reader. +// +// If the message's Body field is an io.Closer, its Close() method is called. +// If Body is not an io.Closer and the message was created with ParseMessage +// provided with an io.Closer, the Close() method on the original reader will +// be called. +// Otherwise, this function does nothing. +func (msg *Message) Close() error { + if rc, ok := msg.Body.(io.Closer); ok { + return rc.Close() + } + if msg.closer != nil { + return msg.closer.Close() + } + return nil +} + +// ComputeLength sets the length header parameter based on the message body. +// First, msg.TryComputeLength() is attempted; if that fails, the request is +// fully read into a buffer and msg.Body is set to a bytes.Reader. +func (msg *Message) ComputeLength() error { + if !msg.TryComputeLength() { + buf, err := ioutil.ReadAll(msg.Body) + if len(buf) > 0 { + msg.Body = bytes.NewReader(buf) + msg.SetLength(int64(len(buf))) + } + if err != nil { + return err + } + } + + return nil +} + +// TryComputeLength sets the length header parameter to the length of the +// message body if it's one of *bytes.Buffer, *bytes.Reader or *strings.Reader +// and returns true. If msg.Body is nil, the length parameter is unset and the +// function returns true. Otherwise, false is returned. +func (msg *Message) TryComputeLength() bool { + switch v := msg.Body.(type) { + case *bytes.Buffer: + msg.SetLength(int64(v.Len())) + case *bytes.Reader: + msg.SetLength(int64(v.Len())) + case *strings.Reader: + msg.SetLength(int64(v.Len())) + case nil: + msg.SetLength(0) + default: + return false + } + return true +} + +// SetLength sets the length header parameter to n. +func (msg *Message) SetLength(n int64) { + if n == 0 { + msg.SetParam("length", "") + } else { + setInt(msg, "length", n) + } +} + +// Length gets the length header parameter (or 0 if it's not set or invalid). +func (msg *Message) Length() int64 { + n, err := getInt(msg, "length") + if err != nil { + return 0 + } + return n +} + +// Param retrieves a header parameter. +func (msg *Message) Param(key string) string { + return msg.Header.Parameters[key] +} + +// SetParam sets a header parameter. If the value is empty, the parameter is +// unset. +func (msg *Message) SetParam(key, value string) { + if len(value) == 0 { + delete(msg.Header.Parameters, key) + } else { + msg.Header.Parameters[key] = value + } +} + +// Intent retrieves the message header intent. +func (msg *Message) Intent() string { + return msg.Header.Intent +} + +// SetIntent sets the message header intent. +func (msg *Message) SetIntent(s string) { + msg.Header.Intent = s +} + +// Validate validates the message header parameter value format (length). +func (msg *Message) Validate() error { + _, err := getInt(msg, "length") + return err +} + +// Write writes the message header and body to w. +func (msg *Message) Write(w io.Writer) error { + if err := msg.Header.Write(w); err != nil { + return err + } + if msg.Body != nil { + if _, err := io.Copy(w, msg.Body); err != nil { + return err + } + } + return nil +} diff --git a/message_test.go b/message_test.go new file mode 100644 index 0000000..6955fc3 --- /dev/null +++ b/message_test.go @@ -0,0 +1,204 @@ +package cnp + +import ( + "bytes" + "io" + "strconv" + "strings" + "testing" +) + +var ( + messageTests = []messageTest{ + { + "qweasd", "text/plain", nil, nil, + "cnp/0.3 test1 length=6 type=text/plain\nqweasd", + &Message{ + Header: Header{ + VersionMajor: 0, + VersionMinor: 3, + Intent: "test1", + Parameters: Parameters{ + "type": "text/plain", + "length": "6", + }, + }, + Body: strings.NewReader("qweasd"), + }, + }, + + { + "qweasd", "text", nil, ErrorInvalid{}, + "cnp/0.3 test2 length=w type=text/plain\nqweasd", + &Message{ + Header: Header{ + VersionMajor: 0, + VersionMinor: 3, + Intent: "test2", + Parameters: Parameters{ + "type": "text/plain", + "length": "w", + }, + }, + Body: strings.NewReader("qweasd"), + }, + }, + + { + "", "text/plain", ErrorSyntax{}, nil, + "cnp/0.3 test3 type=text/plain", + nil, + }, + + { + "", "text/plain", nil, nil, + "cnp/0.3 test4 type=text/plain\n", + &Message{ + Header: Header{ + VersionMajor: 0, + VersionMinor: 3, + Intent: "test4", + Parameters: Parameters{ + "type": "text/plain", + }, + }, + Body: nil, + }, + }, + } +) + +type messageTest struct { + b, t string + p, v error + s string + m *Message +} + +func (tst *messageTest) Reset() { + if tst.m != nil && tst.b != "" { + tst.m.Body = strings.NewReader(tst.b) + } +} + +func TestParse(t *testing.T) { + for _, tst := range messageTests { + tst.Reset() + msg, err := ParseMessage(strings.NewReader(tst.s)) + if tst.p != nil || err != nil { + if !errorEqual(tst.p, err) { + t.Errorf("ParseMessage(%q): expected error %+v, got %+v (%v)", tst.s, tst.p, err, err) + continue + } + } else if !msgEqual(msg, tst.m) { + t.Errorf("\nexpected: %+v\ngot: %+v", tst.m, msg) + } else if l := msg.Length(); tst.v == nil && l != int64(len(tst.b)) { + t.Errorf("%+v: expected length %d, got %d", tst.m, len(tst.b), l) + } else if err = msg.Validate(); !errorEqual(err, tst.v) { + t.Errorf("%+v: expected validation error %+v, got %+v (%s)", msg, tst.v, err, err) + } + } +} + +func TestNew(t *testing.T) { + for i, tst := range messageTests { + if tst.v != nil || tst.p != nil { + continue // skip invalid messages + } + tst.Reset() + var r io.Reader + if tst.b != "" { + r = strings.NewReader(tst.b) + } + msg := NewMessage("test"+strconv.Itoa(i+1), r) + msg.SetParam("type", tst.t) + if !msgEqual(msg, tst.m) { + t.Errorf("\nexpected: %+v\ngot: %+v", tst.m, msg) + } + } +} + +func TestWrite(t *testing.T) { + for _, tst := range messageTests { + if tst.p != nil { + continue // skip invalid messages + } + tst.Reset() + var buf bytes.Buffer + err := tst.m.Write(&buf) + if err != nil { + t.Errorf("%+v: error writing message body: %s", tst.m, err) + } + if buf.String() != tst.s { + t.Errorf("%+v:\nexpected: %q\ngot: %q", tst.m, tst.s, buf.String()) + } + } +} + +func TestComputeLength(t *testing.T) { + s := "qweasd" + r := &testStringReader{"qweasd"} + msg := NewMessage("ok", r) + if _, ok := msg.Header.Parameters["length"]; msg.Length() != 0 || msg.Param("length") != "" || ok { + t.Fatalf("%+v: expected no length parameter", msg) + } + if err := msg.ComputeLength(); err != nil { + t.Fatalf("%+v.ComputeLength(): error: %s", msg, err) + } + if l := msg.Length(); l != int64(len(s)) { + t.Fatalf("%+v: invalid length %d, expected %d\n", msg, l, len(s)) + } + if _, ok := msg.Body.(*testStringReader); r.s != "" || ok || r == msg.Body { + t.Fatalf("%+v: did not buffer body correctly", msg) + } + if !bodyEqual(msg.Body, strings.NewReader(s)) { + t.Fatalf("%+v: incorrect buffered body", msg) + } +} + +func TestParseTooLarge(t *testing.T) { + s := "cnp/0.3 ok text=" + str := s + strings.Repeat(".", MaxHeaderLength-len(s)) + "\n" + _, err := ParseMessage(strings.NewReader(str)) + if _, ok := err.(ErrorTooLarge); !ok { + t.Errorf("\nexpected: ErrorTooLarge\ngot: %+v", err) + } +} + +type noopCloser struct { + r io.Reader + closed bool +} + +func (n *noopCloser) Close() error { + n.closed = true + return nil +} + +func (n *noopCloser) Read(b []byte) (int, error) { + return n.r.Read(b) +} + +func TestClose(t *testing.T) { + var r io.Reader = strings.NewReader("cnp/0.3 ok\nqweasd") + if msg, err := ParseMessage(r); err != nil { + t.Errorf("ParseMessage error: %v", err) + } else if err = msg.Close(); err != nil { + t.Errorf("Error closing message: %v", err) + } + r = &noopCloser{r: strings.NewReader("cnp/0.3 ok\nqweasd")} + if msg, err := ParseMessage(r); err != nil { + t.Errorf("ParseMessage error: %v", err) + } else if err = msg.Close(); err != nil { + t.Errorf("Error closing message: %v", err) + } else if !r.(*noopCloser).closed { + t.Errorf("Reader was not closed") + } + r = &noopCloser{r: strings.NewReader("qweasd")} + msg := NewMessage("ok", r) + if err := msg.Close(); err != nil { + t.Errorf("Error closing message: %v", err) + } else if !r.(*noopCloser).closed { + t.Errorf("Reader was not closed") + } +} diff --git a/request.go b/request.go new file mode 100644 index 0000000..8c5a34e --- /dev/null +++ b/request.go @@ -0,0 +1,260 @@ +package cnp + +import ( + "bytes" + "errors" + "io" + "net/url" + "path" + "strconv" + "strings" + "time" +) + +// Request represents a CNP request message. +type Request struct { + Message +} + +// NewRequest creates a new Request from a host, path and optional body data. +func NewRequest(host, pth string, body []byte) (*Request, error) { + var r io.Reader + if body != nil { + r = bytes.NewReader(body) + } + req := &Request{*NewMessage("/", r)} + if err := req.SetHostPath(host, pth); err != nil { + return nil, err + } + return req, nil +} + +// NewRequestURL creates a new Request from a URL and body data. +func NewRequestURL(urlStr string, body []byte) (*Request, error) { + // XXX: handle //example.com/path URLs + if strings.HasPrefix(urlStr, "//") { + urlStr = "cnp:" + urlStr + } + + u, err := url.ParseRequestURI(urlStr) + if err != nil { + return nil, ErrorURL{err, urlStr} + } + + if u.Scheme != "cnp" && u.Scheme != "" { + return nil, ErrorURL{errors.New("NewRequestURL: URL is not a cnp:// URL"), urlStr} + } + if u.Opaque != "" { + return nil, ErrorURL{errors.New("NewRequestURL: CNP URL may not contain opaque data"), urlStr} + } + if u.User != nil { + return nil, ErrorURL{errors.New("NewRequestURL: CNP URL cannot may not contain userinfo"), urlStr} + } + + host := u.Hostname() + if strings.ContainsRune(host, ':') { // IPv6 + host = "[" + host + "]" + } + port := DefaultPort + if sp := u.Port(); sp != "" { + port, err = strconv.Atoi(sp) + if err != nil { + return nil, ErrorURL{err, urlStr} + } + } + if port != DefaultPort { + host = host + ":" + strconv.Itoa(port) + } + + pth := u.Path + if pth == "" { + pth = "/" + } + /*if u.RawQuery != "" { + q, err := url.QueryUnescape(u.RawQuery) + if err != nil { + return nil, ErrorURL{err, urlStr} + } + pth = pth + "?" + q + }*/ + + return NewRequest(host, pth, body) +} + +// ParseRequest parses a request message. +func ParseRequest(r io.Reader) (*Request, error) { + msg, err := ParseMessage(r) + if err != nil { + return nil, err + } + + if err = validateRequestIntent(msg.Intent()); err != nil { + return nil, err + } + + return &Request{*msg}, nil +} + +// SetHost sets the host part of the request intent, leaving path unchanged. +func (r *Request) SetHost(host string) error { + return r.SetHostPath(host, r.Path()) +} + +// SetPath sets the path part of the request intent, leaving host unchanged. +func (r *Request) SetPath(pth string) error { + return r.SetHostPath(r.Host(), pth) +} + +// SetHostPath sets the request intent. +func (r *Request) SetHostPath(host, pth string) error { + if len(pth) < 1 || pth[0] != '/' { + return ErrorInvalid{"invalid request: invalid path"} + } + if strings.ContainsRune(host, '/') { + return ErrorInvalid{"invalid request: invalid host"} + } + r.SetIntent(host + Clean(pth)) + return nil +} + +// Host returns the host part of the request intent. +func (r *Request) Host() string { + host, _ := r.HostPath() + return host +} + +// Path returns the path part of the request intent. +func (r *Request) Path() string { + _, pth := r.HostPath() + return pth +} + +// HostPath returns the host and path parts of the request intent. +func (r *Request) HostPath() (host string, pth string) { + ss := strings.SplitN(r.Intent(), "/", 2) + if len(ss) != 2 { + return "", "/" + } + host = ss[0] + pth = "/" + ss[1] + return +} + +// URL returns a cnp:// URL based on this request's intent. +func (r *Request) URL() *url.URL { + var u url.URL + u.Scheme = "cnp" + u.Host = r.Host() + u.Path = r.Path() + return &u +} + +// Name retrieves the name request parameter. +// +// If the name request parameter is not a valid filename, an empty string is +// returned. +func (r *Request) Name() string { + name, err := getFilename(&r.Message, "name") + if err != nil { + return "" + } + return name +} + +// SetName sets the name request parameter. +// +// An error is raised if the name includes characters not valid in a filename +// (slash, null byte). +func (r *Request) SetName(name string) error { + return setFilename(&r.Message, "name", name) +} + +// Type retrieves the type request parameter. +// +// If the type request parameter is invalid or empty, the default value +// "application/octet-stream" is returned. +func (r *Request) Type() string { + typ, _ := getType(&r.Message, "type") + return typ +} + +// SetType sets the type request parameter. +// +// An error is raised if typ is not a valid format for a MIME type. +func (r *Request) SetType(typ string) error { + return setType(&r.Message, "type", typ) +} + +// IfModified retrieves the if_modified request parameter. +// +// If the parameter isn't a valid RFC3339 timestamp, a zero time.Time is +// returned. +func (r *Request) IfModified() time.Time { + t, err := getTime(&r.Message, "if_modified") + if err != nil { + return time.Time{} + } + return t +} + +// SetIfModified sets the if_modified request parameter. +// +// If t is the zero time value, the if_modified parameter is unset. +func (r *Request) SetIfModified(t time.Time) { + setTime(&r.Message, "if_modified", t) +} + +// Validate validates the request header intent and parameter value format +// (length, name, type, if_modified) +func (r *Request) Validate() error { + if err := validateRequestIntent(r.Intent()); err != nil { + return err + } + if err := r.Message.Validate(); err != nil { + return err + } + if _, err := getFilename(&r.Message, "name"); err != nil { + return err + } + if _, err := getType(&r.Message, "type"); err != nil { + return err + } + if _, err := getTime(&r.Message, "if_modified"); err != nil { + return err + } + return nil +} + +func validateRequestIntent(intent string) error { + ss := strings.SplitN(intent, "/", 2) + if len(ss) != 2 { + return ErrorInvalid{"invalid request: invalid intent"} + } + host, pth := ss[0], ss[1] + + if strings.ContainsAny(host, "\x00 ") || strings.ContainsRune(pth, '\x00') { + return ErrorInvalid{"invalid request: invalid intent"} + } + + return nil +} + +// Write ensures that the request's length parameter is set if it has body and +// then writes it to w. +func (r *Request) Write(w io.Writer) error { + if _, ok := r.Header.Parameters["length"]; !ok { + if err := r.ComputeLength(); err != nil { + return err + } + } + return r.Message.Write(w) +} + +// Clean cleans a CNP intent path. +func Clean(s string) string { + c := path.Clean(s) + if len(s) > 0 && len(c) > 0 && s[len(s)-1] == '/' && c[len(c)-1] != '/' { + return c + "/" + } + return c +} diff --git a/request_test.go b/request_test.go new file mode 100644 index 0000000..bb4aa9b --- /dev/null +++ b/request_test.go @@ -0,0 +1,341 @@ +package cnp + +import ( + "bytes" + "strconv" + "strings" + "testing" + "time" +) + +type requestTest struct { + h, s string + p Parameters + e, v error +} + +var requestTests = []requestTest{ + // invalid intent + {"", "", nil, ErrorInvalid{}, nil}, + {"", "foo/bar", nil, ErrorInvalid{}, nil}, + {"foo/", "bar", nil, ErrorInvalid{}, nil}, + {"foo/", "/bar", nil, ErrorInvalid{}, nil}, + + // invalid request params + {"", "/", Parameters{"length": "w"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"length": "-1"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"length": "03"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"name": "foo/bar"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"name": "foo\x00bar"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"type": "foo"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"type": "\x00"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"type": "text/plain\x00"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"type": "foo/bar "}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"type": " foo/bar"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"type": "foo /bar"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"type": "foo/ bar"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"type": "foo/bar\n"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"type": "foo/b(r"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"if_modified": "0"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"if_modified": "now"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"if_modified": "today"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"if_modified": "Thu Jan 1 00:00:00 UTC 1970"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"if_modified": "1970-01-01 00:00:00+00:00"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"if_modified": "1970-01-01 00:00:00 UTC"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"if_modified": "1970-01-01 00:00:00+0000"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"if_modified": "1970-01-01 00:00:00+00"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"if_modified": "1970-01-01T00:00:00+00:00"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"if_modified": "1970-01-01T00:00:00 UTC"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"if_modified": "1970-01-01T00:00:00+0000"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"if_modified": "1970-01-01T00:00:00+00"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"if_modified": "0000-00-01T00:00:00Z"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"if_modified": "0000-01-00T00:00:00Z"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"if_modified": "0000-01-01T24:00:00Z"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"if_modified": "0000-01-01T00:60:00Z"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"if_modified": "0000-01-01T00:00:60Z"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"if_modified": "0000-11-31T00:00:00Z"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"if_modified": "0001-02-29T00:00:00Z"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"if_modified": "0002-02-29T00:00:00Z"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"if_modified": "0003-02-29T00:00:00Z"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"if_modified": "0005-02-29T00:00:00Z"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"if_modified": "0100-02-29T00:00:00Z"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"if_modified": "1000-02-29T00:00:00Z"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"if_modified": "123-01-01T00:00:00Z"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"if_modified": "12345-01-01T00:00:00Z"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"if_modified": "-5-01-01T00:00:00Z"}, nil, ErrorInvalid{}}, + {"", "/", Parameters{"if_modified": "-2005-01-01T00:00:00Z"}, nil, ErrorInvalid{}}, + + // valid simple requests + {"", "/", nil, nil, nil}, + {"", "/foo/bar", nil, nil, nil}, + {"cnp.example.com", "/", nil, nil, nil}, + {"foo", "/bar", nil, nil, nil}, + {"example.com", "/ f=#\\oo///.././.../~/\x01/\xff/ba\nr", nil, nil, nil}, + + // valid request params + {"", "/", Parameters{"length": "", "name": "", "type": "", "if_modified": "", "": "", "q\x00we": "=a s\nd"}, nil, nil}, + {"", "/", Parameters{"length": "0"}, nil, nil}, + {"", "/", Parameters{"length": "1"}, nil, nil}, + {"", "/", Parameters{"length": "12345670089000000"}, nil, nil}, + {"", "/", Parameters{"name": "foobar"}, nil, nil}, + {"", "/", Parameters{"name": "x"}, nil, nil}, + {"", "/", Parameters{"name": "..-~!foo bar\nbaz\rquux=qwe\\asd"}, nil, nil}, + {"", "/", Parameters{"name": strings.Repeat("w", 1024*8)}, nil, nil}, + {"", "/", Parameters{"type": "qwe/asd"}, nil, nil}, + {"", "/", Parameters{"type": "qwe+asd/foo"}, nil, nil}, + {"", "/", Parameters{"type": "qwe-asd/foo"}, nil, nil}, + {"", "/", Parameters{"if_modified": "1970-01-01T00:00:00Z"}, nil, nil}, + {"", "/", Parameters{"if_modified": "0000-01-01T00:00:00Z"}, nil, nil}, + {"", "/", Parameters{"if_modified": "9999-12-31T23:59:59Z"}, nil, nil}, + {"", "/", Parameters{"if_modified": "0123-05-06T07:08:09Z"}, nil, nil}, + {"", "/", Parameters{"if_modified": "0000-02-29T00:00:00Z"}, nil, nil}, + {"", "/", Parameters{"if_modified": "2000-02-29T00:00:00Z"}, nil, nil}, +} + +func TestNewRequest(t *testing.T) { + for _, tst := range requestTests { + req, err := NewRequest(tst.h, tst.s, []byte{}) + if !errorEqual(tst.e, err) { + t.Errorf("NewRequest(%q, %q): expected error %+v, got %+v (%v)", tst.h, tst.s, tst.e, err, err) + continue + } + if tst.e == nil { + if tst.h != req.Host() { + t.Errorf("NewRequest(%q: %q): got unexpected host %q", tst.h, tst.s, req.Host()) + } else if Clean(tst.s) != req.Path() { + t.Errorf("NewRequest(%q: %q): got unexpected path %q", tst.h, tst.s, req.Path()) + } + } + } +} + +type requestURLTest struct { + u string + h, p string + e error +} + +var requestURLTests = []requestURLTest{ + // invalid + {"cnp:example.com/path/to/file", "", "", ErrorURL{}}, + {"http://example.com/path/to/file", "", "", ErrorURL{}}, + {"cnp://foo@bar:example.com/path/to/file", "", "", ErrorURL{}}, + {"cnp://foo:example.com/path/to/file", "", "", ErrorURL{}}, + {"", "", "/", ErrorURL{}}, + {"cnp://example.com/%5", "", "", ErrorURL{}}, + //{"cnp://example.com/?%5", "", "", ErrorURL{}}, + + // valid + {"cnp://example.com/path/to/file", "example.com", "/path/to/file", nil}, + {"cnp://2130706433/", "2130706433", "/", nil}, + //{"cnp://02130706433/", "2130706433", "/", nil}, + {"cnp://127.0.0.1/", "127.0.0.1", "/", nil}, + //{"cnp://0127.0.00.01/", "127.0.0.1", "/", nil}, + {"cnp://localhost", "localhost", "/", nil}, + {"cnp://[::1]/foo%20bar", "[::1]", "/foo bar", nil}, + {"cnp://[2001:db8::7334]/foo%0abar", "[2001:db8::7334]", "/foo\nbar", nil}, + //{"cnp://[2001:0db8:0000::7334]/foo%0abar", "[2001:db8::7334]", "/foo\nbar", nil}, + //{"cnp://example.com/qwe%20asd?foo=bar&baz=qu%20ux", "example.com", "/qwe asd?foo=bar&baz=qu ux", nil}, + {"cnp://example.com/qwe%20asd?foo=bar&baz=qu%20ux", "example.com", "/qwe asd", nil}, + {"cnp://localhost:25454", "localhost", "/", nil}, + {"cnp://localhost:12345/foo/bar", "localhost:12345", "/foo/bar", nil}, + {"cnp://localhost:25454/foo/bar", "localhost", "/foo/bar", nil}, + {"cnp://2130706433:12345/foo/bar", "2130706433:12345", "/foo/bar", nil}, + {"cnp://2130706433:25454/foo/bar", "2130706433", "/foo/bar", nil}, + {"cnp://127.0.0.1:12345/foo/bar", "127.0.0.1:12345", "/foo/bar", nil}, + {"cnp://127.0.0.1:25454/foo/bar", "127.0.0.1", "/foo/bar", nil}, + {"cnp://[::1]:12345/foo/bar", "[::1]:12345", "/foo/bar", nil}, + {"cnp://[::1]:25454/foo/bar", "[::1]", "/foo/bar", nil}, + {"//example.com/path/", "example.com", "/path/", nil}, + {"/example.com/path/", "", "/example.com/path/", nil}, + {"/foo/bar", "", "/foo/bar", nil}, + {"/", "", "/", nil}, + {"cnp://example.com/ ™☺)\n", "example.com", "/ ™☺)\n", nil}, + {"cnp://œ¤å₥¶ḹə.©°ɱ/baz/../foo/.//bar////.//../bar//", "œ¤å₥¶ḹə.©°ɱ", "/foo/bar/", nil}, + {"cnp://foo/bar/", "foo", "/bar/", nil}, + {"cnp:///foo/bar/", "", "/foo/bar/", nil}, + {"cnp:////////foo/bar/", "", "/foo/bar/", nil}, +} + +func TestNewRequestURL(t *testing.T) { + for _, tst := range requestURLTests { + req, err := NewRequestURL(tst.u, nil) + if !errorEqual(tst.e, err) { + t.Errorf("NewRequestURL(%q, nil): expected error %+v, got %+v (%v)", tst.u, tst.e, err, err) + continue + } + if tst.e == nil { + if req.Host() != tst.h { + t.Errorf("NewRequestURL(%q, nil): expected host %q, got %q", tst.u, tst.h, req.Host()) + } else if req.Path() != tst.p { + t.Errorf("NewRequestURL(%q, nil): expected path %q, got %q", tst.u, tst.p, req.Path()) + } + } + } +} + +func TestWriteParseRequest(t *testing.T) { + for _, tst := range requestTests { + if tst.e != nil { + continue // skip syntax errors but not invalid params + } + req := &Request{Message{Header{0, 3, "/", Parameters{}}, nil, nil}} + if err := req.SetHost(tst.h); err != nil { + t.Errorf("SetHost(%q) error: %v", tst.h, err) + continue + } + if err := req.SetPath(tst.s); err != nil { + t.Errorf("SetPath(%q) error: %v", tst.s, err) + continue + } + for k, v := range tst.p { + req.SetParam(k, v) + } + + var buf bytes.Buffer + if err := req.Write(&buf); err != nil { + t.Errorf("%+v.Write: error %v", req, err) + continue + } + + s := buf.String() + req2, err := ParseRequest(strings.NewReader(s)) + if err != nil { + t.Errorf("ParseRequest(%q) error: %v", s, err) + continue + } + if !msgEqual(&req.Message, &req2.Message) { + t.Errorf("Write/Parse: expected %+v, got %+v", req, req2) + continue + } + } +} + +func TestRequestValidate(t *testing.T) { + for _, tst := range requestTests { + if tst.e != nil { + continue + } + req := &Request{Message{Header{Intent: tst.h + tst.s, Parameters: tst.p}, nil, nil}} + if err := req.Validate(); !errorEqual(tst.v, err) { + t.Errorf("%+v.Validate(): expected error %+v, got %+v (%v)", req, tst.v, err, err) + } + } +} + +func TestRequestGetSet(t *testing.T) { + e := func(k, v string, expect, err error) { + if !errorEqual(err, expect) { + t.Errorf("setting param %q to %q: expected %+v, got %+v (%v)", k, v, expect, err, err) + } + } + c := func(k, v string, expect error, p string) { + if expect == nil && v != p { + t.Errorf("getting param %q: expected %q, got %q", k, v, p) + } + } + + for _, tst := range requestTests { + if tst.e != nil { + continue + } + req, _ := NewRequest(tst.h, tst.s, nil) + for k, v := range tst.p { + switch k { + case "length": + n, _ := strconv.ParseUint(v, 10, 63) + req.SetLength(int64(n)) + s := strconv.FormatInt(req.Length(), 10) + if s == "0" && v == "" { + s = "" + } + c(k, v, tst.v, s) + case "name": + e(k, v, tst.v, req.SetName(v)) + c(k, v, tst.v, req.Name()) + case "type": + e(k, v, tst.v, req.SetType(v)) + t := req.Type() + if t == "application/octet-stream" && v == "" { + t = "" + } + c(k, v, tst.v, t) + case "if_modified": + if tst.v != nil { + continue // invalid time format, can't even parse + } + var tm time.Time + if v != "" { + tm, _ = time.Parse(time.RFC3339, v) + } + req.SetIfModified(tm) + tm2 := "" + if !req.IfModified().IsZero() { + tm2 = req.IfModified().Format(time.RFC3339) + } + c(k, v, tst.v, tm2) + default: + req.SetParam(k, v) + c(k, v, tst.v, req.Param(k)) + } + } + } +} + +var urlTests = map[string]string{ + "œ¤å₥¶ḹə.©°ɱ/foo/bar": "cnp://%C5%93%C2%A4%C3%A5%E2%82%A5%C2%B6%E1%B8%B9%C9%99.%C2%A9%C2%B0%C9%B1/foo/bar", + "2130706433/": "cnp://2130706433/", + "example.com/path/to/file/": "cnp://example.com/path/to/file/", + "[::1]/foo bar": "cnp://[::1]/foo%20bar", + "[2001:db8::7334]/foo\nbar": "cnp://[2001:db8::7334]/foo%0Abar", + "example.com/qwe asd": "cnp://example.com/qwe%20asd", + "example.com/ ™☺)\n": "cnp://example.com/%20%E2%84%A2%E2%98%BA%29%0A", +} + +func TestRequestURL(t *testing.T) { + for k, v := range urlTests { + t.Run(v, func(t *testing.T) { + req, err := NewRequestURL(v, nil) + if err != nil { + t.Fatalf("NewRequestURL(%q): error: %v", v, err) + } + if req.Intent() != k { + t.Fatalf("NewRequestURL(%q).Intent(): expected %q, got %q", v, k, req.Intent()) + } + if req.URL().String() != v { + t.Fatalf("NewRequestURL(%q).URL(): expected %q, got %q", v, v, req.URL()) + } + }) + } +} + +var cleanTests = map[string]string{ + "/../.././//foo/bar/..": "/foo", + "//foo/bar/../": "/foo/", + "/": "/", + "//": "/", + "/..": "/", + "/../..": "/", + "/../": "/", + "/.": "/", + "/./": "/", + "/./.": "/", + "/././": "/", + "/./../": "/", + "/.././": "/", + "/foo/bar/../baz": "/foo/baz", + "/foo/bar/../baz/": "/foo/baz/", + "/foo/../foo/bar/../bar/baz/quux/../": "/foo/bar/baz/", + "/foo/../foo/bar/../bar/baz/quux/..": "/foo/bar/baz", +} + +func TestClean(t *testing.T) { + for k, v := range cleanTests { + t.Run(k, func(t *testing.T) { + c := Clean(k) + if c != v { + t.Fatalf("Clean(%q): expected %q, got %q", k, v, c) + } + }) + } +} diff --git a/response.go b/response.go new file mode 100644 index 0000000..bf23f20 --- /dev/null +++ b/response.go @@ -0,0 +1,249 @@ +package cnp + +import ( + "bytes" + "io" + "strings" + "time" +) + +// Response represents a CNP response message. +type Response struct { + Message +} + +// NewResponse creates a new Response from a response intent and optional body +// data. +func NewResponse(intent string, body []byte) (resp *Response, err error) { + var r io.Reader + if body != nil { + r = bytes.NewReader(body) + } + resp = &Response{*NewMessage("", r)} + err = resp.SetResponseIntent(intent) + return +} + +var responseIntents = map[string]bool{ + IntentOK: true, + IntentNotModified: true, + IntentRedirect: true, + IntentError: true, +} + +// ParseResponse parses a response message. +func ParseResponse(r io.Reader) (*Response, error) { + msg, err := ParseMessage(r) + if msg == nil { + return nil, err + } + return &Response{*msg}, err +} + +// ResponseIntent retrieves the response intent. +// If the intent is unknown, the "error" intent is returned. +func (r *Response) ResponseIntent() string { + s := r.Intent() + if !responseIntents[s] { + return IntentError + } + return s +} + +// SetResponseIntent sets the response intent. +// If the provided intent is not a known response intent, an error is returned +// and the intent is set to "error". +func (r *Response) SetResponseIntent(intent string) error { + if !responseIntents[intent] { + r.SetIntent(IntentError) + return ErrorInvalid{"invalid response: unknown response intent"} + } + r.SetIntent(intent) + return nil +} + +// Name retrieves the name response parameter. +// +// If the name response parameter is not a valid filename, an empty string is +// returned. +func (r *Response) Name() string { + name, err := getFilename(&r.Message, "name") + if err != nil { + return "" + } + return name +} + +// SetName sets the name response parameter. +// +// An error is raised if the name includes characters not valid in a filename +// (slash, null byte). +func (r *Response) SetName(name string) error { + return setFilename(&r.Message, "name", name) +} + +// Type retrieves the type response parameter. +// +// If the type response parameter is invalid or empty, the default value +// "application/octet-stream" is returned. +func (r *Response) Type() string { + typ, _ := getType(&r.Message, "type") + return typ +} + +// SetType sets the type response parameter. +// +// An error is raised if typ is not a valid format for a MIME type. +func (r *Response) SetType(typ string) error { + return setType(&r.Message, "type", typ) +} + +// Time retrieves the time response parameter. +// +// If the parameter isn't a valid RFC3339 timestamp, a zero time.Time is +// returned. +func (r *Response) Time() time.Time { + t, err := getTime(&r.Message, "time") + if err != nil { + return time.Time{} + } + return t +} + +// SetTime sets the time response parameter. +// +// If t is the zero time value, the time parameter is unset. +func (r *Response) SetTime(t time.Time) { + setTime(&r.Message, "time", t) +} + +// Modified retrieves the modified Response parameter. +// +// If the parameter isn't a valid RFC3339 timestamp, a zero time.Time is +// returned. +func (r *Response) Modified() time.Time { + t, err := getTime(&r.Message, "modified") + if err != nil { + return time.Time{} + } + return t +} + +// SetModified sets the modified response parameter. +// +// If the time response parameter is empty, it's set to the current time. +// If t is the zero time value, the modified parameter is unset. +func (r *Response) SetModified(t time.Time) { + setTime(&r.Message, "modified", t) + if r.Time().IsZero() { + r.SetTime(time.Now()) + } +} + +// Location retrieves the host and path from the location response +// parameter. +// +// If the location parameter is empty, it returns empty host and path. If the +// location parameter is invalid, an error is returned. +func (r *Response) Location() (host, path string, err error) { + l := r.Param("location") + if l == "" { + return "", "", nil + } + if err := validateRequestIntent(l); err != nil { + return "", "", err + } + ss := strings.SplitN(l, "/", 2) + return ss[0], "/" + ss[1], nil +} + +// SetLocation sets the location response parameter to host and path. +// +// If the host or path are invalid +func (r *Response) SetLocation(host, path string) error { + if strings.ContainsRune(host, '/') { + return ErrorInvalid{"invalid response: invalid location parameter"} + } + l := host + path + if err := validateRequestIntent(l); err != nil { + return ErrorInvalid{"invalid response: invalid location parameter"} + } + r.SetParam("location", l) + return nil +} + +var responseErrorReasons = map[string]bool{ + ReasonSyntax: true, + ReasonVersion: true, + ReasonInvalid: true, + ReasonNotSupported: true, + ReasonTooLarge: true, + ReasonNotFound: true, + ReasonDenied: true, + ReasonRejected: true, + ReasonServerError: true, + "": true, +} + +// Reason retrieves the reason response parameter. +// +// If the reason is nonempty and unknown, "server_error" is returned. +func (r *Response) Reason() string { + reason := r.Param("reason") + if _, ok := responseErrorReasons[reason]; ok { + return reason + } + return ReasonServerError +} + +// SetReason sets the reason response parameter. +// +// If the reason is nonempty and unknown, it sets "server_error" instead. +func (r *Response) SetReason(reason string) error { + if _, ok := responseErrorReasons[reason]; !ok { + r.SetParam("reason", ReasonServerError) + return ErrorInvalid{"invalid response: unknown error reason"} + } + r.SetParam("reason", reason) + return nil +} + +// Validate validates the response intent and header parameter value format +// (length, name, type, time, modified, location, reason) +func (r *Response) Validate() error { + if !responseIntents[r.Intent()] { + return ErrorInvalid{"invalid response: unknown response intent"} + } + if err := r.Message.Validate(); err != nil { + return err + } + if _, err := getFilename(&r.Message, "name"); err != nil { + return err + } + if _, err := getType(&r.Message, "type"); err != nil { + return err + } + if _, err := getTime(&r.Message, "time"); err != nil { + return err + } + if _, err := getTime(&r.Message, "modified"); err != nil { + return err + } + if h, p, err := r.Location(); err != nil { + return err + } else if r.ResponseIntent() == IntentRedirect && h == "" && p == "" { + return ErrorInvalid{"invalid response: redirect response needs location parameter"} + } + if !responseErrorReasons[r.Param("reason")] { + return ErrorInvalid{"invalid response: unknown error reason"} + } + return nil +} + +// Write writes the response to w. +func (r *Response) Write(w io.Writer) error { + if _, ok := r.Header.Parameters["length"]; !ok { + r.TryComputeLength() + } + return r.Message.Write(w) +} diff --git a/response_test.go b/response_test.go new file mode 100644 index 0000000..8033335 --- /dev/null +++ b/response_test.go @@ -0,0 +1,317 @@ +package cnp + +import ( + "bytes" + "strconv" + "strings" + "testing" + "time" +) + +var ( + responseTests = []responseTest{ + // invalid intent + {"", nil, ErrorInvalid{}, ErrorInvalid{}}, + {"foo", nil, ErrorInvalid{}, ErrorInvalid{}}, + {"example.com/path", nil, ErrorInvalid{}, ErrorInvalid{}}, + {"errors", nil, ErrorInvalid{}, ErrorInvalid{}}, + {" ok", nil, ErrorInvalid{}, ErrorInvalid{}}, + {"ok ", nil, ErrorInvalid{}, ErrorInvalid{}}, + {"оk", nil, ErrorInvalid{}, ErrorInvalid{}}, + + // invalid response params + {"ok", Parameters{"length": "w"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"length": "-1"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"length": "03"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"name": "/.."}, nil, ErrorInvalid{}}, + {"ok", Parameters{"name": "/"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"name": "foo/bar"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"name": "foo/bar"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"name": "foo\x00bar"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"type": "foo"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"type": "\x00"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"type": "text/plain\x00"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"type": "foo/bar "}, nil, ErrorInvalid{}}, + {"ok", Parameters{"type": " foo/bar"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"type": "foo /bar"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"type": "foo/ bar"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"type": "foo/bar\n"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"type": "foo/b(r"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"time": "0"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"time": "now"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"time": "today"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"time": "Thu Jan 1 00:00:00 UTC 1970"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"time": "1970-01-01 00:00:00+00:00"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"time": "1970-01-01 00:00:00 UTC"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"time": "1970-01-01 00:00:00+0000"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"time": "1970-01-01 00:00:00+00"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"time": "1970-01-01T00:00:00+00:00"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"time": "1970-01-01T00:00:00 UTC"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"time": "1970-01-01T00:00:00+0000"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"time": "1970-01-01T00:00:00+00"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"time": "0000-00-01T00:00:00Z"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"time": "0000-01-00T00:00:00Z"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"time": "0000-01-01T24:00:00Z"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"time": "0000-01-01T00:60:00Z"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"time": "0000-01-01T00:00:60Z"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"time": "0000-11-31T00:00:00Z"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"time": "0001-02-29T00:00:00Z"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"time": "0002-02-29T00:00:00Z"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"time": "0003-02-29T00:00:00Z"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"time": "0005-02-29T00:00:00Z"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"time": "0100-02-29T00:00:00Z"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"time": "1000-02-29T00:00:00Z"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"time": "123-01-01T00:00:00Z"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"time": "12345-01-01T00:00:00Z"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"time": "-5-01-01T00:00:00Z"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"time": "-2005-01-01T00:00:00Z"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"modified": "0"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"modified": "now"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"modified": "today"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"modified": "Thu Jan 1 00:00:00 UTC 1970"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"modified": "1970-01-01 00:00:00+00:00"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"modified": "1970-01-01 00:00:00 UTC"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"modified": "1970-01-01 00:00:00+0000"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"modified": "1970-01-01 00:00:00+00"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"modified": "1970-01-01T00:00:00+00:00"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"modified": "1970-01-01T00:00:00 UTC"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"modified": "1970-01-01T00:00:00+0000"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"modified": "1970-01-01T00:00:00+00"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"modified": "0000-00-01T00:00:00Z"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"modified": "0000-01-00T00:00:00Z"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"modified": "0000-01-01T24:00:00Z"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"modified": "0000-01-01T00:60:00Z"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"modified": "0000-01-01T00:00:60Z"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"modified": "0000-11-31T00:00:00Z"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"modified": "0001-02-29T00:00:00Z"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"modified": "0002-02-29T00:00:00Z"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"modified": "0003-02-29T00:00:00Z"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"modified": "0005-02-29T00:00:00Z"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"modified": "0100-02-29T00:00:00Z"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"modified": "1000-02-29T00:00:00Z"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"modified": "123-01-01T00:00:00Z"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"modified": "12345-01-01T00:00:00Z"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"modified": "-5-01-01T00:00:00Z"}, nil, ErrorInvalid{}}, + {"ok", Parameters{"modified": "-2005-01-01T00:00:00Z"}, nil, ErrorInvalid{}}, + {"redirect", Parameters{"location": "foo"}, nil, ErrorInvalid{}}, + {"redirect", Parameters{"location": "foo bar/baz quux"}, nil, ErrorInvalid{}}, + {"redirect", Parameters{"location": "/foo\x00bar"}, nil, ErrorInvalid{}}, + {"error", Parameters{"reason": "ok"}, nil, ErrorInvalid{}}, + {"error", Parameters{"reason": "not supported"}, nil, ErrorInvalid{}}, + {"error", Parameters{"reason": "syntax\n"}, nil, ErrorInvalid{}}, + {"error", Parameters{"reason": " server_error"}, nil, ErrorInvalid{}}, + {"error", Parameters{"reason": "invalid "}, nil, ErrorInvalid{}}, + + // invalid: redirect *requires* the location parameter + {"redirect", nil, nil, ErrorInvalid{}}, + {"redirect", Parameters{"location": ""}, nil, ErrorInvalid{}}, + + // valid simple responses + {"ok", nil, nil, nil}, + {"not_modified", nil, nil, nil}, + {"error", nil, nil, nil}, + + // valid responses with parameters + {"ok", Parameters{"length": "", "name": "", "type": "", "time": "", "reason": "", "location": "", "modified": "", "": "", "q\x00we": "=a s\nd"}, nil, nil}, + + {"ok", Parameters{"length": "0"}, nil, nil}, + {"ok", Parameters{"length": "1"}, nil, nil}, + {"ok", Parameters{"length": "12345670089000000"}, nil, nil}, + {"ok", Parameters{"length": "12345678900"}, nil, nil}, + + {"ok", Parameters{"name": "foobar"}, nil, nil}, + {"ok", Parameters{"name": "foo bar"}, nil, nil}, + {"ok", Parameters{"name": "foo=bar\nbaz\rquux"}, nil, nil}, + {"ok", Parameters{"name": "..-~!foo bar\nbaz\rquux=qwe\\asd"}, nil, nil}, + {"ok", Parameters{"name": strings.Repeat("w", 1024*8)}, nil, nil}, + {"ok", Parameters{"name": " "}, nil, nil}, + {"ok", Parameters{"name": ".."}, nil, nil}, + {"ok", Parameters{"name": "."}, nil, nil}, + + {"ok", Parameters{"type": "foo/bar"}, nil, nil}, + {"ok", Parameters{"type": "application/octet-stream"}, nil, nil}, + {"ok", Parameters{"type": "x-test/x-testing"}, nil, nil}, + {"ok", Parameters{"type": "application/vnd.testing.test-test.5+xml"}, nil, nil}, + + {"ok", Parameters{"modified": "1970-01-01T00:00:00Z"}, nil, nil}, + {"ok", Parameters{"modified": "0000-01-01T00:00:00Z"}, nil, nil}, + {"ok", Parameters{"modified": "9999-12-31T23:59:59Z"}, nil, nil}, + {"ok", Parameters{"modified": "0123-05-06T07:08:09Z"}, nil, nil}, + {"ok", Parameters{"modified": "0000-02-29T00:00:00Z"}, nil, nil}, + {"ok", Parameters{"modified": "2000-02-29T00:00:00Z"}, nil, nil}, + + {"ok", Parameters{"time": "1970-01-01T00:00:00Z"}, nil, nil}, + {"ok", Parameters{"time": "0000-01-01T00:00:00Z"}, nil, nil}, + {"ok", Parameters{"time": "9999-12-31T23:59:59Z"}, nil, nil}, + {"ok", Parameters{"time": "0123-05-06T07:08:09Z"}, nil, nil}, + {"ok", Parameters{"time": "0000-02-29T00:00:00Z"}, nil, nil}, + {"ok", Parameters{"time": "2000-02-29T00:00:00Z"}, nil, nil}, + + {"error", Parameters{"reason": "syntax"}, nil, nil}, + {"error", Parameters{"reason": "version"}, nil, nil}, + {"error", Parameters{"reason": "invalid"}, nil, nil}, + {"error", Parameters{"reason": "not_supported"}, nil, nil}, + {"error", Parameters{"reason": "too_large"}, nil, nil}, + {"error", Parameters{"reason": "not_found"}, nil, nil}, + {"error", Parameters{"reason": "denied"}, nil, nil}, + {"error", Parameters{"reason": "rejected"}, nil, nil}, + {"error", Parameters{"reason": "server_error"}, nil, nil}, + + {"redirect", Parameters{"location": "/"}, nil, nil}, + {"redirect", Parameters{"location": "foo/bar"}, nil, nil}, + {"redirect", Parameters{"location": "foo/"}, nil, nil}, + {"redirect", Parameters{"location": "/bar"}, nil, nil}, + {"redirect", Parameters{"location": "[::1]:12345/ foo\n\x01\xff/"}, nil, nil}, + {"redirect", Parameters{"location": "/../../////././.."}, nil, nil}, + } +) + +type responseTest struct { + i string + p Parameters + e, v error +} + +func TestNewResponse(t *testing.T) { + for _, tst := range responseTests { + resp, err := NewResponse(tst.i, []byte{}) + if !errorEqual(tst.e, err) { + t.Errorf("NewResponse(%q): expected error %+v, got %+v (%v)", tst.i, tst.e, err, err) + continue + } + if tst.e == nil { + if tst.i != resp.ResponseIntent() { + t.Errorf("NewResponse(%q): got unexpected intent %q", tst.i, resp.ResponseIntent()) + } + } + } +} + +func TestWriteParseResponse(t *testing.T) { + for _, tst := range responseTests { + if tst.e != nil { + continue // skip invalid intents + } + resp := &Response{Message{Header{0, 3, "ok", Parameters{}}, nil, nil}} + if err := resp.SetResponseIntent(tst.i); err != nil { + t.Errorf("SetResponseIntent(%q) error: %v", tst.i, err) + continue + } + for k, v := range tst.p { + resp.SetParam(k, v) + } + + var buf bytes.Buffer + if err := resp.Write(&buf); err != nil { + t.Errorf("%+v.Write: error %v", resp, err) + continue + } + + s := buf.String() + resp2, err := ParseResponse(strings.NewReader(s)) + if err != nil { + t.Errorf("ParseResponse(%q) error: %v", s, err) + continue + } + if !msgEqual(&resp.Message, &resp2.Message) { + t.Errorf("Write/Parse: expected %+v, got %+v", resp, resp2) + continue + } + } +} + +func TestResponseValidate(t *testing.T) { + for _, tst := range responseTests { + resp := &Response{Message{Header{Intent: tst.i, Parameters: tst.p}, nil, nil}} + if err := resp.Validate(); !errorEqual(tst.v, err) { + t.Errorf("%+v.Validate(): expected error %+v, got %+v (%v)", resp, tst.v, err, err) + } + } +} + +func TestResponseGetSet(t *testing.T) { + e := func(k, v string, expect, err error) { + if !errorEqual(err, expect) { + t.Errorf("setting param %q to %q: expected %+v, got %+v (%v)", k, v, expect, err, err) + } + } + c := func(k, v string, expect error, p string) { + if expect == nil && v != p { + t.Errorf("getting param %q: expected %q, got %q", k, v, p) + } + } + + for _, tst := range responseTests { + resp, _ := NewResponse("ok", nil) + if err := resp.SetResponseIntent(tst.i); !errorEqual(err, tst.e) { + t.Errorf("setting response intent to %q: expected %+v, got %+v (%v)", tst.i, tst.e, err, err) + continue + } + if tst.e == nil && resp.ResponseIntent() != tst.i { + t.Errorf("getting response intent: expected %q, got %q", tst.i, resp.ResponseIntent()) + continue + } + for k, v := range tst.p { + switch k { + case "length": + n, _ := strconv.ParseUint(v, 10, 63) + resp.SetLength(int64(n)) + s := strconv.FormatInt(resp.Length(), 10) + if s == "0" && v == "" { + s = "" + } + c(k, v, tst.v, s) + case "name": + e(k, v, tst.v, resp.SetName(v)) + c(k, v, tst.v, resp.Name()) + case "type": + e(k, v, tst.v, resp.SetType(v)) + t := resp.Type() + if t == "application/octet-stream" && v == "" { + t = "" + } + c(k, v, tst.v, t) + case "time": + var tm time.Time + if v != "" { + tm, _ = time.Parse(time.RFC3339, v) + } + resp.SetTime(tm) + tm2 := "" + if !resp.Time().IsZero() { + tm2 = resp.Time().Format(time.RFC3339) + } + c(k, v, tst.v, tm2) + case "modified": + var tm time.Time + if v != "" { + tm, _ = time.Parse(time.RFC3339, v) + } + resp.SetModified(tm) + tm2 := "" + if !resp.Modified().IsZero() { + tm2 = resp.Modified().Format(time.RFC3339) + } + c(k, v, tst.v, tm2) + case "location": + ss := strings.SplitN(v, "/", 2) + if len(ss) != 2 { + continue // invalid location + } + e(k, v, tst.v, resp.SetLocation(ss[0], "/"+ss[1])) + host, path, err := resp.Location() + if err != nil { + t.Errorf("getting parameter location: error %v", err) + } else { + c(k, v, tst.v, host+path) + } + case "reason": + e(k, v, tst.v, resp.SetReason(v)) + c(k, v, tst.v, resp.Reason()) + default: + resp.SetParam(k, v) + c(k, v, tst.v, resp.Param(k)) + } + } + } +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..3413923 --- /dev/null +++ b/server.go @@ -0,0 +1,257 @@ +package cnp + +import ( + "bytes" + "errors" + "io" + "log" + "net" + "os" + "strconv" + "strings" + "time" +) + +// TODO: make more modular and extensible like net/http + +// Server represents a CNP server. +type Server struct { + // LogAccess is called for every finished response if it's non-nil. + LogAccess func(resp ResponseWriter, req *Request, respIntent string, respBytes int64) + + // LogError is called when an error happens if it's non-nil. + LogError func(err interface{}) + + // AccessLog is used to log finished responses if LogAccess is nil. + AccessLogger *log.Logger + + // ErrorLogger is used to log errors if LogError is nil. + ErrorLogger *log.Logger + + // Address is the host:port that the server listens on. + Address string + + // Handler is used to handle received requests. + Handler Handler + + // Validate enables request parameter value validation; invalid requests + // are responded with errors. + Validate bool + + sock net.Conn +} + +// NewServer creates a new Server with default access and errors logs and sets +// the listen address to "localhost". +func NewServer() *Server { + return &Server{ + AccessLogger: log.New(os.Stdout, "", 0), + ErrorLogger: log.New(os.Stderr, "error: ", log.LstdFlags), + Validate: true, + Address: "localhost", + } +} + +// ListenAndServe uses net.Listen to listen on TCP srv.Address for new +// requests, then calls srv.Serve. +func (srv *Server) ListenAndServe() error { + addr := srv.Address + if strings.LastIndexByte(addr, ':') <= strings.LastIndexByte(addr, ']') || + strings.Count(addr, ":") > 1 && !strings.HasPrefix(addr, "[") { // missing/default port + addr = net.JoinHostPort(addr, strconv.Itoa(DefaultPort)) + } + l, err := net.Listen("tcp", addr) + if err != nil { + return err + } + return srv.Serve(l) +} + +// Serve listens on l for new connections and dispatches HandleConn goroutines. +func (srv *Server) Serve(l net.Listener) error { + defer l.Close() + + for { + conn, err := l.Accept() + if err != nil { + return err + } + go srv.HandleConn(conn) + } +} + +func (srv *Server) sendError(conn net.Conn, req *Request, err Error) { + er, _ := NewResponse(IntentError, []byte(err.Error()+"\n")) + er.SetParam("reason", err.CNPError()) + er.SetParam("type", "text/plain") + var buf bytes.Buffer + er.Write(&buf) + er.SetLength(int64(buf.Len())) + l, e2 := io.Copy(conn, &buf) + if e2 != nil { + srv.logError(e2) + } + srv.logAccess(&responseWriter{addr: conn.RemoteAddr()}, req, er.Intent(), l) +} + +// HandleConn reads a CNP request from conn and runs a handler to respond. +func (srv *Server) HandleConn(conn net.Conn) { + var rw *responseWriter + var req *Request + + defer func() { + /*_, err := io.Copy(ioutil.Discard, req.Body) + if err != nil { + srv.ErrorLog.Print(err) + }*/ + + if rw != nil && rw.headerWritten { + srv.logAccess(rw, req, rw.resp.Header.Intent, rw.n) + } + + if rec := recover(); rec != nil { + srv.logError(rec) + if err, ok := rec.(Error); ok && rw != nil && !rw.headerWritten { + srv.sendError(conn, req, err) + } + } + + if req != nil { + req.Close() + } + }() + + req, err := ParseRequest(conn) + req.Body = io.LimitReader(req.Body, req.Length()) + if err != nil { + if e, ok := err.(Error); ok { + resp, _ := NewResponse(IntentError, nil) + resp.SetParam("reason", e.CNPError()) + resp.Write(conn) + return + } + panic(err) + } + + if srv.Validate { + err = req.Validate() + if err != nil { + srv.sendError(conn, req, err.(Error)) + return + } + } + + if srv.Handler != nil { + resp, _ := NewResponse(IntentOK, nil) + rw = &responseWriter{ + w: conn, + resp: resp, + addr: conn.RemoteAddr(), + } + srv.Handler.ServeCNP(rw, req) + if !rw.headerWritten { + rw.WriteHeader() + } + } +} + +func (srv *Server) logAccess(resp ResponseWriter, req *Request, respIntent string, respBytes int64) { + if srv.LogAccess != nil { + srv.LogAccess(resp, req, respIntent, respBytes) + } else if srv.AccessLogger != nil { + srv.AccessLogger.Printf("%s - - %s %q %s %d", + resp.RemoteAddr(), + time.Now().Format("[02/Jan/2006:03:04:05 -0700]"), + req.Header.Version()+" "+string(Escape(req.Host()))+string(Escape(req.Path())), + respIntent, + respBytes, + ) + } +} + +func (srv *Server) logError(err interface{}) { + if srv.LogError != nil { + srv.LogError(err) + } else if srv.ErrorLogger != nil { + srv.ErrorLogger.Println(err) + } +} + +// Handler handles CNP requests accepted by the server. +type Handler interface { + // ServeCNP responds to a CNP request. + // + // This function must be safe for concurrent use. + ServeCNP(resp ResponseWriter, req *Request) +} + +// HandlerFunc allows using raw functions as handlers. +type HandlerFunc func(resp ResponseWriter, req *Request) + +// ServeCNP calls h(resp, req). +func (h HandlerFunc) ServeCNP(resp ResponseWriter, req *Request) { + h(resp, req) +} + +// ResponseWriter is used by a CNP server to write responses to CNP requests. +type ResponseWriter interface { + // Response returns a Response object whose header will be written to the + // socket. The body is nil and should be ignored. + Response() *Response + + // RemoteAddr returns the network address of the client. + RemoteAddr() net.Addr + + // WriteHeader sends a CNP header with an intent. + WriteHeader() error + + // Write sends data in the CNP response body. + // + // If WriteHeader has not been called yet, it also calls WriteHeader("ok"). + Write(data []byte) (int, error) +} + +type responseWriter struct { + w io.Writer + addr net.Addr + resp *Response + srv *Server + n int64 + headerWritten bool +} + +func (r *responseWriter) Response() *Response { + return r.resp +} + +func (r *responseWriter) RemoteAddr() net.Addr { + return r.addr +} + +func (r *responseWriter) WriteHeader() error { + if r.headerWritten { + err := errors.New("multiple WriteHeader calls") + r.srv.logError(err) + return err + } + r.headerWritten = true + return r.resp.Header.Write(r) +} + +func (r *responseWriter) Write(data []byte) (int, error) { + if !r.headerWritten { + r.WriteHeader() + } + n, err := r.w.Write(data) + r.n += int64(n) + return n, err +} + +// ListenAndServe creates a new Server with a listen address and a handler and +// calls its ListenAndServe method. +func ListenAndServe(addr string, handler Handler) error { + srv := NewServer() + srv.Address = addr + srv.Handler = handler + return srv.ListenAndServe() +} |