EmbeddingService.java

/*******************************************************************************
 * Copyright (c) 2026 Carsten Hammer.
 *
 * This program and the accompanying materials
 * are made available under the terms of the Eclipse Public License 2.0
 * which accompanies this distribution, and is available at
 * https://www.eclipse.org/legal/epl-2.0/
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 * Contributors:
 *     Carsten Hammer
 *******************************************************************************/
package org.eclipse.jgit.storage.hibernate.search;

import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.logging.Level;
import java.util.logging.Logger;

import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;

/**
 * Service for generating semantic embeddings from text using a local ONNX
 * model.
 * <p>
 * Uses Deep Java Library (DJL) with ONNX Runtime to run the
 * {@code all-MiniLM-L6-v2} sentence-transformer model locally on CPU. The
 * model produces 384-dimensional float vectors suitable for cosine similarity
 * search via Hibernate Search's {@code @VectorField}.
 * </p>
 * <p>
 * The model is lazily initialized on first use and cached for the lifetime of
 * the service. If the model cannot be loaded (e.g., no network on first run,
 * corrupted cache), the service degrades gracefully — {@link #embed(String)}
 * returns {@code null} and callers should set {@code hasEmbedding = false}.
 * </p>
 *
 * <h3>Configuration (environment variables)</h3>
 * <ul>
 * <li>{@code JGIT_EMBEDDING_ENABLED} — set to {@code false} to disable
 * embedding generation entirely (default: {@code true})</li>
 * <li>{@code JGIT_EMBEDDING_MODEL_DIR} — local directory for cached model
 * files (default: DJL default cache {@code ~/.djl.ai/cache})</li>
 * <li>{@code JGIT_EMBEDDING_MODEL_NAME} — HuggingFace model ID (default:
 * {@code sentence-transformers/all-MiniLM-L6-v2})</li>
 * </ul>
 *
 * @see ai.djl.repository.zoo.ZooModel
 */
public class EmbeddingService {

	private static final Logger LOG = Logger
			.getLogger(EmbeddingService.class.getName());

	/** Embedding vector dimension produced by all-MiniLM-L6-v2. */
	public static final int EMBEDDING_DIMENSION = 384;

	/** Maximum token length supported by the model. */
	private static final int MAX_TOKEN_LENGTH = 512;

	/** Default model name. */
	private static final String DEFAULT_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"; //$NON-NLS-1$

	private final boolean enabled;

	private final String modelName;

	private final String modelDir;

	private volatile ZooModel<String, float[]> model;

	private volatile boolean initAttempted;

	private volatile boolean available;

	/**
	 * Create an embedding service with default configuration from environment
	 * variables.
	 */
	public EmbeddingService() {
		this(isEnabledFromEnv(), getModelNameFromEnv(), getModelDirFromEnv());
	}

	/**
	 * Create an embedding service with explicit configuration.
	 *
	 * @param enabled
	 *            whether embedding generation is enabled
	 * @param modelName
	 *            the HuggingFace model identifier
	 * @param modelDir
	 *            local model cache directory, or {@code null} for DJL default
	 */
	public EmbeddingService(boolean enabled, String modelName,
			String modelDir) {
		this.enabled = enabled;
		this.modelName = modelName;
		this.modelDir = modelDir;
	}

	/**
	 * Generate a semantic embedding vector for the given text.
	 * <p>
	 * The text is truncated to the model's maximum token length (512 tokens)
	 * before encoding. If the model is not available or embedding is disabled,
	 * returns {@code null}.
	 * </p>
	 *
	 * @param text
	 *            the input text to embed
	 * @return a 384-dimensional float array, or {@code null} if embedding is
	 *         unavailable
	 */
	public float[] embed(String text) {
		if (!enabled || text == null || text.isBlank()) {
			return null;
		}
		ensureInitialized();
		if (!available || model == null) {
			return null;
		}
		try {
			String truncated = truncateToTokenLimit(text);
			try (Predictor<String, float[]> predictor = model
					.newPredictor()) {
				return predictor.predict(truncated);
			}
		} catch (TranslateException e) {
			LOG.log(Level.WARNING,
					"Failed to generate embedding", e); //$NON-NLS-1$
			return null;
		}
	}

