Weighted random number
Write a function that returns values randomly, according to their weight.
Let me give you an example. Suppose we have 3 elements with their weights: A (1), B (1) and C (2). The function should return A with probability 25%, B with 25% and C with 50% based on the weights.
The answer is not obvious, but it’s not too hard to think. Also, writing bug-free code would fail majority candidates. This is really the perfect question for coding interviews.
Solution)
import java.util.NavigableMap;
import java.util.Random;
import java.util.TreeMap;
public class WeightedCollection<E> {
private NavigableMap<Integer, E> map = new TreeMap<Integer, E>();
private Random random;
private int total = 0;
public WeightedCollection() {
this(new Random());
}
public WeightedCollection(Random random) {
this.random = random;
}
public void add(int weight, E object) {
if (weight <= 0) return;
total += weight;
map.put(total, object);
}
public E next() {
int value = random.nextInt(total) + 1; // Can also use floating-point weights
return map.ceilingEntry(value).getValue();
}
}
weighted
TreeMap
NavigableMap<K,V>
1) Given weights are already given (pre-processing) : TreeMap use ceilingEntry method
2) store all sums and find matching one (0.1,0.2,0.3,0.4) 0.7 (O(n) or O(logn) by binary search)
Possible questions)
- Are the given weights integer or float?
- What are the element for the Random component? (general element <E> or String or Character?)
- Can we use the existing Random library which can generate random numbers uniformely?
- What are the main goal of this component? Is this time-sensitive or memory-sensitive?
- How about weights? Can we assume that we have the pre-defined weights for the component or we will allow the weight to be added dynamically?
- Do we have to consider multi-threading environment?
Testing
- Unit testing (functional testing)
- Scalable testing (more results, plot it (histogram), and then check the probability of each element
This solution is to use Integer as a weight and general element
|----------------------------------------------|
|-- A --|-- B ----------|-- C -----------------|
^ \(sum\)
|-- A --|-- B ----------|-- C -----------------|-- D -----|
sum . sum+D's weight
Interval A, Interval B, Interval C
maintaining sum of weights can be also applied to floating value
public class WeightedRandomGenerator<E> {
class Interval<E> {
int start;
int end;
E element;
Interval() {
this(0,0, null);
}
Interval(int s, int e, E elem) {
start = s;
end = e;
element = elem;
}
boolean in(int value) {
return value > s && value <= end;
}
}
private Random random;
private List<Interval> weights;
private int total;
public WeightedRandomGenerator() {
this(new Random(System.currentTimeMillis());
}
public WeightedRandomGenerator(Random r) {
random = r;
total = 0;
weights = new ArrayList<>();
}
public synchronized void addWeight(int weight, E element) {
weights.add(new Interval(total, total+weight, element);
total += weight;
}
private Interval findInterval(int value) {
for (Interval interval : weights)
if (interval.in(value)) return interval;
return null;
}
private Interval findIntervalBS(int value) {
int s = 0, e = weights.size()-1;
while(s < e) {
int m = (s+e)/2;
Interval interval = weights.get(m);
if (interval.in(value)) return interval;
if (value > interval.end) s = m+1;
else e = m -1;
return null;
}
// O(log n)
public synchronized E next() {
if (total == 0) return null;
// generate random value from Random function
// find the Interval for the value
Interval interval = findInterval(random.nextInt(total)+1);
// assume that we can return the last element for the given random value 1
return interval == null ? weights.get(weights.size()-1).element : interval.element;
}
}
If we have an enough memory, we can create a hashmap for storing all possible integer values with element.
public class WeightedRandomGenerator<E> {
private Random random;
private ArrayList<E> weights;
public WeightedRandomGenerator() {
this(new Random(System.currentTimeMillis());
}
public WeightedRandomGenerator(Random r) {
random = r;
weights = new ArrayList<>();
}
public synchronized void addWeight(int weight, E element) {
// add elements which size is weight (0 - weight-1)
for (int i = 0; i < weight; i++)
weights.add(element);
}
// O(1)
public synchronized E next() {
if (weights.size() == 0) return null;
// generate random value from Random function
// find the Interval for the value
int value = random.nextInt(weights.size());
// assume that we can return the last element for the given random value 1
return value != weights.size() ? weights.get(value) : weights.get(weights.size()-1);
}
}