pytorch
- Version:
1.9.1
- Category:
ai
- Cluster:
Loki, Vali
Description
PyTorch is an open-source machine learning library developed by Facebook AI Research. It provides a flexible deep learning platform with a focus on rapid experimentation, GPU acceleration, and dynamic computation graphs.
Documentation
import torch
print(torch.__version__)
print("CUDA available:", torch.cuda.is_available())
# Define and train a basic model
model = torch.nn.Linear(10, 1)
x = torch.randn(5, 10)
output = model(x)
print(output)
# Save/load a model
torch.save(model.state_dict(), "model.pt")
model.load_state_dict(torch.load("model.pt"))
# PyTorch 2.x new compiler (2.1+ only)
if hasattr(torch, "compile"):
model = torch.compile(model)
Examples/Usage
Load PyTorch (version 2.1.2):
$ module load ai/PyTorch/2.1.2-foss-2023b
Launch Python and run a quick GPU check:
import torch
print(torch.cuda.get_device_name(0))
Train a simple model:
x = torch.randn(100, 1)
y = 2 * x + 1 + 0.1 * torch.randn(100, 1)
model = torch.nn.Linear(1, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
loss_fn = torch.nn.MSELoss()
for _ in range(200):
pred = model(x)
loss = loss_fn(pred, y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
Unload the module:
$ module unload ai/PyTorch/2.1.2-foss-2023b
Installation
Source: https://github.com/pytorch/pytorch/releases/tag/v1.12.0