Skip to content

Instantly share code, notes, and snippets.

@robinmonjo
Last active April 8, 2023 22:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save robinmonjo/255e7c75fd1db6e14d1a1a386d84e5a7 to your computer and use it in GitHub Desktop.
Save robinmonjo/255e7c75fd1db6e14d1a1a386d84e5a7 to your computer and use it in GitHub Desktop.

fastai-7-collaborative-filtering - fork

Mix.install([
  {:axon, "~> 0.5.1"},
  {:explorer, "~> 0.5.6"},
  {:exla, "~> 0.5.2"},
  {:kino, "~> 0.9.1"}
])

Nx.Defn.global_default_options(compiler: EXLA)
Nx.Defn.default_options(compiler: EXLA)

alias Explorer.DataFrame
alias Explorer.Series
Resolving Hex dependencies...
Resolution completed in 0.082s
New:
  axon 0.5.1
  castore 1.0.1
  complex 0.5.0
  elixir_make 0.7.6
  exla 0.5.2
  explorer 0.5.6
  kino 0.9.1
  nx 0.5.2
  rustler_precompiled 0.6.1
  table 0.1.2
  table_rex 3.1.1
  telemetry 1.2.1
  xla 0.4.4
* Getting axon (Hex package)
* Getting explorer (Hex package)
* Getting exla (Hex package)
* Getting kino (Hex package)
* Getting table (Hex package)
* Getting elixir_make (Hex package)
* Getting nx (Hex package)
* Getting telemetry (Hex package)
* Getting xla (Hex package)
* Getting complex (Hex package)
* Getting rustler_precompiled (Hex package)
* Getting table_rex (Hex package)
* Getting castore (Hex package)
==> table
Compiling 5 files (.ex)
Generated table app
===> Analyzing applications...
===> Compiling telemetry
==> complex
Compiling 2 files (.ex)
Generated complex app
==> nx
Compiling 31 files (.ex)
Generated nx app
==> kino
Compiling 39 files (.ex)
Generated kino app
==> table_rex
Compiling 7 files (.ex)
Generated table_rex app
==> axon
Compiling 23 files (.ex)
Generated axon app
==> castore
Compiling 1 file (.ex)
Generated castore app
==> elixir_make
Compiling 6 files (.ex)
Generated elixir_make app
==> xla
Compiling 2 files (.ex)
Generated xla app
==> exla
Unpacking /Users/robinmonjo/Library/Caches/xla/0.4.4/cache/download/xla_extension-x86_64-darwin-cpu.tar.gz into /Users/robinmonjo/Library/Caches/mix/installs/elixir-1.14.2-erts-13.0/56f1972b9b527de38d5c78c43d186808/deps/exla/cache
Using libexla.so from /Users/robinmonjo/Library/Caches/xla/exla/elixir-1.14.2-erts-13.0-xla-0.4.4-exla-0.5.2-cb7hisxh7o7emcwvdb2oss4p4e/libexla.so
Compiling 21 files (.ex)
Generated exla app
==> rustler_precompiled
Compiling 4 files (.ex)
Generated rustler_precompiled app
==> explorer
Compiling 19 files (.ex)

23:29:09.384 [debug] Copying NIF from cache and extracting to /Users/robinmonjo/Library/Caches/mix/installs/elixir-1.14.2-erts-13.0/56f1972b9b527de38d5c78c43d186808/_build/dev/lib/explorer/priv/native/libexplorer-v0.5.6-nif-2.16-x86_64-apple-darwin.so
Generated explorer app
Explorer.Series

Why this notebook ?

I have been following Fast AI lectures but using Elixir instead of Python.

Appart from the first lectures, I'm a total beginner in the ML world.

I'm kind of stuck on lecture 7 about collaborative filtering and can't find out why.

The original notebook is available here

This notebook explains it all.

First download the dataset from here.

Unzip it into ./

df =
  DataFrame.from_csv!("./ml-100k/u.data", delimiter: "\t", header: false)
  |> DataFrame.rename([:user, :movie, :rating, :timestamp])
#Explorer.DataFrame<
  Polars[100000 x 4]
  user integer [196, 186, 22, 244, 166, ...]
  movie integer [242, 302, 377, 51, 346, ...]
  rating integer [3, 3, 1, 2, 1, ...]
  timestamp integer [881250949, 891717742, 878887116, 880606923, 886397596, ...]
>

Ok we have the data, we won't use the timestamp column.

Now we separate our data set into training and validation data frame.

{size, _} = DataFrame.shape(df)
n_train = ceil(size * 0.8)

shuffled_df = DataFrame.shuffle(df)
train_df = DataFrame.slice(shuffled_df, 0..(n_train - 1))
validation_df = DataFrame.slice(shuffled_df, n_train..size)
{train_df, validation_df}
{#Explorer.DataFrame<
   Polars[80000 x 4]
   user integer [468, 695, 318, 327, 880, ...]
   movie integer [1168, 882, 628, 152, 571, ...]
   rating integer [2, 4, 4, 3, 2, ...]
   timestamp integer [875302155, 888805836, 884494757, 887819194, 880175187, ...]
 >,
 #Explorer.DataFrame<
   Polars[20000 x 4]
   user integer [181, 26, 102, 294, 618, ...]
   movie integer [1390, 109, 856, 979, 582, ...]
   rating integer [1, 3, 2, 3, 4, ...]
   timestamp integer [878962052, 891376987, 892993927, 877819897, 891309217, ...]
 >}

Here we prepare the training loop inputs. Our model will have 2 inputs:

  • user_input
  • movie_input

User and movie inputs are their ID. Output is the rating.

PS: I know this code is not sexy but it does the job 😋

batch_size = 1000

train_inputs =
  Enum.zip([
    Series.to_enum(train_df["user"]),
    Series.to_enum(train_df["movie"]),
    Series.to_enum(train_df["rating"])
  ])
  |> Enum.chunk_every(batch_size)
  |> Enum.map(fn batch ->
    {
      %{
        "user_input" =>
          Enum.map(batch, fn {u, _, _} ->
            [u]
          end)
          |> Nx.tensor(),
        "movie_input" =>
          Enum.map(batch, fn {_, m, _} ->
            [m]
          end)
          |> Nx.tensor()
      },
      Enum.map(batch, fn {_, _, r} ->
        [r]
      end)
      |> Nx.tensor()
    }
  end)
