18 State Dict: Save and Load
Trained models need to be saved. state_dict makes this easy.
18.1 What is state_dict?
A dictionary mapping parameter names to values:
model = Linear(4, 3)
print(model.state_dict())
# {
# 'weight': array([[...], [...], [...]]), # shape (4, 3)
# 'bias': array([...]) # shape (3,)
# }18.2 Implementing state_dict
class Module:
def state_dict(self):
"""Return all parameters as a dictionary."""
state = {}
# Add own parameters
for name, param in self._parameters.items():
state[name] = param.data.copy()
# Add parameters from submodules (with prefix)
for name, module in self._modules.items():
sub_state = module.state_dict()
for key, value in sub_state.items():
state[f"{name}.{key}"] = value
return state
def load_state_dict(self, state_dict):
"""Load parameters from a dictionary."""
# Load own parameters
for name, param in self._parameters.items():
if name in state_dict:
param.data = state_dict[name].copy()
# Load into submodules
for name, module in self._modules.items():
prefix = f"{name}."
sub_dict = {
k[len(prefix):]: v
for k, v in state_dict.items()
if k.startswith(prefix)
}
module.load_state_dict(sub_dict)18.3 Example: state_dict Structure
class MLP(Module):
def __init__(self):
super().__init__()
self.fc1 = Linear(4, 8)
self.fc2 = Linear(8, 3)
model = MLP()
state = model.state_dict()
print(state.keys())
# dict_keys(['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])Nested modules use dot notation for names.
18.4 Saving Models
Using NumPy’s npz format:
def save_model(model, filepath):
"""Save model parameters to file."""
state = model.state_dict()
np.savez(filepath, **state)
print(f"Saved to {filepath}")
def load_model(model, filepath):
"""Load model parameters from file."""
state = dict(np.load(filepath))
model.load_state_dict(state)
print(f"Loaded from {filepath}")Usage:
# Train model
model = IrisClassifier()
train(model)
# Save
save_model(model, 'iris_model.npz')
# Later: load into new model
model2 = IrisClassifier()
load_model(model2, 'iris_model.npz')
# model2 now has trained weights!18.5 Saving with Pickle
For more complex cases:
import pickle
def save_checkpoint(model, optimizer, epoch, filepath):
"""Save full training state."""
checkpoint = {
'model_state': model.state_dict(),
'optimizer_state': optimizer_state_dict(optimizer),
'epoch': epoch
}
with open(filepath, 'wb') as f:
pickle.dump(checkpoint, f)
def load_checkpoint(model, optimizer, filepath):
"""Load full training state."""
with open(filepath, 'rb') as f:
checkpoint = pickle.load(f)
model.load_state_dict(checkpoint['model_state'])
load_optimizer_state(optimizer, checkpoint['optimizer_state'])
return checkpoint['epoch']18.6 Optimizer State
Optimizers have state too (momentum, Adam moments):
def optimizer_state_dict(optimizer):
"""Get optimizer state."""
return {
'lr': optimizer.lr,
'm': [m.copy() for m in optimizer.m], # Adam first moment
'v': [v.copy() for v in optimizer.v], # Adam second moment
't': optimizer.t
}
def load_optimizer_state(optimizer, state):
"""Load optimizer state."""
optimizer.lr = state['lr']
optimizer.m = state['m']
optimizer.v = state['v']
optimizer.t = state['t']18.7 Full Training with Checkpoints
model = IrisClassifier()
optimizer = Adam(model.parameters(), lr=0.01)
best_loss = float('inf')
checkpoint_path = 'best_model.npz'
for epoch in range(100):
model.train()
loss = train_epoch(model, optimizer, train_data)
model.eval()
val_loss = evaluate(model, val_data)
# Save best model
if val_loss < best_loss:
best_loss = val_loss
save_model(model, checkpoint_path)
print(f"Epoch {epoch}: New best! val_loss={val_loss:.4f}")
# Load best model for final evaluation
load_model(model, checkpoint_path)
test_acc = evaluate(model, test_data)
print(f"Test accuracy: {test_acc:.2%}")18.8 Partial Loading
Sometimes you want to load only some parameters:
def load_state_dict(self, state_dict, strict=True):
"""
Load parameters from dictionary.
Args:
state_dict: Dictionary of parameter values
strict: If True, all keys must match
"""
own_state = self.state_dict()
missing = []
unexpected = []
for name, param in own_state.items():
if name in state_dict:
self._set_param(name, state_dict[name])
elif strict:
missing.append(name)
for name in state_dict:
if name not in own_state:
unexpected.append(name)
if strict and (missing or unexpected):
raise ValueError(f"Missing: {missing}, Unexpected: {unexpected}")
return missing, unexpected18.9 Transfer Learning
Load pretrained weights, change output layer:
# Pretrained model for 10 classes
pretrained = MLP(input_size=784, hidden_size=256, output_size=10)
load_model(pretrained, 'pretrained_mnist.npz')
# New model for 5 classes
new_model = MLP(input_size=784, hidden_size=256, output_size=5)
# Load only shared layers
state = pretrained.state_dict()
del state['fc2.weight'] # Remove output layer
del state['fc2.bias']
new_model.load_state_dict(state, strict=False)
# Now fc1 has pretrained weights, fc2 is random18.10 Summary
- state_dict(): Get all parameters as dictionary
- load_state_dict(): Load parameters from dictionary
- Nested modules use dot notation:
fc1.weight - Save with
np.savez()orpickle - Checkpoint training for resume capability
- Enable transfer learning with partial loading
Next: Dataset and DataLoader for organized data handling.