| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541 | // Copyright 2011 Aaron Jacobs. All Rights Reserved.// Author: aaronjjacobs@gmail.com (Aaron Jacobs)//// Licensed under the Apache License, Version 2.0 (the "License");// you may not use this file except in compliance with the License.// You may obtain a copy of the License at////     http://www.apache.org/licenses/LICENSE-2.0//// Unless required by applicable law or agreed to in writing, software// distributed under the License is distributed on an "AS IS" BASIS,// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.// See the License for the specific language governing permissions and// limitations under the License.package oglematchersimport (	"errors"	"fmt"	"math"	"reflect")// Equals(x) returns a matcher that matches values v such that v and x are// equivalent. This includes the case when the comparison v == x using Go's// built-in comparison operator is legal (except for structs, which this// matcher does not support), but for convenience the following rules also// apply:////  *  Type checking is done based on underlying types rather than actual//     types, so that e.g. two aliases for string can be compared:////         type stringAlias1 string//         type stringAlias2 string////         a := "taco"//         b := stringAlias1("taco")//         c := stringAlias2("taco")////         ExpectTrue(a == b)  // Legal, passes//         ExpectTrue(b == c)  // Illegal, doesn't compile////         ExpectThat(a, Equals(b))  // Passes//         ExpectThat(b, Equals(c))  // Passes////  *  Values of numeric type are treated as if they were abstract numbers, and//     compared accordingly. Therefore Equals(17) will match int(17),//     int16(17), uint(17), float32(17), complex64(17), and so on.//// If you want a stricter matcher that contains no such cleverness, see// IdenticalTo instead.//// Arrays are supported by this matcher, but do not participate in the// exceptions above. Two arrays compared with this matcher must have identical// types, and their element type must itself be comparable according to Go's ==// operator.func Equals(x interface{}) Matcher {	v := reflect.ValueOf(x)	// This matcher doesn't support structs.	if v.Kind() == reflect.Struct {		panic(fmt.Sprintf("oglematchers.Equals: unsupported kind %v", v.Kind()))	}	// The == operator is not defined for non-nil slices.	if v.Kind() == reflect.Slice && v.Pointer() != uintptr(0) {		panic(fmt.Sprintf("oglematchers.Equals: non-nil slice"))	}	return &equalsMatcher{v}}type equalsMatcher struct {	expectedValue reflect.Value}////////////////////////////////////////////////////////////////////////// Numeric types////////////////////////////////////////////////////////////////////////func isSignedInteger(v reflect.Value) bool {	k := v.Kind()	return k >= reflect.Int && k <= reflect.Int64}func isUnsignedInteger(v reflect.Value) bool {	k := v.Kind()	return k >= reflect.Uint && k <= reflect.Uintptr}func isInteger(v reflect.Value) bool {	return isSignedInteger(v) || isUnsignedInteger(v)}func isFloat(v reflect.Value) bool {	k := v.Kind()	return k == reflect.Float32 || k == reflect.Float64}func isComplex(v reflect.Value) bool {	k := v.Kind()	return k == reflect.Complex64 || k == reflect.Complex128}func checkAgainstInt64(e int64, c reflect.Value) (err error) {	err = errors.New("")	switch {	case isSignedInteger(c):		if c.Int() == e {			err = nil		}	case isUnsignedInteger(c):		u := c.Uint()		if u <= math.MaxInt64 && int64(u) == e {			err = nil		}	// Turn around the various floating point types so that the checkAgainst*	// functions for them can deal with precision issues.	case isFloat(c), isComplex(c):		return Equals(c.Interface()).Matches(e)	default:		err = NewFatalError("which is not numeric")	}	return}func checkAgainstUint64(e uint64, c reflect.Value) (err error) {	err = errors.New("")	switch {	case isSignedInteger(c):		i := c.Int()		if i >= 0 && uint64(i) == e {			err = nil		}	case isUnsignedInteger(c):		if c.Uint() == e {			err = nil		}	// Turn around the various floating point types so that the checkAgainst*	// functions for them can deal with precision issues.	case isFloat(c), isComplex(c):		return Equals(c.Interface()).Matches(e)	default:		err = NewFatalError("which is not numeric")	}	return}func checkAgainstFloat32(e float32, c reflect.Value) (err error) {	err = errors.New("")	switch {	case isSignedInteger(c):		if float32(c.Int()) == e {			err = nil		}	case isUnsignedInteger(c):		if float32(c.Uint()) == e {			err = nil		}	case isFloat(c):		// Compare using float32 to avoid a false sense of precision; otherwise		// e.g. Equals(float32(0.1)) won't match float32(0.1).		if float32(c.Float()) == e {			err = nil		}	case isComplex(c):		comp := c.Complex()		rl := real(comp)		im := imag(comp)		// Compare using float32 to avoid a false sense of precision; otherwise		// e.g. Equals(float32(0.1)) won't match (0.1 + 0i).		if im == 0 && float32(rl) == e {			err = nil		}	default:		err = NewFatalError("which is not numeric")	}	return}func checkAgainstFloat64(e float64, c reflect.Value) (err error) {	err = errors.New("")	ck := c.Kind()	switch {	case isSignedInteger(c):		if float64(c.Int()) == e {			err = nil		}	case isUnsignedInteger(c):		if float64(c.Uint()) == e {			err = nil		}	// If the actual value is lower precision, turn the comparison around so we	// apply the low-precision rules. Otherwise, e.g. Equals(0.1) may not match	// float32(0.1).	case ck == reflect.Float32 || ck == reflect.Complex64:		return Equals(c.Interface()).Matches(e)		// Otherwise, compare with double precision.	case isFloat(c):		if c.Float() == e {			err = nil		}	case isComplex(c):		comp := c.Complex()		rl := real(comp)		im := imag(comp)		if im == 0 && rl == e {			err = nil		}	default:		err = NewFatalError("which is not numeric")	}	return}func checkAgainstComplex64(e complex64, c reflect.Value) (err error) {	err = errors.New("")	realPart := real(e)	imaginaryPart := imag(e)	switch {	case isInteger(c) || isFloat(c):		// If we have no imaginary part, then we should just compare against the		// real part. Otherwise, we can't be equal.		if imaginaryPart != 0 {			return		}		return checkAgainstFloat32(realPart, c)	case isComplex(c):		// Compare using complex64 to avoid a false sense of precision; otherwise		// e.g. Equals(0.1 + 0i) won't match float32(0.1).		if complex64(c.Complex()) == e {			err = nil		}	default:		err = NewFatalError("which is not numeric")	}	return}func checkAgainstComplex128(e complex128, c reflect.Value) (err error) {	err = errors.New("")	realPart := real(e)	imaginaryPart := imag(e)	switch {	case isInteger(c) || isFloat(c):		// If we have no imaginary part, then we should just compare against the		// real part. Otherwise, we can't be equal.		if imaginaryPart != 0 {			return		}		return checkAgainstFloat64(realPart, c)	case isComplex(c):		if c.Complex() == e {			err = nil		}	default:		err = NewFatalError("which is not numeric")	}	return}////////////////////////////////////////////////////////////////////////// Other types////////////////////////////////////////////////////////////////////////func checkAgainstBool(e bool, c reflect.Value) (err error) {	if c.Kind() != reflect.Bool {		err = NewFatalError("which is not a bool")		return	}	err = errors.New("")	if c.Bool() == e {		err = nil	}	return}func checkAgainstChan(e reflect.Value, c reflect.Value) (err error) {	// Create a description of e's type, e.g. "chan int".	typeStr := fmt.Sprintf("%s %s", e.Type().ChanDir(), e.Type().Elem())	// Make sure c is a chan of the correct type.	if c.Kind() != reflect.Chan ||		c.Type().ChanDir() != e.Type().ChanDir() ||		c.Type().Elem() != e.Type().Elem() {		err = NewFatalError(fmt.Sprintf("which is not a %s", typeStr))		return	}	err = errors.New("")	if c.Pointer() == e.Pointer() {		err = nil	}	return}func checkAgainstFunc(e reflect.Value, c reflect.Value) (err error) {	// Make sure c is a function.	if c.Kind() != reflect.Func {		err = NewFatalError("which is not a function")		return	}	err = errors.New("")	if c.Pointer() == e.Pointer() {		err = nil	}	return}func checkAgainstMap(e reflect.Value, c reflect.Value) (err error) {	// Make sure c is a map.	if c.Kind() != reflect.Map {		err = NewFatalError("which is not a map")		return	}	err = errors.New("")	if c.Pointer() == e.Pointer() {		err = nil	}	return}func checkAgainstPtr(e reflect.Value, c reflect.Value) (err error) {	// Create a description of e's type, e.g. "*int".	typeStr := fmt.Sprintf("*%v", e.Type().Elem())	// Make sure c is a pointer of the correct type.	if c.Kind() != reflect.Ptr ||		c.Type().Elem() != e.Type().Elem() {		err = NewFatalError(fmt.Sprintf("which is not a %s", typeStr))		return	}	err = errors.New("")	if c.Pointer() == e.Pointer() {		err = nil	}	return}func checkAgainstSlice(e reflect.Value, c reflect.Value) (err error) {	// Create a description of e's type, e.g. "[]int".	typeStr := fmt.Sprintf("[]%v", e.Type().Elem())	// Make sure c is a slice of the correct type.	if c.Kind() != reflect.Slice ||		c.Type().Elem() != e.Type().Elem() {		err = NewFatalError(fmt.Sprintf("which is not a %s", typeStr))		return	}	err = errors.New("")	if c.Pointer() == e.Pointer() {		err = nil	}	return}func checkAgainstString(e reflect.Value, c reflect.Value) (err error) {	// Make sure c is a string.	if c.Kind() != reflect.String {		err = NewFatalError("which is not a string")		return	}	err = errors.New("")	if c.String() == e.String() {		err = nil	}	return}func checkAgainstArray(e reflect.Value, c reflect.Value) (err error) {	// Create a description of e's type, e.g. "[2]int".	typeStr := fmt.Sprintf("%v", e.Type())	// Make sure c is the correct type.	if c.Type() != e.Type() {		err = NewFatalError(fmt.Sprintf("which is not %s", typeStr))		return	}	// Check for equality.	if e.Interface() != c.Interface() {		err = errors.New("")		return	}	return}func checkAgainstUnsafePointer(e reflect.Value, c reflect.Value) (err error) {	// Make sure c is a pointer.	if c.Kind() != reflect.UnsafePointer {		err = NewFatalError("which is not a unsafe.Pointer")		return	}	err = errors.New("")	if c.Pointer() == e.Pointer() {		err = nil	}	return}func checkForNil(c reflect.Value) (err error) {	err = errors.New("")	// Make sure it is legal to call IsNil.	switch c.Kind() {	case reflect.Invalid:	case reflect.Chan:	case reflect.Func:	case reflect.Interface:	case reflect.Map:	case reflect.Ptr:	case reflect.Slice:	default:		err = NewFatalError("which cannot be compared to nil")		return	}	// Ask whether the value is nil. Handle a nil literal (kind Invalid)	// specially, since it's not legal to call IsNil there.	if c.Kind() == reflect.Invalid || c.IsNil() {		err = nil	}	return}////////////////////////////////////////////////////////////////////////// Public implementation////////////////////////////////////////////////////////////////////////func (m *equalsMatcher) Matches(candidate interface{}) error {	e := m.expectedValue	c := reflect.ValueOf(candidate)	ek := e.Kind()	switch {	case ek == reflect.Bool:		return checkAgainstBool(e.Bool(), c)	case isSignedInteger(e):		return checkAgainstInt64(e.Int(), c)	case isUnsignedInteger(e):		return checkAgainstUint64(e.Uint(), c)	case ek == reflect.Float32:		return checkAgainstFloat32(float32(e.Float()), c)	case ek == reflect.Float64:		return checkAgainstFloat64(e.Float(), c)	case ek == reflect.Complex64:		return checkAgainstComplex64(complex64(e.Complex()), c)	case ek == reflect.Complex128:		return checkAgainstComplex128(complex128(e.Complex()), c)	case ek == reflect.Chan:		return checkAgainstChan(e, c)	case ek == reflect.Func:		return checkAgainstFunc(e, c)	case ek == reflect.Map:		return checkAgainstMap(e, c)	case ek == reflect.Ptr:		return checkAgainstPtr(e, c)	case ek == reflect.Slice:		return checkAgainstSlice(e, c)	case ek == reflect.String:		return checkAgainstString(e, c)	case ek == reflect.Array:		return checkAgainstArray(e, c)	case ek == reflect.UnsafePointer:		return checkAgainstUnsafePointer(e, c)	case ek == reflect.Invalid:		return checkForNil(c)	}	panic(fmt.Sprintf("equalsMatcher.Matches: unexpected kind: %v", ek))}func (m *equalsMatcher) Description() string {	// Special case: handle nil.	if !m.expectedValue.IsValid() {		return "is nil"	}	return fmt.Sprintf("%v", m.expectedValue.Interface())}
 |