Created
February 29, 2024 21:37
-
-
Save looselycoupled/428946520969112468437f496be8f9ac to your computer and use it in GitHub Desktop.
Proof of concept for rewriting weighted random choice but with Spanner so cannot do entirely in query
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
from pprint import pprint | |
from google.cloud import spanner | |
# Initialize Spanner client | |
spanner_client = spanner.Client(project='test-project') | |
instance = spanner_client.instance('test-instance') | |
database = instance.database('test-database') | |
def create_table(): | |
# Define a DDL statement for creating a table | |
ddl_statements = [ | |
""" | |
CREATE TABLE employees ( | |
employee_id INT64 NOT NULL, | |
name STRING(1024), | |
role STRING(1024), | |
start_date DATE, | |
salary INT64, | |
weight FLOAT64 | |
) PRIMARY KEY (employee_id) | |
""" | |
] | |
# Execute DDL statement in a database update operation | |
database.update_ddl(ddl_statements).result() | |
# insert data into the table | |
def insert(): | |
with database.batch() as batch: | |
batch.insert( | |
table='employees', | |
columns=('employee_id', 'name', 'role', 'start_date', 'salary', 'weight'), | |
values=[ | |
(1, 'Alice', 'Engineer', '2021-05-01', 70000, 1.0), | |
(2, 'Bob', 'Manager', '2021-06-15', 85000, 2.0), | |
(3, 'Charlie', 'Manager', '2021-06-15', 85000, 2.0), | |
] | |
) | |
print("Table updated and records inserted/updated successfully.") | |
# Perform weighted random selection | |
def weighted_random_selection(weights): | |
total = sum(w for _, w in weights) | |
r = random.uniform(0, total) | |
upto = 0 | |
for employee, weight in weights: | |
if upto + weight >= r: | |
return employee | |
upto += weight | |
if __name__ == "__main__": | |
# Query to fetch employees and their weights | |
with database.snapshot() as snapshot: | |
results = snapshot.execute_sql( | |
"SELECT employee_id, weight FROM employees" | |
) | |
results = list(results) | |
pprint(results) | |
# Perform weighted random selection | |
for i in range(10): | |
random_employee = weighted_random_selection(results) | |
print(f"Randomly selected employee: {random_employee}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment