Skip to content

Instantly share code, notes, and snippets.

HloModule jit__slice_and_replicate, is_scheduled=true, entry_computation_layout={(f32[256,512,256]{2,1,0:T(8,128)}, s32[]{:T(128)})->f32[524288]{0:T(1024)}}, allow_spmd_sharding_propagation_to_parameters={false,true}, allow_spmd_sharding_propagation_to_output={true}, num_partitions=4
ENTRY main.11_spmd {
constant = s32[]{:T(128)} constant(0)
constant.1 = s32[]{:T(128)} constant(134217728)
param = f32[256,512,256]{2,1,0:T(8,128)} parameter(0), sharding={devices=[4,1,1]<=[4]}, metadata={op_name="in_array"}
param.1 = s32[]{:T(128)} parameter(1), sharding={replicated}, metadata={op_name="offset"}
bitcast.3 = f32[16384,2,8,128]{3,2,1,0:T(8,128)} bitcast(param)
copy.2 = f32[16384,2,8,128]{3,1,2,0:T(2,128)S(3)} copy(bitcast.3)
bitcast.2 = f32[33554432]{0:T(1024)S(3)} bitcast(copy.2)
@dlwh
dlwh / coordinator.log
Created March 4, 2024 06:07
tpu coordination service crash
(maybe coordinator?)
2024-03-04 05:14:56.199355: E external/tsl/tsl/distributed_runtime/coordination/coordination_service.cc:584] /job:jax_worker/replica:0/task:11 unexpectedly tried to connect with a different incarnation. It has likely restarted.
2024-03-04 05:14:56.199447: E external/tsl/tsl/distributed_runtime/coordination/coordination_service.cc:992] /job:jax_worker/replica:0/task:11 has been set to ERROR in coordination service: ABORTED: /job:jax_worker/replica:0/task:11 unexpectedly tried to connect with a different incarnation. It has likely restarted. [type.googleapis.com/tensorflow.CoordinationServiceError='\"\x0e\n\njax_worker\x10\x0b']
2024-03-04 05:14:56.199460: E external/tsl/tsl/distributed_runtime/coordination/coordination_service.cc:828] Stopping coordination service as there is no service-to-client connection, but we encountered an error: ABORTED: /job:jax_worker/replica:0/task:11 unexpectedly tried to connect with a different incarnation. It has likely restarted. [type.googleapis.com/tenso
@dlwh
dlwh / run_clm.py
Created August 12, 2022 18:11
bad eval output
from itertools import chain
from typing import Optional
import numpy as np
import datasets
import torch
import transformers
[debug]
[debug] Initial source changes:
[debug] removed:Set()
[debug] added: Set()
[debug] modified: Set(/Users/dlwh/src/breeze/math/src/main/scala/breeze/linalg/operators/DenseMatrixOps.scala)
[debug] Invalidated products: Set()
[debug] External API changes: API Changes: Set()
[debug] Modified binary dependencies: Set()
[debug] Initial directly invalidated classes: Set(breeze.linalg.operators.DenseMatrixMultOps, breeze.linalg.operators.LowPriorityDenseMatrix1.SetDMVOp, breeze.linalg.operators.DenseMatrixFloatMultiplyStuff, breeze.linalg.operators.LowPriorityDenseMatrix.SetMSOp, breeze.linalg.operators.DenseMatrix_OrderingOps, breeze.linalg.operators.DenseMatrixFloatMultiplyStuff.implOpSolveMatrixBy_DMF_DVF_eq_DVF, breeze.linalg.operators.LowPriorityDenseMatrix.SetDMDVOp, breeze.linalg.operators.DenseMatrixFloatMultiplyStuff.implOpMulMatrix_DMF_DVF_eq_DVF, breeze.linalg.operators.DenseMatrixOps, breeze.linalg.operators.DenseMatrixMultiplyStuff.implOpMulMatrix_DMD_DMD_eq_DMD, breeze.linalg.operators.Dense
@dlwh
dlwh / build.sbt
Created January 21, 2015 23:14
javacl sbt error
organization := "org.scalanlp"
name := "qqq"
version := "0.1-SNAPSHOT"
scalaVersion := "2.11.4"
libraryDependencies ++= Seq(
"com.nativelibs4java" % "javacl" % "1.0-SNAPSHOT"
@dlwh
dlwh / scEnrichColl.scala
Created September 3, 2014 01:38
toMultiMap
implicit class scEnrichColl[Coll <: Traversable[(_,_)]](val __this: Coll) extends AnyVal {
def toMultiMap[Result, A, B](implicit view: Coll <:< Traversable[(A, B)], cbf: CanBuildFrom[Coll, B, Result]): Map[A, Result] = {
var result = collection.mutable.Map[A, mutable.Builder[B, Result]]()
result = result.withDefault { a => val r = cbf(__this); result.update(a, r); r}
for((a,b) <- view(__this)) {
result(a) += b
}
result.mapValues(_.result()).toMap
@dlwh
dlwh / UpdateSerializedObjects.scala
Created August 31, 2014 06:39
For updating model files
import java.io._
import java.util.zip.GZIPInputStream
import breeze.util.SerializableLogging
/**
* Class that reads in objects serialized with [[breeze.util.writeObject]], ignoring their serialversionuids,
* and then writes them to the same file.
*
* @author dlwh
@dlwh
dlwh / gist:8c9a1ef767a905bb8e98
Created July 14, 2014 01:40
epic dependency graph
[info] org.scalanlp:epic_2.10:0.2-SNAPSHOT [S]
[info] +-de.jflex:jflex:1.4.3
[info] +-org.mapdb:mapdb:0.9.2
[info] +-org.scala-lang:scala-library:2.10.2 (evicted by: 2.10.4)
[info] +-org.scala-lang:scala-library:2.10.3 (evicted by: 2.10.4)
[info] +-org.scalanlp:breeze-config_2.10:0.8 (evicted by: 0.8.1-SNAPSHOT)
[info] +-org.scalanlp:breeze-config_2.10:0.8.1-SNAPSHOT [S]
[info] | +-com.thoughtworks.paranamer:paranamer:2.2
[info] | +-org.scala-lang:scala-reflect:2.10.4 [S]
[info] |
import breeze.linalg._
import gust.linalg.cuda._
import jcuda.jcublas._
implicit val handle = new cublasHandle()
JCublas2.cublasCreate(handle)
// an extract from a REPL session
val A = CuMatrix.fromDense(DenseMatrix((1.0f, 1.0f, 2.0f, 1.0f), (1.0f, 2.0f, 1.0f, -2.0f), (3.0f, -1.0f, 3.0f, -2.0f), (-2.0f, 3.0f, -1.0f, 1.0f)))
val N = 4
@dlwh
dlwh / NMF.scala
Last active August 29, 2015 13:57
client code
object NMF {
def supervised(W: CuMatrix[Float], X: CuMatrix[Float], iters: Int = 200, eps: Float = 1E-6f) = {
require(W.rows == X.rows)
import W.blas
val n = X.rows
val m = X.cols
val r = W.cols
var H = CuMatrix.ones[Float](r, m)