Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision

Target

Select target project
  • oss/libraries/go/services/job-queues
1 result
Select Git revision
Show changes
Showing
with 3556 additions and 0 deletions
// Copyright (c) 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package field
import "math/bits"
// uint128 holds a 128-bit number as two 64-bit limbs, for use with the
// bits.Mul64 and bits.Add64 intrinsics.
type uint128 struct {
lo, hi uint64
}
// mul64 returns a * b.
func mul64(a, b uint64) uint128 {
hi, lo := bits.Mul64(a, b)
return uint128{lo, hi}
}
// addMul64 returns v + a * b.
func addMul64(v uint128, a, b uint64) uint128 {
hi, lo := bits.Mul64(a, b)
lo, c := bits.Add64(lo, v.lo, 0)
hi, _ = bits.Add64(hi, v.hi, c)
return uint128{lo, hi}
}
// shiftRightBy51 returns a >> 51. a is assumed to be at most 115 bits.
func shiftRightBy51(a uint128) uint64 {
return (a.hi << (64 - 51)) | (a.lo >> 51)
}
func feMulGeneric(v, a, b *Element) {
a0 := a.l0
a1 := a.l1
a2 := a.l2
a3 := a.l3
a4 := a.l4
b0 := b.l0
b1 := b.l1
b2 := b.l2
b3 := b.l3
b4 := b.l4
// Limb multiplication works like pen-and-paper columnar multiplication, but
// with 51-bit limbs instead of digits.
//
// a4 a3 a2 a1 a0 x
// b4 b3 b2 b1 b0 =
// ------------------------
// a4b0 a3b0 a2b0 a1b0 a0b0 +
// a4b1 a3b1 a2b1 a1b1 a0b1 +
// a4b2 a3b2 a2b2 a1b2 a0b2 +
// a4b3 a3b3 a2b3 a1b3 a0b3 +
// a4b4 a3b4 a2b4 a1b4 a0b4 =
// ----------------------------------------------
// r8 r7 r6 r5 r4 r3 r2 r1 r0
//
// We can then use the reduction identity (a * 2²⁵⁵ + b = a * 19 + b) to
// reduce the limbs that would overflow 255 bits. r5 * 2²⁵⁵ becomes 19 * r5,
// r6 * 2³⁰⁶ becomes 19 * r6 * 2⁵¹, etc.
//
// Reduction can be carried out simultaneously to multiplication. For
// example, we do not compute r5: whenever the result of a multiplication
// belongs to r5, like a1b4, we multiply it by 19 and add the result to r0.
//
// a4b0 a3b0 a2b0 a1b0 a0b0 +
// a3b1 a2b1 a1b1 a0b1 19×a4b1 +
// a2b2 a1b2 a0b2 19×a4b2 19×a3b2 +
// a1b3 a0b3 19×a4b3 19×a3b3 19×a2b3 +
// a0b4 19×a4b4 19×a3b4 19×a2b4 19×a1b4 =
// --------------------------------------
// r4 r3 r2 r1 r0
//
// Finally we add up the columns into wide, overlapping limbs.
a1_19 := a1 * 19
a2_19 := a2 * 19
a3_19 := a3 * 19
a4_19 := a4 * 19
// r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1)
r0 := mul64(a0, b0)
r0 = addMul64(r0, a1_19, b4)
r0 = addMul64(r0, a2_19, b3)
r0 = addMul64(r0, a3_19, b2)
r0 = addMul64(r0, a4_19, b1)
// r1 = a0×b1 + a1×b0 + 19×(a2×b4 + a3×b3 + a4×b2)
r1 := mul64(a0, b1)
r1 = addMul64(r1, a1, b0)
r1 = addMul64(r1, a2_19, b4)
r1 = addMul64(r1, a3_19, b3)
r1 = addMul64(r1, a4_19, b2)
// r2 = a0×b2 + a1×b1 + a2×b0 + 19×(a3×b4 + a4×b3)
r2 := mul64(a0, b2)
r2 = addMul64(r2, a1, b1)
r2 = addMul64(r2, a2, b0)
r2 = addMul64(r2, a3_19, b4)
r2 = addMul64(r2, a4_19, b3)
// r3 = a0×b3 + a1×b2 + a2×b1 + a3×b0 + 19×a4×b4
r3 := mul64(a0, b3)
r3 = addMul64(r3, a1, b2)
r3 = addMul64(r3, a2, b1)
r3 = addMul64(r3, a3, b0)
r3 = addMul64(r3, a4_19, b4)
// r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0
r4 := mul64(a0, b4)
r4 = addMul64(r4, a1, b3)
r4 = addMul64(r4, a2, b2)
r4 = addMul64(r4, a3, b1)
r4 = addMul64(r4, a4, b0)
// After the multiplication, we need to reduce (carry) the five coefficients
// to obtain a result with limbs that are at most slightly larger than 2⁵¹,
// to respect the Element invariant.
//
// Overall, the reduction works the same as carryPropagate, except with
// wider inputs: we take the carry for each coefficient by shifting it right
// by 51, and add it to the limb above it. The top carry is multiplied by 19
// according to the reduction identity and added to the lowest limb.
//
// The largest coefficient (r0) will be at most 111 bits, which guarantees
// that all carries are at most 111 - 51 = 60 bits, which fits in a uint64.
//
// r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1)
// r0 < 2⁵²×2⁵² + 19×(2⁵²×2⁵² + 2⁵²×2⁵² + 2⁵²×2⁵² + 2⁵²×2⁵²)
// r0 < (1 + 19 × 4) × 2⁵² × 2⁵²
// r0 < 2⁷ × 2⁵² × 2⁵²
// r0 < 2¹¹¹
//
// Moreover, the top coefficient (r4) is at most 107 bits, so c4 is at most
// 56 bits, and c4 * 19 is at most 61 bits, which again fits in a uint64 and
// allows us to easily apply the reduction identity.
//
// r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0
// r4 < 5 × 2⁵² × 2⁵²
// r4 < 2¹⁰⁷
//
c0 := shiftRightBy51(r0)
c1 := shiftRightBy51(r1)
c2 := shiftRightBy51(r2)
c3 := shiftRightBy51(r3)
c4 := shiftRightBy51(r4)
rr0 := r0.lo&maskLow51Bits + c4*19
rr1 := r1.lo&maskLow51Bits + c0
rr2 := r2.lo&maskLow51Bits + c1
rr3 := r3.lo&maskLow51Bits + c2
rr4 := r4.lo&maskLow51Bits + c3
// Now all coefficients fit into 64-bit registers but are still too large to
// be passed around as an Element. We therefore do one last carry chain,
// where the carries will be small enough to fit in the wiggle room above 2⁵¹.
*v = Element{rr0, rr1, rr2, rr3, rr4}
v.carryPropagate()
}
func feSquareGeneric(v, a *Element) {
l0 := a.l0
l1 := a.l1
l2 := a.l2
l3 := a.l3
l4 := a.l4
// Squaring works precisely like multiplication above, but thanks to its
// symmetry we get to group a few terms together.
//
// l4 l3 l2 l1 l0 x
// l4 l3 l2 l1 l0 =
// ------------------------
// l4l0 l3l0 l2l0 l1l0 l0l0 +
// l4l1 l3l1 l2l1 l1l1 l0l1 +
// l4l2 l3l2 l2l2 l1l2 l0l2 +
// l4l3 l3l3 l2l3 l1l3 l0l3 +
// l4l4 l3l4 l2l4 l1l4 l0l4 =
// ----------------------------------------------
// r8 r7 r6 r5 r4 r3 r2 r1 r0
//
// l4l0 l3l0 l2l0 l1l0 l0l0 +
// l3l1 l2l1 l1l1 l0l1 19×l4l1 +
// l2l2 l1l2 l0l2 19×l4l2 19×l3l2 +
// l1l3 l0l3 19×l4l3 19×l3l3 19×l2l3 +
// l0l4 19×l4l4 19×l3l4 19×l2l4 19×l1l4 =
// --------------------------------------
// r4 r3 r2 r1 r0
//
// With precomputed 2×, 19×, and 2×19× terms, we can compute each limb with
// only three Mul64 and four Add64, instead of five and eight.
l0_2 := l0 * 2
l1_2 := l1 * 2
l1_38 := l1 * 38
l2_38 := l2 * 38
l3_38 := l3 * 38
l3_19 := l3 * 19
l4_19 := l4 * 19
// r0 = l0×l0 + 19×(l1×l4 + l2×l3 + l3×l2 + l4×l1) = l0×l0 + 19×2×(l1×l4 + l2×l3)
r0 := mul64(l0, l0)
r0 = addMul64(r0, l1_38, l4)
r0 = addMul64(r0, l2_38, l3)
// r1 = l0×l1 + l1×l0 + 19×(l2×l4 + l3×l3 + l4×l2) = 2×l0×l1 + 19×2×l2×l4 + 19×l3×l3
r1 := mul64(l0_2, l1)
r1 = addMul64(r1, l2_38, l4)
r1 = addMul64(r1, l3_19, l3)
// r2 = l0×l2 + l1×l1 + l2×l0 + 19×(l3×l4 + l4×l3) = 2×l0×l2 + l1×l1 + 19×2×l3×l4
r2 := mul64(l0_2, l2)
r2 = addMul64(r2, l1, l1)
r2 = addMul64(r2, l3_38, l4)
// r3 = l0×l3 + l1×l2 + l2×l1 + l3×l0 + 19×l4×l4 = 2×l0×l3 + 2×l1×l2 + 19×l4×l4
r3 := mul64(l0_2, l3)
r3 = addMul64(r3, l1_2, l2)
r3 = addMul64(r3, l4_19, l4)
// r4 = l0×l4 + l1×l3 + l2×l2 + l3×l1 + l4×l0 = 2×l0×l4 + 2×l1×l3 + l2×l2
r4 := mul64(l0_2, l4)
r4 = addMul64(r4, l1_2, l3)
r4 = addMul64(r4, l2, l2)
c0 := shiftRightBy51(r0)
c1 := shiftRightBy51(r1)
c2 := shiftRightBy51(r2)
c3 := shiftRightBy51(r3)
c4 := shiftRightBy51(r4)
rr0 := r0.lo&maskLow51Bits + c4*19
rr1 := r1.lo&maskLow51Bits + c0
rr2 := r2.lo&maskLow51Bits + c1
rr3 := r3.lo&maskLow51Bits + c2
rr4 := r4.lo&maskLow51Bits + c3
*v = Element{rr0, rr1, rr2, rr3, rr4}
v.carryPropagate()
}
// carryPropagateGeneric brings the limbs below 52 bits by applying the reduction
// identity (a * 2²⁵⁵ + b = a * 19 + b) to the l4 carry.
func (v *Element) carryPropagateGeneric() *Element {
c0 := v.l0 >> 51
c1 := v.l1 >> 51
c2 := v.l2 >> 51
c3 := v.l3 >> 51
c4 := v.l4 >> 51
// c4 is at most 64 - 51 = 13 bits, so c4*19 is at most 18 bits, and
// the final l0 will be at most 52 bits. Similarly for the rest.
v.l0 = v.l0&maskLow51Bits + c4*19
v.l1 = v.l1&maskLow51Bits + c0
v.l2 = v.l2&maskLow51Bits + c1
v.l3 = v.l3&maskLow51Bits + c2
v.l4 = v.l4&maskLow51Bits + c3
return v
}
// Copyright (c) 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
import (
"encoding/binary"
"errors"
)
// A Scalar is an integer modulo
//
// l = 2^252 + 27742317777372353535851937790883648493
//
// which is the prime order of the edwards25519 group.
//
// This type works similarly to math/big.Int, and all arguments and
// receivers are allowed to alias.
//
// The zero value is a valid zero element.
type Scalar struct {
// s is the scalar in the Montgomery domain, in the format of the
// fiat-crypto implementation.
s fiatScalarMontgomeryDomainFieldElement
}
// The field implementation in scalar_fiat.go is generated by the fiat-crypto
// project (https://github.com/mit-plv/fiat-crypto) at version v0.0.9 (23d2dbc)
// from a formally verified model.
//
// fiat-crypto code comes under the following license.
//
// Copyright (c) 2015-2020 The fiat-crypto Authors. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// 1. Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//
// THIS SOFTWARE IS PROVIDED BY the fiat-crypto authors "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL Berkeley Software Design,
// Inc. BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
// NewScalar returns a new zero Scalar.
func NewScalar() *Scalar {
return &Scalar{}
}
// MultiplyAdd sets s = x * y + z mod l, and returns s. It is equivalent to
// using Multiply and then Add.
func (s *Scalar) MultiplyAdd(x, y, z *Scalar) *Scalar {
// Make a copy of z in case it aliases s.
zCopy := new(Scalar).Set(z)
return s.Multiply(x, y).Add(s, zCopy)
}
// Add sets s = x + y mod l, and returns s.
func (s *Scalar) Add(x, y *Scalar) *Scalar {
// s = 1 * x + y mod l
fiatScalarAdd(&s.s, &x.s, &y.s)
return s
}
// Subtract sets s = x - y mod l, and returns s.
func (s *Scalar) Subtract(x, y *Scalar) *Scalar {
// s = -1 * y + x mod l
fiatScalarSub(&s.s, &x.s, &y.s)
return s
}
// Negate sets s = -x mod l, and returns s.
func (s *Scalar) Negate(x *Scalar) *Scalar {
// s = -1 * x + 0 mod l
fiatScalarOpp(&s.s, &x.s)
return s
}
// Multiply sets s = x * y mod l, and returns s.
func (s *Scalar) Multiply(x, y *Scalar) *Scalar {
// s = x * y + 0 mod l
fiatScalarMul(&s.s, &x.s, &y.s)
return s
}
// Set sets s = x, and returns s.
func (s *Scalar) Set(x *Scalar) *Scalar {
*s = *x
return s
}
// SetUniformBytes sets s = x mod l, where x is a 64-byte little-endian integer.
// If x is not of the right length, SetUniformBytes returns nil and an error,
// and the receiver is unchanged.
//
// SetUniformBytes can be used to set s to a uniformly distributed value given
// 64 uniformly distributed random bytes.
func (s *Scalar) SetUniformBytes(x []byte) (*Scalar, error) {
if len(x) != 64 {
return nil, errors.New("edwards25519: invalid SetUniformBytes input length")
}
// We have a value x of 512 bits, but our fiatScalarFromBytes function
// expects an input lower than l, which is a little over 252 bits.
//
// Instead of writing a reduction function that operates on wider inputs, we
// can interpret x as the sum of three shorter values a, b, and c.
//
// x = a + b * 2^168 + c * 2^336 mod l
//
// We then precompute 2^168 and 2^336 modulo l, and perform the reduction
// with two multiplications and two additions.
s.setShortBytes(x[:21])
t := new(Scalar).setShortBytes(x[21:42])
s.Add(s, t.Multiply(t, scalarTwo168))
t.setShortBytes(x[42:])
s.Add(s, t.Multiply(t, scalarTwo336))
return s, nil
}
// scalarTwo168 and scalarTwo336 are 2^168 and 2^336 modulo l, encoded as a
// fiatScalarMontgomeryDomainFieldElement, which is a little-endian 4-limb value
// in the 2^256 Montgomery domain.
var scalarTwo168 = &Scalar{s: [4]uint64{0x5b8ab432eac74798, 0x38afddd6de59d5d7,
0xa2c131b399411b7c, 0x6329a7ed9ce5a30}}
var scalarTwo336 = &Scalar{s: [4]uint64{0xbd3d108e2b35ecc5, 0x5c3a3718bdf9c90b,
0x63aa97a331b4f2ee, 0x3d217f5be65cb5c}}
// setShortBytes sets s = x mod l, where x is a little-endian integer shorter
// than 32 bytes.
func (s *Scalar) setShortBytes(x []byte) *Scalar {
if len(x) >= 32 {
panic("edwards25519: internal error: setShortBytes called with a long string")
}
var buf [32]byte
copy(buf[:], x)
fiatScalarFromBytes((*[4]uint64)(&s.s), &buf)
fiatScalarToMontgomery(&s.s, (*fiatScalarNonMontgomeryDomainFieldElement)(&s.s))
return s
}
// SetCanonicalBytes sets s = x, where x is a 32-byte little-endian encoding of
// s, and returns s. If x is not a canonical encoding of s, SetCanonicalBytes
// returns nil and an error, and the receiver is unchanged.
func (s *Scalar) SetCanonicalBytes(x []byte) (*Scalar, error) {
if len(x) != 32 {
return nil, errors.New("invalid scalar length")
}
if !isReduced(x) {
return nil, errors.New("invalid scalar encoding")
}
fiatScalarFromBytes((*[4]uint64)(&s.s), (*[32]byte)(x))
fiatScalarToMontgomery(&s.s, (*fiatScalarNonMontgomeryDomainFieldElement)(&s.s))
return s, nil
}
// scalarMinusOneBytes is l - 1 in little endian.
var scalarMinusOneBytes = [32]byte{236, 211, 245, 92, 26, 99, 18, 88, 214, 156, 247, 162, 222, 249, 222, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16}
// isReduced returns whether the given scalar in 32-byte little endian encoded
// form is reduced modulo l.
func isReduced(s []byte) bool {
if len(s) != 32 {
return false
}
for i := len(s) - 1; i >= 0; i-- {
switch {
case s[i] > scalarMinusOneBytes[i]:
return false
case s[i] < scalarMinusOneBytes[i]:
return true
}
}
return true
}
// SetBytesWithClamping applies the buffer pruning described in RFC 8032,
// Section 5.1.5 (also known as clamping) and sets s to the result. The input
// must be 32 bytes, and it is not modified. If x is not of the right length,
// SetBytesWithClamping returns nil and an error, and the receiver is unchanged.
//
// Note that since Scalar values are always reduced modulo the prime order of
// the curve, the resulting value will not preserve any of the cofactor-clearing
// properties that clamping is meant to provide. It will however work as
// expected as long as it is applied to points on the prime order subgroup, like
// in Ed25519. In fact, it is lost to history why RFC 8032 adopted the
// irrelevant RFC 7748 clamping, but it is now required for compatibility.
func (s *Scalar) SetBytesWithClamping(x []byte) (*Scalar, error) {
// The description above omits the purpose of the high bits of the clamping
// for brevity, but those are also lost to reductions, and are also
// irrelevant to edwards25519 as they protect against a specific
// implementation bug that was once observed in a generic Montgomery ladder.
if len(x) != 32 {
return nil, errors.New("edwards25519: invalid SetBytesWithClamping input length")
}
// We need to use the wide reduction from SetUniformBytes, since clamping
// sets the 2^254 bit, making the value higher than the order.
var wideBytes [64]byte
copy(wideBytes[:], x[:])
wideBytes[0] &= 248
wideBytes[31] &= 63
wideBytes[31] |= 64
return s.SetUniformBytes(wideBytes[:])
}
// Bytes returns the canonical 32-byte little-endian encoding of s.
func (s *Scalar) Bytes() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var encoded [32]byte
return s.bytes(&encoded)
}
func (s *Scalar) bytes(out *[32]byte) []byte {
var ss fiatScalarNonMontgomeryDomainFieldElement
fiatScalarFromMontgomery(&ss, &s.s)
fiatScalarToBytes(out, (*[4]uint64)(&ss))
return out[:]
}
// Equal returns 1 if s and t are equal, and 0 otherwise.
func (s *Scalar) Equal(t *Scalar) int {
var diff fiatScalarMontgomeryDomainFieldElement
fiatScalarSub(&diff, &s.s, &t.s)
var nonzero uint64
fiatScalarNonzero(&nonzero, (*[4]uint64)(&diff))
nonzero |= nonzero >> 32
nonzero |= nonzero >> 16
nonzero |= nonzero >> 8
nonzero |= nonzero >> 4
nonzero |= nonzero >> 2
nonzero |= nonzero >> 1
return int(^nonzero) & 1
}
// nonAdjacentForm computes a width-w non-adjacent form for this scalar.
//
// w must be between 2 and 8, or nonAdjacentForm will panic.
func (s *Scalar) nonAdjacentForm(w uint) [256]int8 {
// This implementation is adapted from the one
// in curve25519-dalek and is documented there:
// https://github.com/dalek-cryptography/curve25519-dalek/blob/f630041af28e9a405255f98a8a93adca18e4315b/src/scalar.rs#L800-L871
b := s.Bytes()
if b[31] > 127 {
panic("scalar has high bit set illegally")
}
if w < 2 {
panic("w must be at least 2 by the definition of NAF")
} else if w > 8 {
panic("NAF digits must fit in int8")
}
var naf [256]int8
var digits [5]uint64
for i := 0; i < 4; i++ {
digits[i] = binary.LittleEndian.Uint64(b[i*8:])
}
width := uint64(1 << w)
windowMask := uint64(width - 1)
pos := uint(0)
carry := uint64(0)
for pos < 256 {
indexU64 := pos / 64
indexBit := pos % 64
var bitBuf uint64
if indexBit < 64-w {
// This window's bits are contained in a single u64
bitBuf = digits[indexU64] >> indexBit
} else {
// Combine the current 64 bits with bits from the next 64
bitBuf = (digits[indexU64] >> indexBit) | (digits[1+indexU64] << (64 - indexBit))
}
// Add carry into the current window
window := carry + (bitBuf & windowMask)
if window&1 == 0 {
// If the window value is even, preserve the carry and continue.
// Why is the carry preserved?
// If carry == 0 and window & 1 == 0,
// then the next carry should be 0
// If carry == 1 and window & 1 == 0,
// then bit_buf & 1 == 1 so the next carry should be 1
pos += 1
continue
}
if window < width/2 {
carry = 0
naf[pos] = int8(window)
} else {
carry = 1
naf[pos] = int8(window) - int8(width)
}
pos += w
}
return naf
}
func (s *Scalar) signedRadix16() [64]int8 {
b := s.Bytes()
if b[31] > 127 {
panic("scalar has high bit set illegally")
}
var digits [64]int8
// Compute unsigned radix-16 digits:
for i := 0; i < 32; i++ {
digits[2*i] = int8(b[i] & 15)
digits[2*i+1] = int8((b[i] >> 4) & 15)
}
// Recenter coefficients:
for i := 0; i < 63; i++ {
carry := (digits[i] + 8) >> 4
digits[i] -= carry << 4
digits[i+1] += carry
}
return digits
}
This diff is collapsed.
// Copyright (c) 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
import "sync"
// basepointTable is a set of 32 affineLookupTables, where table i is generated
// from 256i * basepoint. It is precomputed the first time it's used.
func basepointTable() *[32]affineLookupTable {
basepointTablePrecomp.initOnce.Do(func() {
p := NewGeneratorPoint()
for i := 0; i < 32; i++ {
basepointTablePrecomp.table[i].FromP3(p)
for j := 0; j < 8; j++ {
p.Add(p, p)
}
}
})
return &basepointTablePrecomp.table
}
var basepointTablePrecomp struct {
table [32]affineLookupTable
initOnce sync.Once
}
// ScalarBaseMult sets v = x * B, where B is the canonical generator, and
// returns v.
//
// The scalar multiplication is done in constant time.
func (v *Point) ScalarBaseMult(x *Scalar) *Point {
basepointTable := basepointTable()
// Write x = sum(x_i * 16^i) so x*B = sum( B*x_i*16^i )
// as described in the Ed25519 paper
//
// Group even and odd coefficients
// x*B = x_0*16^0*B + x_2*16^2*B + ... + x_62*16^62*B
// + x_1*16^1*B + x_3*16^3*B + ... + x_63*16^63*B
// x*B = x_0*16^0*B + x_2*16^2*B + ... + x_62*16^62*B
// + 16*( x_1*16^0*B + x_3*16^2*B + ... + x_63*16^62*B)
//
// We use a lookup table for each i to get x_i*16^(2*i)*B
// and do four doublings to multiply by 16.
digits := x.signedRadix16()
multiple := &affineCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
// Accumulate the odd components first
v.Set(NewIdentityPoint())
for i := 1; i < 64; i += 2 {
basepointTable[i/2].SelectInto(multiple, digits[i])
tmp1.AddAffine(v, multiple)
v.fromP1xP1(tmp1)
}
// Multiply by 16
tmp2.FromP3(v) // tmp2 = v in P2 coords
tmp1.Double(tmp2) // tmp1 = 2*v in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 2*v in P2 coords
tmp1.Double(tmp2) // tmp1 = 4*v in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 4*v in P2 coords
tmp1.Double(tmp2) // tmp1 = 8*v in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 8*v in P2 coords
tmp1.Double(tmp2) // tmp1 = 16*v in P1xP1 coords
v.fromP1xP1(tmp1) // now v = 16*(odd components)
// Accumulate the even components
for i := 0; i < 64; i += 2 {
basepointTable[i/2].SelectInto(multiple, digits[i])
tmp1.AddAffine(v, multiple)
v.fromP1xP1(tmp1)
}
return v
}
// ScalarMult sets v = x * q, and returns v.
//
// The scalar multiplication is done in constant time.
func (v *Point) ScalarMult(x *Scalar, q *Point) *Point {
checkInitialized(q)
var table projLookupTable
table.FromP3(q)
// Write x = sum(x_i * 16^i)
// so x*Q = sum( Q*x_i*16^i )
// = Q*x_0 + 16*(Q*x_1 + 16*( ... + Q*x_63) ... )
// <------compute inside out---------
//
// We use the lookup table to get the x_i*Q values
// and do four doublings to compute 16*Q
digits := x.signedRadix16()
// Unwrap first loop iteration to save computing 16*identity
multiple := &projCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
table.SelectInto(multiple, digits[63])
v.Set(NewIdentityPoint())
tmp1.Add(v, multiple) // tmp1 = x_63*Q in P1xP1 coords
for i := 62; i >= 0; i-- {
tmp2.FromP1xP1(tmp1) // tmp2 = (prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 2*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 2*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 4*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 4*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 8*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 8*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 16*(prev) in P1xP1 coords
v.fromP1xP1(tmp1) // v = 16*(prev) in P3 coords
table.SelectInto(multiple, digits[i])
tmp1.Add(v, multiple) // tmp1 = x_i*Q + 16*(prev) in P1xP1 coords
}
v.fromP1xP1(tmp1)
return v
}
// basepointNafTable is the nafLookupTable8 for the basepoint.
// It is precomputed the first time it's used.
func basepointNafTable() *nafLookupTable8 {
basepointNafTablePrecomp.initOnce.Do(func() {
basepointNafTablePrecomp.table.FromP3(NewGeneratorPoint())
})
return &basepointNafTablePrecomp.table
}
var basepointNafTablePrecomp struct {
table nafLookupTable8
initOnce sync.Once
}
// VarTimeDoubleScalarBaseMult sets v = a * A + b * B, where B is the canonical
// generator, and returns v.
//
// Execution time depends on the inputs.
func (v *Point) VarTimeDoubleScalarBaseMult(a *Scalar, A *Point, b *Scalar) *Point {
checkInitialized(A)
// Similarly to the single variable-base approach, we compute
// digits and use them with a lookup table. However, because
// we are allowed to do variable-time operations, we don't
// need constant-time lookups or constant-time digit
// computations.
//
// So we use a non-adjacent form of some width w instead of
// radix 16. This is like a binary representation (one digit
// for each binary place) but we allow the digits to grow in
// magnitude up to 2^{w-1} so that the nonzero digits are as
// sparse as possible. Intuitively, this "condenses" the
// "mass" of the scalar onto sparse coefficients (meaning
// fewer additions).
basepointNafTable := basepointNafTable()
var aTable nafLookupTable5
aTable.FromP3(A)
// Because the basepoint is fixed, we can use a wider NAF
// corresponding to a bigger table.
aNaf := a.nonAdjacentForm(5)
bNaf := b.nonAdjacentForm(8)
// Find the first nonzero coefficient.
i := 255
for j := i; j >= 0; j-- {
if aNaf[j] != 0 || bNaf[j] != 0 {
break
}
}
multA := &projCached{}
multB := &affineCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
tmp2.Zero()
// Move from high to low bits, doubling the accumulator
// at each iteration and checking whether there is a nonzero
// coefficient to look up a multiple of.
for ; i >= 0; i-- {
tmp1.Double(tmp2)
// Only update v if we have a nonzero coeff to add in.
if aNaf[i] > 0 {
v.fromP1xP1(tmp1)
aTable.SelectInto(multA, aNaf[i])
tmp1.Add(v, multA)
} else if aNaf[i] < 0 {
v.fromP1xP1(tmp1)
aTable.SelectInto(multA, -aNaf[i])
tmp1.Sub(v, multA)
}
if bNaf[i] > 0 {
v.fromP1xP1(tmp1)
basepointNafTable.SelectInto(multB, bNaf[i])
tmp1.AddAffine(v, multB)
} else if bNaf[i] < 0 {
v.fromP1xP1(tmp1)
basepointNafTable.SelectInto(multB, -bNaf[i])
tmp1.SubAffine(v, multB)
}
tmp2.FromP1xP1(tmp1)
}
v.fromP2(tmp2)
return v
}
// Copyright (c) 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
import (
"crypto/subtle"
)
// A dynamic lookup table for variable-base, constant-time scalar muls.
type projLookupTable struct {
points [8]projCached
}
// A precomputed lookup table for fixed-base, constant-time scalar muls.
type affineLookupTable struct {
points [8]affineCached
}
// A dynamic lookup table for variable-base, variable-time scalar muls.
type nafLookupTable5 struct {
points [8]projCached
}
// A precomputed lookup table for fixed-base, variable-time scalar muls.
type nafLookupTable8 struct {
points [64]affineCached
}
// Constructors.
// Builds a lookup table at runtime. Fast.
func (v *projLookupTable) FromP3(q *Point) {
// Goal: v.points[i] = (i+1)*Q, i.e., Q, 2Q, ..., 8Q
// This allows lookup of -8Q, ..., -Q, 0, Q, ..., 8Q
v.points[0].FromP3(q)
tmpP3 := Point{}
tmpP1xP1 := projP1xP1{}
for i := 0; i < 7; i++ {
// Compute (i+1)*Q as Q + i*Q and convert to a projCached
// This is needlessly complicated because the API has explicit
// receivers instead of creating stack objects and relying on RVO
v.points[i+1].FromP3(tmpP3.fromP1xP1(tmpP1xP1.Add(q, &v.points[i])))
}
}
// This is not optimised for speed; fixed-base tables should be precomputed.
func (v *affineLookupTable) FromP3(q *Point) {
// Goal: v.points[i] = (i+1)*Q, i.e., Q, 2Q, ..., 8Q
// This allows lookup of -8Q, ..., -Q, 0, Q, ..., 8Q
v.points[0].FromP3(q)
tmpP3 := Point{}
tmpP1xP1 := projP1xP1{}
for i := 0; i < 7; i++ {
// Compute (i+1)*Q as Q + i*Q and convert to affineCached
v.points[i+1].FromP3(tmpP3.fromP1xP1(tmpP1xP1.AddAffine(q, &v.points[i])))
}
}
// Builds a lookup table at runtime. Fast.
func (v *nafLookupTable5) FromP3(q *Point) {
// Goal: v.points[i] = (2*i+1)*Q, i.e., Q, 3Q, 5Q, ..., 15Q
// This allows lookup of -15Q, ..., -3Q, -Q, 0, Q, 3Q, ..., 15Q
v.points[0].FromP3(q)
q2 := Point{}
q2.Add(q, q)
tmpP3 := Point{}
tmpP1xP1 := projP1xP1{}
for i := 0; i < 7; i++ {
v.points[i+1].FromP3(tmpP3.fromP1xP1(tmpP1xP1.Add(&q2, &v.points[i])))
}
}
// This is not optimised for speed; fixed-base tables should be precomputed.
func (v *nafLookupTable8) FromP3(q *Point) {
v.points[0].FromP3(q)
q2 := Point{}
q2.Add(q, q)
tmpP3 := Point{}
tmpP1xP1 := projP1xP1{}
for i := 0; i < 63; i++ {
v.points[i+1].FromP3(tmpP3.fromP1xP1(tmpP1xP1.AddAffine(&q2, &v.points[i])))
}
}
// Selectors.
// Set dest to x*Q, where -8 <= x <= 8, in constant time.
func (v *projLookupTable) SelectInto(dest *projCached, x int8) {
// Compute xabs = |x|
xmask := x >> 7
xabs := uint8((x + xmask) ^ xmask)
dest.Zero()
for j := 1; j <= 8; j++ {
// Set dest = j*Q if |x| = j
cond := subtle.ConstantTimeByteEq(xabs, uint8(j))
dest.Select(&v.points[j-1], dest, cond)
}
// Now dest = |x|*Q, conditionally negate to get x*Q
dest.CondNeg(int(xmask & 1))
}
// Set dest to x*Q, where -8 <= x <= 8, in constant time.
func (v *affineLookupTable) SelectInto(dest *affineCached, x int8) {
// Compute xabs = |x|
xmask := x >> 7
xabs := uint8((x + xmask) ^ xmask)
dest.Zero()
for j := 1; j <= 8; j++ {
// Set dest = j*Q if |x| = j
cond := subtle.ConstantTimeByteEq(xabs, uint8(j))
dest.Select(&v.points[j-1], dest, cond)
}
// Now dest = |x|*Q, conditionally negate to get x*Q
dest.CondNeg(int(xmask & 1))
}
// Given odd x with 0 < x < 2^4, return x*Q (in variable time).
func (v *nafLookupTable5) SelectInto(dest *projCached, x int8) {
*dest = v.points[x/2]
}
// Given odd x with 0 < x < 2^7, return x*Q (in variable time).
func (v *nafLookupTable8) SelectInto(dest *affineCached, x int8) {
*dest = v.points[x/2]
}
/examples/blog/blog
/examples/orders/orders
/examples/basic/basic
.idea/
language: go
go_import_path: github.com/DATA-DOG/go-sqlmock
go:
- 1.2.x
- 1.3.x
- 1.4 # has no cover tool for latest releases
- 1.5.x
- 1.6.x
- 1.7.x
- 1.8.x
- 1.9.x
- 1.10.x
- 1.11.x
- 1.12.x
- 1.13.x
- 1.14.x
script:
- go vet
- test -z "$(go fmt ./...)" # fail if not formatted properly
- go test -race -coverprofile=coverage.txt -covermode=atomic
after_success:
- bash <(curl -s https://codecov.io/bash)
The three clause BSD license (http://en.wikipedia.org/wiki/BSD_licenses)
Copyright (c) 2013-2019, DATA-DOG team
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* The name DataDog.lt may not be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL MICHAEL BOSTOCK BE LIABLE FOR ANY DIRECT,
INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
[![Build Status](https://travis-ci.org/DATA-DOG/go-sqlmock.svg)](https://travis-ci.org/DATA-DOG/go-sqlmock)
[![GoDoc](https://godoc.org/github.com/DATA-DOG/go-sqlmock?status.svg)](https://godoc.org/github.com/DATA-DOG/go-sqlmock)
[![Go Report Card](https://goreportcard.com/badge/github.com/DATA-DOG/go-sqlmock)](https://goreportcard.com/report/github.com/DATA-DOG/go-sqlmock)
[![codecov.io](https://codecov.io/github/DATA-DOG/go-sqlmock/branch/master/graph/badge.svg)](https://codecov.io/github/DATA-DOG/go-sqlmock)
# Sql driver mock for Golang
**sqlmock** is a mock library implementing [sql/driver](https://godoc.org/database/sql/driver). Which has one and only
purpose - to simulate any **sql** driver behavior in tests, without needing a real database connection. It helps to
maintain correct **TDD** workflow.
- this library is now complete and stable. (you may not find new changes for this reason)
- supports concurrency and multiple connections.
- supports **go1.8** Context related feature mocking and Named sql parameters.
- does not require any modifications to your source code.
- the driver allows to mock any sql driver method behavior.
- has strict by default expectation order matching.
- has no third party dependencies.
**NOTE:** in **v1.2.0** **sqlmock.Rows** has changed to struct from interface, if you were using any type references to that
interface, you will need to switch it to a pointer struct type. Also, **sqlmock.Rows** were used to implement **driver.Rows**
interface, which was not required or useful for mocking and was removed. Hope it will not cause issues.
## Looking for maintainers
I do not have much spare time for this library and willing to transfer the repository ownership
to person or an organization motivated to maintain it. Open up a conversation if you are interested. See #230.
## Install
go get github.com/DATA-DOG/go-sqlmock
## Documentation and Examples
Visit [godoc](http://godoc.org/github.com/DATA-DOG/go-sqlmock) for general examples and public api reference.
See **.travis.yml** for supported **go** versions.
Different use case, is to functionally test with a real database - [go-txdb](https://github.com/DATA-DOG/go-txdb)
all database related actions are isolated within a single transaction so the database can remain in the same state.
See implementation examples:
- [blog API server](https://github.com/DATA-DOG/go-sqlmock/tree/master/examples/blog)
- [the same orders example](https://github.com/DATA-DOG/go-sqlmock/tree/master/examples/orders)
### Something you may want to test, assuming you use the [go-mysql-driver](https://github.com/go-sql-driver/mysql)
``` go
package main
import (
"database/sql"
_ "github.com/go-sql-driver/mysql"
)
func recordStats(db *sql.DB, userID, productID int64) (err error) {
tx, err = db.Begin()
if err != nil {
return
}
defer func() {
switch err {
case nil:
err = tx.Commit()
default:
tx.Rollback()
}
}()
if _, err = tx.Exec("UPDATE products SET views = views + 1"); err != nil {
return
}
if _, err = tx.Exec("INSERT INTO product_viewers (user_id, product_id) VALUES (?, ?)", userID, productID); err != nil {
return
}
return
}
func main() {
// @NOTE: the real connection is not required for tests
db, err := sql.Open("mysql", "root@/blog")
if err != nil {
panic(err)
}
defer db.Close()
if err = recordStats(db, 1 /*some user id*/, 5 /*some product id*/); err != nil {
panic(err)
}
}
```
### Tests with sqlmock
``` go
package main
import (
"fmt"
"testing"
"github.com/DATA-DOG/go-sqlmock"
)
// a successful case
func TestShouldUpdateStats(t *testing.T) {
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
mock.ExpectBegin()
mock.ExpectExec("UPDATE products").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("INSERT INTO product_viewers").WithArgs(2, 3).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
// now we execute our method
if err = recordStats(db, 2, 3); err != nil {
t.Errorf("error was not expected while updating stats: %s", err)
}
// we make sure that all expectations were met
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}
// a failing test case
func TestShouldRollbackStatUpdatesOnFailure(t *testing.T) {
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
mock.ExpectBegin()
mock.ExpectExec("UPDATE products").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("INSERT INTO product_viewers").
WithArgs(2, 3).
WillReturnError(fmt.Errorf("some error"))
mock.ExpectRollback()
// now we execute our method
if err = recordStats(db, 2, 3); err == nil {
t.Errorf("was expecting an error, but there was none")
}
// we make sure that all expectations were met
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}
```
## Customize SQL query matching
There were plenty of requests from users regarding SQL query string validation or different matching option.
We have now implemented the `QueryMatcher` interface, which can be passed through an option when calling
`sqlmock.New` or `sqlmock.NewWithDSN`.
This now allows to include some library, which would allow for example to parse and validate `mysql` SQL AST.
And create a custom QueryMatcher in order to validate SQL in sophisticated ways.
By default, **sqlmock** is preserving backward compatibility and default query matcher is `sqlmock.QueryMatcherRegexp`
which uses expected SQL string as a regular expression to match incoming query string. There is an equality matcher:
`QueryMatcherEqual` which will do a full case sensitive match.
In order to customize the QueryMatcher, use the following:
``` go
db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
```
The query matcher can be fully customized based on user needs. **sqlmock** will not
provide a standard sql parsing matchers, since various drivers may not follow the same SQL standard.
## Matching arguments like time.Time
There may be arguments which are of `struct` type and cannot be compared easily by value like `time.Time`. In this case
**sqlmock** provides an [Argument](https://godoc.org/github.com/DATA-DOG/go-sqlmock#Argument) interface which
can be used in more sophisticated matching. Here is a simple example of time argument matching:
``` go
type AnyTime struct{}
// Match satisfies sqlmock.Argument interface
func (a AnyTime) Match(v driver.Value) bool {
_, ok := v.(time.Time)
return ok
}
func TestAnyTimeArgument(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
mock.ExpectExec("INSERT INTO users").
WithArgs("john", AnyTime{}).
WillReturnResult(NewResult(1, 1))
_, err = db.Exec("INSERT INTO users(name, created_at) VALUES (?, ?)", "john", time.Now())
if err != nil {
t.Errorf("error '%s' was not expected, while inserting a row", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}
```
It only asserts that argument is of `time.Time` type.
## Run tests
go test -race
## Change Log
- **2019-04-06** - added functionality to mock a sql MetaData request
- **2019-02-13** - added `go.mod` removed the references and suggestions using `gopkg.in`.
- **2018-12-11** - added expectation of Rows to be closed, while mocking expected query.
- **2018-12-11** - introduced an option to provide **QueryMatcher** in order to customize SQL query matching.
- **2017-09-01** - it is now possible to expect that prepared statement will be closed,
using **ExpectedPrepare.WillBeClosed**.
- **2017-02-09** - implemented support for **go1.8** features. **Rows** interface was changed to struct
but contains all methods as before and should maintain backwards compatibility. **ExpectedQuery.WillReturnRows** may now
accept multiple row sets.
- **2016-11-02** - `db.Prepare()` was not validating expected prepare SQL
query. It should still be validated even if Exec or Query is not
executed on that prepared statement.
- **2016-02-23** - added **sqlmock.AnyArg()** function to provide any kind
of argument matcher.
- **2016-02-23** - convert expected arguments to driver.Value as natural
driver does, the change may affect time.Time comparison and will be
stricter. See [issue](https://github.com/DATA-DOG/go-sqlmock/issues/31).
- **2015-08-27** - **v1** api change, concurrency support, all known issues fixed.
- **2014-08-16** instead of **panic** during reflect type mismatch when comparing query arguments - now return error
- **2014-08-14** added **sqlmock.NewErrorResult** which gives an option to return driver.Result with errors for
interface methods, see [issue](https://github.com/DATA-DOG/go-sqlmock/issues/5)
- **2014-05-29** allow to match arguments in more sophisticated ways, by providing an **sqlmock.Argument** interface
- **2014-04-21** introduce **sqlmock.New()** to open a mock database connection for tests. This method
calls sql.DB.Ping to ensure that connection is open, see [issue](https://github.com/DATA-DOG/go-sqlmock/issues/4).
This way on Close it will surely assert if all expectations are met, even if database was not triggered at all.
The old way is still available, but it is advisable to call db.Ping manually before asserting with db.Close.
- **2014-02-14** RowsFromCSVString is now a part of Rows interface named as FromCSVString.
It has changed to allow more ways to construct rows and to easily extend this API in future.
See [issue 1](https://github.com/DATA-DOG/go-sqlmock/issues/1)
**RowsFromCSVString** is deprecated and will be removed in future
## Contributions
Feel free to open a pull request. Note, if you wish to contribute an extension to public (exported methods or types) -
please open an issue before, to discuss whether these changes can be accepted. All backward incompatible changes are
and will be treated cautiously
## License
The [three clause BSD license](http://en.wikipedia.org/wiki/BSD_licenses)
package sqlmock
import "database/sql/driver"
// Argument interface allows to match
// any argument in specific way when used with
// ExpectedQuery and ExpectedExec expectations.
type Argument interface {
Match(driver.Value) bool
}
// AnyArg will return an Argument which can
// match any kind of arguments.
//
// Useful for time.Time or similar kinds of arguments.
func AnyArg() Argument {
return anyArgument{}
}
type anyArgument struct{}
func (a anyArgument) Match(_ driver.Value) bool {
return true
}
package sqlmock
import "reflect"
// Column is a mocked column Metadata for rows.ColumnTypes()
type Column struct {
name string
dbType string
nullable bool
nullableOk bool
length int64
lengthOk bool
precision int64
scale int64
psOk bool
scanType reflect.Type
}
func (c *Column) Name() string {
return c.name
}
func (c *Column) DbType() string {
return c.dbType
}
func (c *Column) IsNullable() (bool, bool) {
return c.nullable, c.nullableOk
}
func (c *Column) Length() (int64, bool) {
return c.length, c.lengthOk
}
func (c *Column) PrecisionScale() (int64, int64, bool) {
return c.precision, c.scale, c.psOk
}
func (c *Column) ScanType() reflect.Type {
return c.scanType
}
// NewColumn returns a Column with specified name
func NewColumn(name string) *Column {
return &Column{
name: name,
}
}
// Nullable returns the column with nullable metadata set
func (c *Column) Nullable(nullable bool) *Column {
c.nullable = nullable
c.nullableOk = true
return c
}
// OfType returns the column with type metadata set
func (c *Column) OfType(dbType string, sampleValue interface{}) *Column {
c.dbType = dbType
c.scanType = reflect.TypeOf(sampleValue)
return c
}
// WithLength returns the column with length metadata set.
func (c *Column) WithLength(length int64) *Column {
c.length = length
c.lengthOk = true
return c
}
// WithPrecisionAndScale returns the column with precision and scale metadata set.
func (c *Column) WithPrecisionAndScale(precision, scale int64) *Column {
c.precision = precision
c.scale = scale
c.psOk = true
return c
}
package sqlmock
import (
"database/sql"
"database/sql/driver"
"fmt"
"sync"
)
var pool *mockDriver
func init() {
pool = &mockDriver{
conns: make(map[string]*sqlmock),
}
sql.Register("sqlmock", pool)
}
type mockDriver struct {
sync.Mutex
counter int
conns map[string]*sqlmock
}
func (d *mockDriver) Open(dsn string) (driver.Conn, error) {
d.Lock()
defer d.Unlock()
c, ok := d.conns[dsn]
if !ok {
return c, fmt.Errorf("expected a connection to be available, but it is not")
}
c.opened++
return c, nil
}
// New creates sqlmock database connection and a mock to manage expectations.
// Accepts options, like ValueConverterOption, to use a ValueConverter from
// a specific driver.
// Pings db so that all expectations could be
// asserted.
func New(options ...func(*sqlmock) error) (*sql.DB, Sqlmock, error) {
pool.Lock()
dsn := fmt.Sprintf("sqlmock_db_%d", pool.counter)
pool.counter++
smock := &sqlmock{dsn: dsn, drv: pool, ordered: true}
pool.conns[dsn] = smock
pool.Unlock()
return smock.open(options)
}
// NewWithDSN creates sqlmock database connection with a specific DSN
// and a mock to manage expectations.
// Accepts options, like ValueConverterOption, to use a ValueConverter from
// a specific driver.
// Pings db so that all expectations could be asserted.
//
// This method is introduced because of sql abstraction
// libraries, which do not provide a way to initialize
// with sql.DB instance. For example GORM library.
//
// Note, it will error if attempted to create with an
// already used dsn
//
// It is not recommended to use this method, unless you
// really need it and there is no other way around.
func NewWithDSN(dsn string, options ...func(*sqlmock) error) (*sql.DB, Sqlmock, error) {
pool.Lock()
if _, ok := pool.conns[dsn]; ok {
pool.Unlock()
return nil, nil, fmt.Errorf("cannot create a new mock database with the same dsn: %s", dsn)
}
smock := &sqlmock{dsn: dsn, drv: pool, ordered: true}
pool.conns[dsn] = smock
pool.Unlock()
return smock.open(options)
}
This diff is collapsed.
// +build !go1.8
package sqlmock
import (
"database/sql/driver"
"fmt"
"reflect"
)
// WillReturnRows specifies the set of resulting rows that will be returned
// by the triggered query
func (e *ExpectedQuery) WillReturnRows(rows *Rows) *ExpectedQuery {
e.rows = &rowSets{sets: []*Rows{rows}, ex: e}
return e
}
func (e *queryBasedExpectation) argsMatches(args []namedValue) error {
if nil == e.args {
return nil
}
if len(args) != len(e.args) {
return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args))
}
for k, v := range args {
// custom argument matcher
matcher, ok := e.args[k].(Argument)
if ok {
// @TODO: does it make sense to pass value instead of named value?
if !matcher.Match(v.Value) {
return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k])
}
continue
}
dval := e.args[k]
// convert to driver converter
darg, err := e.converter.ConvertValue(dval)
if err != nil {
return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err)
}
if !driver.IsValue(darg) {
return fmt.Errorf("argument %d: non-subset type %T returned from Value", k, darg)
}
if !reflect.DeepEqual(darg, v.Value) {
return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v.Value, v.Value)
}
}
return nil
}
func (e *queryBasedExpectation) attemptArgMatch(args []namedValue) (err error) {
// catch panic
defer func() {
if e := recover(); e != nil {
_, ok := e.(error)
if !ok {
err = fmt.Errorf(e.(string))
}
}
}()
err = e.argsMatches(args)
return
}
// +build go1.8
package sqlmock
import (
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
)
// WillReturnRows specifies the set of resulting rows that will be returned
// by the triggered query
func (e *ExpectedQuery) WillReturnRows(rows ...*Rows) *ExpectedQuery {
defs := 0
sets := make([]*Rows, len(rows))
for i, r := range rows {
sets[i] = r
if r.def != nil {
defs++
}
}
if defs > 0 && defs == len(sets) {
e.rows = &rowSetsWithDefinition{&rowSets{sets: sets, ex: e}}
} else {
e.rows = &rowSets{sets: sets, ex: e}
}
return e
}
func (e *queryBasedExpectation) argsMatches(args []driver.NamedValue) error {
if nil == e.args {
return nil
}
if len(args) != len(e.args) {
return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args))
}
// @TODO should we assert either all args are named or ordinal?
for k, v := range args {
// custom argument matcher
matcher, ok := e.args[k].(Argument)
if ok {
if !matcher.Match(v.Value) {
return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k])
}
continue
}
dval := e.args[k]
if named, isNamed := dval.(sql.NamedArg); isNamed {
dval = named.Value
if v.Name != named.Name {
return fmt.Errorf("named argument %d: name: \"%s\" does not match expected: \"%s\"", k, v.Name, named.Name)
}
} else if k+1 != v.Ordinal {
return fmt.Errorf("argument %d: ordinal position: %d does not match expected: %d", k, k+1, v.Ordinal)
}
// convert to driver converter
darg, err := e.converter.ConvertValue(dval)
if err != nil {
return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err)
}
if !reflect.DeepEqual(darg, v.Value) {
return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v.Value, v.Value)
}
}
return nil
}
func (e *queryBasedExpectation) attemptArgMatch(args []driver.NamedValue) (err error) {
// catch panic
defer func() {
if e := recover(); e != nil {
_, ok := e.(error)
if !ok {
err = fmt.Errorf(e.(string))
}
}
}()
err = e.argsMatches(args)
return
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.