18  State Dict: Save and Load

Trained models need to be saved. state_dict makes this easy.

18.1 What is state_dict?

A dictionary mapping parameter names to values:

model = Linear(4, 3)
print(model.state_dict())
# {
#   'weight': array([[...], [...], [...]]),  # shape (4, 3)
#   'bias': array([...])                      # shape (3,)
# }

18.2 Implementing state_dict

class Module:
    def state_dict(self):
        """Return all parameters as a dictionary."""
        state = {}

        # Add own parameters
        for name, param in self._parameters.items():
            state[name] = param.data.copy()

        # Add parameters from submodules (with prefix)
        for name, module in self._modules.items():
            sub_state = module.state_dict()
            for key, value in sub_state.items():
                state[f"{name}.{key}"] = value

        return state

    def load_state_dict(self, state_dict):
        """Load parameters from a dictionary."""
        # Load own parameters
        for name, param in self._parameters.items():
            if name in state_dict:
                param.data = state_dict[name].copy()

        # Load into submodules
        for name, module in self._modules.items():
            prefix = f"{name}."
            sub_dict = {
                k[len(prefix):]: v
                for k, v in state_dict.items()
                if k.startswith(prefix)
            }
            module.load_state_dict(sub_dict)

18.3 Example: state_dict Structure

class MLP(Module):
    def __init__(self):
        super().__init__()
        self.fc1 = Linear(4, 8)
        self.fc2 = Linear(8, 3)

model = MLP()
state = model.state_dict()
print(state.keys())
# dict_keys(['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])

Nested modules use dot notation for names.

18.4 Saving Models

Using NumPy’s npz format:

def save_model(model, filepath):
    """Save model parameters to file."""
    state = model.state_dict()
    np.savez(filepath, **state)
    print(f"Saved to {filepath}")

def load_model(model, filepath):
    """Load model parameters from file."""
    state = dict(np.load(filepath))
    model.load_state_dict(state)
    print(f"Loaded from {filepath}")

Usage:

# Train model
model = IrisClassifier()
train(model)

# Save
save_model(model, 'iris_model.npz')

# Later: load into new model
model2 = IrisClassifier()
load_model(model2, 'iris_model.npz')

# model2 now has trained weights!

18.5 Saving with Pickle

For more complex cases:

import pickle

def save_checkpoint(model, optimizer, epoch, filepath):
    """Save full training state."""
    checkpoint = {
        'model_state': model.state_dict(),
        'optimizer_state': optimizer_state_dict(optimizer),
        'epoch': epoch
    }
    with open(filepath, 'wb') as f:
        pickle.dump(checkpoint, f)

def load_checkpoint(model, optimizer, filepath):
    """Load full training state."""
    with open(filepath, 'rb') as f:
        checkpoint = pickle.load(f)

    model.load_state_dict(checkpoint['model_state'])
    load_optimizer_state(optimizer, checkpoint['optimizer_state'])
    return checkpoint['epoch']

18.6 Optimizer State

Optimizers have state too (momentum, Adam moments):

def optimizer_state_dict(optimizer):
    """Get optimizer state."""
    return {
        'lr': optimizer.lr,
        'm': [m.copy() for m in optimizer.m],  # Adam first moment
        'v': [v.copy() for v in optimizer.v],  # Adam second moment
        't': optimizer.t
    }

def load_optimizer_state(optimizer, state):
    """Load optimizer state."""
    optimizer.lr = state['lr']
    optimizer.m = state['m']
    optimizer.v = state['v']
    optimizer.t = state['t']

18.7 Full Training with Checkpoints

model = IrisClassifier()
optimizer = Adam(model.parameters(), lr=0.01)

best_loss = float('inf')
checkpoint_path = 'best_model.npz'

for epoch in range(100):
    model.train()
    loss = train_epoch(model, optimizer, train_data)

    model.eval()
    val_loss = evaluate(model, val_data)

    # Save best model
    if val_loss < best_loss:
        best_loss = val_loss
        save_model(model, checkpoint_path)
        print(f"Epoch {epoch}: New best! val_loss={val_loss:.4f}")

# Load best model for final evaluation
load_model(model, checkpoint_path)
test_acc = evaluate(model, test_data)
print(f"Test accuracy: {test_acc:.2%}")

18.8 Partial Loading

Sometimes you want to load only some parameters:

def load_state_dict(self, state_dict, strict=True):
    """
    Load parameters from dictionary.

    Args:
        state_dict: Dictionary of parameter values
        strict: If True, all keys must match
    """
    own_state = self.state_dict()
    missing = []
    unexpected = []

    for name, param in own_state.items():
        if name in state_dict:
            self._set_param(name, state_dict[name])
        elif strict:
            missing.append(name)

    for name in state_dict:
        if name not in own_state:
            unexpected.append(name)

    if strict and (missing or unexpected):
        raise ValueError(f"Missing: {missing}, Unexpected: {unexpected}")

    return missing, unexpected

18.9 Transfer Learning

Load pretrained weights, change output layer:

# Pretrained model for 10 classes
pretrained = MLP(input_size=784, hidden_size=256, output_size=10)
load_model(pretrained, 'pretrained_mnist.npz')

# New model for 5 classes
new_model = MLP(input_size=784, hidden_size=256, output_size=5)

# Load only shared layers
state = pretrained.state_dict()
del state['fc2.weight']  # Remove output layer
del state['fc2.bias']
new_model.load_state_dict(state, strict=False)

# Now fc1 has pretrained weights, fc2 is random

18.10 Summary

  • state_dict(): Get all parameters as dictionary
  • load_state_dict(): Load parameters from dictionary
  • Nested modules use dot notation: fc1.weight
  • Save with np.savez() or pickle
  • Checkpoint training for resume capability
  • Enable transfer learning with partial loading

Next: Dataset and DataLoader for organized data handling.