Skip to content

Instantly share code, notes, and snippets.

@buttercutter
Last active May 13, 2024 03:45
Show Gist options
  • Save buttercutter/b3331ca1fd9e2f5871b0eded6b758f39 to your computer and use it in GitHub Desktop.
Save buttercutter/b3331ca1fd9e2f5871b0eded6b758f39 to your computer and use it in GitHub Desktop.
Mamba: Linear-Time Sequence Modeling with Selective State Spaces
# [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752)
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn import functional as F
from einops import rearrange, repeat
from tqdm import tqdm
import math
import os
import urllib.request
from zipfile import ZipFile
from transformers import AutoTokenizer
torch.autograd.set_detect_anomaly(True)
debugging_is_on = 0
def print_tensor_info(tensor_name, tensor):
# Check if tensor is floating point, and convert if necessary
tensor_float = tensor.float() if not tensor.is_floating_point() else tensor
# Gather the information
info = {
"shape": tuple(tensor.shape),
"min/max": (tensor.min().item(), tensor.max().item()),
"mean": tensor_float.mean().item(),
"std": tensor_float.std().item()
}
# Print the default representation and the extra information
print(f"{tensor_name} = {tensor}")
for key, value in info.items():
print(f"{key}: {value}")
USE_MAMBA = 1
USE_TRANSFORMER = ~USE_MAMBA
DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM = 0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# User hyperparameters
d_model = 16
state_size = 64 # Example state size
seq_len = 100 # Example sequence length
batch_size = 128 # Example batch size
class S6(nn.Module):
def __init__(self, seq_len, d_model, state_size, device):
super(S6, self).__init__()
self.fc1 = nn.Linear(d_model, d_model, device=device)
self.fc2 = nn.Linear(d_model, state_size, device=device)
self.fc3 = nn.Linear(d_model, state_size, device=device)
self.seq_len = seq_len
self.d_model = d_model
self.state_size = state_size
#self.A = nn.Parameter(torch.ones(d_model, state_size, device=device))
#self.A = nn.Parameter(F.normalize(torch.ones(d_model, state_size, device=device), p=2, dim=-1))
#nn.init.xavier_uniform_(self.A)
# S4D real initialization, MAMBA removed imaginary portions for S4D-Inv and S4D-Lin initialization schemes
# described in [On the Parameterization and Initialization of Diagonal State Space Models](https://arxiv.org/abs/2206.11893)
# https://github.com/state-spaces/mamba/blob/fb7b5310fa865dbd62aa059b1e26f2b431363e2a/mamba_ssm/modules/mamba_simple.py#L103-L108C23
A = repeat(
torch.arange(1, state_size + 1, dtype=torch.float32, device=device),
"n -> d n",
d=d_model,
).contiguous()
A_log = torch.log(A) # For numerical stability during training process
self.A_log = nn.Parameter(A_log)
self.A_log._no_weight_decay = True
self.A = torch.zeros_like(self.A_log)
self.B = torch.zeros(batch_size, self.seq_len, self.state_size, device=device)
self.C = torch.zeros(batch_size, self.seq_len, self.state_size, device=device)
#self.delta = torch.zeros(batch_size, self.seq_len, self.d_model, device=device)
# Initialize delta parameter using a uniform distribution and apply the inverse softplus
uniform_distribution = torch.distributions.Uniform(0.001, 0.1)
# Sample from the uniform distribution and then apply the inverse softplus
self.delta = self.inverse_softplus(uniform_distribution.sample((batch_size, self.seq_len, self.d_model)))
self.dA = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device)
self.dB = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device)
# h should have dimensions [batch_size, seq_len, d_model, state_size]
self.h = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device)
self.y = torch.zeros(batch_size, self.seq_len, self.d_model, device=device)
def inverse_softplus(self, y):
return torch.log(torch.exp(y) - 1)
def discretization(self):
# discretization function is defined based on the MAMBA paper's description using ZOH on page 28
# in Section C : Mechanics on Selective SSMs
# See also "Zero-order hold discretization" maths proof inside https://studywolf.wordpress.com/tag/zero-order-hold/
"""
Here is an explanation of the mathematical rationale for the formulation of Δt used in Mamba:
The key idea is that Δt controls the discretization rate of the continuous SSM dynamics. By making Δt input-dependent, it introduces selectivity into the discrete transition matrices.
Specifically, in Mamba they parameterize Δt as:
Δt = τΔ(Parameter + sΔ(xt))
Where:
- Parameter is a learned scalar parameter that controls the baseline discretization rate
- sΔ(xt) is a projection that makes Δt input-dependent by computing a value based on xt
- τΔ(x) = softplus(x) transforms the result to be positive through the softplus nonlinearity
The rationale for this formulation is:
- Parameter provides a reasonable default discretization rate
- sΔ(xt) injects input-dependence through the projection
- softplus ensures Δt is positive as required to be a valid timestep
- The projection sΔ allows the model to learn to modulate Δt based on the input xt
- This modulation creates selectivity in how rapidly or slowly the states update
So in summary, the learned input-dependent projection allows Δt, and thus the discrete dynamics, to become selective. The softplus and scalar parameter provide useful inductive biases on top of this flexibility.
The end result is discrete transition matrices that are selective on the input, enabling powerful sequence modeling capabilities.
Credit: Claude2 AI chatbot
"""
# For numerical stability during training process
self.A = -torch.exp(self.A_log.float()) # (d_model, state_size)
#print(f"self.A.shape = {self.A.shape}")
#print(f"self.B.shape = {self.B.shape}")
#print(f"self.delta.shape = {self.delta.shape}")
# inverse() only supports square matrix
#dB = torch.matmul(torch.inverse(A * delta), torch.matmul(dA - torch.eye(A.shape[0]), B))
self.dB = torch.einsum("bld,bln->bldn", self.delta, self.B)
# https://github.com/state-spaces/mamba/blob/0131c1e94a46fc9f70bcfc9d57962963bb2f0b9e/mamba_ssm/modules/mamba_simple.py#L240
#dA = torch.matrix_exp(A * delta) # matrix_exp() only supports square matrix
self.dA = torch.exp(torch.einsum("bld,dn->bldn", self.delta, self.A))
#print(f"self.dA.shape = {self.dA.shape}")
#print(f"self.dA.requires_grad = {self.dA.requires_grad}")
return self.dA, self.dB
def forward(self, x):
# Refer to Algorithm 2 in the MAMBA paper
self.B = self.fc2(x)
self.C = self.fc3(x)
# "a large ∆ resets the state `h` and focuses on the current input `x`,
# while a small ∆ persists the state and ignores the current input."
self.delta = F.softplus(self.fc1(x))
# Uses ZOH as in MAMBA, Hungry Hippo still uses bilinear transform for discretization
self.discretization()
if DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM: # this will trigger in-place runtime error if without using `h_new`
#print(f"self.dA = {self.dA}, self.dB = {self.dB}")
#print(f"self.dA.shape = {self.dA.shape}")
#print(f"self.dB.shape = {self.dB.shape}")
#print(f"x.shape = {x.shape}")
#print(f"self.h.shape = {self.h.shape}")
#print(f"self.C.shape = {self.C.shape}")
global current_batch_size
current_batch_size = x.shape[0]
if self.h.shape[0] != current_batch_size:
#print("Adjusting h_new for the different batch size of input data `x`")
different_batch_size = True
# Resize self.h to match the current batch size
h_new = torch.einsum('bldn,bldn->bldn', self.dA, self.h[:current_batch_size, ...]) + rearrange(x, "b l d -> b l d 1") * self.dB
else:
different_batch_size = False
h_new = torch.einsum('bldn,bldn->bldn', self.dA, self.h) + rearrange(x, "b l d -> b l d 1") * self.dB
# y needs to have a shape of [batch_size, seq_len, d_model]
self.y = torch.einsum('bln,bldn->bld', self.C, h_new)
# Update self.h with the detached state of h_new
# Only do this if retaining gradients for self.h is not necessary for backprop
# Otherwise, store h_new in a temporary list and update self.h after the loop
global temp_buffer
temp_buffer = h_new.detach().clone() if not self.h.requires_grad else h_new.clone()
#print(f"temp_buffer.shape = {temp_buffer.shape}")
#print(f"self.y = {self.y}")
#print(f"self.dA.requires_grad = {self.dA.requires_grad}")
#print(f"self.dB.requires_grad = {self.dB.requires_grad}")
#print(f"self.C.requires_grad = {self.C.requires_grad}")
#print(f"self.h.requires_grad = {self.h.requires_grad}")
#print(f"self.y.requires_grad = {self.y.requires_grad}")
return self.y
else: # this will not trigger in-place runtime error
# h should have dimensions [batch_size, seq_len, d_model, state_size]
h = torch.zeros(x.size(0), self.seq_len, self.d_model, self.state_size, device=x.device)
y = torch.zeros_like(x)
h = torch.einsum('bldn,bldn->bldn', self.dA, h) + rearrange(x, "b l d -> b l d 1") * self.dB
# y needs to have a shape of [batch_size, seq_len, d_model]
y = torch.einsum('bln,bldn->bld', self.C, h)
return y
class MambaBlock(nn.Module):
def __init__(self, seq_len, d_model, state_size, device):
super(MambaBlock, self).__init__()
self.inp_proj = nn.Linear(d_model, 2*d_model, device=device)
self.out_proj = nn.Linear(2*d_model, d_model, device=device)
# For residual skip connection
self.D = nn.Linear(d_model, 2*d_model, device=device)
# Set _no_weight_decay attribute on bias
self.out_proj.bias._no_weight_decay = True
# Initialize bias to a small constant value
nn.init.constant_(self.out_proj.bias, 1.0)
self.S6 = S6(seq_len, 2*d_model, state_size, device)
# Add 1D convolution with kernel size 3
self.conv = nn.Conv1d(seq_len, seq_len, kernel_size=3, padding=1, device=device)
# rmsnorm
self.norm = RMSNorm(d_model, device=device)
def forward(self, x, attention_mask=None):
if attention_mask is not None:
# Apply the attention mask
x = x * attention_mask.unsqueeze(-1)
"""
x_proj.shape = torch.Size([batch_size, seq_len, 2*d_model])
x_conv.shape = torch.Size([batch_size, seq_len, 2*d_model])
x_conv_act.shape = torch.Size([batch_size, seq_len, 2*d_model])
"""
# Refer to Figure 3 in the MAMBA paper
x = self.norm(x)
x_proj = self.inp_proj(x)
#print(f"x_proj.shape = {x_proj.shape}")
# Add 1D convolution with kernel size 3
x_conv = self.conv(x_proj)
# Create a triangular mask of the same shape as the input sequence
mask = torch.tril(torch.ones(seq_len, 2*d_model, device=device))
# Add batch dimension with unsqueeze(0) -> (1, seq_len, seq_len)
# Repeat batch dim to match x_conv batches with .repeat()
current_batch_size = x.shape[0]
mask = mask.repeat(current_batch_size, 1, 1)
# Apply causal mask to zero out the masked regions
x_conv = x_conv * mask
#print(f"x_conv.shape = {x_conv.shape}")
x_conv_act = F.silu(x_conv) # Swish activation can be implemented as x * sigmoid(x)
#print(f"x_conv_act.shape = {x_conv_act.shape}")
x_ssm = self.S6(x_conv_act)
#print(f"x_ssm.shape = {x_ssm.shape}")
# residual skip connection with nonlinearity introduced by multiplication
x_residual = F.silu(self.D(x))
#print(f"x_residual.shape = {x_residual.shape}")
x_combined = x_ssm * x_residual
#print(f"x_combined.shape = {x_combined.shape}")
x_out = self.out_proj(x_combined)
#print(f"x_out.shape = {x_out.shape}")
return x_out
class Mamba(nn.Module):
def __init__(self, seq_len, d_model, state_size, vocab_size, device):
super(Mamba, self).__init__()
if vocab_size is None:
vocab_size = d_model
self.mamba_block1 = MambaBlock(seq_len, d_model, state_size, device)
self.mamba_block2 = MambaBlock(seq_len, d_model, state_size, device)
self.mamba_block3 = MambaBlock(seq_len, d_model, state_size, device)
self.final_proj = nn.Linear(d_model, vocab_size, device=device)
def forward(self, x, attention_mask=None):
x = self.mamba_block1(x, attention_mask)
x = self.mamba_block2(x, attention_mask)
x = self.mamba_block3(x, attention_mask)
x = self.final_proj(x)
return x
class RMSNorm(nn.Module):
def __init__(self,
d_model: int,
eps: float = 1e-5,
device: str ='cuda'):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d_model, device=device))
def forward(self, x):
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
return output
# Example usage:
# Create a random input tensor
if USE_MAMBA:
x = torch.rand(batch_size, seq_len, d_model, device=device)
# Create the Mamba model
mamba = Mamba(seq_len, d_model, state_size, None, device)
# rmsnorm
norm = RMSNorm(d_model)
x = norm(x)
# Forward pass
test_output = mamba(x)
print(f"test_output.shape = {test_output.shape}") # Should be [batch_size, seq_len, d_model]
class Enwiki8Dataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data['encoded_inputs'])
def __getitem__(self, idx):
item = {key: val[idx].clone().detach() for key, val in self.data.items()}
return item
# Define a function for padding
def pad_sequences_3d(sequences, max_len=None, pad_value=0):
if sequences.ndim == 3:
# Assuming sequences is a tensor of shape (batch_size, seq_len, feature_size)
batch_size, seq_len, feature_size = sequences.shape
else:
# Assuming sequences is a tensor of shape (batch_size, seq_len)
batch_size, seq_len = sequences.shape
if max_len is None:
max_len = seq_len + 1
if sequences.ndim == 3:
# Initialize padded_sequences with the pad_value
padded_sequences = torch.full((batch_size, max_len, feature_size), fill_value=pad_value, dtype=sequences.dtype, device=sequences.device)
# Pad each sequence to the max_len
padded_sequences[:, :seq_len, :] = sequences
else:
# Initialize padded_sequences with the pad_value
padded_sequences = torch.full((batch_size, max_len), fill_value=pad_value, dtype=sequences.dtype, device=sequences.device)
# Pad each sequence to the max_len
padded_sequences[:, :seq_len] = sequences
return padded_sequences
def train(model, tokenizer, data_loader, optimizer, criterion, device, max_grad_norm=1.0, DEBUGGING_IS_ON=False):
model.train()
total_loss = 0
for batch in data_loader:
optimizer.zero_grad()
original_data = batch['input_ids'].clone().to(device) # data without downsized dimension
input_data = batch['encoded_inputs'].clone().to(device) # data with downsized dimension for Mamba model
attention_mask = batch['attention_mask'].clone().to(device)
# In most sequence modeling tasks, like language modeling, the target should be the next token
# in the sequence rather than the input token itself.
# This is because the model's goal is to predict the next word given the previous words.
# Shift the input data by one position to get the target, so that each target token
# is the next token following the input token.
target = original_data[:, 1:]
input_data = input_data[:, :-1]
#print("Before padding: ")
#print(f"target.shape = {target.shape}")
#print(f"input_data.shape = {input_data.shape}")
# Pad all the sequences in the batch:
input_data = pad_sequences_3d(input_data, pad_value=tokenizer.pad_token_id)
target = pad_sequences_3d(target, max_len=original_data.size(1), pad_value=tokenizer.pad_token_id)
#print("After padding: ")
#print(f"target.shape = {target.shape}")
#print(f"input_data.shape = {input_data.shape}")
# For Mamba model, it can only accept downsized `input_data` due to RAM memory restriction
# and already have a final_proj layer to upsize the `output` dimension to be the same as `target`
output = model(input_data, attention_mask)
#print(f"Output shape: {output.shape}")
#print(f"Target shape: {target.shape}")
loss = criterion(output.view(-1, vocab_size), target.view(-1))
loss.backward(retain_graph=True)
# Clip gradients: gradients are modified in place
#torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
for name, param in model.named_parameters():
if 'out_proj.bias' not in name:
# clip weights but not bias for out_proj
torch.nn.utils.clip_grad_norm_(param, max_norm=max_grad_norm)
if DEBUGGING_IS_ON:
print("DEBUGGING IS ON !!!")
print_tensor_info("output", output)
print_tensor_info("target", target)
for name, parameter in model.named_parameters():
if parameter.grad is not None:
print(f"{name} gradient: {parameter.grad.data.norm(2)}")
else:
print(f"{name} has no gradient")
if USE_MAMBA and DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM:
# update self.h from temp_buffer
#print(f"temp_buffer = {temp_buffer}")
#print(f"temp_buffer.shape = {temp_buffer.shape}")
#print(f"current_batch_size = {current_batch_size}")
model.S6.h[:current_batch_size, ...].copy_(temp_buffer)
optimizer.step()
total_loss += loss.item()
return total_loss / len(data_loader)
def evaluate(model, data_loader, criterion, device, DEBUGGING_IS_ON=False):
model.eval()
total_loss = 0
with torch.no_grad():
for batch in data_loader:
original_data = batch['input_ids'].clone().to(device) # data without downsized dimension
input_data = batch['encoded_inputs'].clone().detach().to(device) # data with downsized dimension for Mamba model
attention_mask = batch['attention_mask'].clone().detach().to(device)
# In most sequence modeling tasks, like language modeling, the target should be the next token
# in the sequence rather than the input token itself.
# This is because the model's goal is to predict the next word given the previous words.
# Shift the input data by one position to get the target, so that each target token
# is the next token following the input token.
target = original_data[:, 1:]
input_data = input_data[:, :-1]
#print("Before padding: ")
#print(f"target.shape = {target.shape}")
#print(f"input_data.shape = {input_data.shape}")
# Pad all the sequences in the batch:
input_data = pad_sequences_3d(input_data, pad_value=tokenizer.pad_token_id)
target = pad_sequences_3d(target, max_len=original_data.size(1), pad_value=tokenizer.pad_token_id)
#print("After padding: ")
#print(f"target.shape = {target.shape}")
#print(f"input_data.shape = {input_data.shape}")
# For Mamba model, it can only accept downsized `input_data` due to RAM memory restriction
# and already have a final_proj layer to upsize the `output` dimension to be the same as `target`
output = model(input_data, attention_mask)
#print(f"Output shape: {output.shape}")
#print(f"Target shape: {target.shape}")
loss = criterion(output.view(-1, vocab_size), target.view(-1))
total_loss += loss.item()
if DEBUGGING_IS_ON:
print("DEBUGGING IS ON !!!")
print_tensor_info("output", output)
print_tensor_info("target", target)
return total_loss / len(data_loader)
def calculate_perplexity(loss):
return math.exp(loss)
def load_enwiki8_dataset():
print(f"Download and extract enwiki8 data")
url = "http://mattmahoney.net/dc/enwik8.zip"
urllib.request.urlretrieve(url, "enwik8.zip")
with ZipFile("enwik8.zip") as f:
data = f.read("enwik8").decode("utf-8")
return data
# Tokenize and encode the dataset
def encode_dataset(tokenizer, text_data):
def batch_encode(tokenizer, text_data, batch_size=1000):
# Tokenize in batches
batched_input_ids = []
for i in range(0, len(text_data), batch_size):
batch = text_data[i:i+batch_size]
inputs = tokenizer(batch, add_special_tokens=True, truncation=True,
padding='max_length', max_length=seq_len,
return_tensors='pt')
batched_input_ids.append(inputs['input_ids'])
return torch.cat(batched_input_ids)
# Assuming enwiki8_data is a list of sentences
input_ids = batch_encode(tokenizer, enwiki8_data)
# vocab_size is the number of unique tokens in the tokenizer's vocabulary
global vocab_size
vocab_size = len(tokenizer.vocab) # Note that for some tokenizers, we might access the vocab directly
print(f"vocab_size = {vocab_size}")
# Create an embedding layer
# embedding_dim is the size of the embedding vectors (MAMBA model's D)
embedding_layer = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)
# Pass `input_ids` through the embedding layer
# This will change `input_ids` from shape [B, L] to [B, L, D]
#encoded_inputs = embedding_layer(input_ids) ## this eats memory, so use batched_embedding_calls instead
def batch_embedding_calls(input_ids, embedding_layer, batch_size=256):
# Check if input_ids is already a tensor, if not convert it
if not isinstance(input_ids, torch.Tensor):
input_ids = torch.tensor(input_ids, dtype=torch.long)
# Calculate the number of batches needed
num_batches = math.ceil(input_ids.size(0) / batch_size)
# List to hold the output embeddings
output_embeddings = []
# Process each batch
for i in range(num_batches):
# Calculate start and end indices for the current batch
start_idx = i * batch_size
end_idx = start_idx + batch_size
# Get the batch
input_id_batch = input_ids[start_idx:end_idx]
# Call the embedding layer
with torch.no_grad(): # No need gradients for this operation
batch_embeddings = embedding_layer(input_id_batch)
# Append the result to the list
output_embeddings.append(batch_embeddings)
# Concatenate the embeddings from each batch into a single tensor
all_embeddings = torch.cat(output_embeddings, dim=0)
return all_embeddings
# `input_ids` is a list or tensor of the input IDs and `embedding_layer` is model's embedding layer
if USE_MAMBA:
# Set `batch_size` to a value that works for memory constraints
# batch_embedding_calls() is very slow, not suitable to implement directly during forward pass
encoded_inputs = batch_embedding_calls(input_ids, embedding_layer, batch_size=1).float()
elif USE_TRANSFORMER:
encoded_inputs = input_ids.long() # Cast input_ids to long if necessary
attention_mask = (input_ids != tokenizer.pad_token_id).type(input_ids.dtype)
#print(f"attention_mask.shape = {attention_mask.shape}")
#print(f"encoded_inputs.shape = {encoded_inputs.shape}")
return encoded_inputs, attention_mask, input_ids
# Load a pretrained tokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
#tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
# Use an existing special token as the padding token.
#tokenizer.pad_token = tokenizer.eos_token
# Assuming encoded_inputs is a preprocessed tensor of shape [num_samples, seq_len, d_model]
if USE_MAMBA:
encoded_inputs_file = 'encoded_inputs_mamba.pt'
elif USE_TRANSFORMER:
encoded_inputs_file = 'encoded_inputs_transformer.pt'
if os.path.exists(encoded_inputs_file):
print("Loading pre-tokenized data...")
encoded_inputs = torch.load(encoded_inputs_file)
else:
print("Tokenizing raw data...")
enwiki8_data = load_enwiki8_dataset()
encoded_inputs, attention_mask, input_ids = encode_dataset(tokenizer, enwiki8_data)
torch.save(encoded_inputs, encoded_inputs_file)
print(f"finished tokenizing data")
# Combine into a single dictionary
data = {
'input_ids': input_ids,
'encoded_inputs': encoded_inputs,
'attention_mask': attention_mask
}
# Split the data into train and validation sets
total_size = len(data['encoded_inputs'])
train_size = int(total_size * 0.8)
train_data = {key: val[:train_size] for key, val in data.items()}
val_data = {key: val[train_size:] for key, val in data.items()}
train_dataset = Enwiki8Dataset(train_data)
val_dataset = Enwiki8Dataset(val_data)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
# Initialize the model
if USE_MAMBA:
model = Mamba(seq_len, d_model, state_size, vocab_size, device).to(device)
elif USE_TRANSFORMER:
from transformers import AutoModel
# Create TinyBert model instance
bert_model = AutoModel.from_pretrained("prajjwal1/bert-tiny").to(device)
print(f"bert_model.config.hidden_size = {bert_model.config.hidden_size}")
class NextTokenPredictor(nn.Module):
def __init__(self, bert_model, vocab_size):
super(NextTokenPredictor, self).__init__()
self.bert = bert_model
self.predictor = nn.Linear(bert_model.config.hidden_size, vocab_size)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids, attention_mask=attention_mask)
sequence_output = outputs.last_hidden_state
prediction_scores = self.predictor(sequence_output)
return prediction_scores
model = NextTokenPredictor(bert_model, vocab_size).to(device)
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=5e-6)
# Training loop
num_epochs = 25 # Number of epochs to train for
for epoch in tqdm(range(num_epochs)): # loop over the dataset multiple times
train_loss = train(model, tokenizer, train_loader, optimizer, criterion, device, max_grad_norm=10.0, DEBUGGING_IS_ON=debugging_is_on)
val_loss = evaluate(model, val_loader, criterion, device, DEBUGGING_IS_ON=debugging_is_on)
val_perplexity = calculate_perplexity(val_loss)
print(f'Epoch: {epoch+1}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Validation Perplexity: {val_perplexity:.4f}')
if train_loss < 0 or val_loss < 0:
debugging_is_on = 1
test_output.shape = torch.Size([256, 100, 8])
/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:88: UserWarning:
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
warnings.warn(
tokenizer_config.json: 100%
28.0/28.0 [00:00<00:00, 474B/s]
config.json: 100%
570/570 [00:00<00:00, 15.4kB/s]
vocab.txt: 100%
232k/232k [00:00<00:00, 2.79MB/s]
tokenizer.json: 100%
466k/466k [00:00<00:00, 4.03MB/s]
Tokenizing raw data...
Download and extract enwiki8 data
vocab_size = 30522
finished tokenizing data
4%|▍ | 1/25 [01:19<31:51, 79.63s/it]
Streaming output truncated to the last 5000 lines.
[1.0150, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]],
[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0052, 1.0064, 0.9966, ..., 1.0035, 1.0033, 0.9953],
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9949],
...,
[1.0123, 1.0031, 0.9993, ..., 1.0050, 1.0097, 0.9956],
[1.0247, 1.0138, 0.9932, ..., 1.0348, 1.0349, 0.9925],
[1.0150, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]],
[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0052, 1.0064, 0.9966, ..., 1.0035, 1.0033, 0.9953],
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9949],
...,
[1.0123, 1.0032, 0.9993, ..., 1.0050, 1.0097, 0.9956],
[1.0247, 1.0138, 0.9932, ..., 1.0348, 1.0349, 0.9925],
[1.0150, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]],
...,
[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0052, 1.0064, 0.9966, ..., 1.0035, 1.0033, 0.9953],
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9949],
...,
[1.0123, 1.0031, 0.9993, ..., 1.0050, 1.0097, 0.9956],
[1.0247, 1.0138, 0.9932, ..., 1.0348, 1.0349, 0.9925],
[1.0150, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]],
[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0052, 1.0064, 0.9966, ..., 1.0035, 1.0033, 0.9953],
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9949],
...,
[1.0123, 1.0032, 0.9993, ..., 1.0050, 1.0097, 0.9956],
[1.0247, 1.0138, 0.9932, ..., 1.0348, 1.0349, 0.9925],
[1.0150, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]],
[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0052, 1.0064, 0.9967, ..., 1.0035, 1.0033, 0.9953],
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9949],
...,
[1.0123, 1.0032, 0.9992, ..., 1.0050, 1.0097, 0.9956],
[1.0247, 1.0138, 0.9932, ..., 1.0348, 1.0350, 0.9925],
[1.0150, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]]],
device='cuda:0', grad_fn=<ViewBackward0>)
shape: (256, 100, 8)
min/max: (0.963894784450531, 1.03498375415802)
mean: 1.0009433031082153
std: 0.007903832010924816
target = tensor([[[ 0.3699, 0.2790, 0.5842, ..., 1.5279, 0.0930, -1.2546],
[ 3.0019, -0.4078, -0.7056, ..., 1.3097, 0.8935, 1.5305],
[-0.0055, 1.4312, -0.2068, ..., 0.2403, 0.8108, -0.4160],
...,
[ 0.2078, 0.4916, -0.6117, ..., -0.0424, -0.4392, -1.6947],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 1.4162, 0.2253, -0.0672, ..., -0.9196, 0.7513, 0.9457],
[-0.7880, 0.3277, -0.4625, ..., 1.0912, 0.8847, 0.0261],
[-1.0726, -0.8486, 0.7417, ..., 0.2901, 0.5678, 0.3142],
...,
[ 0.1903, 0.7261, -1.3328, ..., -1.6171, 0.1211, -0.1400],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.4741, -0.1311, 2.4631, ..., 1.1667, 0.0434, 2.2398],
[-0.0729, -1.5737, 0.1047, ..., -1.7538, 0.6804, -0.4289],
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989],
...,
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
...,
[[ 0.6819, 0.6503, -0.0784, ..., 1.2032, 0.4372, 0.0505],
[ 0.4974, 0.7906, -1.0028, ..., 1.2410, -1.4670, -1.0270],
[ 2.8159, 0.9576, -0.3607, ..., 2.3174, -0.3391, -0.0629],
...,
[-1.0850, -0.6438, 0.2743, ..., -1.1642, -0.8742, -0.2776],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.9667, -1.2356, -0.0224, ..., -0.2647, -0.8683, 1.7923],
[-2.1527, -0.0821, -0.2856, ..., 0.0990, 1.7970, 0.9253],
[ 2.0079, 0.2899, 0.1462, ..., 0.0117, 0.1548, 0.9144],
...,
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.8051, -1.0119, 0.1091, ..., 1.4222, 0.1646, 0.0119],
[ 0.6819, 0.6503, -0.0784, ..., 1.2032, 0.4372, 0.0505],
[ 1.6541, 0.4031, -0.3804, ..., 0.1305, 0.6855, -0.8260],
...,
[ 0.9954, 0.6389, 0.7271, ..., -0.3038, 0.5158, -1.5865],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]],
device='cuda:0')
shape: (256, 100, 8)
min/max: (-4.062912464141846, 3.9631638526916504)
mean: -0.039700813591480255
std: 0.99717116355896
mamba_block1.inp_proj.weight gradient: 6.656166988250334e-06
mamba_block1.inp_proj.bias gradient: 9.812881216930691e-06
mamba_block1.out_proj.weight gradient: 8.017977961571887e-05
mamba_block1.out_proj.bias gradient: 0.0019019206520169973
mamba_block1.D.weight gradient: 2.1973764887661673e-05
mamba_block1.D.bias gradient: 2.58996915363241e-05
mamba_block1.S6.A_log gradient: 0.0
mamba_block1.S6.fc1.weight gradient: 2.925547278209706e-06
mamba_block1.S6.fc1.bias gradient: 4.201373030809918e-06
mamba_block1.S6.fc2.weight gradient: 1.4111941709415987e-05
mamba_block1.S6.fc2.bias gradient: 2.2652086045127362e-05
mamba_block1.S6.fc3.weight gradient: 1.3704809134651441e-05
mamba_block1.S6.fc3.bias gradient: 2.2110638383310288e-05
mamba_block1.conv.weight gradient: 4.779999653692357e-05
mamba_block1.conv.bias gradient: 8.194298061425798e-06
mamba_block1.conv_linear.weight gradient: 1.5478439308935776e-05
mamba_block1.conv_linear.bias gradient: 5.2600811613956466e-05
mamba_block1.norm.weight gradient: 4.226156761433231e-06
mamba_block2.inp_proj.weight gradient: 0.007911593653261662
mamba_block2.inp_proj.bias gradient: 0.002791037317365408
mamba_block2.out_proj.weight gradient: 0.008094603195786476
mamba_block2.out_proj.bias gradient: 0.020823726430535316
mamba_block2.D.weight gradient: 0.004664940293878317
mamba_block2.D.bias gradient: 0.0016456706216558814
mamba_block2.S6.A_log gradient: 0.0
mamba_block2.S6.fc1.weight gradient: 0.0010565894190222025
mamba_block2.S6.fc1.bias gradient: 0.0013118531787768006
mamba_block2.S6.fc2.weight gradient: 0.0029101786203682423
mamba_block2.S6.fc2.bias gradient: 0.003688430180773139
mamba_block2.S6.fc3.weight gradient: 0.0027016273234039545
mamba_block2.S6.fc3.bias gradient: 0.0033678270410746336
mamba_block2.conv.weight gradient: 0.011834848672151566
mamba_block2.conv.bias gradient: 0.0008222330361604691
mamba_block2.conv_linear.weight gradient: 0.008866168558597565
mamba_block2.conv_linear.bias gradient: 0.005379044450819492
mamba_block2.norm.weight gradient: 0.0031838142313063145
mamba_block3.inp_proj.weight gradient: 0.08878087252378464
mamba_block3.inp_proj.bias gradient: 0.03128563240170479
mamba_block3.out_proj.weight gradient: 0.04247088357806206
mamba_block3.out_proj.bias gradient: 9.189469096781977e-08
mamba_block3.D.weight gradient: 0.04408084601163864
mamba_block3.D.bias gradient: 0.015569724142551422
mamba_block3.S6.A_log gradient: 0.0
mamba_block3.S6.fc1.weight gradient: 0.004259995650500059
mamba_block3.S6.fc1.bias gradient: 0.004275763873010874
mamba_block3.S6.fc2.weight gradient: 0.01696809008717537
mamba_block3.S6.fc2.bias gradient: 0.01751020923256874
mamba_block3.S6.fc3.weight gradient: 0.017805999144911766
mamba_block3.S6.fc3.bias gradient: 0.018047377467155457
mamba_block3.conv.weight gradient: 0.1599825918674469
mamba_block3.conv.bias gradient: 0.015678398311138153
mamba_block3.conv_linear.weight gradient: 0.0732274278998375
mamba_block3.conv_linear.bias gradient: 0.04735743626952171
mamba_block3.norm.weight gradient: 0.02320939488708973
DEBUGGING IS ON !!!
output = tensor([[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0052, 1.0064, 0.9966, ..., 1.0035, 1.0033, 0.9953],
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9949],
...,
[1.0123, 1.0032, 0.9993, ..., 1.0050, 1.0097, 0.9956],
[1.0248, 1.0139, 0.9932, ..., 1.0349, 1.0351, 0.9925],
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]],
[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0052, 1.0064, 0.9966, ..., 1.0035, 1.0033, 0.9952],
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9949],
...,
[1.0123, 1.0032, 0.9993, ..., 1.0050, 1.0097, 0.9956],
[1.0248, 1.0138, 0.9932, ..., 1.0349, 1.0351, 0.9925],
[1.0150, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]],
[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0052, 1.0064, 0.9966, ..., 1.0035, 1.0033, 0.9952],
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9949],
...,
[1.0123, 1.0031, 0.9993, ..., 1.0050, 1.0097, 0.9956],
[1.0248, 1.0138, 0.9932, ..., 1.0349, 1.0351, 0.9925],
[1.0150, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]],
...,
[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0052, 1.0064, 0.9966, ..., 1.0035, 1.0033, 0.9953],
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9949],
...,
[1.0123, 1.0031, 0.9993, ..., 1.0050, 1.0097, 0.9956],
[1.0248, 1.0139, 0.9932, ..., 1.0349, 1.0351, 0.9925],
[1.0150, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]],
[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0052, 1.0064, 0.9966, ..., 1.0035, 1.0033, 0.9953],
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9948],
...,
[1.0123, 1.0032, 0.9993, ..., 1.0050, 1.0097, 0.9956],
[1.0248, 1.0138, 0.9932, ..., 1.0349, 1.0351, 0.9925],
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]],
[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0052, 1.0064, 0.9966, ..., 1.0035, 1.0033, 0.9952],
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9949],
...,
[1.0123, 1.0032, 0.9993, ..., 1.0050, 1.0097, 0.9956],
[1.0248, 1.0138, 0.9932, ..., 1.0349, 1.0350, 0.9925],
[1.0150, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]]],
device='cuda:0', grad_fn=<ViewBackward0>)
shape: (256, 100, 8)
min/max: (0.9638505578041077, 1.0351523160934448)
mean: 1.000943899154663
std: 0.007913809269666672
target = tensor([[[-8.3893e-01, -8.6427e-01, 4.2425e-01, ..., 5.8477e-01,
1.5457e+00, -4.3527e-01],
[-2.3506e+00, -9.3173e-01, -1.7008e-01, ..., -1.3117e+00,
1.3262e+00, -2.5985e-02],
[ 2.0079e+00, 2.8988e-01, 1.4619e-01, ..., 1.1677e-02,
1.5477e-01, 9.1439e-01],
...,
[-9.1471e-02, 1.2755e-01, 7.2934e-01, ..., 1.1558e+00,
-3.6694e-01, -2.0441e-01],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[ 4.9737e-01, 7.9059e-01, -1.0028e+00, ..., 1.2410e+00,
-1.4670e+00, -1.0270e+00],
[ 1.6537e-02, 5.9942e-01, -1.0490e+00, ..., -1.0667e+00,
-1.8011e-01, -2.0437e-01],
[ 2.0079e+00, 2.8988e-01, 1.4619e-01, ..., 1.1677e-02,
1.5477e-01, 9.1439e-01],
...,
[ 6.8187e-01, 6.5033e-01, -7.8437e-02, ..., 1.2032e+00,
4.3723e-01, 5.0549e-02],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[ 9.2321e-01, 2.1170e-01, -5.1829e-01, ..., 7.5898e-01,
1.7760e+00, 9.7635e-01],
[-2.8159e-02, 8.7647e-01, 3.6170e-01, ..., -8.5379e-01,
5.3774e-01, -1.6134e+00],
[ 6.4561e-01, -1.7245e+00, -5.6855e-01, ..., -4.0166e-01,
-1.8768e+00, -1.1828e+00],
...,
[ 1.3604e+00, 8.3413e-01, 9.7125e-01, ..., -9.8477e-02,
-2.4212e-01, 6.4055e-01],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
...,
[[ 6.4588e-01, -9.7711e-01, 1.4713e-01, ..., -1.7452e+00,
2.4286e-02, 1.4304e-02],
[ 4.1251e-01, 8.6066e-01, -2.1138e-01, ..., -5.0017e-03,
-4.6324e-02, -1.4117e+00],
[ 2.1325e+00, 6.8348e-02, 1.1581e+00, ..., 1.2571e+00,
4.6634e-01, -7.2127e-01],
...,
[ 2.1325e+00, 6.8348e-02, 1.1581e+00, ..., 1.2571e+00,
4.6634e-01, -7.2127e-01],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[-1.7780e+00, -3.5378e-01, 5.7481e-01, ..., -4.0297e-01,
-3.2478e-01, -6.1987e-01],
[-1.7780e+00, -3.5378e-01, 5.7481e-01, ..., -4.0297e-01,
-3.2478e-01, -6.1987e-01],
[-1.5360e+00, -4.3045e-01, -2.2538e-02, ..., -8.2286e-01,
2.1251e-01, 1.5091e-01],
...,
[-3.8774e-01, 3.0244e-01, -1.1404e+00, ..., 2.0661e+00,
-6.1905e-01, -9.3546e-01],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[ 8.4960e-01, -4.3197e-01, 5.5274e-01, ..., -2.7416e-01,
2.0447e+00, -5.1754e-01],
[ 6.3896e-02, 3.2472e-04, -1.2828e+00, ..., -1.0525e+00,
-1.3741e+00, -1.5745e+00],
[ 6.8187e-01, 6.5033e-01, -7.8437e-02, ..., 1.2032e+00,
4.3723e-01, 5.0549e-02],
...,
[ 1.8655e-01, -3.5074e-01, 6.4411e-02, ..., 9.5573e-01,
1.1114e+00, -1.9372e+00],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]]], device='cuda:0')
shape: (256, 100, 8)
min/max: (-4.062912464141846, 4.1066203117370605)
mean: -0.044864702969789505
std: 0.9986171722412109
mamba_block1.inp_proj.weight gradient: 6.592638783331495e-06
mamba_block1.inp_proj.bias gradient: 1.3115262845531106e-05
mamba_block1.out_proj.weight gradient: 8.0315483501181e-05
mamba_block1.out_proj.bias gradient: 0.001918964902870357
mamba_block1.D.weight gradient: 2.1023452063673176e-05
mamba_block1.D.bias gradient: 2.4453845981042832e-05
mamba_block1.S6.A_log gradient: 0.0
mamba_block1.S6.fc1.weight gradient: 2.9185762286942918e-06
mamba_block1.S6.fc1.bias gradient: 4.370540409581736e-06
mamba_block1.S6.fc2.weight gradient: 1.885054552985821e-05
mamba_block1.S6.fc2.bias gradient: 2.9134898795746267e-05
mamba_block1.S6.fc3.weight gradient: 1.7872369426186197e-05
mamba_block1.S6.fc3.bias gradient: 2.769949242065195e-05
mamba_block1.conv.weight gradient: 4.985975465388037e-05
mamba_block1.conv.bias gradient: 8.862216418492608e-06
mamba_block1.conv_linear.weight gradient: 1.7178435882669874e-05
mamba_block1.conv_linear.bias gradient: 5.5030737712513655e-05
mamba_block1.norm.weight gradient: 6.020811724738451e-06
mamba_block2.inp_proj.weight gradient: 0.008193454705178738
mamba_block2.inp_proj.bias gradient: 0.002890463685616851
mamba_block2.out_proj.weight gradient: 0.00791140180081129
mamba_block2.out_proj.bias gradient: 0.019019681960344315
mamba_block2.D.weight gradient: 0.0049378108233213425
mamba_block2.D.bias gradient: 0.0017419286305084825
mamba_block2.S6.A_log gradient: 0.0
mamba_block2.S6.fc1.weight gradient: 0.0010571378516033292
mamba_block2.S6.fc1.bias gradient: 0.0013405927456915379
mamba_block2.S6.fc2.weight gradient: 0.0029663534369319677
mamba_block2.S6.fc2.bias gradient: 0.003960576374083757
mamba_block2.S6.fc3.weight gradient: 0.002755584428086877
mamba_block2.S6.fc3.bias gradient: 0.00362204248085618
mamba_block2.conv.weight gradient: 0.011379457078874111
mamba_block2.conv.bias gradient: 0.0008499903487972915
mamba_block2.conv_linear.weight gradient: 0.009108995087444782
mamba_block2.conv_linear.bias gradient: 0.005850357934832573
mamba_block2.norm.weight gradient: 0.0032401597127318382
mamba_block3.inp_proj.weight gradient: 0.08919057995080948
mamba_block3.inp_proj.bias gradient: 0.03143063187599182
mamba_block3.out_proj.weight gradient: 0.04368537664413452
mamba_block3.out_proj.bias gradient: 8.888643066029545e-08
mamba_block3.D.weight gradient: 0.04513612762093544
mamba_block3.D.bias gradient: 0.015944337472319603
mamba_block3.S6.A_log gradient: 0.0
mamba_block3.S6.fc1.weight gradient: 0.004324209876358509
mamba_block3.S6.fc1.bias gradient: 0.004632828291505575
mamba_block3.S6.fc2.weight gradient: 0.017361413687467575
mamba_block3.S6.fc2.bias gradient: 0.016848554834723473
mamba_block3.S6.fc3.weight gradient: 0.018281958997249603
mamba_block3.S6.fc3.bias gradient: 0.017590906471014023
mamba_block3.conv.weight gradient: 0.16209356486797333
mamba_block3.conv.bias gradient: 0.015879524871706963
mamba_block3.conv_linear.weight gradient: 0.0732278898358345
mamba_block3.conv_linear.bias gradient: 0.04939649626612663
mamba_block3.norm.weight gradient: 0.022089680656790733
DEBUGGING IS ON !!!
output = tensor([[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9949],
...,
[1.0123, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956],
[1.0249, 1.0138, 0.9933, ..., 1.0351, 1.0353, 0.9925],
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]],
[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9949],
...,
[1.0123, 1.0031, 0.9993, ..., 1.0050, 1.0097, 0.9956],
[1.0249, 1.0138, 0.9932, ..., 1.0351, 1.0353, 0.9925],
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]],
[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9953],
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9948],
...,
[1.0123, 1.0031, 0.9993, ..., 1.0050, 1.0097, 0.9956],
[1.0249, 1.0139, 0.9932, ..., 1.0350, 1.0352, 0.9925],
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]],
...,
[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0052, 1.0064, 0.9966, ..., 1.0035, 1.0033, 0.9952],
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9948],
...,
[1.0123, 1.0032, 0.9992, ..., 1.0050, 1.0098, 0.9956],
[1.0249, 1.0138, 0.9932, ..., 1.0351, 1.0353, 0.9925],
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]],
[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0033, 0.9952],
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956],
[1.0249, 1.0139, 0.9932, ..., 1.0350, 1.0353, 0.9925],
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]],
[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0033, 0.9952],
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956],
[1.0249, 1.0138, 0.9932, ..., 1.0351, 1.0353, 0.9925],
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]]],
device='cuda:0', grad_fn=<ViewBackward0>)
shape: (256, 100, 8)
min/max: (0.9638057351112366, 1.035346508026123)
mean: 1.0009446144104004
std: 0.00792401097714901
target = tensor([[[-1.1622e-01, 1.4332e+00, 8.4441e-01, ..., 6.9435e-04,
-6.6773e-02, 1.0834e-01],
[ 5.7104e-01, -4.6999e-01, 1.1255e+00, ..., -8.9141e-01,
1.4730e+00, -9.9213e-02],
[ 1.0452e+00, 7.1647e-01, 6.4485e-02, ..., 1.4146e-01,
1.8992e-01, -1.2258e+00],
...,
[-4.3337e-01, -1.1911e-01, 1.6830e+00, ..., 1.7715e+00,
2.0065e-01, -1.6473e+00],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[ 5.0676e-01, -9.3567e-01, 5.5966e-01, ..., -2.6422e-01,
1.4782e+00, 2.3104e+00],
[ 1.4765e+00, -7.6907e-01, 4.0878e-01, ..., 9.7170e-01,
7.2011e-01, 5.5136e-01],
[ 6.6372e-01, 1.7408e-01, -3.5581e-01, ..., -1.4354e+00,
1.1672e+00, 4.4820e-02],
...,
[ 6.6372e-01, 1.7408e-01, -3.5581e-01, ..., -1.4354e+00,
1.1672e+00, 4.4820e-02],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[-3.9287e-01, 1.2244e+00, 1.4819e+00, ..., -8.4328e-01,
-1.3749e+00, -6.7026e-01],
[ 4.9737e-01, 7.9059e-01, -1.0028e+00, ..., 1.2410e+00,
-1.4670e+00, -1.0270e+00],
[ 1.1287e+00, -4.2922e-01, -6.2596e-01, ..., 4.3149e-03,
-1.7797e+00, -1.4768e+00],
...,
[ 6.8187e-01, 6.5033e-01, -7.8437e-02, ..., 1.2032e+00,
4.3723e-01, 5.0549e-02],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
...,
[[-5.5460e-01, -1.8933e-02, 8.0374e-01, ..., 9.7693e-01,
4.5635e-01, -1.4246e+00],
[ 1.5424e+00, -8.6155e-01, -1.6940e+00, ..., -1.3017e+00,
-4.4700e-01, -1.3483e+00],
[ 2.5504e+00, -5.1241e-01, -2.8401e-01, ..., -9.7184e-02,
-6.2687e-01, -4.9886e-01],
...,
[-4.3280e-01, 4.2621e-01, -7.8516e-01, ..., 3.9015e-01,
8.2322e-01, -1.1738e+00],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[-6.0796e-03, 7.3052e-02, 1.9578e-01, ..., -5.9691e-01,
-9.9734e-01, -2.2435e+00],
[-3.7331e-01, 1.8360e+00, -1.2402e+00, ..., 1.2983e+00,
-6.1130e-01, -2.7833e-01],
[ 2.5504e+00, -5.1241e-01, -2.8401e-01, ..., -9.7184e-02,
-6.2687e-01, -4.9886e-01],
...,
[-3.0289e-01, -7.8487e-01, 6.5365e-01, ..., 2.1631e-02,
-5.1024e-02, 1.3417e+00],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[ 2.5504e+00, -5.1241e-01, -2.8401e-01, ..., -9.7184e-02,
-6.2687e-01, -4.9886e-01],
[ 2.5504e+00, -5.1241e-01, -2.8401e-01, ..., -9.7184e-02,
-6.2687e-01, -4.9886e-01],
[ 1.0189e-01, -3.6331e-01, 1.4970e+00, ..., -1.1070e+00,
1.8470e+00, 7.8808e-01],
...,
[-8.3236e-01, -4.5072e-01, 2.3980e-01, ..., 7.7698e-01,
-1.6973e+00, -1.6883e+00],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]]], device='cuda:0')
shape: (256, 100, 8)
min/max: (-3.7171173095703125, 4.436890602111816)
mean: -0.04337029531598091
std: 0.996184229850769
mamba_block1.inp_proj.weight gradient: 8.927928320190404e-06
mamba_block1.inp_proj.bias gradient: 1.3252603821456432e-05
mamba_block1.out_proj.weight gradient: 6.981792103033513e-05
mamba_block1.out_proj.bias gradient: 0.0018748645670711994
mamba_block1.D.weight gradient: 1.9445253201411106e-05
mamba_block1.D.bias gradient: 2.4798053345875815e-05
mamba_block1.S6.A_log gradient: 0.0
mamba_block1.S6.fc1.weight gradient: 2.8577871944435174e-06
mamba_block1.S6.fc1.bias gradient: 4.1480857362330426e-06
mamba_block1.S6.fc2.weight gradient: 2.404704355285503e-05
mamba_block1.S6.fc2.bias gradient: 3.7825808249181136e-05
mamba_block1.S6.fc3.weight gradient: 2.3223990865517408e-05
mamba_block1.S6.fc3.bias gradient: 3.674719846458174e-05
mamba_block1.conv.weight gradient: 5.1158986025257036e-05
mamba_block1.conv.bias gradient: 1.1530260053405073e-05
mamba_block1.conv_linear.weight gradient: 1.925104697875213e-05
mamba_block1.conv_linear.bias gradient: 6.093188130762428e-05
mamba_block1.norm.weight gradient: 5.640730250888737e-06
mamba_block2.inp_proj.weight gradient: 0.008388090878725052
mamba_block2.inp_proj.bias gradient: 0.00295907910913229
mamba_block2.out_proj.weight gradient: 0.008359096944332123
mamba_block2.out_proj.bias gradient: 0.018886201083660126
mamba_block2.D.weight gradient: 0.005083959549665451
mamba_block2.D.bias gradient: 0.0017934782663360238
mamba_block2.S6.A_log gradient: 0.0
mamba_block2.S6.fc1.weight gradient: 0.0011655916459858418
mamba_block2.S6.fc1.bias gradient: 0.0014712277334183455
mamba_block2.S6.fc2.weight gradient: 0.0031938902102410793
mamba_block2.S6.fc2.bias gradient: 0.004056679084897041
mamba_block2.S6.fc3.weight gradient: 0.0029626914765685797
mamba_block2.S6.fc3.bias gradient: 0.0037037297151982784
mamba_block2.conv.weight gradient: 0.01159916166216135
mamba_block2.conv.bias gradient: 0.00081581249833107
mamba_block2.conv_linear.weight gradient: 0.009459671564400196
mamba_block2.conv_linear.bias gradient: 0.0063721355982124805
mamba_block2.norm.weight gradient: 0.0032281361054629087
mamba_block3.inp_proj.weight gradient: 0.09379231184720993
mamba_block3.inp_proj.bias gradient: 0.03305402398109436
mamba_block3.out_proj.weight gradient: 0.0430489145219326
mamba_block3.out_proj.bias gradient: 5.539216729744112e-08
mamba_block3.D.weight gradient: 0.03742552548646927
mamba_block3.D.bias gradient: 0.013216960243880749
mamba_block3.S6.A_log gradient: 0.0
mamba_block3.S6.fc1.weight gradient: 0.00466371001675725
mamba_block3.S6.fc1.bias gradient: 0.005200730636715889
mamba_block3.S6.fc2.weight gradient: 0.01829967088997364
mamba_block3.S6.fc2.bias gradient: 0.018736790865659714
mamba_block3.S6.fc3.weight gradient: 0.01904943399131298
mamba_block3.S6.fc3.bias gradient: 0.019463086500763893
mamba_block3.conv.weight gradient: 0.16161637008190155
mamba_block3.conv.bias gradient: 0.01583053544163704
mamba_block3.conv_linear.weight gradient: 0.07311218231916428
mamba_block3.conv_linear.bias gradient: 0.043246448040008545
mamba_block3.norm.weight gradient: 0.02208569645881653
DEBUGGING IS ON !!!
output = tensor([[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951],
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0029, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956],
[1.0250, 1.0138, 0.9933, ..., 1.0352, 1.0355, 0.9926],
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]],
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951],
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9992, ..., 1.0050, 1.0098, 0.9956],
[1.0250, 1.0139, 0.9933, ..., 1.0352, 1.0354, 0.9926],
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9964]],
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951],
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0097, 0.9956],
[1.0250, 1.0139, 0.9933, ..., 1.0351, 1.0354, 0.9925],
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9964]],
...,
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951],
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956],
[1.0250, 1.0139, 0.9932, ..., 1.0352, 1.0354, 0.9925],
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9964]],
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951],
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0029, 1.0012, 0.9948],
...,
[1.0124, 1.0031, 0.9993, ..., 1.0050, 1.0098, 0.9956],
[1.0250, 1.0139, 0.9933, ..., 1.0352, 1.0354, 0.9926],
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9964]],
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951],
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0029, 1.0012, 0.9948],
...,
[1.0124, 1.0031, 0.9993, ..., 1.0050, 1.0097, 0.9956],
[1.0250, 1.0139, 0.9933, ..., 1.0351, 1.0354, 0.9925],
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9964]]],
device='cuda:0', grad_fn=<ViewBackward0>)
shape: (256, 100, 8)
min/max: (0.9637707471847534, 1.0355018377304077)
mean: 1.0009452104568481
std: 0.00793476589024067
target = tensor([[[ 1.2582e+00, 6.2747e-01, -1.9484e+00, ..., -7.7599e-01,
1.0496e+00, 5.3618e-01],
[-1.3360e+00, 1.4383e-01, 1.7031e+00, ..., -1.1077e+00,
8.4779e-01, -3.4812e-01],
[ 2.5504e+00, -5.1241e-01, -2.8401e-01, ..., -9.7184e-02,
-6.2687e-01, -4.9886e-01],
...,
[-1.0850e+00, -6.4378e-01, 2.7434e-01, ..., -1.1642e+00,
-8.7424e-01, -2.7755e-01],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[ 1.1027e+00, 7.7860e-01, 1.2513e+00, ..., -2.4502e-01,
3.2866e-01, -1.6867e+00],
[ 2.5271e+00, 3.8280e-01, 4.4642e-01, ..., 1.7231e-01,
-5.7369e-01, 2.5980e+00],
[ 1.2286e+00, -1.6981e+00, 1.1041e-01, ..., -1.3688e+00,
-4.3559e-01, 2.2583e-01],
...,
[-1.7780e+00, -3.5378e-01, 5.7481e-01, ..., -4.0297e-01,
-3.2478e-01, -6.1987e-01],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[ 5.0676e-01, -9.3567e-01, 5.5966e-01, ..., -2.6422e-01,
1.4782e+00, 2.3104e+00],
[-5.8164e-01, 7.9681e-03, 1.8231e+00, ..., -1.1851e+00,
4.1620e-01, -2.9570e-03],
[ 9.3353e-01, 1.8774e-01, -2.0042e+00, ..., -1.1503e+00,
-1.7980e+00, -5.6396e-01],
...,
[ 6.3467e-01, -6.0116e-01, 3.4803e-01, ..., 1.5082e+00,
-9.4524e-01, 2.0558e+00],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
...,
[[-4.1120e-01, 2.8092e-02, 5.1873e-01, ..., 6.3312e-01,
3.6938e-01, 1.5776e-01],
[ 2.9367e-01, 2.7959e+00, -1.3492e+00, ..., -1.4478e+00,
-5.1723e-01, 8.9243e-01],
[ 2.6099e-02, -1.0027e-01, -1.5132e+00, ..., -3.9709e-02,
1.1886e-01, 1.3587e+00],
...,
[-1.0028e+00, 1.3193e+00, 1.1326e+00, ..., 1.1135e+00,
-2.1063e+00, -1.4438e-01],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[ 4.9737e-01, 7.9059e-01, -1.0028e+00, ..., 1.2410e+00,
-1.4670e+00, -1.0270e+00],
[-1.0945e-01, 1.1213e+00, 2.1538e-01, ..., -6.1082e-01,
1.7132e-01, -1.0861e+00],
[-1.6287e+00, -9.6626e-01, -2.6581e-03, ..., 1.2085e+00,
-9.0730e-01, 4.8345e-02],
...,
[-3.8774e-01, 3.0244e-01, -1.1404e+00, ..., 2.0661e+00,
-6.1905e-01, -9.3546e-01],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[ 9.2342e-01, -9.4206e-01, 1.0451e+00, ..., -4.7812e-01,
4.2050e-02, 2.9336e-01],
[-4.0796e-01, 1.0457e+00, 1.2001e-02, ..., -1.2754e-01,
2.3795e+00, 1.2947e-01],
[-8.9576e-01, -1.3298e+00, 4.7374e-01, ..., -2.1709e-01,
-5.5530e-02, -4.8000e-01],
...,
[ 2.5504e+00, -5.1241e-01, -2.8401e-01, ..., -9.7184e-02,
-6.2687e-01, -4.9886e-01],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]]], device='cuda:0')
shape: (256, 100, 8)
min/max: (-3.8195555210113525, 4.436890602111816)
mean: -0.03474504500627518
std: 0.9982456564903259
mamba_block1.inp_proj.weight gradient: 6.786232461308828e-06
mamba_block1.inp_proj.bias gradient: 1.2180601515865419e-05
mamba_block1.out_proj.weight gradient: 6.749451131327078e-05
mamba_block1.out_proj.bias gradient: 0.0017811341676861048
mamba_block1.D.weight gradient: 1.8621594790602103e-05
mamba_block1.D.bias gradient: 2.298756589880213e-05
mamba_block1.S6.A_log gradient: 0.0
mamba_block1.S6.fc1.weight gradient: 2.4946195935626747e-06
mamba_block1.S6.fc1.bias gradient: 3.2899588404688984e-06
mamba_block1.S6.fc2.weight gradient: 1.3449737707560416e-05
mamba_block1.S6.fc2.bias gradient: 2.101602149195969e-05
mamba_block1.S6.fc3.weight gradient: 1.3065160601399839e-05
mamba_block1.S6.fc3.bias gradient: 2.0496405340963975e-05
mamba_block1.conv.weight gradient: 4.914818055112846e-05
mamba_block1.conv.bias gradient: 8.672555850353092e-06
mamba_block1.conv_linear.weight gradient: 1.5568120943498798e-05
mamba_block1.conv_linear.bias gradient: 4.648379763239063e-05
mamba_block1.norm.weight gradient: 3.9241208469320554e-06
mamba_block2.inp_proj.weight gradient: 0.008072205819189548
mamba_block2.inp_proj.bias gradient: 0.002847641473636031
mamba_block2.out_proj.weight gradient: 0.008008691482245922
mamba_block2.out_proj.bias gradient: 0.020124541595578194
mamba_block2.D.weight gradient: 0.0048057264648377895
mamba_block2.D.bias gradient: 0.0016953115118667483
mamba_block2.S6.A_log gradient: 0.0
mamba_block2.S6.fc1.weight gradient: 0.0010728761553764343
mamba_block2.S6.fc1.bias gradient: 0.0013542993692681193
mamba_block2.S6.fc2.weight gradient: 0.003033427754417062
mamba_block2.S6.fc2.bias gradient: 0.004091166891157627
mamba_block2.S6.fc3.weight gradient: 0.002817986300215125
mamba_block2.S6.fc3.bias gradient: 0.003747538896277547
mamba_block2.conv.weight gradient: 0.011370806023478508
mamba_block2.conv.bias gradient: 0.0008210184169001877
mamba_block2.conv_linear.weight gradient: 0.008997836150228977
mamba_block2.conv_linear.bias gradient: 0.006177668925374746
mamba_block2.norm.weight gradient: 0.003070503007620573
mamba_block3.inp_proj.weight gradient: 0.09553761035203934
mamba_block3.inp_proj.bias gradient: 0.033675868064165115
mamba_block3.out_proj.weight gradient: 0.043427761644124985
mamba_block3.out_proj.bias gradient: 6.981348121826159e-08
mamba_block3.D.weight gradient: 0.044211406260728836
mamba_block3.D.bias gradient: 0.01561672892421484
mamba_block3.S6.A_log gradient: 0.0
mamba_block3.S6.fc1.weight gradient: 0.00466552609577775
mamba_block3.S6.fc1.bias gradient: 0.005186178721487522
mamba_block3.S6.fc2.weight gradient: 0.01796438917517662
mamba_block3.S6.fc2.bias gradient: 0.020860623568296432
mamba_block3.S6.fc3.weight gradient: 0.018770582973957062
mamba_block3.S6.fc3.bias gradient: 0.02151082083582878
mamba_block3.conv.weight gradient: 0.16399376094341278
mamba_block3.conv.bias gradient: 0.015978528186678886
mamba_block3.conv_linear.weight gradient: 0.07053054869174957
mamba_block3.conv_linear.bias gradient: 0.04882459342479706
mamba_block3.norm.weight gradient: 0.022343263030052185
DEBUGGING IS ON !!!
output = tensor([[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951],
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956],
[1.0251, 1.0139, 0.9933, ..., 1.0353, 1.0356, 0.9925],
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]],
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951],
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0029, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956],
[1.0251, 1.0139, 0.9933, ..., 1.0353, 1.0356, 0.9926],
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9964]],
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951],
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0029, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956],
[1.0251, 1.0139, 0.9933, ..., 1.0353, 1.0356, 0.9926],
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9964]],
...,
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951],
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0029, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956],
[1.0251, 1.0139, 0.9933, ..., 1.0353, 1.0356, 0.9926],
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]],
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951],
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0029, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9992, ..., 1.0050, 1.0098, 0.9956],
[1.0251, 1.0139, 0.9933, ..., 1.0353, 1.0356, 0.9926],
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9964]],
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951],
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0029, 1.0012, 0.9948],
...,
[1.0124, 1.0031, 0.9993, ..., 1.0050, 1.0098, 0.9956],
[1.0251, 1.0139, 0.9933, ..., 1.0353, 1.0356, 0.9926],
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]]],
device='cuda:0', grad_fn=<ViewBackward0>)
shape: (256, 100, 8)
min/max: (0.9637242555618286, 1.0356793403625488)
mean: 1.0009459257125854
std: 0.007945802994072437
target = tensor([[[-1.0850e+00, -6.4378e-01, 2.7434e-01, ..., -1.1642e+00,
-8.7424e-01, -2.7755e-01],
[-3.2981e-01, -6.2568e-01, 7.4563e-01, ..., -2.8829e+00,
-2.6204e+00, 1.0786e+00],
[ 8.9414e-01, -2.4687e+00, 5.5291e-01, ..., 1.8136e-02,
2.4835e-01, 5.5237e-02],
...,
[ 4.7191e-01, 3.6167e-01, -3.5786e-01, ..., -3.8691e-01,
1.6128e+00, 2.4838e-01],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[-1.6287e+00, -9.6626e-01, -2.6581e-03, ..., 1.2085e+00,
-9.0730e-01, 4.8345e-02],
[ 1.3781e-01, -3.8981e-01, 4.6194e-01, ..., 1.9883e-01,
-3.7158e-01, 3.5527e-01],
[ 2.5504e+00, -5.1241e-01, -2.8401e-01, ..., -9.7184e-02,
-6.2687e-01, -4.9886e-01],
...,
[-7.6430e-01, -1.8293e+00, 3.4729e-01, ..., -1.8000e-02,
-4.8519e-01, -4.4253e-01],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[-2.2150e-02, -1.3493e+00, -6.6053e-01, ..., 9.5004e-01,
-2.8410e-01, 1.1236e-01],
[ 7.7266e-01, -1.2528e-01, -5.1251e-01, ..., -9.5071e-01,
1.0857e+00, 6.4368e-01],
[-2.7640e-01, 1.4894e+00, 1.4303e-01, ..., -2.2086e-01,
2.4025e+00, 8.1037e-01],
...,
[ 1.0189e-01, -3.6331e-01, 1.4970e+00, ..., -1.1070e+00,
1.8470e+00, 7.8808e-01],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
...,
[[ 2.9367e-01, 2.7959e+00, -1.3492e+00, ..., -1.4478e+00,
-5.1723e-01, 8.9243e-01],
[-5.3807e-01, 1.8558e+00, -1.3125e+00, ..., -2.1141e+00,
-5.7919e-01, -8.3718e-02],
[-1.0806e-01, 8.7904e-01, 6.7809e-01, ..., -5.8664e-01,
-1.6239e-01, 4.4618e-01],
...,
[-1.7780e+00, -3.5378e-01, 5.7481e-01, ..., -4.0297e-01,
-3.2478e-01, -6.1987e-01],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[-9.1818e-01, 9.8258e-02, -4.3746e-01, ..., 1.4176e-01,
-7.2111e-01, -1.5051e+00],
[-2.6851e-01, -6.7224e-01, -6.0742e-01, ..., 3.7681e-01,
-1.0639e+00, -1.6735e+00],
[ 2.5504e+00, -5.1241e-01, -2.8401e-01, ..., -9.7184e-02,
-6.2687e-01, -4.9886e-01],
...,
[-1.7780e+00, -3.5378e-01, 5.7481e-01, ..., -4.0297e-01,
-3.2478e-01, -6.1987e-01],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[ 8.2675e-01, -7.5806e-01, -1.7703e+00, ..., -1.0994e+00,
5.3112e-02, -9.7968e-01],
[ 1.1250e+00, 1.1728e+00, 5.4517e-01, ..., -1.0478e+00,
4.3682e-01, 1.5019e+00],
[-1.8971e-01, 2.4852e-01, 1.0079e+00, ..., -1.3113e-01,
-7.4732e-01, 1.3381e+00],
...,
[ 2.4386e-01, -1.2999e-01, -1.2611e+00, ..., -5.8786e-01,
-1.3674e-02, -1.0314e+00],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]]], device='cuda:0')
shape: (256, 100, 8)
min/max: (-3.7811129093170166, 3.9631638526916504)
mean: -0.037468329071998596
std: 0.9972341060638428
mamba_block1.inp_proj.weight gradient: 7.708286830165889e-06
mamba_block1.inp_proj.bias gradient: 1.3073607078695204e-05
mamba_block1.out_proj.weight gradient: 8.224073826568201e-05
mamba_block1.out_proj.bias gradient: 0.0018965421477332711
mamba_block1.D.weight gradient: 2.0551829948090017e-05
mamba_block1.D.bias gradient: 2.6780233383760788e-05
mamba_block1.S6.A_log gradient: 0.0
mamba_block1.S6.fc1.weight gradient: 3.2340644793293905e-06
mamba_block1.S6.fc1.bias gradient: 4.6128470785333775e-06
mamba_block1.S6.fc2.weight gradient: 1.912422703753691e-05
mamba_block1.S6.fc2.bias gradient: 3.044418008357752e-05
mamba_block1.S6.fc3.weight gradient: 1.824636638048105e-05
mamba_block1.S6.fc3.bias gradient: 2.918856989708729e-05
mamba_block1.conv.weight gradient: 4.986266867490485e-05
mamba_block1.conv.bias gradient: 8.942976819525938e-06
mamba_block1.conv_linear.weight gradient: 1.743154825817328e-05
mamba_block1.conv_linear.bias gradient: 6.066836431273259e-05
mamba_block1.norm.weight gradient: 6.70700228511123e-06
mamba_block2.inp_proj.weight gradient: 0.007989523932337761
mamba_block2.inp_proj.bias gradient: 0.0028184521943330765
mamba_block2.out_proj.weight gradient: 0.008060919120907784
mamba_block2.out_proj.bias gradient: 0.018983684480190277
mamba_block2.D.weight gradient: 0.004776414018124342
mamba_block2.D.bias gradient: 0.0016849525272846222
mamba_block2.S6.A_log gradient: 0.0
mamba_block2.S6.fc1.weight gradient: 0.0010103124659508467
mamba_block2.S6.fc1.bias gradient: 0.0012939375592395663
mamba_block2.S6.fc2.weight gradient: 0.002913458738476038
mamba_block2.S6.fc2.bias gradient: 0.003868406405672431
mamba_block2.S6.fc3.weight gradient: 0.0027066958136856556
mamba_block2.S6.fc3.bias gradient: 0.00352851883508265
mamba_block2.conv.weight gradient: 0.011430115438997746
mamba_block2.conv.bias gradient: 0.0008375911857001483
mamba_block2.conv_linear.weight gradient: 0.008923036977648735
mamba_block2.conv_linear.bias gradient: 0.005675981752574444
mamba_block2.norm.weight gradient: 0.003188737900927663
mamba_block3.inp_proj.weight gradient: 0.09013670682907104
mamba_block3.inp_proj.bias gradient: 0.03176071494817734
mamba_block3.out_proj.weight gradient: 0.04090064764022827
mamba_block3.out_proj.bias gradient: 8.47963974592858e-08
mamba_block3.D.weight gradient: 0.04176799952983856
mamba_block3.D.bias gradient: 0.014754664152860641
mamba_block3.S6.A_log gradient: 0.0
mamba_block3.S6.fc1.weight gradient: 0.00422002375125885
mamba_block3.S6.fc1.bias gradient: 0.0044675562530756
mamba_block3.S6.fc2.weight gradient: 0.0172658059746027
mamba_block3.S6.fc2.bias gradient: 0.019037457183003426
mamba_block3.S6.fc3.weight gradient: 0.01806436851620674
mamba_block3.S6.fc3.bias gradient: 0.019686652347445488
mamba_block3.conv.weight gradient: 0.160682812333107
mamba_block3.conv.bias gradient: 0.01587325893342495
mamba_block3.conv_linear.weight gradient: 0.06943379342556
mamba_block3.conv_linear.bias gradient: 0.05003580451011658
mamba_block3.norm.weight gradient: 0.02130027487874031
DEBUGGING IS ON !!!
output = tensor([[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951],
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9992, ..., 1.0050, 1.0098, 0.9956],
[1.0252, 1.0139, 0.9933, ..., 1.0354, 1.0358, 0.9926],
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9964]],
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951],
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0029, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956],
[1.0252, 1.0138, 0.9933, ..., 1.0354, 1.0358, 0.9926],
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]],
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951],
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0029, 1.0012, 0.9948],
...,
[1.0124, 1.0031, 0.9993, ..., 1.0050, 1.0098, 0.9956],
[1.0252, 1.0139, 0.9933, ..., 1.0354, 1.0358, 0.9926],
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]],
...,
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951],
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956],
[1.0252, 1.0138, 0.9933, ..., 1.0354, 1.0358, 0.9926],
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]],
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951],
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0029, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956],
[1.0252, 1.0139, 0.9933, ..., 1.0354, 1.0358, 0.9926],
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]],
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951],
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956],
[1.0252, 1.0139, 0.9933, ..., 1.0355, 1.0358, 0.9926],
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]]],
device='cuda:0', grad_fn=<ViewBackward0>)
shape: (256, 100, 8)
min/max: (0.9636823534965515, 1.0358315706253052)
mean: 1.0009465217590332
std: 0.007956922054290771
target = tensor([[[-0.4988, -0.6918, -0.5971, ..., -1.3568, -0.3844, 0.6915],
[ 0.3156, -1.5684, -0.7855, ..., -0.0484, -0.9211, -0.2853],
[-0.3282, -0.4495, 0.3974, ..., 0.9546, -0.4394, -0.2031],
...,
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-1.0574, 0.2834, -0.0056, ..., 1.8479, 0.3408, -0.3568],
[ 0.2411, 1.3111, 0.5789, ..., -0.3322, 1.1244, -1.1123],
[ 0.4974, 0.7906, -1.0028, ..., 1.2410, -1.4670, -1.0270],
...,
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.0383, 0.7346, 1.1566, ..., -1.2615, -1.3133, 0.8579],
[-0.7210, 1.5826, 0.4122, ..., 0.3692, 1.2578, 0.0504],
[-0.0915, 0.1275, 0.7293, ..., 1.1558, -0.3669, -0.2044],
...,
[-2.0846, -0.2999, 0.0431, ..., -0.1785, 1.3174, 1.8029],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
...,
[[ 0.1019, -0.3633, 1.4970, ..., -1.1070, 1.8470, 0.7881],
[-0.7269, -0.3677, -0.0543, ..., 0.4443, 0.2045, 0.0918],
[-0.9685, 0.8548, -0.1369, ..., 0.6784, 0.1392, -0.7722],
...,
[ 0.3352, 1.3790, -1.4903, ..., 0.1442, 0.8230, -0.7261],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 1.3890, 1.7612, -1.5885, ..., 0.6402, 1.1234, 0.6314],
[ 0.5255, -0.6071, -0.6983, ..., -0.1975, 0.2420, 0.5352],
[-0.3851, -1.0689, 0.9486, ..., -0.4575, 0.1463, -0.2335],
...,
[ 1.8217, 0.4270, 0.9168, ..., -0.5362, 0.0306, -0.0278],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-0.6884, 0.6784, -1.0951, ..., -1.3010, 0.1661, -0.5777],
[-0.6884, 0.6784, -1.0951, ..., -1.3010, 0.1661, -0.5777],
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199],
...,
[ 1.1470, 1.1976, -0.3732, ..., -0.1076, 2.3560, -0.8394],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]],
device='cuda:0')
shape: (256, 100, 8)
min/max: (-4.062912464141846, 3.9631638526916504)
mean: -0.036761801689863205
std: 0.9940363168716431
mamba_block1.inp_proj.weight gradient: 7.884400474722497e-06
mamba_block1.inp_proj.bias gradient: 1.4247218132368289e-05
mamba_block1.out_proj.weight gradient: 7.718842971371487e-05
mamba_block1.out_proj.bias gradient: 0.00190842489246279
mamba_block1.D.weight gradient: 2.390151894360315e-05
mamba_block1.D.bias gradient: 2.5844841729849577e-05
mamba_block1.S6.A_log gradient: 0.0
mamba_block1.S6.fc1.weight gradient: 3.1179840789263835e-06
mamba_block1.S6.fc1.bias gradient: 4.432289642863907e-06
mamba_block1.S6.fc2.weight gradient: 2.273116297146771e-05
mamba_block1.S6.fc2.bias gradient: 3.4517765016062185e-05
mamba_block1.S6.fc3.weight gradient: 2.1624193323077634e-05
mamba_block1.S6.fc3.bias gradient: 3.2941166864475235e-05
mamba_block1.conv.weight gradient: 4.8649115342414007e-05
mamba_block1.conv.bias gradient: 1.0308258424629457e-05
mamba_block1.conv_linear.weight gradient: 1.780458478606306e-05
mamba_block1.conv_linear.bias gradient: 5.851773312315345e-05
mamba_block1.norm.weight gradient: 6.214230779733043e-06
mamba_block2.inp_proj.weight gradient: 0.008193948306143284
mamba_block2.inp_proj.bias gradient: 0.002890530275180936
mamba_block2.out_proj.weight gradient: 0.008393185213208199
mamba_block2.out_proj.bias gradient: 0.019226713106036186
mamba_block2.D.weight gradient: 0.0050740428268909454
mamba_block2.D.bias gradient: 0.00178993318695575
mamba_block2.S6.A_log gradient: 0.0
mamba_block2.S6.fc1.weight gradient: 0.001113535021431744
mamba_block2.S6.fc1.bias gradient: 0.0014299751492217183
mamba_block2.S6.fc2.weight gradient: 0.003115260973572731
mamba_block2.S6.fc2.bias gradient: 0.004271494224667549
mamba_block2.S6.fc3.weight gradient: 0.002891583601012826
mamba_block2.S6.fc3.bias gradient: 0.003908565733581781
mamba_block2.conv.weight gradient: 0.011489537544548512
mamba_block2.conv.bias gradient: 0.0008398296195082366
mamba_block2.conv_linear.weight gradient: 0.009319150820374489
mamba_block2.conv_linear.bias gradient: 0.006470021326094866
mamba_block2.norm.weight gradient: 0.003238578559830785
mamba_block3.inp_proj.weight gradient: 0.09595136344432831
mamba_block3.inp_proj.bias gradient: 0.033816881477832794
mamba_block3.out_proj.weight gradient: 0.044870056211948395
mamba_block3.out_proj.bias gradient: 6.642654426514127e-08
mamba_block3.D.weight gradient: 0.04112648218870163
mamba_block3.D.bias gradient: 0.014526760205626488
mamba_block3.S6.A_log gradient: 0.0
mamba_block3.S6.fc1.weight gradient: 0.004705202765762806
mamba_block3.S6.fc1.bias gradient: 0.0051134019158780575
mamba_block3.S6.fc2.weight gradient: 0.01586112007498741
mamba_block3.S6.fc2.bias gradient: 0.013769936747848988
mamba_block3.S6.fc3.weight gradient: 0.016735132783651352
mamba_block3.S6.fc3.bias gradient: 0.014533820562064648
mamba_block3.conv.weight gradient: 0.16413094103336334
mamba_block3.conv.bias gradient: 0.016221188008785248
mamba_block3.conv_linear.weight gradient: 0.07475219666957855
mamba_block3.conv_linear.bias gradient: 0.04689887911081314
mamba_block3.norm.weight gradient: 0.0209010262042284
DEBUGGING IS ON !!!
output = tensor([[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951],
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956],
[1.0253, 1.0139, 0.9934, ..., 1.0355, 1.0359, 0.9926],
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]],
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9956],
[1.0253, 1.0139, 0.9933, ..., 1.0356, 1.0360, 0.9926],
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]],
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951],
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956],
[1.0253, 1.0139, 0.9934, ..., 1.0356, 1.0360, 0.9926],
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]],
...,
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9956],
[1.0253, 1.0139, 0.9934, ..., 1.0356, 1.0359, 0.9926],
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]],
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951],
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9956],
[1.0253, 1.0139, 0.9934, ..., 1.0355, 1.0359, 0.9926],
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]],
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0031, 0.9993, ..., 1.0051, 1.0098, 0.9956],
[1.0253, 1.0139, 0.9934, ..., 1.0355, 1.0359, 0.9926],
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]]],
device='cuda:0', grad_fn=<ViewBackward0>)
shape: (256, 100, 8)
min/max: (0.9636375904083252, 1.0359997749328613)
mean: 1.00094735622406
std: 0.007967358455061913
target = tensor([[[-0.4072, -0.5642, 0.8817, ..., 0.7706, -0.4521, -0.3770],
[-0.6884, 0.6784, -1.0951, ..., -1.3010, 0.1661, -0.5777],
[-0.8324, -0.4507, 0.2398, ..., 0.7770, -1.6973, -1.6883],
...,
[-0.6713, -0.7524, -0.7726, ..., -0.4873, 0.0152, 1.0856],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.0967, -0.0082, -0.8923, ..., -1.5443, -0.6645, -0.7764],
[ 0.1676, 0.6545, -0.4603, ..., 0.9874, 1.3225, 0.1617],
[ 1.1250, 1.1728, 0.5452, ..., -1.0478, 0.4368, 1.5019],
...,
[ 0.3847, -0.3358, -0.6223, ..., -0.8391, 0.2528, 0.4785],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-1.1551, 0.3942, 1.8174, ..., -0.4964, -0.1678, 0.8586],
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989],
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989],
...,
[ 0.0378, -0.2052, -0.5975, ..., -0.1757, 0.5491, -1.5124],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
...,
[[-0.8834, -1.0507, -0.6479, ..., -0.5122, 0.4084, -0.2526],
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989],
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989],
...,
[-0.9680, -1.2858, -1.1414, ..., -0.5307, -0.5660, -0.2579],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-1.7408, 2.1198, 0.1103, ..., 0.3524, 1.0912, -0.2684],
[ 1.0274, 0.0333, 0.1309, ..., -0.9998, -0.7694, 0.0330],
[ 1.9510, -1.3107, 1.1983, ..., 1.4472, -0.9458, 0.6607],
...,
[ 0.9335, 0.1877, -2.0042, ..., -1.1503, -1.7980, -0.5640],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-1.9022, -1.0460, -0.8212, ..., -0.6692, -0.9739, -0.3634],
[ 0.6606, 0.6995, -1.1284, ..., 0.8394, -0.4208, -0.3543],
[-1.4874, 0.0543, 1.0052, ..., 1.6346, -0.3576, 0.3655],
...,
[-1.3817, 0.1942, -1.2593, ..., 0.0592, 0.4673, 0.3583],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]],
device='cuda:0')
shape: (256, 100, 8)
min/max: (-4.0426435470581055, 3.9631638526916504)
mean: -0.041232164949178696
std: 0.9984654784202576
mamba_block1.inp_proj.weight gradient: 7.257808647409547e-06
mamba_block1.inp_proj.bias gradient: 1.2536821486719418e-05
mamba_block1.out_proj.weight gradient: 7.684003503527492e-05
mamba_block1.out_proj.bias gradient: 0.0019513132283464074
mamba_block1.D.weight gradient: 2.0352830688352697e-05
mamba_block1.D.bias gradient: 2.5774028472369537e-05
mamba_block1.S6.A_log gradient: 0.0
mamba_block1.S6.fc1.weight gradient: 2.904744633269729e-06
mamba_block1.S6.fc1.bias gradient: 4.214916316414019e-06
mamba_block1.S6.fc2.weight gradient: 1.7458878573961556e-05
mamba_block1.S6.fc2.bias gradient: 2.8126347388024442e-05
mamba_block1.S6.fc3.weight gradient: 1.7111855413531885e-05
mamba_block1.S6.fc3.bias gradient: 2.7683992811944336e-05
mamba_block1.conv.weight gradient: 4.96259490319062e-05
mamba_block1.conv.bias gradient: 7.881315468694083e-06
mamba_block1.conv_linear.weight gradient: 1.8037038898910396e-05
mamba_block1.conv_linear.bias gradient: 5.335741661838256e-05
mamba_block1.norm.weight gradient: 5.022363893658621e-06
mamba_block2.inp_proj.weight gradient: 0.008337417617440224
mamba_block2.inp_proj.bias gradient: 0.0029411145951598883
mamba_block2.out_proj.weight gradient: 0.00849019642919302
mamba_block2.out_proj.bias gradient: 0.018002500757575035
mamba_block2.D.weight gradient: 0.005273017566651106
mamba_block2.D.bias gradient: 0.001860112533904612
mamba_block2.S6.A_log gradient: 0.0
mamba_block2.S6.fc1.weight gradient: 0.0011189623037353158
mamba_block2.S6.fc1.bias gradient: 0.001455896534025669
mamba_block2.S6.fc2.weight gradient: 0.0032008709385991096
mamba_block2.S6.fc2.bias gradient: 0.004261344205588102
mamba_block2.S6.fc3.weight gradient: 0.002974285278469324
mamba_block2.S6.fc3.bias gradient: 0.003896206384524703
mamba_block2.conv.weight gradient: 0.011668318882584572
mamba_block2.conv.bias gradient: 0.0008345782989636064
mamba_block2.conv_linear.weight gradient: 0.009310065768659115
mamba_block2.conv_linear.bias gradient: 0.0065779616124928
mamba_block2.norm.weight gradient: 0.003354104468598962
mamba_block3.inp_proj.weight gradient: 0.09024792164564133
mamba_block3.inp_proj.bias gradient: 0.03179505467414856
mamba_block3.out_proj.weight gradient: 0.043840229511260986
mamba_block3.out_proj.bias gradient: 7.639116006430413e-08
mamba_block3.D.weight gradient: 0.041635192930698395
mamba_block3.D.bias gradient: 0.014703379012644291
mamba_block3.S6.A_log gradient: 0.0
mamba_block3.S6.fc1.weight gradient: 0.004581212531775236
mamba_block3.S6.fc1.bias gradient: 0.004985094536095858
mamba_block3.S6.fc2.weight gradient: 0.016831787303090096
mamba_block3.S6.fc2.bias gradient: 0.015875842422246933
mamba_block3.S6.fc3.weight gradient: 0.017684169113636017
mamba_block3.S6.fc3.bias gradient: 0.016509560868144035
mamba_block3.conv.weight gradient: 0.16424451768398285
mamba_block3.conv.bias gradient: 0.016224239021539688
mamba_block3.conv_linear.weight gradient: 0.07360465824604034
mamba_block3.conv_linear.bias gradient: 0.04532546550035477
mamba_block3.norm.weight gradient: 0.02086501196026802
DEBUGGING IS ON !!!
output = tensor([[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9956],
[1.0254, 1.0139, 0.9934, ..., 1.0357, 1.0362, 0.9926],
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]],
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9956],
[1.0254, 1.0139, 0.9934, ..., 1.0357, 1.0361, 0.9926],
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]],
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956],
[1.0254, 1.0138, 0.9934, ..., 1.0357, 1.0362, 0.9926],
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]],
...,
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956],
[1.0254, 1.0139, 0.9934, ..., 1.0357, 1.0361, 0.9926],
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]],
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956],
[1.0254, 1.0139, 0.9934, ..., 1.0357, 1.0361, 0.9926],
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]],
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9956],
[1.0254, 1.0139, 0.9934, ..., 1.0357, 1.0361, 0.9926],
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]]],
device='cuda:0', grad_fn=<ViewBackward0>)
shape: (256, 100, 8)
min/max: (0.9636093378067017, 1.036202073097229)
mean: 1.0009483098983765
std: 0.007977578788995743
target = tensor([[[-1.1048, -0.1699, 0.3172, ..., 0.6925, 0.7191, 1.4389],
[ 0.6819, 0.6503, -0.0784, ..., 1.2032, 0.4372, 0.0505],
[ 0.4974, 0.7906, -1.0028, ..., 1.2410, -1.4670, -1.0270],
...,
[-1.3817, 0.1942, -1.2593, ..., 0.0592, 0.4673, 0.3583],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-0.0580, -0.5538, -0.9253, ..., 0.6467, 1.4621, -0.3138],
[-1.3817, 0.1942, -1.2593, ..., 0.0592, 0.4673, 0.3583],
[-0.7571, -1.6050, -0.0124, ..., -0.9880, -0.9499, 0.8033],
...,
[ 1.7151, 1.0070, 0.6890, ..., -2.3825, -0.5136, 0.5498],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.6873, 0.3348, 0.9381, ..., -0.6962, 0.4933, -0.2609],
[-1.0015, -0.2747, -0.3922, ..., 0.6551, 0.1457, 1.8843],
[-0.2476, -0.3332, -0.2145, ..., -0.8714, 0.4179, -0.0367],
...,
[ 0.6761, -0.8634, -0.7832, ..., 0.2734, -0.3206, -0.2002],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
...,
[[ 0.0383, 0.7346, 1.1566, ..., -1.2615, -1.3133, 0.8579],
[ 1.0334, -0.8128, -0.3230, ..., 0.2623, 2.1819, 0.4262],
[-0.3029, -0.7849, 0.6537, ..., 0.0216, -0.0510, 1.3417],
...,
[-0.8324, -0.4507, 0.2398, ..., 0.7770, -1.6973, -1.6883],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.5255, -0.6071, -0.6983, ..., -0.1975, 0.2420, 0.5352],
[ 0.2451, -0.2897, 0.8116, ..., -0.1863, 0.8451, -1.3344],
[ 1.2114, 0.3140, -1.4007, ..., 1.1863, 0.0090, 0.1881],
...,
[ 0.4974, 0.7906, -1.0028, ..., 1.2410, -1.4670, -1.0270],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.7789, 0.5709, 1.3162, ..., 0.9926, 0.0632, -1.1557],
[ 1.2286, -1.6981, 0.1104, ..., -1.3688, -0.4356, 0.2258],
[-0.8187, 1.7120, 1.2602, ..., 0.9032, -1.0293, -0.3666],
...,
[-1.2926, -0.1770, 0.0189, ..., 0.3937, -0.4130, 1.5345],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]],
device='cuda:0')
shape: (256, 100, 8)
min/max: (-3.895033836364746, 4.436890602111816)
mean: -0.04351688176393509
std: 0.9954912662506104
mamba_block1.inp_proj.weight gradient: 9.100136594497599e-06
mamba_block1.inp_proj.bias gradient: 1.3978528841107618e-05
mamba_block1.out_proj.weight gradient: 8.495710790157318e-05
mamba_block1.out_proj.bias gradient: 0.0020170381758362055
mamba_block1.D.weight gradient: 2.5585110051906668e-05
mamba_block1.D.bias gradient: 2.804783252940979e-05
mamba_block1.S6.A_log gradient: 0.0
mamba_block1.S6.fc1.weight gradient: 3.4274560221092543e-06
mamba_block1.S6.fc1.bias gradient: 4.966110736859264e-06
mamba_block1.S6.fc2.weight gradient: 2.0056253561051562e-05
mamba_block1.S6.fc2.bias gradient: 3.251164889661595e-05
mamba_block1.S6.fc3.weight gradient: 1.9456374502624385e-05
mamba_block1.S6.fc3.bias gradient: 3.162693974445574e-05
mamba_block1.conv.weight gradient: 5.1763261581072584e-05
mamba_block1.conv.bias gradient: 8.400884325965308e-06
mamba_block1.conv_linear.weight gradient: 1.8744514818536118e-05
mamba_block1.conv_linear.bias gradient: 6.390304770320654e-05
mamba_block1.norm.weight gradient: 7.944043318275362e-06
mamba_block2.inp_proj.weight gradient: 0.008640581741929054
mamba_block2.inp_proj.bias gradient: 0.00304804602637887
mamba_block2.out_proj.weight gradient: 0.008689355105161667
mamba_block2.out_proj.bias gradient: 0.020420508459210396
mamba_block2.D.weight gradient: 0.005131551064550877
mamba_block2.D.bias gradient: 0.0018101985333487391
mamba_block2.S6.A_log gradient: 0.0
mamba_block2.S6.fc1.weight gradient: 0.0011199985165148973
mamba_block2.S6.fc1.bias gradient: 0.0014004047261551023
mamba_block2.S6.fc2.weight gradient: 0.0031374646350741386
mamba_block2.S6.fc2.bias gradient: 0.004075405187904835
mamba_block2.S6.fc3.weight gradient: 0.002909448929131031
mamba_block2.S6.fc3.bias gradient: 0.0037189736030995846
mamba_block2.conv.weight gradient: 0.012041107751429081
mamba_block2.conv.bias gradient: 0.0009172795107588172
mamba_block2.conv_linear.weight gradient: 0.009620931930840015
mamba_block2.conv_linear.bias gradient: 0.00613889517262578
mamba_block2.norm.weight gradient: 0.0034102771896868944
mamba_block3.inp_proj.weight gradient: 0.0934479609131813
mamba_block3.inp_proj.bias gradient: 0.03292842581868172
mamba_block3.out_proj.weight gradient: 0.044350285083055496
mamba_block3.out_proj.bias gradient: 1.3198520321111573e-07
mamba_block3.D.weight gradient: 0.04143233224749565
mamba_block3.D.bias gradient: 0.01463618129491806
mamba_block3.S6.A_log gradient: 0.0
mamba_block3.S6.fc1.weight gradient: 0.004638255573809147
mamba_block3.S6.fc1.bias gradient: 0.004756370093673468
mamba_block3.S6.fc2.weight gradient: 0.01848497800529003
mamba_block3.S6.fc2.bias gradient: 0.017832543700933456
mamba_block3.S6.fc3.weight gradient: 0.01909303106367588
mamba_block3.S6.fc3.bias gradient: 0.018696237355470657
mamba_block3.conv.weight gradient: 0.1660355180501938
mamba_block3.conv.bias gradient: 0.01661163568496704
mamba_block3.conv_linear.weight gradient: 0.07300027459859848
mamba_block3.conv_linear.bias gradient: 0.04526618868112564
mamba_block3.norm.weight gradient: 0.023428840562701225
DEBUGGING IS ON !!!
output = tensor([[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955],
[1.0255, 1.0139, 0.9934, ..., 1.0358, 1.0363, 0.9926],
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0128, 0.9964]],
[[1.0054, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955],
[1.0255, 1.0139, 0.9934, ..., 1.0358, 1.0363, 0.9926],
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0128, 0.9964]],
[[1.0054, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955],
[1.0255, 1.0138, 0.9935, ..., 1.0358, 1.0363, 0.9926],
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0128, 0.9964]],
...,
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0031, 0.9993, ..., 1.0051, 1.0098, 0.9955],
[1.0255, 1.0139, 0.9934, ..., 1.0358, 1.0363, 0.9926],
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0128, 0.9963]],
[[1.0054, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955],
[1.0255, 1.0139, 0.9934, ..., 1.0358, 1.0363, 0.9926],
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9963]],
[[1.0054, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9956],
[1.0255, 1.0138, 0.9934, ..., 1.0358, 1.0363, 0.9926],
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0128, 0.9964]]],
device='cuda:0', grad_fn=<ViewBackward0>)
shape: (256, 100, 8)
min/max: (0.9635644555091858, 1.0363487005233765)
mean: 1.0009490251541138
std: 0.007988139055669308
target = tensor([[[ 0.1019, -0.3633, 1.4970, ..., -1.1070, 1.8470, 0.7881],
[ 0.4941, 0.0856, 0.3690, ..., 1.3915, 2.5161, 0.3218],
[ 1.7297, 0.7489, 0.7269, ..., -0.3836, 0.6932, 0.7111],
...,
[-1.6287, -0.9663, -0.0027, ..., 1.2085, -0.9073, 0.0483],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-0.5515, 0.7887, -1.5313, ..., 1.2504, 0.1500, -1.8818],
[ 0.8201, 1.6476, 0.4960, ..., -0.2201, 0.8857, 0.0669],
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199],
...,
[-0.8324, -0.4507, 0.2398, ..., 0.7770, -1.6973, -1.6883],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.3433, -0.6291, 0.8468, ..., -1.2711, -1.2323, -0.1769],
[ 0.6819, 0.6503, -0.0784, ..., 1.2032, 0.4372, 0.0505],
[-1.6287, -0.9663, -0.0027, ..., 1.2085, -0.9073, 0.0483],
...,
[ 0.5068, -0.9357, 0.5597, ..., -0.2642, 1.4782, 2.3104],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
...,
[[ 0.0519, -1.6652, 0.4465, ..., 0.6680, 0.7076, -0.9326],
[ 1.4714, 0.0124, -0.2384, ..., -0.2375, 1.1155, 0.2285],
[ 0.1385, -0.2701, 0.1457, ..., 0.4512, -1.1078, -0.2718],
...,
[ 0.0368, 0.7456, -1.4815, ..., 0.9900, 1.4748, -0.2182],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.8941, -2.4687, 0.5529, ..., 0.0181, 0.2483, 0.0552],
[-1.1859, 0.6207, 1.1728, ..., 0.3623, 0.6124, 0.1387],
[-0.5582, -1.8473, -0.1892, ..., -1.3669, 1.0029, -0.2609],
...,
[-0.2494, 1.6643, 1.0550, ..., 0.0960, -0.4710, 0.4718],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.4882, 0.1704, -1.3309, ..., -1.1073, 0.2595, -0.9865],
[ 1.0589, -0.6273, -1.0979, ..., -1.3877, -0.8624, 0.4007],
[ 2.1325, 0.0683, 1.1581, ..., 1.2571, 0.4663, -0.7213],
...,
[ 0.5853, 0.2439, -0.6474, ..., 0.7711, -0.5776, 0.1493],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]],
device='cuda:0')
shape: (256, 100, 8)
min/max: (-4.203547954559326, 3.9631638526916504)
mean: -0.03762376680970192
std: 0.998776376247406
mamba_block1.inp_proj.weight gradient: 8.837168934405781e-06
mamba_block1.inp_proj.bias gradient: 1.3773204045719467e-05
mamba_block1.out_proj.weight gradient: 8.079419058049098e-05
mamba_block1.out_proj.bias gradient: 0.001937799621373415
mamba_block1.D.weight gradient: 2.2958665795158595e-05
mamba_block1.D.bias gradient: 2.63482506852597e-05
mamba_block1.S6.A_log gradient: 0.0
mamba_block1.S6.fc1.weight gradient: 3.0723977033630945e-06
mamba_block1.S6.fc1.bias gradient: 4.301163698983146e-06
mamba_block1.S6.fc2.weight gradient: 1.544903170724865e-05
mamba_block1.S6.fc2.bias gradient: 2.3952608898980543e-05
mamba_block1.S6.fc3.weight gradient: 1.4936680599930696e-05
mamba_block1.S6.fc3.bias gradient: 2.318737460882403e-05
mamba_block1.conv.weight gradient: 5.149357093614526e-05
mamba_block1.conv.bias gradient: 8.877683285390958e-06
mamba_block1.conv_linear.weight gradient: 1.7505970390629955e-05
mamba_block1.conv_linear.bias gradient: 5.392322418629192e-05
mamba_block1.norm.weight gradient: 4.662304036173737e-06
mamba_block2.inp_proj.weight gradient: 0.008749652653932571
mamba_block2.inp_proj.bias gradient: 0.003086492419242859
mamba_block2.out_proj.weight gradient: 0.008550377562642097
mamba_block2.out_proj.bias gradient: 0.0215672105550766
mamba_block2.D.weight gradient: 0.0049406080506742
mamba_block2.D.bias gradient: 0.001742832944728434
mamba_block2.S6.A_log gradient: 0.0
mamba_block2.S6.fc1.weight gradient: 0.0011518856044858694
mamba_block2.S6.fc1.bias gradient: 0.0014069563476368785
mamba_block2.S6.fc2.weight gradient: 0.0031139287166297436
mamba_block2.S6.fc2.bias gradient: 0.003985110204666853
mamba_block2.S6.fc3.weight gradient: 0.002889266237616539
mamba_block2.S6.fc3.bias gradient: 0.003641492687165737
mamba_block2.conv.weight gradient: 0.011927427724003792
mamba_block2.conv.bias gradient: 0.0008801336516626179
mamba_block2.conv_linear.weight gradient: 0.009689133614301682
mamba_block2.conv_linear.bias gradient: 0.005984927993267775
mamba_block2.norm.weight gradient: 0.0032904285471886396
mamba_block3.inp_proj.weight gradient: 0.09736455976963043
mamba_block3.inp_proj.bias gradient: 0.034317947924137115
mamba_block3.out_proj.weight gradient: 0.04285677894949913
mamba_block3.out_proj.bias gradient: 4.898083716398105e-08
mamba_block3.D.weight gradient: 0.0413573682308197
mamba_block3.D.bias gradient: 0.014606538228690624
mamba_block3.S6.A_log gradient: 0.0
mamba_block3.S6.fc1.weight gradient: 0.004640404600650072
mamba_block3.S6.fc1.bias gradient: 0.005210277624428272
mamba_block3.S6.fc2.weight gradient: 0.017779408022761345
mamba_block3.S6.fc2.bias gradient: 0.019206494092941284
mamba_block3.S6.fc3.weight gradient: 0.018545418977737427
mamba_block3.S6.fc3.bias gradient: 0.019921310245990753
mamba_block3.conv.weight gradient: 0.16301649808883667
mamba_block3.conv.bias gradient: 0.01633545197546482
mamba_block3.conv_linear.weight gradient: 0.07117603719234467
mamba_block3.conv_linear.bias gradient: 0.04827188700437546
mamba_block3.norm.weight gradient: 0.023814095184206963
DEBUGGING IS ON !!!
output = tensor([[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955],
[1.0256, 1.0139, 0.9935, ..., 1.0359, 1.0365, 0.9926],
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0128, 0.9963]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955],
[1.0256, 1.0140, 0.9934, ..., 1.0359, 1.0364, 0.9926],
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0128, 0.9963]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955],
[1.0256, 1.0139, 0.9935, ..., 1.0359, 1.0364, 0.9926],
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0128, 0.9963]],
...,
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955],
[1.0256, 1.0139, 0.9935, ..., 1.0359, 1.0365, 0.9926],
[1.0152, 1.0008, 1.0023, ..., 1.0066, 1.0128, 0.9963]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955],
[1.0256, 1.0139, 0.9935, ..., 1.0359, 1.0365, 0.9926],
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0128, 0.9963]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0031, 0.9993, ..., 1.0051, 1.0098, 0.9955],
[1.0256, 1.0139, 0.9935, ..., 1.0359, 1.0364, 0.9926],
[1.0152, 1.0008, 1.0023, ..., 1.0066, 1.0128, 0.9963]]],
device='cuda:0', grad_fn=<ViewBackward0>)
shape: (256, 100, 8)
min/max: (0.9635269641876221, 1.0365265607833862)
mean: 1.0009498596191406
std: 0.007998868823051453
target = tensor([[[-1.2912e-01, 2.7761e-01, 6.5007e-01, ..., 5.3681e-01,
1.4878e+00, -6.7947e-01],
[ 6.0857e-01, 1.4627e+00, 2.5454e-01, ..., 1.6538e+00,
-1.0191e+00, 8.4912e-01],
[ 4.9737e-01, 7.9059e-01, -1.0028e+00, ..., 1.2410e+00,
-1.4670e+00, -1.0270e+00],
...,
[-7.9787e-01, 6.7782e-01, 1.4350e-01, ..., 3.0334e-01,
6.2231e-01, -9.4687e-01],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[ 1.5454e+00, 9.2277e-01, 3.0021e-01, ..., -8.3794e-01,
7.2716e-01, -1.8499e+00],
[ 1.3767e+00, 1.0128e-01, 3.5444e-01, ..., 7.6632e-02,
-1.6822e+00, -1.4354e+00],
[ 1.0189e-01, -3.6331e-01, 1.4970e+00, ..., -1.1070e+00,
1.8470e+00, 7.8808e-01],
...,
[-9.7426e-01, -8.7755e-01, 1.9398e-01, ..., -3.6643e-01,
1.9255e-03, 2.0825e+00],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[ 4.1016e-01, -3.8302e-03, -1.2295e-01, ..., -2.5800e-01,
1.4403e+00, -2.4625e-01],
[ 6.1667e-02, -2.4054e-02, 1.9664e+00, ..., -1.5273e+00,
2.2778e-01, -4.6371e-01],
[ 2.0079e+00, 2.8988e-01, 1.4619e-01, ..., 1.1677e-02,
1.5477e-01, 9.1439e-01],
...,
[ 5.0676e-01, -9.3567e-01, 5.5966e-01, ..., -2.6422e-01,
1.4782e+00, 2.3104e+00],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
...,
[[-9.1950e-02, -6.5555e-01, 1.6096e+00, ..., -1.5558e+00,
6.1454e-01, 1.4055e+00],
[-1.3817e+00, 1.9420e-01, -1.2593e+00, ..., 5.9164e-02,
4.6726e-01, 3.5826e-01],
[-1.4729e+00, -2.1761e+00, 9.2319e-01, ..., 4.0713e-01,
-1.6731e+00, 1.1180e+00],
...,
[-6.8272e-01, -2.8986e-01, -8.1461e-02, ..., 3.9673e-01,
2.5136e-01, 6.9517e-01],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[ 1.1350e+00, -1.6268e+00, 1.5461e+00, ..., -1.8916e+00,
-1.9114e+00, 1.1798e+00],
[-5.3807e-01, 1.8558e+00, -1.3125e+00, ..., -2.1141e+00,
-5.7919e-01, -8.3718e-02],
[ 5.2551e-01, -6.0715e-01, -6.9834e-01, ..., -1.9748e-01,
2.4198e-01, 5.3519e-01],
...,
[-1.5331e-01, 5.9299e-01, 1.3224e-01, ..., -1.6126e+00,
-6.2350e-01, -1.3132e+00],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[ 3.8803e-01, -1.8462e-01, 3.3831e-01, ..., -6.0497e-01,
1.2007e-01, -1.0940e+00],
[-2.9351e-01, 9.6275e-01, -1.5990e+00, ..., -1.4156e+00,
-1.0206e+00, -1.0802e+00],
[ 6.3896e-02, 3.2472e-04, -1.2828e+00, ..., -1.0525e+00,
-1.3741e+00, -1.5745e+00],
...,
[ 1.2286e+00, -1.6981e+00, 1.1041e-01, ..., -1.3688e+00,
-4.3559e-01, 2.2583e-01],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]]], device='cuda:0')
shape: (256, 100, 8)
min/max: (-4.062912464141846, 3.9631638526916504)
mean: -0.03548859432339668
std: 0.9919624924659729
mamba_block1.inp_proj.weight gradient: 7.336026101256721e-06
mamba_block1.inp_proj.bias gradient: 1.0591419595584739e-05
mamba_block1.out_proj.weight gradient: 7.542931416537613e-05
mamba_block1.out_proj.bias gradient: 0.00191353855188936
mamba_block1.D.weight gradient: 1.9951721696997993e-05
mamba_block1.D.bias gradient: 2.5469640604569577e-05
mamba_block1.S6.A_log gradient: 0.0
mamba_block1.S6.fc1.weight gradient: 2.71867838819162e-06
mamba_block1.S6.fc1.bias gradient: 3.859155185637064e-06
mamba_block1.S6.fc2.weight gradient: 1.6416681319242343e-05
mamba_block1.S6.fc2.bias gradient: 2.4024193407967687e-05
mamba_block1.S6.fc3.weight gradient: 1.577912007633131e-05
mamba_block1.S6.fc3.bias gradient: 2.3290005628950894e-05
mamba_block1.conv.weight gradient: 4.8139336286112666e-05
mamba_block1.conv.bias gradient: 7.358869879681151e-06
mamba_block1.conv_linear.weight gradient: 1.663083821767941e-05
mamba_block1.conv_linear.bias gradient: 5.175057958695106e-05
mamba_block1.norm.weight gradient: 5.831683665746823e-06
mamba_block2.inp_proj.weight gradient: 0.008543197065591812
mamba_block2.inp_proj.bias gradient: 0.003013620851561427
mamba_block2.out_proj.weight gradient: 0.008505250327289104
mamba_block2.out_proj.bias gradient: 0.020350750535726547
mamba_block2.D.weight gradient: 0.004964698106050491
mamba_block2.D.bias gradient: 0.0017513002967461944
mamba_block2.S6.A_log gradient: 0.0
mamba_block2.S6.fc1.weight gradient: 0.001128288102336228
mamba_block2.S6.fc1.bias gradient: 0.0014180107973515987
mamba_block2.S6.fc2.weight gradient: 0.0031073689460754395
mamba_block2.S6.fc2.bias gradient: 0.004030835349112749
mamba_block2.S6.fc3.weight gradient: 0.0028815357945859432
mamba_block2.S6.fc3.bias gradient: 0.0036812785547226667
mamba_block2.conv.weight gradient: 0.011601514182984829
mamba_block2.conv.bias gradient: 0.0008414683397859335
mamba_block2.conv_linear.weight gradient: 0.00955396518111229
mamba_block2.conv_linear.bias gradient: 0.006414241157472134
mamba_block2.norm.weight gradient: 0.0032158917747437954
mamba_block3.inp_proj.weight gradient: 0.08929342031478882
mamba_block3.inp_proj.bias gradient: 0.03145138546824455
mamba_block3.out_proj.weight gradient: 0.04311970993876457
mamba_block3.out_proj.bias gradient: 4.8321194157097125e-08
mamba_block3.D.weight gradient: 0.04716562479734421
mamba_block3.D.bias gradient: 0.01665896736085415
mamba_block3.S6.A_log gradient: 0.0
mamba_block3.S6.fc1.weight gradient: 0.004517058376222849
mamba_block3.S6.fc1.bias gradient: 0.004935705102980137
mamba_block3.S6.fc2.weight gradient: 0.019176244735717773
mamba_block3.S6.fc2.bias gradient: 0.018631301820278168
mamba_block3.S6.fc3.weight gradient: 0.019713152199983597
mamba_block3.S6.fc3.bias gradient: 0.019606487825512886
mamba_block3.conv.weight gradient: 0.1666284203529358
mamba_block3.conv.bias gradient: 0.016464872285723686
mamba_block3.conv_linear.weight gradient: 0.07397904992103577
mamba_block3.conv_linear.bias gradient: 0.04956180974841118
mamba_block3.norm.weight gradient: 0.0232034083455801
DEBUGGING IS ON !!!
output = tensor([[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955],
[1.0257, 1.0139, 0.9935, ..., 1.0360, 1.0366, 0.9927],
[1.0152, 1.0008, 1.0023, ..., 1.0066, 1.0128, 0.9963]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955],
[1.0257, 1.0139, 0.9935, ..., 1.0361, 1.0366, 0.9927],
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0128, 0.9963]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955],
[1.0257, 1.0139, 0.9935, ..., 1.0361, 1.0367, 0.9927],
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0128, 0.9963]],
...,
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9956],
[1.0257, 1.0139, 0.9935, ..., 1.0361, 1.0366, 0.9927],
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0128, 0.9963]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955],
[1.0257, 1.0139, 0.9935, ..., 1.0361, 1.0366, 0.9927],
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0128, 0.9963]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955],
[1.0257, 1.0139, 0.9935, ..., 1.0361, 1.0367, 0.9927],
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0128, 0.9963]]],
device='cuda:0', grad_fn=<ViewBackward0>)
shape: (256, 100, 8)
min/max: (0.9634785056114197, 1.0367209911346436)
mean: 1.000950574874878
std: 0.008009559474885464
target = tensor([[[-1.6873, 1.0246, -0.1206, ..., 1.1435, -0.6533, -0.3542],
[ 1.1250, 1.1728, 0.5452, ..., -1.0478, 0.4368, 1.5019],
[-0.0696, -0.1048, 0.0242, ..., 0.4035, -0.3938, -1.4395],
...,
[-0.1848, -1.6540, -1.1411, ..., -0.4016, -0.8012, 2.9020],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.1664, 0.1732, -0.6635, ..., -0.2567, -0.0699, -0.5926],
[-0.8324, -0.4507, 0.2398, ..., 0.7770, -1.6973, -1.6883],
[-1.9692, 0.7451, 1.1040, ..., 1.9508, 0.4760, 0.4680],
...,
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 1.1350, -1.6268, 1.5461, ..., -1.8916, -1.9114, 1.1798],
[-0.3877, 0.3024, -1.1404, ..., 2.0661, -0.6191, -0.9355],
[-0.0229, 1.0938, -0.0923, ..., -0.2372, -0.9342, -0.0119],
...,
[ 0.1846, -0.2507, 0.2757, ..., -0.6224, -2.2640, 0.0596],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
...,
[[ 1.2858, 1.6970, -0.7869, ..., -0.6545, -0.6808, -1.0335],
[-0.0877, -0.1239, -1.0846, ..., -0.3003, -0.8652, 2.1638],
[ 0.4974, 0.7906, -1.0028, ..., 1.2410, -1.4670, -1.0270],
...,
[ 0.6566, -2.1274, -0.8276, ..., 0.4022, -0.3414, 0.9617],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.5679, 1.1681, -0.2152, ..., 0.4324, -0.3278, 0.3071],
[-2.0951, -1.0155, 0.1259, ..., 0.1168, 0.8176, -1.6148],
[-0.6088, 0.1729, 0.2571, ..., 1.8095, 0.2413, 1.2040],
...,
[ 1.6005, 0.5397, -0.2096, ..., 0.3994, 1.0095, 0.0273],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-0.3876, 0.3656, 0.3301, ..., 0.5791, -0.6306, 0.7447],
[-1.5113, 1.0783, 0.5030, ..., -0.0460, -0.4079, 0.7232],
[ 2.0079, 0.2899, 0.1462, ..., 0.0117, 0.1548, 0.9144],
...,
[-0.4137, -0.8622, -1.2035, ..., -1.6588, 1.6848, 1.1585],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]],
device='cuda:0')
shape: (256, 100, 8)
min/max: (-4.0426435470581055, 3.9631638526916504)
mean: -0.045278389006853104
std: 1.0026400089263916
mamba_block1.inp_proj.weight gradient: 7.799080776749179e-06
mamba_block1.inp_proj.bias gradient: 1.4282920346886385e-05
mamba_block1.out_proj.weight gradient: 7.392667612293735e-05
mamba_block1.out_proj.bias gradient: 0.001936393091455102
mamba_block1.D.weight gradient: 2.073771793220658e-05
mamba_block1.D.bias gradient: 2.6325606086174957e-05
mamba_block1.S6.A_log gradient: 0.0
mamba_block1.S6.fc1.weight gradient: 2.736860324148438e-06
mamba_block1.S6.fc1.bias gradient: 4.0264562812808435e-06
mamba_block1.S6.fc2.weight gradient: 1.7693038898869418e-05
mamba_block1.S6.fc2.bias gradient: 2.8277174351387657e-05
mamba_block1.S6.fc3.weight gradient: 1.71951214724686e-05
mamba_block1.S6.fc3.bias gradient: 2.7498386771185324e-05
mamba_block1.conv.weight gradient: 5.321425123838708e-05
mamba_block1.conv.bias gradient: 1.0530870895308908e-05
mamba_block1.conv_linear.weight gradient: 1.8596581867313944e-05
mamba_block1.conv_linear.bias gradient: 5.6252407375723124e-05
mamba_block1.norm.weight gradient: 3.8478110582218505e-06
mamba_block2.inp_proj.weight gradient: 0.008295131847262383
mamba_block2.inp_proj.bias gradient: 0.0029260965529829264
mamba_block2.out_proj.weight gradient: 0.008724554441869259
mamba_block2.out_proj.bias gradient: 0.020028317347168922
mamba_block2.D.weight gradient: 0.00506933219730854
mamba_block2.D.bias gradient: 0.001788186258636415
mamba_block2.S6.A_log gradient: 0.0
mamba_block2.S6.fc1.weight gradient: 0.0011186042102053761
mamba_block2.S6.fc1.bias gradient: 0.0013990309089422226
mamba_block2.S6.fc2.weight gradient: 0.0030988024082034826
mamba_block2.S6.fc2.bias gradient: 0.003995432984083891
mamba_block2.S6.fc3.weight gradient: 0.0028782477602362633
mamba_block2.S6.fc3.bias gradient: 0.003649807535111904
mamba_block2.conv.weight gradient: 0.011995462700724602
mamba_block2.conv.bias gradient: 0.0008671989198774099
mamba_block2.conv_linear.weight gradient: 0.009438985958695412
mamba_block2.conv_linear.bias gradient: 0.006164009682834148
mamba_block2.norm.weight gradient: 0.003288773586973548
mamba_block3.inp_proj.weight gradient: 0.09197264909744263
mamba_block3.inp_proj.bias gradient: 0.032406117767095566
mamba_block3.out_proj.weight gradient: 0.04633298143744469
mamba_block3.out_proj.bias gradient: 7.247580668945375e-08
mamba_block3.D.weight gradient: 0.04627760872244835
mamba_block3.D.bias gradient: 0.01634623482823372
mamba_block3.S6.A_log gradient: 0.0
mamba_block3.S6.fc1.weight gradient: 0.004637931939214468
mamba_block3.S6.fc1.bias gradient: 0.004829443525522947
mamba_block3.S6.fc2.weight gradient: 0.018521852791309357
mamba_block3.S6.fc2.bias gradient: 0.017666202038526535
mamba_block3.S6.fc3.weight gradient: 0.019322099164128304
mamba_block3.S6.fc3.bias gradient: 0.01848628558218479
mamba_block3.conv.weight gradient: 0.1661185473203659
mamba_block3.conv.bias gradient: 0.016495849937200546
mamba_block3.conv_linear.weight gradient: 0.07894043624401093
mamba_block3.conv_linear.bias gradient: 0.055515777319669724
mamba_block3.norm.weight gradient: 0.02291417308151722
DEBUGGING IS ON !!!
output = tensor([[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0258, 1.0139, 0.9935, ..., 1.0362, 1.0368, 0.9926],
[1.0152, 1.0009, 1.0023, ..., 1.0066, 1.0128, 0.9963]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955],
[1.0258, 1.0139, 0.9935, ..., 1.0362, 1.0368, 0.9927],
[1.0152, 1.0009, 1.0023, ..., 1.0066, 1.0128, 0.9963]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955],
[1.0258, 1.0139, 0.9935, ..., 1.0362, 1.0368, 0.9927],
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0128, 0.9963]],
...,
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955],
[1.0258, 1.0139, 0.9935, ..., 1.0362, 1.0368, 0.9927],
[1.0152, 1.0008, 1.0023, ..., 1.0066, 1.0128, 0.9963]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0258, 1.0139, 0.9935, ..., 1.0362, 1.0368, 0.9927],
[1.0152, 1.0008, 1.0023, ..., 1.0066, 1.0128, 0.9963]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0258, 1.0139, 0.9936, ..., 1.0362, 1.0368, 0.9927],
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0128, 0.9963]]],
device='cuda:0', grad_fn=<ViewBackward0>)
shape: (256, 100, 8)
min/max: (0.9634485244750977, 1.0368773937225342)
mean: 1.0009514093399048
std: 0.008020208217203617
target = tensor([[[-7.2930e-02, -1.5876e+00, -1.7188e-01, ..., 1.2421e+00,
7.0656e-01, 4.5039e-01],
[-3.9398e-01, -3.2687e-01, -2.5453e+00, ..., 2.5721e-01,
-2.0661e-01, 1.2287e-01],
[-2.1060e-01, 8.8564e-01, -2.6071e-01, ..., -3.1098e-01,
-9.9241e-01, 2.3331e-01],
...,
[-1.7780e+00, -3.5378e-01, 5.7481e-01, ..., -4.0297e-01,
-3.2478e-01, -6.1987e-01],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[-4.0608e-02, 8.3209e-01, -5.6969e-01, ..., 1.5675e-01,
-2.0986e+00, 1.0620e+00],
[-6.5043e-01, -1.3069e+00, -1.5379e-01, ..., 7.5101e-01,
-1.4239e+00, -2.3928e-01],
[-1.0840e+00, 6.8404e-01, 1.2655e-03, ..., -2.0117e-01,
7.5984e-01, -4.4518e-01],
...,
[ 5.0676e-01, -9.3567e-01, 5.5966e-01, ..., -2.6422e-01,
1.4782e+00, 2.3104e+00],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[ 2.5504e+00, -5.1241e-01, -2.8401e-01, ..., -9.7184e-02,
-6.2687e-01, -4.9886e-01],
[ 2.5504e+00, -5.1241e-01, -2.8401e-01, ..., -9.7184e-02,
-6.2687e-01, -4.9886e-01],
[-1.3817e+00, 1.9420e-01, -1.2593e+00, ..., 5.9164e-02,
4.6726e-01, 3.5826e-01],
...,
[-1.3817e+00, 1.9420e-01, -1.2593e+00, ..., 5.9164e-02,
4.6726e-01, 3.5826e-01],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
...,
[[ 6.8187e-01, 6.5033e-01, -7.8437e-02, ..., 1.2032e+00,
4.3723e-01, 5.0549e-02],
[ 2.5504e+00, -5.1241e-01, -2.8401e-01, ..., -9.7184e-02,
-6.2687e-01, -4.9886e-01],
[ 2.5504e+00, -5.1241e-01, -2.8401e-01, ..., -9.7184e-02,
-6.2687e-01, -4.9886e-01],
...,
[ 1.6843e+00, -9.7226e-01, -1.0947e+00, ..., 2.9206e-01,
8.7524e-01, -1.8599e+00],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[-8.9061e-01, -1.8222e-01, -8.0707e-01, ..., 9.1797e-01,
5.8479e-01, 8.0782e-01],
[-1.6287e+00, -9.6626e-01, -2.6581e-03, ..., 1.2085e+00,
-9.0730e-01, 4.8345e-02],
[-5.9222e-01, 1.0597e+00, -8.2489e-01, ..., 3.3105e-01,
5.1061e-01, -1.4595e-01],
...,
[ 2.5486e+00, 1.0393e-01, 1.4986e+00, ..., -1.2783e+00,
-7.6003e-01, -8.4845e-01],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[-9.1471e-02, 1.2755e-01, 7.2934e-01, ..., 1.1558e+00,
-3.6694e-01, -2.0441e-01],
[ 4.9737e-01, 7.9059e-01, -1.0028e+00, ..., 1.2410e+00,
-1.4670e+00, -1.0270e+00],
[-1.7780e+00, -3.5378e-01, 5.7481e-01, ..., -4.0297e-01,
-3.2478e-01, -6.1987e-01],
...,
[ 5.2773e-01, -1.5574e+00, 5.2337e-02, ..., -4.6493e-01,
-4.2155e-02, 2.6858e+00],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]]], device='cuda:0')
shape: (256, 100, 8)
min/max: (-4.335270404815674, 4.436890602111816)
mean: -0.04498155415058136
std: 1.001204490661621
mamba_block1.inp_proj.weight gradient: 7.761200322420336e-06
mamba_block1.inp_proj.bias gradient: 1.3567365385824814e-05
mamba_block1.out_proj.weight gradient: 8.328901458298787e-05
mamba_block1.out_proj.bias gradient: 0.001959998393431306
mamba_block1.D.weight gradient: 2.5373388780280948e-05
mamba_block1.D.bias gradient: 2.6117280867765658e-05
mamba_block1.S6.A_log gradient: 0.0
mamba_block1.S6.fc1.weight gradient: 2.946576842077775e-06
mamba_block1.S6.fc1.bias gradient: 4.046666617796291e-06
mamba_block1.S6.fc2.weight gradient: 1.7208309145644307e-05
mamba_block1.S6.fc2.bias gradient: 2.567360206739977e-05
mamba_block1.S6.fc3.weight gradient: 1.6707210306776688e-05
mamba_block1.S6.fc3.bias gradient: 2.5138038836303167e-05
mamba_block1.conv.weight gradient: 5.01254471600987e-05
mamba_block1.conv.bias gradient: 8.776757567829918e-06
mamba_block1.conv_linear.weight gradient: 1.8366961739957333e-05
mamba_block1.conv_linear.bias gradient: 5.426110510597937e-05
mamba_block1.norm.weight gradient: 5.980205969535746e-06
mamba_block2.inp_proj.weight gradient: 0.008871041238307953
mamba_block2.inp_proj.bias gradient: 0.003129237564280629
mamba_block2.out_proj.weight gradient: 0.008679620921611786
mamba_block2.out_proj.bias gradient: 0.021518703550100327
mamba_block2.D.weight gradient: 0.005031412001699209
mamba_block2.D.bias gradient: 0.0017748060636222363
mamba_block2.S6.A_log gradient: 0.0
mamba_block2.S6.fc1.weight gradient: 0.00114446971565485
mamba_block2.S6.fc1.bias gradient: 0.0014250640524551272
mamba_block2.S6.fc2.weight gradient: 0.0031288242898881435
mamba_block2.S6.fc2.bias gradient: 0.0038880067877471447
mamba_block2.S6.fc3.weight gradient: 0.00290935137309134
mamba_block2.S6.fc3.bias gradient: 0.0035403394140303135
mamba_block2.conv.weight gradient: 0.011712766252458096
mamba_block2.conv.bias gradient: 0.0008701256010681391
mamba_block2.conv_linear.weight gradient: 0.009876110590994358
mamba_block2.conv_linear.bias gradient: 0.0061364308930933475
mamba_block2.norm.weight gradient: 0.0033124934416264296
mamba_block3.inp_proj.weight gradient: 0.09452962875366211
mamba_block3.inp_proj.bias gradient: 0.033306971192359924
mamba_block3.out_proj.weight gradient: 0.04634234309196472
mamba_block3.out_proj.bias gradient: 6.166452948264123e-08
mamba_block3.D.weight gradient: 0.05371489003300667
mamba_block3.D.bias gradient: 0.018968436866998672
mamba_block3.S6.A_log gradient: 0.0
mamba_block3.S6.fc1.weight gradient: 0.0050024171359837055
mamba_block3.S6.fc1.bias gradient: 0.005564017221331596
mamba_block3.S6.fc2.weight gradient: 0.01854153349995613
mamba_block3.S6.fc2.bias gradient: 0.01928705908358097
mamba_block3.S6.fc3.weight gradient: 0.019480036571621895
mamba_block3.S6.fc3.bias gradient: 0.020197639241814613
mamba_block3.conv.weight gradient: 0.16645151376724243
mamba_block3.conv.bias gradient: 0.01688520796597004
mamba_block3.conv_linear.weight gradient: 0.07389499992132187
mamba_block3.conv_linear.bias gradient: 0.04942871257662773
mamba_block3.norm.weight gradient: 0.025095799937844276
DEBUGGING IS ON !!!
output = tensor([[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0259, 1.0139, 0.9936, ..., 1.0363, 1.0370, 0.9927],
[1.0152, 1.0009, 1.0023, ..., 1.0066, 1.0128, 0.9963]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955],
[1.0259, 1.0139, 0.9936, ..., 1.0363, 1.0370, 0.9927],
[1.0152, 1.0008, 1.0023, ..., 1.0066, 1.0128, 0.9963]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0259, 1.0139, 0.9936, ..., 1.0363, 1.0370, 0.9927],
[1.0152, 1.0008, 1.0023, ..., 1.0066, 1.0128, 0.9963]],
...,
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0259, 1.0139, 0.9936, ..., 1.0363, 1.0370, 0.9927],
[1.0152, 1.0008, 1.0023, ..., 1.0066, 1.0128, 0.9963]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0259, 1.0139, 0.9936, ..., 1.0363, 1.0370, 0.9927],
[1.0152, 1.0008, 1.0023, ..., 1.0066, 1.0128, 0.9963]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0259, 1.0138, 0.9936, ..., 1.0363, 1.0370, 0.9927],
[1.0152, 1.0008, 1.0024, ..., 1.0067, 1.0128, 0.9963]]],
device='cuda:0', grad_fn=<ViewBackward0>)
shape: (256, 100, 8)
min/max: (0.9633966684341431, 1.0370630025863647)
mean: 1.0009517669677734
std: 0.008031157776713371
target = tensor([[[ 0.0443, 0.8938, -0.9553, ..., -1.4890, -0.4225, -0.0190],
[ 0.6640, -0.1311, -0.5633, ..., -0.9016, 0.9427, -2.0075],
[ 0.0461, -0.2495, 0.1279, ..., 0.0830, 0.3020, -0.2346],
...,
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 1.1027, 0.7786, 1.2513, ..., -0.2450, 0.3287, -1.6867],
[ 0.4430, 1.5878, 1.1566, ..., 0.9944, -1.4975, -0.3028],
[ 0.1019, -0.3633, 1.4970, ..., -1.1070, 1.8470, 0.7881],
...,
[-1.6522, -1.8433, -0.7053, ..., -0.3111, 1.1393, -0.4392],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-1.4081, -1.4504, -0.2482, ..., 1.1472, -0.2720, 1.4572],
[-0.3877, 0.3024, -1.1404, ..., 2.0661, -0.6191, -0.9355],
[ 1.7248, -1.1906, -0.3196, ..., 0.4742, 2.1333, 0.7659],
...,
[-0.6860, -0.8323, 1.5771, ..., 0.9945, -0.1853, 0.8436],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
...,
[[-0.3277, -0.7792, 0.8676, ..., 0.2718, -1.9822, 0.4135],
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989],
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989],
...,
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.9245, 1.2509, 0.0326, ..., 0.1467, 1.1677, -0.9162],
[-1.3817, 0.1942, -1.2593, ..., 0.0592, 0.4673, 0.3583],
[ 1.0915, -1.1170, -0.3247, ..., -0.9190, 1.1993, 0.6716],
...,
[ 0.6819, 0.6503, -0.0784, ..., 1.2032, 0.4372, 0.0505],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-0.3029, -0.7849, 0.6537, ..., 0.0216, -0.0510, 1.3417],
[-0.3029, -0.7849, 0.6537, ..., 0.0216, -0.0510, 1.3417],
[ 0.0290, -0.8582, -0.3564, ..., 1.0203, 0.7273, 0.1357],
...,
[-0.5816, 0.0080, 1.8231, ..., -1.1851, 0.4162, -0.0030],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]],
device='cuda:0')
shape: (256, 100, 8)
min/max: (-4.0426435470581055, 3.9631638526916504)
mean: -0.04280959814786911
std: 0.9980056881904602
mamba_block1.inp_proj.weight gradient: 9.450679499423131e-06
mamba_block1.inp_proj.bias gradient: 1.5937728676362894e-05
mamba_block1.out_proj.weight gradient: 7.879294571466744e-05
mamba_block1.out_proj.bias gradient: 0.001937422202900052
mamba_block1.D.weight gradient: 2.1698680939152837e-05
mamba_block1.D.bias gradient: 2.7504667741595767e-05
mamba_block1.S6.A_log gradient: 0.0
mamba_block1.S6.fc1.weight gradient: 3.608612132666167e-06
mamba_block1.S6.fc1.bias gradient: 5.172527835384244e-06
mamba_block1.S6.fc2.weight gradient: 2.2741836801287718e-05
mamba_block1.S6.fc2.bias gradient: 3.544551145751029e-05
mamba_block1.S6.fc3.weight gradient: 2.2189064111444168e-05
mamba_block1.S6.fc3.bias gradient: 3.465290501480922e-05
mamba_block1.conv.weight gradient: 5.530563066713512e-05
mamba_block1.conv.bias gradient: 8.797837836027611e-06
mamba_block1.conv_linear.weight gradient: 2.0959831090294756e-05
mamba_block1.conv_linear.bias gradient: 6.4122024923563e-05
mamba_block1.norm.weight gradient: 6.733871941833058e-06
mamba_block2.inp_proj.weight gradient: 0.008612259291112423
mamba_block2.inp_proj.bias gradient: 0.0030379192903637886
mamba_block2.out_proj.weight gradient: 0.008520158939063549
mamba_block2.out_proj.bias gradient: 0.021844662725925446
mamba_block2.D.weight gradient: 0.0048363665118813515
mamba_block2.D.bias gradient: 0.0017059911042451859
mamba_block2.S6.A_log gradient: 0.0
mamba_block2.S6.fc1.weight gradient: 0.0011253735283389688
mamba_block2.S6.fc1.bias gradient: 0.0013490341370925307
mamba_block2.S6.fc2.weight gradient: 0.0030624449718743563
mamba_block2.S6.fc2.bias gradient: 0.0040188198909163475
mamba_block2.S6.fc3.weight gradient: 0.0028361633885651827
mamba_block2.S6.fc3.bias gradient: 0.003677345346659422
mamba_block2.conv.weight gradient: 0.01234238687902689
mamba_block2.conv.bias gradient: 0.000914952193852514
mamba_block2.conv_linear.weight gradient: 0.009588218294084072
mamba_block2.conv_linear.bias gradient: 0.0058879912830889225
mamba_block2.norm.weight gradient: 0.0032451574224978685
mamba_block3.inp_proj.weight gradient: 0.09518136829137802
mamba_block3.inp_proj.bias gradient: 0.03354042023420334
mamba_block3.out_proj.weight gradient: 0.04517391696572304
mamba_block3.out_proj.bias gradient: 1.0457387844553523e-07
mamba_block3.D.weight gradient: 0.047621291130781174
mamba_block3.D.bias gradient: 0.016819199547171593
mamba_block3.S6.A_log gradient: 0.0
mamba_block3.S6.fc1.weight gradient: 0.0044674440287053585
mamba_block3.S6.fc1.bias gradient: 0.0047797891311347485
mamba_block3.S6.fc2.weight gradient: 0.01866152696311474
mamba_block3.S6.fc2.bias gradient: 0.019014324992895126
mamba_block3.S6.fc3.weight gradient: 0.01939181052148342
mamba_block3.S6.fc3.bias gradient: 0.020019695162773132
mamba_block3.conv.weight gradient: 0.16745711863040924
mamba_block3.conv.bias gradient: 0.01679954305291176
mamba_block3.conv_linear.weight gradient: 0.07562368363142014
mamba_block3.conv_linear.bias gradient: 0.054438941180706024
mamba_block3.norm.weight gradient: 0.024364715442061424
DEBUGGING IS ON !!!
output = tensor([[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0260, 1.0139, 0.9936, ..., 1.0364, 1.0372, 0.9927],
[1.0153, 1.0009, 1.0023, ..., 1.0066, 1.0128, 0.9963]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0260, 1.0139, 0.9936, ..., 1.0365, 1.0372, 0.9927],
[1.0153, 1.0009, 1.0023, ..., 1.0066, 1.0128, 0.9963]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0260, 1.0139, 0.9936, ..., 1.0364, 1.0372, 0.9927],
[1.0153, 1.0009, 1.0023, ..., 1.0066, 1.0128, 0.9962]],
...,
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0260, 1.0139, 0.9936, ..., 1.0365, 1.0372, 0.9927],
[1.0153, 1.0009, 1.0023, ..., 1.0066, 1.0128, 0.9963]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0260, 1.0138, 0.9936, ..., 1.0365, 1.0372, 0.9927],
[1.0153, 1.0009, 1.0023, ..., 1.0066, 1.0128, 0.9963]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950],
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0260, 1.0139, 0.9936, ..., 1.0365, 1.0372, 0.9927],
[1.0153, 1.0009, 1.0023, ..., 1.0066, 1.0128, 0.9963]]],
device='cuda:0', grad_fn=<ViewBackward0>)
shape: (256, 100, 8)
min/max: (0.963355541229248, 1.0372397899627686)
mean: 1.0009523630142212
std: 0.008042216300964355
target = tensor([[[ 6.4588e-01, -9.7711e-01, 1.4713e-01, ..., -1.7452e+00,
2.4286e-02, 1.4304e-02],
[ 1.2286e+00, -1.6981e+00, 1.1041e-01, ..., -1.3688e+00,
-4.3559e-01, 2.2583e-01],
[ 5.7543e-01, -1.2123e+00, 1.6030e+00, ..., -8.3098e-01,
-2.7845e+00, -2.1074e-02],
...,
[-1.7780e+00, -3.5378e-01, 5.7481e-01, ..., -4.0297e-01,
-3.2478e-01, -6.1987e-01],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[-1.7117e+00, 7.8182e-01, 1.7233e-01, ..., 1.2136e+00,
3.9576e-01, -4.3173e-01],
[ 4.4899e-01, -2.4530e+00, -1.6500e-01, ..., -1.3791e-01,
-4.4953e-02, 3.7787e-01],
[-9.2340e-01, 1.3582e+00, 1.4513e+00, ..., -3.5925e-01,
-1.2063e+00, -1.5141e-01],
...,
[ 1.2389e+00, 1.4293e-01, -9.1111e-01, ..., -5.7567e-02,
9.1207e-01, 5.4976e-01],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[-1.6287e+00, -9.6626e-01, -2.6581e-03, ..., 1.2085e+00,
-9.0730e-01, 4.8345e-02],
[ 1.0462e+00, 9.3372e-01, 9.1681e-01, ..., 4.8498e-01,
5.8902e-01, -9.3716e-02],
[ 5.0676e-01, -9.3567e-01, 5.5966e-01, ..., -2.6422e-01,
1.4782e+00, 2.3104e+00],
...,
[ 1.0110e+00, 5.1812e-01, -9.6063e-01, ..., 6.9258e-01,
2.0773e-01, -8.0699e-01],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
...,
[[-4.0720e-01, -5.6419e-01, 8.8167e-01, ..., 7.7059e-01,
-4.5208e-01, -3.7696e-01],
[-2.8159e-02, 8.7647e-01, 3.6170e-01, ..., -8.5379e-01,
5.3774e-01, -1.6134e+00],
[ 1.0189e-01, -3.6331e-01, 1.4970e+00, ..., -1.1070e+00,
1.8470e+00, 7.8808e-01],
...,
[-3.2770e-01, -7.7916e-01, 8.6764e-01, ..., 2.7178e-01,
-1.9822e+00, 4.1346e-01],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[-1.5331e-01, 5.9299e-01, 1.3224e-01, ..., -1.6126e+00,
-6.2350e-01, -1.3132e+00],
[-7.7284e-01, 7.0358e-01, -2.8840e-02, ..., -1.3084e+00,
-3.0288e-01, -8.2964e-01],
[ 1.9111e-01, -1.2776e+00, -1.7906e-01, ..., -1.6976e-01,
-3.4747e-01, 1.2224e+00],
...,
[ 1.0189e-01, -3.6331e-01, 1.4970e+00, ..., -1.1070e+00,
1.8470e+00, 7.8808e-01],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[ 1.0189e-01, -3.6331e-01, 1.4970e+00, ..., -1.1070e+00,
1.8470e+00, 7.8808e-01],
[-1.7242e+00, 4.8190e-01, 1.8281e+00, ..., 4.0987e-01,
-2.7694e-01, -1.8146e-01],
[-7.2930e-02, -1.5876e+00, -1.7188e-01, ..., 1.2421e+00,
7.0656e-01, 4.5039e-01],
...,
[ 2.5504e+00, -5.1241e-01, -2.8401e-01, ..., -9.7184e-02,
-6.2687e-01, -4.9886e-01],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]]], device='cuda:0')
shape: (256, 100, 8)
min/max: (-4.062912464141846, 4.436890602111816)
mean: -0.045185498893260956
std: 1.000666618347168
mamba_block1.inp_proj.weight gradient: 7.601716788485646e-06
mamba_block1.inp_proj.bias gradient: 1.3704358934774064e-05
mamba_block1.out_proj.weight gradient: 8.511666965205222e-05
mamba_block1.out_proj.bias gradient: 0.0020013893954455853
mamba_block1.D.weight gradient: 2.2376692868419923e-05
mamba_block1.D.bias gradient: 2.6539159080130048e-05
mamba_block1.S6.A_log gradient: 0.0
mamba_block1.S6.fc1.weight gradient: 3.140715534755145e-06
mamba_block1.S6.fc1.bias gradient: 4.521216851571808e-06
mamba_block1.S6.fc2.weight gradient: 1.8029242710326798e-05
mamba_block1.S6.fc2.bias gradient: 2.9236758564366028e-05
mamba_block1.S6.fc3.weight gradient: 1.7625890905037522e-05
mamba_block1.S6.fc3.bias gradient: 2.8570417271112092e-05
mamba_block1.conv.weight gradient: 5.065084042144008e-05
mamba_block1.conv.bias gradient: 8.323479960381519e-06
mamba_block1.conv_linear.weight gradient: 1.7244665286852978e-05
mamba_block1.conv_linear.bias gradient: 5.684297138941474e-05
mamba_block1.norm.weight gradient: 5.099446298117982e-06
mamba_block2.inp_proj.weight gradient: 0.00868897046893835
mamba_block2.inp_proj.bias gradient: 0.0030649492982774973
mamba_block2.out_proj.weight gradient: 0.008559616282582283
mamba_block2.out_proj.bias gradient: 0.020069818943738937
mamba_block2.D.weight gradient: 0.005039901006966829
mamba_block2.D.bias gradient: 0.0017777711618691683
mamba_block2.S6.A_log gradient: 0.0
mamba_block2.S6.fc1.weight gradient: 0.0011198658030480146
mamba_block2.S6.fc1.bias gradient: 0.0014025933342054486
mamba_block2.S6.fc2.weight gradient: 0.003106378484517336
mamba_block2.S6.fc2.bias gradient: 0.004144839011132717
mamba_block2.S6.fc3.weight gradient: 0.002879718318581581
mamba_block2.S6.fc3.bias gradient: 0.003787883324548602
mamba_block2.conv.weight gradient: 0.011892883107066154
mamba_block2.conv.bias gradient: 0.0009201067732647061
mamba_block2.conv_linear.weight gradient: 0.0095598753541708
mamba_block2.conv_linear.bias gradient: 0.006426714826375246
mamba_block2.norm.weight gradient: 0.003324520541355014
mamba_block3.inp_proj.weight gradient: 0.09168072789907455
mamba_block3.inp_proj.bias gradient: 0.03229833021759987
mamba_block3.out_proj.weight gradient: 0.04626988619565964
mamba_block3.out_proj.bias gradient: 7.785879319044398e-08
mamba_block3.D.weight gradient: 0.04269688203930855
mamba_block3.D.bias gradient: 0.015081201680004597
mamba_block3.S6.A_log gradient: 0.0
mamba_block3.S6.fc1.weight gradient: 0.004488199949264526
mamba_block3.S6.fc1.bias gradient: 0.004849399905651808
mamba_block3.S6.fc2.weight gradient: 0.015625562518835068
mamba_block3.S6.fc2.bias gradient: 0.013996592722833157
mamba_block3.S6.fc3.weight gradient: 0.016473641619086266
mamba_block3.S6.fc3.bias gradient: 0.014759422279894352
mamba_block3.conv.weight gradient: 0.16740712523460388
mamba_block3.conv.bias gradient: 0.01699674315750599
mamba_block3.conv_linear.weight gradient: 0.07384508848190308
mamba_block3.conv_linear.bias gradient: 0.049471717327833176
mamba_block3.norm.weight gradient: 0.02215823344886303
DEBUGGING IS ON !!!
output = tensor([[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0262, 1.0139, 0.9936, ..., 1.0366, 1.0374, 0.9927],
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0128, 0.9962]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0261, 1.0139, 0.9937, ..., 1.0366, 1.0374, 0.9927],
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0128, 0.9962]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0261, 1.0139, 0.9936, ..., 1.0366, 1.0373, 0.9927],
[1.0153, 1.0009, 1.0023, ..., 1.0066, 1.0128, 0.9962]],
...,
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0261, 1.0139, 0.9936, ..., 1.0366, 1.0374, 0.9927],
[1.0153, 1.0009, 1.0023, ..., 1.0066, 1.0128, 0.9962]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0261, 1.0139, 0.9936, ..., 1.0366, 1.0373, 0.9927],
[1.0153, 1.0009, 1.0023, ..., 1.0066, 1.0128, 0.9962]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0261, 1.0139, 0.9936, ..., 1.0366, 1.0373, 0.9927],
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0128, 0.9962]]],
device='cuda:0', grad_fn=<ViewBackward0>)
shape: (256, 100, 8)
min/max: (0.963310956954956, 1.037413239479065)
mean: 1.0009527206420898
std: 0.008052636869251728
target = tensor([[[-0.0729, -1.5876, -0.1719, ..., 1.2421, 0.7066, 0.4504],
[ 0.4974, 0.7906, -1.0028, ..., 1.2410, -1.4670, -1.0270],
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199],
...,
[ 1.0504, 1.2879, 1.0797, ..., -0.6484, -1.7104, 0.3437],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.9335, 0.1877, -2.0042, ..., -1.1503, -1.7980, -0.5640],
[ 0.9420, -0.5706, -0.0884, ..., -0.8730, 1.7595, -1.3956],
[-1.6287, -0.9663, -0.0027, ..., 1.2085, -0.9073, 0.0483],
...,
[ 0.7351, -1.4100, 0.0052, ..., 0.4583, 1.4485, -0.0438],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-1.0589, 0.7468, 0.2544, ..., 0.0787, 0.4220, 0.6745],
[-1.0592, 0.1051, -0.3675, ..., -0.1518, -0.8563, -1.1461],
[ 1.0504, 1.2879, 1.0797, ..., -0.6484, -1.7104, 0.3437],
...,
[-0.1842, -1.0219, 1.0257, ..., -0.1165, -0.2031, -0.5445],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
...,
[[-0.1081, 0.8790, 0.6781, ..., -0.5866, -0.1624, 0.4462],
[-0.3029, -0.7849, 0.6537, ..., 0.0216, -0.0510, 1.3417],
[-0.3029, -0.7849, 0.6537, ..., 0.0216, -0.0510, 1.3417],
...,
[-1.1794, -1.5216, 0.0929, ..., 0.0363, 1.0894, 2.1755],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.3699, 0.2790, 0.5842, ..., 1.5279, 0.0930, -1.2546],
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989],
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989],
...,
[-0.0282, 0.8765, 0.3617, ..., -0.8538, 0.5377, -1.6134],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.1884, 0.9860, -0.6278, ..., 0.4238, -1.8099, -0.7295],
[-0.2976, -2.1353, -0.2941, ..., -0.8635, 0.5327, 0.7513],
[-0.7571, -1.6050, -0.0124, ..., -0.9880, -0.9499, 0.8033],
...,
[-0.7124, -0.1175, 0.4958, ..., 0.8150, 0.6772, 1.4007],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]],
device='cuda:0')
shape: (256, 100, 8)
min/max: (-4.062912464141846, 4.436890602111816)
mean: -0.04132794588804245
std: 0.9972973465919495
mamba_block1.inp_proj.weight gradient: 8.05390391178662e-06
mamba_block1.inp_proj.bias gradient: 1.3198721717344597e-05
mamba_block1.out_proj.weight gradient: 8.135655662044883e-05
mamba_block1.out_proj.bias gradient: 0.0019478988833725452
mamba_block1.D.weight gradient: 2.6859646823140793e-05
mamba_block1.D.bias gradient: 2.88037299469579e-05
mamba_block1.S6.A_log gradient: 0.0
mamba_block1.S6.fc1.weight gradient: 3.693293592732516e-06
mamba_block1.S6.fc1.bias gradient: 5.4035172070143744e-06
mamba_block1.S6.fc2.weight gradient: 2.4369408492930233e-05
mamba_block1.S6.fc2.bias gradient: 3.828832268482074e-05
mamba_block1.S6.fc3.weight gradient: 2.356448931095656e-05
mamba_block1.S6.fc3.bias gradient: 3.708050280692987e-05
mamba_block1.conv.weight gradient: 5.2881990995956585e-05
mamba_block1.conv.bias gradient: 8.851877282722853e-06
mamba_block1.conv_linear.weight gradient: 1.8447050024406053e-05
mamba_block1.conv_linear.bias gradient: 6.762157136108726e-05
mamba_block1.norm.weight gradient: 6.673816642432939e-06
mamba_block2.inp_proj.weight gradient: 0.008638471364974976
mamba_block2.inp_proj.bias gradient: 0.0030471261125057936
mamba_block2.out_proj.weight gradient: 0.00904333870857954
mamba_block2.out_proj.bias gradient: 0.022471558302640915
mamba_block2.D.weight gradient: 0.005017615854740143
mamba_block2.D.bias gradient: 0.0017699210438877344
mamba_block2.S6.A_log gradient: 0.0
mamba_block2.S6.fc1.weight gradient: 0.0011227426584810019
mamba_block2.S6.fc1.bias gradient: 0.0013828028459101915
mamba_block2.S6.fc2.weight gradient: 0.0031488672830164433
mamba_block2.S6.fc2.bias gradient: 0.004178797360509634
mamba_block2.S6.fc3.weight gradient: 0.002923778258264065
mamba_block2.S6.fc3.bias gradient: 0.0038226121105253696
mamba_block2.conv.weight gradient: 0.012409028597176075
mamba_block2.conv.bias gradient: 0.0009441959555260837
mamba_block2.conv_linear.weight gradient: 0.00955372303724289
mamba_block2.conv_linear.bias gradient: 0.006237642839550972
mamba_block2.norm.weight gradient: 0.003364289877936244
mamba_block3.inp_proj.weight gradient: 0.1019379273056984
mamba_block3.inp_proj.bias gradient: 0.03592952340841293
mamba_block3.out_proj.weight gradient: 0.048958100378513336
mamba_block3.out_proj.bias gradient: 4.5786606506226235e-08
mamba_block3.D.weight gradient: 0.05141298845410347
mamba_block3.D.bias gradient: 0.018162697553634644
mamba_block3.S6.A_log gradient: 0.0
mamba_block3.S6.fc1.weight gradient: 0.004572192206978798
mamba_block3.S6.fc1.bias gradient: 0.005084225907921791
mamba_block3.S6.fc2.weight gradient: 0.016733428463339806
mamba_block3.S6.fc2.bias gradient: 0.018023964017629623
mamba_block3.S6.fc3.weight gradient: 0.017616454511880875
mamba_block3.S6.fc3.bias gradient: 0.018776189535856247
mamba_block3.conv.weight gradient: 0.16808335483074188
mamba_block3.conv.bias gradient: 0.017074599862098694
mamba_block3.conv_linear.weight gradient: 0.08281093835830688
mamba_block3.conv_linear.bias gradient: 0.060874998569488525
mamba_block3.norm.weight gradient: 0.024700812995433807
DEBUGGING IS ON !!!
output = tensor([[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0262, 1.0139, 0.9937, ..., 1.0367, 1.0375, 0.9927],
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0128, 0.9962]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0262, 1.0140, 0.9936, ..., 1.0367, 1.0375, 0.9927],
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0262, 1.0139, 0.9937, ..., 1.0367, 1.0376, 0.9928],
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0128, 0.9962]],
...,
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0262, 1.0139, 0.9937, ..., 1.0367, 1.0375, 0.9927],
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0013, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0262, 1.0139, 0.9937, ..., 1.0367, 1.0375, 0.9927],
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0128, 0.9962]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0262, 1.0139, 0.9937, ..., 1.0367, 1.0375, 0.9927],
[1.0153, 1.0009, 1.0023, ..., 1.0066, 1.0128, 0.9962]]],
device='cuda:0', grad_fn=<ViewBackward0>)
shape: (256, 100, 8)
min/max: (0.9632824659347534, 1.0375888347625732)
mean: 1.000953197479248
std: 0.008062897250056267
target = tensor([[[ 0.2981, -0.4210, -1.5597, ..., -2.1300, -0.6522, 1.3287],
[ 0.1265, 0.3060, -1.2604, ..., 1.1243, -0.3889, -0.2856],
[-1.0045, -0.8447, 0.0927, ..., -0.8352, -1.6738, 0.2916],
...,
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-0.3029, -0.7849, 0.6537, ..., 0.0216, -0.0510, 1.3417],
[-0.3029, -0.7849, 0.6537, ..., 0.0216, -0.0510, 1.3417],
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199],
...,
[-0.5858, 2.4877, 0.2696, ..., -0.1860, 0.7473, 0.5435],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.8496, -0.4320, 0.5527, ..., -0.2742, 2.0447, -0.5175],
[-0.2976, -2.1353, -0.2941, ..., -0.8635, 0.5327, 0.7513],
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989],
...,
[ 1.8766, 0.0779, -2.8239, ..., -0.9667, -0.3084, 1.0684],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
...,
[[-0.1949, 1.6447, -0.3521, ..., -1.4622, 0.0887, 0.7248],
[ 0.2176, -0.9511, 0.3012, ..., -0.8989, -0.1549, 0.8165],
[-0.3029, -0.7849, 0.6537, ..., 0.0216, -0.0510, 1.3417],
...,
[-0.4027, 1.6658, -0.0122, ..., -0.5772, 2.0100, -0.6190],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-1.6138, 0.4046, 0.1097, ..., 0.0617, -0.5610, 0.3161],
[-1.9648, 2.5084, 1.4522, ..., -1.1336, -1.4860, 1.4592],
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989],
...,
[ 1.2397, -0.0055, 0.5789, ..., -1.6370, -0.4645, -1.3456],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-0.1453, 2.1408, 0.6240, ..., 0.6000, 0.3859, 1.1016],
[ 0.0083, -1.4979, -0.0571, ..., -0.1176, 1.0814, 0.6415],
[ 0.6819, 0.6503, -0.0784, ..., 1.2032, 0.4372, 0.0505],
...,
[-0.9827, -0.1144, 2.1513, ..., 0.4412, -1.5209, -0.6943],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]],
device='cuda:0')
shape: (256, 100, 8)
min/max: (-4.062912464141846, 3.9631638526916504)
mean: -0.04515944421291351
std: 0.9984781742095947
mamba_block1.inp_proj.weight gradient: 8.875967978383414e-06
mamba_block1.inp_proj.bias gradient: 1.4417765669350047e-05
mamba_block1.out_proj.weight gradient: 8.700188482180238e-05
mamba_block1.out_proj.bias gradient: 0.0019073453731834888
mamba_block1.D.weight gradient: 2.5550840291543864e-05
mamba_block1.D.bias gradient: 2.8005353669868782e-05
mamba_block1.S6.A_log gradient: 0.0
mamba_block1.S6.fc1.weight gradient: 3.4992071960004978e-06
mamba_block1.S6.fc1.bias gradient: 4.9961317927227356e-06
mamba_block1.S6.fc2.weight gradient: 2.349440001125913e-05
mamba_block1.S6.fc2.bias gradient: 3.7002042517997324e-05
mamba_block1.S6.fc3.weight gradient: 2.271077573823277e-05
mamba_block1.S6.fc3.bias gradient: 3.5931770980823785e-05
mamba_block1.conv.weight gradient: 5.5383403378073126e-05
mamba_block1.conv.bias gradient: 1.014738609228516e-05
mamba_block1.conv_linear.weight gradient: 1.983829861273989e-05
mamba_block1.conv_linear.bias gradient: 6.198877235874534e-05
mamba_block1.norm.weight gradient: 7.389627626253059e-06
mamba_block2.inp_proj.weight gradient: 0.008638354018330574
mamba_block2.inp_proj.bias gradient: 0.0030470658093690872
mamba_block2.out_proj.weight gradient: 0.008586056530475616
mamba_block2.out_proj.bias gradient: 0.021609429270029068
mamba_block2.D.weight gradient: 0.005029033869504929
mamba_block2.D.bias gradient: 0.0017739012837409973
mamba_block2.S6.A_log gradient: 0.0
mamba_block2.S6.fc1.weight gradient: 0.0011420527007430792
mamba_block2.S6.fc1.bias gradient: 0.0013999291695654392
mamba_block2.S6.fc2.weight gradient: 0.003110809251666069
mamba_block2.S6.fc2.bias gradient: 0.004045133478939533
mamba_block2.S6.fc3.weight gradient: 0.0028845700435340405
mamba_block2.S6.fc3.bias gradient: 0.003693824866786599
mamba_block2.conv.weight gradient: 0.012067034840583801
mamba_block2.conv.bias gradient: 0.0008926084847189486
mamba_block2.conv_linear.weight gradient: 0.009714269079267979
mamba_block2.conv_linear.bias gradient: 0.00610923208296299
mamba_block2.norm.weight gradient: 0.003266611136496067
mamba_block3.inp_proj.weight gradient: 0.09592823684215546
mamba_block3.inp_proj.bias gradient: 0.033799685537815094
mamba_block3.out_proj.weight gradient: 0.04734755679965019
mamba_block3.out_proj.bias gradient: 1.0258085580971965e-07
mamba_block3.D.weight gradient: 0.051352519541978836
mamba_block3.D.bias gradient: 0.018139084801077843
mamba_block3.S6.A_log gradient: 0.0
mamba_block3.S6.fc1.weight gradient: 0.004822483751922846
mamba_block3.S6.fc1.bias gradient: 0.005282685160636902
mamba_block3.S6.fc2.weight gradient: 0.017948182299733162
mamba_block3.S6.fc2.bias gradient: 0.017465978860855103
mamba_block3.S6.fc3.weight gradient: 0.01869887299835682
mamba_block3.S6.fc3.bias gradient: 0.018225835636258125
mamba_block3.conv.weight gradient: 0.16913020610809326
mamba_block3.conv.bias gradient: 0.01695968210697174
mamba_block3.conv_linear.weight gradient: 0.07513276487588882
mamba_block3.conv_linear.bias gradient: 0.04860156029462814
mamba_block3.norm.weight gradient: 0.024453913792967796
DEBUGGING IS ON !!!
output = tensor([[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9967, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0263, 1.0139, 0.9937, ..., 1.0369, 1.0377, 0.9928],
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0263, 1.0139, 0.9937, ..., 1.0369, 1.0377, 0.9928],
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0263, 1.0139, 0.9937, ..., 1.0369, 1.0377, 0.9928],
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]],
...,
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0263, 1.0139, 0.9937, ..., 1.0369, 1.0377, 0.9927],
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0263, 1.0139, 0.9937, ..., 1.0369, 1.0377, 0.9928],
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0263, 1.0139, 0.9937, ..., 1.0368, 1.0377, 0.9928],
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]]],
device='cuda:0', grad_fn=<ViewBackward0>)
shape: (256, 100, 8)
min/max: (0.9632494449615479, 1.037784218788147)
mean: 1.0009536743164062
std: 0.008073143661022186
target = tensor([[[ 1.4736, 0.5671, 0.4209, ..., -0.5206, -0.6041, 1.2744],
[ 0.4655, -0.2961, -0.1109, ..., 0.1105, 0.1356, -0.1565],
[ 1.1350, -1.6268, 1.5461, ..., -1.8916, -1.9114, 1.1798],
...,
[ 0.0083, -1.4979, -0.0571, ..., -0.1176, 1.0814, 0.6415],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 2.4458, -0.4609, -0.6933, ..., 0.2152, 0.6763, -1.1608],
[ 0.6606, 0.6995, -1.1284, ..., 0.8394, -0.4208, -0.3543],
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989],
...,
[ 0.9285, 0.5199, 1.0481, ..., 2.5334, 0.8552, -1.4535],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 2.1325, 0.0683, 1.1581, ..., 1.2571, 0.4663, -0.7213],
[-0.1533, 0.5930, 0.1322, ..., -1.6126, -0.6235, -1.3132],
[ 1.4308, 0.4979, -3.0519, ..., -0.6231, 0.7584, -0.9699],
...,
[ 1.1250, 1.1728, 0.5452, ..., -1.0478, 0.4368, 1.5019],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
...,
[[ 0.8267, -0.7581, -1.7703, ..., -1.0994, 0.0531, -0.9797],
[ 0.1753, -0.7625, 0.1469, ..., 0.3183, 0.1690, -0.1184],
[ 1.1250, 1.1728, 0.5452, ..., -1.0478, 0.4368, 1.5019],
...,
[-0.3877, 0.3024, -1.1404, ..., 2.0661, -0.6191, -0.9355],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-1.1214, 0.1502, -1.9141, ..., -0.7317, 0.2875, 0.2514],
[-0.7564, -0.0087, -1.0106, ..., 0.9032, -1.1468, -0.9196],
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989],
...,
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.2905, -0.4199, -0.1905, ..., -1.0879, 0.5756, 1.0819],
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199],
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199],
...,
[ 0.6819, 0.6503, -0.0784, ..., 1.2032, 0.4372, 0.0505],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]],
device='cuda:0')
shape: (256, 100, 8)
min/max: (-4.062912464141846, 3.9631638526916504)
mean: -0.03772750869393349
std: 0.9975149035453796
mamba_block1.inp_proj.weight gradient: 7.80315986048663e-06
mamba_block1.inp_proj.bias gradient: 1.3307480912772007e-05
mamba_block1.out_proj.weight gradient: 8.432415779680014e-05
mamba_block1.out_proj.bias gradient: 0.0019840870518237352
mamba_block1.D.weight gradient: 2.4237377147073857e-05
mamba_block1.D.bias gradient: 2.91150627163006e-05
mamba_block1.S6.A_log gradient: 0.0
mamba_block1.S6.fc1.weight gradient: 3.429665866860887e-06
mamba_block1.S6.fc1.bias gradient: 4.9839036364573985e-06
mamba_block1.S6.fc2.weight gradient: 1.875735142675694e-05
mamba_block1.S6.fc2.bias gradient: 2.9167822503950447e-05
mamba_block1.S6.fc3.weight gradient: 1.8030044884653762e-05
mamba_block1.S6.fc3.bias gradient: 2.8111589926993474e-05
mamba_block1.conv.weight gradient: 5.296375456964597e-05
mamba_block1.conv.bias gradient: 9.648112609283999e-06
mamba_block1.conv_linear.weight gradient: 1.8288330466020852e-05
mamba_block1.conv_linear.bias gradient: 6.101214967202395e-05
mamba_block1.norm.weight gradient: 6.388846941263182e-06
mamba_block2.inp_proj.weight gradient: 0.008736726827919483
mamba_block2.inp_proj.bias gradient: 0.003081721253693104
mamba_block2.out_proj.weight gradient: 0.008913129568099976
mamba_block2.out_proj.bias gradient: 0.021359387785196304
mamba_block2.D.weight gradient: 0.005379557143896818
mamba_block2.D.bias gradient: 0.0018975320272147655
mamba_block2.S6.A_log gradient: 0.0
mamba_block2.S6.fc1.weight gradient: 0.0012010665377601981
mamba_block2.S6.fc1.bias gradient: 0.0015100709861144423
mamba_block2.S6.fc2.weight gradient: 0.0034038915764540434
mamba_block2.S6.fc2.bias gradient: 0.0046056401915848255
mamba_block2.S6.fc3.weight gradient: 0.0031575202010571957
mamba_block2.S6.fc3.bias gradient: 0.0042189848609268665
mamba_block2.conv.weight gradient: 0.012530266307294369
mamba_block2.conv.bias gradient: 0.0009156710002571344
mamba_block2.conv_linear.weight gradient: 0.009922015480697155
mamba_block2.conv_linear.bias gradient: 0.0066932328045368195
mamba_block2.norm.weight gradient: 0.003500598017126322
mamba_block3.inp_proj.weight gradient: 0.10002566874027252
mamba_block3.inp_proj.bias gradient: 0.03524862602353096
mamba_block3.out_proj.weight gradient: 0.046485137194395065
mamba_block3.out_proj.bias gradient: 5.290740290320173e-08
mamba_block3.D.weight gradient: 0.0468018613755703
mamba_block3.D.bias gradient: 0.016533298417925835
mamba_block3.S6.A_log gradient: 0.0
mamba_block3.S6.fc1.weight gradient: 0.005081711336970329
mamba_block3.S6.fc1.bias gradient: 0.005625939462333918
mamba_block3.S6.fc2.weight gradient: 0.018596788868308067
mamba_block3.S6.fc2.bias gradient: 0.019362082704901695
mamba_block3.S6.fc3.weight gradient: 0.019523371011018753
mamba_block3.S6.fc3.bias gradient: 0.020279493182897568
mamba_block3.conv.weight gradient: 0.16787144541740417
mamba_block3.conv.bias gradient: 0.017073169350624084
mamba_block3.conv_linear.weight gradient: 0.07526125758886337
mamba_block3.conv_linear.bias gradient: 0.04669470712542534
mamba_block3.norm.weight gradient: 0.024019045755267143
DEBUGGING IS ON !!!
output = tensor([[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0264, 1.0139, 0.9937, ..., 1.0370, 1.0379, 0.9928],
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0264, 1.0139, 0.9937, ..., 1.0370, 1.0379, 0.9928],
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]],
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0264, 1.0139, 0.9937, ..., 1.0370, 1.0379, 0.9928],
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]],
...,
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0264, 1.0139, 0.9938, ..., 1.0370, 1.0379, 0.9928],
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]],
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0264, 1.0139, 0.9937, ..., 1.0370, 1.0379, 0.9928],
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]],
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0264, 1.0139, 0.9937, ..., 1.0370, 1.0379, 0.9928],
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]]],
device='cuda:0', grad_fn=<ViewBackward0>)
shape: (256, 100, 8)
min/max: (0.9632093906402588, 1.0379583835601807)
mean: 1.000954270362854
std: 0.008083177730441093
target = tensor([[[-0.3148, -2.4389, -0.7981, ..., 1.4565, 0.6902, -2.8516],
[-0.8489, -1.0928, -0.8596, ..., -0.1898, -0.6665, -1.0761],
[-1.3817, 0.1942, -1.2593, ..., 0.0592, 0.4673, 0.3583],
...,
[ 0.4974, 0.7906, -1.0028, ..., 1.2410, -1.4670, -1.0270],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-1.6416, -0.2457, -2.1230, ..., -0.0060, -1.1015, 1.9065],
[ 0.6251, -0.3997, -0.4391, ..., 0.7783, -1.3073, -0.5255],
[-0.4625, 0.4049, -0.4079, ..., 0.6291, 1.8454, 0.2429],
...,
[-0.6884, 0.6784, -1.0951, ..., -1.3010, 0.1661, -0.5777],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-0.1575, -1.4267, 1.2486, ..., -0.2827, 0.5434, -0.3321],
[ 0.9335, 0.1877, -2.0042, ..., -1.1503, -1.7980, -0.5640],
[-0.3470, 1.6160, -1.1352, ..., 1.0317, 1.0726, 0.2802],
...,
[-1.0850, -0.6438, 0.2743, ..., -1.1642, -0.8742, -0.2776],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
...,
[[ 0.9105, 0.0598, -0.7111, ..., 0.9642, -0.3206, 0.5715],
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989],
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989],
...,
[-2.0443, 0.7522, -0.2560, ..., 0.3880, 0.9740, 0.8830],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.3006, -1.3258, 0.1337, ..., 0.5020, -1.0170, -1.4881],
[-1.3817, 0.1942, -1.2593, ..., 0.0592, 0.4673, 0.3583],
[-1.3397, -0.5167, 0.8265, ..., 0.2521, -0.3263, 0.4133],
...,
[ 2.2013, -0.1434, -0.3354, ..., 0.7899, -1.2002, 0.6800],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.8830, -1.9559, 0.9161, ..., -0.2516, -1.0361, -0.5355],
[ 0.9335, 0.1877, -2.0042, ..., -1.1503, -1.7980, -0.5640],
[-1.7728, -0.2004, -0.4214, ..., -0.8403, 0.5624, 1.3858],
...,
[ 0.1098, 1.6962, 1.1069, ..., 0.4857, 0.8313, 2.2824],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]],
device='cuda:0')
shape: (256, 100, 8)
min/max: (-4.293727397918701, 4.436890602111816)
mean: -0.03683827444911003
std: 0.9964061379432678
mamba_block1.inp_proj.weight gradient: 9.050143489730544e-06
mamba_block1.inp_proj.bias gradient: 1.3294705240696203e-05
mamba_block1.out_proj.weight gradient: 8.745616651140153e-05
mamba_block1.out_proj.bias gradient: 0.002077836310490966
mamba_block1.D.weight gradient: 2.5103221560129896e-05
mamba_block1.D.bias gradient: 2.983676859003026e-05
mamba_block1.S6.A_log gradient: 0.0
mamba_block1.S6.fc1.weight gradient: 3.631815843618824e-06
mamba_block1.S6.fc1.bias gradient: 5.111124210088747e-06
mamba_block1.S6.fc2.weight gradient: 1.814823190215975e-05
mamba_block1.S6.fc2.bias gradient: 2.881790351239033e-05
mamba_block1.S6.fc3.weight gradient: 1.7504753486718982e-05
mamba_block1.S6.fc3.bias gradient: 2.7879737899638712e-05
mamba_block1.conv.weight gradient: 5.321612843545154e-05
mamba_block1.conv.bias gradient: 8.180650183930993e-06
mamba_block1.conv_linear.weight gradient: 1.8778771845973097e-05
mamba_block1.conv_linear.bias gradient: 6.445188046200201e-05
mamba_block1.norm.weight gradient: 6.032561486790655e-06
mamba_block2.inp_proj.weight gradient: 0.009084882214665413
mamba_block2.inp_proj.bias gradient: 0.003204511944204569
mamba_block2.out_proj.weight gradient: 0.009371430613100529
mamba_block2.out_proj.bias gradient: 0.02175474539399147
mamba_block2.D.weight gradient: 0.005289588589221239
mamba_block2.D.bias gradient: 0.0018657789332792163
mamba_block2.S6.A_log gradient: 0.0
mamba_block2.S6.fc1.weight gradient: 0.0011827623238787055
mamba_block2.S6.fc1.bias gradient: 0.00147047801874578
mamba_block2.S6.fc2.weight gradient: 0.00330118159763515
mamba_block2.S6.fc2.bias gradient: 0.004464610014110804
mamba_block2.S6.fc3.weight gradient: 0.0030671891290694475
mamba_block2.S6.fc3.bias gradient: 0.00408173305913806
mamba_block2.conv.weight gradient: 0.012712618336081505
mamba_block2.conv.bias gradient: 0.0009551901021040976
mamba_block2.conv_linear.weight gradient: 0.010156966745853424
mamba_block2.conv_linear.bias gradient: 0.0067415423691272736
mamba_block2.norm.weight gradient: 0.0035081824753433466
mamba_block3.inp_proj.weight gradient: 0.09881533682346344
mamba_block3.inp_proj.bias gradient: 0.03481940180063248
mamba_block3.out_proj.weight gradient: 0.04826899245381355
mamba_block3.out_proj.bias gradient: 8.839717935416047e-08
mamba_block3.D.weight gradient: 0.04611104354262352
mamba_block3.D.bias gradient: 0.0162859745323658
mamba_block3.S6.A_log gradient: 0.0
mamba_block3.S6.fc1.weight gradient: 0.004664601292461157
mamba_block3.S6.fc1.bias gradient: 0.004930058494210243
mamba_block3.S6.fc2.weight gradient: 0.017418786883354187
mamba_block3.S6.fc2.bias gradient: 0.01641864702105522
mamba_block3.S6.fc3.weight gradient: 0.018391642719507217
mamba_block3.S6.fc3.bias gradient: 0.01718025654554367
mamba_block3.conv.weight gradient: 0.1693553924560547
mamba_block3.conv.bias gradient: 0.017287174239754677
mamba_block3.conv_linear.weight gradient: 0.08135779947042465
mamba_block3.conv_linear.bias gradient: 0.05725223198533058
mamba_block3.norm.weight gradient: 0.024166064336895943
DEBUGGING IS ON !!!
output = tensor([[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0266, 1.0139, 0.9938, ..., 1.0371, 1.0381, 0.9928],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]],
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0266, 1.0140, 0.9937, ..., 1.0371, 1.0381, 0.9928],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]],
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0266, 1.0140, 0.9938, ..., 1.0371, 1.0380, 0.9928],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]],
...,
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0265, 1.0139, 0.9938, ..., 1.0371, 1.0381, 0.9928],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]],
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955],
[1.0265, 1.0139, 0.9938, ..., 1.0371, 1.0381, 0.9928],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]],
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0265, 1.0139, 0.9938, ..., 1.0372, 1.0381, 0.9928],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]]],
device='cuda:0', grad_fn=<ViewBackward0>)
shape: (256, 100, 8)
min/max: (0.9631598591804504, 1.0381572246551514)
mean: 1.0009552240371704
std: 0.008093072101473808
target = tensor([[[-1.6287, -0.9663, -0.0027, ..., 1.2085, -0.9073, 0.0483],
[-0.2070, -0.1024, -0.5238, ..., 0.6950, -0.0898, 0.7767],
[-1.3817, 0.1942, -1.2593, ..., 0.0592, 0.4673, 0.3583],
...,
[ 1.0039, 1.2015, 1.3542, ..., 0.8332, 0.5095, 0.6952],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199],
[ 1.2464, 0.3930, 0.7058, ..., -0.5867, 0.7455, -0.8427],
[ 1.2286, -1.6981, 0.1104, ..., -1.3688, -0.4356, 0.2258],
...,
[-0.1288, 0.0194, 0.3021, ..., -0.5487, -0.8879, 0.6104],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-0.1324, -1.4891, -1.6448, ..., -0.2209, -0.6961, 0.3296],
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989],
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989],
...,
[-1.2696, 1.7538, -0.7169, ..., -0.5047, 0.6277, 1.0967],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
...,
[[-1.4777, -0.0545, 0.2544, ..., 0.3233, 0.7367, 0.1191],
[-0.7127, 0.5620, -2.2520, ..., 0.6136, -1.2390, -0.4233],
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199],
...,
[-0.6169, -0.1545, -0.1991, ..., 1.8318, 0.8822, -0.0214],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199],
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199],
[-0.0338, 1.4944, -0.6408, ..., -0.5996, 1.3481, -0.3070],
...,
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989],
[-0.1816, -0.4553, -1.1590, ..., -0.4902, -0.3588, 1.5264],
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199],
...,
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]],
device='cuda:0')
shape: (256, 100, 8)
min/max: (-4.062912464141846, 4.436890602111816)
mean: -0.03783148527145386
std: 0.999710202217102
mamba_block1.inp_proj.weight gradient: 9.031292393046897e-06
mamba_block1.inp_proj.bias gradient: 1.4736494449607562e-05
mamba_block1.out_proj.weight gradient: 8.168919157469645e-05
mamba_block1.out_proj.bias gradient: 0.0020221523009240627
mamba_block1.D.weight gradient: 2.2245831132750027e-05
mamba_block1.D.bias gradient: 2.855477032426279e-05
mamba_block1.S6.A_log gradient: 0.0
mamba_block1.S6.fc1.weight gradient: 3.2786444990051677e-06
mamba_block1.S6.fc1.bias gradient: 4.6202048906707205e-06
mamba_block1.S6.fc2.weight gradient: 2.992124791489914e-05
mamba_block1.S6.fc2.bias gradient: 4.635922232409939e-05
mamba_block1.S6.fc3.weight gradient: 2.8813263270421885e-05
mamba_block1.S6.fc3.bias gradient: 4.471436113817617e-05
mamba_block1.conv.weight gradient: 5.269802568363957e-05
mamba_block1.conv.bias gradient: 9.273887371819e-06
mamba_block1.conv_linear.weight gradient: 2.0372070139274e-05
mamba_block1.conv_linear.bias gradient: 6.995351577643305e-05
mamba_block1.norm.weight gradient: 6.127910182840424e-06
mamba_block2.inp_proj.weight gradient: 0.008746661245822906
mamba_block2.inp_proj.bias gradient: 0.003085183212533593
mamba_block2.out_proj.weight gradient: 0.008963636122643948
mamba_block2.out_proj.bias gradient: 0.021550053730607033
mamba_block2.D.weight gradient: 0.005240651313215494
mamba_block2.D.bias gradient: 0.0018485068576410413
mamba_block2.S6.A_log gradient: 0.0
mamba_block2.S6.fc1.weight gradient: 0.001189987058751285
mamba_block2.S6.fc1.bias gradient: 0.0014902404509484768
mamba_block2.S6.fc2.weight gradient: 0.003241603495553136
mamba_block2.S6.fc2.bias gradient: 0.0039949361234903336
mamba_block2.S6.fc3.weight gradient: 0.0030070182401686907
mamba_block2.S6.fc3.bias gradient: 0.003639199770987034
mamba_block2.conv.weight gradient: 0.012088305316865444
mamba_block2.conv.bias gradient: 0.0008949778275564313
mamba_block2.conv_linear.weight gradient: 0.010031497105956078
mamba_block2.conv_linear.bias gradient: 0.0062803542241454124
mamba_block2.norm.weight gradient: 0.0034341691061854362
mamba_block3.inp_proj.weight gradient: 0.09730257093906403
mamba_block3.inp_proj.bias gradient: 0.03428181633353233
mamba_block3.out_proj.weight gradient: 0.047032639384269714
mamba_block3.out_proj.bias gradient: 6.363734428305179e-08
mamba_block3.D.weight gradient: 0.0502166673541069
mamba_block3.D.bias gradient: 0.017735188826918602
mamba_block3.S6.A_log gradient: 0.0
mamba_block3.S6.fc1.weight gradient: 0.004838158842176199
mamba_block3.S6.fc1.bias gradient: 0.0052208430133759975
mamba_block3.S6.fc2.weight gradient: 0.019451884552836418
mamba_block3.S6.fc2.bias gradient: 0.018285999074578285
mamba_block3.S6.fc3.weight gradient: 0.02031179703772068
mamba_block3.S6.fc3.bias gradient: 0.01904214359819889
mamba_block3.conv.weight gradient: 0.16995052993297577
mamba_block3.conv.bias gradient: 0.01724228635430336
mamba_block3.conv_linear.weight gradient: 0.07708907872438431
mamba_block3.conv_linear.bias gradient: 0.05492333322763443
mamba_block3.norm.weight gradient: 0.02481512911617756
DEBUGGING IS ON !!!
output = tensor([[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955],
[1.0266, 1.0139, 0.9938, ..., 1.0372, 1.0382, 0.9928],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]],
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0267, 1.0139, 0.9938, ..., 1.0372, 1.0382, 0.9928],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]],
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0266, 1.0139, 0.9938, ..., 1.0373, 1.0383, 0.9928],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]],
...,
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955],
[1.0267, 1.0139, 0.9938, ..., 1.0372, 1.0383, 0.9928],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]],
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0013, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955],
[1.0266, 1.0139, 0.9938, ..., 1.0373, 1.0383, 0.9928],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]],
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0013, 0.9948],
...,
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955],
[1.0266, 1.0139, 0.9938, ..., 1.0372, 1.0383, 0.9928],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]]],
device='cuda:0', grad_fn=<ViewBackward0>)
shape: (256, 100, 8)
min/max: (0.9631233811378479, 1.0383379459381104)
mean: 1.0009559392929077
std: 0.008103198371827602
target = tensor([[[-0.3860, -0.5861, 0.8306, ..., 0.6317, -1.0193, -0.9245],
[-1.0569, -0.2186, -1.6387, ..., -1.4346, 0.8052, 0.0375],
[-0.2675, 0.8185, -0.1607, ..., -0.9674, -0.5626, 1.1895],
...,
[-0.3029, -0.7849, 0.6537, ..., 0.0216, -0.0510, 1.3417],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.5150, -0.6664, 2.2355, ..., 1.6450, -1.1488, -1.9170],
[-0.0915, 0.1275, 0.7293, ..., 1.1558, -0.3669, -0.2044],
[ 0.4974, 0.7906, -1.0028, ..., 1.2410, -1.4670, -1.0270],
...,
[-0.1322, -1.6870, -0.4999, ..., 0.5103, 0.1246, -0.4105],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 1.1350, -1.6268, 1.5461, ..., -1.8916, -1.9114, 1.1798],
[ 1.9510, -1.3107, 1.1983, ..., 1.4472, -0.9458, 0.6607],
[-0.0915, 0.1275, 0.7293, ..., 1.1558, -0.3669, -0.2044],
...,
[ 0.2615, 0.2679, -0.0044, ..., -0.9232, 0.1088, -2.0702],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
...,
[[ 0.9234, -0.9421, 1.0451, ..., -0.4781, 0.0420, 0.2934],
[ 0.8607, 0.1229, -0.0035, ..., 0.5764, -2.2611, 1.0230],
[-0.0915, 0.1275, 0.7293, ..., 1.1558, -0.3669, -0.2044],
...,
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 1.2071, 0.1010, 1.8911, ..., 1.6783, 0.7741, 0.0761],
[-0.5735, -0.7252, 0.3188, ..., 0.7167, 0.8917, 1.2515],
[ 0.8172, 2.1389, 1.0939, ..., 0.7351, -1.6642, 1.7776],
...,
[-1.3817, 0.1942, -1.2593, ..., 0.0592, 0.4673, 0.3583],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.5068, -0.9357, 0.5597, ..., -0.2642, 1.4782, 2.3104],
[ 0.6347, -0.6012, 0.3480, ..., 1.5082, -0.9452, 2.0558],
[ 0.9335, 0.1877, -2.0042, ..., -1.1503, -1.7980, -0.5640],
...,
[ 0.5068, -0.9357, 0.5597, ..., -0.2642, 1.4782, 2.3104],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]],
device='cuda:0')
shape: (256, 100, 8)
min/max: (-3.895033836364746, 3.9631638526916504)
mean: -0.03932111710309982
std: 0.9978443384170532
mamba_block1.inp_proj.weight gradient: 9.64433729677694e-06
mamba_block1.inp_proj.bias gradient: 1.7220121662830934e-05
mamba_block1.out_proj.weight gradient: 8.797573536867276e-05
mamba_block1.out_proj.bias gradient: 0.0020824200473725796
mamba_block1.D.weight gradient: 2.7073385354015045e-05
mamba_block1.D.bias gradient: 2.999591379193589e-05
mamba_block1.S6.A_log gradient: 0.0
mamba_block1.S6.fc1.weight gradient: 3.646446202765219e-06
mamba_block1.S6.fc1.bias gradient: 5.3650101108360104e-06
mamba_block1.S6.fc2.weight gradient: 2.464946919644717e-05
mamba_block1.S6.fc2.bias gradient: 3.9769645809428766e-05
mamba_block1.S6.fc3.weight gradient: 2.3926391804707237e-05
mamba_block1.S6.fc3.bias gradient: 3.8730660889996216e-05
mamba_block1.conv.weight gradient: 5.6322707678191364e-05
mamba_block1.conv.bias gradient: 1.1165049727424048e-05
mamba_block1.conv_linear.weight gradient: 2.2284859369392507e-05
mamba_block1.conv_linear.bias gradient: 7.244812877615914e-05
mamba_block1.norm.weight gradient: 6.758013569196919e-06
mamba_block2.inp_proj.weight gradient: 0.008704984560608864
mamba_block2.inp_proj.bias gradient: 0.003070454578846693
mamba_block2.out_proj.weight gradient: 0.00883528869599104
mamba_block2.out_proj.bias gradient: 0.021135663613677025
mamba_block2.D.weight gradient: 0.005147555842995644
mamba_block2.D.bias gradient: 0.0018156523583456874
mamba_block2.S6.A_log gradient: 0.0
mamba_block2.S6.fc1.weight gradient: 0.001187782734632492
mamba_block2.S6.fc1.bias gradient: 0.0014818564523011446
mamba_block2.S6.fc2.weight gradient: 0.003225495107471943
mamba_block2.S6.fc2.bias gradient: 0.004239422734826803
mamba_block2.S6.fc3.weight gradient: 0.002991206245496869
mamba_block2.S6.fc3.bias gradient: 0.0038775706198066473
mamba_block2.conv.weight gradient: 0.012525566853582859
mamba_block2.conv.bias gradient: 0.0008841158705763519
mamba_block2.conv_linear.weight gradient: 0.009744850918650627
mamba_block2.conv_linear.bias gradient: 0.006499310024082661
mamba_block2.norm.weight gradient: 0.0034381034784018993
mamba_block3.inp_proj.weight gradient: 0.09548863023519516
mamba_block3.inp_proj.bias gradient: 0.03364047408103943
mamba_block3.out_proj.weight gradient: 0.04658520221710205
mamba_block3.out_proj.bias gradient: 7.18145756195554e-08
mamba_block3.D.weight gradient: 0.044484060257673264
mamba_block3.D.bias gradient: 0.015709370374679565
mamba_block3.S6.A_log gradient: 0.0
mamba_block3.S6.fc1.weight gradient: 0.004670510068535805
mamba_block3.S6.fc1.bias gradient: 0.0048215556889772415
mamba_block3.S6.fc2.weight gradient: 0.016718773171305656
mamba_block3.S6.fc2.bias gradient: 0.014307697303593159
mamba_block3.S6.fc3.weight gradient: 0.017529338598251343
mamba_block3.S6.fc3.bias gradient: 0.014922278933227062
mamba_block3.conv.weight gradient: 0.17111019790172577
mamba_block3.conv.bias gradient: 0.017431458458304405
mamba_block3.conv_linear.weight gradient: 0.07662271708250046
mamba_block3.conv_linear.bias gradient: 0.04832540079951286
mamba_block3.norm.weight gradient: 0.023405499756336212
DEBUGGING IS ON !!!
output = tensor([[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955],
[1.0268, 1.0139, 0.9938, ..., 1.0374, 1.0385, 0.9928],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]],
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955],
[1.0268, 1.0139, 0.9938, ..., 1.0374, 1.0385, 0.9928],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]],
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955],
[1.0268, 1.0140, 0.9938, ..., 1.0374, 1.0385, 0.9928],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]],
...,
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955],
[1.0267, 1.0139, 0.9939, ..., 1.0374, 1.0385, 0.9928],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]],
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955],
[1.0268, 1.0140, 0.9938, ..., 1.0374, 1.0384, 0.9928],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]],
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955],
[1.0268, 1.0139, 0.9939, ..., 1.0374, 1.0385, 0.9928],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]]],
device='cuda:0', grad_fn=<ViewBackward0>)
shape: (256, 100, 8)
min/max: (0.9630730152130127, 1.0385202169418335)
mean: 1.0009565353393555
std: 0.008112873882055283
target = tensor([[[ 0.9143, 0.0766, -1.7929, ..., -0.3747, -0.3347, 1.5366],
[ 0.5731, -1.5050, -1.4184, ..., 1.9338, -1.1914, -0.8985],
[-1.1330, 1.9570, 0.5161, ..., 0.3537, 0.1684, 0.5828],
...,
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-0.6279, 2.0758, -0.0780, ..., 1.0208, -0.5319, -0.6121],
[ 0.6215, -0.0394, 0.0192, ..., -0.9825, 0.1665, -1.2019],
[-0.5596, -0.3416, 0.7026, ..., 0.5689, -0.4135, -0.8946],
...,
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 2.1325, 0.0683, 1.1581, ..., 1.2571, 0.4663, -0.7213],
[-0.1533, 0.5930, 0.1322, ..., -1.6126, -0.6235, -1.3132],
[ 0.7243, -0.4449, -0.2085, ..., -0.3937, 0.7526, -0.2379],
...,
[-0.8324, -0.4507, 0.2398, ..., 0.7770, -1.6973, -1.6883],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
...,
[[-1.4081, -1.4504, -0.2482, ..., 1.1472, -0.2720, 1.4572],
[-0.3877, 0.3024, -1.1404, ..., 2.0661, -0.6191, -0.9355],
[ 1.5206, 0.1764, 1.2191, ..., 0.9333, -0.5523, -0.3989],
...,
[ 0.6819, 0.6503, -0.0784, ..., 1.2032, 0.4372, 0.0505],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-0.1533, 0.5930, 0.1322, ..., -1.6126, -0.6235, -1.3132],
[-0.6782, -1.3975, 0.4929, ..., 0.3559, -0.3167, -1.3198],
[ 0.4125, 0.8607, -0.2114, ..., -0.0050, -0.0463, -1.4117],
...,
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-0.0061, 0.0731, 0.1958, ..., -0.5969, -0.9973, -2.2435],
[-0.0282, 0.8765, 0.3617, ..., -0.8538, 0.5377, -1.6134],
[ 1.2286, -1.6981, 0.1104, ..., -1.3688, -0.4356, 0.2258],
...,
[ 1.9510, -1.3107, 1.1983, ..., 1.4472, -0.9458, 0.6607],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]],
device='cuda:0')
shape: (256, 100, 8)
min/max: (-4.062912464141846, 3.9631638526916504)
mean: -0.03662058338522911
std: 0.9983726143836975
mamba_block1.inp_proj.weight gradient: 7.730282050033566e-06
mamba_block1.inp_proj.bias gradient: 1.4676887076348066e-05
mamba_block1.out_proj.weight gradient: 8.539092959836125e-05
mamba_block1.out_proj.bias gradient: 0.002003992907702923
mamba_block1.D.weight gradient: 2.3007403797237203e-05
mamba_block1.D.bias gradient: 2.977289659611415e-05
mamba_block1.S6.A_log gradient: 0.0
mamba_block1.S6.fc1.weight gradient: 3.472130174486665e-06
mamba_block1.S6.fc1.bias gradient: 4.907219135930063e-06
mamba_block1.S6.fc2.weight gradient: 2.2677684683003463e-05
mamba_block1.S6.fc2.bias gradient: 3.604851008276455e-05
mamba_block1.S6.fc3.weight gradient: 2.215725544374436e-05
mamba_block1.S6.fc3.bias gradient: 3.528478919179179e-05
mamba_block1.conv.weight gradient: 5.533275179914199e-05
mamba_block1.conv.bias gradient: 9.216477337758988e-06
mamba_block1.conv_linear.weight gradient: 1.9760547729674727e-05
mamba_block1.conv_linear.bias gradient: 6.318444502539933e-05
mamba_block1.norm.weight gradient: 5.745764610765036e-06
mamba_block2.inp_proj.weight gradient: 0.009204614907503128
mamba_block2.inp_proj.bias gradient: 0.0032466622069478035
mamba_block2.out_proj.weight gradient: 0.009073683060705662
mamba_block2.out_proj.bias gradient: 0.021660299971699715
mamba_block2.D.weight gradient: 0.00532105565071106
mamba_block2.D.bias gradient: 0.0018768273293972015
mamba_block2.S6.A_log gradient: 0.0
mamba_block2.S6.fc1.weight gradient: 0.0012535410933196545
mamba_block2.S6.fc1.bias gradient: 0.0015219313791021705
mamba_block2.S6.fc2.weight gradient: 0.00339127192273736
mamba_block2.S6.fc2.bias gradient: 0.004481762647628784
mamba_block2.S6.fc3.weight gradient: 0.0031439471058547497
mamba_block2.S6.fc3.bias gradient: 0.004101179540157318
mamba_block2.conv.weight gradient: 0.01293809525668621
mamba_block2.conv.bias gradient: 0.0009332753252238035
mamba_block2.conv_linear.weight gradient: 0.010446547530591488
mamba_block2.conv_linear.bias gradient: 0.006634898949414492
mamba_block2.norm.weight gradient: 0.0034382217563688755
mamba_block3.inp_proj.weight gradient: 0.10265932977199554
mamba_block3.inp_proj.bias gradient: 0.03616949915885925
mamba_block3.out_proj.weight gradient: 0.04556775465607643
mamba_block3.out_proj.bias gradient: 6.957547071806403e-08
mamba_block3.D.weight gradient: 0.0451095886528492
mamba_block3.D.bias gradient: 0.015929805114865303
mamba_block3.S6.A_log gradient: 0.0
mamba_block3.S6.fc1.weight gradient: 0.004865700379014015
mamba_block3.S6.fc1.bias gradient: 0.005076521076261997
mamba_block3.S6.fc2.weight gradient: 0.01791907660663128
mamba_block3.S6.fc2.bias gradient: 0.015313433483242989
mamba_block3.S6.fc3.weight gradient: 0.01880076713860035
mamba_block3.S6.fc3.bias gradient: 0.016120247542858124
mamba_block3.conv.weight gradient: 0.1731170117855072
mamba_block3.conv.bias gradient: 0.017587197944521904
mamba_block3.conv_linear.weight gradient: 0.08034130185842514
mamba_block3.conv_linear.bias gradient: 0.048483993858098984
mamba_block3.norm.weight gradient: 0.023936070501804352
DEBUGGING IS ON !!!
output = tensor([[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955],
[1.0268, 1.0140, 0.9939, ..., 1.0375, 1.0386, 0.9928],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]],
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955],
[1.0269, 1.0140, 0.9939, ..., 1.0375, 1.0386, 0.9928],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]],
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955],
[1.0269, 1.0139, 0.9939, ..., 1.0375, 1.0387, 0.9928],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]],
...,
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955],
[1.0269, 1.0139, 0.9939, ..., 1.0375, 1.0386, 0.9928],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]],
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955],
[1.0269, 1.0139, 0.9939, ..., 1.0375, 1.0386, 0.9928],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]],
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955],
[1.0269, 1.0139, 0.9939, ..., 1.0375, 1.0386, 0.9928],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]]],
device='cuda:0', grad_fn=<ViewBackward0>)
shape: (256, 100, 8)
min/max: (0.963013768196106, 1.0387181043624878)
mean: 1.0009573698043823
std: 0.0081221554428339
target = tensor([[[-0.1301, -0.9992, 0.0145, ..., -0.3068, 0.9142, -2.1975],
[ 0.6819, 0.6503, -0.0784, ..., 1.2032, 0.4372, 0.0505],
[ 0.5060, 2.4795, 0.5965, ..., -0.9060, -0.1549, 0.4353],
...,
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-0.3053, -0.2988, -1.8848, ..., 1.0646, -0.5250, -0.6723],
[ 0.1819, -1.5667, -1.6287, ..., -1.8686, 0.2220, 0.7203],
[-0.3029, -0.7849, 0.6537, ..., 0.0216, -0.0510, 1.3417],
...,
[-0.0729, -1.5876, -0.1719, ..., 1.2421, 0.7066, 0.4504],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 2.1325, 0.0683, 1.1581, ..., 1.2571, 0.4663, -0.7213],
[-0.6782, -1.3975, 0.4929, ..., 0.3559, -0.3167, -1.3198],
[ 0.4125, 0.8607, -0.2114, ..., -0.0050, -0.0463, -1.4117],
...,
[ 2.1325, 0.0683, 1.1581, ..., 1.2571, 0.4663, -0.7213],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
...,
[[-0.1533, 0.5930, 0.1322, ..., -1.6126, -0.6235, -1.3132],
[-0.7129, -0.4076, -0.2963, ..., 1.9239, -1.4047, 0.4096],
[ 0.4125, 0.8607, -0.2114, ..., -0.0050, -0.0463, -1.4117],
...,
[-0.1533, 0.5930, 0.1322, ..., -1.6126, -0.6235, -1.3132],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-0.8332, -0.0094, 0.1628, ..., -0.1734, -0.2614, -1.1598],
[-1.0951, 0.4374, 1.1074, ..., 0.7239, 0.9897, 0.3390],
[-0.8718, 0.4088, -0.5637, ..., 0.8139, 0.3387, -0.3325],
...,
[-0.0729, -1.5876, -0.1719, ..., 1.2421, 0.7066, 0.4504],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.4021, 1.1737, -0.5641, ..., -0.0357, 0.1684, 0.1134],
[ 2.2047, -2.3811, 1.1213, ..., 0.0741, 0.3054, 0.0921],
[ 0.3262, -0.7823, -0.1636, ..., 0.3129, -0.0835, 0.3686],
...,
[ 0.3753, -0.1169, 0.4159, ..., 0.8816, -0.7008, 1.1613],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]],
device='cuda:0')
shape: (256, 100, 8)
min/max: (-4.0426435470581055, 3.970458745956421)
mean: -0.03180283308029175
std: 0.9980447292327881
mamba_block1.inp_proj.weight gradient: 8.23084428702714e-06
mamba_block1.inp_proj.bias gradient: 1.3557030797528569e-05
mamba_block1.out_proj.weight gradient: 8.309278200613335e-05
mamba_block1.out_proj.bias gradient: 0.0020105468574911356
mamba_block1.D.weight gradient: 2.385676998528652e-05
mamba_block1.D.bias gradient: 2.8827647838625126e-05
mamba_block1.S6.A_log gradient: 0.0
mamba_block1.S6.fc1.weight gradient: 3.5073132949037245e-06
mamba_block1.S6.fc1.bias gradient: 5.155464805284282e-06
mamba_block1.S6.fc2.weight gradient: 1.3048370419710409e-05
mamba_block1.S6.fc2.bias gradient: 2.1127018044353463e-05
mamba_block1.S6.fc3.weight gradient: 1.2743626029987354e-05
mamba_block1.S6.fc3.bias gradient: 2.0642875824705698e-05
mamba_block1.conv.weight gradient: 5.38126150786411e-05
mamba_block1.conv.bias gradient: 9.490980119153392e-06
mamba_block1.conv_linear.weight gradient: 1.9674673239933327e-05
mamba_block1.conv_linear.bias gradient: 5.9852962294826284e-05
mamba_block1.norm.weight gradient: 5.154767222848022e-06
mamba_block2.inp_proj.weight gradient: 0.009204426780343056
mamba_block2.inp_proj.bias gradient: 0.0032465667463839054
mamba_block2.out_proj.weight gradient: 0.00896873977035284
mamba_block2.out_proj.bias gradient: 0.021650521084666252
mamba_block2.D.weight gradient: 0.005361825693398714
mamba_block2.D.bias gradient: 0.0018912054365500808
mamba_block2.S6.A_log gradient: 0.0
mamba_block2.S6.fc1.weight gradient: 0.001220333855599165
mamba_block2.S6.fc1.bias gradient: 0.0015079170698300004
mamba_block2.S6.fc2.weight gradient: 0.003332579042762518
mamba_block2.S6.fc2.bias gradient: 0.004285544157028198
mamba_block2.S6.fc3.weight gradient: 0.0030965039040893316
mamba_block2.S6.fc3.bias gradient: 0.003915925044566393
mamba_block2.conv.weight gradient: 0.012820214033126831
mamba_block2.conv.bias gradient: 0.0009201719076372683
mamba_block2.conv_linear.weight gradient: 0.01025470346212387
mamba_block2.conv_linear.bias gradient: 0.00649257143959403
mamba_block2.norm.weight gradient: 0.0035417950712144375
mamba_block3.inp_proj.weight gradient: 0.09837187826633453
mamba_block3.inp_proj.bias gradient: 0.034655362367630005
mamba_block3.out_proj.weight gradient: 0.04687173664569855
mamba_block3.out_proj.bias gradient: 8.309151411367566e-08
mamba_block3.D.weight gradient: 0.053607795387506485
mamba_block3.D.bias gradient: 0.01893215999007225
mamba_block3.S6.A_log gradient: 0.0
mamba_block3.S6.fc1.weight gradient: 0.0047625149600207806
mamba_block3.S6.fc1.bias gradient: 0.005159389693289995
mamba_block3.S6.fc2.weight gradient: 0.0174313522875309
mamba_block3.S6.fc2.bias gradient: 0.01905071921646595
mamba_block3.S6.fc3.weight gradient: 0.0183222908526659
mamba_block3.S6.fc3.bias gradient: 0.019951820373535156
mamba_block3.conv.weight gradient: 0.17245595157146454
mamba_block3.conv.bias gradient: 0.017512831836938858
mamba_block3.conv_linear.weight gradient: 0.07903292030096054
mamba_block3.conv_linear.bias gradient: 0.05545603483915329
mamba_block3.norm.weight gradient: 0.02469041757285595
DEBUGGING IS ON !!!
output = tensor([[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9965, ..., 1.0036, 1.0034, 0.9951],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955],
[1.0270, 1.0139, 0.9939, ..., 1.0377, 1.0388, 0.9929],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]],
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9965, ..., 1.0036, 1.0034, 0.9951],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955],
[1.0270, 1.0139, 0.9939, ..., 1.0377, 1.0388, 0.9929],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]],
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955],
[1.0270, 1.0140, 0.9939, ..., 1.0376, 1.0388, 0.9928],
[1.0154, 1.0009, 1.0025, ..., 1.0066, 1.0130, 0.9961]],
...,
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955],
[1.0270, 1.0139, 0.9939, ..., 1.0377, 1.0388, 0.9929],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]],
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955],
[1.0270, 1.0139, 0.9939, ..., 1.0377, 1.0389, 0.9929],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]],
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955],
[1.0270, 1.0139, 0.9939, ..., 1.0376, 1.0388, 0.9928],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]]],
device='cuda:0', grad_fn=<ViewBackward0>)
shape: (256, 100, 8)
min/max: (0.9629841446876526, 1.0388917922973633)
mean: 1.0009583234786987
std: 0.008131702430546284
target = tensor([[[-2.2013, 1.4751, 0.8977, ..., -1.7997, -1.3911, -0.1680],
[-1.6287, -0.9663, -0.0027, ..., 1.2085, -0.9073, 0.0483],
[ 1.0801, -0.3887, -0.2138, ..., 0.7030, -1.7206, 0.2015],
...,
[-1.3817, 0.1942, -1.2593, ..., 0.0592, 0.4673, 0.3583],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-1.0850, -0.6438, 0.2743, ..., -1.1642, -0.8742, -0.2776],
[ 0.6819, 0.6503, -0.0784, ..., 1.2032, 0.4372, 0.0505],
[-0.4621, 0.4395, 1.3246, ..., -0.5279, 0.6105, 2.4551],
...,
[ 0.6628, 2.2315, 0.2679, ..., 0.4018, 1.3974, 0.2715],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.3699, 0.2790, 0.5842, ..., 1.5279, 0.0930, -1.2546],
[-0.0915, 0.1275, 0.7293, ..., 1.1558, -0.3669, -0.2044],
[-0.1491, -0.8928, -1.1765, ..., -0.9342, 2.1916, 0.8451],
...,
[-0.3843, 0.2086, -1.3855, ..., 0.5185, 0.1296, 0.5115],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
...,
[[ 0.4125, 0.8607, -0.2114, ..., -0.0050, -0.0463, -1.4117],
[-0.8696, -0.4503, 0.3101, ..., 1.0256, -0.7886, -1.2446],
[-0.2637, 0.3004, 0.0593, ..., -1.0608, 0.3555, 0.8021],
...,
[ 1.4968, 0.2999, 0.0651, ..., -0.6530, -1.8364, 0.4741],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-0.0221, -1.3493, -0.6605, ..., 0.9500, -0.2841, 0.1124],
[-1.9098, -0.9699, -1.8455, ..., 0.6946, 1.9096, 0.4540],
[ 0.4838, -0.5874, -1.1409, ..., -0.1160, -0.5902, 0.5632],
...,
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.4428, 0.8034, -0.7362, ..., -0.6413, 2.3065, -0.3966],
[ 0.5679, 1.1681, -0.2152, ..., 0.4324, -0.3278, 0.3071],
[-0.6504, -1.3069, -0.1538, ..., 0.7510, -1.4239, -0.2393],
...,
[ 0.1019, -0.3633, 1.4970, ..., -1.1070, 1.8470, 0.7881],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]],
device='cuda:0')
shape: (256, 100, 8)
min/max: (-4.203547954559326, 4.436890602111816)
mean: -0.03840525448322296
std: 1.0000741481781006
mamba_block1.inp_proj.weight gradient: 8.991964023152832e-06
mamba_block1.inp_proj.bias gradient: 1.7170854334835894e-05
mamba_block1.out_proj.weight gradient: 9.231397416442633e-05
mamba_block1.out_proj.bias gradient: 0.00212532258592546
mamba_block1.D.weight gradient: 2.5002524125739e-05
mamba_block1.D.bias gradient: 3.0523628083756194e-05
mamba_block1.S6.A_log gradient: 0.0
mamba_block1.S6.fc1.weight gradient: 3.370713557160343e-06
mamba_block1.S6.fc1.bias gradient: 4.826616986974841e-06
mamba_block1.S6.fc2.weight gradient: 1.8368005839874968e-05
mamba_block1.S6.fc2.bias gradient: 2.9517530492739752e-05
mamba_block1.S6.fc3.weight gradient: 1.795031494111754e-05
mamba_block1.S6.fc3.bias gradient: 2.8881711841677316e-05
mamba_block1.conv.weight gradient: 5.597508788923733e-05
mamba_block1.conv.bias gradient: 9.704062904347666e-06
mamba_block1.conv_linear.weight gradient: 2.1619767721858807e-05
mamba_block1.conv_linear.bias gradient: 6.322540866676718e-05
mamba_block1.norm.weight gradient: 5.060106559540145e-06
mamba_block2.inp_proj.weight gradient: 0.009407997131347656
mamba_block2.inp_proj.bias gradient: 0.003318359376862645
mamba_block2.out_proj.weight gradient: 0.00935858953744173
mamba_block2.out_proj.bias gradient: 0.022997798398137093
mamba_block2.D.weight gradient: 0.00555233983322978
mamba_block2.D.bias gradient: 0.001958364387974143
mamba_block2.S6.A_log gradient: 0.0
mamba_block2.S6.fc1.weight gradient: 0.0012591216946020722
mamba_block2.S6.fc1.bias gradient: 0.0015584889333695173
mamba_block2.S6.fc2.weight gradient: 0.003501052735373378
mamba_block2.S6.fc2.bias gradient: 0.0046097771264612675
mamba_block2.S6.fc3.weight gradient: 0.0032506112474948168
mamba_block2.S6.fc3.bias gradient: 0.004216910805553198
mamba_block2.conv.weight gradient: 0.012959428131580353
mamba_block2.conv.bias gradient: 0.0009631924913264811
mamba_block2.conv_linear.weight gradient: 0.01051102951169014
mamba_block2.conv_linear.bias gradient: 0.006955367047339678
mamba_block2.norm.weight gradient: 0.0037049155216664076
mamba_block3.inp_proj.weight gradient: 0.10352003574371338
mamba_block3.inp_proj.bias gradient: 0.03647966310381889
mamba_block3.out_proj.weight gradient: 0.048568934202194214
mamba_block3.out_proj.bias gradient: 5.696314886449727e-08
mamba_block3.D.weight gradient: 0.048673514276742935
mamba_block3.D.bias gradient: 0.017192283645272255
mamba_block3.S6.A_log gradient: 0.0
mamba_block3.S6.fc1.weight gradient: 0.004940532147884369
mamba_block3.S6.fc1.bias gradient: 0.00535177206620574
mamba_block3.S6.fc2.weight gradient: 0.017920760437846184
mamba_block3.S6.fc2.bias gradient: 0.01845845952630043
mamba_block3.S6.fc3.weight gradient: 0.01888885162770748
mamba_block3.S6.fc3.bias gradient: 0.019463254138827324
mamba_block3.conv.weight gradient: 0.17237348854541779
mamba_block3.conv.bias gradient: 0.01762760430574417
mamba_block3.conv_linear.weight gradient: 0.08468104898929596
mamba_block3.conv_linear.bias gradient: 0.054214172065258026
mamba_block3.norm.weight gradient: 0.025246795266866684
DEBUGGING IS ON !!!
output = tensor([[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0054, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9947],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955],
[1.0271, 1.0139, 0.9939, ..., 1.0378, 1.0390, 0.9929],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]],
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955],
[1.0271, 1.0139, 0.9940, ..., 1.0378, 1.0390, 0.9928],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]],
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955],
[1.0271, 1.0140, 0.9939, ..., 1.0378, 1.0390, 0.9929],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]],
...,
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955],
[1.0271, 1.0140, 0.9940, ..., 1.0378, 1.0390, 0.9929],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0130, 0.9961]],
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0054, 1.0066, 0.9966, ..., 1.0037, 1.0034, 0.9951],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955],
[1.0271, 1.0140, 0.9940, ..., 1.0378, 1.0390, 0.9929],
[1.0154, 1.0009, 1.0025, ..., 1.0066, 1.0130, 0.9961]],
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0054, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951],
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955],
[1.0271, 1.0139, 0.9940, ..., 1.0378, 1.0390, 0.9929],
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]]],
device='cuda:0', grad_fn=<ViewBackward0>)
shape: (256, 100, 8)
min/max: (0.9629332423210144, 1.0391098260879517)
mean: 1.0009595155715942
std: 0.008141160011291504
target = tensor([[[-1.8136, -0.6341, -0.1093, ..., 0.4384, -0.3662, -0.9972],
[ 0.4974, 0.7906, -1.0028, ..., 1.2410, -1.4670, -1.0270],
[ 0.7458, 1.6577, 0.0364, ..., 1.2313, -0.1711, 0.0749],
...,
[ 1.4597, 1.8877, -0.5288, ..., -0.6553, -0.6894, -0.6879],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 2.1325, 0.0683, 1.1581, ..., 1.2571, 0.4663, -0.7213],
[-0.6782, -1.3975, 0.4929, ..., 0.3559, -0.3167, -1.3198],
[ 0.4125, 0.8607, -0.2114, ..., -0.0050, -0.0463, -1.4117],
...,
[-0.3029, -0.7849, 0.6537, ..., 0.0216, -0.0510, 1.3417],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-0.8389, -0.8643, 0.4242, ..., 0.5848, 1.5457, -0.4353],
[-0.8738, -0.8262, 0.1785, ..., -0.7729, 0.3997, 1.0206],
[-0.0377, -0.0993, -0.3354, ..., -0.4587, 2.1620, -1.0884],
...,
[-0.8324, -0.4507, 0.2398, ..., 0.7770, -1.6973, -1.6883],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
...,
[[ 0.9234, -0.9421, 1.0451, ..., -0.4781, 0.0420, 0.2934],
[-1.6287, -0.9663, -0.0027, ..., 1.2085, -0.9073, 0.0483],
[ 0.4974, 0.7906, -1.0028, ..., 1.2410, -1.4670, -1.0270],
...,
[ 0.8195, 1.3160, 0.7905, ..., 0.3638, -0.4126, 0.3174],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.5068, -0.9357, 0.5597, ..., -0.2642, 1.4782, 2.3104],
[ 0.6347, -0.6012, 0.3480, ..., 1.5082, -0.9452, 2.0558],
[ 0.9335, 0.1877, -2.0042, ..., -1.1503, -1.7980, -0.5640],
...,
[-0.3877, 0.3024, -1.1404, ..., 2.0661, -0.6191, -0.9355],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.8417, 0.6482, 2.4525, ..., -0.0736, -1.3844, -1.5417],
[-1.6539, -0.2380, 1.2548, ..., -0.2029, -0.3846, -1.4885],
[ 0.6819, 0.6503, -0.0784, ..., 1.2032, 0.4372, 0.0505],
...,
[ 0.5459, -1.0019, 1.6465, ..., -0.7943, 1.1101, 1.4487],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]],
device='cuda:0')
shape: (256, 100, 8)
min/max: (-4.062912464141846, 3.9631638526916504)
mean: -0.043465666472911835
std: 0.9981724619865417
mamba_block1.inp_proj.weight gradient: 9.746233445184771e-06
mamba_block1.inp_proj.bias gradient: 1.613192944205366e-05
mamba_block1.out_proj.weight gradient: 9.361501724924892e-05
mamba_block1.out_proj.bias gradient: 0.002090686932206154
mamba_block1.D.weight gradient: 2.397046955593396e-05
mamba_block1.D.bias gradient: 2.898801540140994e-05
mamba_block1.S6.A_log gradient: 0.0
mamba_block1.S6.fc1.weight gradient: 3.917689809895819e-06
mamba_block1.S6.fc1.bias gradient: 5.5823461480031256e-06
mamba_block1.S6.fc2.weight gradient: 2.6315385184716433e-05
mamba_block1.S6.fc2.bias gradient: 4.1125862480839714e-05
mamba_block1.S6.fc3.weight gradient: 2.5236464352929033e-05
mamba_block1.S6.fc3.bias gradient: 3.954580824938603e-05
mamba_block1.conv.weight gradient: 5.663771298713982e-05
mamba_block1.conv.bias gradient: 1.0728360393841285e-05
mamba_block1.conv_linear.weight gradient: 2.2806825654697604e-05
mamba_block1.conv_linear.bias gradient: 7.061885844450444e-05
mamba_block1.norm.weight gradient: 6.072532869438874e-06
mamba_block2.inp_proj.weight gradient: 0.009219370782375336
mamba_block2.inp_proj.bias gradient: 0.003251793095842004
mamba_block2.out_proj.weight gradient: 0.009203227236866951
mamba_block2.out_proj.bias gradient: 0.022098751738667488
mamba_block2.D.weight gradient: 0.00542529858648777
mamba_block2.D.bias gradient: 0.0019135409966111183
mamba_block2.S6.A_log gradient: 0.0
mamba_block2.S6.fc1.weight gradient: 0.0012004896998405457
mamba_block2.S6.fc1.bias gradient: 0.0015050115762278438
mamba_block2.S6.fc2.weight gradient: 0.0033298747148364782
mamba_block2.S6.fc2.bias gradient: 0.0042640469036996365
mamba_block2.S6.fc3.weight gradient: 0.0030927204061299562
mamba_block2.S6.fc3.bias gradient: 0.0038947444409132004
mamba_block2.conv.weight gradient: 0.012392286211252213
mamba_block2.conv.bias gradient: 0.0009247513953596354
mamba_block2.conv_linear.weight gradient: 0.010260899551212788
mamba_block2.conv_linear.bias gradient: 0.006446880754083395
mamba_block2.norm.weight gradient: 0.0036036702804267406
mamba_block3.inp_proj.weight gradient: 0.09711047261953354
mamba_block3.inp_proj.bias gradient: 0.034207314252853394
mamba_block3.out_proj.weight gradient: 0.04665215313434601
mamba_block3.out_proj.bias gradient: 8.800382289564368e-08
mamba_block3.D.weight gradient: 0.04661861062049866
mamba_block3.D.bias gradient: 0.016462555155158043
mamba_block3.S6.A_log gradient: 0.0
mamba_block3.S6.fc1.weight gradient: 0.004771151579916477
mamba_block3.S6.fc1.bias gradient: 0.004979342687875032
mamba_block3.S6.fc2.weight gradient: 0.017470696941018105
mamba_block3.S6.fc2.bias gradient: 0.018941262736916542
mamba_block3.S6.fc3.weight gradient: 0.018347127363085747
mamba_block3.S6.fc3.bias gradient: 0.01973014324903488
mamba_block3.conv.weight gradient: 0.176979660987854
mamba_block3.conv.bias gradient: 0.017966141924262047
mamba_block3.conv_linear.weight gradient: 0.0760805532336235
mamba_block3.conv_linear.bias gradient: 0.0522383488714695
mamba_block3.norm.weight gradient: 0.024708587676286697
DEBUGGING IS ON !!!
output = tensor([[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0054, 1.0066, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955],
[1.0272, 1.0140, 0.9940, ..., 1.0379, 1.0392, 0.9929],
[1.0155, 1.0009, 1.0025, ..., 1.0066, 1.0130, 0.9961]],
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0054, 1.0066, 0.9966, ..., 1.0037, 1.0034, 0.9951],
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955],
[1.0272, 1.0139, 0.9940, ..., 1.0379, 1.0392, 0.9929],
[1.0155, 1.0009, 1.0024, ..., 1.0066, 1.0130, 0.9961]],
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0054, 1.0066, 0.9966, ..., 1.0037, 1.0034, 0.9951],
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955],
[1.0271, 1.0139, 0.9940, ..., 1.0379, 1.0391, 0.9929],
[1.0155, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]],
...,
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0054, 1.0066, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0013, 0.9948],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955],
[1.0272, 1.0139, 0.9940, ..., 1.0379, 1.0392, 0.9929],
[1.0155, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]],
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0054, 1.0066, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955],
[1.0272, 1.0139, 0.9940, ..., 1.0379, 1.0392, 0.9929],
[1.0155, 1.0009, 1.0024, ..., 1.0066, 1.0130, 0.9961]],
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0054, 1.0066, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955],
[1.0272, 1.0140, 0.9940, ..., 1.0379, 1.0392, 0.9929],
[1.0155, 1.0009, 1.0024, ..., 1.0066, 1.0130, 0.9961]]],
device='cuda:0', grad_fn=<ViewBackward0>)
shape: (256, 100, 8)
min/max: (0.9628950357437134, 1.0392659902572632)
mean: 1.0009607076644897
std: 0.008150782436132431
target = tensor([[[-1.6287e+00, -9.6626e-01, -2.6581e-03, ..., 1.2085e+00,
-9.0730e-01, 4.8345e-02],
[ 1.3781e-01, -3.8981e-01, 4.6194e-01, ..., 1.9883e-01,
-3.7158e-01, 3.5527e-01],
[-1.8160e-01, -4.5527e-01, -1.1590e+00, ..., -4.9020e-01,
-3.5882e-01, 1.5264e+00],
...,
[-1.3817e+00, 1.9420e-01, -1.2593e+00, ..., 5.9164e-02,
4.6726e-01, 3.5826e-01],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[-1.8078e-01, 2.4252e-01, -6.3844e-01, ..., -7.8935e-02,
6.0249e-01, -5.8976e-01],
[-1.5419e+00, 1.8212e+00, 1.8157e+00, ..., -9.8702e-03,
-7.1342e-01, -2.5576e-01],
[ 6.8187e-01, 6.5033e-01, -7.8437e-02, ..., 1.2032e+00,
4.3723e-01, 5.0549e-02],
...,
[ 2.0079e+00, 2.8988e-01, 1.4619e-01, ..., 1.1677e-02,
1.5477e-01, 9.1439e-01],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[ 1.1616e+00, -1.3193e+00, -6.1403e-01, ..., 1.9209e+00,
1.3262e+00, 5.9860e-02],
[ 3.5665e-01, -3.1154e-01, -1.2586e+00, ..., -9.5706e-01,
-2.0711e+00, -1.0293e+00],
[-6.5829e-01, -2.2524e-01, 2.0800e+00, ..., 7.8087e-01,
7.4104e-01, -1.9717e+00],
...,
[-5.4846e-01, -1.0681e+00, -1.4576e+00, ..., -9.8265e-01,
2.3030e+00, 1.3365e+00],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
...,
[[-1.2654e-01, -6.2001e-01, -2.4523e+00, ..., -1.2879e+00,
-2.9765e-01, 1.6772e+00],
[ 2.0127e-01, 3.0885e-01, 1.0572e+00, ..., 2.7429e-01,
-7.5508e-01, 3.9383e-01],
[ 6.6372e-01, 1.7408e-01, -3.5581e-01, ..., -1.4354e+00,
1.1672e+00, 4.4820e-02],
...,
[ 3.2237e-01, 5.8461e-01, -3.4121e-03, ..., -2.5780e-01,
-1.3302e+00, -5.8217e-01],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[ 8.9742e-01, -8.8120e-01, 7.8961e-01, ..., -8.5716e-01,
-1.6618e+00, -3.6012e-01],
[ 2.5504e+00, -5.1241e-01, -2.8401e-01, ..., -9.7184e-02,
-6.2687e-01, -4.9886e-01],
[ 2.5504e+00, -5.1241e-01, -2.8401e-01, ..., -9.7184e-02,
-6.2687e-01, -4.9886e-01],
...,
[ 1.7525e-01, -7.6253e-01, 1.4695e-01, ..., 3.1829e-01,
1.6904e-01, -1.1843e-01],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[-1.5811e-01, 1.7191e-01, 1.5170e+00, ..., -4.7812e-01,
4.5400e-02, 1.0040e+00],
[-7.9920e-01, 1.3202e+00, -1.7626e-01, ..., 1.0812e-01,
-7.9432e-02, -3.9323e-01],
[-3.4205e-01, 1.4121e+00, 2.6875e+00, ..., 2.3489e-01,
1.2428e-01, -6.1681e-02],
...,
[-1.6287e+00, -9.6626e-01, -2.6581e-03, ..., 1.2085e+00,
-9.0730e-01, 4.8345e-02],
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01,
1.3140e+00, 2.3112e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]]], device='cuda:0')
shape: (256, 100, 8)
min/max: (-4.203547954559326, 4.436890602111816)
mean: -0.04427793249487877
std: 0.9962211847305298
mamba_block1.inp_proj.weight gradient: 9.651295840740204e-06
mamba_block1.inp_proj.bias gradient: 1.6675130609655753e-05
mamba_block1.out_proj.weight gradient: 9.673438034951687e-05
mamba_block1.out_proj.bias gradient: 0.002231176942586899
mamba_block1.D.weight gradient: 2.4628068786114454e-05
mamba_block1.D.bias gradient: 3.227106572012417e-05
mamba_block1.S6.A_log gradient: 0.0
mamba_block1.S6.fc1.weight gradient: 3.960161848226562e-06
mamba_block1.S6.fc1.bias gradient: 5.747951036028098e-06
mamba_block1.S6.fc2.weight gradient: 2.7824979042634368e-05
mamba_block1.S6.fc2.bias gradient: 4.349627124611288e-05
mamba_block1.S6.fc3.weight gradient: 2.6914129193755798e-05
mamba_block1.S6.fc3.bias gradient: 4.2169736843788996e-05
mamba_block1.conv.weight gradient: 5.887298539164476e-05
mamba_block1.conv.bias gradient: 1.2363690075289924e-05
mamba_block1.conv_linear.weight gradient: 2.250147554150317e-05
mamba_block1.conv_linear.bias gradient: 7.393556734314188e-05
mamba_block1.norm.weight gradient: 7.986875061760657e-06
mamba_block2.inp_proj.weight gradient: 0.00992958340793848
mamba_block2.inp_proj.bias gradient: 0.003502229694277048
mamba_block2.out_proj.weight gradient: 0.00934018474072218
mamba_block2.out_proj.bias gradient: 0.021850064396858215
mamba_block2.D.weight gradient: 0.005657918751239777
mamba_block2.D.bias gradient: 0.0019955879542976618
mamba_block2.S6.A_log gradient: 0.0
mamba_block2.S6.fc1.weight gradient: 0.0012419703416526318
mamba_block2.S6.fc1.bias gradient: 0.0015505834016948938
mamba_block2.S6.fc2.weight gradient: 0.0034699649550020695
mamba_block2.S6.fc2.bias gradient: 0.0044254641979932785
mamba_block2.S6.fc3.weight gradient: 0.0032270450610667467
mamba_block2.S6.fc3.bias gradient: 0.0040438235737383366
mamba_block2.conv.weight gradient: 0.01314946822822094
mamba_block2.conv.bias gradient: 0.000992308254353702
mamba_block2.conv_linear.weight gradient: 0.01082244236022234
mamba_block2.conv_linear.bias gradient: 0.006649897433817387
mamba_block2.norm.weight gradient: 0.003846563631668687
mamba_block3.inp_proj.weight gradient: 0.09649848937988281
mamba_block3.inp_proj.bias gradient: 0.033982012420892715
mamba_block3.out_proj.weight gradient: 0.04800761863589287
mamba_block3.out_proj.bias gradient: 5.8161077731710975e-08
mamba_block3.D.weight gradient: 0.04944797605276108
mamba_block3.D.bias gradient: 0.017467014491558075
mamba_block3.S6.A_log gradient: 0.0
mamba_block3.S6.fc1.weight gradient: 0.004728706553578377
mamba_block3.S6.fc1.bias gradient: 0.004932350944727659
mamba_block3.S6.fc2.weight gradient: 0.017793580889701843
mamba_block3.S6.fc2.bias gradient: 0.018198303878307343
mamba_block3.S6.fc3.weight gradient: 0.018742457032203674
mamba_block3.S6.fc3.bias gradient: 0.01912083476781845
mamba_block3.conv.weight gradient: 0.17541822791099548
mamba_block3.conv.bias gradient: 0.01787564344704151
mamba_block3.conv_linear.weight gradient: 0.07878080010414124
mamba_block3.conv_linear.bias gradient: 0.056361325085163116
mamba_block3.norm.weight gradient: 0.024798491969704628
DEBUGGING IS ON !!!
output = tensor([[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955],
[1.0273, 1.0139, 0.9940, ..., 1.0381, 1.0394, 0.9929],
[1.0155, 1.0009, 1.0025, ..., 1.0066, 1.0130, 0.9960]],
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0054, 1.0066, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955],
[1.0273, 1.0140, 0.9940, ..., 1.0380, 1.0393, 0.9929],
[1.0155, 1.0009, 1.0025, ..., 1.0066, 1.0130, 0.9960]],
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0054, 1.0066, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955],
[1.0273, 1.0140, 0.9940, ..., 1.0381, 1.0394, 0.9929],
[1.0155, 1.0010, 1.0025, ..., 1.0066, 1.0130, 0.9960]],
...,
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9948],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955],
[1.0273, 1.0140, 0.9940, ..., 1.0381, 1.0394, 0.9929],
[1.0155, 1.0009, 1.0025, ..., 1.0066, 1.0130, 0.9960]],
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0054, 1.0066, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955],
[1.0273, 1.0140, 0.9940, ..., 1.0381, 1.0394, 0.9929],
[1.0155, 1.0010, 1.0025, ..., 1.0066, 1.0130, 0.9960]],
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955],
[1.0273, 1.0140, 0.9940, ..., 1.0380, 1.0394, 0.9929],
[1.0155, 1.0010, 1.0025, ..., 1.0066, 1.0130, 0.9960]]],
device='cuda:0', grad_fn=<ViewBackward0>)
shape: (256, 100, 8)
min/max: (0.9628592133522034, 1.0394682884216309)
mean: 1.0009617805480957
std: 0.008160555735230446
target = tensor([[[ 0.1249, 0.1479, 0.4132, ..., -0.2172, -0.6020, 0.3062],
[-1.3817, 0.1942, -1.2593, ..., 0.0592, 0.4673, 0.3583],
[ 0.4974, 0.7906, -1.0028, ..., 1.2410, -1.4670, -1.0270],
...,
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199],
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199],
[ 2.3047, 0.6387, -1.2971, ..., -0.9008, -0.7687, -0.1274],
...,
[-2.1382, -0.4375, 0.1092, ..., -0.3279, 0.5643, 0.2475],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.0803, -0.3617, -0.1729, ..., -0.1031, 1.7060, -0.8089],
[ 0.4838, -0.5874, -1.1409, ..., -0.1160, -0.5902, 0.5632],
[-0.0287, -1.1162, -0.5596, ..., 1.2069, -1.2071, 0.0189],
...,
[-0.5417, 0.3940, 0.3822, ..., -0.3933, -1.1325, -0.0510],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
...,
[[ 1.1773, 0.2168, 0.4060, ..., -0.1085, 0.1342, 0.5152],
[ 0.1019, -0.3633, 1.4970, ..., -1.1070, 1.8470, 0.7881],
[-0.4072, 0.9312, 1.0190, ..., -0.9175, 0.1262, -0.9890],
...,
[ 0.8941, -2.4687, 0.5529, ..., 0.0181, 0.2483, 0.0552],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 1.0778, -0.1816, -0.6237, ..., 0.5324, -0.4506, -0.2228],
[-0.0915, 0.1275, 0.7293, ..., 1.1558, -0.3669, -0.2044],
[ 0.4974, 0.7906, -1.0028, ..., 1.2410, -1.4670, -1.0270],
...,
[ 0.9335, 0.1877, -2.0042, ..., -1.1503, -1.7980, -0.5640],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-0.3029, -0.7849, 0.6537, ..., 0.0216, -0.0510, 1.3417],
[-0.3029, -0.7849, 0.6537, ..., 0.0216, -0.0510, 1.3417],
[-0.3334, 1.1392, 0.6457, ..., -0.8796, -1.0417, -0.8816],
...,
[ 0.0445, -1.4234, -0.5175, ..., 1.3668, -0.1756, 0.9730],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]],
device='cuda:0')
shape: (256, 100, 8)
min/max: (-4.0426435470581055, 3.9631638526916504)
mean: -0.03936820104718208
std: 0.9987026453018188
mamba_block1.inp_proj.weight gradient: 7.3608516686363146e-06
mamba_block1.inp_proj.bias gradient: 1.3926341125625186e-05
mamba_block1.out_proj.weight gradient: 9.16551289265044e-05
mamba_block1.out_proj.bias gradient: 0.0020849842112511396
mamba_block1.D.weight gradient: 2.415824019408319e-05
mamba_block1.D.bias gradient: 2.9187509426265024e-05
mamba_block1.S6.A_log gradient: 0.0
mamba_block1.S6.fc1.weight gradient: 3.2261905289487913e-06
mamba_block1.S6.fc1.bias gradient: 4.610864834830863e-06
mamba_block1.S6.fc2.weight gradient: 1.9801318558165804e-05
mamba_block1.S6.fc2.bias gradient: 3.303638368379325e-05
mamba_block1.S6.fc3.weight gradient: 1.9410605091252364e-05
mamba_block1.S6.fc3.bias gradient: 3.262169411755167e-05
mamba_block1.conv.weight gradient: 5.535763193620369e-05
mamba_block1.conv.bias gradient: 1.002511453407351e-05
mamba_block1.conv_linear.weight gradient: 2.053168100246694e-05
mamba_block1.conv_linear.bias gradient: 6.233315070858225e-05
mamba_block1.norm.weight gradient: 6.316646249615587e-06
mamba_block2.inp_proj.weight gradient: 0.009400204755365849
mamba_block2.inp_proj.bias gradient: 0.003315509995445609
mamba_block2.out_proj.weight gradient: 0.009219340980052948
mamba_block2.out_proj.bias gradient: 0.021291621029376984
mamba_block2.D.weight gradient: 0.005432978272438049
mamba_block2.D.bias gradient: 0.0019162470707669854
mamba_block2.S6.A_log gradient: 0.0
mamba_block2.S6.fc1.weight gradient: 0.0012191222049295902
mamba_block2.S6.fc1.bias gradient: 0.00152978312689811
mamba_block2.S6.fc2.weight gradient: 0.003414183622226119
mamba_block2.S6.fc2.bias gradient: 0.004413718823343515
mamba_block2.S6.fc3.weight gradient: 0.003169463714584708
mamba_block2.S6.fc3.bias gradient: 0.004032541066408157
mamba_block2.conv.weight gradient: 0.012530333362519741
mamba_block2.conv.bias gradient: 0.0009309174492955208
mamba_block2.conv_linear.weight gradient: 0.010265973396599293
mamba_block2.conv_linear.bias gradient: 0.0068965875543653965
mamba_block2.norm.weight gradient: 0.0036086307372897863
mamba_block3.inp_proj.weight gradient: 0.09822450578212738
mamba_block3.inp_proj.bias gradient: 0.034595269709825516
mamba_block3.out_proj.weight gradient: 0.04629985988140106
mamba_block3.out_proj.bias gradient: 7.177833083460428e-08
mamba_block3.D.weight gradient: 0.04594907537102699
mamba_block3.D.bias gradient: 0.01622462458908558
mamba_block3.S6.A_log gradient: 0.0
mamba_block3.S6.fc1.weight gradient: 0.004941641353070736
mamba_block3.S6.fc1.bias gradient: 0.005250551737844944
mamba_block3.S6.fc2.weight gradient: 0.01758144423365593
mamba_block3.S6.fc2.bias gradient: 0.015057405456900597
mamba_block3.S6.fc3.weight gradient: 0.018423432484269142
mamba_block3.S6.fc3.bias gradient: 0.015630599111318588
mamba_block3.conv.weight gradient: 0.17396844923496246
mamba_block3.conv.bias gradient: 0.018070943653583527
mamba_block3.conv_linear.weight gradient: 0.07747268676757812
mamba_block3.conv_linear.bias gradient: 0.04878745600581169
mamba_block3.norm.weight gradient: 0.023946603760123253
DEBUGGING IS ON !!!
output = tensor([[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955],
[1.0274, 1.0140, 0.9941, ..., 1.0382, 1.0396, 0.9929],
[1.0155, 1.0010, 1.0025, ..., 1.0066, 1.0130, 0.9960]],
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955],
[1.0274, 1.0140, 0.9941, ..., 1.0382, 1.0396, 0.9930],
[1.0155, 1.0009, 1.0025, ..., 1.0066, 1.0130, 0.9960]],
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955],
[1.0274, 1.0140, 0.9941, ..., 1.0382, 1.0395, 0.9929],
[1.0155, 1.0009, 1.0025, ..., 1.0066, 1.0130, 0.9960]],
...,
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955],
[1.0274, 1.0139, 0.9941, ..., 1.0382, 1.0396, 0.9929],
[1.0155, 1.0009, 1.0025, ..., 1.0066, 1.0130, 0.9960]],
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955],
[1.0274, 1.0139, 0.9941, ..., 1.0382, 1.0396, 0.9929],
[1.0155, 1.0010, 1.0025, ..., 1.0066, 1.0130, 0.9960]],
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955],
[1.0274, 1.0140, 0.9941, ..., 1.0382, 1.0396, 0.9929],
[1.0155, 1.0009, 1.0025, ..., 1.0066, 1.0130, 0.9960]]],
device='cuda:0', grad_fn=<ViewBackward0>)
shape: (256, 100, 8)
min/max: (0.9628267288208008, 1.0396476984024048)
mean: 1.000962495803833
std: 0.008170326240360737
target = tensor([[[-1.5137, 0.0657, -0.9680, ..., 1.6269, -0.2294, 0.1420],
[-0.0282, 0.8765, 0.3617, ..., -0.8538, 0.5377, -1.6134],
[-0.1533, 0.5930, 0.1322, ..., -1.6126, -0.6235, -1.3132],
...,
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.4974, 0.7906, -1.0028, ..., 1.2410, -1.4670, -1.0270],
[-0.3501, 0.5389, -0.7310, ..., -0.0815, 0.4691, 0.4229],
[ 1.8674, -0.6901, -1.5037, ..., 0.8689, 1.6506, 0.1824],
...,
[ 1.7151, 1.0070, 0.6890, ..., -2.3825, -0.5136, 0.5498],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 1.1616, -1.3193, -0.6140, ..., 1.9209, 1.3262, 0.0599],
[ 0.3868, -1.0279, -0.3675, ..., -0.6507, 0.5047, -0.5453],
[-0.0137, -1.7578, 0.7203, ..., -0.7771, 1.8718, -0.1505],
...,
[-0.8427, 1.6156, 1.2061, ..., 0.4317, 1.9322, 0.3907],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
...,
[[ 0.8524, -0.2206, 0.9268, ..., -0.7495, -0.6237, 0.3975],
[-1.7090, -1.0052, -1.0034, ..., -0.9609, -1.5528, -0.8253],
[ 1.1688, 0.4326, 0.6992, ..., -0.6485, -0.1625, 1.0952],
...,
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.3843, -1.3493, -1.1372, ..., -1.0553, 0.6164, 1.1378],
[-1.3817, 0.1942, -1.2593, ..., 0.0592, 0.4673, 0.3583],
[ 0.4974, 0.7906, -1.0028, ..., 1.2410, -1.4670, -1.0270],
...,
[-1.8815, 0.8164, 0.5484, ..., -0.3336, 0.1990, 1.8763],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-2.3747, -0.7223, 0.2557, ..., 0.5687, -0.0835, -2.1125],
[-1.0612, -0.9659, 0.0180, ..., 2.1914, -2.7829, 1.4622],
[-1.3817, 0.1942, -1.2593, ..., 0.0592, 0.4673, 0.3583],
...,
[-1.6361, 0.3128, -1.7070, ..., 1.2532, -1.0657, 0.4411],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]],
device='cuda:0')
shape: (256, 100, 8)
min/max: (-3.7949037551879883, 3.9631638526916504)
mean: -0.04261881858110428
std: 0.9969062209129333
mamba_block1.inp_proj.weight gradient: 1.1255859135417268e-05
mamba_block1.inp_proj.bias gradient: 1.5913874449324794e-05
mamba_block1.out_proj.weight gradient: 8.910077303880826e-05
mamba_block1.out_proj.bias gradient: 0.0021355818025767803
mamba_block1.D.weight gradient: 2.525844865886029e-05
mamba_block1.D.bias gradient: 3.0218390747904778e-05
mamba_block1.S6.A_log gradient: 0.0
mamba_block1.S6.fc1.weight gradient: 3.48461890098406e-06
mamba_block1.S6.fc1.bias gradient: 4.768642611452378e-06
mamba_block1.S6.fc2.weight gradient: 1.4115617887000553e-05
mamba_block1.S6.fc2.bias gradient: 2.272063647978939e-05
mamba_block1.S6.fc3.weight gradient: 1.4028752957528923e-05
mamba_block1.S6.fc3.bias gradient: 2.2550859284820035e-05
mamba_block1.conv.weight gradient: 5.71208875044249e-05
mamba_block1.conv.bias gradient: 1.0420963008073159e-05
mamba_block1.conv_linear.weight gradient: 2.3161373974289745e-05
mamba_block1.conv_linear.bias gradient: 6.114652205724269e-05
mamba_block1.norm.weight gradient: 3.6366443509905366e-06
mamba_block2.inp_proj.weight gradient: 0.009616881608963013
mamba_block2.inp_proj.bias gradient: 0.0033918858971446753
mamba_block2.out_proj.weight gradient: 0.009563383646309376
mamba_block2.out_proj.bias gradient: 0.02152535691857338
mamba_block2.D.weight gradient: 0.005726755131036043
mamba_block2.D.bias gradient: 0.0020198312122374773
mamba_block2.S6.A_log gradient: 0.0
mamba_block2.S6.fc1.weight gradient: 0.0013312151422724128
mamba_block2.S6.fc1.bias gradient: 0.0016568585997447371
mamba_block2.S6.fc2.weight gradient: 0.0035273043904453516
mamba_block2.S6.fc2.bias gradient: 0.004576574545353651
mamba_block2.S6.fc3.weight gradient: 0.0032715476118028164
mamba_block2.S6.fc3.bias gradient: 0.0041853212751448154
mamba_block2.conv.weight gradient: 0.012951940298080444
mamba_block2.conv.bias gradient: 0.0009483825415372849
mamba_block2.conv_linear.weight gradient: 0.010837584733963013
mamba_block2.conv_linear.bias gradient: 0.007166210561990738
mamba_block2.norm.weight gradient: 0.0036460179835557938
mamba_block3.inp_proj.weight gradient: 0.09701595455408096
mamba_block3.inp_proj.bias gradient: 0.034169118851423264
mamba_block3.out_proj.weight gradient: 0.047714706510305405
mamba_block3.out_proj.bias gradient: 3.2355920609461464e-08
mamba_block3.D.weight gradient: 0.05303538590669632
mamba_block3.D.bias gradient: 0.018727842718362808
mamba_block3.S6.A_log gradient: 0.0
mamba_block3.S6.fc1.weight gradient: 0.005017312243580818
mamba_block3.S6.fc1.bias gradient: 0.0054448735900223255
mamba_block3.S6.fc2.weight gradient: 0.016375340521335602
mamba_block3.S6.fc2.bias gradient: 0.01577238366007805
mamba_block3.S6.fc3.weight gradient: 0.017361776903271675
mamba_block3.S6.fc3.bias gradient: 0.016659947112202644
mamba_block3.conv.weight gradient: 0.17826542258262634
mamba_block3.conv.bias gradient: 0.018036192283034325
mamba_block3.conv_linear.weight gradient: 0.0783582255244255
mamba_block3.conv_linear.bias gradient: 0.04873950034379959
mamba_block3.norm.weight gradient: 0.023995602503418922
DEBUGGING IS ON !!!
output = tensor([[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955],
[1.0276, 1.0140, 0.9941, ..., 1.0383, 1.0398, 0.9930],
[1.0155, 1.0010, 1.0025, ..., 1.0066, 1.0130, 0.9960]],
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955],
[1.0276, 1.0140, 0.9941, ..., 1.0383, 1.0398, 0.9930],
[1.0155, 1.0010, 1.0025, ..., 1.0066, 1.0130, 0.9960]],
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955],
[1.0275, 1.0140, 0.9941, ..., 1.0383, 1.0398, 0.9929],
[1.0155, 1.0009, 1.0025, ..., 1.0066, 1.0130, 0.9960]],
...,
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955],
[1.0276, 1.0140, 0.9941, ..., 1.0383, 1.0398, 0.9930],
[1.0155, 1.0010, 1.0025, ..., 1.0066, 1.0130, 0.9960]],
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955],
[1.0275, 1.0140, 0.9941, ..., 1.0383, 1.0398, 0.9930],
[1.0155, 1.0010, 1.0025, ..., 1.0066, 1.0130, 0.9960]],
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0035, 0.9951],
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955],
[1.0275, 1.0139, 0.9941, ..., 1.0383, 1.0398, 0.9929],
[1.0155, 1.0010, 1.0025, ..., 1.0066, 1.0130, 0.9960]]],
device='cuda:0', grad_fn=<ViewBackward0>)
shape: (256, 100, 8)
min/max: (0.9627799391746521, 1.0398532152175903)
mean: 1.0009632110595703
std: 0.008179647848010063
target = tensor([[[ 0.1461, -1.5373, 1.8414, ..., 2.1830, 2.1411, -0.5229],
[-0.8324, -0.4507, 0.2398, ..., 0.7770, -1.6973, -1.6883],
[-0.1043, 1.6985, 0.2488, ..., 0.7312, 1.5784, 2.1510],
...,
[ 0.6566, -1.6146, 0.8445, ..., -0.1293, -0.1222, -0.1538],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.4719, 0.3617, -0.3579, ..., -0.3869, 1.6128, 0.2484],
[ 0.4135, 0.7320, -0.6458, ..., 1.4670, 0.7813, -1.1558],
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989],
...,
[ 0.6819, 0.6503, -0.0784, ..., 1.2032, 0.4372, 0.0505],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-0.8389, -0.8643, 0.4242, ..., 0.5848, 1.5457, -0.4353],
[-0.1063, -1.2993, 1.8509, ..., -1.6608, 1.7566, -1.0788],
[ 0.6819, 0.6503, -0.0784, ..., 1.2032, 0.4372, 0.0505],
...,
[ 0.8352, 0.9417, -0.3653, ..., -0.0158, -0.0074, 0.4276],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
...,
[[ 0.4102, -0.0038, -0.1229, ..., -0.2580, 1.4403, -0.2463],
[-0.9544, -0.8690, 0.8258, ..., -1.0243, 1.2432, 0.9506],
[ 2.0079, 0.2899, 0.1462, ..., 0.0117, 0.1548, 0.9144],
...,
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989],
[ 0.6819, 0.6503, -0.0784, ..., 1.2032, 0.4372, 0.0505],
[-0.6884, 0.6784, -1.0951, ..., -1.3010, 0.1661, -0.5777],
...,
[ 1.6601, -0.1799, 0.7201, ..., 0.6700, 0.4782, -0.1476],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.8417, 0.6482, 2.4525, ..., -0.0736, -1.3844, -1.5417],
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989],
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989],
...,
[ 0.0948, -1.3712, -1.2927, ..., 0.6679, 0.6076, 0.3466],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]],
device='cuda:0')
shape: (256, 100, 8)
min/max: (-4.062912464141846, 3.9631638526916504)
mean: -0.0383220836520195
std: 0.9961226582527161
mamba_block1.inp_proj.weight gradient: 1.0743613529484719e-05
mamba_block1.inp_proj.bias gradient: 1.3487114301824477e-05
mamba_block1.out_proj.weight gradient: 9.076731657842174e-05
mamba_block1.out_proj.bias gradient: 0.0021569449454545975
mamba_block1.D.weight gradient: 2.2917021851753816e-05
mamba_block1.D.bias gradient: 2.967903492390178e-05
mamba_block1.S6.A_log gradient: 0.0
mamba_block1.S6.fc1.weight gradient: 3.3216970223293174e-06
mamba_block1.S6.fc1.bias gradient: 4.622655069397297e-06
mamba_block1.S6.fc2.weight gradient: 2.1657657271134667e-05
mamba_block1.S6.fc2.bias gradient: 3.360168557264842e-05
mamba_block1.S6.fc3.weight gradient: 2.1000263586756773e-05
mamba_block1.S6.fc3.bias gradient: 3.272424510214478e-05
mamba_block1.conv.weight gradient: 5.6427117669954896e-05
mamba_block1.conv.bias gradient: 9.617741852707695e-06
mamba_block1.conv_linear.weight gradient: 2.0858946299995296e-05
mamba_block1.conv_linear.bias gradient: 6.406073953257874e-05
mamba_block1.norm.weight gradient: 7.72745261201635e-06
mamba_block2.inp_proj.weight gradient: 0.009284038096666336
mamba_block2.inp_proj.bias gradient: 0.0032744971103966236
mamba_block2.out_proj.weight gradient: 0.009324286133050919
mamba_block2.out_proj.bias gradient: 0.022122588008642197
mamba_block2.D.weight gradient: 0.00526660680770874
mamba_block2.D.bias gradient: 0.0018575439462438226
mamba_block2.S6.A_log gradient: 0.0
mamba_block2.S6.fc1.weight gradient: 0.0011835441691800952
mamba_block2.S6.fc1.bias gradient: 0.0014753035502508283
mamba_block2.S6.fc2.weight gradient: 0.0032423110678792
mamba_block2.S6.fc2.bias gradient: 0.004185546655207872
mamba_block2.S6.fc3.weight gradient: 0.0030104692559689283
mamba_block2.S6.fc3.bias gradient: 0.0038207483012229204
mamba_block2.conv.weight gradient: 0.012847086414694786
mamba_block2.conv.bias gradient: 0.0009625973762013018
mamba_block2.conv_linear.weight gradient: 0.010192912071943283
mamba_block2.conv_linear.bias gradient: 0.006669995374977589
mamba_block2.norm.weight gradient: 0.003584688063710928
mamba_block3.inp_proj.weight gradient: 0.10175695270299911
mamba_block3.inp_proj.bias gradient: 0.03584553673863411
mamba_block3.out_proj.weight gradient: 0.04716596007347107
mamba_block3.out_proj.bias gradient: 7.46511830129748e-08
mamba_block3.D.weight gradient: 0.04783427715301514
mamba_block3.D.bias gradient: 0.016893167048692703
mamba_block3.S6.A_log gradient: 0.0
mamba_block3.S6.fc1.weight gradient: 0.0046433014795184135
mamba_block3.S6.fc1.bias gradient: 0.004779241979122162
mamba_block3.S6.fc2.weight gradient: 0.016801055520772934
mamba_block3.S6.fc2.bias gradient: 0.016340354457497597
mamba_block3.S6.fc3.weight gradient: 0.017698541283607483
mamba_block3.S6.fc3.bias gradient: 0.017167871817946434
mamba_block3.conv.weight gradient: 0.17702537775039673
mamba_block3.conv.bias gradient: 0.01801113784313202
mamba_block3.conv_linear.weight gradient: 0.08223015069961548
mamba_block3.conv_linear.bias gradient: 0.05006503686308861
mamba_block3.norm.weight gradient: 0.02444547973573208
DEBUGGING IS ON !!!
output = tensor([[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0035, 0.9951],
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947],
...,
[1.0126, 1.0033, 0.9993, ..., 1.0052, 1.0100, 0.9955],
[1.0277, 1.0139, 0.9942, ..., 1.0385, 1.0400, 0.9930],
[1.0155, 1.0010, 1.0025, ..., 1.0066, 1.0130, 0.9960]],
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0035, 0.9951],
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947],
...,
[1.0126, 1.0033, 0.9993, ..., 1.0052, 1.0101, 0.9955],
[1.0276, 1.0140, 0.9941, ..., 1.0385, 1.0399, 0.9930],
[1.0155, 1.0010, 1.0025, ..., 1.0066, 1.0130, 0.9960]],
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0035, 0.9951],
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0101, 0.9955],
[1.0276, 1.0139, 0.9941, ..., 1.0385, 1.0400, 0.9930],
[1.0155, 1.0010, 1.0025, ..., 1.0066, 1.0130, 0.9960]],
...,
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0035, 0.9951],
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0101, 0.9955],
[1.0276, 1.0140, 0.9941, ..., 1.0385, 1.0400, 0.9930],
[1.0155, 1.0010, 1.0025, ..., 1.0066, 1.0130, 0.9960]],
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0034, 0.9951],
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955],
[1.0276, 1.0140, 0.9941, ..., 1.0385, 1.0399, 0.9930],
[1.0155, 1.0010, 1.0025, ..., 1.0066, 1.0130, 0.9960]],
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950],
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0035, 0.9951],
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947],
...,
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0101, 0.9955],
[1.0276, 1.0140, 0.9941, ..., 1.0385, 1.0400, 0.9930],
[1.0155, 1.0010, 1.0025, ..., 1.0066, 1.0130, 0.9960]]],
device='cuda:0', grad_fn=<ViewBackward0>)
shape: (256, 100, 8)
min/max: (0.9627537727355957, 1.0400527715682983)
mean: 1.0009636878967285
std: 0.008188863284885883
target = tensor([[[ 0.7047, -0.1636, -1.4103, ..., 0.0981, 0.1269, 0.2884],
[-0.0752, -0.2943, -0.5152, ..., -1.0968, 0.3245, -0.6512],
[ 2.5271, 0.3828, 0.4464, ..., 0.1723, -0.5737, 2.5980],
...,
[-0.3420, -0.8499, 1.2070, ..., -0.0510, -0.3249, -0.1682],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-0.0061, 0.0731, 0.1958, ..., -0.5969, -0.9973, -2.2435],
[ 1.1250, 1.1728, 0.5452, ..., -1.0478, 0.4368, 1.5019],
[-1.0631, -0.8864, 0.3796, ..., -0.0263, -1.2731, -1.4496],
...,
[-0.1816, -0.4553, -1.1590, ..., -0.4902, -0.3588, 1.5264],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.9234, -0.9421, 1.0451, ..., -0.4781, 0.0420, 0.2934],
[-1.6287, -0.9663, -0.0027, ..., 1.2085, -0.9073, 0.0483],
[ 0.2567, 1.9508, 2.2958, ..., 0.1782, 0.4551, -1.1158],
...,
[ 0.8881, -0.0437, -1.5893, ..., -0.5971, -0.4100, 1.8774],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
...,
[[-0.6347, 1.8580, -0.0971, ..., 1.7939, 0.2032, -0.1249],
[-0.2951, -0.1044, -1.3054, ..., -0.6431, -0.4934, 0.8809],
[ 0.4335, 0.5802, 1.1978, ..., 1.2420, 0.1373, -0.8186],
...,
[-0.6884, 0.6784, -1.0951, ..., -1.3010, 0.1661, -0.5777],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.3804, -0.0668, -0.8042, ..., 0.1360, 0.0233, -1.5833],
[ 1.5597, -0.4917, -0.4323, ..., -1.5830, -0.4509, -0.0552],
[-0.6078, 0.9015, 0.9592, ..., -0.3502, -0.7853, 1.1148],
...,
[ 0.6637, 0.1741, -0.3558, ..., -1.4354, 1.1672, 0.0448],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[-0.0729, -1.5876, -0.1719, ..., 1.2421, 0.7066, 0.4504],
[ 0.4974, 0.7906, -1.0028, ..., 1.2410, -1.4670, -1.0270],
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199],
...,
[ 0.4974, 0.7906, -1.0028, ..., 1.2410, -1.4670, -1.0270],
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]],
device='cuda:0')
shape: (256, 100, 8)
min/max: (-3.77933669090271, 3.9631638526916504)
mean: -0.04762903228402138
std: 1.0014238357543945
mamba_block1.inp_proj.weight gradient: 8.953470569394995e-06
mamba_block1.inp_proj.bias gradient: 1.7052547264029272e-05
mamba_block1.out_proj.weight gradient: 9.578504250384867e-05
mamba_block1.out_proj.bias gradient: 0.0022079788614064455
mamba_block1.D.weight gradient: 2.5630950403865427e-05
mamba_block1.D.bias gradient: 3.2377221941715106e-05
mamba_block1.S6.A_log gradient: 0.0
mamba_block1.S6.fc1.weight gradient: 3.6499291127256583e-06
mamba_block1.S6.fc1.bias gradient: 5.3718267736257985e-06
mamba_block1.S6.fc2.weight gradient: 2.8647153158090077e-05
mamba_block1.S6.fc2.bias gradient: 4.695907045970671e-05
mamba_block1.S6.fc3.weight gradient: 2.7555784981814213e-05
mamba_block1.S6.fc3.bias gradient: 4.5302771468413994e-05
mamba_block1.conv.weight gradient: 5.673680789186619e-05
mamba_block1.conv.bias gradient: 1.0635959370119963e-05
mamba_block1.conv_linear.weight gradient: 2.1541560272453353e-05
mamba_block1.conv_linear.bias gradient: 7.382209150819108e-05
mamba_block1.norm.weight gradient: 6.72381838739966e-06
mamba_block2.inp_proj.weight gradient: 0.01039169728755951
mamba_block2.inp_proj.bias gradient: 0.0036651124246418476
mamba_block2.out_proj.weight gradient: 0.009969900362193584
mamba_block2.out_proj.bias gradient: 0.022011322900652885
mamba_block2.D.weight gradient: 0.005989375524222851
mamba_block2.D.bias gradient: 0.0021124156191945076
mamba_block2.S6.A_log gradient: 0.0
mamba_block2.S6.fc1.weight gradient: 0.0013773165410384536
mamba_block2.S6.fc1.bias gradient: 0.001692585414275527
mamba_block2.S6.fc2.weight gradient: 0.0037395297549664974
mamba_block2.S6.fc2.bias gradient: 0.004679436329752207
mamba_block2.S6.fc3.weight gradient: 0.0034694750793278217
mamba_block2.S6.fc3.bias gradient: 0.004271483514457941
mamba_block2.conv.weight gradient: 0.013511927798390388
mamba_block2.conv.bias gradient: 0.0010048352414742112
mamba_block2.conv_linear.weight gradient: 0.011564495973289013
mamba_block2.conv_linear.bias gradient: 0.007294516544789076
mamba_block2.norm.weight gradient: 0.003915737848728895
mamba_block3.inp_proj.weight gradient: 0.10254763066768646
mamba_block3.inp_proj.bias gradient: 0.03612254932522774
mamba_block3.out_proj.weight gradient: 0.048949189484119415
mamba_block3.out_proj.bias gradient: 7.861819284471494e-08
mamba_block3.D.weight gradient: 0.049613695591688156
mamba_block3.D.bias gradient: 0.017519617453217506
mamba_block3.S6.A_log gradient: 0.0
mamba_block3.S6.fc1.weight gradient: 0.005079771392047405
mamba_block3.S6.fc1.bias gradient: 0.005594966001808643
mamba_block3.S6.fc2.weight gradient: 0.01893662102520466
mamba_block3.S6.fc2.bias gradient: 0.0198878962546587
mamba_block3.S6.fc3.weight gradient: 0.019843893125653267
mamba_block3.S6.fc3.bias gradient: 0.02052515186369419
mamba_block3.conv.weight gradient: 0.17641305923461914
mamba_block3.conv.bias gradient: 0.01822531968355179
mamba_block3.conv_linear.weight gradient: 0.08087118715047836
mamba_block3.conv_linear.bias gradient: 0.057623837143182755
mamba_block3.norm.weight gradient: 0.024983780458569527
@minaadel
Copy link

I want to thank you for writing out this implementation, which seems to follow the description of the algorithm in the paper more closely than the actual implementation provided in the official repository (Mamba SSM Official Repository) . I would be interested to hear your opinion in where you see your implementation and the official repository diverging and why?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment