Sélection pondérée aléatoire en Java


Je veux choisir un élément aléatoire dans un ensemble, mais la chance de choisir un élément doit être proportionnelle au poids associé

Exemples d'entrées:

item                weight
----                ------
sword of misery         10
shield of happy          5
potion of dying          6
triple-edged sword       1

Donc, si j'ai 4 éléments possibles, la chance d'obtenir un élément sans poids serait de 1 sur 4.

Dans ce cas, un utilisateur devrait être 10 fois plus susceptible d'obtenir l'épée de misère que l'épée à triple tranchant.

Comment faire une sélection aléatoire pondérée en Java?

Author: Peter Lawrey, 2011-06-20

6 answers

J'utiliserais un NavigableMap

public class RandomCollection<E> {
    private final NavigableMap<Double, E> map = new TreeMap<Double, E>();
    private final Random random;
    private double total = 0;

    public RandomCollection() {
        this(new Random());
    }

    public RandomCollection(Random random) {
        this.random = random;
    }

    public RandomCollection<E> add(double weight, E result) {
        if (weight <= 0) return this;
        total += weight;
        map.put(total, result);
        return this;
    }

    public E next() {
        double value = random.nextDouble() * total;
        return map.higherEntry(value).getValue();
    }
}

Disons que j'ai une liste d'animaux chien, chat, cheval avec des probabilités de 40%, 35%, 25% respectivement

RandomCollection<String> rc = new RandomCollection<>()
                              .add(40, "dog").add(35, "cat").add(25, "horse");

for (int i = 0; i < 10; i++) {
    System.out.println(rc.next());
} 
 85
Author: Peter Lawrey, 2017-06-18 19:12:13

Vous ne trouverez pas de cadre pour ce genre de problème, car la fonctionnalité demandée n'est rien de plus qu'une simple fonction. Faites quelque chose comme ceci:

interface Item {
    double getWeight();
}

class RandomItemChooser {
    public Item chooseOnWeight(List<Item> items) {
        double completeWeight = 0.0;
        for (Item item : items)
            completeWeight += item.getWeight();
        double r = Math.random() * completeWeight;
        double countWeight = 0.0;
        for (Item item : items) {
            countWeight += item.getWeight();
            if (countWeight >= r)
                return item;
        }
        throw new RuntimeException("Should never be shown.");
    }
}
 23
Author: Arne Deutsch, 2014-09-25 21:19:26

Il existe maintenant une classe pour cela dans Apache Commons: EnumeratedDistribution

Item selectedItem = new EnumeratedDistribution(itemWeights).sample();

itemWeights est un List<Pair<Item,Double>>, comme (en supposant l'interface Item dans la réponse d'Arne):

List<Pair<Item,Double>> itemWeights = Collections.newArrayList();
for (Item i : itemSet) {
    itemWeights.add(new Pair(i, i.getWeight()));
}

Ou en Java 8:

itemSet.stream().map(i -> new Pair(i, i.getWeight())).collect(toList());

Remarque: Pair ici doit être org.apache.commons.math3.util.Pair, pas org.apache.commons.lang3.tuple.Pair.

 13
Author: kdkeck, 2017-05-17 20:50:38

Utiliser une méthode alias

Si vous allez rouler beaucoup de fois (comme dans un jeu), vous devriez utiliser une méthode alias.

Le code ci-dessous est une implémentation assez longue d'une telle méthode d'alias, en effet. Mais c'est à cause de la partie initialisation. La récupération des éléments est très rapide (voir les méthodes next et applyAsInt elles ne bouclent pas).

Utilisation

Set<Item> items = ... ;
ToDoubleFunction<Item> weighter = ... ;

Random random = new Random();

RandomSelector<T> selector = RandomSelector.weighted(items, weighter);
Item drop = selector.next(random);

Mise en œuvre

Cette implémentation:

  • utilise Java 8;
  • est conçu pour être aussi rapide que possible (enfin, au moins, j'ai essayé de le faire en utilisant le micro-benchmarking);
  • est totalement thread-safe (gardez un Random dans chaque thread pour des performances maximales, utilisez ThreadLocalRandom?);
  • récupère les éléments dans O (1) , contrairement à ce que vous trouvez principalement sur Internet ou sur StackOverflow, où les implémentations naïves s'exécutent dans O (n) ou O (log(n));
  • maintient le éléments indépendants de leur poids, de sorte qu'un élément peut être attribué divers poids dans différents contextes.

De toute façon, voici le code. (Notez que je maintiens une version à jour de cette classe.)

import static java.util.Objects.requireNonNull;

import java.util.*;
import java.util.function.*;

public final class RandomSelector<T> {

  public static <T> RandomSelector<T> weighted(Set<T> elements, ToDoubleFunction<? super T> weighter)
      throws IllegalArgumentException {
    requireNonNull(elements, "elements must not be null");
    requireNonNull(weighter, "weighter must not be null");
    if (elements.isEmpty()) { throw new IllegalArgumentException("elements must not be empty"); }

    // Array is faster than anything. Use that.
    int size = elements.size();
    T[] elementArray = elements.toArray((T[]) new Object[size]);

    double totalWeight = 0d;
    double[] discreteProbabilities = new double[size];

    // Retrieve the probabilities
    for (int i = 0; i < size; i++) {
      double weight = weighter.applyAsDouble(elementArray[i]);
      if (weight < 0.0d) { throw new IllegalArgumentException("weighter may not return a negative number"); }
      discreteProbabilities[i] = weight;
      totalWeight += weight;
    }
    if (totalWeight == 0.0d) { throw new IllegalArgumentException("the total weight of elements must be greater than 0"); }

    // Normalize the probabilities
    for (int i = 0; i < size; i++) {
      discreteProbabilities[i] /= totalWeight;
    }
    return new RandomSelector<>(elementArray, new RandomWeightedSelection(discreteProbabilities));
  }

  private final T[] elements;
  private final ToIntFunction<Random> selection;

  private RandomSelector(T[] elements, ToIntFunction<Random> selection) {
    this.elements = elements;
    this.selection = selection;
  }

  public T next(Random random) {
    return elements[selection.applyAsInt(random)];
  }

  private static class RandomWeightedSelection implements ToIntFunction<Random> {
    // Alias method implementation O(1)
    // using Vose's algorithm to initialize O(n)

    private final double[] probabilities;
    private final int[] alias;

    RandomWeightedSelection(double[] probabilities) {
      int size = probabilities.length;

      double average = 1.0d / size;
      int[] small = new int[size];
      int smallSize = 0;
      int[] large = new int[size];
      int largeSize = 0;

      // Describe a column as either small (below average) or large (above average).
      for (int i = 0; i < size; i++) {
        if (probabilities[i] < average) {
          small[smallSize++] = i;
        } else {
          large[largeSize++] = i;
        }
      }

      // For each column, saturate a small probability to average with a large probability.
      while (largeSize != 0 && smallSize != 0) {
        int less = small[--smallSize];
        int more = large[--largeSize];
        probabilities[less] = probabilities[less] * size;
        alias[less] = more;
        probabilities[more] += probabilities[less] - average;
        if (probabilities[more] < average) {
          small[smallSize++] = more;
        } else {
          large[largeSize++] = more;
        }
      }

      // Flush unused columns.
      while (smallSize != 0) {
        probabilities[small[--smallSize]] = 1.0d;
      }
      while (largeSize != 0) {
        probabilities[large[--largeSize]] = 1.0d;
      }
    }

    @Override public int applyAsInt(Random random) {
      // Call random once to decide which column will be used.
      int column = random.nextInt(probabilities.length);

      // Call random a second time to decide which will be used: the column or the alias.
      if (random.nextDouble() < probabilities[column]) {
        return column;
      } else {
        return alias[column];
      }
    }
  }
}
 4
Author: Olivier Grégoire, 2015-08-01 14:43:32
public class RandomCollection<E> {
  private final NavigableMap<Double, E> map = new TreeMap<Double, E>();
  private double total = 0;

  public void add(double weight, E result) {
    if (weight <= 0 || map.containsValue(result))
      return;
    total += weight;
    map.put(total, result);
  }

  public E next() {
    double value = ThreadLocalRandom.current().nextDouble() * total;
    return map.ceilingEntry(value).getValue();
  }
}
 1
Author: ronen, 2016-11-24 22:58:35

Si vous devez supprimer des éléments après avoir choisi, vous pouvez utiliser une autre solution. Ajoutez tous les éléments dans une 'LinkedList', chaque élément doit être ajouté autant de fois que son poids est, puis utilisez Collections.shuffle() qui, selon JavaDoc

Permute aléatoirement la liste spécifiée en utilisant une source aléatoire par défaut. Toutes les permutations se produisent avec une probabilité à peu près égale.

Enfin, obtenez et supprimez des éléments en utilisant pop() ou removeFirst()

Map<String, Integer> map = new HashMap<String, Integer>() {{
    put("Five", 5);
    put("Four", 4);
    put("Three", 3);
    put("Two", 2);
    put("One", 1);
}};

LinkedList<String> list = new LinkedList<>();

for (Map.Entry<String, Integer> entry : map.entrySet()) {
    for (int i = 0; i < entry.getValue(); i++) {
        list.add(entry.getKey());
    }
}

Collections.shuffle(list);

int size = list.size();
for (int i = 0; i < size; i++) {
    System.out.println(list.pop());
}
 1
Author: Yuri Heiko, 2017-09-13 04:12:50