Skip to content

Instantly share code, notes, and snippets.

@cpcloud
Created June 26, 2023 19:18
Show Gist options
  • Save cpcloud/84a4e3eb4df56812db1fc488ff120cb8 to your computer and use it in GitHub Desktop.
Save cpcloud/84a4e3eb4df56812db1fc488ff120cb8 to your computer and use it in GitHub Desktop.
Ibis, DuckDB and PyTorch
import ibis.expr.datatypes as dt
import torch
import torch.nn as nn
import tqdm
import pyarrow as pa
class LinearRegression(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, distances):
return self.linear(distances)
class PredictCabFare:
def __init__(self, data, learning_rate: float = 0.01, epochs: int = 100) -> None:
# Define the input and output dimensions
input_dim = 1
output_dim = 1
# Create a linear regression model instance
self.data = data
self.model = LinearRegression(input_dim, output_dim)
self.learning_rate = learning_rate
self.epochs = epochs
def train(self):
distances = self.data["trip_distance"].reshape(-1, 1)
fares = self.data["fare_amount"].reshape(-1, 1)
# Define the loss function
criterion = nn.MSELoss()
# Define the optimizer
optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate)
# Train the model
for _ in tqdm.trange(self.epochs): # noqa: F402
# Forward pass
y_pred = self.model(distances)
# Compute loss
loss = criterion(y_pred, fares)
# Backward pass and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
def predict(self, input):
with torch.no_grad():
return self.model(input)
def __call__(self, input: pa.ChunkedArray):
# Convert the input to numpy so it can be fed to the model
#
# .copy() to avoid the warning about undefined behavior from torch
input = torch.from_numpy(input.to_numpy().copy())[:, None]
predicted = self.predict(input).ravel()
return pa.array(predicted.numpy())
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment