PyTorch — это библиотека машинного обучения с открытым исходным кодом, которая в основном используется для компьютерного зрения и обработки естественного языка в Python. Также мы рассмотрим различные примеры, связанные с Jax и PyTorch.
JAX означает «Сразу после выполнения». Это библиотека машинного обучения, разработанная DeepMind. Jax — это JIT-компилятор(Just In Time), ориентированный на управление максимальным количеством FLOPS, создающих оптимизированный код при использовании Python.
Jax — это язык программирования, показывающий и преобразующий числовые программы. Он также способен компилировать числовые программы для процессора или ускоряющего графического процессора.
Jax включает код Numpy ON не только для процессора, но также для графического процессора и TPU.
Код:
В следующем коде мы импортируем все необходимые библиотеки, такие как импорт jax.numpy как jnp, импорт grad, jit, vmap из jax и импорт случайных чисел из jax.
- rd = random.PRNGKey(0) используется для генерации случайных данных, а случайное состояние описывается двумя 32-битными целыми числами без знака, которые мы называем ключом.
- y = random.normal(rd,(8,)) используется для генерации выборки чисел, взятых из нормального распределения.
- print(y) используется для печати значений y с помощью функции print().
- size = 2899 используется для указания размера.
- random.normal(rd,(size, size), dtype=jnp.float32) используется для генерации выборки чисел, взятых из нормального распределения.
- %timeit jnp.dot(y, yT).block_until_ready() запускается на графическом процессоре, когда графический процессор доступен, в противном случае он запускается на центральном процессоре.
# Importing libraries import jax.numpy as jnp from jax import grad, jit, vmap from jax import random # Generating random data rd = random.PRNGKey(0) y = random.normal(rd,(8,)) print(y) # Multiply two matrices size = 2899 y = random.normal(rd,(size, size), dtype=jnp.float32) # runs on the GPU %timeit jnp.dot(y, y.T).block_until_ready()
Выход:
После запуска приведенного выше кода мы получаем следующий вывод, в котором видим, что умножение двух матриц выводится на экран.
Введение
PyTorch — это библиотека машинного обучения с открытым исходным кодом, которая в основном используется для компьютерного зрения и обработки естественного языка в Python. Он разработан исследовательской лабораторией искусственного интеллекта Facebook.
Это программное обеспечение, выпущенное под модифицированной лицензией BSD. Он построен на основе Python, который поддерживает расчет тензоров на графическом процессоре(GPU).
PyTorch прост в использовании, эффективно использует память, имеет динамический вычислительный граф, является гибким и позволяет создавать осуществимые коды, которые увеличивают скорость обработки. PyTorch — наиболее рекомендуемая библиотека для глубокого обучения и искусственного интеллекта.
Код:
В следующем коде мы импортируем все необходимые библиотеки, такие как import torch и import math.
- y = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) используется для создания случайных входных и выходных данных.
- m = torch.randn((), device=device, dtype=dtype) используется для случайной инициализации весов.
- z_pred = m + n * y + o * y ** 2 + p * y ** 3 используется в качестве прямого прохода для вычисления прогнозируемого Z.
- loss =(z_pred – z).pow(2).sum().item() используется для вычисления потерь.
- print(int, loss) используется для печати потерь.
- grad_m = grad_z_pred.sum(): здесь мы применяем обратное распространение ошибки для вычисления градиентов m, n, o и p относительно потерь.
- m -= Learning_rate * grad_m используется для обновления весов с использованием градиентного спуска.
- print(f’Result: z = {m.item()} + {n.item()} y + {o.item()} y^2 + {p.item()} y^3′) используется для печати результата с помощью функции print().
# Importing libraries import torch import math # Device Used dtype = torch.float device = torch.device("cpu") # Create random input and output data y = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) z = torch.sin(y) # Randomly initialize weights m = torch.randn((), device=device, dtype=dtype) n = torch.randn((), device=device, dtype=dtype) o = torch.randn((), device=device, dtype=dtype) p = torch.randn((), device=device, dtype=dtype) learning_rate = 1e-6 for i in range(2000): # Forward pass: compute predicted z z_pred = m + n * y + o * y ** 2 + p * y ** 3 # Compute and print loss loss =(z_pred - z).pow(2).sum().item() if i % 100 == 99: print(int, loss) # Backprop to compute gradients of m, n, o, p with respect to loss grad_z_pred = 2.0 *(z_pred - z) grad_m = grad_z_pred.sum() grad_n =(grad_z_pred * y).sum() grad_o =(grad_z_pred * y ** 2).sum() grad_p =(grad_z_pred * y ** 3).sum() # Update weights using gradient descent m -= learning_rate * grad_m n -= learning_rate * grad_n o -= learning_rate * grad_o p -= learning_rate * grad_p # Print the result print(f'Result: z = {m.item()} + {n.item()} y + {o.item()} y^2 + {p.item()} y^3')
Выход:
В приведенном ниже выводе вы можете видеть, что результат выполнения элементов выводится на экран.
Различиях между Jax и PyTorch
Jax | PyTorch |
Jax был выпущен в декабре 2018 года. | PyTorch был выпущен в октябре 2016 года. |
Jax разработан Google | PyTorch разработан Facebook |
Создание его графика является статическим. | Создание графика является динамическим. |
Целевая аудитория – исследователи | Целевая аудитория — исследователи и разработчики. |
Реализация Jax имеет линейную сложность во время выполнения. | Реализация PyTorch имеет сложность квадратичного времени. |
Jax более гибок, чем PyTorch, поскольку позволяет определять функции, а затем автоматически вычислять производную этих функций. | PyTorch является гибким. |
Стадия разработки развивается(v0.1.55) | Стадия разработки — зрелая(v1.8.0). |
Jax более эффективен, чем PyTorch, поскольку он может автоматически распараллеливать наш код на нескольких процессорах. | PyTorch эффективен. |
Jax и PyTorch и TensorFlow
В этом разделе мы узнаем о ключевых различиях между Jax, PyTorch и TensorFlow в Python.
Jax | PyTorch | TensorFlow |
Jax разработан Google. | PyTorch разработан Facebook. | TensorFlow разработан Google. |
Jax гибкий. | PyTorch является гибким. | TensorFlow не является гибким. |
Целевая аудитория Jax — исследователи. | Целевая аудитория PyTorch — исследователи и разработчики. | Целевая аудитория TensorFlow — исследователи и разработчики. |
Создает статические графики | PyTorch создает динамические графики | TensorFlow создает как статические, так и динамические графики. |
Jax имеет как высокоуровневый, так и низкоуровневый API. | PyTorch имеет низкоуровневый API. | TensorFlow имеет API высокого уровня. |
Jax более эффективен, чем PyTorch и TensorFlow. | PyTorch менее эффективен, чем Jax | Tensorflow также менее эффективен, чем Jax. |
Стадия разработки Jax — «Разработка»(v0.1.55). | Стадия разработки PyTorch – зрелая(v.1.8.0). | Стадия разработки TensorFlow — зрелая(v2.4.1). |
Тест
Jax — это библиотека машинного обучения для изменения числовых функций. Он может собирать числовые программы для процессоров или ускорителей графического процессора.
Код:
- В следующем коде мы импортируем все необходимые библиотеки, такие как импорт jax.numpy как jnp, импорт grad, jit, vmap из jax и импорт случайных чисел из jax.
- m = random.PRNGKey(0) используется для генерации случайных данных, а случайное состояние описывается двумя 32-битными целыми числами без знака, которые мы называем ключом.
i = random.normal(rd,(8,)) используется для генерации выборки чисел, взятых из нормального распределения. - print(y) используется для печати значений y с помощью функции print().
size = 2899 используется для указания размера.
random.normal(rd,(size, size), dtype=jnp.float32) используется для генерации выборки чисел, взятых из нормального распределения.
%timeit jnp.dot(y, yT).block_until_ready() запускается на графическом процессоре, когда графический процессор доступен, в противном случае он запускается на центральном процессоре. - %timeit jnp.dot(i, iT).block_until_ready(): Здесь мы используем block_until_ready, потому что jax использует асинхронное выполнение.
- i = num.random.normal(size=(siz, siz)).astype(num.float32) используется для передачи данных на графический процессор
# Importing Libraries import jax.numpy as jnp from jax import grad, jit, vmap from jax import random # Multiplying Matrices m = random.PRNGKey(0) i = random.normal(m,(10,)) print(i) # Multiply two big matrices siz = 2800 i = random.normal(m,(siz, siz), dtype=jnp.float32) %timeit jnp.dot(i, i.T).block_until_ready() # Jax Numpy function work on regular numpy array import numpy as num i = num.random.normal(size=(siz, siz)).astype(num.float32) %timeit jnp.dot(i, i.T).block_until_ready() # Transfer the data to GPU from jax import device_put i = num.random.normal(size=(siz, siz)).astype(num.float32) i = device_put(i) %timeit jnp.dot(i, i.T).block_until_ready()
Выход:
После запуска приведенного выше кода мы получаем следующий вывод, в котором видим, что умножение матриц с использованием Jax выполняется на экране.
Тест PyTorch помогает нам убедиться, что наш код соответствует ожиданиям по производительности, и сравнить различные подходы к решению проблем.
В следующем коде мы импортируем все необходимые библиотеки, такие как import torch и import timeit.
- return m.mul(n).sum(-1) используется для расчета пакетной точки путем умножения и суммирования.
- m = m.reshape(-1, 1, m.shape[-1]) используется для расчета пакетной точки путем уменьшения до bmm.
- i = torch.randn(1000, 62) используется в качестве входных данных для сравнительного анализа.
- print(f’multiply_sum(i, i): {j.timeit(100) / 100 * 1e6:>5.1f} us’) используется для печати значений умножения и суммирования.
- print(f’bmm(i, i): {j1.timeit(100) / 100 * 1e6:>5.1f} us’) используется для печати значений bmm.
# Import library import torch import timeit # Define the Model def batcheddot_multiply_sum(m, n): # Calculates batched dot by multiplying and sum return m.mul(n).sum(-1) def batcheddot_bmm(m, n): #Calculates batched dot by reducing to bmm m = m.reshape(-1, 1, m.shape[-1]) n = n.reshape(-1, n.shape[-1], 1) return torch.bmm(m, n).flatten(-3) # Input for benchmarking i = torch.randn(1000, 62) # Ensure that both functions compute the same output assert batcheddot_multiply_sum(i, i).allclose(batcheddot_bmm(i, i)) # Using timeit.Timer() method j = timeit.Timer( stmt='batcheddot_multiply_sum(i, i)', setup='from __main__ import batcheddot_multiply_sum', globals={'i': i}) j1 = timeit.Timer( stmt='batcheddot_bmm(i, i)', setup='from __main__ import batcheddot_bmm', globals={'i': i}) print(f'multiply_sum(i, i): {j.timeit(100) / 100 * 1e6:>5.1f} us') print(f'bmm(i, i): {j1.timeit(100) / 100 * 1e6:>5.1f} us')
Выход:
После запуска приведенного выше кода мы получаем следующий вывод, в котором мы видим, что значение умножения и суммы с использованием теста PyTorch выводится на экран.