[
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [1168],
         [882],
         [628],
         [152],
         [571],
         [879],
         [253],
         [455],
         [504],
         [197],
         [691],
         [340],
         [187],
         [96],
         [739],
         [484],
         [1217],
         [375],
         [127],
         [947],
         [1228],
         [83],
         [357],
         [1478],
         [174],
         [289],
         [744],
         [50],
         [926],
         [328],
         [475],
         [472],
         [838],
         [118],
         [642],
         [1013],
         [530],
         [45],
         [1224],
         [781],
         [710],
         [140],
         [443],
         [143],
         [179],
         [173],
         [196],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [468],
         [695],
         [318],
         [327],
         [880],
         [451],
         [460],
         [847],
         [643],
         [567],
         [587],
         [624],
         [154],
         [296],
         [301],
         [567],
         [95],
         [183],
         [692],
         [543],
         [705],
         [409],
         [514],
         [378],
         [639],
         [518],
         [558],
         [474],
         [42],
         [159],
         [699],
         [648],
         [128],
         [493],
         [497],
         [116],
         [409],
         [539],
         [279],
         [532],
         [62],
         [577],
         [643],
         [426],
         [162],
         [99],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [2],
       [4],
       [4],
       [3],
       [2],
       [4],
       [3],
       [2],
       [4],
       [5],
       [4],
       [3],
       [5],
       [5],
       [2],
       [4],
       [3],
       [2],
       [3],
       [4],
       [2],
       [3],
       [4],
       [3],
       [4],
       [4],
       [4],
       [5],
       [3],
       [3],
       [4],
       [3],
       [5],
       [4],
       [3],
       [3],
       [4],
       [4],
       [3],
       [5],
       [3],
       [4],
       [4],
       [3],
       [3],
       [4],
       [4],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [1063],
         [70],
         [193],
         [625],
         [1291],
         [1115],
         [118],
         [13],
         [111],
         [1081],
         [431],
         [548],
         [83],
         [507],
         [226],
         [268],
         [521],
         [24],
         [511],
         [509],
         [143],
         [10],
         [898],
         [222],
         [170],
         [193],
         [1303],
         [176],
         [411],
         [908],
         [78],
         [219],
         [264],
         [161],
         [978],
         [665],
         [234],
         [512],
         [749],
         [1039],
         [735],
         [98],
         [304],
         [293],
         [108],
         [985],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [383],
         [276],
         [412],
         [622],
         [533],
         [450],
         [416],
         [244],
         [711],
         [276],
         [180],
         [887],
         [916],
         [608],
         [327],
         [710],
         [749],
         [226],
         [932],
         [932],
         [912],
         [567],
         [410],
         [862],
         [449],
         [774],
         [268],
         [177],
         [349],
         [341],
         [749],
         [843],
         [105],
         [152],
         [907],
         [586],
         [378],
         [23],
         [762],
         [458],
         [561],
         [210],
         [732],
         [26],
         [279],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [5],
       [4],
       [4],
       [3],
       [1],
       [4],
       [2],
       [4],
       [2],
       [3],
       [4],
       [1],
       [4],
       [3],
       [3],
       [4],
       [4],
       [4],
       [5],
       [3],
       [5],
       [4],
       [3],
       [5],
       [4],
       [5],
       [1],
       [4],
       [4],
       [3],
       [3],
       [2],
       [2],
       [5],
       [5],
       [3],
       [4],
       [5],
       [1],
       [5],
       [3],
       [5],
       [5],
       [3],
       [4],
       [3],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [108],
         [485],
         [651],
         [135],
         [300],
         [176],
         [291],
         [395],
         [380],
         [324],
         [19],
         [834],
         [311],
         [179],
         [94],
         [94],
         [1113],
         [622],
         [755],
         [229],
         [420],
         [373],
         [362],
         [234],
         [949],
         [306],
         [9],
         [926],
         [333],
         [421],
         [222],
         [323],
         [364],
         [257],
         [322],
         [780],
         [237],
         [95],
         [678],
         [515],
         [531],
         [503],
         [298],
         [77],
         [126],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [804],
         [114],
         [345],
         [766],
         [100],
         [94],
         [332],
         [588],
         [70],
         [592],
         [2],
         [181],
         [460],
         [913],
         [495],
         [885],
         [429],
         [712],
         [588],
         [343],
         [892],
         [790],
         [418],
         [130],
         [145],
         [514],
         [291],
         [847],
         [534],
         [320],
         [682],
         [597],
         [222],
         [939],
         [459],
         [30],
         [829],
         [779],
         [57],
         [59],
         [679],
         [328],
         [378],
         [363],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [3],
       [3],
       [4],
       [4],
       [4],
       [4],
       [4],
       [4],
       [3],
       [4],
       [3],
       [3],
       [5],
       [3],
       [3],
       [2],
       [3],
       [4],
       [3],
       [4],
       [2],
       [3],
       [1],
       [5],
       [4],
       [4],
       [5],
       [1],
       [5],
       [4],
       [4],
       [3],
       [1],
       [5],
       [4],
       [4],
       [3],
       [5],
       [3],
       [4],
       [4],
       [3],
       [3],
       [2],
       [3],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [354],
         [15],
         [1187],
         [52],
         [152],
         [1355],
         [231],
         [535],
         [3],
         [4],
         [636],
         [385],
         [318],
         [96],
         [237],
         [642],
         [182],
         [117],
         [496],
         [56],
         [195],
         [604],
         [1074],
         [124],
         [211],
         [1139],
         [546],
         [887],
         [1045],
         [23],
         [106],
         [106],
         [28],
         [55],
         [923],
         [485],
         [87],
         [28],
         [170],
         [526],
         [288],
         [190],
         [269],
         [276],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [782],
         [896],
         [617],
         [429],
         [41],
         [264],
         [222],
         [916],
         [269],
         [484],
         [303],
         [805],
         [452],
         [709],
         [453],
         [339],
         [13],
         [543],
         [601],
         [506],
         [85],
         [854],
         [487],
         [454],
         [692],
         [246],
         [54],
         [755],
         [535],
         [115],
         [690],
         [634],
         [804],
         [788],
         [747],
         [632],
         [653],
         [833],
         [894],
         [815],
         [693],
         [94],
         [85],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [2],
       [3],
       [3],
       [4],
       [4],
       [4],
       [2],
       [3],
       [3],
       [4],
       [3],
       [1],
       [5],
       [5],
       [4],
       [5],
       [5],
       [3],
       [4],
       [4],
       [3],
       [4],
       [1],
       [4],
       [4],
       [2],
       [3],
       [3],
       [4],
       [5],
       [3],
       [3],
       [4],
       [4],
       [5],
       [4],
       [4],
       [3],
       [4],
       [4],
       [2],
       [5],
       [3],
       [5],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [919],
         [300],
         [496],
         [498],
         [612],
         [25],
         [239],
         [268],
         [506],
         [1042],
         [895],
         [663],
         [121],
         [125],
         [458],
         [166],
         [946],
         [184],
         [302],
         [268],
         [321],
         [292],
         [405],
         [529],
         [228],
         [260],
         [673],
         [273],
         [12],
         [717],
         [686],
         [411],
         [662],
         [50],
         [133],
         [58],
         [192],
         [72],
         [323],
         [483],
         [315],
         [313],
         [191],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [581],
         [40],
         [426],
         [916],
         [102],
         [54],
         [279],
         [244],
         [532],
         [42],
         [758],
         [452],
         [291],
         [580],
         [707],
         [565],
         [826],
         [592],
         [408],
         [574],
         [370],
         [511],
         [901],
         [506],
         [715],
         [70],
         [313],
         [64],
         [686],
         [764],
         [43],
         [542],
         [468],
         [232],
         [286],
         [346],
         [235],
         [504],
         [493],
         [566],
         [880],
         [271],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [5],
       [3],
       [3],
       [3],
       [4],
       [4],
       [4],
       [5],
       [5],
       [3],
       [4],
       [2],
       [2],
       [3],
       [3],
       [4],
       [3],
       [5],
       [5],
       [5],
       [2],
       [5],
       [4],
       [3],
       [3],
       [2],
       [4],
       [2],
       [5],
       [3],
       [3],
       [4],
       [4],
       [4],
       [4],
       [3],
       [4],
       [4],
       [4],
       [4],
       [5],
       [4],
       [3],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [443],
         [246],
         [121],
         [92],
         [523],
         [237],
         [168],
         [250],
         [205],
         [239],
         [708],
         [70],
         [56],
         [479],
         [744],
         [733],
         [409],
         [576],
         [205],
         [83],
         [332],
         [124],
         [1073],
         [179],
         [97],
         [196],
         [240],
         [708],
         [476],
         [705],
         [465],
         [818],
         [405],
         [768],
         [1083],
         [216],
         [278],
         [515],
         [703],
         [385],
         [473],
         [317],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [498],
         [344],
         [89],
         [267],
         [327],
         [72],
         [862],
         [919],
         [383],
         [868],
         [378],
         [468],
         [664],
         [43],
         [445],
         [354],
         [200],
         [758],
         [321],
         [705],
         [11],
         [474],
         [296],
         [650],
         [715],
         [435],
         [552],
         [532],
         [318],
         [269],
         [210],
         [269],
         [666],
         [64],
         [49],
         [308],
         [907],
         [516],
         [405],
         [838],
         [463],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [3],
       [4],
       [5],
       [4],
       [4],
       [3],
       [4],
       [3],
       [4],
       [3],
       [4],
       [3],
       [4],
       [4],
       [2],
       [3],
       [2],
       [4],
       [5],
       [4],
       [5],
       [5],
       [5],
       [2],
       [3],
       [4],
       [2],
       [4],
       [4],
       [2],
       [4],
       [3],
       [2],
       [2],
       [2],
       [3],
       [5],
       [4],
       [2],
       [4],
       [4],
       [2],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [549],
         [318],
         [60],
         [167],
         [81],
         [996],
         [1132],
         [132],
         [22],
         [123],
         [94],
         [33],
         [433],
         [501],
         [693],
         [29],
         [216],
         [934],
         [1245],
         [657],
         [248],
         [873],
         [684],
         [633],
         [530],
         [202],
         [141],
         [185],
         [238],
         [29],
         [755],
         [67],
         [810],
         [228],
         [70],
         [441],
         [191],
         [1441],
         [151],
         [117],
         [934],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [474],
         [437],
         [712],
         [210],
         [18],
         [500],
         [279],
         [934],
         [553],
         [790],
         [823],
         [327],
         [889],
         [49],
         [796],
         [889],
         [89],
         [907],
         [445],
         [491],
         [782],
         [13],
         [880],
         [748],
         [458],
         [476],
         [497],
         [138],
         [394],
         [650],
         [303],
         [497],
         [267],
         [417],
         [795],
         [270],
         [327],
         [416],
         [495],
         [747],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [5],
       [4],
       [1],
       [4],
       [3],
       [1],
       [1],
       [4],
       [5],
       [3],
       [2],
       [3],
       [4],
       [3],
       [3],
       [3],
       [5],
       [4],
       [1],
       [5],
       [4],
       [1],
       [4],
       [4],
       [4],
       [4],
       [3],
       [4],
       [5],
       [2],
       [2],
       [3],
       [3],
       [3],
       [3],
       [5],
       [4],
       [3],
       [5],
       [2],
       [1],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [257],
         [145],
         [879],
         [177],
         [432],
         [612],
         [83],
         [434],
         [1119],
         [705],
         [419],
         [238],
         [919],
         [143],
         [255],
         [68],
         [211],
         [271],
         [1609],
         [222],
         [886],
         [458],
         [588],
         [494],
         [707],
         [102],
         [73],
         [301],
         [679],
         [530],
         [191],
         [172],
         [631],
         [459],
         [200],
         [226],
         [111],
         [154],
         [121],
         [919],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [251],
         [262],
         [100],
         [13],
         [618],
         [524],
         [249],
         [59],
         [457],
         [942],
         [919],
         [334],
         [634],
         [123],
         [274],
         [619],
         [65],
         [177],
         [486],
         [250],
         [894],
         [786],
         [200],
         [506],
         [405],
         [642],
         [416],
         [747],
         [293],
         [537],
         [381],
         [48],
         [664],
         [823],
         [479],
         [347],
         [478],
         [132],
         [432],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [3],
       [1],
       [4],
       [5],
       [5],
       [3],
       [5],
       [4],
       [4],
       [4],
       [5],
       [4],
       [2],
       [5],
       [2],
       [3],
       [4],
       [2],
       [3],
       [4],
       [3],
       [3],
       [5],
       [5],
       [1],
       [5],
       [3],
       [1],
       [2],
       [4],
       [5],
       [5],
       [4],
       [4],
       [5],
       [4],
       [3],
       [4],
       [4],
       [2],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [939],
         [95],
         [1228],
         [1566],
         [335],
         [128],
         [25],
         [255],
         [12],
         [270],
         [127],
         [1129],
         [167],
         [483],
         [170],
         [66],
         [1086],
         [449],
         [1628],
         [575],
         [452],
         [427],
         [173],
         [19],
         [83],
         [134],
         [66],
         [873],
         [174],
         [48],
         [612],
         [469],
         [631],
         [79],
         [179],
         [496],
         [67],
         [387],
         [566],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [405],
         [38],
         [276],
         [405],
         [479],
         [653],
         [727],
         [599],
         [506],
         [353],
         [91],
         [257],
         [890],
         [90],
         [592],
         [429],
         [936],
         [109],
         [655],
         [406],
         [372],
         [474],
         [437],
         [837],
         [642],
         [269],
         [671],
         [587],
         [886],
         [69],
         [468],
         [221],
         [385],
         [70],
         [291],
         [645],
         [1],
         [503],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [5],
       [5],
       [1],
       [1],
       [3],
       [3],
       [3],
       [5],
       [5],
       [2],
       [5],
       [5],
       [2],
       [5],
       [5],
       [2],
       [3],
       [5],
       [2],
       [1],
       [4],
       [5],
       [4],
       [4],
       [5],
       [4],
       [5],
       [3],
       [5],
       [5],
       [4],
       [3],
       [3],
       [4],
       [5],
       [3],
       [3],
       [4],
       [4],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [689],
         [890],
         [293],
         [316],
         [959],
         [631],
         [586],
         [257],
         [659],
         [866],
         [508],
         [549],
         [289],
         [134],
         [159],
         [225],
         [288],
         [597],
         [134],
         [441],
         [79],
         [1296],
         [216],
         [449],
         [191],
         [71],
         [424],
         [546],
         [367],
         [294],
         [794],
         [82],
         [233],
         [64],
         [235],
         [331],
         [1039],
         [634],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [507],
         [717],
         [363],
         [383],
         [456],
         [307],
         [217],
         [104],
         [189],
         [82],
         [605],
         [577],
         [413],
         [303],
         [521],
         [75],
         [286],
         [164],
         [296],
         [881],
         [567],
         [628],
         [591],
         [603],
         [92],
         [630],
         [939],
         [119],
         [327],
         [416],
         [880],
         [804],
         [130],
         [236],
         [141],
         [489],
         [303],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [5],
       [1],
       [4],
       [5],
       [4],
       [3],
       [2],
       [4],
       [4],
       [3],
       [5],
       [5],
       [4],
       [5],
       [3],
       [2],
       [5],
       [4],
       [5],
       [2],
       [2],
       [5],
       [4],
       [4],
       [4],
       [3],
       [3],
       [4],
       [4],
       [4],
       [4],
       [5],
       [4],
       [5],
       [1],
       [5],
       [5],
       [4],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [86],
         [1172],
         [298],
         [293],
         [515],
         [124],
         [172],
         [42],
         [69],
         [780],
         [796],
         [265],
         [194],
         [350],
         [1075],
         [686],
         [120],
         [234],
         [655],
         [69],
         [93],
         [949],
         [290],
         [1509],
         [302],
         [125],
         [178],
         [604],
         [282],
         [14],
         [1016],
         [633],
         [82],
         [290],
         [1048],
         [996],
         [205],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [894],
         [474],
         [730],
         [439],
         [561],
         [79],
         [933],
         [305],
         [693],
         [291],
         [551],
         [846],
         [707],
         [428],
         [437],
         [128],
         [457],
         [96],
         [833],
         [398],
         [270],
         [503],
         [42],
         [303],
         [475],
         [82],
         [474],
         [401],
         [423],
         [463],
         [249],
         [416],
         [830],
         [207],
         [634],
         [87],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [4],
       [4],
       [4],
       [3],
       [3],
       [5],
       [2],
       [4],
       [3],
       [5],
       [4],
       [5],
       [4],
       [4],
       [4],
       [4],
       [2],
       [4],
       [2],
       [5],
       [5],
       [3],
       [3],
       [1],
       [3],
       [3],
       [4],
       [4],
       [4],
       [1],
       [3],
       [4],
       [3],
       [2],
       [3],
       [3],
       [4],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [346],
         [492],
         [89],
         [50],
         [9],
         [517],
         [203],
         [22],
         [89],
         [305],
         [211],
         [327],
         [1074],
         [237],
         [512],
         [692],
         [433],
         [1001],
         [143],
         [54],
         [199],
         [378],
         [117],
         [112],
         [418],
         [121],
         [971],
         [1074],
         [785],
         [245],
         [270],
         [1071],
         [474],
         [186],
         [117],
         [384],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [99],
         [325],
         [58],
         [645],
         [725],
         [327],
         [391],
         [188],
         [350],
         [40],
         [918],
         [629],
         [299],
         [864],
         [214],
         [468],
         [207],
         [57],
         [290],
         [363],
         [479],
         [5],
         [124],
         [708],
         [919],
         [493],
         [828],
         [579],
         [13],
         [433],
         [587],
         [429],
         [447],
         [665],
         [722],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [4],
       [4],
       [3],
       [4],
       [4],
       [2],
       [4],
       [5],
       [4],
       [4],
       [2],
       [3],
       [3],
       [4],
       [5],
       [4],
       [3],
       [1],
       [5],
       [3],
       [5],
       [1],
       [3],
       [1],
       [4],
       [5],
       [4],
       [3],
       [3],
       [3],
       [4],
       [2],
       [3],
       [4],
       [4],
       [3],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [498],
         [979],
         [228],
         [234],
         [1047],
         [939],
         [81],
         [609],
         [405],
         [840],
         [91],
         [1101],
         [198],
         [122],
         [322],
         [371],
         [4],
         [248],
         [561],
         [274],
         [245],
         [1144],
         [82],
         [210],
         [527],
         [642],
         [342],
         [282],
         [209],
         [49],
         [459],
         [133],
         [121],
         [21],
         [288],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [409],
         [378],
         [92],
         [667],
         [865],
         [586],
         [478],
         [312],
         [493],
         [532],
         [468],
         [643],
         [246],
         [757],
         [329],
         [337],
         [13],
         [682],
         [788],
         [881],
         [554],
         [69],
         [387],
         [56],
         [458],
         [690],
         [863],
         [495],
         [5],
         [916],
         [932],
         [151],
         [872],
         [307],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [4],
       [3],
       [4],
       [2],
       [1],
       [4],
       [4],
       [3],
       [2],
       [4],
       [5],
       [3],
       [4],
       [1],
       [3],
       [4],
       [5],
       [3],
       [3],
       [3],
       [3],
       [5],
       [4],
       [5],
       [2],
       [3],
       [1],
       [5],
       [5],
       [3],
       [4],
       [5],
       [4],
       [4],
       [4],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [226],
         [28],
         [921],
         [129],
         [326],
         [292],
         [1],
         [213],
         [1234],
         [879],
         [213],
         [393],
         [167],
         [735],
         [510],
         [432],
         [673],
         [683],
         [302],
         [261],
         [69],
         [333],
         [462],
         [132],
         [175],
         [64],
         [237],
         [843],
         [719],
         [413],
         [816],
         [744],
         [239],
         [502],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [313],
         [738],
         [878],
         [655],
         [616],
         [351],
         [655],
         [80],
         [845],
         [688],
         [330],
         [56],
         [682],
         [236],
         [600],
         [239],
         [639],
         [863],
         [171],
         [851],
         [887],
         [851],
         [758],
         [399],
         [896],
         [123],
         [751],
         [293],
         [405],
         [604],
         [21],
         [306],
         [293],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [4],
       [4],
       [4],
       [3],
       [3],
       [4],
       [2],
       [3],
       [4],
       [5],
       [5],
       [4],
       [2],
       [5],
       [5],
       [5],
       [4],
       [1],
       [4],
       [3],
       [4],
       [5],
       [4],
       [3],
       [2],
       [3],
       [2],
       [3],
       [1],
       [3],
       [1],
       [4],
       [3],
       [5],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [1011],
         [481],
         [8],
         [478],
         [402],
         [69],
         [1031],
         [554],
         [732],
         [1483],
         [50],
         [240],
         [762],
         [73],
         [1],
         [12],
         [559],
         [19],
         [435],
         [257],
         [153],
         [665],
         [1136],
         [498],
         [423],
         [38],
         [815],
         [780],
         [641],
         [299],
         [64],
         [584],
         [483],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [458],
         [553],
         [746],
         [450],
         [505],
         [458],
         [38],
         [671],
         [655],
         [416],
         [49],
         [620],
         [145],
         [7],
         [545],
         [327],
         [561],
         [52],
         [82],
         [221],
         [922],
         [826],
         [314],
         [748],
         [387],
         [389],
         [376],
         [99],
         [469],
         [159],
         [314],
         [807],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [3],
       [3],
       [4],
       [5],
       [5],
       [2],
       [5],
       [4],
       [3],
       [4],
       [1],
       [5],
       [3],
       [3],
       [5],
       [3],
       [1],
       [5],
       [5],
       [4],
       [4],
       [5],
       [5],
       [4],
       [3],
       [2],
       [3],
       [5],
       [4],
       [3],
       [5],
       [4],
       [5],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [1],
         [102],
         [1226],
         [511],
         [9],
         [202],
         [1101],
         [259],
         [602],
         [325],
         [578],
         [271],
         [1214],
         [1168],
         [177],
         [1044],
         [85],
         [1136],
         [134],
         [631],
         [627],
         [8],
         [1039],
         [286],
         [846],
         [300],
         [294],
         [133],
         [4],
         [8],
         [949],
         [844],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [390],
         [330],
         [592],
         [382],
         [486],
         [904],
         [796],
         [32],
         [807],
         [250],
         [429],
         [303],
         [655],
         [389],
         [382],
         [524],
         [406],
         [627],
         [911],
         [291],
         [95],
         [506],
         [642],
         [284],
         [334],
         [323],
         [341],
         [60],
         [738],
         [24],
         [533],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [5],
       [4],
       [4],
       [4],
       [5],
       [2],
       [5],
       [2],
       [5],
       [4],
       [3],
       [2],
       [2],
       [3],
       [4],
       [4],
       [2],
       [4],
       [4],
       [5],
       [4],
       [5],
       [5],
       [4],
       [3],
       [2],
       [3],
       [4],
       [4],
       [5],
       [4],
       [4],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [511],
         [1011],
         [347],
         [172],
         [233],
         [418],
         [522],
         [476],
         [277],
         [127],
         [276],
         [217],
         [1011],
         [174],
         [1204],
         [281],
         [97],
         [984],
         [327],
         [772],
         [923],
         [125],
         [53],
         [682],
         [181],
         [181],
         [22],
         [172],
         [491],
         [181],
         [641],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [541],
         [936],
         [174],
         [121],
         [375],
         [454],
         [269],
         [933],
         [151],
         [474],
         [447],
         [617],
         [592],
         [398],
         [608],
         [207],
         [346],
         [332],
         [314],
         [487],
         [535],
         [46],
         [94],
         [550],
         [753],
         [837],
         [328],
         [807],
         [58],
         [121],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [4],
       [4],
       [4],
       [5],
       [4],
       [3],
       [5],
       [2],
       [4],
       [5],
       [4],
       [1],
       [4],
       [5],
       [2],
       [3],
       [4],
       [2],
       [4],
       [3],
       [4],
       [4],
       [4],
       [4],
       [3],
       [3],
       [5],
       [5],
       [4],
       [5],
       [4],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [301],
         [517],
         [234],
         [309],
         [427],
         [924],
         [644],
         [742],
         [66],
         [273],
         [266],
         [47],
         [313],
         [707],
         [1444],
         [147],
         [582],
         [378],
         [182],
         [197],
         [265],
         [845],
         [428],
         [909],
         [109],
         [153],
         [406],
         [665],
         [97],
         [227],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [489],
         [716],
         [109],
         [308],
         [617],
         [537],
         [870],
         [387],
         [216],
         [456],
         [451],
         [325],
         [499],
         [488],
         [882],
         [416],
         [778],
         [200],
         [269],
         [912],
         [922],
         [862],
         [269],
         [655],
         [13],
         [291],
         [601],
         [374],
         [840],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [3],
       [5],
       [4],
       [1],
       [4],
       [3],
       [2],
       [2],
       [2],
       [3],
       [2],
       [3],
       [5],
       [2],
       [4],
       [5],
       [1],
       [5],
       [4],
       [5],
       [5],
       [4],
       [5],
       [3],
       [4],
       [4],
       [2],
       [4],
       [3],
       [2],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [679],
         [60],
         [1121],
         [198],
         [1293],
         [661],
         [237],
         [528],
         [217],
         [227],
         [81],
         [82],
         [364],
         [232],
         [1016],
         [259],
         [768],
         [527],
         [421],
         [1115],
         [230],
         [731],
         [1118],
         [333],
         [456],
         [332],
         [466],
         [300],
         [232],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [833],
         [639],
         [585],
         [727],
         [519],
         [785],
         [653],
         [452],
         [264],
         [200],
         [279],
         [216],
         [642],
         [746],
         [593],
         [559],
         [311],
         [682],
         [727],
         [460],
         [417],
         [158],
         [621],
         [63],
         [825],
         [140],
         [6],
         [621],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [3],
       [3],
       [4],
       [4],
       [5],
       [3],
       [2],
       [4],
       [3],
       [5],
       [4],
       [4],
       [5],
       [3],
       [4],
       [3],
       [2],
       [3],
       [5],
       [3],
       [3],
       [2],
       [3],
       [4],
       [3],
       [3],
       [4],
       [3],
       [3],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [1037],
         [477],
         [750],
         [274],
         [1018],
         [989],
         [1072],
         [118],
         [83],
         [1039],
         [596],
         [269],
         [607],
         [685],
         [618],
         [429],
         [411],
         [32],
         [17],
         [31],
         [474],
         [333],
         [328],
         [293],
         [273],
         [405],
         [12],
         [1278],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [760],
         [938],
         [919],
         [298],
         [429],
         [782],
         [87],
         [68],
         [492],
         [253],
         [815],
         [448],
         [889],
         [221],
         [450],
         [350],
         [144],
         [354],
         [496],
         [741],
         [568],
         [464],
         [116],
         [416],
         [697],
         [442],
         [491],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [5],
       [1],
       [3],
       [3],
       [3],
       [3],
       [3],
       [2],
       [4],
       [4],
       [5],
       [5],
       [4],
       [3],
       [4],
       [4],
       [4],
       [3],
       [3],
       [3],
       [5],
       [4],
       [3],
       [5],
       [5],
       [3],
       [5],
       [5],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [174],
         [1314],
         [100],
         [1280],
         [268],
         [354],
         [93],
         [173],
         [44],
         [19],
         [173],
         [98],
         [475],
         [1141],
         [467],
         [11],
         [656],
         [91],
         [410],
         [302],
         [135],
         [182],
         [286],
         [430],
         [257],
         [422],
         [176],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [780],
         [425],
         [870],
         [286],
         [193],
         [863],
         [439],
         [271],
         [504],
         [558],
         [454],
         [465],
         [222],
         [64],
         [833],
         [201],
         [643],
         [790],
         [727],
         [880],
         [682],
         [331],
         [16],
         [833],
         [697],
         [484],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [5],
       [3],
       [4],
       [5],
       [3],
       [1],
       [4],
       [4],
       [4],
       [5],
       [2],
       [4],
       [4],
       [5],
       [2],
       [4],
       [4],
       [3],
       [2],
       [5],
       [4],
       [4],
       [2],
       [4],
       [5],
       [3],
       [4],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [288],
         [516],
         [107],
         [1051],
         [426],
         [778],
         [498],
         [96],
         [136],
         [193],
         [48],
         [177],
         [277],
         [545],
         [357],
         [318],
         [210],
         [332],
         [631],
         [96],
         [529],
         [960],
         [234],
         [751],
         [268],
         [283],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [773],
         [308],
         [697],
         [637],
         [301],
         [914],
         [406],
         [436],
         [532],
         [881],
         [385],
         [671],
         [234],
         [774],
         [25],
         [912],
         [406],
         [446],
         [327],
         [655],
         [60],
         [655],
         [442],
         [782],
         [715],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [2],
       [4],
       [5],
       [2],
       [4],
       [5],
       [5],
       [4],
       [5],
       [5],
       [5],
       [4],
       [3],
       [1],
       [4],
       [4],
       [5],
       [3],
       [3],
       [3],
       [4],
       [3],
       [4],
       [2],
       [4],
       [4],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [686],
         [552],
         [234],
         [429],
         [313],
         [48],
         [393],
         [661],
         [419],
         [567],
         [11],
         [185],
         [187],
         [120],
         [455],
         [833],
         [395],
         [35],
         [201],
         [1065],
         [385],
         [515],
         [528],
         [1175],
         [668],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [276],
         [95],
         [244],
         [387],
         [711],
         [474],
         [690],
         [506],
         [416],
         [802],
         [592],
         [308],
         [780],
         [125],
         [181],
         [663],
         [881],
         [405],
         [198],
         [269],
         [566],
         [716],
         [457],
         [86],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [3],
       [1],
       [3],
       [3],
       [4],
       [4],
       [4],
       [5],
       [4],
       [4],
       [5],
       [4],
       [5],
       [1],
       [1],
       [4],
       [3],
       [2],
       [3],
       [5],
       [3],
       [5],
       [5],
       [5],
       [4],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [546],
         [281],
         [226],
         [879],
         [19],
         [300],
         [762],
         [1483],
         [468],
         [7],
         [1017],
         [8],
         [73],
         [269],
         [168],
         [334],
         [153],
         [51],
         [173],
         [258],
         [471],
         [273],
         [99],
         [170],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [805],
         [655],
         [561],
         [721],
         [321],
         [701],
         [634],
         [676],
         [311],
         [697],
         [339],
         [409],
         [864],
         [2],
         [653],
         [557],
         [218],
         [610],
         [64],
         [877],
         [360],
         [301],
         [401],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [2],
       [2],
       [1],
       [4],
       [4],
       [3],
       [3],
       [4],
       [4],
       [5],
       [5],
       [3],
       [5],
       [4],
       [3],
       [4],
       [4],
       [5],
       [5],
       [4],
       [4],
       [1],
       [4],
       [3],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [81],
         [289],
         [372],
         [82],
         [748],
         [1165],
         [93],
         [326],
         [117],
         [620],
         [200],
         [151],
         [176],
         [323],
         [1149],
         [202],
         [121],
         [227],
         [378],
         [980],
         [283],
         [269],
         [1411],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [640],
         [69],
         [883],
         [622],
         [258],
         [344],
         [151],
         [149],
         [251],
         [907],
         [269],
         [852],
         [488],
         [515],
         [606],
         [840],
         [349],
         [921],
         [694],
         [561],
         [361],
         [400],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [5],
       [4],
       [3],
       [3],
       [5],
       [1],
       [5],
       [3],
       [4],
       [4],
       [4],
       [4],
       [4],
       [3],
       [4],
       [5],
       [2],
       [3],
       [3],
       [3],
       [4],
       [4],
       [1],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [168],
         [1013],
         [88],
         [322],
         [881],
         [1209],
         [234],
         [625],
         [899],
         [559],
         [64],
         [215],
         [150],
         [248],
         [483],
         [1381],
         [268],
         [509],
         [153],
         [690],
         [83],
         [198],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [850],
         [459],
         [758],
         [397],
         [510],
         [487],
         [90],
         [49],
         [418],
         [286],
         [886],
         [561],
         [640],
         [727],
         [499],
         [662],
         [708],
         [244],
         [539],
         [324],
         [698],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [5],
       [3],
       [4],
       [1],
       [2],
       [4],
       [4],
       [3],
       [5],
       [4],
       [5],
       [3],
       [4],
       [5],
       [5],
       [5],
       [3],
       [5],
       [5],
       [4],
       [5],
       [5],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [507],
         [651],
         [64],
         [1419],
         [728],
         [338],
         [1549],
         [416],
         [767],
         [405],
         [1053],
         [1314],
         [283],
         [326],
         [283],
         [895],
         [64],
         [226],
         [274],
         [526],
         [847],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [60],
         [537],
         [472],
         [197],
         [709],
         [262],
         [655],
         [453],
         [617],
         [22],
         [43],
         [268],
         [786],
         [832],
         [527],
         [365],
         [288],
         [345],
         [396],
         [11],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [4],
       [3],
       [5],
       [2],
       [4],
       [4],
       [2],
       [2],
       [3],
       [1],
       [3],
       [2],
       [4],
       [4],
       [4],
       [4],
       [5],
       [3],
       [4],
       [3],
       [4],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [201],
         [144],
         [921],
         [546],
         [275],
         [483],
         [756],
         [742],
         [237],
         [1228],
         [283],
         [864],
         [1411],
         [257],
         [423],
         [553],
         [122],
         [751],
         [1109],
         [349],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [682],
         [374],
         [591],
         [717],
         [308],
         [312],
         [200],
         [552],
         [84],
         [648],
         [558],
         [94],
         [385],
         [642],
         [20],
         [883],
         [854],
         [676],
         [59],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [4],
       [5],
       [4],
       [3],
       [4],
       [5],
       [3],
       [4],
       [4],
       [3],
       [3],
       [2],
       [3],
       [5],
       [2],
       [4],
       [3],
       [4],
       [3],
       [4],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [926],
         [12],
         [1011],
         [596],
         [165],
         [685],
         [829],
         [25],
         [226],
         [333],
         [369],
         [672],
         [449],
         [382],
         [530],
         [473],
         [298],
         [864],
         [928],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [822],
         [435],
         [194],
         [312],
         [232],
         [348],
         [347],
         [56],
         [843],
         [809],
         [501],
         [92],
         [495],
         [327],
         [350],
         [168],
         [222],
         [344],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [2],
       [5],
       [3],
       [5],
       [4],
       [4],
       [4],
       [4],
       [3],
       [3],
       [4],
       [3],
       [5],
       [3],
       [4],
       [2],
       [4],
       [3],
       [3],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [99],
         [144],
         [281],
         [896],
         [419],
         [742],
         [307],
         [86],
         [254],
         [252],
         [902],
         [197],
         [887],
         [28],
         [228],
         [151],
         [233],
         [79],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [70],
         [738],
         [294],
         [871],
         [885],
         [207],
         [276],
         [276],
         [541],
         [936],
         [269],
         [123],
         [724],
         [128],
         [455],
         [374],
         [83],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [4],
       [5],
       [3],
       [3],
       [4],
       [4],
       [4],
       [3],
       [3],
       [2],
       [5],
       [5],
       [3],
       [5],
       [4],
       [3],
       [4],
       [2],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [748],
         [281],
         [225],
         [558],
         [881],
         [521],
         [52],
         [134],
         [91],
         [70],
         [610],
         [832],
         [98],
         [250],
         [308],
         [482],
         [724],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [181],
         [438],
         [57],
         [604],
         [126],
         [833],
         [90],
         [492],
         [405],
         [642],
         [807],
         [450],
         [392],
         [500],
         [281],
         [747],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [1],
       [4],
       [3],
       [4],
       [5],
       [4],
       [5],
       [3],
       [2],
       [2],
       [3],
       [2],
       [5],
       [4],
       [1],
       [5],
       [3],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [483],
         [1218],
         [864],
         [1016],
         [151],
         [510],
         [363],
         [38],
         [22],
         [286],
         [183],
         [135],
         [12],
         [9],
         [1176],
         [768],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [900],
         [405],
         [825],
         [268],
         [128],
         [488],
         [289],
         [826],
         [123],
         [580],
         [25],
         [374],
         [72],
         [27],
         [557],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [4],
       [5],
       [3],
       [3],
       [3],
       [4],
       [3],
       [3],
       [4],
       [4],
       [4],
       [4],
       [5],
       [4],
       [5],
       [3],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [603],
         [127],
         [107],
         [297],
         [289],
         [181],
         [323],
         [504],
         [544],
         [658],
         [318],
         [576],
         [1501],
         [356],
         [539],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [201],
         [938],
         [416],
         [733],
         [909],
         [779],
         [396],
         [269],
         [221],
         [286],
         [776],
         [543],
         [655],
         [347],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [4],
       [5],
       [5],
       [3],
       [3],
       [5],
       [4],
       [4],
       [4],
       [5],
       [4],
       [4],
       [3],
       [5],
       [2],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [678],
         [203],
         [94],
         [100],
         [363],
         [65],
         [240],
         [275],
         [176],
         [232],
         [506],
         [327],
         [212],
         [21],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [605],
         [458],
         [727],
         [237],
         [619],
         [116],
         [457],
         [321],
         [292],
         [178],
         [312],
         [656],
         [889],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [1],
       [5],
       [4],
       [5],
       [2],
       [2],
       [3],
       [4],
       [5],
       [5],
       [4],
       [2],
       [2],
       [3],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [472],
         [173],
         [781],
         [845],
         [483],
         [491],
         [428],
         [174],
         [197],
         [237],
         [272],
         [268],
         [196],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [727],
         [28],
         [788],
         [768],
         [354],
         [59],
         [5],
         [215],
         [25],
         [215],
         [126],
         [834],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [2],
       [3],
       [3],
       [2],
       [4],
       [4],
       [5],
       [4],
       [3],
       [4],
       [3],
       [3],
       [3],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [204],
         [810],
         [675],
         [1037],
         [8],
         [1225],
         [88],
         [302],
         [550],
         [95],
         [288],
         [1073],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [847],
         [222],
         [405],
         [476],
         [758],
         [94],
         [43],
         [915],
         [524],
         [334],
         [314],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [4],
       [2],
       [1],
       [1],
       [5],
       [3],
       [5],
       [4],
       [3],
       [3],
       [5],
       [3],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [699],
         [151],
         [541],
         [137],
         [216],
         [211],
         [118],
         [258],
         [284],
         [197],
         [504],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [342],
         [536],
         [586],
         [537],
         [665],
         [339],
         [7],
         [811],
         [825],
         [561],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [4],
       [3],
       [3],
       [4],
       [4],
       [5],
       [2],
       [5],
       [3],
       [4],
       [5],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [79],
         [474],
         [748],
         [444],
         [568],
         [117],
         [844],
         [221],
         [196],
         [754],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [320],
         [830],
         [204],
         [5],
         [102],
         [398],
         [658],
         [463],
         [379],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [4],
       [5],
       [1],
       [2],
       [2],
       [4],
       [3],
       [5],
       [4],
       [4],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [780],
         [150],
         [211],
         [312],
         [153],
         [59],
         [70],
         [514],
         [223],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [222],
         [608],
         [350],
         [239],
         [629],
         [645],
         [829],
         [232],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [3],
       [3],
       [2],
       [2],
       [5],
       [5],
       [4],
       [4],
       [4],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [13],
         [134],
         [754],
         [405],
         [566],
         [582],
         [381],
         [393],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [569],
         [313],
         [404],
         [932],
         [95],
         [875],
         [95],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [3],
       [5],
       [3],
       [4],
       [2],
       [5],
       [4],
       [3],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [52],
         [823],
         [259],
         [69],
         [663],
         [1124],
         [728],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [457],
         [276],
         [688],
         [280],
         [13],
         [716],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [4],
       [3],
       [5],
       [4],
       [5],
       [3],
       [4],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [892],
         [286],
         [611],
         [1050],
         [150],
         [1091],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [676],
         [788],
         [321],
         [344],
         [382],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [4],
       [5],
       [4],
       [3],
       [2],
       [3],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [629],
         [875],
         [217],
         [181],
         [622],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [495],
         [755],
         [551],
         [409],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [3],
       [1],
       [1],
       [4],
       [3],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [498],
         [168],
         [76],
         [833],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [868],
         [405],
         [788],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [3],
       [1],
       [3],
       [4],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [127],
         [411],
         [98],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [805],
         [332],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [3],
       [4],
       [4],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [68],
         [1285],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [715],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [4],
       [3],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         [343],
         ...
       ]
     >,
     "user_input" => #Nx.Tensor<
       s64[1000][1]
       [
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       [2],
       ...
     ]
   >},
  {%{
     "movie_input" => #Nx.Tensor<
       s64[1000][1]
       [
         ...
       ]
     >,
     ...
   },
   #Nx.Tensor<
     s64[1000][1]
     [
       ...
     ]
   >},
  {%{...}, ...},
  {...},
  ...
]

