JAX 的技能之一 JIT

zidea 2021-08-15 21:13:51 阅读数:10

本文一共[544]字,预计阅读时长:1分钟~
jax 技能 jit

在這一節中,我們將深究一下 JAX 是如何運行的。將聊一聊 JAX 的 jax.jit()轉換,將對一個 JAX 的 Python 函數進行及時(JIT)編譯,以便在 XLA 中有效地執行。

在之前有關 JAX 分享中,我們已經了解到了 JAX 可以將 Python 函數進行轉換得到一個新的函數。這是通過首先將 Python 函數轉換為一種叫做 jaxpr 的簡單的中間語言來實現的。然後,轉換在 jaxpr 的錶示上工作。

接下來使用jax.make_jaxpr來展示一個函數的 jaxpr 錶示一個 python 函數。

從概念上講,把 JAX 轉換首先要做的是的 Python 函數化為一個輕量級、具有良好錶現形式的中間形式,這個過程可以理解為特定的 trace,Jaxpr 經過內部解釋器執行變換。JAX 能够在如此小巧的軟件包中塞入這麼多的功能,其中原因是不僅從一個熟悉的、靈活的編程接口(帶有 NumPy 的Python)開始,並使用實際的 Python 解釋器來完成大部分繁重的工作,將計算的本質提煉成一種簡單的靜態類型的錶達式語言,具有有限的高階特征。這種語言就是 jaxpr 語言。

import jax
import jax.numpy as jnp
global_list = []
def log2(x):
global_list.append(x)
ln_x = jnp.log(x)
ln_2 = jnp.log(2.0)
return ln_x / ln_2
print(jax.make_jaxpr(log2)(3.0))
複制代碼
{ lambda ; a.
let b = log a
c = log 2.0
d = div b c
in (d,) }
複制代碼

文檔中的 "理解Jaxprs "部分提供了關於上述輸出含義的更多信息。

重要的是,請注意jaxpr 並沒有對函數的副作用進行 trace:在轉換得到中 jaxpr 中沒有找到global_list.append(x)的內容。這是一個特點,而不是一個 error 。JAX 的設計是為了理解無副作用(也就是純函數)的代碼。

JAX 內部錶示是純函數式的,但考慮到 Python 語言高度動態性特點,對用戶使用上有一些編程限制。比如 JAX 自動微分的 Python 函數只支持純函數,要求用戶自行保證這一點。如用戶代碼寫了副作用,可能經過 JAX 變換生成的函數執行結果不符合期望。因 JAX trace 函數為純函數,當全局變量、配置信息發生變化,可能需要重新 trace。

trace 過程中,JAX 用追踪器對象(tracer object)來包裹每個參數,然後這些追踪器記錄了在函數調用過程中對參數進行的所有 JAX 操作(這發生在普通的 Python 中)。然後,JAX 使用追踪器的記錄來重構整個函數。這個重構的輸出是就是中間的 jaxpr。因為追踪器不會記錄 Python 的副作用,副作用的代碼不會出現在 jaxpr 中。其中在跟踪過程中,副作用仍然發生。

def log2_with_print(x):
print("printed x:", x)
ln_x = jnp.log(x)
ln_2 = jnp.log(2.0)
return ln_x / ln_2
print(jax.make_jaxpr(log2_with_print)(3.))
複制代碼

注意:Python 的 print() 函數也並非存函數,因為文本輸出輸入 IO 操作,可以看成副作用,所以 print 也並非純函數。因此,任何print() 並不會出現在 jaxpr 中。

printed x: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
{ lambda ; a.
let b = log a
c = log 2.0
d = div b c
in (d,) }
複制代碼

看到打印出來的 x 是一個跟踪對象了嗎?那是 JAX 內部的工作。Python 代碼至少運行一次這一事實嚴格來說是一個實現細節,所以不應該被依賴。然而,理解這一點很有用,因為你可以在調試時使用來打印出計算的中間值。

版权声明:本文为[zidea]所创,转载请带上原文链接,感谢。 https://gsmany.com/2021/08/20210815211248105c.html