The where method receives:
β’ condition (a < 0)
β’ Operation if cond is True (x)
β’ The "else" operation (y)
You only need the condition, the rest is optional*
2/5π§΅
The return value has two possibilities:
β’ If only condition is used, returns an array with the coordinates of when the condition is true.
β’ If x and y are used, then it returns the input array with all the changes applied
3/5π§΅
For the case with condition only, it would do the same as calling np.asarray(condition).nonzero() which is the preferred way!
It returns a tuple with position of where on the input the condition is True
You can use it to find out how many values were found or index by it
4/5π§΅
π‘: The where method specifically is helping me in many #AdventOfCode challenges as it makes it faster/easier to find special values in matrices!
Have you used it before? Where? ππΎπ
5/5π§΅
β’ β’ β’
Missing some Tweet in this thread? You can try to
force a refresh
So far I learned how JAX.numpy can be converted to XLA ops using Lax lower level API
But where is Jit in all of this?
[this is cool but hard, grab a coffee! βοΈ]
1/12π§΅
Even with JAX operations being compiled to run using XLA, they are still executed one at a time in a sequence.
Can we do better?
Yes, Compiling it!π€―
2/12π§΅
Compilers are great pieces of software and one the magicsπͺ they do is to look into your code and creating an optimized faster version by, for example:
β’ fusing ops β
β’ not allocating temp variables π«ποΈ
JAX is Autograd and XLA, brought together for high-performance numerical computing and ML research. It provides composable transformations of Python+NumPy programs: differentiate, vectorize, parallelize, JIT compile to GPU/TPU, and more.