Created
December 7, 2023 17:33
-
-
Save neuronicnobody/eb22e8b587ba33ab771ff1a438b48d51 to your computer and use it in GitHub Desktop.
OpenBB Demo Plugin
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use extism_pdk::*; | |
use serde::{Serialize, Deserialize}; | |
use serde_json::json; | |
use std::cmp::Ordering; | |
use ndarray::Array1; | |
use ndarray::Zip; | |
use std::collections::HashMap; | |
use rust_decimal::Decimal; | |
use rust_decimal::prelude::ToPrimitive; | |
#[derive(Serialize, Debug, Clone)] | |
struct TaReq { | |
name: String, | |
prices: HashMap<u64, Decimal>, | |
params: String, | |
} | |
#[derive(Deserialize, Debug, Clone)] | |
struct TaRep { | |
data: HashMap<u64, Decimal>, | |
} | |
#[derive(Serialize, Deserialize, Debug, Clone)] | |
struct PluginReq { | |
prices: HashMap<u64, Decimal>, | |
params: HashMap<String, i32>, | |
} | |
#[derive(Serialize, Debug, Clone)] | |
struct Signal { | |
signal: HashMap<u64, i32>, | |
} | |
#[derive(Serialize, Debug, Clone)] | |
struct PluginMetadata { | |
description: String, | |
params: String, | |
} | |
#[host_fn] | |
extern "ExtismHost" { | |
fn get_ta(req: Json<TaReq>) -> Json<TaRep>; | |
} | |
#[plugin_fn] | |
pub unsafe fn get_metadata(_: ()) -> FnResult<Json<PluginMetadata>> { | |
let params = json!([{ | |
"param": "length", | |
"flag": "-l", | |
"desc": "EMA period to consider", | |
"default": 20 | |
}]); | |
let metadata = PluginMetadata { | |
description: "Strategy where stock is bought when Price > EMA(l)".to_string(), | |
params: params.to_string(), | |
}; | |
Ok(Json(metadata)) | |
} | |
#[plugin_fn] | |
pub unsafe fn call<'a>(Json(req): Json<PluginReq>) -> FnResult<Json<Signal>> { | |
info!("Extism: Inside Plugin"); | |
info!("PluginReq: {:#?}", req); | |
// Accessing a value in params hashmap | |
let periods = req.params.get("length").unwrap(); | |
let mut prices_vec: Vec<_> = req.prices.clone().into_iter().collect(); | |
prices_vec.sort_by(|&(k1, _), &(k2, _)| k1.cmp(&k2)); | |
let mut sorted_dates: Vec<u64> = prices_vec.iter() | |
.map(|&(k, _)| k) | |
.collect(); | |
// make call to host function to get the EMA based on the number of periods requested | |
let params = json!({ | |
"periods": periods, | |
}).to_string(); | |
let Json(rep) = unsafe { | |
get_ta(Json(TaReq { | |
name: "ema".to_string(), | |
prices: req.prices, | |
params: params, | |
}))? | |
}; | |
//info!("REP: {:#?}", rep); | |
let mut ema_vec: Vec<_> = rep.data.clone().into_iter().collect(); | |
ema_vec.sort_by(|&(k1, _), &(k2, _)| k1.cmp(&k2)); | |
info!("ema_vec: {:#?}", ema_vec); | |
// Extract the sorted prices into a Vec<f64>. | |
let sorted_ema: Vec<f64> = ema_vec.iter() | |
.map(|&(_, ref v)| v.to_f64().expect("Invalid decimal")) | |
.collect(); | |
// Extract the sorted prices into a Vec<f64>. | |
let sorted_prices: Vec<f64> = prices_vec.iter() | |
.map(|&(_, ref v)| v.to_f64().expect("Invalid decimal")) | |
.collect(); | |
info!("sorted_dates {:?}", sorted_dates); | |
info!("sorted_ema {:?}", sorted_ema); | |
info!("sorted_prices: {:#?}", sorted_prices); | |
// Create an ndarray from the sorted prices Vec<f64>. | |
let ema_array: Array1<f64> = Array1::from(sorted_ema.clone()); | |
let prices_array: Array1<f64> = Array1::from(sorted_prices.clone()); | |
// Now `prices_array` is an ndarray sorted by the keys of the original HashMap. | |
info!("ema_array {:?}", ema_array); | |
info!("prices_array {:?}", prices_array); | |
//let mut greater_than_mask = Array1::from_elem(prices_array.raw_dim(), false); | |
let mut comparison_mask = Array1::from_elem(ema_array.raw_dim(), 0i32); | |
// Perform element-wise comparison. | |
Zip::from(&mut comparison_mask) | |
.and(&prices_array) | |
.and(&ema_array) | |
.apply(|mask, &price, &ema| { | |
*mask = if price > ema { | |
1 // if the price is greater than the ema, set mask to 1 | |
} else { | |
0 // if they are equal or price less than ema, set mask to 0 | |
}; | |
}); | |
info!("Mask: {:?}", comparison_mask); | |
// Turn it into an iterator and collect into a vector. | |
let mut vector: Vec<_> = comparison_mask.into_iter().collect(); | |
info!("{:?}", vector); | |
let paired_vector: Vec<(u64, i32)> = sorted_dates.into_iter() | |
.zip(vector.into_iter()) | |
.collect(); | |
info!("{:?}", paired_vector); | |
let signal_data: HashMap<u64, i32> = paired_vector.into_iter().collect(); | |
let signal = Signal { signal: signal_data }; | |
Ok(Json(signal)) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment