#30DaysOfJAX Day 4

So far I learned how JAX.numpy can be converted to XLA ops using Lax lower level API

But where is Jit in all of this?

[this is cool but hard, grab a coffee! β˜•οΈ]

1/12🧡
Even with JAX operations being compiled to run using XLA, they are still executed one at a time in a sequence.

Can we do better?
Yes, Compiling it!🀯

2/12🧡
Compilers are great pieces of software and one the magicsπŸͺ„ they do is to look into your code and creating an optimized faster version by, for example:

β€’ fusing ops βž•
β€’ not allocating temp variables πŸš«πŸ—‘οΈ

3/12🧡
Python is not a compiled language! And this is where Just in time compilations can be used!

When a function is decorated with the jit decorator, that part of the code will be compiled when it's executed for the first time!

Here is a 3.5x improvement!
⚑️⚑️⚑️

4/12🧡
One question I always ask when I learn about this kind of speed improvement is:

-> Why don't we use it for every function ever and all the time???

jax.jit has some limitations: All arrays must have static shapes

Can you point out the error in the image?
πŸ€”β“β”

5/12🧡
The problem with the get_negatives function is that the shape of the returned array is not known at compile time! It changes with the input!

Let's understand that a little bit better and add two important concepts: tracing and static variables

6/12🧡
When you use the jit decorator, JIT will run (trace) your function and try to understand its effects on the inputs regarding shape and types using tracer objects

For example, you can see in the code below that the printing disappears these ops are discarded

7/12🧡
Next time, when your function is run again with the same input shape and type, it will use the compiled version.

If you change type or shape, the function will throw an exception

8/12🧡
You can even inspect the code that is generated by the decorator using make_jaxpr

9/12🧡
But what can you do when you don't want your variable to be traced?
You can mark it as static

It will leave it un-traced but it's value will be evaluated at compile time
When this value changes, the function is re-compiled

10/12🧡
When you need operations that will keep changing (size/data type)
Use np.numpy, otherwise jax.numpy

This is one of the reasons that you won't just use: import jax.numpy as np!!!

11/12🧡
Summary: Jit can make your code much faster taking advantage of the powers of compilation but sometimes it won't work! Understanding how tracing and static variables work is key to leveraging Jit!

JAX, Jit, tracing! I'm starting to understand why people like it! πŸ§πŸ€“

12/12🧡
For the full code of this thread and previous one:
the colab.sandbox.google.com/github/google/…

I highly recommend doing it and making changes to really get the details!

β€’ β€’ β€’

Missing some Tweet in this thread? You can try to force a refresh
γ€€

Keep Current with Luiz GUStavo πŸ’‰πŸ’‰πŸŽ‰

Luiz GUStavo πŸ’‰πŸ’‰πŸŽ‰ Profile picture

Stay in touch and get notified when new unrolls are available from this author!

Read all threads

This Thread may be Removed Anytime!

PDF

Twitter may remove this content at anytime! Save it as PDF for later use!

Try unrolling a thread yourself!

how to unroll video
  1. Follow @ThreadReaderApp to mention us!

  2. From a Twitter thread mention us with a keyword "unroll"
@threadreaderapp unroll

Practice here first or read more on our help page!

More from @gusthema

7 Dec
Have you ever wanted to run a query on your NumPy array?

Like:
β€’ where are the numbers smaller than 0?
β€’ How many numbers are positive?
πŸ€”

NumPy has the where method for exactly that!
πŸ€©πŸ‘πŸΎ

[1 min]⚑️

1/5🧡 #Python
The where method receives:
β€’ condition (a < 0)
β€’ Operation if cond is True (x)
β€’ The "else" operation (y)

You only need the condition, the rest is optional*

2/5🧡
The return value has two possibilities:
β€’ If only condition is used, returns an array with the coordinates of when the condition is true.
β€’ If x and y are used, then it returns the input array with all the changes applied

3/5🧡
Read 5 tweets
1 Dec
What is JAX?

JAX is Autograd and XLA, brought together for high-performance numerical computing and ML research. It provides composable transformations of Python+NumPy programs: differentiate, vectorize, parallelize, JIT compile to GPU/TPU, and more.

πŸ€”πŸ§

#30DaysOfJAX

1/11🧡
That's already a lot to take in!
Let's try to understand the key words first

What is:
β€’ Autograd
β€’ XLA
β€’ Differentiate
β€’ Just-in-time compile

2/11🧡
What is Differentiate?

🚨 If you studied Calculus you might remember this one. (bear with me, don't run!)

Imagine you have a function:

-> f(x) = 3*x + 4

and you want to know how sensible your output (f) is to changes in the input (x).

3/11🧡
Read 11 tweets
11 Nov
Did you know that you can apply styles to your Pandas visualization?

Let's take a brief look at it πŸ‘€

[1 min]
1/8🧡
Now that you have loaded the data, it's very important to understand it.

To help with that it's good to be able to read it properly and formatting the data definitely help!

Let's come back to the New York Taxi fare

2/8🧡
The fare amount is money.

To format a financial value in Python, we would use the string format "${15,.2f}"

Pandas has a style object and a very similar format method:

3/8🧡
Read 8 tweets
10 Nov
Effective Pandas🐼 tip [4]:

When you start to work on a real dataset with more data (millions of records) and want to run a transformation on the data, what should you do?

Let me tell you how to make your execution more than 19000 times faster!!
🀯🀯🀯

[1 effective min]

1/7🧡
From the documentation, the way to do that would be using the apply method.

It receives a function that is applied to the data (row or col)

Let's try a basic operation: col2 - col1

2/7🧡
Using that on a dataset with 25 million rows, it took 11 minutes! 🐌🐌🐌

Additionally, it uses a lot of memory! On Kaggle Kernels, it almost used all the 16GB of memory available during processing!

Can we do it faster?πŸ€”

3/7🧡
Read 7 tweets
6 Nov
How can we change a 3 minute load time to 1 second?
⚑️⚑️⚑️🀯

As a Pandas🐼 user, the read_csv method might be very dear πŸ’•to you.
But even with a lot of tuning, it will still be slow.

Let's make it faster!!!

[1 ⚑️ min]

1/7🧡 Image
As a ML developer or Data Scientist, [re]loading data is something you do many many times a day!

Having long loading times can make experimentation annoying as everytime you do it, you'll "pay" the time-tax

2/7🧡
One trick to make loading faster is to use a faster file format!

Let's try the Feather file format.

It is a portable that uses the Arrow IPC format: arrow.apache.org/docs/python/ip…

3/7🧡
Read 7 tweets
5 Nov
Imagine you need to load a very large (eg: 5.7GB) csv file to train your model!πŸ€”

This is a very common problem in real world situations and also in many Kaggle competitions!

How can we use Pandas 🐼 effectively to do that?

Let's dive in…

[2 effective min]

1/10🧡
We will use the New York City Taxi Fare Prediction dataset from Kaggle

The csv file has 5.7 GB!!! 😱

Let's try the most obvious thing, just loading it:

df = pd.read_csv("./new-york-city-taxi-fare-prediction/train.csv")

This won't load on Kaggle Kernels!
2/10🧡
That's a bummerβ€¦πŸ˜­

How do I even get to see which columns are in the file?

We can start by loading only some rows (eg: 5) and get some insights.πŸ”

This can give some good information already

3/10🧡
Read 11 tweets

Did Thread Reader help you today?

Support us! We are indie developers!


This site is made by just two indie developers on a laptop doing marketing, support and development! Read more about the story.

Become a Premium Member ($3/month or $30/year) and get exclusive features!

Become Premium

Too expensive? Make a small donation by buying us coffee ($5) or help with server cost ($10)

Donate via Paypal

Or Donate anonymously using crypto!

Ethereum

0xfe58350B80634f60Fa6Dc149a72b4DFbc17D341E copy

Bitcoin

3ATGMxNzCUFzxpMCHL5sWSt4DVtS8UqXpi copy

Thank you for your support!

Follow Us on Twitter!

:(