Numpy’s random choice in GoLang

Max Lefarov
Towards Data Science
5 min readJan 23, 2024

--

Generated with ChatGPT

Recently I’ve assisted with implementing some logic in Java that could easily be achieved with a single call of Numpy’s random.choice. It ended up being one of those tasks that allow looking into things you’re using every day but never have time to fully understand how they work. Also for quite some time I wanted to start learning Go, so why not kill two birds with one stone and reimplement random.choice once again this time in Go?

random.choice allows us to sample N elements from a provided collection according to the specified probabilities. Importantly (for the use case that motivated this work), it allows us to sample these elements without replacements. I.e. if an element of collection was already sampled it won’t be sampled again. For example, if we have a collection [A, B, C] with associated probabilities [0.1, 0.7, 0.2] and we want to sample 3 elements without replacements, most of the time we’ll get [B, C, A] as an output. If we sample with replacement, the expected output would be [B, B, B].

First, let’s define the signature of the Go function. We want to keep it as close to Numpy’s counterpart as possible.

func Choice[T any](
arr []T,
size int,
replace bool,
probs []float64,
rng *rand.Rand,
) ([]T, error) {}

A few things to notice about the function signature:

  • We’re using the generic type T. It allows calling this function for arrays of different types (as long as it satisfies the type constraint, which is none in our case). This should mimic the Pythonic semantics of random.choice, i.e. it doesn’t care about the type of elements stored in the input array.
  • We also pass the pointer to a random number generator (rng) object that we’ll use for sampling. I’ve picked up this style of defining random functions (compared to accessing the global instance of rng) from Jax. In my experience, it simplifies testing and reproducibility.
  • The function has two returns, one is the array of samples, and the second one of the type error . That’s the way to handle the “exception” execution flow in Go (Go doesn’t have assertions or exceptions).

Now we need to figure out how to sample elements from the discrete probability distribution defined by the probs argument using only float random numbers sampled uniformly between [0, 1] returned by the rng. Luckily there’s a method for doing exactly that.

CDF Inversion Method

First of, CDF stands for cumulative distribution function. In the discrete case, it can be represented as the array where an element at index i is equal to the sum of all input probabilities up to and including the position i. Let’s materialize this formulation in a simple helper function.

func CDF(probs []float64) []float64 {
cdf := make([]float64, len(probs))
cum := 0.0

for i := range cdf {
cum += probs[i]
cdf[i] = cum
}

return cdf
}

With CDF of the discrete probability distribution and the random number generator, we can sample elements from the input collection by:

  1. Sample a random float between [0, 1] from the Unirofrm distribution.
  2. Find the first index where the CDF value is ≥ to the random float.
  3. Return the element of the original collection at this index.

To understand why it works, we can do a simple visual experiment. We can think about values in the CDF array being the right borders of the bins placed on the interval between [0, 1]. The width of a bin is proportional to the input probability. When generating the uniform random float between [0, 1] we can think about randomly throwing the ball on the interval and choosing the bin that we hit. The probability of hitting the bin is then proportional to the input probability (exactly what we need). Here’s the visual demonstration for our last example of collection [A, B, C] with associated probabilities [0.1, 0.7, 0.2].

Created by the author in Excalidraw

To get the index of the bin we can return the index of the first right border that is greater or equal to the sampled value. Again, a simple helper function for doing exactly this:

func FindIndexFromRight(val float64, cdf []float64) int {
for i, cumProb := range cdf {
if cumProb >= val {
return i
}
}

return len(cdf) - 1
}

Putting everything together

With that, we have everything to implement random.choice with repetitions. Sampling without repetitions requires one more trick. To ensure that we don’t draw the element that was already sampled, we can mask its probability with 0 after sampling. This, however, will invalidate our discrete probability distribution because its sum will no longer amount to 1. To correct for that we need to re-normalize the probabilities by dividing them by the new total sum. As a bonus, we can perform re-normalization directly on CDF instead of re-normalizing input probabilities and then computing CDF. Putting everything together:

func Choice[T any](
arr []T,
size int,
replace bool,
probs []float64,
rng *rand.Rand,
) ([]T, error) {
if !replace && (size > len(arr)) {
return nil, errors.New("cannot sample more than array size without replacements")
}

samples := make([]T, size)
probsCopy := make([]float64, len(probs))
copy(probsCopy, probs)

for i := 0; i < size; i++ {
cdf := CDF(probsCopy)

if !replace {
total := cdf[len(cdf)-1]
for cdfInd := range cdf {
cdf[cdfInd] /= total
}
}

randFloat := rng.Float64()
sampledIndex := FindIndexFromRight(randFloat, cdf)
samples[i] = arr[sampledIndex]

if !replace {
probsCopy[sampledIndex] = 0.0
}
}

return samples, nil
}

Again, a few things to pay attention to:

  • If we sample without replacement we can’t sample more than the initial size of an input collection.
  • Go passes “slices” (arrays without defined size) as mutable arguments. Thus, we make a copy the of input probabilities for not to mess up the original array with masking.

And the driver code:

rngSource := rand.NewSource(time.Now().UnixNano())
rng := rand.New(rngSource)

arr := []string{"A", "B", "C", "D"}
probs := []float64{0.1, 0.6, 0.2, 0.1}

samples, err := Choice(arr, 3, false, probs, rng)
if err != nil {
log.Fatal(err)
}

fmt.Println("Samples: ", samples)

That’s it, feel free to drop a question in the comments.

--

--