OpenAiCompatibleGateway.java

package com.taxonomy.analysis.service;

import tools.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.*;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.web.client.HttpClientErrorException;
import org.springframework.web.client.HttpServerErrorException;
import org.springframework.web.client.ResourceAccessException;
import org.springframework.web.client.RestTemplate;

import com.taxonomy.preferences.PreferencesService;

import java.net.SocketTimeoutException;
import java.util.*;

/**
 * Gateway for OpenAI-compatible LLM APIs (OpenAI, DeepSeek, Qwen, Llama, Mistral).
 *
 * <p>All these providers share the same request/response format (messages array with
 * role/content, Bearer auth), but may have different endpoints, model names, and
 * rate limits.
 *
 * <p>Each {@code OpenAiCompatibleGateway} instance maintains its own sliding-window
 * throttle queue, so providers with generous rate limits (e.g. paid OpenAI) are not
 * penalised by providers with strict limits.
 *
 * <p>When {@code defaultRpm} is 0, no throttling is applied (suitable for self-hosted
 * models like Llama or Mistral with no API rate limit).
 */
public class OpenAiCompatibleGateway implements LlmGateway {

    private static final Logger log = LoggerFactory.getLogger(OpenAiCompatibleGateway.class);

    /** Buffer added to the sleep duration in the RPM throttle (ms). */
    private static final long THROTTLE_BUFFER_MS = 50L;

    private final LlmProvider provider;
    private final String url;
    private final String model;
    private final int defaultRpm;
    private final RestTemplate restTemplate;
    private final ObjectMapper objectMapper;
    private final LlmResponseParser responseParser;
    private final PreferencesService preferencesService;
    private final SimpleClientHttpRequestFactory llmRequestFactory;
    private final LlmRecordReplayService recordReplayService;

    /** Sliding-window timestamps for per-gateway RPM throttling. */
    private final ArrayDeque<Long> callTimestamps = new ArrayDeque<>();

    public OpenAiCompatibleGateway(LlmProvider provider,
                                    String url,
                                    String model,
                                    int defaultRpm,
                                    RestTemplate restTemplate,
                                    ObjectMapper objectMapper,
                                    LlmResponseParser responseParser,
                                    PreferencesService preferencesService,
                                    SimpleClientHttpRequestFactory llmRequestFactory,
                                    LlmRecordReplayService recordReplayService) {
        this.provider = provider;
        this.url = url;
        this.model = model;
        this.defaultRpm = defaultRpm;
        this.restTemplate = restTemplate;
        this.objectMapper = objectMapper;
        this.responseParser = responseParser;
        this.preferencesService = preferencesService;
        this.llmRequestFactory = llmRequestFactory;
        this.recordReplayService = recordReplayService;
    }

    @Override
    public String providerName() {
        return provider.name();
    }

    @Override
    public String extractResponseText(String rawResponseBody) {
        return responseParser.extractOpenAiText(rawResponseBody);
    }

    @Override
    public String sendHttpRequest(String prompt, String apiKey) {
        // REPLAY: return a previously recorded response — skips throttle and real API call.
        if (recordReplayService != null && recordReplayService.isReplayMode()) {
            Optional<String> recorded = recordReplayService.replay(prompt);
            if (recorded.isPresent()) return recorded.get();
            if (!recordReplayService.isFallbackLive()) {
                log.warn("No LLM recording found for prompt hash — no fallback configured");
                return null;
            }
            log.warn("No LLM recording found for prompt hash — falling back to live API");
        }

        // Real API path — throttle to respect RPM rate limits
        throttle();
        applyCurrentTimeout();

        Map<String, Object> body    = new LinkedHashMap<>();
        Map<String, String> message = new LinkedHashMap<>();
        message.put("role", "user");
        message.put("content", prompt);
        body.put("model", model);
        body.put("messages", List.of(message));

        HttpHeaders headers = new HttpHeaders();
        headers.setContentType(MediaType.APPLICATION_JSON);
        headers.setBearerAuth(apiKey);
        try {
            HttpEntity<String> entity = new HttpEntity<>(objectMapper.writeValueAsString(body), headers);

            int maxRetries = preferencesService != null
                    ? preferencesService.getInt("llm.retry.max", 2) : 2;
            int attempt = 0;
            while (true) {
                ResponseEntity<String> response;
                try {
                    response = restTemplate.exchange(url, HttpMethod.POST, entity, String.class);
                } catch (HttpClientErrorException e) {
                    if (e.getStatusCode().value() == 429) {
                        throw new LlmRateLimitException(
                                provider + " rate limit (HTTP 429): " + e.getResponseBodyAsString(), e);
                    }
                    throw new RuntimeException(provider + " API error " + e.getStatusCode() + ": " +
                            e.getResponseBodyAsString(), e);
                } catch (HttpServerErrorException e) {
                    if (attempt < maxRetries) {
                        attempt++;
                        long backoffMs = 1000L * (1L << (attempt - 1));
                        log.warn("{} API server error {} — retry {}/{} after {}ms",
                                provider, e.getStatusCode(), attempt, maxRetries, backoffMs);
                        try {
                            Thread.sleep(backoffMs);
                        } catch (InterruptedException ie) {
                            Thread.currentThread().interrupt();
                        }
                        continue;
                    }
                    throw new RuntimeException(provider + " API server error " + e.getStatusCode() + ": " +
                            e.getResponseBodyAsString(), e);
                } catch (ResourceAccessException e) {
                    if (e.getCause() instanceof SocketTimeoutException) {
                        int timeoutSeconds = preferencesService != null
                                ? preferencesService.getInt("llm.timeout.seconds", 60) : 60;
                        if (attempt < maxRetries) {
                            attempt++;
                            long backoffMs = 1000L * (1L << (attempt - 1));
                            log.warn("{} API read timeout after {}s — retry {}/{} after {}ms",
                                    provider, timeoutSeconds, attempt, maxRetries, backoffMs);
                            try {
                                Thread.sleep(backoffMs);
                            } catch (InterruptedException ie) {
                                Thread.currentThread().interrupt();
                            }
                            continue;
                        }
                        throw new LlmTimeoutException(
                                provider + " API call timed out after " + timeoutSeconds + "s. "
                                + "You can increase the timeout in Preferences → llm.timeout.seconds.", e);
                    }
                    throw e;
                }

                if (response.getStatusCode().is2xxSuccessful() && response.getBody() != null) {
                    log.info("LLM Response [{}] — raw response (first 500 chars): {}",
                            provider, response.getBody().substring(0, Math.min(response.getBody().length(), 500)));

                    // RECORD: persist prompt + response for future replay.
                    if (recordReplayService != null && recordReplayService.isRecordMode()) {
                        recordReplayService.record(prompt, response.getBody(), provider.name(), null);
                    }

                    return response.getBody();
                }
                log.error("{} API returned status {}", provider, response.getStatusCode());
                return null;
            }
        } catch (LlmRateLimitException | LlmTimeoutException e) {
            throw e;
        } catch (Exception e) {
            log.error("Error calling {} API", provider, e);
            return null;
        }
    }

    // ── Per-gateway RPM throttle (sliding window) ─────────────────────────────

    /**
     * Paces outgoing calls using a sliding-window approach.
     *
     * <p>Reads the provider-specific preference {@code llm.rpm.<provider>} first,
     * then falls back to the constructor-provided {@code defaultRpm}.
     * When the effective RPM is 0, no throttling is applied.
     */
    synchronized void throttle() {
        if (preferencesService == null) return;
        String prefKey = "llm.rpm." + provider.name().toLowerCase();
        int rpm = preferencesService.getInt(prefKey, defaultRpm);
        if (rpm <= 0) return;

        long now = System.currentTimeMillis();
        long windowStart = now - 60_000L;

        while (!callTimestamps.isEmpty() && callTimestamps.peekFirst() < windowStart) {
            callTimestamps.pollFirst();
        }

        if (callTimestamps.size() >= rpm) {
            long oldest = callTimestamps.peekFirst();
            long sleepMs = oldest + 60_000L - System.currentTimeMillis() + THROTTLE_BUFFER_MS;
            if (sleepMs > 0) {
                log.debug("{} RPM throttle: sleeping {}ms (rpm={}, calls in window={})",
                        provider, sleepMs, rpm, callTimestamps.size());
                try {
                    Thread.sleep(sleepMs);
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                }
            }
        }

        callTimestamps.addLast(System.currentTimeMillis());
    }

    private void applyCurrentTimeout() {
        if (preferencesService == null || llmRequestFactory == null) return;
        int timeoutSeconds = preferencesService.getInt("llm.timeout.seconds", 60);
        llmRequestFactory.setReadTimeout(timeoutSeconds * 1000);
    }
}