PyTorch — это библиотека машинного обучения с открытым исходным кодом, которая в основном используется для компьютерного зрения и обработки естественного языка в Python. Также мы рассмотрим различные примеры, связанные с Jax и PyTorch.

JAX означает «Сразу после выполнения». Это библиотека машинного обучения, разработанная DeepMind. Jax — это JIT-компилятор(Just In Time), ориентированный на управление максимальным количеством FLOPS, создающих оптимизированный код при использовании Python.

Jax — это язык программирования, показывающий и преобразующий числовые программы. Он также способен компилировать числовые программы для процессора или ускоряющего графического процессора.

Jax включает код Numpy ON не только для процессора, но также для графического процессора и TPU.

Код:

В следующем коде мы импортируем все необходимые библиотеки, такие как импорт jax.numpy как jnp, импорт grad, jit, vmap из jax и импорт случайных чисел из jax.

  • rd = random.PRNGKey(0) используется для генерации случайных данных, а случайное состояние описывается двумя 32-битными целыми числами без знака, которые мы называем ключом.
  • y = random.normal(rd,(8,)) используется для генерации выборки чисел, взятых из нормального распределения.
  • print(y) используется для печати значений y с помощью функции print().
  • size = 2899 используется для указания размера.
  • random.normal(rd,(size, size), dtype=jnp.float32) используется для генерации выборки чисел, взятых из нормального распределения.
  • %timeit jnp.dot(y, yT).block_until_ready() запускается на графическом процессоре, когда графический процессор доступен, в противном случае он запускается на центральном процессоре.
# Importing libraries
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
# Generating random data
rd = random.PRNGKey(0)
y = random.normal(rd,(8,))
print(y)
# Multiply two matrices
size = 2899
y = random.normal(rd,(size, size), dtype=jnp.float32)
# runs on the GPU
%timeit jnp.dot(y, y.T).block_until_ready()

Выход:

После запуска приведенного выше кода мы получаем следующий вывод, в котором видим, что умножение двух матриц выводится на экран.

Введение в JAX

Содержание

Введение

PyTorch — это библиотека машинного обучения с открытым исходным кодом, которая в основном используется для компьютерного зрения и обработки естественного языка в Python. Он разработан исследовательской лабораторией искусственного интеллекта Facebook.

Это программное обеспечение, выпущенное под модифицированной лицензией BSD. Он построен на основе Python, который поддерживает расчет тензоров на графическом процессоре(GPU).

PyTorch прост в использовании, эффективно использует память, имеет динамический вычислительный граф, является гибким и позволяет создавать осуществимые коды, которые увеличивают скорость обработки. PyTorch — наиболее рекомендуемая библиотека для глубокого обучения и искусственного интеллекта.

Код:

В следующем коде мы импортируем все необходимые библиотеки, такие как import torch и import math.

  • y = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) используется для создания случайных входных и выходных данных.
  • m = torch.randn((), device=device, dtype=dtype) используется для случайной инициализации весов.
  • z_pred = m + n * y + o * y ** 2 + p * y ** 3 используется в качестве прямого прохода для вычисления прогнозируемого Z.
  • loss =(z_pred – z).pow(2).sum().item() используется для вычисления потерь.
  • print(int, loss) используется для печати потерь.
  • grad_m = grad_z_pred.sum(): здесь мы применяем обратное распространение ошибки для вычисления градиентов m, n, o и p относительно потерь.
  • m -= Learning_rate * grad_m используется для обновления весов с использованием градиентного спуска.
  • print(f’Result: z = {m.item()} + {n.item()} y + {o.item()} y^2 + {p.item()} y^3′) используется для печати результата с помощью функции print().
# Importing libraries
import torch 
import math 

# Device Used
dtype = torch.float 
device = torch.device("cpu") 

# Create random input and output data 
y = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) 
z = torch.sin(y)  

# Randomly initialize weights 
m = torch.randn((), device=device, dtype=dtype) 
n = torch.randn((), device=device, dtype=dtype) 
o = torch.randn((), device=device, dtype=dtype) 
p = torch.randn((), device=device, dtype=dtype) 


learning_rate = 1e-6 
for i in range(2000): 
    # Forward pass: compute predicted z 
    z_pred = m + n * y + o * y ** 2 + p * y ** 3 
 
    # Compute and print loss 
    loss =(z_pred - z).pow(2).sum().item() 
    if i % 100 == 99: 
        print(int, loss) 

    # Backprop to compute gradients of m, n, o, p with respect to loss 
    grad_z_pred = 2.0 *(z_pred - z) 
    grad_m = grad_z_pred.sum() 
    grad_n =(grad_z_pred * y).sum() 
    grad_o =(grad_z_pred * y ** 2).sum() 
    grad_p =(grad_z_pred * y ** 3).sum() 
 
    # Update weights using gradient descent 
    m -= learning_rate * grad_m 
    n -= learning_rate * grad_n 
    o -= learning_rate * grad_o 
    p -= learning_rate * grad_p 
 
# Print the result
print(f'Result: z = {m.item()} + {n.item()} y + {o.item()} y^2 + {p.item()} y^3')

Выход:

В приведенном ниже выводе вы можете видеть, что результат выполнения элементов выводится на экран.

Введение в PyTorch

Различиях между Jax и PyTorch

Jax PyTorch
Jax был выпущен в декабре 2018 года. PyTorch был выпущен в октябре 2016 года.
Jax разработан Google PyTorch разработан Facebook
Создание его графика является статическим. Создание графика является динамическим.
Целевая аудитория – исследователи Целевая аудитория — исследователи и разработчики.
Реализация Jax имеет линейную сложность во время выполнения. Реализация PyTorch имеет сложность квадратичного времени.
Jax более гибок, чем PyTorch, поскольку позволяет определять функции, а затем автоматически вычислять производную этих функций. PyTorch является гибким.
Стадия разработки развивается(v0.1.55) Стадия разработки — зрелая(v1.8.0).
Jax более эффективен, чем PyTorch, поскольку он может автоматически распараллеливать наш код на нескольких процессорах. PyTorch эффективен.

Jax и PyTorch и TensorFlow

В этом разделе мы узнаем о ключевых различиях между Jax, PyTorch и TensorFlow в Python.

Jax PyTorch TensorFlow
Jax разработан Google. PyTorch разработан Facebook. TensorFlow разработан Google.
Jax гибкий. PyTorch является гибким. TensorFlow не является гибким.
Целевая аудитория Jax — исследователи. Целевая аудитория PyTorch — исследователи и разработчики. Целевая аудитория TensorFlow — исследователи и разработчики.
Создает статические графики PyTorch создает динамические графики TensorFlow создает как статические, так и динамические графики.
Jax имеет как высокоуровневый, так и низкоуровневый API. PyTorch имеет низкоуровневый API. TensorFlow имеет API высокого уровня.
Jax более эффективен, чем PyTorch и TensorFlow. PyTorch менее эффективен, чем Jax Tensorflow также менее эффективен, чем Jax.
Стадия разработки Jax — «Разработка»(v0.1.55). Стадия разработки PyTorch – зрелая(v.1.8.0). Стадия разработки TensorFlow — зрелая(v2.4.1).

Тест

Jax — это библиотека машинного обучения для изменения числовых функций. Он может собирать числовые программы для процессоров или ускорителей графического процессора.

Код:

  • В следующем коде мы импортируем все необходимые библиотеки, такие как импорт jax.numpy как jnp, импорт grad, jit, vmap из jax и импорт случайных чисел из jax.
  • m = random.PRNGKey(0) используется для генерации случайных данных, а случайное состояние описывается двумя 32-битными целыми числами без знака, которые мы называем ключом.
    i = random.normal(rd,(8,)) используется для генерации выборки чисел, взятых из нормального распределения.
  • print(y) используется для печати значений y с помощью функции print().
    size = 2899 используется для указания размера.
    random.normal(rd,(size, size), dtype=jnp.float32) используется для генерации выборки чисел, взятых из нормального распределения.
    %timeit jnp.dot(y, yT).block_until_ready() запускается на графическом процессоре, когда графический процессор доступен, в противном случае он запускается на центральном процессоре.
  • %timeit jnp.dot(i, iT).block_until_ready(): Здесь мы используем block_until_ready, потому что jax использует асинхронное выполнение.
  • i = num.random.normal(size=(siz, siz)).astype(num.float32) используется для передачи данных на графический процессор
# Importing Libraries
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

# Multiplying Matrices
m = random.PRNGKey(0)
i = random.normal(m,(10,))
print(i)

# Multiply two big matrices
siz = 2800
i = random.normal(m,(siz, siz), dtype=jnp.float32)
%timeit jnp.dot(i, i.T).block_until_ready()  

# Jax Numpy function work on regular numpy array
import numpy as num
i = num.random.normal(size=(siz, siz)).astype(num.float32)
%timeit jnp.dot(i, i.T).block_until_ready()

# Transfer the data to GPU
from jax import device_put

i = num.random.normal(size=(siz, siz)).astype(num.float32)
i = device_put(i)
%timeit jnp.dot(i, i.T).block_until_ready()

Выход:

После запуска приведенного выше кода мы получаем следующий вывод, в котором видим, что умножение матриц с использованием Jax выполняется на экране.

Тест

Тест PyTorch помогает нам убедиться, что наш код соответствует ожиданиям по производительности, и сравнить различные подходы к решению проблем.

В следующем коде мы импортируем все необходимые библиотеки, такие как import torch и import timeit.

  • return m.mul(n).sum(-1) используется для расчета пакетной точки путем умножения и суммирования.
  • m = m.reshape(-1, 1, m.shape[-1]) используется для расчета пакетной точки путем уменьшения до bmm.
  • i = torch.randn(1000, 62) используется в качестве входных данных для сравнительного анализа.
  • print(f’multiply_sum(i, i): {j.timeit(100) / 100 * 1e6:>5.1f} us’) используется для печати значений умножения и суммирования.
  • print(f’bmm(i, i): {j1.timeit(100) / 100 * 1e6:>5.1f} us’) используется для печати значений bmm.
# Import library
import torch
import timeit


# Define the Model
def batcheddot_multiply_sum(m, n):
    # Calculates batched dot by multiplying and sum
    return m.mul(n).sum(-1)


def batcheddot_bmm(m, n):
    #Calculates batched dot by reducing to bmm
    m = m.reshape(-1, 1, m.shape[-1])
    n = n.reshape(-1, n.shape[-1], 1)
    return torch.bmm(m, n).flatten(-3)


# Input for benchmarking
i = torch.randn(1000, 62)

# Ensure that both functions compute the same output
assert batcheddot_multiply_sum(i, i).allclose(batcheddot_bmm(i, i))


# Using timeit.Timer() method
j = timeit.Timer(
    stmt='batcheddot_multiply_sum(i, i)',
    setup='from __main__ import batcheddot_multiply_sum',
    globals={'i': i})

j1 = timeit.Timer(
    stmt='batcheddot_bmm(i, i)',
    setup='from __main__ import batcheddot_bmm',
    globals={'i': i})

print(f'multiply_sum(i, i):  {j.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'bmm(i, i):      {j1.timeit(100) / 100 * 1e6:>5.1f} us')

Выход:

После запуска приведенного выше кода мы получаем следующий вывод, в котором мы видим, что значение умножения и суммы с использованием теста PyTorch выводится на экран.

Тест Jax против PyTorch

Добавить комментарий