	/**
	 * Build the embedding input text from Java source metadata.
	 * <p>
	 * Combines class name, documentation, method signatures and package name
	 * into a single string optimized for semantic search.
	 * </p>
	 *
	 * @param simpleClassName
	 *            the simple class name (may be null)
	 * @param typeDocumentation
	 *            the Javadoc documentation (may be null)
	 * @param methodSignatures
	 *            newline-separated method signatures (may be null)
	 * @param packageName
	 *            the package name (may be null)
	 * @return the combined embedding input text
	 */
	public static String buildEmbeddingText(String simpleClassName,
			String typeDocumentation, String methodSignatures,
			String packageName) {
		StringBuilder sb = new StringBuilder();
		if (simpleClassName != null && !simpleClassName.isEmpty()) {
			sb.append(simpleClassName);
		}
		if (typeDocumentation != null && !typeDocumentation.isEmpty()) {
			if (sb.length() > 0) {
				sb.append(": "); //$NON-NLS-1$
			}
			sb.append(typeDocumentation);
		}
		if (methodSignatures != null && !methodSignatures.isEmpty()) {
			if (sb.length() > 0) {
				sb.append("\nMethods: "); //$NON-NLS-1$
			}
			sb.append(methodSignatures);
		}
		if (packageName != null && !packageName.isEmpty()) {
			if (sb.length() > 0) {
				sb.append("\nPackage: "); //$NON-NLS-1$
			}
			sb.append(packageName);
		}
		return sb.toString();
	}

	/**
	 * Check if the embedding service is available (model loaded successfully).
	 *
	 * @return {@code true} if embeddings can be generated
	 */
	public boolean isAvailable() {
		if (!enabled) {
			return false;
		}
		ensureInitialized();
		return available;
	}

	/**
	 * Check if embedding generation is enabled.
	 *
	 * @return {@code true} if embedding is enabled
	 */
	public boolean isEnabled() {
		return enabled;
	}

	/**
	 * Close the model and release resources.
	 */
	public void close() {
		if (model != null) {
			model.close();
			model = null;
			available = false;
		}
	}

	private synchronized void ensureInitialized() {
		if (initAttempted) {
			return;
		}
		initAttempted = true;
		try {
			LOG.log(Level.INFO,
					"Initializing embedding model: {0}", modelName); //$NON-NLS-1$
			Criteria.Builder<String, float[]> builder = Criteria.builder()
					.setTypes(String.class, float[].class)
					.optModelUrls(
							"djl://ai.djl.huggingface.pytorch/" //$NON-NLS-1$
									+ modelName)
					.optEngine("OnnxRuntime") //$NON-NLS-1$
					.optTranslatorFactory(
							new ai.djl.huggingface.translator.TextEmbeddingTranslatorFactory());
			if (modelDir != null && !modelDir.isEmpty()) {
				Path dir = Paths.get(modelDir);
				System.setProperty("DJL_CACHE_DIR", //$NON-NLS-1$
						dir.toAbsolutePath().toString());
			}
			model = builder.build().loadModel();
			available = true;
			LOG.log(Level.INFO,
					"Embedding model loaded successfully: {0}", //$NON-NLS-1$
					modelName);
		} catch (ModelNotFoundException | MalformedModelException
				| IOException e) {
			LOG.log(Level.WARNING,
					"Failed to load embedding model — vector search disabled. " //$NON-NLS-1$
							+ "Full-text search remains functional.", //$NON-NLS-1$
					e);
			available = false;
		}
	}

	/**
	 * Truncate text to approximate the model's token limit.
	 * <p>
	 * Uses a simple character-based approximation (~4 characters per token
	 * for English text). This may over- or under-estimate for non-English
	 * text or code with many special characters. The model itself handles
	 * actual tokenization and will truncate at its internal token limit.
	 * </p>
	 */
	private static String truncateToTokenLimit(String text) {
		// Simple character-based approximation: ~4 chars per token for English
		int maxChars = MAX_TOKEN_LENGTH * 4;
		if (text.length() > maxChars) {
			return text.substring(0, maxChars);
		}
		return text;
	}

	private static boolean isEnabledFromEnv() {
		String val = System.getenv("JGIT_EMBEDDING_ENABLED"); //$NON-NLS-1$
		return val == null || !"false".equalsIgnoreCase(val); //$NON-NLS-1$
	}

	private static String getModelNameFromEnv() {
		String val = System.getenv("JGIT_EMBEDDING_MODEL_NAME"); //$NON-NLS-1$
		return val != null && !val.isEmpty() ? val : DEFAULT_MODEL_NAME;
	}

	private static String getModelDirFromEnv() {
		return System.getenv("JGIT_EMBEDDING_MODEL_DIR"); //$NON-NLS-1$
	}
}