121 lines
4.1 KiB
Kotlin
121 lines
4.1 KiB
Kotlin
package com.android.trisolarisserver.component
|
|
|
|
import com.fasterxml.jackson.databind.ObjectMapper
|
|
import org.springframework.beans.factory.annotation.Value
|
|
import org.springframework.http.HttpEntity
|
|
import org.springframework.http.HttpHeaders
|
|
import org.springframework.http.MediaType
|
|
import org.springframework.stereotype.Component
|
|
import org.springframework.web.client.RestTemplate
|
|
|
|
@Component
|
|
class LlamaClient(
|
|
private val restTemplate: RestTemplate,
|
|
private val objectMapper: ObjectMapper,
|
|
@Value("\${ai.llama.baseUrl}")
|
|
private val baseUrl: String,
|
|
@Value("\${ai.llama.temperature:0.7}")
|
|
private val temperature: Double,
|
|
@Value("\${ai.llama.topP:0.8}")
|
|
private val topP: Double,
|
|
@Value("\${ai.llama.minP:0.2}")
|
|
private val minP: Double,
|
|
@Value("\${ai.llama.repeatPenalty:1.0}")
|
|
private val repeatPenalty: Double,
|
|
@Value("\${ai.llama.topK:40}")
|
|
private val topK: Int,
|
|
@Value("\${ai.llama.model}")
|
|
private val model: String
|
|
) {
|
|
private val systemPrompt =
|
|
"Read extremely carefully. Look only at visible text. " +
|
|
"Return the exact text you can read verbatim. " +
|
|
"If the text is unclear, partial, or inferred, return NOT CLEARLY VISIBLE. " +
|
|
"Do not guess. Do not explain."
|
|
|
|
fun ask(imageUrl: String, question: String): String {
|
|
val payload = mapOf(
|
|
"model" to model,
|
|
"temperature" to temperature,
|
|
"top_p" to topP,
|
|
"min_p" to minP,
|
|
"repeat_penalty" to repeatPenalty,
|
|
"top_k" to topK,
|
|
"messages" to listOf(
|
|
mapOf(
|
|
"role" to "system",
|
|
"content" to systemPrompt
|
|
),
|
|
mapOf(
|
|
"role" to "user",
|
|
"content" to listOf(
|
|
mapOf("type" to "text", "text" to question),
|
|
mapOf("type" to "image_url", "image_url" to mapOf("url" to imageUrl))
|
|
)
|
|
)
|
|
)
|
|
)
|
|
return post(payload)
|
|
}
|
|
|
|
fun askWithOcr(imageUrl: String, ocrText: String, question: String): String {
|
|
val payload = mapOf(
|
|
"model" to model,
|
|
"temperature" to temperature,
|
|
"top_p" to topP,
|
|
"min_p" to minP,
|
|
"repeat_penalty" to repeatPenalty,
|
|
"top_k" to topK,
|
|
"messages" to listOf(
|
|
mapOf(
|
|
"role" to "system",
|
|
"content" to systemPrompt
|
|
),
|
|
mapOf(
|
|
"role" to "user",
|
|
"content" to listOf(
|
|
mapOf(
|
|
"type" to "text",
|
|
"text" to "${question}\n\nOCR:\n${ocrText}"
|
|
),
|
|
mapOf("type" to "image_url", "image_url" to mapOf("url" to imageUrl))
|
|
)
|
|
)
|
|
)
|
|
)
|
|
return post(payload)
|
|
}
|
|
|
|
fun askText(content: String, question: String): String {
|
|
val payload = mapOf(
|
|
"model" to model,
|
|
"temperature" to temperature,
|
|
"top_p" to topP,
|
|
"min_p" to minP,
|
|
"repeat_penalty" to repeatPenalty,
|
|
"top_k" to topK,
|
|
"messages" to listOf(
|
|
mapOf(
|
|
"role" to "system",
|
|
"content" to systemPrompt
|
|
),
|
|
mapOf(
|
|
"role" to "user",
|
|
"content" to "${question}\n\nEMAIL:\n${content}"
|
|
)
|
|
)
|
|
)
|
|
return post(payload)
|
|
}
|
|
|
|
private fun post(payload: Map<String, Any>): String {
|
|
val headers = HttpHeaders()
|
|
headers.contentType = MediaType.APPLICATION_JSON
|
|
val entity = HttpEntity(payload, headers)
|
|
val response = restTemplate.postForEntity(baseUrl, entity, String::class.java)
|
|
val body = response.body ?: return ""
|
|
val node = objectMapper.readTree(body)
|
|
return node.path("choices").path(0).path("message").path("content").asText()
|
|
}
|
|
}
|