FrequencyClusterer.java
package org.hammer.audio.experimental.acoustic.tracking;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
/**
* Groups {@link DetectedPeak detected peaks} across microphone channels into stable {@link
* FrequencyCluster frequency clusters}.
*
* <p>Peaks are sorted by magnitude and assigned greedily to the nearest existing cluster (in Hz),
* subject to a configurable matching tolerance. The tolerance can be specified either in absolute
* Hz or in cents (1200 cents = one octave); both are honoured and a peak matches when it falls
* within either tolerance. This combined matching gives stable behaviour when two sources sit close
* in frequency: the matching tolerance keeps each cluster narrow enough to keep them apart while
* the cents tolerance avoids over-splitting at high frequencies where the absolute spacing
* naturally grows.
*
* <p>Each cluster must collect peaks from at least {@code minPeaksPerCluster} distinct channels to
* be emitted; transient single-channel artefacts are dropped. The output is ordered by total
* magnitude (descending) and capped at {@code maxClusters}.
*/
public final class FrequencyClusterer {
private final double matchToleranceHz;
private final double matchToleranceCents;
private final int minPeaksPerCluster;
private final int maxClusters;
/**
* Configure a clusterer.
*
* @param matchToleranceHz absolute matching tolerance in Hz (>= 0)
* @param matchToleranceCents relative matching tolerance in cents (>= 0)
* @param minPeaksPerCluster minimum number of channels that must contribute a peak (>= 1)
* @param maxClusters maximum number of clusters returned (>= 1)
*/
public FrequencyClusterer(
double matchToleranceHz,
double matchToleranceCents,
int minPeaksPerCluster,
int maxClusters) {
if (!Double.isFinite(matchToleranceHz) || matchToleranceHz < 0.0) {
throw new IllegalArgumentException("matchToleranceHz must be finite and >= 0");
}
if (!Double.isFinite(matchToleranceCents) || matchToleranceCents < 0.0) {
throw new IllegalArgumentException("matchToleranceCents must be finite and >= 0");
}
if (minPeaksPerCluster < 1) {
throw new IllegalArgumentException("minPeaksPerCluster must be >= 1");
}
if (maxClusters < 1) {
throw new IllegalArgumentException("maxClusters must be >= 1");
}
this.matchToleranceHz = matchToleranceHz;
this.matchToleranceCents = matchToleranceCents;
this.minPeaksPerCluster = minPeaksPerCluster;
this.maxClusters = maxClusters;
}
/** Cluster peaks supplied as one list per channel. */
public List<FrequencyCluster> clusterPerChannel(
Collection<? extends Collection<DetectedPeak>> perChannel) {
if (perChannel == null) {
throw new IllegalArgumentException("perChannel must not be null");
}
List<DetectedPeak> all = new ArrayList<>();
for (Collection<DetectedPeak> peaks : perChannel) {
if (peaks != null) {
all.addAll(peaks);
}
}
return cluster(all);
}
/** Cluster a flat collection of peaks from possibly several channels. */
public List<FrequencyCluster> cluster(Collection<DetectedPeak> peaks) {
if (peaks == null) {
throw new IllegalArgumentException("peaks must not be null");
}
List<DetectedPeak> sorted = new ArrayList<>(peaks);
sorted.sort(Comparator.comparingDouble(DetectedPeak::magnitude).reversed());
List<List<DetectedPeak>> buckets = new ArrayList<>();
List<Double> bucketCenters = new ArrayList<>();
for (DetectedPeak peak : sorted) {
int matched = findMatch(bucketCenters, peak.frequencyHz());
if (matched >= 0 && hasChannel(buckets.get(matched), peak.channel())) {
// Already have a peak for this channel in the cluster; skip the weaker one.
continue;
}
if (matched < 0) {
List<DetectedPeak> bucket = new ArrayList<>();
bucket.add(peak);
buckets.add(bucket);
bucketCenters.add(peak.frequencyHz());
} else {
buckets.get(matched).add(peak);
bucketCenters.set(matched, weightedCentre(buckets.get(matched)));
}
}
List<FrequencyCluster> clusters = new ArrayList<>();
for (int i = 0; i < buckets.size(); i++) {
List<DetectedPeak> bucket = buckets.get(i);
if (countDistinctChannels(bucket) < minPeaksPerCluster) {
continue;
}
double totalMagnitude = 0.0;
for (DetectedPeak peak : bucket) {
totalMagnitude += peak.magnitude();
}
clusters.add(new FrequencyCluster(bucketCenters.get(i), totalMagnitude, bucket));
}
clusters.sort(Comparator.comparingDouble(FrequencyCluster::totalMagnitude).reversed());
if (clusters.size() > maxClusters) {
return Collections.unmodifiableList(new ArrayList<>(clusters.subList(0, maxClusters)));
}
return Collections.unmodifiableList(clusters);
}
private int findMatch(List<Double> centres, double frequencyHz) {
int best = -1;
double bestDelta = Double.POSITIVE_INFINITY;
for (int i = 0; i < centres.size(); i++) {
double centre = centres.get(i);
double delta = Math.abs(centre - frequencyHz);
if (matches(centre, frequencyHz) && delta < bestDelta) {
bestDelta = delta;
best = i;
}
}
return best;
}
private boolean matches(double centre, double frequencyHz) {
if (Math.abs(centre - frequencyHz) <= matchToleranceHz) {
return true;
}
if (matchToleranceCents <= 0.0 || centre <= 0.0 || frequencyHz <= 0.0) {
return false;
}
double cents = 1200.0 * Math.log(frequencyHz / centre) / Math.log(2.0);
return Math.abs(cents) <= matchToleranceCents;
}
private static boolean hasChannel(List<DetectedPeak> bucket, int channel) {
for (DetectedPeak peak : bucket) {
if (peak.channel() == channel) {
return true;
}
}
return false;
}
private static int countDistinctChannels(List<DetectedPeak> bucket) {
int mask = 0;
int count = 0;
for (DetectedPeak peak : bucket) {
int bit = 1 << Math.min(peak.channel(), 30);
if ((mask & bit) == 0) {
mask |= bit;
count++;
}
}
return count;
}
private static double weightedCentre(List<DetectedPeak> bucket) {
double sum = 0.0;
double weight = 0.0;
for (DetectedPeak peak : bucket) {
sum += peak.frequencyHz() * peak.magnitude();
weight += peak.magnitude();
}
return weight > 0.0 ? sum / weight : bucket.get(0).frequencyHz();
}
}