| OLD | NEW |
| 1 // Copyright 2015 The LUCI Authors. All rights reserved. | 1 // Copyright 2015 The LUCI Authors. All rights reserved. |
| 2 // Use of this source code is governed under the Apache License, Version 2.0 | 2 // Use of this source code is governed under the Apache License, Version 2.0 |
| 3 // that can be found in the LICENSE file. | 3 // that can be found in the LICENSE file. |
| 4 | 4 |
| 5 // Package mathrand implements a mockable interface for math/rand.Rand. | 5 // Package mathrand implements a mockable interface for math/rand.Rand. |
| 6 // | 6 // |
| 7 // It is controllable through context.Context. You should use this instead of | 7 // It is controllable through context.Context. You should use this instead of |
| 8 // math/rand directly, to allow you to make deterministic tests. | 8 // math/rand directly, to allow you to make deterministic tests. |
| 9 package mathrand | 9 package mathrand |
| 10 | 10 |
| 11 import ( | 11 import ( |
| 12 "math/rand" | 12 "math/rand" |
| 13 "sync" |
| 13 | 14 |
| 14 "golang.org/x/net/context" | 15 "golang.org/x/net/context" |
| 15 ) | 16 ) |
| 16 | 17 |
| 17 var key = "holds a rand.Rand for mathrand" | 18 var key = "holds a rand.Rand for mathrand" |
| 18 | 19 |
| 19 func newRand() *rand.Rand { | 20 var ( |
| 20 » return rand.New(rand.NewSource(rand.Int63())) | 21 » // globalOnce performs one-time global random variable initialization. |
| 22 » globalOnce sync.Once |
| 23 |
| 24 » // globalRandBase is the gloal *rand.Rand instance. It MUST NOT BE USED |
| 25 » // without holding globalRand's lock. |
| 26 » // |
| 27 » // globalRandBase must not be accessed directly; instead, it must be obt
ained |
| 28 » // through getGlobalRand to ensure that it is initialized. |
| 29 » globalRandBase *rand.Rand |
| 30 |
| 31 » // globalRand is the locking Rand implementation that wraps globalRandBa
se. |
| 32 » // |
| 33 » // globalRand must not be accessed directly; instead, it must be obtaine
d |
| 34 » // through getGlobalRand to ensure that it is initialized. |
| 35 » globalRand *Locking |
| 36 ) |
| 37 |
| 38 // getGlobalRand returns globalRand and its Locking wrapper. This must be used |
| 39 // instead of direct variable access in order to ensure that everything is |
| 40 // initialized. |
| 41 // |
| 42 // We use a Once to perform this initialization so that we can enable |
| 43 // applications to set the seed via rand.Seed if they wish. |
| 44 func getGlobalRand() (*Locking, *rand.Rand) { |
| 45 » globalOnce.Do(func() { |
| 46 » » globalRandBase = newRand() |
| 47 » » globalRand = &Locking{R: wrapped{globalRandBase}} |
| 48 » }) |
| 49 » return globalRand, globalRandBase |
| 21 } | 50 } |
| 22 | 51 |
| 52 func newRand() *rand.Rand { return rand.New(rand.NewSource(rand.Int63())) } |
| 53 |
| 54 // getRand returns the Rand installed in c, or nil if no Rand is installed. |
| 23 func getRand(c context.Context) Rand { | 55 func getRand(c context.Context) Rand { |
| 24 if r, ok := c.Value(&key).(Rand); ok { | 56 if r, ok := c.Value(&key).(Rand); ok { |
| 25 return r | 57 return r |
| 26 } | 58 } |
| 27 return nil | 59 return nil |
| 28 } | 60 } |
| 29 | 61 |
| 30 // Get gets a Rand from the Context. The resulting Rand is safe for concurrent | 62 // Get gets a Rand from the Context. The resulting Rand is safe for concurrent |
| 31 // use. | 63 // use. |
| 32 // | 64 // |
| 33 // If one hasn't been set, this creates a new Rand object with a Source | 65 // If one hasn't been set, this will return a global Rand object backed by a |
| 34 // initialized from the global randomness source provided by stdlib. | 66 // shared rand.Rand singleton. Just like in "math/rand", rand.Seed can be called |
| 35 // | 67 // prior to using Get to set the seed used by this singleton. |
| 36 // If you want to get just a single random number, prefer to use a corresponding | |
| 37 // global function instead: they know how to use math/rand global RNG directly | |
| 38 // and thus are much faster in case the context doesn't have a rand.Rand | |
| 39 // installed. | |
| 40 // | |
| 41 // Use 'Get' only if you plan to obtain a large series of random numbers. | |
| 42 func Get(c context.Context) Rand { | 68 func Get(c context.Context) Rand { |
| 43 if r := getRand(c); r != nil { | 69 if r := getRand(c); r != nil { |
| 44 return r | 70 return r |
| 45 } | 71 } |
| 46 | 72 |
| 47 » // Generate a new Rand instance and return it. Our callers expect this t
o be | 73 » // Use the global instance. |
| 48 » // concurrency-safe. | 74 » gr, _ := getGlobalRand() |
| 49 » return wrapLocking(wrapRand(newRand())) | 75 » return gr |
| 50 } | 76 } |
| 51 | 77 |
| 52 // Set sets the current *"math/rand".Rand object in the context. | 78 // Set sets the current *"math/rand".Rand object in the context. |
| 53 // | 79 // |
| 54 // Useful for testing with a quick mock. The supplied *rand.Rand will be wrapped | 80 // Useful for testing with a quick mock. The supplied *rand.Rand will be wrapped |
| 55 // in a *Locking if necessary such that when it is returned from Get, it is | 81 // in a *Locking if necessary such that when it is returned from Get, it is |
| 56 // safe for concurrent use. | 82 // safe for concurrent use. |
| 57 func Set(c context.Context, mr *rand.Rand) context.Context { | 83 func Set(c context.Context, mr *rand.Rand) context.Context { |
| 58 var r Rand | 84 var r Rand |
| 59 if mr != nil { | 85 if mr != nil { |
| (...skipping 10 matching lines...) Expand all Loading... |
| 70 func SetRand(c context.Context, r Rand) context.Context { | 96 func SetRand(c context.Context, r Rand) context.Context { |
| 71 if r != nil { | 97 if r != nil { |
| 72 r = wrapLocking(r) | 98 r = wrapLocking(r) |
| 73 } | 99 } |
| 74 return context.WithValue(c, &key, r) | 100 return context.WithValue(c, &key, r) |
| 75 } | 101 } |
| 76 | 102 |
| 77 //////////////////////////////////////////////////////////////////////////////// | 103 //////////////////////////////////////////////////////////////////////////////// |
| 78 // Top-level convenience functions mirroring math/rand package API. | 104 // Top-level convenience functions mirroring math/rand package API. |
| 79 // | 105 // |
| 80 // They are here to optimize the case when the context doesn't have math.Rand | 106 // This makes mathrand API more similar to math/rand API. |
| 81 // installed. Using Get(ctx).<function> in this case is semantically equivalent | |
| 82 // to using global RNG, but much slower (up to 400x slower), because it creates | |
| 83 // and seeds new math.Rand object on each call. | |
| 84 // | |
| 85 // Unfortunately since math.Rand is a struct, and its implementation is not | |
| 86 // thread-safe, we can't just return some global math.Rand instance. The stdlib | |
| 87 // has one, but it is private, and we can't reimplement it because stdlib does | |
| 88 // some disgusting type casts to private types in math.Rand implementation, e.g: | |
| 89 // https://github.com/golang/go/blob/fb3cf5c/src/math/rand/rand.go#L183 | |
| 90 // | |
| 91 // This also makes mathrand API more similar to math/rand API. | |
| 92 | 107 |
| 93 // Int63 returns a non-negative pseudo-random 63-bit integer as an int64 | 108 // Int63 returns a non-negative pseudo-random 63-bit integer as an int64 |
| 94 // from the source in the context or the shared global source. | 109 // from the source in the context or the shared global source. |
| 95 func Int63(c context.Context) int64 { | 110 func Int63(c context.Context) int64 { return Get(c).Int63() } |
| 96 » if r := getRand(c); r != nil { | |
| 97 » » return r.Int63() | |
| 98 » } | |
| 99 » return rand.Int63() | |
| 100 } | |
| 101 | 111 |
| 102 // Uint32 returns a pseudo-random 32-bit value as a uint32 from the source in | 112 // Uint32 returns a pseudo-random 32-bit value as a uint32 from the source in |
| 103 // the context or the shared global source. | 113 // the context or the shared global source. |
| 104 func Uint32(c context.Context) uint32 { | 114 func Uint32(c context.Context) uint32 { return Get(c).Uint32() } |
| 105 » if r := getRand(c); r != nil { | |
| 106 » » return r.Uint32() | |
| 107 » } | |
| 108 » return rand.Uint32() | |
| 109 } | |
| 110 | 115 |
| 111 // Int31 returns a non-negative pseudo-random 31-bit integer as an int32 from | 116 // Int31 returns a non-negative pseudo-random 31-bit integer as an int32 from |
| 112 // the source in the context or the shared global source. | 117 // the source in the context or the shared global source. |
| 113 func Int31(c context.Context) int32 { | 118 func Int31(c context.Context) int32 { return Get(c).Int31() } |
| 114 » if r := getRand(c); r != nil { | |
| 115 » » return r.Int31() | |
| 116 » } | |
| 117 » return rand.Int31() | |
| 118 } | |
| 119 | 119 |
| 120 // Int returns a non-negative pseudo-random int from the source in the context | 120 // Int returns a non-negative pseudo-random int from the source in the context |
| 121 // or the shared global source. | 121 // or the shared global source. |
| 122 func Int(c context.Context) int { | 122 func Int(c context.Context) int { return Get(c).Int() } |
| 123 » if r := getRand(c); r != nil { | |
| 124 » » return r.Int() | |
| 125 » } | |
| 126 » return rand.Int() | |
| 127 } | |
| 128 | 123 |
| 129 // Int63n returns, as an int64, a non-negative pseudo-random number in [0,n) | 124 // Int63n returns, as an int64, a non-negative pseudo-random number in [0,n) |
| 130 // from the source in the context or the shared global source. | 125 // from the source in the context or the shared global source. |
| 131 // | 126 // |
| 132 // It panics if n <= 0. | 127 // It panics if n <= 0. |
| 133 func Int63n(c context.Context, n int64) int64 { | 128 func Int63n(c context.Context, n int64) int64 { return Get(c).Int63n(n) } |
| 134 » if r := getRand(c); r != nil { | |
| 135 » » return r.Int63n(n) | |
| 136 » } | |
| 137 » return rand.Int63n(n) | |
| 138 } | |
| 139 | 129 |
| 140 // Int31n returns, as an int32, a non-negative pseudo-random number in [0,n) | 130 // Int31n returns, as an int32, a non-negative pseudo-random number in [0,n) |
| 141 // from the source in the context or the shared global source. | 131 // from the source in the context or the shared global source. |
| 142 // | 132 // |
| 143 // It panics if n <= 0. | 133 // It panics if n <= 0. |
| 144 func Int31n(c context.Context, n int32) int32 { | 134 func Int31n(c context.Context, n int32) int32 { return Get(c).Int31n(n) } |
| 145 » if r := getRand(c); r != nil { | |
| 146 » » return r.Int31n(n) | |
| 147 » } | |
| 148 » return rand.Int31n(n) | |
| 149 } | |
| 150 | 135 |
| 151 // Intn returns, as an int, a non-negative pseudo-random number in [0,n) from | 136 // Intn returns, as an int, a non-negative pseudo-random number in [0,n) from |
| 152 // the source in the context or the shared global source. | 137 // the source in the context or the shared global source. |
| 153 // | 138 // |
| 154 // It panics if n <= 0. | 139 // It panics if n <= 0. |
| 155 func Intn(c context.Context, n int) int { | 140 func Intn(c context.Context, n int) int { return Get(c).Intn(n) } |
| 156 » if r := getRand(c); r != nil { | |
| 157 » » return r.Intn(n) | |
| 158 » } | |
| 159 » return rand.Intn(n) | |
| 160 } | |
| 161 | 141 |
| 162 // Float64 returns, as a float64, a pseudo-random number in [0.0,1.0) from | 142 // Float64 returns, as a float64, a pseudo-random number in [0.0,1.0) from |
| 163 // the source in the context or the shared global source. | 143 // the source in the context or the shared global source. |
| 164 func Float64(c context.Context) float64 { | 144 func Float64(c context.Context) float64 { return Get(c).Float64() } |
| 165 » if r := getRand(c); r != nil { | |
| 166 » » return r.Float64() | |
| 167 » } | |
| 168 » return rand.Float64() | |
| 169 } | |
| 170 | 145 |
| 171 // Float32 returns, as a float32, a pseudo-random number in [0.0,1.0) from | 146 // Float32 returns, as a float32, a pseudo-random number in [0.0,1.0) from |
| 172 // the source in the context or the shared global source. | 147 // the source in the context or the shared global source. |
| 173 func Float32(c context.Context) float32 { | 148 func Float32(c context.Context) float32 { return Get(c).Float32() } |
| 174 » if r := getRand(c); r != nil { | |
| 175 » » return r.Float32() | |
| 176 » } | |
| 177 » return rand.Float32() | |
| 178 } | |
| 179 | 149 |
| 180 // Perm returns, as a slice of n ints, a pseudo-random permutation of the | 150 // Perm returns, as a slice of n ints, a pseudo-random permutation of the |
| 181 // integers [0,n) from the source in the context or the shared global source. | 151 // integers [0,n) from the source in the context or the shared global source. |
| 182 func Perm(c context.Context, n int) []int { | 152 func Perm(c context.Context, n int) []int { return Get(c).Perm(n) } |
| 183 » if r := getRand(c); r != nil { | |
| 184 » » return r.Perm(n) | |
| 185 » } | |
| 186 » return rand.Perm(n) | |
| 187 } | |
| 188 | 153 |
| 189 // Read generates len(p) random bytes from the source in the context or | 154 // Read generates len(p) random bytes from the source in the context or |
| 190 // the shared global source and writes them into p. It always returns len(p) | 155 // the shared global source and writes them into p. It always returns len(p) |
| 191 // and a nil error. | 156 // and a nil error. |
| 192 func Read(c context.Context, p []byte) (n int, err error) { | 157 func Read(c context.Context, p []byte) (n int, err error) { return Get(c).Read(p
) } |
| 193 » if r := getRand(c); r != nil { | |
| 194 » » return r.Read(p) | |
| 195 » } | |
| 196 » return rand.Read(p) | |
| 197 } | |
| 198 | 158 |
| 199 // NormFloat64 returns a normally distributed float64 in the range | 159 // NormFloat64 returns a normally distributed float64 in the range |
| 200 // [-math.MaxFloat64, +math.MaxFloat64] with standard normal distribution | 160 // [-math.MaxFloat64, +math.MaxFloat64] with standard normal distribution |
| 201 // (mean = 0, stddev = 1) from the source in the context or the shared global | 161 // (mean = 0, stddev = 1) from the source in the context or the shared global |
| 202 // source. | 162 // source. |
| 203 // | 163 // |
| 204 // To produce a different normal distribution, callers can adjust the output | 164 // To produce a different normal distribution, callers can adjust the output |
| 205 // using: | 165 // using: |
| 206 // | 166 // |
| 207 // sample = NormFloat64(ctx) * desiredStdDev + desiredMean | 167 // sample = NormFloat64(ctx) * desiredStdDev + desiredMean |
| 208 // | 168 // |
| 209 func NormFloat64(c context.Context) float64 { | 169 func NormFloat64(c context.Context) float64 { return Get(c).NormFloat64() } |
| 210 » if r := getRand(c); r != nil { | |
| 211 » » return r.NormFloat64() | |
| 212 » } | |
| 213 » return rand.NormFloat64() | |
| 214 } | |
| 215 | 170 |
| 216 // ExpFloat64 returns an exponentially distributed float64 in the range | 171 // ExpFloat64 returns an exponentially distributed float64 in the range |
| 217 // (0, +math.MaxFloat64] with an exponential distribution whose rate parameter | 172 // (0, +math.MaxFloat64] with an exponential distribution whose rate parameter |
| 218 // (lambda) is 1 and whose mean is 1/lambda (1) from the source in the context | 173 // (lambda) is 1 and whose mean is 1/lambda (1) from the source in the context |
| 219 // or the shared global source. | 174 // or the shared global source. |
| 220 // | 175 // |
| 221 // To produce a distribution with a different rate parameter, callers can adjust | 176 // To produce a distribution with a different rate parameter, callers can adjust |
| 222 // the output using: | 177 // the output using: |
| 223 // | 178 // |
| 224 // sample = ExpFloat64(ctx) / desiredRateParameter | 179 // sample = ExpFloat64(ctx) / desiredRateParameter |
| 225 // | 180 // |
| 226 func ExpFloat64(c context.Context) float64 { | 181 func ExpFloat64(c context.Context) float64 { return Get(c).ExpFloat64() } |
| 227 » if r := getRand(c); r != nil { | |
| 228 » » return r.ExpFloat64() | |
| 229 » } | |
| 230 » return rand.ExpFloat64() | |
| 231 } | |
| 232 | 182 |
| 233 // WithGoRand invokes the supplied "fn" while holding an exclusive lock | 183 // WithGoRand invokes the supplied "fn" while holding an exclusive lock |
| 234 // for it. This can be used by callers to pull and use a *rand.Rand instance | 184 // for it. This can be used by callers to pull and use a *rand.Rand instance |
| 235 // out of the Context safely. | 185 // out of the Context safely. |
| 236 // | 186 // |
| 237 // The callback's r must not be retained or used outside of hte scope of the | 187 // The callback's r must not be retained or used outside of the scope of the |
| 238 // callback. | 188 // callback. |
| 239 func WithGoRand(c context.Context, fn func(r *rand.Rand) error) error { | 189 func WithGoRand(c context.Context, fn func(r *rand.Rand) error) error { |
| 240 if r := getRand(c); r != nil { | 190 if r := getRand(c); r != nil { |
| 241 return r.WithGoRand(fn) | 191 return r.WithGoRand(fn) |
| 242 } | 192 } |
| 243 | 193 |
| 244 » // No Rand is installed in our Context. Generate a single-use Rand insta
nce. | 194 » // Return our globalRandBase. We MUST hold globalRand's lock in order fo
r this |
| 245 » // We don't need to wrap this at all, since the premise of this method i
s | 195 » // to be safe. |
| 246 » // that the result is not safe for concurrent use. | 196 » l, base := getGlobalRand() |
| 247 » return fn(newRand()) | 197 » l.Lock() |
| 198 » defer l.Unlock() |
| 199 » return fn(base) |
| 248 } | 200 } |
| OLD | NEW |