Now I defined the model. 2 inputs, each one goes through an embedding layer or 50 "factors".

Then factors are multiplied and summed:

n_factors = 50

user_input =
  Axon.input("user_input", shape: {batch_size, 1})
  |> Axon.embedding(size, n_factors)

movie_input =
  Axon.input("movie_input", shape: {batch_size, 1})
  |> Axon.embedding(size, n_factors)

model =
  Axon.multiply(user_input, movie_input)
  |> Axon.nx(&Nx.sum(&1))

Axon.Display.as_table(model, Nx.template({batch_size, 1}, :s64)) |> IO.puts()
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|                                                                                   Model                                                                                   |
+==========================================================+==================================+================================+==================+=========================+
| Layer                                                    | Input Shape                      | Output Shape                   | Options          | Parameters              |
+==========================================================+==================================+================================+==================+=========================+
| user_input ( input )                                     | []                               | {1000, 1}                      | shape: {1000, 1} |                         |
|                                                          |                                  |                                | optional: false  |                         |
+----------------------------------------------------------+----------------------------------+--------------------------------+------------------+-------------------------+
| embedding_0 ( embedding["user_input"] )                  | [{1000, 1}]                      | {1000, 1, 50}                  |                  | kernel: f32[100000][50] |
+----------------------------------------------------------+----------------------------------+--------------------------------+------------------+-------------------------+
| movie_input ( input )                                    | []                               | {1000, 1}                      | shape: {1000, 1} |                         |
|                                                          |                                  |                                | optional: false  |                         |
+----------------------------------------------------------+----------------------------------+--------------------------------+------------------+-------------------------+
| embedding_1 ( embedding["movie_input"] )                 | [{1000, 1}]                      | {1000, 1, 50}                  |                  | kernel: f32[100000][50] |
+----------------------------------------------------------+----------------------------------+--------------------------------+------------------+-------------------------+
| container_0 ( container {"embedding_0", "embedding_1"} ) | {}                               | {{1000, 1, 50}, {1000, 1, 50}} |                  |                         |
+----------------------------------------------------------+----------------------------------+--------------------------------+------------------+-------------------------+
| multiply_0 ( multiply["container_0"] )                   | [{{1000, 1, 50}, {1000, 1, 50}}] | {1000, 1, 50}                  |                  |                         |
+----------------------------------------------------------+----------------------------------+--------------------------------+------------------+-------------------------+
| nx_0 ( nx["multiply_0"] )                                | [{1000, 1, 50}]                  | {}                             |                  |                         |
+----------------------------------------------------------+----------------------------------+--------------------------------+------------------+-------------------------+
Total Parameters: 10000000
Total Parameters Memory: 40000000 bytes

