diff options
-rw-r--r-- | common.go | 4 | ||||
-rw-r--r-- | message.go | 44 | ||||
-rw-r--r-- | message_test.go | 2 | ||||
-rw-r--r-- | request.go | 20 | ||||
-rw-r--r-- | response.go | 20 | ||||
-rw-r--r-- | server.go | 2 |
6 files changed, 67 insertions, 25 deletions
@@ -30,10 +30,10 @@ func readLimitedLine(br *bufio.Reader, length int) ([]byte, error) { } } -func getInt(m *Message, param string) (int64, error) { +func getInt(m *Message, param string, fallback int64) (int64, error) { p := m.Param(param) if p == "" { - return 0, nil + return fallback, nil } n, err := strconv.ParseUint(p, 10, 63) if err != nil || (n != 0 && p[0] == '0') || (n == 0 && len(p) != 1) { @@ -114,22 +114,14 @@ func (msg *Message) TryComputeLength() bool { return true } -// SetLength sets the length header parameter to n. -func (msg *Message) SetLength(n int64) { - if n == 0 { - msg.SetParam("length", "") - } else { - setInt(msg, "length", n) - } +// Intent retrieves the message header intent. +func (msg *Message) Intent() string { + return msg.Header.Intent } -// Length gets the length header parameter (or 0 if it's not set or invalid). -func (msg *Message) Length() int64 { - n, err := getInt(msg, "length") - if err != nil { - return 0 - } - return n +// SetIntent sets the message header intent. +func (msg *Message) SetIntent(s string) { + msg.Header.Intent = s } // Param retrieves a header parameter. It performs no value validation. @@ -147,19 +139,29 @@ func (msg *Message) SetParam(key, value string) { } } -// Intent retrieves the message header intent. -func (msg *Message) Intent() string { - return msg.Header.Intent +// Length gets the length header parameter (or 0 if it's not set or invalid). +func (msg *Message) Length() int64 { + n, err := getInt(msg, "length", 0) + if err != nil { + return 0 + } + return n } -// SetIntent sets the message header intent. -func (msg *Message) SetIntent(s string) { - msg.Header.Intent = s +// SetLength sets the length header parameter to n. +// +// If negative or zero, the parameter is unset. +func (msg *Message) SetLength(n int64) { + if n <= 0 { + msg.SetParam("length", "") + } else { + setInt(msg, "length", n) + } } // Validate validates the message header parameter value format (length). func (msg *Message) Validate() error { - _, err := getInt(msg, "length") + _, err := getInt(msg, "length", 0) return err } diff --git a/message_test.go b/message_test.go index 6955fc3..1a73fa5 100644 --- a/message_test.go +++ b/message_test.go @@ -92,7 +92,7 @@ func TestParse(t *testing.T) { } } else if !msgEqual(msg, tst.m) { t.Errorf("\nexpected: %+v\ngot: %+v", tst.m, msg) - } else if l := msg.Length(); tst.v == nil && l != int64(len(tst.b)) { + } else if l := msg.Length(); tst.v == nil && (l != int64(len(tst.b)) && !(l < 0 == (len(tst.b) == 0))) { t.Errorf("%+v: expected length %d, got %d", tst.m, len(tst.b), l) } else if err = msg.Validate(); !errorEqual(err, tst.v) { t.Errorf("%+v: expected validation error %+v, got %+v (%s)", msg, tst.v, err, err) @@ -219,6 +219,26 @@ func (r *Request) SetSelect(selector, query string) error { return setSelect(&r.Message, "select", selector, query) } +// Length gets the length request parameter (or 0 if not set or invalid). +func (r *Request) Length() int64 { + n, err := getInt(&r.Message, "length", 0) + if err != nil { + return 0 + } + return n +} + +// SetLength sets the length request parameter to n. +// +// If n is negative or zero, the parameter is unset. +func (r *Request) SetLength(n int64) { + if n <= 0 { + r.SetParam("length", "") + } else { + setInt(&r.Message, "length", n) + } +} + // Validate validates the request header intent and parameter value format // (length, name, type, if_modified, select) func (r *Request) Validate() error { diff --git a/response.go b/response.go index 3ecb961..4c3f753 100644 --- a/response.go +++ b/response.go @@ -224,6 +224,26 @@ func (r *Response) SetSelect(selector, query string) error { return setSelect(&r.Message, "select", selector, query) } +// Length gets the length response parameter (or -1 if it's not set or invalid). +func (r *Response) Length() int64 { + n, err := getInt(&r.Message, "length", -1) + if err != nil { + return -1 + } + return n +} + +// SetLength sets the length response parameter to n. +// +// If negative, the parameter is unset. +func (r *Response) SetLength(n int64) { + if n < 0 { + r.SetParam("length", "") + } else { + setInt(&r.Message, "length", n) + } +} + // Validate validates the response intent and header parameter value format // (length, name, type, time, modified, location, reason, select) func (r *Response) Validate() error { @@ -139,7 +139,7 @@ func (srv *Server) HandleConn(conn net.Conn) { req.Body = io.LimitReader(req.Body, req.Length()) if srv.Validate { - if err != nil { + if err := req.Validate(); err != nil { panic(err) } } |