Training distribuito del modello di raccomandazione a due torri utilizzando Lightning

Questo notebook illustra come creare un modello di raccomandazione a due torre usando l'API PyTorch Lightning Trainer con training distribuito tra 8 GPU H100 in un singolo nodo.

Promemoria:

  • Per questa demo, collegarsi alla GPU 8xH100 per sfruttare il training distribuito tra più GPU.
  • L'elemento @distributed decorator della serverless_gpu libreria Python distribuirà la funzione di training PyTorch Lightning tra 8 GPU H100.

Per iniziare, configurare il notebook per l'uso della GPU serverless:

  1. Fare clic sull'elenco a discesa Connetti nella parte superiore per aprire il selettore di calcolo.
  2. Selezionare GPU serverless.
  3. Aprire il pannello Ambiente a destra.
  4. Selezionare 8xH100 come acceleratore.
  5. Seleziona AI v5 come ambiente. Fare clic su Applica, quindi su Conferma.

Il tuo notebook è ora connesso al sistema di calcolo GPU serverless. Il decorator @distributed gestirà l'avvio del training su tutte le 8 GPU.

Prerequisiti

Prima di eseguire questa demo, configurare le variabili del widget nella parte superiore di questo notebook:

  • uc_catalog: catalogo di Unity Catalog in cui è registrato il modello sottoposto a training.
  • uc_schema: lo schema di Unity Catalog del catalogo sopra indicato in cui è registrato il modello addestrato.

Il set di dati viene scaricato automaticamente durante l'esecuzione del notebook. Il modello viene salvato in <uc_catalog>.<uc_schema>.<model_name_in_registry>.

Modello di raccomandazione a due torre

Per altre informazioni sul modello di raccomandazione a due torre, vedere le risorse seguenti:

Istruzioni:

Di seguito, il codice illustra come:

  1. Installare i pacchetti
  2. Scaricare e preparare il set di dati
  3. Configurazioni di training necessarie
  4. Definizione del modello di raccomandazione a due torre
  5. Creazione della funzione di training principale
  6. Training del modello a due torri
  7. Eseguire l'inferenza
  8. Registrare il modello in MLflow per la gestione

1) Installare pacchetti

L'ambiente databricks per intelligenza artificiale v5 include la maggior parte delle librerie necessarie per questo esempio. Eseguire la cella seguente per installare i pacchetti aggiuntivi che non fanno ancora parte dell'ambiente.

%pip install --no-cache-dir --force-reinstall --no-deps --index-url https://download.pytorch.org/whl/cu129 torchaudio==2.9.0 fbgemm-gpu==1.4.0
%pip install --no-cache-dir --extra-index-url https://download.pytorch.org/whl/cu129 torchrec==1.4.0
dbutils.library.restartPython()

La cella seguente consolida tutte le importazioni utilizzate in questo esempio.

# General Imports
import os
import urllib.request
import zipfile
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple

# Data Processing Imports
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split

# Databricks Specific Imports
import mlflow
from mlflow.models.signature import infer_signature
from mlflow.pyfunc import PythonModel

# Torch Specific Imports
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchmetrics.classification import AUROC

# PyTorch Lightning
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, DeviceStatsMonitor
from pytorch_lightning.loggers import MLFlowLogger

# TorchRec Specific Imports
from torchrec.datasets.utils import Batch
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.mlp import MLP
from torchrec.optim.keyed import KeyedOptimizerWrapper
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

2) Scaricare e preparare il set di dati

Scaricare il set di dati Learning from Sets, preelaborarlo e suddividerlo in set di addestramento/convalida/test.

dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("uc_volume", "recsys")

uc_catalog = dbutils.widgets.get("uc_catalog")
uc_schema = dbutils.widgets.get("uc_schema")
uc_volume = dbutils.widgets.get("uc_volume")
DATASET_URL = "https://files.grouplens.org/datasets/learning-from-sets-2019/learning-from-sets-2019.zip"

DATASET_PATH = f"/Volumes/{uc_catalog}/{uc_schema}/{uc_volume}/dataset"
ZIP_PATH = f"{DATASET_PATH}/learning-from-sets-2019.zip"
CSV_PATH = f"{DATASET_PATH}/learning-from-sets-2019/item_ratings.csv"

# Download and extract
if not os.path.exists(CSV_PATH):
    os.makedirs(DATASET_PATH, exist_ok=True)
    print("Downloading dataset...")
    urllib.request.urlretrieve(DATASET_URL, ZIP_PATH)
    with zipfile.ZipFile(ZIP_PATH, "r") as zf:
        zf.extractall(DATASET_PATH)
    print("Download complete.")

# Load and preprocess
df = pd.read_csv(CSV_PATH)
df = df.sort_values(["userId", "movieId"]).head(100_000)

# Encode userId to contiguous integers
user_encoder = LabelEncoder()
df["userId"] = user_encoder.fit_transform(df["userId"])

# Binarize ratings: 1 if >= mean, else 0
mean_rating = df["rating"].mean()
df["label"] = (df["rating"] >= mean_rating).astype(np.int64)
df = df[["userId", "movieId", "label"]]

# Compute embedding table sizes from data
num_users = int(df["userId"].nunique())
num_movies = int(df["movieId"].nunique())
print(f"Dataset: {len(df)} rows, {num_users} users, {num_movies} movies")

# Split: 70% train, 21% validation, 9% test
train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.33, random_state=42)
print(f"Train: {len(train_df)}, Validation: {len(val_df)}, Test: {len(test_df)}")
class RecDataset(Dataset):
    """Wraps a DataFrame with columns [userId, movieId, label] as a PyTorch Dataset."""
    def __init__(self, dataframe: pd.DataFrame):
        self.users = dataframe["userId"].values.astype(np.int64)
        self.movies = dataframe["movieId"].values.astype(np.int64)
        self.labels = dataframe["label"].values.astype(np.int64)

    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, idx: int) -> dict:
        return {"userId": self.users[idx], "movieId": self.movies[idx], "label": self.labels[idx]}


def get_dataloader(dataframe: pd.DataFrame, batch_size: int = 1024, shuffle: bool = True) -> DataLoader:
    return DataLoader(RecDataset(dataframe), batch_size=batch_size, shuffle=shuffle, num_workers=2, pin_memory=True)

3) Configurazioni di training necessarie

Tutti gli argomenti e le informazioni necessari per questo esempio di training vengono consolidati nella cella seguente. Tutti questi elementi possono essere modificati in base al caso d'uso.

@dataclass
class Args:
    epochs: int = 3
    embedding_dim: int = 128
    layer_sizes: List[int] = field(default_factory=lambda: [128, 64])
    learning_rate: float = 0.01
    batch_size: int = 1024

cat_cols = ["userId", "movieId"]
emb_counts = [num_users, num_movies]  # computed from data in section 2

4) Definizione del modello di raccomandazione a due torre

Questa sezione definisce il modello usando PyTorch Lightning. Per altre informazioni, vedere la documentazione:

class TwoTowerModel(nn.Module):
    def __init__(
        self,
        embedding_bag_collection: EmbeddingBagCollection,
        layer_sizes: List[int],
        device: Optional[torch.device] = None
    ) -> None:
        super().__init__()
        assert len(embedding_bag_collection.embedding_bag_configs()) == 2, "Expected two EmbeddingBags in the two tower model"
        assert embedding_bag_collection.embedding_bag_configs()[0].embedding_dim == embedding_bag_collection.embedding_bag_configs()[1].embedding_dim, "Both EmbeddingBagConfigs must have the same dimension"
        embedding_dim = embedding_bag_collection.embedding_bag_configs()[0].embedding_dim
        self._feature_names_query: List[str] = embedding_bag_collection.embedding_bag_configs()[0].feature_names
        self._candidate_feature_names: List[str] = embedding_bag_collection.embedding_bag_configs()[1].feature_names
        self.ebc = embedding_bag_collection
        self.query_proj = MLP(in_size=embedding_dim, layer_sizes=layer_sizes, device=device)
        self.candidate_proj = MLP(in_size=embedding_dim, layer_sizes=layer_sizes, device=device)

    def forward(self, kjt: KeyedJaggedTensor) -> Tuple[torch.Tensor, torch.Tensor]:
        pooled_embeddings = self.ebc(kjt)
        query_embedding: torch.Tensor = self.query_proj(
            torch.cat(
                [pooled_embeddings[feature] for feature in self._feature_names_query],
                dim=1,
            )
        )
        candidate_embedding: torch.Tensor = self.candidate_proj(
            torch.cat(
                [pooled_embeddings[feature] for feature in self._candidate_feature_names],
                dim=1,
            )
        )
        return query_embedding, candidate_embedding

class LitTwoTower(pl.LightningModule):
    """
    PyTorch Lightning module wrapping a TwoTowerModel.
    Uses torchmetrics AUROC for train/val metrics.
    """
    def __init__(
        self,
        two_tower: nn.Module,
        device: torch.device,
        emb_counts: Optional[List[int]],
        cat_cols: List[str],
        lr: float = 1e-3,
    ) -> None:
        super().__init__()
        self.two_tower = two_tower
        self.loss_fn = nn.BCEWithLogitsLoss()
        self.train_auroc = AUROC(task="binary")
        self.val_auroc = AUROC(task="binary")
        self.lr = lr

        # Store metadata used in batch transform
        self.emb_counts = emb_counts
        self.cat_cols = cat_cols

        self.save_hyperparameters(ignore=["two_tower", "device"])

    def forward(self, batch: Dict[str, Any]) -> torch.Tensor:
        kjt_batch = self._transform_to_torchrec_batch(batch, self.emb_counts)
        query_embedding, candidate_embedding = self.two_tower(kjt_batch.sparse_features)
        logits = (query_embedding * candidate_embedding).sum(dim=1).squeeze()
        return logits

    def _loss(self, outputs: torch.Tensor, batch: Dict[str, Any]) -> torch.Tensor:
        labels = self._get_batch_labels(batch)
        return self.loss_fn(outputs, labels)

    def _update_metric(self, batch: Dict[str, Any], outputs: Optional[torch.Tensor], metric: AUROC) -> None:
        if outputs is None:
            outputs = self.forward(batch)
        preds = torch.sigmoid(outputs)
        labels = self._get_batch_labels(batch)
        metric.update(preds, labels)

    def training_step(self, batch: Dict[str, Any], batch_idx: int):
        logits = self.forward(batch)
        loss = self._loss(logits, batch)

        # Metric update
        self._update_metric(batch, logits, self.train_auroc)

        # Log both step and epoch loss series; enable sync_dist for multi-GPU/DDP
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)

        self.log("train_auroc", self.train_auroc, on_step=False, on_epoch=True, prog_bar=True,
             logger=True, sync_dist=True)

        return loss

    def validation_step(self, batch: Dict[str, Any], batch_idx: int):
        logits = self.forward(batch)
        loss = self._loss(logits, batch)

        self._update_metric(batch, logits, self.val_auroc)

        # Typically only epoch-level val metrics are needed for monitoring
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        self.log("val_auroc", self.val_auroc, on_step=False, on_epoch=True, prog_bar=True,
             logger=True, sync_dist=True)

    def configure_optimizers(self):
        optimizer = KeyedOptimizerWrapper(
            dict(self.two_tower.named_parameters()),
            lambda params: torch.optim.Adam(params, lr=self.lr),
        )
        return optimizer

    def _get_batch_labels(self, batch: Dict[str, Any]) -> torch.Tensor:
        return batch["label"].to(dtype=torch.float32, device=self.device)

    def _transform_to_torchrec_batch(
        self,
        batch: Dict[str, Any],
        num_embeddings_per_feature: Optional[List[int]],
    ) -> Batch:
        kjt_values_list = []
        kjt_lengths_list = []
        for col_idx, col_name in enumerate(self.cat_cols):
            values = batch[col_name]
            num_emb = num_embeddings_per_feature[col_idx]
            kjt_values_list.append(values % num_emb)
            kjt_lengths_list.append(torch.ones(len(values), dtype=torch.int64))

        values_t = torch.cat(kjt_values_list).to(dtype=torch.int64, device=self.device)
        lengths_t = torch.cat(kjt_lengths_list).to(device=self.device)

        sparse_features = KeyedJaggedTensor.from_lengths_sync(
            self.cat_cols,
            values_t,
            lengths_t,
        )

        labels = batch["label"].to(dtype=torch.int64, device=self.device)

        return Batch(
            dense_features=torch.zeros(1, device=self.device),
            sparse_features=sparse_features,
            labels=labels,
        )

def create_two_tower_model(args, device, cat_cols, emb_counts) -> LitTwoTower:
    eb_configs = [
        EmbeddingBagConfig(
            name=f"t_{feature_name}",
            embedding_dim=args.embedding_dim,
            num_embeddings=emb_counts[feature_idx],
            feature_names=[feature_name],
        )
        for feature_idx, feature_name in enumerate(cat_cols)
    ]
    ebc = EmbeddingBagCollection(tables=eb_configs, device=device)
    base = TwoTowerModel(
        embedding_bag_collection=ebc, layer_sizes=args.layer_sizes, device=device
    )
    lit = LitTwoTower(
        base, cat_cols=cat_cols, emb_counts=emb_counts, device=device, lr=args.learning_rate
    )
    return lit

5) Creare la funzione di training principale

Successivamente, utilizza il decoratore @distributed della libreria serverless_gpu insieme alle funzioni helper e all'API Trainer di PyTorch Lightning per lanciare l'addestramento su più GPU.

# setup mlflow experiment
username = spark.sql("SELECT current_user()").first()['current_user()']
experiment_path = f'/Users/{username}/sgc-torchrec-example'
experiment = mlflow.set_experiment(experiment_path)
os.environ["MLFLOW_EXPERIMENT_NAME"] = experiment_path
from serverless_gpu import distributed

# setup arguments for training function
args = Args(epochs=1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT_PATH = f"/Volumes/{uc_catalog}/{uc_schema}/{uc_volume}/checkpoints"

@distributed(gpus=8, gpu_type="H100")
def training_function(args=args, cat_cols=cat_cols, emb_counts=emb_counts, device=device,
                      train_data=train_df, val_data=val_df, checkpoint_path=CHECKPOINT_PATH):
    mlflow.pytorch.autolog()
    model = create_two_tower_model(args, device=device, cat_cols=cat_cols, emb_counts=emb_counts)
    train_dataloader = get_dataloader(train_data, batch_size=args.batch_size, shuffle=True)
    eval_dataloader = get_dataloader(val_data, batch_size=args.batch_size, shuffle=False)

    mlflow_logger = MLFlowLogger(
        experiment_name=experiment_path,
        log_model="all",
    )

    ckpt_cb = ModelCheckpoint(
        dirpath=checkpoint_path,
        monitor="val_auroc",
        mode="max",
        save_top_k=1,
        save_last=True,                        # enables last_model_path
        filename="{epoch}-{val_auroc:.4f}",
    )

    callbacks = [
        LearningRateMonitor(logging_interval="step"),
        DeviceStatsMonitor(),
        ckpt_cb,
    ]

    trainer = Trainer(
        max_epochs=args.epochs,
        accelerator="gpu",
        strategy="ddp",
        devices=8,
        log_every_n_steps=20,
        logger=mlflow_logger,
        callbacks=callbacks,
    )
    trainer.fit(
        model,
        train_dataloaders=train_dataloader,
        val_dataloaders=eval_dataloader
    )

    # Return run_id and best checkpoint path
    result = {
        "run_id": trainer.logger.run_id,                   # MLflow run id
        "best_model_checkpoint": ckpt_cb.best_model_path,  # best checkpoint path
        "last_model_checkpoint": ckpt_cb.last_model_path   # last checkpoint path
    }
    return result

6) Addestrare il modello a due torri usando l'API serverless di addestramento distribuito su GPU

