Skip to content

Instantly share code, notes, and snippets.

@praeclarum
Created July 19, 2018 17:48
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save praeclarum/7b5029656962864936d7667ae2f4a624 to your computer and use it in GitHub Desktop.
Save praeclarum/7b5029656962864936d7667ae2f4a624 to your computer and use it in GitHub Desktop.
Predicts then next C# tokens given a history of previous tokens using CoreML on iOS with F#
// Given previous tokens, predict the next token (and runners up)
let predictNextToken (previousKinds : SyntaxKind[]) : Prediction[] =
if ios11 then
let model : MLModel = model.Value // Load the cached model
let mutable predictions : Prediction[] = [| |]
// RNNs require external memory
let mutable lstm_1_h : MLMultiArray = null
let mutable lstm_1_c : MLMultiArray = null
// Run the model for each previous token
let inputKeys1 = [| s_prevVectorizedToken |]
let inputKeys3 = [| s_prevVectorizedToken; s_lstm_1_h_in; s_lstm_1_c_in |]
let mutable error : NSError = null
for kindIndex, prevKind in previousKinds |> Array.indexed do
// Convert the token to a vector for the model
let vectorizedToken = CSharpPredictor.kindToVector prevKind
// The first run doesn't include the memory
let inputKeys, inputValues = if lstm_1_h <> null then inputKeys3, [| vectorizedToken :> NSObject; lstm_1_h :> NSObject; lstm_1_c :> NSObject |]
else inputKeys1, [| vectorizedToken :> NSObject |]
let inputDict = NSDictionary<NSString, NSObject>.FromObjectsAndKeys (inputValues, inputKeys, System.nint inputKeys.Length)
let inputFeatures = new MLDictionaryFeatureProvider (inputDict, &error)
// Run the prediction
match model.GetPrediction (inputFeatures) with
| _, error when error <> null ->
Debug.WriteLine (error)
failwith "Prediction failed"
| output, _ ->
lstm_1_h <- output.GetFeatureValue("lstm_1_h_out").MultiArrayValue
lstm_1_c <- output.GetFeatureValue("lstm_1_c_out").MultiArrayValue
// If this is the last prediction, store the results
if kindIndex = previousKinds.Length - 1 then
predictions <-
output.GetFeatureValue("nextTokenProbabilities").DictionaryValue
:> Collections.Generic.IDictionary<NSObject, NSNumber>
|> Seq.map (fun (x : System.Collections.Generic.KeyValuePair<NSObject, NSNumber>) -> string x.Key, x.Value.DoubleValue)
|> Seq.filter (fun (_, p) -> p > 1.0e-4)
|> Seq.sortBy (fun (_, p) -> -p)
|> Seq.map (fun (tokenName, p) ->
let kind = CSharpPredictor.stringToSyntaxKind tokenName
let insertText, formatText = CSharpPredictor.kindToCompletion kind
kind, insertText, formatText, p)
|> Array.ofSeq
predictions
else
[| |]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment