14 lines
352 B
Python
14 lines
352 B
Python
import torch.nn as nn
|
|
|
|
class ProjectionHead(nn.Module):
|
|
def __init__(self, proj_dim=256):
|
|
super(ProjectionHead, self).__init__()
|
|
|
|
self.proj = nn.Sequential(
|
|
nn.Linear(proj_dim, proj_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(proj_dim, proj_dim),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.proj(x) |