LocalEmbeddingService.java
package com.taxonomy.shared.service;
import ai.djl.inference.Predictor;
import ai.djl.huggingface.translator.TextEmbeddingTranslatorFactory;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import com.taxonomy.dto.TaxonomyNodeDto;
import com.taxonomy.catalog.model.TaxonomyNode;
import jakarta.persistence.EntityManager;
import jakarta.persistence.PersistenceContext;
import org.hibernate.search.mapper.orm.Search;
import org.hibernate.search.mapper.orm.session.SearchSession;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.util.*;
import java.util.stream.Collectors;
import com.taxonomy.analysis.service.LlmService;
import com.taxonomy.search.NodeEmbeddingBinder;
/**
* Local embedding service that scores taxonomy nodes against a business requirement using
* the {@code BAAI/bge-small-en-v1.5} ONNX model loaded via DJL.
*
* <h2>Architecture</h2>
* <p>The DJL model is <em>lazily initialised</em> on first use — application startup is not
* slowed down and no model is downloaded unless actually needed.
*
* <p>Vector storage and KNN retrieval are handled by Hibernate Search (Lucene backend).
* The {@code @VectorField(name = "embedding")} on {@link TaxonomyNode} (via
* {@link com.taxonomy.search.NodeEmbeddingBinder}) stores the pre-computed embedding.
* Queries use {@code f.knn(k).field("embedding").matching(queryVector)}.
*
* <h2>Configuration</h2>
* <ul>
* <li>{@code TAXONOMY_EMBEDDING_ENABLED} (default {@code true}) — set to {@code false} to
* disable all embedding and semantic search globally.</li>
* <li>{@code TAXONOMY_EMBEDDING_MODEL_DIR} — path to a pre-downloaded model directory;
* empty = auto-download from HuggingFace into {@code ~/.djl.ai/cache/taxonomy/}.</li>
* <li>{@code TAXONOMY_EMBEDDING_MODEL_NAME} — HuggingFace model URL or local path;
* default {@code https://huggingface.co/BAAI/bge-small-en-v1.5}.</li>
* <li>{@code TAXONOMY_EMBEDDING_ALLOW_DOWNLOAD} (default {@code true}) — set to
* {@code false} to prevent runtime model downloads (CI mode). When disabled, a local
* model must be provided via {@code TAXONOMY_EMBEDDING_MODEL_DIR}.</li>
* </ul>
*
* <h2>Graceful degradation</h2>
* <p>When {@code TAXONOMY_EMBEDDING_ENABLED=false} or the model fails to load, all semantic
* search methods return empty results without throwing, and {@link #isAvailable()} returns
* {@code false}.
*
* <h2>Scoring</h2>
* <p>Hibernate Search's KNN query returns cosine similarity scores in [0, 1].
* Raw cosine similarity is recovered as {@code 2 * luceneScore - 1} and mapped to 0–100.
*
* <p>Enable as the LLM provider with {@code LLM_PROVIDER=LOCAL_ONNX}. No API key required.
*/
@Service
public class LocalEmbeddingService {
private static final Logger log = LoggerFactory.getLogger(LocalEmbeddingService.class);
/**
* Default HuggingFace model identifier for downloading the ONNX export.
*
* <p><strong>Note:</strong> The previous default {@code djl://ai.djl.huggingface/BAAI/bge-small-en-v1.5}
* never worked because DJL's {@code HfModelZoo} registers with GROUP_ID
* {@code ai.djl.huggingface.pytorch} and only supports the PyTorch engine —
* not OnnxRuntime. We now download the ONNX model files directly from
* HuggingFace into a local cache directory.
*/
public static final String DEFAULT_MODEL_URL =
"https://huggingface.co/BAAI/bge-small-en-v1.5";
/** Base URL pattern for resolving individual files from a HuggingFace model repo. */
private static final String HF_RESOLVE_PATTERN = "%s/resolve/main/%s";
/** Files to download from HuggingFace for a working ONNX model directory. */
private static final String[] HF_MODEL_FILES = {
"onnx/model.onnx",
"tokenizer.json"
};
/**
* Default query prefix for asymmetric retrieval with BGE models.
* Prepended to query texts (but not document texts) so the model produces
* retrieval-oriented embeddings that work better for search and scoring.
*/
static final String DEFAULT_QUERY_PREFIX =
"Represent this sentence for searching relevant passages: ";
/**
* Raw cosine-similarity threshold below which a node receives score 0.
* Value in (-1, 1); 0.25 means "weak or no semantic overlap".
*/
static final double THRESHOLD = 0.25;
// ── Configuration (from application.properties / env vars) ───────────────
@Value("${embedding.enabled:true}")
private boolean embeddingEnabled;
@Value("${embedding.model.dir:}")
private String modelDir;
@Value("${embedding.model.name:https://huggingface.co/BAAI/bge-small-en-v1.5}")
private String modelName;
@Value("${embedding.query.prefix:Represent this sentence for searching relevant passages: }")
private String queryPrefix;
@Value("${embedding.allow-download:true}")
private boolean allowDownload;
// ── DJL model (lazy) ──────────────────────────────────────────────────────
private volatile ZooModel<String, float[]> model;
private volatile boolean modelLoadFailed = false;
private final Object modelLock = new Object();
// ── Dependencies ──────────────────────────────────────────────────────────
@PersistenceContext
private EntityManager entityManager;
// ── Model lifecycle ───────────────────────────────────────────────────────
/**
* Returns {@code true} if embedding is configured as enabled
* ({@code TAXONOMY_EMBEDDING_ENABLED=true}, which is the default).
*/
public boolean isEnabled() {
return embeddingEnabled;
}
/**
* Returns {@code true} if embedding is globally enabled AND the DJL model loaded
* successfully (or has not been tried yet).
*/
public boolean isAvailable() {
return embeddingEnabled && !modelLoadFailed;
}
/**
* Returns the effective DJL model URL: {@link #modelDir} if set (offline cache),
* otherwise {@link #modelName}.
*/
public String effectiveModelUrl() {
return (modelDir != null && !modelDir.isBlank()) ? modelDir : modelName;
}
/** Returns the lazily loaded DJL model, downloading it on first call. */
ZooModel<String, float[]> getModel() throws Exception {
if (!embeddingEnabled) {
throw new IllegalStateException("Embedding is disabled (TAXONOMY_EMBEDDING_ENABLED=false)");
}
if (modelLoadFailed) {
throw new IllegalStateException("DJL model failed to load previously; embedding unavailable");
}
if (model == null) {
synchronized (modelLock) {
if (model == null) {
String url = effectiveModelUrl();
// When downloads are disabled, only local paths are allowed
if (!allowDownload && (url.startsWith("http://") || url.startsWith("https://")
|| url.startsWith("djl://"))) {
modelLoadFailed = true;
log.error("Model download disabled (embedding.allow-download=false) "
+ "and no local model found. Set TAXONOMY_EMBEDDING_MODEL_DIR.");
throw new IllegalStateException(
"No local model and download disabled (TAXONOMY_EMBEDDING_ALLOW_DOWNLOAD=false)");
}
log.info("Loading embedding model via DJL / ONNX Runtime from {} …", url);
try {
model = loadModel(url);
log.info("Embedding model loaded successfully.");
} catch (Exception | LinkageError primary) {
// LinkageError covers UnsatisfiedLinkError / NoClassDefFoundError
// from native DJL / ONNX Runtime library loading failures.
modelLoadFailed = true;
log.error("Failed to load embedding model from '{}'; semantic search disabled. Error: {}",
url, primary.getMessage());
if (primary instanceof Exception ex) throw ex;
throw new Exception("Native library loading failed", primary);
}
}
}
}
return model;
}
/**
* Loads a DJL ONNX model from the given URL or local path.
*
* <p>Supported URL schemes:
* <ul>
* <li>{@code https://huggingface.co/{org}/{model}} — downloads {@code onnx/model.onnx}
* and {@code tokenizer.json} to a local cache directory, generates
* {@code serving.properties}, and loads from there.</li>
* <li>Local directory path — loads directly (auto-generates {@code serving.properties}
* if missing).</li>
* <li>{@code file://} prefix — same as local directory.</li>
* </ul>
*/
private ZooModel<String, float[]> loadModel(String url) throws Exception {
String localPath;
if (url.startsWith("https://huggingface.co/") || url.startsWith("http://huggingface.co/")) {
localPath = downloadHuggingFaceModel(url);
} else if (url.startsWith("djl://")) {
// Legacy djl:// URLs never worked with OnnxRuntime — extract the model ID
// and try to download from HuggingFace instead.
// djl://ai.djl.huggingface/BAAI/bge-small-en-v1.5 → BAAI/bge-small-en-v1.5
String modelId = url.replaceFirst("djl://[^/]+/", "");
String hfUrl = "https://huggingface.co/" + modelId;
log.warn("Migrating legacy djl:// URL to HuggingFace download: {} → {}", url, hfUrl);
localPath = downloadHuggingFaceModel(hfUrl);
} else {
// Local path or file: URI
if (url.startsWith("file:")) {
try {
// Use URI + Paths to correctly handle file:/, file:///, and Windows drive letters
localPath = java.nio.file.Paths.get(java.net.URI.create(url)).toString();
} catch (IllegalArgumentException e) {
// Fall back to legacy-style handling if URI is malformed
log.warn("Invalid file: URI '{}', falling back to raw path handling", url, e);
localPath = url.replaceFirst("^file:(//)?", "");
}
} else {
localPath = url;
}
}
ensureServingProperties(localPath);
java.nio.file.Path modelPath = java.nio.file.Path.of(localPath);
log.info("Loading DJL model from local path: {}", modelPath.toAbsolutePath());
try {
return Criteria.builder()
.setTypes(String.class, float[].class)
.optModelPath(modelPath)
.optModelName("model")
.optEngine("OnnxRuntime")
.optArgument("includeTokenTypes", true)
.optTranslatorFactory(new TextEmbeddingTranslatorFactory())
.build().loadModel();
} catch (Exception e) {
log.error("DJL Criteria.loadModel() failed for path '{}': {}", modelPath.toAbsolutePath(), e.getMessage(), e);
throw e;
}
}
/**
* Downloads the ONNX model files from a HuggingFace repository URL into a local
* cache directory under {@code ~/.djl.ai/cache/taxonomy/}.
*
* <p>Downloads are skipped if the files already exist locally (idempotent).
*
* @param hfRepoUrl e.g. {@code https://huggingface.co/BAAI/bge-small-en-v1.5}
* @return absolute path to the local model directory
*/
private String downloadHuggingFaceModel(String hfRepoUrl) throws Exception {
// Derive a cache directory name from the repo URL
// https://huggingface.co/BAAI/bge-small-en-v1.5 → BAAI--bge-small-en-v1.5
String repoId = hfRepoUrl
.replaceFirst("https?://huggingface\\.co/", "")
.replaceAll("[/\\\\]", "--");
java.nio.file.Path cacheDir = java.nio.file.Path.of(
System.getProperty("user.home"), ".djl.ai", "cache", "taxonomy", repoId);
java.nio.file.Files.createDirectories(cacheDir);
String baseUrl = hfRepoUrl.endsWith("/")
? hfRepoUrl.substring(0, hfRepoUrl.length() - 1) : hfRepoUrl;
for (String relPath : HF_MODEL_FILES) {
String fileUrl = String.format(HF_RESOLVE_PATTERN, baseUrl, relPath);
// Flatten onnx/model.onnx → model.onnx in local cache
String localName = relPath.contains("/")
? relPath.substring(relPath.lastIndexOf('/') + 1) : relPath;
java.nio.file.Path localFile = cacheDir.resolve(localName);
if (java.nio.file.Files.exists(localFile) && java.nio.file.Files.size(localFile) > 0) {
log.debug("Model file already cached: {}", localFile);
continue;
}
log.info("Downloading {} → {}", fileUrl, localFile);
java.net.http.HttpClient httpClient = java.net.http.HttpClient.newBuilder()
.connectTimeout(java.time.Duration.ofSeconds(30))
.followRedirects(java.net.http.HttpClient.Redirect.NORMAL)
.build();
java.net.http.HttpRequest request = java.net.http.HttpRequest.newBuilder()
.uri(java.net.URI.create(fileUrl))
.timeout(java.time.Duration.ofMinutes(5))
.GET()
.build();
java.net.http.HttpResponse<java.io.InputStream> response;
try {
response = httpClient.send(
request, java.net.http.HttpResponse.BodyHandlers.ofInputStream());
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
throw new java.io.IOException("Download interrupted for " + fileUrl, ie);
}
if (response.statusCode() != 200) {
modelLoadFailed = true;
throw new Exception("Failed to download " + fileUrl
+ ": HTTP " + response.statusCode());
}
try (java.io.InputStream in = response.body()) {
java.nio.file.Files.copy(in, localFile,
java.nio.file.StandardCopyOption.REPLACE_EXISTING);
}
log.info("Downloaded {} ({} bytes)", localName, java.nio.file.Files.size(localFile));
}
return cacheDir.toAbsolutePath().toString();
}
/**
* If {@code url} points to a local directory that contains an ONNX model file but no
* valid {@code serving.properties}, this method generates a minimal one so that DJL can
* discover the model. Also detects and repairs malformed files (e.g. with leading
* whitespace from YAML heredoc indentation in CI workflows).
*/
private static final String SERVING_PROPERTIES_CONTENT =
"engine=OnnxRuntime\n"
+ "option.modelName=model\n"
+ "translatorFactory=ai.djl.huggingface.translator.TextEmbeddingTranslatorFactory\n"
+ "option.mapLocation=true\n"
+ "option.includeTokenTypes=true\n";
private void ensureServingProperties(String url) {
try {
String path = url.startsWith("file://") ? url.substring("file://".length()) : url;
java.nio.file.Path dir = java.nio.file.Path.of(path);
if (!java.nio.file.Files.isDirectory(dir)) return;
java.nio.file.Path servingProps = dir.resolve("serving.properties");
// Check if existing file has valid content (correct engine + translator factory)
if (java.nio.file.Files.exists(servingProps)) {
String existing = java.nio.file.Files.readString(servingProps);
if (existing.contains("engine=OnnxRuntime")
&& existing.contains("TextEmbeddingTranslatorFactory")
&& existing.contains("includeTokenTypes=true")) {
return; // file is valid
}
log.warn("serving.properties exists but is missing OnnxRuntime engine, "
+ "TextEmbeddingTranslatorFactory, or includeTokenTypes; regenerating");
// fall through to regeneration
}
// Only generate if there is actually an ONNX model file present
boolean hasOnnx;
try (var files = java.nio.file.Files.list(dir)) {
hasOnnx = files.anyMatch(p -> p.getFileName().toString().endsWith(".onnx"));
}
if (!hasOnnx) return;
java.nio.file.Files.writeString(servingProps, SERVING_PROPERTIES_CONTENT);
log.info("Auto-generated serving.properties in {}", dir);
} catch (Exception e) {
log.warn("Could not auto-generate serving.properties: {}", e.getMessage());
}
}
// ── Index status ──────────────────────────────────────────────────────────
/**
* Returns the number of nodes currently in the Hibernate Search index.
* Used by the embedding-status endpoint.
*/
@Transactional(readOnly = true)
public int indexedNodeCount() {
try {
SearchSession session = Search.session(entityManager);
return (int) session.search(TaxonomyNode.class)
.where(f -> f.matchAll())
.fetchTotalHitCount();
} catch (Exception e) {
return 0;
}
}
// ── Public API ────────────────────────────────────────────────────────────
/**
* Returns the DJL embedding vector for {@code text} (document embedding, no prefix).
* Used for indexing taxonomy nodes — call {@link #embedQuery(String)} for search queries.
*
* @throws Exception if the model cannot be loaded or inference fails
*/
public float[] embed(String text) throws Exception {
try (Predictor<String, float[]> predictor = getModel().newPredictor()) {
return predictor.predict(text);
}
}
/**
* Returns the DJL embedding vector for a <em>query</em> text, prepending the
* configured query prefix ({@code embedding.query.prefix}) for asymmetric retrieval.
* BGE models produce better search results when the query is prefixed.
*
* @throws Exception if the model cannot be loaded or inference fails
*/
public float[] embedQuery(String text) throws Exception {
String prefixed = (queryPrefix != null && !queryPrefix.isEmpty())
? queryPrefix + text : text;
return embed(prefixed);
}
/**
* Scores each taxonomy node against {@code businessText} using Hibernate Search
* KNN vector query. Used by {@link LlmService} when {@code LLM_PROVIDER=LOCAL_ONNX}.
*
* <p>Scores are derived from the Hibernate Search / Lucene KNN score (which uses
* {@code (1 + cosineSimilarity) / 2}) and scaled to the 0–100 % range:
* {@code percentage = clamp(round((2 * luceneScore - 1) * 100), 0, 100)}.
*
* <p>On any error all nodes receive score 0 and the exception is logged.
*/
@Transactional(readOnly = true)
public Map<String, Integer> scoreNodes(String businessText, List<TaxonomyNode> nodes) {
Map<String, Integer> scores = new HashMap<>();
for (TaxonomyNode node : nodes) {
scores.put(node.getCode(), 0);
}
if (!isAvailable()) return scores;
try {
float[] queryVector = embedQuery(businessText);
List<String> nodeCodes = nodes.stream()
.map(TaxonomyNode::getCode).collect(Collectors.toList());
SearchSession session = Search.session(entityManager);
// Use score projection so we can map Lucene scores to percentages
List<List<?>> hits = session.search(TaxonomyNode.class)
.select(f -> f.composite(f.entity(TaxonomyNode.class), f.score()))
.where(f -> f.knn(nodes.size())
.field("embedding")
.matching(queryVector)
.filter(f.terms().field("code").matchingAny(nodeCodes)))
.fetchHits(nodes.size());
for (List<?> hit : hits) {
TaxonomyNode node = (TaxonomyNode) hit.get(0);
float luceneScore = (Float) hit.get(1);
// Lucene COSINE KNN score = (1 + cosineSim) / 2 → cosineSim in [-1, 1]
// Map cosineSim to percentage: cosineSim = 2 * luceneScore - 1
int percentage = (int) Math.round((2.0 * luceneScore - 1.0) * 100.0);
percentage = Math.max(0, Math.min(100, percentage));
scores.put(node.getCode(), percentage);
}
log.info("LOCAL_ONNX scores: {}", scores);
} catch (Exception e) {
log.error("Error in KNN vector scoring; returning zero scores", e);
}
return scores;
}
/**
* Semantic search across the full taxonomy index.
* Returns up to {@code topK} taxonomy node DTOs ranked by cosine similarity to
* {@code queryText}. Returns an empty list when embedding is not available.
*
* @param queryText natural-language description (e.g. "secure voice communications")
* @param topK maximum number of results
* @return ranked list of matching taxonomy nodes (flat DTOs, no children)
*/
@Transactional(readOnly = true)
public List<TaxonomyNodeDto> semanticSearch(String queryText, int topK) {
if (!isAvailable()) return Collections.emptyList();
try {
float[] queryVector = embedQuery(queryText);
SearchSession session = Search.session(entityManager);
List<TaxonomyNode> hits = session.search(TaxonomyNode.class)
.where(f -> f.knn(topK).field("embedding").matching(queryVector))
.fetchHits(topK);
return hits.stream().map(this::toFlatDto).collect(Collectors.toList());
} catch (Exception e) {
log.error("Semantic search failed for query '{}': {}", queryText, e.getMessage());
return Collections.emptyList();
}
}
/**
* Find taxonomy nodes semantically similar to a given node, identified by its code.
* Uses the node's enriched text as the query.
* Excludes the source node itself from the results.
*
* @param nodeCode code of the reference node (e.g. "BP.001")
* @param topK maximum number of similar nodes to return
* @return ranked list of similar taxonomy node DTOs
*/
@Transactional(readOnly = true)
public List<TaxonomyNodeDto> findSimilarNodes(String nodeCode, int topK) {
if (!isAvailable()) return Collections.emptyList();
try {
TaxonomyNode node = entityManager.createQuery(
"SELECT n FROM TaxonomyNode n WHERE n.code = :code", TaxonomyNode.class)
.setParameter("code", nodeCode)
.getResultStream().findFirst().orElse(null);
if (node == null) {
log.warn("Node '{}' not found in database", nodeCode);
return Collections.emptyList();
}
String nodeText = buildNodeText(node);
float[] queryVector = embed(nodeText);
SearchSession session = Search.session(entityManager);
// Retrieve topK+1 so we can exclude the source node
List<TaxonomyNode> hits = session.search(TaxonomyNode.class)
.where(f -> f.knn(topK + 1).field("embedding").matching(queryVector))
.fetchHits(topK + 1);
return hits.stream()
.filter(n -> !nodeCode.equals(n.getCode()))
.limit(topK)
.map(this::toFlatDto)
.collect(Collectors.toList());
} catch (Exception e) {
log.error("findSimilarNodes failed for node '{}': {}", nodeCode, e.getMessage());
return Collections.emptyList();
}
}
// ── Helpers ───────────────────────────────────────────────────────────────
private String buildNodeText(TaxonomyNode node) {
StringBuilder sb = new StringBuilder(node.getNameEn() != null ? node.getNameEn() : "");
if (node.getDescriptionEn() != null && !node.getDescriptionEn().isBlank()) {
sb.append(". ").append(node.getDescriptionEn());
}
return sb.toString();
}
private TaxonomyNodeDto toFlatDto(TaxonomyNode node) {
TaxonomyNodeDto dto = new TaxonomyNodeDto();
dto.setId(node.getId());
dto.setCode(node.getCode());
dto.setUuid(node.getUuid());
dto.setNameEn(node.getNameEn());
dto.setNameDe(node.getNameDe());
dto.setDescriptionEn(node.getDescriptionEn());
dto.setDescriptionDe(node.getDescriptionDe());
dto.setParentCode(node.getParentCode());
dto.setTaxonomyRoot(node.getTaxonomyRoot());
dto.setLevel(node.getLevel());
dto.setDataset(node.getDataset());
dto.setExternalId(node.getExternalId());
dto.setSource(node.getSource());
dto.setReference(node.getReference());
dto.setSortOrder(node.getSortOrder());
dto.setState(node.getState());
return dto;
}
}