PyTorch — это библиотека машинного обучения с открытым исходным кодом, которая в основном используется для компьютерного зрения и обработки естественного языка в Python. Также мы рассмотрим различные примеры, связанные с Jax и PyTorch.
JAX означает «Сразу после выполнения». Это библиотека машинного обучения, разработанная DeepMind. Jax — это JIT-компилятор(Just In Time), ориентированный на управление максимальным количеством FLOPS, создающих оптимизированный код при использовании Python.
Jax — это язык программирования, показывающий и преобразующий числовые программы. Он также способен компилировать числовые программы для процессора или ускоряющего графического процессора.
Jax включает код Numpy ON не только для процессора, но также для графического процессора и TPU.
Код:
В следующем коде мы импортируем все необходимые библиотеки, такие как импорт jax.numpy как jnp, импорт grad, jit, vmap из jax и импорт случайных чисел из jax.
- rd = random.PRNGKey(0) используется для генерации случайных данных, а случайное состояние описывается двумя 32-битными целыми числами без знака, которые мы называем ключом.
- y = random.normal(rd,(8,)) используется для генерации выборки чисел, взятых из нормального распределения.
- print(y) используется для печати значений y с помощью функции print().
- size = 2899 используется для указания размера.
- random.normal(rd,(size, size), dtype=jnp.float32) используется для генерации выборки чисел, взятых из нормального распределения.
- %timeit jnp.dot(y, yT).block_until_ready() запускается на графическом процессоре, когда графический процессор доступен, в противном случае он запускается на центральном процессоре.
# Importing libraries import jax.numpy as jnp from jax import grad, jit, vmap from jax import random # Generating random data rd = random.PRNGKey(0) y = random.normal(rd,(8,)) print(y) # Multiply two matrices size = 2899 y = random.normal(rd,(size, size), dtype=jnp.float32) # runs on the GPU %timeit jnp.dot(y, y.T).block_until_ready()
Выход:
После запуска приведенного выше кода мы получаем следующий вывод, в котором видим, что умножение двух матриц выводится на экран.

Введение
PyTorch — это библиотека машинного обучения с открытым исходным кодом, которая в основном используется для компьютерного зрения и обработки естественного языка в Python. Он разработан исследовательской лабораторией искусственного интеллекта Facebook.
Это программное обеспечение, выпущенное под модифицированной лицензией BSD. Он построен на основе Python, который поддерживает расчет тензоров на графическом процессоре(GPU).
PyTorch прост в использовании, эффективно использует память, имеет динамический вычислительный граф, является гибким и позволяет создавать осуществимые коды, которые увеличивают скорость обработки. PyTorch — наиболее рекомендуемая библиотека для глубокого обучения и искусственного интеллекта.
Код:
В следующем коде мы импортируем все необходимые библиотеки, такие как import torch и import math.
- y = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) используется для создания случайных входных и выходных данных.
- m = torch.randn((), device=device, dtype=dtype) используется для случайной инициализации весов.
- z_pred = m + n * y + o * y ** 2 + p * y ** 3 используется в качестве прямого прохода для вычисления прогнозируемого Z.
- loss =(z_pred – z).pow(2).sum().item() используется для вычисления потерь.
- print(int, loss) используется для печати потерь.
- grad_m = grad_z_pred.sum(): здесь мы применяем обратное распространение ошибки для вычисления градиентов m, n, o и p относительно потерь.
- m -= Learning_rate * grad_m используется для обновления весов с использованием градиентного спуска.
- print(f’Result: z = {m.item()} + {n.item()} y + {o.item()} y^2 + {p.item()} y^3′) используется для печати результата с помощью функции print().
# Importing libraries
import torch
import math
# Device Used
dtype = torch.float
device = torch.device("cpu")
# Create random input and output data
y = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
z = torch.sin(y)
# Randomly initialize weights
m = torch.randn((), device=device, dtype=dtype)
n = torch.randn((), device=device, dtype=dtype)
o = torch.randn((), device=device, dtype=dtype)
p = torch.randn((), device=device, dtype=dtype)
learning_rate = 1e-6
for i in range(2000):
# Forward pass: compute predicted z
z_pred = m + n * y + o * y ** 2 + p * y ** 3
# Compute and print loss
loss =(z_pred - z).pow(2).sum().item()
if i % 100 == 99:
print(int, loss)
# Backprop to compute gradients of m, n, o, p with respect to loss
grad_z_pred = 2.0 *(z_pred - z)
grad_m = grad_z_pred.sum()
grad_n =(grad_z_pred * y).sum()
grad_o =(grad_z_pred * y ** 2).sum()
grad_p =(grad_z_pred * y ** 3).sum()
# Update weights using gradient descent
m -= learning_rate * grad_m
n -= learning_rate * grad_n
o -= learning_rate * grad_o
p -= learning_rate * grad_p
# Print the result
print(f'Result: z = {m.item()} + {n.item()} y + {o.item()} y^2 + {p.item()} y^3')
Выход:
В приведенном ниже выводе вы можете видеть, что результат выполнения элементов выводится на экран.

Различиях между Jax и PyTorch
| Jax | PyTorch |
| Jax был выпущен в декабре 2018 года. | PyTorch был выпущен в октябре 2016 года. |
| Jax разработан Google | PyTorch разработан Facebook |
| Создание его графика является статическим. | Создание графика является динамическим. |
| Целевая аудитория – исследователи | Целевая аудитория — исследователи и разработчики. |
| Реализация Jax имеет линейную сложность во время выполнения. | Реализация PyTorch имеет сложность квадратичного времени. |
| Jax более гибок, чем PyTorch, поскольку позволяет определять функции, а затем автоматически вычислять производную этих функций. | PyTorch является гибким. |
| Стадия разработки развивается(v0.1.55) | Стадия разработки — зрелая(v1.8.0). |
| Jax более эффективен, чем PyTorch, поскольку он может автоматически распараллеливать наш код на нескольких процессорах. | PyTorch эффективен. |
Jax и PyTorch и TensorFlow
В этом разделе мы узнаем о ключевых различиях между Jax, PyTorch и TensorFlow в Python.
| Jax | PyTorch | TensorFlow |
| Jax разработан Google. | PyTorch разработан Facebook. | TensorFlow разработан Google. |
| Jax гибкий. | PyTorch является гибким. | TensorFlow не является гибким. |
| Целевая аудитория Jax — исследователи. | Целевая аудитория PyTorch — исследователи и разработчики. | Целевая аудитория TensorFlow — исследователи и разработчики. |
| Создает статические графики | PyTorch создает динамические графики | TensorFlow создает как статические, так и динамические графики. |
| Jax имеет как высокоуровневый, так и низкоуровневый API. | PyTorch имеет низкоуровневый API. | TensorFlow имеет API высокого уровня. |
| Jax более эффективен, чем PyTorch и TensorFlow. | PyTorch менее эффективен, чем Jax | Tensorflow также менее эффективен, чем Jax. |
| Стадия разработки Jax — «Разработка»(v0.1.55). | Стадия разработки PyTorch – зрелая(v.1.8.0). | Стадия разработки TensorFlow — зрелая(v2.4.1). |
Тест
Jax — это библиотека машинного обучения для изменения числовых функций. Он может собирать числовые программы для процессоров или ускорителей графического процессора.
Код:
- В следующем коде мы импортируем все необходимые библиотеки, такие как импорт jax.numpy как jnp, импорт grad, jit, vmap из jax и импорт случайных чисел из jax.
- m = random.PRNGKey(0) используется для генерации случайных данных, а случайное состояние описывается двумя 32-битными целыми числами без знака, которые мы называем ключом.
i = random.normal(rd,(8,)) используется для генерации выборки чисел, взятых из нормального распределения. - print(y) используется для печати значений y с помощью функции print().
size = 2899 используется для указания размера.
random.normal(rd,(size, size), dtype=jnp.float32) используется для генерации выборки чисел, взятых из нормального распределения.
%timeit jnp.dot(y, yT).block_until_ready() запускается на графическом процессоре, когда графический процессор доступен, в противном случае он запускается на центральном процессоре. - %timeit jnp.dot(i, iT).block_until_ready(): Здесь мы используем block_until_ready, потому что jax использует асинхронное выполнение.
- i = num.random.normal(size=(siz, siz)).astype(num.float32) используется для передачи данных на графический процессор
# Importing Libraries import jax.numpy as jnp from jax import grad, jit, vmap from jax import random # Multiplying Matrices m = random.PRNGKey(0) i = random.normal(m,(10,)) print(i) # Multiply two big matrices siz = 2800 i = random.normal(m,(siz, siz), dtype=jnp.float32) %timeit jnp.dot(i, i.T).block_until_ready() # Jax Numpy function work on regular numpy array import numpy as num i = num.random.normal(size=(siz, siz)).astype(num.float32) %timeit jnp.dot(i, i.T).block_until_ready() # Transfer the data to GPU from jax import device_put i = num.random.normal(size=(siz, siz)).astype(num.float32) i = device_put(i) %timeit jnp.dot(i, i.T).block_until_ready()
Выход:
После запуска приведенного выше кода мы получаем следующий вывод, в котором видим, что умножение матриц с использованием Jax выполняется на экране.

Тест PyTorch помогает нам убедиться, что наш код соответствует ожиданиям по производительности, и сравнить различные подходы к решению проблем.
В следующем коде мы импортируем все необходимые библиотеки, такие как import torch и import timeit.
- return m.mul(n).sum(-1) используется для расчета пакетной точки путем умножения и суммирования.
- m = m.reshape(-1, 1, m.shape[-1]) используется для расчета пакетной точки путем уменьшения до bmm.
- i = torch.randn(1000, 62) используется в качестве входных данных для сравнительного анализа.
- print(f’multiply_sum(i, i): {j.timeit(100) / 100 * 1e6:>5.1f} us’) используется для печати значений умножения и суммирования.
- print(f’bmm(i, i): {j1.timeit(100) / 100 * 1e6:>5.1f} us’) используется для печати значений bmm.
# Import library
import torch
import timeit
# Define the Model
def batcheddot_multiply_sum(m, n):
# Calculates batched dot by multiplying and sum
return m.mul(n).sum(-1)
def batcheddot_bmm(m, n):
#Calculates batched dot by reducing to bmm
m = m.reshape(-1, 1, m.shape[-1])
n = n.reshape(-1, n.shape[-1], 1)
return torch.bmm(m, n).flatten(-3)
# Input for benchmarking
i = torch.randn(1000, 62)
# Ensure that both functions compute the same output
assert batcheddot_multiply_sum(i, i).allclose(batcheddot_bmm(i, i))
# Using timeit.Timer() method
j = timeit.Timer(
stmt='batcheddot_multiply_sum(i, i)',
setup='from __main__ import batcheddot_multiply_sum',
globals={'i': i})
j1 = timeit.Timer(
stmt='batcheddot_bmm(i, i)',
setup='from __main__ import batcheddot_bmm',
globals={'i': i})
print(f'multiply_sum(i, i): {j.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'bmm(i, i): {j1.timeit(100) / 100 * 1e6:>5.1f} us')
Выход:
После запуска приведенного выше кода мы получаем следующий вывод, в котором мы видим, что значение умножения и суммы с использованием теста PyTorch выводится на экран.
