Last active
January 23, 2025 15:39
-
-
Save DanielKocan/22d5088aaf45bf97d47c5299ea80e3d9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| // Copyright Epic Games, Inc. All Rights Reserved. | |
| #include "AwesomeAIPlugin.h" | |
| #include <iostream> | |
| #include <sstream> | |
| #include "llama.h" | |
| // For online AI | |
| #include "HttpModule.h" | |
| #include "Interfaces/IHttpRequest.h" | |
| #include "Interfaces/IHttpResponse.h" | |
| #include "JsonObjectConverter.h" | |
| #define LOCTEXT_NAMESPACE "FAwesomeAIPluginModule" | |
| void FAwesomeAIPluginModule::StartupModule() | |
| { | |
| std::string assistantRole = "Helpful assistant"; | |
| std::string messageToAI = "Hey, who are you?"; | |
| std::string modelName = "gpt-3.5-turbo"; | |
| std::string secretKey = ""; | |
| // Prepare the JSON payload (actual data being sent in a request or message, often as part of an HTTP request or network communicationS) | |
| TSharedPtr<FJsonObject> JsonPayload = MakeShareable(new FJsonObject); | |
| JsonPayload->SetStringField("model", FString(modelName.c_str())); | |
| JsonPayload->SetNumberField("max_tokens", 4096); | |
| JsonPayload->SetNumberField("temperature", 0.7); | |
| TArray<TSharedPtr<FJsonValue>> Messages; | |
| TSharedPtr<FJsonObject> SystemMessageObject = MakeShareable(new FJsonObject); // Create the "system" message | |
| SystemMessageObject->SetStringField("role", "system"); | |
| SystemMessageObject->SetStringField("content", FString(assistantRole.c_str())); | |
| // Wrap the object in a FJsonValueObject | |
| Messages.Add(MakeShareable(new FJsonValueObject(SystemMessageObject))); | |
| // Create the new "user" message | |
| TSharedPtr<FJsonObject> UserMessageObject = MakeShareable(new FJsonObject); | |
| UserMessageObject->SetStringField("role", "user"); | |
| UserMessageObject->SetStringField("content", FString(messageToAI.c_str())); | |
| Messages.Add(MakeShareable(new FJsonValueObject(UserMessageObject))); // Wrap the object in a FJsonValueObject | |
| JsonPayload->SetArrayField("messages", Messages); | |
| FString JsonString; | |
| TSharedRef<TJsonWriter<>> Writer = TJsonWriterFactory<>::Create(&JsonString); | |
| FJsonSerializer::Serialize(JsonPayload.ToSharedRef(), Writer); | |
| // Set up HTTP request | |
| TSharedRef<IHttpRequest> HttpRequest = FHttpModule::Get().CreateRequest(); | |
| HttpRequest->SetURL("https://api.openai.com/v1/chat/completions"); | |
| HttpRequest->SetVerb("POST"); | |
| HttpRequest->SetHeader(TEXT("Content-Type"), TEXT("application/json")); | |
| HttpRequest->SetHeader(TEXT("Authorization"), FString(secretKey.c_str())); | |
| HttpRequest->SetContentAsString(JsonString); | |
| UE_LOG(LogTemp, Log, TEXT("Request content: %s \n"), *JsonString); | |
| // Bind callback for when the request is complete | |
| HttpRequest->OnProcessRequestComplete().BindLambda( | |
| [](FHttpRequestPtr Request, FHttpResponsePtr Response, bool bWasSuccessful) | |
| { | |
| if (bWasSuccessful && Response.IsValid() && EHttpResponseCodes::IsOk(Response->GetResponseCode())) | |
| { | |
| // Log the entire JSON response first for debug | |
| FString RawResponse = Response->GetContentAsString(); | |
| UE_LOG(LogTemp, Log, TEXT("Full OpenAI Response: %s"), *RawResponse); | |
| // Parse the response | |
| TSharedPtr<FJsonObject> JsonResponse; | |
| TSharedRef<TJsonReader<>> Reader = TJsonReaderFactory<>::Create(Response->GetContentAsString()); | |
| if (FJsonSerializer::Deserialize(Reader, JsonResponse) && JsonResponse.IsValid()) | |
| { | |
| const TArray<TSharedPtr<FJsonValue>>* Choices; | |
| if (JsonResponse->TryGetArrayField("choices", Choices) && Choices->Num() > 0) | |
| { | |
| TSharedPtr<FJsonObject> Choice = (*Choices)[0]->AsObject(); | |
| if (Choice.IsValid()) | |
| { | |
| TSharedPtr<FJsonObject> Message = Choice->GetObjectField("message"); | |
| if (Message.IsValid()) | |
| { | |
| const FString AIResponse = Message->GetStringField("content"); | |
| UE_LOG(LogTemp, Warning, TEXT("AI Response: %s"), *AIResponse); | |
| } | |
| } | |
| } | |
| } | |
| else | |
| { | |
| UE_LOG(LogTemp, Error, TEXT("Failed to parse JSON response.")); | |
| } | |
| } | |
| else | |
| { | |
| if (Response.IsValid()) | |
| { | |
| UE_LOG(LogTemp, Error, TEXT("HTTP Request Failed: %s"), *Response->GetContentAsString()); | |
| } | |
| else | |
| { | |
| UE_LOG(LogTemp, Error, TEXT("HTTP Request Failed: Invalid Response.")); | |
| } | |
| } | |
| }); | |
| // Send the request | |
| HttpRequest->ProcessRequest(); | |
| /* | |
| // Local AI | |
| // This code will execute after your module is loaded into memory; the exact timing is specified in the .uplugin file per-module | |
| // Set assistant role in the context | |
| std::string assistantRole = "Helpful assistant"; | |
| std::string messageToAI = "Hey, who are you?"; | |
| std::string model_path = "C:/CPlusPlus/LlamaCppAi/assets/aiModels/Mistral-7B-Instruct-v0.3.Q6_K.gguf"; | |
| std::vector<llama_chat_message> messages; // Holds the conversation context | |
| llama_model* model = nullptr; // Model instance | |
| llama_context* ctx = nullptr; // Context instance | |
| llama_sampler* sampler = nullptr; // Sampler instance | |
| std::vector<char> formatted; // Buffer for formatted prompts | |
| // --- Initialize Model --- | |
| llama_model_params model_params = llama_model_default_params(); | |
| model = llama_load_model_from_file(model_path.c_str(), model_params); | |
| if (!model) { | |
| UE_LOG(LogTemp, Error, TEXT("Unable to load model.")); | |
| return; | |
| } | |
| llama_context_params ctx_params = llama_context_default_params(); | |
| ctx_params.n_ctx = 2048; // Default context size (size of the context window, how many tokens (words, parts of words, punctuation) the model can consider at one time. If you exceed this limit, the model will start forgetting older parts of the conversation.) | |
| ctx = llama_new_context_with_model(model, ctx_params); // Actualy creates the workspace (memory) | |
| if (!ctx) { | |
| UE_LOG(LogTemp, Error, TEXT("Failed to create context.")); | |
| return; | |
| } | |
| // Initialize sampler (A sampler in the context of language models determines how the next token (word, part of a word, or symbol) is selected during text generation.) | |
| sampler = llama_sampler_chain_init(llama_sampler_chain_default_params()); | |
| // Add specific sampling techniques | |
| llama_sampler_chain_add(sampler, llama_sampler_init_min_p(0.1f, 1)); // ensures that less probable tokens don’t get picked too often. | |
| llama_sampler_chain_add(sampler, llama_sampler_init_top_p(0.90f, 1)); // limits the pool of tokens to the top 90% by probability, making the output more focused. | |
| llama_sampler_chain_add(sampler, llama_sampler_init_temp(0.7)); // controls randomness; lower values make the AI's responses more predictable, while higher values make them more creative. | |
| llama_sampler_chain_add(sampler, llama_sampler_init_dist(LLAMA_DEFAULT_SEED)); | |
| formatted.resize(llama_n_ctx(ctx)); | |
| // ----------------------------- | |
| // Add role for AI moderl and user input to messages | |
| messages.push_back({ "system", _strdup(assistantRole.c_str()) }); | |
| messages.push_back({ "user", _strdup(messageToAI.c_str()) }); | |
| int word_limit = 50; | |
| // Format the prompt with the updated messages | |
| int new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size()); | |
| if (new_len > static_cast<int>(formatted.size())) { | |
| formatted.resize(new_len); | |
| new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size()); | |
| } | |
| if (new_len < 0) { | |
| UE_LOG(LogTemp, Error, TEXT("Failed to format the chat template.")); | |
| return; | |
| } | |
| std::string prompt(formatted.begin(), formatted.begin() + new_len); | |
| // Generate response (Main loop) | |
| auto generate = [&](const std::string& prompt) { | |
| const int n_prompt_tokens = -llama_tokenize(model, prompt.c_str(), prompt.size(), NULL, 0, true, true); | |
| std::vector<llama_token> prompt_tokens(n_prompt_tokens); | |
| if (llama_tokenize(model, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), llama_get_kv_cache_used_cells(ctx) == 0, true) < 0) { | |
| return std::string("Error: Failed to tokenize prompt."); | |
| } | |
| llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); | |
| llama_token new_token_id; | |
| std::ostringstream response_stream; | |
| FString UELOG = ""; | |
| int word_count = 0; // Word counter to limit the number of words in the response | |
| while (true) { | |
| if (llama_decode(ctx, batch)) { | |
| return std::string("Error: Failed to decode."); | |
| } | |
| new_token_id = llama_sampler_sample(sampler, ctx, -1); | |
| if (llama_token_is_eog(model, new_token_id) || word_count >= word_limit) { | |
| break; | |
| } | |
| char buf[256]; | |
| int n = llama_token_to_piece(model, new_token_id, buf, sizeof(buf), 0, true); | |
| if (n < 0) { | |
| return std::string("Error: Failed to convert token to piece."); | |
| } | |
| // Output each word as it's generated | |
| std::string word_piece(buf, n); | |
| UELOG += word_piece.c_str(); | |
| //std::cout << word_piece << std::flush; | |
| response_stream << word_piece; // Append word to the response | |
| word_count++; // Increment word count | |
| //UE_LOG(LogTemp, Log, TEXT("%s"), *UELOG); // Uncoment to print token by token generation. | |
| batch = llama_batch_get_one(&new_token_id, 1); | |
| } | |
| std::cout << std::endl; | |
| return response_stream.str(); | |
| }; | |
| std::string response = generate(prompt); | |
| UE_LOG(LogTemp, Warning, TEXT("Generated response: %s"), *FString(response.c_str())); | |
| // Add the response to the conversation history | |
| messages.push_back({ "assistant", _strdup(response.c_str()) }); | |
| // free resources | |
| llama_sampler_free(sampler); | |
| llama_free(ctx); | |
| llama_free_model(model); */ | |
| } | |
| void FAwesomeAIPluginModule::ShutdownModule() | |
| { | |
| // This function may be called during shutdown to clean up your module. For modules that support dynamic reloading, | |
| // we call this function before unloading the module. | |
| } | |
| #undef LOCTEXT_NAMESPACE | |
| IMPLEMENT_MODULE(FAwesomeAIPluginModule, AwesomeAIPlugin) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment