Skip to content

Instantly share code, notes, and snippets.

@lampinen-dm
Last active May 29, 2024 15:28
Show Gist options
  • Save lampinen-dm/b6541019ef4cf2988669ab44aa82460b to your computer and use it in GitHub Desktop.
Save lampinen-dm/b6541019ef4cf2988669ab44aa82460b to your computer and use it in GitHub Desktop.
feature_complexity_bias_demo.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"name": "feature_complexity_bias_demo.ipynb",
"authorship_tag": "ABX9TyOn2XMcXioLAW16Ta0z5Zwi",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/lampinen-dm/b6541019ef4cf2988669ab44aa82460b/easy_vs_hard_feature_bias_demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"Copyright 2024 Google LLC.\n",
"\n",
"SPDX-License-Identifier: Apache-2.0"
],
"metadata": {
"id": "yKrsMQs7aPqP"
}
},
{
"cell_type": "markdown",
"source": [
"# Representations are biased by feature complexity\n",
"\n",
"This colab is intended to provide a basic demonstration of the feature complexity bias we describe in \"Learned feature representations are biased by complexity, learning order, position, and more\" (https://arxiv.org/abs/2405.05847). It contains a simple implementation that captures most key features from the first binary feature + MLP experiments. Note that this colab does not attempt to reproduce those experiments exactly (e.g., this colab uses the Adam optimizer for quick experimentation, while the first experiments in the paper used SGD + early stopping), but rather it is intended to provide a simple illustrative example and starting point for further investigation.\n",
"\n",
"Some notes:\n",
"\n",
"* The Sum(W,X,Y,Z) % 2 function we describe in the paper is equivalent to the xor_xor_xor = XOR(XOR(W,X), XOR(Y,Z)) function used here.\n",
"* To plot the representation variance curves comparably to Fig. 2 in the paper, multiply the R^2 for each feature at step *t* by the ratio total_variance_*t* / total_variance_*final* to normalize by the final variance explained.\n",
"\n",
"\n",
"\n"
],
"metadata": {
"id": "mYkfQlD1ZkSl"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vM9PavxxezSa"
},
"outputs": [],
"source": [
"import flax.linen as nn\n",
"from flax.training import train_state\n",
"import jax, jax.numpy as jnp\n",
"import numpy as np\n",
"import optax\n",
"from sklearn import linear_model"
]
},
{
"cell_type": "code",
"source": [
"NUM_SEEDS = 5\n",
"FEATURES = ('linear', 'xor_xor_xor') # features can be 'linear' or combine exactly three logical ops from ('and', 'or', 'xor'), e.g. 'and_and_xor' or 'xor_or_or'\n",
"TRAIN_DATASET_SIZE = 8192 # in which case the order is top_left_right for the function top(left(A,B), right(C,D)) where A-D are boolean inputs\n",
"TEST_DATASET_SIZE = 1024\n",
"BATCH_SIZE = 1024"
],
"metadata": {
"id": "3Q_fnX8NtGdi"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# datasets"
],
"metadata": {
"id": "weNlnpQ3j3zt"
}
},
{
"cell_type": "code",
"source": [
"OP_FNS = {'and': np.logical_and, 'or': np.logical_or, 'xor': np.logical_xor}\n",
"\n",
"def make_easy_hard_multi_feature_dataset(features=('linear', 'xor_xor_xor'), input_units_per=16, num_examples=128, seed=123):\n",
" np.random.seed(seed)\n",
" inputs = np.random.binomial(1, 0.5, (num_examples, len(features) * input_units_per))\n",
" outputs = []\n",
" for feature_i, feature_type in enumerate(features):\n",
" these_inputs = inputs[:, feature_i * input_units_per:(feature_i + 1) * input_units_per,]\n",
" if feature_type == 'linear':\n",
" these_inputs[:, :4] = these_inputs[:, :1] # copy easy feature across 4 input units to match number of relevant inputs\n",
" these_outputs = these_inputs[:, :1]\n",
" else:\n",
" these_outputs = np.zeros_like(these_inputs[:, :1])\n",
" def label_fn(inputs):\n",
" top_op, left_op, right_op = [OP_FNS[op] for op in feature_type.split('_')]\n",
" return top_op(left_op(inputs[0], inputs[1]), right_op(inputs[2], inputs[3]))\n",
" for ex_number, ex_index in enumerate(np.random.permutation(num_examples)):\n",
" while label_fn(these_inputs[ex_index]) != ex_number % 2: # rejection sample\n",
" these_inputs[ex_index] = np.random.binomial(1, 0.5, input_units_per)\n",
" these_outputs[ex_index] = label_fn(these_inputs[ex_index])\n",
" outputs.append(these_outputs)\n",
" return {'inputs': inputs, 'labels': np.concatenate(outputs, axis=-1)}"
],
"metadata": {
"id": "eEgEH_Ey5GGp"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# network"
],
"metadata": {
"id": "9qaL7VzHQZyb"
}
},
{
"cell_type": "code",
"source": [
"class MLP(nn.Module):\n",
" layer_sizes: list[int]\n",
"\n",
" @nn.compact\n",
" def __call__(self, x):\n",
" for layer_size in self.layer_sizes[:-1]:\n",
" x = nn.leaky_relu(nn.Dense(layer_size, kernel_init=nn.initializers.variance_scaling(scale=1., mode='fan_avg', distribution='truncated_normal'))(x))\n",
" output = nn.Dense(self.layer_sizes[-1], kernel_init=nn.initializers.variance_scaling(scale=1., mode='fan_avg', distribution='truncated_normal'))(x)\n",
" return output, x"
],
"metadata": {
"id": "GTRIB9dWjE-A"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"@jax.jit\n",
"def apply_model(state, inputs, labels):\n",
" def loss_fn(params):\n",
" logits, penultimate_reps = state.apply_fn({'params': params}, inputs)\n",
" loss_array = np.mean(optax.sigmoid_binary_cross_entropy(logits=logits, labels=labels), axis=0)\n",
" return jnp.mean(loss_array), (logits, penultimate_reps, loss_array)\n",
" grad_fn = jax.value_and_grad(loss_fn, has_aux=True)\n",
" (loss, (logits, penultimate_reps, feature_losses)), grads = grad_fn(state.params)\n",
" accuracies = jnp.mean((logits > 0) * 1. == labels, axis=0)\n",
" return grads, loss, accuracies, penultimate_reps, feature_losses\n",
"\n",
"\n",
"@jax.jit\n",
"def update_model(state, grads):\n",
" return state.apply_gradients(grads=grads)\n",
"\n",
"\n",
"def train_epoch(state, train_ds):\n",
" this_data_order = np.random.permutation(train_ds['inputs'].shape[0])\n",
" train_ds = {k: v[this_data_order, :] for k, v in train_ds.items()}\n",
" for batch_index in range(0, TRAIN_DATASET_SIZE, BATCH_SIZE):\n",
" grads, *_ = apply_model(state, train_ds['inputs'][batch_index:batch_index + BATCH_SIZE], train_ds['labels'][batch_index:batch_index + BATCH_SIZE])\n",
" state = update_model(state, grads)\n",
" return state\n",
"\n",
"\n",
"def create_train_state(rng, num_features, num_input_units_per_feature=16):\n",
" network = MLP(layer_sizes=[256, 128, 64, 64, num_features])\n",
" params = network.init(rng, jnp.ones([1, num_input_units_per_feature * num_features]))['params']\n",
" return train_state.TrainState.create(apply_fn=network.apply, params=params, tx=optax.adam(1e-3))"
],
"metadata": {
"id": "EJYZk2Hflfhj"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# analyze"
],
"metadata": {
"id": "L_hhdDN6keFd"
}
},
{
"cell_type": "code",
"source": [
"def analyze_rep_var_explained(fit_reps, fit_labels, test_reps, test_labels):\n",
" scores = []\n",
" total_variance = np.sum(np.var(test_reps, axis=0)) # useful to normalize\n",
" for feat in range(fit_labels.shape[-1]):\n",
" regr = linear_model.LinearRegression()\n",
" regr.fit(fit_labels[:, feat:feat+1], fit_reps)\n",
" scores.append(regr.score(test_labels[:, feat:feat+1], test_reps))\n",
" return scores, total_variance"
],
"metadata": {
"id": "EqBuSB1KEKb5"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# run"
],
"metadata": {
"id": "Zeg9TOhoRzD1"
}
},
{
"cell_type": "code",
"source": [
"features_str = '(' + ', '.join([str(x) for x in FEATURES]) + ')'\n",
"with open('./results.csv', 'w') as outfile:\n",
" output_labels = ['seed', 'epoch', 'features','train-loss', 'test-loss', 'total-variance'] + ['%s_feature%i-%s' % (stat, i, str(f)) for stat in ('train-acc', 'test-acc', 'train-loss', 'test-loss', 'rep-R2') for i,f in enumerate(FEATURES)]\n",
" outfile.write(', '.join(output_labels) + '\\n')\n",
" output_formats = ['%d', '%d', '%s'] + ['%.8f'] * (len(output_labels) - 3)\n",
" print_format = ', '.join([x + ': ' + y for x, y in zip(output_labels, output_formats)]).replace('%.8f', '%.4f')\n",
" output_format = ', '.join(output_formats) + '\\n'\n",
" for seed in range(NUM_SEEDS):\n",
" state = create_train_state(jax.random.key(123 + seed), num_features=len(FEATURES))\n",
" rng = np.random.default_rng(123 + seed)\n",
" train_ds = make_easy_hard_multi_feature_dataset(features=FEATURES, num_examples=TRAIN_DATASET_SIZE, seed=123 + seed)\n",
" val_rep_ds = make_easy_hard_multi_feature_dataset(features=FEATURES, num_examples=TEST_DATASET_SIZE, seed=1234 + seed)\n",
" test_ds = make_easy_hard_multi_feature_dataset(features=FEATURES, num_examples=TEST_DATASET_SIZE, seed=12345 + seed)\n",
" def _do_eval():\n",
" _, train_loss, train_accuracies, train_reps, train_feature_losses = apply_model(state, train_ds['inputs'], train_ds['labels'])\n",
" _, _, _, extra_reps, _ = apply_model(state, val_rep_ds['inputs'], val_rep_ds['labels'])\n",
" _, test_loss, test_accuracies, test_reps, test_feature_losses = apply_model(state, test_ds['inputs'], test_ds['labels'])\n",
" variance_scores, total_variance = analyze_rep_var_explained(extra_reps, val_rep_ds['labels'], test_reps, test_ds['labels'])\n",
" these_results = (seed, epoch, features_str, train_loss, test_loss, total_variance, *train_accuracies, *test_accuracies, *train_feature_losses, *test_feature_losses, *variance_scores)\n",
" print(print_format % these_results, flush=True)\n",
" outfile.write(output_format % these_results)\n",
" return variance_scores\n",
" for epoch in range(0, 101):\n",
" variance_scores = _do_eval()\n",
" state = train_epoch(state, train_ds)\n",
" print(f\"Seed {seed} representation variance at end of training:\\n\" + \"\\n\".join([f'{fe}: {sc}' for fe, sc in zip(FEATURES, variance_scores)]))\n"
],
"metadata": {
"id": "0DcKuTvWR1FV"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment