summaryrefslogtreecommitdiffstats
path: root/common_test.go
blob: 8e8c67382f3ff5a0e4a6466df3be76222dbadac1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
package cnp

import (
	"bytes"
	"io"
	"io/ioutil"
	"reflect"
	"strings"
)

func headerEqual(a, b Header) bool {
	return a.Intent == b.Intent && a.VersionMajor == b.VersionMajor && a.VersionMinor == b.VersionMinor && paramEqual(a.Parameters, b.Parameters)
}

func paramEqual(a, b Parameters) bool {
	if len(a) != len(b) {
		return false
	}
	for k := range a {
		if a[k] != b[k] {
			return false
		}
	}
	return true
}

func errorEqual(a, b error) bool {
	if a == nil && b == nil {
		return true
	}
	return reflect.TypeOf(a) == reflect.TypeOf(b)
}

func msgEqual(a, b *Message) bool {
	return headerEqual(a.Header, b.Header) && bodyEqual(a.Body, b.Body)
}

func bodyEqual(a, b io.Reader) bool {
	if a == nil && b == nil {
		return true
	}
	if a == nil {
		a = strings.NewReader("")
	}
	if b == nil {
		b = strings.NewReader("")
	}
	ba, err := ioutil.ReadAll(a)
	if err != nil {
		panic(err)
	}
	if s, ok := a.(io.Seeker); ok {
		_, _ = s.Seek(0, io.SeekStart)
	}
	bb, err := ioutil.ReadAll(b)
	if err != nil {
		panic(err)
	}
	if s, ok := b.(io.Seeker); ok {
		_, _ = s.Seek(0, io.SeekStart)
	}
	return bytes.Equal(ba, bb)
}

type testStringReader struct {
	s string
}

func (r *testStringReader) Read(b []byte) (n int, err error) {
	if r.s == "" {
		return 0, io.EOF
	}
	n = copy(b, r.s)
	r.s = r.s[n:]
	return
}