In [ ]:
import torch
from torch import nn
import torch.nn.functional as F
from functools import partial
def batch_norm(x):
mean = x.mean(0, keepdim=True)
var = x.var(0, unbiased=False, keepdim=True)
x_norm = (x - mean) / (var + 1e-5).sqrt()
return x_norm
def layer_norm(x):
mean = x.mean(1, keepdim=True)
var = x.var(1, unbiased=False, keepdim=True)
x_norm = (x - mean) / (var + 1e-5).sqrt()
return x_norm
def group_norm(x, num_groups):
N, C = x.shape
x = x.view(N, num_groups, -1)
mean = x.mean(-1, keepdim=True)
var = x.var(-1, unbiased=False, keepdim=True)
x_norm = (x - mean) / (var + 1e-5).sqrt()
x_norm = x_norm.view(N, C)
return x_norm
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, norm_func):
super().__init__()
self.linear1 = nn.Linear(input_dim, hidden_dim)
self.norm_func = norm_func
self.linear2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
x = self.linear1(x)
x = self.norm_func(x)
x = F.relu(x)
x = self.linear2(x)
return x
# Create a random tensor with size (batch_size, input_dim)
x = torch.randn(32, 100)
# Create the MLP models with batch norm, layer norm, and group norm
model_bn = MLP(100, 64, 10, batch_norm)
model_ln = MLP(100, 64, 10, layer_norm)
model_gn = MLP(100, 64, 10, partial(group_norm, num_groups=4))
# Pass the input tensor through the models
output_bn = model_bn(x)
output_ln = model_ln(x)
output_gn = model_gn(x)
# Print the outputs
print("Output with batch norm:\n", output_bn.shape)
print("\nOutput with layer norm:\n", output_ln.shape)
print("\nOutput with group norm:\n", output_gn.shape)
In [ ]: