Weighted random number

Write a function that returns values randomly, according to their weight.

Let me give you an example. Suppose we have 3 elements with their weights: A (1), B (1) and C (2). The function should return A with probability 25%, B with 25% and C with 50% based on the weights.

The answer is not obvious, but it’s not too hard to think. Also, writing bug-free code would fail majority candidates. This is really the perfect question for coding interviews.

Solution)

import java.util.NavigableMap;
import java.util.Random;
import java.util.TreeMap;

public class WeightedCollection<E> {

    private NavigableMap<Integer, E> map = new TreeMap<Integer, E>();
    private Random random;
    private int total = 0;

    public WeightedCollection() {
        this(new Random());
    }

    public WeightedCollection(Random random) {
        this.random = random;
    }

    public void add(int weight, E object) {
        if (weight <= 0) return;
        total += weight;
        map.put(total, object);
    }

    public E next() {
        int value = random.nextInt(total) + 1; // Can also use floating-point weights
        return map.ceilingEntry(value).getValue();
    }

}

weighted

TreeMap

NavigableMap<K,V>

1) Given weights are already given (pre-processing) : TreeMap use ceilingEntry method

2) store all sums and find matching one (0.1,0.2,0.3,0.4) 0.7 (O(n) or O(logn) by binary search)

Possible questions)

  1. Are the given weights integer or float?
  2. What are the element for the Random component? (general element <E> or String or Character?)
  3. Can we use the existing Random library which can generate random numbers uniformely?
  4. What are the main goal of this component? Is this time-sensitive or memory-sensitive?
  5. How about weights? Can we assume that we have the pre-defined weights for the component or we will allow the weight to be added dynamically?
  6. Do we have to consider multi-threading environment?

Testing

  1. Unit testing (functional testing)
  2. Scalable testing (more results, plot it (histogram), and then check the probability of each element

This solution is to use Integer as a weight and general element

|----------------------------------------------|

|-- A --|-- B ----------|-- C -----------------|

                                                   ^ \(sum\)

|-- A --|-- B ----------|-- C -----------------|-- D -----|

                                                   sum .     sum+D's weight

Interval A, Interval B, Interval C

maintaining sum of weights can be also applied to floating value

public class WeightedRandomGenerator<E> {
    class Interval<E> {
        int start;
        int end;
        E element;

        Interval() {
            this(0,0, null);
        }
        Interval(int s, int e, E elem) {
            start = s;
            end = e;
            element = elem;
        }

        boolean in(int value) {
            return value > s && value <= end;
        }
    }

    private Random random;
    private List<Interval> weights;
    private int total;

    public WeightedRandomGenerator() {
        this(new Random(System.currentTimeMillis());
    }

    public WeightedRandomGenerator(Random r) {
        random = r;
        total = 0;
        weights = new ArrayList<>();
    }

    public synchronized void addWeight(int weight, E element) {
        weights.add(new Interval(total, total+weight, element);
        total += weight;
    }

    private Interval findInterval(int value) {
        for (Interval interval : weights)
            if (interval.in(value)) return interval;
        return null;       
    }

    private Interval findIntervalBS(int value) {
        int s = 0, e = weights.size()-1;
        while(s < e) {
            int m = (s+e)/2;
            Interval interval = weights.get(m);
            if (interval.in(value)) return interval;
            if (value > interval.end) s = m+1;
            else e = m -1;
        return null;
    }

    // O(log n)
    public synchronized E next() {
        if (total == 0) return null;
        // generate random value from Random function
        // find the Interval for the value
        Interval interval = findInterval(random.nextInt(total)+1);
        // assume that we can return the last element for the given random value 1
        return interval == null ? weights.get(weights.size()-1).element : interval.element;
    }
}

If we have an enough memory, we can create a hashmap for storing all possible integer values with element.

public class WeightedRandomGenerator<E> {
    private Random random;
    private ArrayList<E> weights;

    public WeightedRandomGenerator() {
        this(new Random(System.currentTimeMillis());
    }

    public WeightedRandomGenerator(Random r) {
        random = r;
        weights = new ArrayList<>();
    }

    public synchronized void addWeight(int weight, E element) {
        // add elements which size is weight (0 - weight-1)
        for (int i = 0; i < weight; i++)
            weights.add(element);
    }

    // O(1)
    public synchronized E next() {
        if (weights.size() == 0) return null;
        // generate random value from Random function
        // find the Interval for the value
        int value = random.nextInt(weights.size());
        // assume that we can return the last element for the given random value 1
        return value != weights.size() ? weights.get(value) : weights.get(weights.size()-1);
    }
}

results matching ""

    No results matching ""