SemanticChunkingService.java

package com.taxonomy.provenance.service;

import com.taxonomy.shared.service.LocalEmbeddingService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;

import java.text.BreakIterator;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;

/**
 * Embedding-based semantic chunking as a fallback for documents with few or no
 * recognisable headings.
 *
 * <p>Algorithm:
 * <ol>
 *   <li>Split text into individual sentences.</li>
 *   <li>Form overlapping windows of {@code windowSize} sentences.</li>
 *   <li>Embed each window with {@link LocalEmbeddingService}.</li>
 *   <li>Compute cosine distance between consecutive windows.</li>
 *   <li>Distances above a threshold indicate topic changes → split points.</li>
 *   <li>Build chunks from the split points.</li>
 * </ol>
 *
 * <p>Requires the local embedding service to be available (ONNX model loaded).
 */
@Service
public class SemanticChunkingService {

    private static final Logger log = LoggerFactory.getLogger(SemanticChunkingService.class);

    private static final int DEFAULT_WINDOW_SIZE = 5;
    private static final double DEFAULT_THRESHOLD = 0.3;

    private final LocalEmbeddingService embeddingService;

    public SemanticChunkingService(LocalEmbeddingService embeddingService) {
        this.embeddingService = embeddingService;
    }

    /**
     * Splits raw text into semantically coherent chunks.
     *
     * @param rawText    the raw text to chunk
     * @param windowSize number of sentences per sliding window (default 5)
     * @return list of chunk strings, or a single-element list containing the
     *         full text if embedding is unavailable or too few sentences exist
     */
    public List<String> chunk(String rawText, int windowSize) {
        if (rawText == null || rawText.isBlank()) {
            return List.of();
        }

        List<String> sentences = splitIntoSentences(rawText);

        if (sentences.size() < windowSize + 1 || !embeddingService.isAvailable()) {
            log.debug("Semantic chunking skipped (sentences={}, embedding={})",
                    sentences.size(), embeddingService.isAvailable());
            return List.of(rawText.strip());
        }

        try {
            return doSemanticChunking(sentences, windowSize);
        } catch (Exception e) {
            log.warn("Semantic chunking failed, returning full text as single chunk", e);
            return List.of(rawText.strip());
        }
    }

    /**
     * Convenience overload using the default window size.
     */
    public List<String> chunk(String rawText) {
        return chunk(rawText, DEFAULT_WINDOW_SIZE);
    }

    private List<String> doSemanticChunking(List<String> sentences, int windowSize) throws Exception {
        // Build sliding-window embeddings
        List<float[]> embeddings = new ArrayList<>();
        for (int i = 0; i <= sentences.size() - windowSize; i++) {
            String window = String.join(" ", sentences.subList(i, i + windowSize));
            embeddings.add(embeddingService.embed(window));
        }

        // Compute cosine distances between consecutive windows
        List<Double> distances = new ArrayList<>();
        for (int i = 0; i < embeddings.size() - 1; i++) {
            distances.add(1.0 - cosineSimilarity(embeddings.get(i), embeddings.get(i + 1)));
        }

        // Find split points where distance exceeds threshold
        List<Integer> splitPoints = findSplitPoints(distances, DEFAULT_THRESHOLD);

        // Build chunks from split points
        return buildChunksFromSplitPoints(sentences, splitPoints, windowSize);
    }

    /**
     * Finds indices where the cosine distance exceeds the threshold,
     * indicating likely topic boundaries.
     */
    List<Integer> findSplitPoints(List<Double> distances, double threshold) {
        List<Integer> splits = new ArrayList<>();
        for (int i = 0; i < distances.size(); i++) {
            if (distances.get(i) > threshold) {
                splits.add(i);
            }
        }
        return splits;
    }

    /**
     * Reassembles sentence groups into chunks using the identified split points.
     */
    private List<String> buildChunksFromSplitPoints(List<String> sentences,
                                                     List<Integer> splitPoints,
                                                     int windowSize) {
        List<String> chunks = new ArrayList<>();
        int offset = windowSize / 2;  // centre of the window
        int start = 0;

        for (int splitIdx : splitPoints) {
            int splitSentence = splitIdx + offset;
            if (splitSentence > start && splitSentence < sentences.size()) {
                chunks.add(joinSentences(sentences, start, splitSentence));
                start = splitSentence;
            }
        }
        // Remaining sentences
        if (start < sentences.size()) {
            chunks.add(joinSentences(sentences, start, sentences.size()));
        }

        return chunks;
    }

    private static String joinSentences(List<String> sentences, int from, int to) {
        StringBuilder sb = new StringBuilder();
        for (int i = from; i < to; i++) {
            if (!sb.isEmpty()) sb.append(' ');
            sb.append(sentences.get(i).strip());
        }
        return sb.toString();
    }

    /**
     * Splits text into sentences using a German-locale {@link BreakIterator}.
     */
    List<String> splitIntoSentences(String text) {
        List<String> sentences = new ArrayList<>();
        BreakIterator iter = BreakIterator.getSentenceInstance(Locale.GERMAN);
        iter.setText(text);
        int start = iter.first();
        for (int end = iter.next(); end != BreakIterator.DONE; end = iter.next()) {
            String sentence = text.substring(start, end).strip();
            if (!sentence.isEmpty()) {
                sentences.add(sentence);
            }
            start = end;
        }
        return sentences;
    }

    /**
     * Computes the cosine similarity between two vectors.
     */
    static double cosineSimilarity(float[] a, float[] b) {
        double dotProduct = 0.0;
        double normA = 0.0;
        double normB = 0.0;
        for (int i = 0; i < a.length; i++) {
            dotProduct += a[i] * b[i];
            normA += (double) a[i] * a[i];
            normB += (double) b[i] * b[i];
        }
        double denominator = Math.sqrt(normA) * Math.sqrt(normB);
        return denominator == 0 ? 0 : dotProduct / denominator;
    }
}