result = training_function.distributed()

7) Testare il checkpoint del modello migliore

Recuperare il checkpoint del modello migliore ed eseguire il test per verificare i risultati

print(f"Experiment Name: {experiment.name}")
print(f"Experiment ID: {experiment.experiment_id}")
print(f"Artifact Location: {experiment.artifact_location}")
print(f"Lifecycle_stage: {experiment.lifecycle_stage}")

ranked_checkpoints = mlflow.search_logged_models(
  experiment_ids=[experiment.experiment_id],
  output_format="list",
  order_by=[{"field_name": "metrics.accuracy", "ascending": False}]
)

best_checkpoint: mlflow.entities.LoggedModel = ranked_checkpoints[0]
print(best_checkpoint.metrics[0])
run_id = best_checkpoint.source_run_id
artifact_path = best_checkpoint.artifact_location
model_uri = f"runs:/{run_id}/{artifact_path}"
two_tower_model = mlflow.pytorch.load_model(model_uri)

num_batches = 5 # Number of batches to print out at a time
batch_size = 1 # Print out each individual row

test_dataloader = iter(get_dataloader(test_df, batch_size=batch_size, shuffle=False))

device = torch.device("cuda:0")
two_tower_model.to(device)
two_tower_model.eval()

for _ in range(num_batches):
    next_batch = next(test_dataloader)
    expected_result = next_batch["label"][0]

    actual_result = two_tower_model(next_batch)
    actual_result = torch.sigmoid(actual_result)
    print(f"Expected Result: {expected_result}; Actual Result: {actual_result.round().item()}")

8) Registrare il modello in MLflow per la gestione

Quando il modello nel passaggio precedente risulta corretto, usare il corrispondente run_id dalla versione più recente per registrare il modello. Per semplificare questa operazione, creare un PyFunc che integra il modello delle due torri per ricevere un input più semplice: (Dict[str, List] -> List[float]).

class TwoTowerWrapper(PythonModel):
    """
    MLflow PythonModel wrapper for TwoTower model that handles dictionary input and returns list outputs
    """
    def __init__(self, two_tower_model):
        self.two_tower_model = two_tower_model

    def predict(self, model_input: Dict[str, List]) -> List[float]:
        batch = {key: torch.tensor(value) for key, value in model_input.items()}
        if "label" not in batch:
            batch["label"] = torch.zeros(len(next(iter(batch.values()))))
        with torch.no_grad():
            output = self.two_tower_model(batch).cpu()
        output = torch.sigmoid(output)
        return output.tolist()
def preprocess_data(batch):
    # turn the example test dataset from Dict[str, Tensor] to Dict[str, List] and remove the label
    return {key: tensor.tolist() for key, tensor in batch.items() if key != "label"}

def add_and_get_model_signature(two_tower_model, test_dataloader):
    current_batch = preprocess_data(next(test_dataloader))

    pyfunc_two_tower_model = TwoTowerWrapper(two_tower_model)
    current_output = pyfunc_two_tower_model.predict(current_batch)
    signature = infer_signature(current_batch, current_output)
    logged_model = mlflow.pyfunc.log_model(
        artifact_path="two_tower_pyfunc",
        python_model=pyfunc_two_tower_model,
        signature=signature,
        input_example=current_batch
    )
    return signature, logged_model

signature, logged_model = add_and_get_model_signature(two_tower_model, test_dataloader)
model_name = "two_tower_model"
uc_model_version = mlflow.register_model(
    f"models:/{logged_model.model_id}",
    name=f"{uc_catalog}.{uc_schema}.{model_name}"
)

Notebook di esempio

Training distribuito del modello di raccomandazione a due torri utilizzando Lightning

Ottieni il notebook