:ok

Then comes the training...

loop = Axon.Loop.trainer(model, :mean_squared_error, Axon.Optimizers.sgd(0.029), log: 10)
params = Axon.Loop.run(loop, train_inputs, %{}, epochs: 5)
Epoch: 0, Batch: 70, loss: 3.2608242
Epoch: 1, Batch: 70, loss: 2.2175949
Epoch: 2, Batch: 70, loss: 1.8950887
Epoch: 3, Batch: 70, loss: 1.7377074
Epoch: 4, Batch: 70, loss: 1.6443866
%{
  "embedding_0" => %{
    "kernel" => #Nx.Tensor<
      f32[100000][50]
      EXLA.Backend<host:0, 0.4086315873.1394475020.251918>
      [
        [0.0017787455581128597, -0.0033857631497085094, -0.007207505404949188, -5.064010474598035e-5, -0.0015526985516771674, -0.0015923524042591453, -0.00791232567280531, -0.00362397194840014, 0.0036139844451099634, -0.0049926540814340115, 0.0013311314396560192, 0.004891934338957071, 0.009271780960261822, 0.0012215184979140759, 0.004105870611965656, -0.008875987492501736, 0.00799651350826025, 0.005906536243855953, 0.009425980970263481, -1.5794753562659025e-4, -0.007067639846354723, 0.00678034033626318, 0.009312136098742485, -0.008945650421082973, -0.0035132288467139006, 0.0015603804495185614, -0.00865252036601305, 0.009704163298010826, -0.004609372466802597, -6.457519484683871e-4, -8.309435797855258e-4, -0.006967253517359495, -0.006095163524150848, -0.006676304154098034, -5.000352975912392e-4, -0.006253594998270273, 0.008416946046054363, -0.004622812382876873, -0.00874529592692852, -0.008357870392501354, -0.004108376335352659, -0.0014787721447646618, 1.672506332397461e-4, 0.0015589212998747826, 0.008024237118661404, -0.0015777706867083907, 0.006632501725107431, -5.642509204335511e-4, ...],
        ...
      ]
    >
  },
  "embedding_1" => %{
    "kernel" => #Nx.Tensor<
      f32[100000][50]
      EXLA.Backend<host:0, 0.4086315873.1394475020.251919>
      [
        [-3.9393422775901854e-4, -4.997325013391674e-4, 0.008873040787875652, -0.00902603566646576, -0.0016105007380247116, -0.009117345325648785, 0.0011610769433900714, -0.005720951594412327, 0.005600349977612495, -0.001416308805346489, 0.0029332065023481846, -0.00965881533920765, 0.009536738507449627, 8.892511832527816e-4, -0.0064435601234436035, -0.0017142700962722301, -0.006117257755249739, -0.0052512288093566895, -0.0037650822196155787, -2.108240150846541e-4, -0.006255025509744883, 0.006375038530677557, -0.007711143232882023, 0.002668146975338459, 8.393215830437839e-4, 0.00433955667540431, -0.007989172823727131, 0.0019231723854318261, 0.005344254896044731, 0.002423608209937811, -0.005921106319874525, -7.3070521466434e-5, 0.0034095309674739838, -0.004935376346111298, -0.007925936952233315, 0.008986014872789383, 0.006540043279528618, 0.0019592808093875647, -0.007518403232097626, -0.00981221441179514, -0.0044106170535087585, -0.001608817488886416, 0.00894960667937994, -0.004569544922560453, 0.005322928540408611, 0.0023723600897938013, 5.017256480641663e-4, ...],
        ...
      ]
    >
  }
}

The loss is pretty bad compared to what Jeremy Howard managed to get in his Notebook:

I don't understand why. Our models seems similar to me. I tried with batch size of 64 but it was way to slow. I tried to tune the learning rate, if I use something like him 5e-3, the loss is even worst.

And if I try my model for example:

require Explorer.DataFrame

DataFrame.filter(df, col("movie") == 1)
#Explorer.DataFrame<
  Polars[452 x 4]
  user integer [308, 287, 148, 280, 66, ...]
  movie integer [1, 1, 1, 1, 1, ...]
  rating integer [4, 5, 4, 4, 3, ...]
  timestamp integer [887736532, 875334088, 877019411, 891700426, 883601324, ...]
>

Sor user 308 rated the movie 1 with 4, let see how the model behaves:

{_, predict_fn} = Axon.build(model)
predict_fn.(params, %{"user_input" => Nx.tensor([308]), "movie_input" => Nx.tensor([1])})
#Nx.Tensor<
  f32
  EXLA.Backend<host:0, 0.4086315873.1394475020.247481>
  0.017103202641010284
>

Very far from what is expected ...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment