Skip to content

Instantly share code, notes, and snippets.

@fintanmm
Last active January 18, 2024 16:30
Show Gist options
  • Save fintanmm/df7d547bf109029f8c122bfb313ad33b to your computer and use it in GitHub Desktop.
Save fintanmm/df7d547bf109029f8c122bfb313ad33b to your computer and use it in GitHub Desktop.
Messing around with Langchain4j and Jbang
///usr/bin/env jbang "$0" "$@" ; exit $?
//DEPS info.picocli:picocli:4.5.0
//DEPS info.picocli:picocli-codegen:4.5.0
//DEPS ch.qos.reload4j:reload4j:1.2.19
//DEPS dev.langchain4j:langchain4j:0.25.0
//DEPS dev.langchain4j:langchain4j-embeddings:0.25.0
//DEPS dev.langchain4j:langchain4j-ollama:0.25.0
import dev.langchain4j.chain.ConversationalRetrievalChain;
import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.DocumentSplitter;
import dev.langchain4j.data.document.parser.TextDocumentParser;
import dev.langchain4j.data.document.splitter.DocumentSplitters;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.language.StreamingLanguageModel;
import dev.langchain4j.model.ollama.OllamaChatModel;
import dev.langchain4j.model.ollama.OllamaEmbeddingModel;
import dev.langchain4j.model.ollama.OllamaStreamingLanguageModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.retriever.EmbeddingStoreRetriever;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
import org.apache.log4j.BasicConfigurator;
import org.apache.log4j.Logger;
import picocli.CommandLine;
import picocli.CommandLine.Command;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import static dev.langchain4j.data.document.loader.FileSystemDocumentLoader.loadDocument;
import static java.time.Duration.*;
import static java.util.concurrent.TimeUnit.SECONDS;
import java.util.List;
@Command(name = "lc4j", mixinStandardHelpOptions = true, version = "lc4j 0.0.1", description = "lc4j made with jbang")
class Lc4j implements Runnable {
@CommandLine.Option(names = { "-b",
"--base-url" }, description = "Base url", defaultValue = "http://localhost:11434")
private String baseUrl;
@CommandLine.Option(names = { "-m", "--model" }, description = "Model name", defaultValue = "phi")
private String modelName;
@CommandLine.Option(names = { "-q",
"--question" }, description = "Question", defaultValue = "What is the capital of Germany?")
private String question;
@CommandLine.Option(names = { "-t", "--temperature" }, description = "Temperature")
private double temperature = 0.5;
@CommandLine.Option(names = { "-o", "--timeout" }, description = "Timeout in seconds")
private long timeout = 30;
@CommandLine.Option(names = { "-f", "--file" }, description = "File to ingest")
private String file;
private static final Logger logger = Logger.getLogger(Lc4j.class.getName());
public static void main(String... args) {
BasicConfigurator.configure();
new CommandLine(new Lc4j()).execute(args);
}
@Override
public void run() {
if (Objects.nonNull(file)) {
chatWithDocuments(file);
return;
}
askQuestion(question);
}
private void chatWithDocuments(String filePath) {
EmbeddingModel embeddingModel = new OllamaEmbeddingModel(baseUrl, modelName, ofSeconds(timeout), 3);
EmbeddingStore<TextSegment> embeddingStore = new InMemoryEmbeddingStore<>();
EmbeddingStoreIngestor ingestor = EmbeddingStoreIngestor.builder()
.documentSplitter(DocumentSplitters.recursive(500, 0))
.embeddingModel(embeddingModel)
.embeddingStore(embeddingStore)
.build();
Document document = loadDocument(filePath, new TextDocumentParser());
DocumentSplitter splitter = DocumentSplitters.recursive(500, 0);
List<TextSegment> segments = splitter.split(document);
List<Embedding> embeddings = embeddingModel.embedAll(segments).content();
embeddingStore.addAll(embeddings, segments);
Embedding questionEmbedding = embeddingModel.embed(question).content();
ingestor.ingest(document);
ConversationalRetrievalChain chain = ConversationalRetrievalChain.builder()
.chatLanguageModel(OllamaChatModel.builder()
.baseUrl(baseUrl)
.modelName(modelName)
.temperature(temperature)
.timeout(ofSeconds(timeout))
.build()).retriever(EmbeddingStoreRetriever.from(embeddingStore, embeddingModel))
.build();
String answer = chain.execute(question);
logger.info("Answer: " + answer);
}
private void askQuestion(String question) {
StreamingLanguageModel languageModel = createStreamingLanguageModel();
CompletableFuture<Response<String>> futureResponse = new CompletableFuture<>();
languageModel.generate(question,
new StreamingResponseHandler<String>() {
private final StringBuilder answerBuilder = new StringBuilder();
@Override
public void onNext(String token) {
answerBuilder.append(token);
}
@Override
public void onComplete(Response<String> response) {
futureResponse.complete(response);
}
@Override
public void onError(Throwable error) {
futureResponse.completeExceptionally(error);
}
});
try {
Response<String> response = futureResponse.get(timeout, SECONDS);
logger.info("Answer: " + response.content());
} catch (InterruptedException | ExecutionException | TimeoutException e) {
logger.debug("Error: " + e.getMessage());
Thread.currentThread().interrupt();
}
}
private StreamingLanguageModel createStreamingLanguageModel() {
return OllamaStreamingLanguageModel.builder()
.baseUrl(baseUrl)
.modelName(modelName)
.temperature(temperature)
.timeout(ofSeconds(timeout))
.build();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment