CoCoOp给文本编码器带来了可变的prompt,而Maple给视觉编码器也带来了可变的prompt,不觉得这很酷吗,我觉得这真是太酷了,满足了我对CLIP最终形态的想象。
前言
上一篇读过的部分,这次会简单带过,主要看MaPLe引入的部分,另外就是阅读的顺序会有所变动,尽可能让阅读的体验更好一些。
Config
本次使用的config是vit_b16_c2_ep5_batch4_2ctx
class MaPLe(TrainerX):
MaPLe主训练代码的改动
总的来说,基本是没有什么变动的,只有一处有变化:
print("Turning off gradients in both the image and the text encoder")
name_to_update = "prompt_learner"
for name, param in self.model.named_parameters():
if name_to_update not in name:
# Make sure that VPT prompts are updated
if "VPT" in name:
param.requires_grad_(True)
else:
param.requires_grad_(False)
该段用于限制参数更新,除了prompt_learner和包含了VPT的参数外,均冻结。
prompt_learner表示提示学习,VPT
表示视觉提示。
class CustomCLIP(nn.Module)
自定义CLIP模块
MaPLe的CustomCLIP类初始化定义部分与CoCoOp是一样的:
def __init__(self, cfg, classnames, clip_model):
super().__init__()
self.prompt_learner = MultiModalPromptLearner(cfg, classnames, clip_model)
self.tokenized_prompts = self.prompt_learner.tokenized_prompts
self.image_encoder = clip_model.visual
self.text_encoder = TextEncoder(clip_model)
self.logit_scale = clip_model.logit_scale
self.dtype = clip_model.dtype
不过里面却大有不同,为了理解起来方便,我们先搁置定义内的结构,而是按照前向传播的顺序进行介绍:
def forward(self, image, label=None):
tokenized_prompts = self.tokenized_prompts
logit_scale = self.logit_scale.exp()
prompts, shared_ctx, deep_compound_prompts_text, deep_compound_prompts_vision = self.prompt_learner()
text_features = self.text_encoder(prompts, tokenized_prompts, deep_compound_prompts_text)
image_features = self.image_encoder(image.type(self.dtype), shared_ctx, deep_compound_prompts_vision)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
logits = logit_scale * image_features @ text_features.t()
if self.prompt_learner.training:
return F.cross_entropy(logits, label)
return logits
看过CoCoOp的话,这里应该就比较眼熟了吧?
首先是tokenized_prompts,这里的形状是(50,77),其中50是类的数量,77是序列长度,首先是将序列化后的提示取出,具体是如何进行序列化的我们可以稍后在MultiModalPromptLearner当中看到。
接下来的是将提示,上下文向量和深层文本提示,深层视觉提示取出:
prompts, shared_ctx, deep_compound_prompts_text, deep_compound_prompts_vision = self.prompt_learner()
带着输出内容的疑问进入MultiModalPromptLearner类,我们需要分别取出:
token化的提示tokenized_prompts,形状为(50,77)
原本的提示prompts,形状为(50,77,512)
共享的上下文向量shared_ctx,形状为(2,768)
文本分支的深层提示deep_compound_prompts_text,长度为8的列表,每层形状为(2,512)
视觉分支的深层提示deep_compound_prompts_vision,长度为8的列表,每层形状为(2,768)
根据512和768两种隐藏维,我们可以猜测到其属于视觉分支还是文本分支
class MultiModalPromptLearner(nn.Module)
从提示学习先开始看起,解答先前的问题
首先是初始化部分,切片着看
def __init__(self, cfg, classnames, clip_model):
super().__init__()
n_cls = len(classnames)
n_ctx = cfg.TRAINER.MAPLE.N_CTX
ctx_init = cfg.TRAINER.MAPLE.CTX_INIT
dtype = clip_model.dtype
ctx_dim = clip_model.ln_final.weight.shape[0]
clip_imsize = clip_model.visual.input_resolution
cfg_imsize = cfg.INPUT.SIZE[0]
到目前为止,基本都一致,没有什么差异,接下来引入了这次MaPLe新增的部分
assert cfg.TRAINER.MAPLE.PROMPT_DEPTH >= 1, "For MaPLe, PROMPT_DEPTH should be >= 1"
self.compound_prompts_depth = cfg.TRAINER.MAPLE.PROMPT_DEPTH # max=12, but will create 11 such shared prompts
assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
第一行是检查配置文件当中定义的提示学习深度,设计上来说不允许低于1,因为MaPLe的设计当中,同时插入提示到Transformer的浅层和深层是核心的设计理念。论文根据消融实验,9为最佳值,所以配置文件中设定深度为9层,这里传入的深度就是9了。
if ctx_init and (n_ctx) <= 4:
# use given words to initialize context vectors
ctx_init = ctx_init.replace("_", " ")
n_ctx = n_ctx
prompt = clip.tokenize(ctx_init)
with torch.no_grad():
embedding = clip_model.token_embedding(prompt).type(dtype)
ctx_vectors = embedding[0, 1: 1 + n_ctx, :]
prompt_prefix = ctx_init
else:
# random initialization
ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
nn.init.normal_(ctx_vectors, std=0.02)
prompt_prefix = " ".join(["X"] * n_ctx)
print('MaPLe design: Multi-modal Prompt Learning')
print(f'Initial context: "{prompt_prefix}"')
print(f"Number of MaPLe context words (tokens): {n_ctx}")
这一段虽然相较CoCoOp略有修改,但是基本没什么区别,就不多说了。
接下来看这次加入的映射层
# These below, related to the shallow prompts
# Linear layer so that the tokens will project to 512 and will be initialized from 768
self.proj = nn.Linear(ctx_dim, 768)
self.proj.half()
self.ctx = nn.Parameter(ctx_vectors)
之前说过,CLIP默认的文本编码器维度是512维,但是当需要处理视觉分支部分时就会遇到768维的输入,所以如果想同时将上下文向量输入到视觉编码器当中时,会出现维度不一致的问题。那么MaPLe就引入了一个MLP映射层,在需要的时候将其从512维映射到了768维。
不过这里声明的上下文向量仍旧是512维度。
接下来是引入了文本分支的深层提示学习,也就是输出之一deep_compound_prompts_text
# These below parameters related to the shared prompts
# Define the compound prompts for the deeper layers
# Minimum can be 1, which defaults to shallow MaPLe
# compound prompts
self.compound_prompts_text = nn.ParameterList([nn.Parameter(torch.empty(n_ctx, 512))
for _ in range(self.compound_prompts_depth - 1)])
for single_para in self.compound_prompts_text:
nn.init.normal_(single_para, std=0.02)
CoCoOp当中的上下文向量只会应用在输入层,而MaPLe提出了一种将其插入到Transformer多层之间的思想,称为深层提示(Deep Prompting),过往仅用于输入层的则称为浅层提示(Shallow Prompting)。
首先是初始化深层提示的上下文向量,然后根据设定的深度,生成一个具有多组张量的列表,数量需减去浅层,也即第一层输入层(self.ctx)合计8组,此处n_ctx是2,512隐藏维代表这里是文本分支。
关于为什么此处上下文向量的长度只有2,而不是4~16,是因为MaPLe是每层插入2,而非仅在输入层应用提示,所以上下文向量数量是累加的形式,具体可以参考下论文的消融实验,token数量在2和3的时候有比较好的泛化表现:
当然,初始的空值在接下来的迭代中用nn.init.normal_(single_para, std=0.02)方法,对其进行正态分布初始化,标准差为0.02,此时文本分支每层的可学习深层提示就初始化完毕了。
接下来做视觉分支的映射,定义了一个512维度到768维度的MLP,和之前的声明一致,这里不使用之前的self.proj而单独初始化,应该是为了可读性?因为用了_get_clones的方法的话,应该是直接可以用之前定义好的MLP用于克隆的。
# Also make corresponding projection layers, for each prompt
single_layer = nn.Linear(ctx_dim, 768)
self.compound_prompt_projections = _get_clones(single_layer, self.compound_prompts_depth - 1)
一共克隆了8层,对应深层提示的层数
接下来就是将提示进行嵌入化了,将类名,类数量,固定提示词传入,组合在一起,然后token化:
classnames = [name.replace("_", " ") for name in classnames]
name_lens = [len(_tokenizer.encode(name)) for name in classnames]
prompts = [prompt_prefix + " " + name + "." for name in classnames]
tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) # (n_cls, n_tkn)
with torch.no_grad():
embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
其中prompt_prefix是固定提示词'a photo of a',将类名拼接在后面,最后token化,这个在CoCoOp已经见过相同的处理方式了,最后就会得到之前我们想知道的tokenized_prompts,tokenized_prompts是不进行嵌入化这一步的。
接下来的初始化部分和提示的组成都是一致的了,跳过这个部分,直接到前向传播部分。
def forward(self):
ctx = self.ctx
if ctx.dim() == 2:
ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
prefix = self.token_prefix
suffix = self.token_suffix
prompts = self.construct_prompts(ctx, prefix, suffix)
注意到, 在第五行,MaPLe对ctx对齐图像的处理方式与CoCoOp不同,回忆下,CoCoOp的ctx对齐是:
ctx = self.ctx # (n_ctx, ctx_dim)
bias = self.meta_net(im_features) # (batch, ctx_dim)
bias = bias.unsqueeze(1) # (batch, 1, ctx_dim)
ctx = ctx.unsqueeze(0) # (1, n_ctx, ctx_dim)
ctx_shifted = ctx + bias # (batch, n_ctx, ctx_dim)
# Use instance-conditioned context tokens for all classes
prompts = []
for ctx_shifted_i in ctx_shifted:
ctx_i = ctx_shifted_i.unsqueeze(0).expand(self.n_cls, -1, -1)
pts_i = self.construct_prompts(ctx_i, prefix, suffix) # (n_cls, n_tkn, ctx_dim)
prompts.append(pts_i)
prompts = torch.stack(prompts)
同样是在(n_ctx, ctx_dim),CoCoOp采用了截然不同的方式去处理对齐,通过扩展一个维度,和声明一种基于图像的偏移量,对齐到与batch相等的维度,之后在循环当中对齐类别的数量。总的来说,CoCoOp是利用偏移量动态的对齐每个图像对应的文本提示。
而MaPLe在这个部分则是延续了以往CoOp的方式,为每个类别分配一个静态的上下文向量。
利用.unsqueeze(0).expand(self.n_cls, -1, -1),将原先形状为(n_ctx, ctx_dim)变换为了(n_cls ,n_ctx, ctx_dim),已知n_cls其实就是类别数量,所以这里就是让上下文向量与类别数量对齐,所有类别共享相同的提示。
最后的prompts就是将起始符[SOS],形状为(50,1,512),拼接上浅层提示的上下文学习向量[CTX],形状为(50,2,512)然后是长度为74的序列,形状为(50,74,512),最终就是(50,77,512)了。
后续我们还会见到替换这里的CTX为深层提示,但仍旧会保持这个格式,也就是[SOS]+[CTX]+....+[EOS]
继续向下读:
# Before returning, need to transform
# prompts to 768 for the visual side
visual_deep_prompts = []
for index, layer in enumerate(self.compound_prompt_projections):
visual_deep_prompts.append(layer(self.compound_prompts_text[index]))
# Now the other way around
# We will project the textual prompts from 512 to 768
return prompts, self.proj(self.ctx), self.compound_prompts_text, visual_deep_prompts # pass here original, as for visual 768 is required
这里声明了视觉分支的深层提示,其中的compound_prompt_projections就是之前克隆的映射层了,共计8层的线性层。
本处迭代器访问了两个元素,一个是index,用于访问层数,另一个是layer方法,利用layer方法,我们可以将形状为(2,512)的文本分支深层提示向量self.compound_prompts_text[index]输入到线性层(512,768)当中,经过线性变换得到适用于视觉分支的输出,新的形状为(2,768)。
那么到此,这个部分的代码就全部结束了,还差一个返回
# Now the other way around
# We will project the textual prompts from 512 to 768
return prompts, self.proj(self.ctx), self.compound_prompts_text, visual_deep_prompts # pass here original, as for visual 768 is required
第一个返回值prompts我们已经清楚了,形状是(50,77,512)。
第二个返回值是用了一个映射层,将上下文向量从(2,512)映射到了(2,768),从变量名shared_ctx我们可以猜测这是文本分支的上下文向量,同时共享给了视觉分支,具体是否是这个用途,我们可以先按下不表,到ViT的时候再予以验证,只需要记住这里的原始内容是由self.ctx映射得到的就好了。
第三个返回值是文本分支的深层提示,形状是(8,2,512),或者说长度为8的列表,每层为(2,512)
第四个是视觉分支的深层提示,是用文本分支的深层提示从512维映射到768维得到的,形状是(8,2,768),或者说长度为8的列表,每层为(2,768)
那么提示模块就看完了,输出的来源也得知了,唯一等待关注的就是shared_ctx。
回到CustomCLIP类,继续向下看。
text_features = self.text_encoder(prompts, tokenized_prompts, deep_compound_prompts_text)
image_features = self.image_encoder(image.type(self.dtype), shared_ctx, deep_compound_prompts_vision)
短短两行将先前的输出都用上了,可以看到,512维的都进入了文本编码器,768维的都进入了视觉编码器,我们先按照顺序来看文本编码器的部分。
文本编码器传入了三个变量,分别是形状为(50,77,512)的prompts提示,token化的提示,文本分支的深层提示。
提示的内容已经很清楚了,50是类,77是SOS/CTX/…+EOS,512维度对齐文本编码器隐藏维。
tokenized_prompts的形状是(50,77),没有进行嵌入化也没有插入上下文向量的原始prompts,内容为预设的提示词和类名的token化。
deep_compound_prompts_text就是深层提示了,形状为(8,2,512)
进入文本编码器,看一下是怎么处理这三个输入并得到文本特征的。
class TextEncoder(nn.Module):
文本编码器部分
def __init__(self, clip_model):
super().__init__()
self.transformer = clip_model.transformer
self.positional_embedding = clip_model.positional_embedding
self.ln_final = clip_model.ln_final
self.text_projection = clip_model.text_projection
self.dtype = clip_model.dtype
声明部分一笔带过,主要是看接下来的前向传播部分,关注下文本编码器是怎么从我们输入的三个变量中进行提取特征的:
def forward(self, prompts, tokenized_prompts, compound_prompts_deeper_text):
x = prompts + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
# Pass as the list, as nn.sequential cannot process multiple arguments in the forward pass
combined = [x, compound_prompts_deeper_text, 0] # third argument is the counter which denotes depth of prompt
outputs = self.transformer(combined)
x = outputs[0] # extract the x back from here
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
基本上与CoCoOp是一样的,首先是x = prompts + self.positional_embedding.type(self.dtype)对提示进行位置编码,接下来是维度变换为Transformer的格式,然后将提示组合。
MaPLe新加入部分的就是这里的提示组合:
combined = [x, compound_prompts_deeper_text, 0] # third argument is the counter which denotes depth of prompt
实际上就是一个列表,第一位是已经变换了形状的提示,第二位是传入的深层提示,第三位是计数器。
接下来直接输入到修改的transformer当中:
outputs = self.transformer(combined)
先看下输出,和输入的格式是相同的:
第一位是对应的提示,形状为(77,50,512)
第二位是对应的深层提示,形状为(8,2,512)
第三位是计数器,这里是8,因为过了八层
最后是经过一个映射层,将全局信息提取出来
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
return x
映射层会提取出长度为77的嵌入序列当中的[EOS],即实际的最后一位包含信息的嵌入,这个嵌入包含了整个序列的全局信息
最后提取得到的这个嵌入形状为(50,512),即50个分类,每个序列长度77,提取每个类当中的[EOS]
回到MultiModalPromptLearner类,我们接着看,紧接着是视觉分支
image_features = self.image_encoder(image.type(self.dtype), shared_ctx, deep_compound_prompts_vision)
回忆一下,传入到视觉编码器的参数一共是三个,第一个是输入图像,第二个是共享的上下文向量,也就是映射到768维度的浅层提示,第三个是从文本深层提示克隆到768维的。
我们进入视觉编码器,看看是如何提取特征的
class VisionTransformer_MaPLe(nn.Module)
视觉编码器部分
首先看初始化部分
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int,
design_details):
super().__init__()
self.input_resolution = input_resolution
self.output_dim = output_dim
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
self.VPT_shallow = True
scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
self.ln_pre = LayerNorm(width)
# hyper-parameter if need to add prompt embeddings inside to the input
# of transformer block or not:
self.prompt_till_layer_visual = 0
self.transformer = Transformer(width, layers, heads, design_details=design_details)
self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
第六行self.conv1声明了一个卷积层,其卷积核为(16,16),步长为(16,16),输入通道为3。这个卷积层用于将图像嵌入化
第七行,声明VPT_shallow浅层提示为True
第九行self.class_embedding初始化了CLS标识符,按照VIT的官方说法,CLS主要是用于对齐Ttransformer的,本身没有重要含义[1]
由于图像嵌入本身是没有位置信息的,所以第十行引入了位置编码的实现,形状是(197,768),其中序列长度为(\frac{224}{16})^2+1=197,宽度为768,与ViT维度对齐。197就是序列的长度,包括图像patch的长度和CLS标签,一共197。
第十四行定义了插入提示嵌入的设定,默认为0,即不插入。
该行在本处主要是声明作用,在VisionTransformer类中会影响VPT的启用,prompt_till_layer_visual将由字典design_details["vision_depth"]赋值,当为0时,浅层提示self.VPT_shallow = False。
而design_details字典则在第十五行出现,用于传入到Transformer类当中,该参数为:
'trainer' = 'MaPLe'
'vision_depth' = 0
'language_depth' = 0
'vision_ctx' = 0
'language_ctx' = 0
'maple_length' = 2
vision_depth是0,意味此处传入Transformer是关闭了浅层提示的,同时出了maple_length都是0,本处实际是不设固定值,而是选择以counter动态添加的形式进行。
这里的Transformer模块MaPLe有大幅修改,以使深层提示可以插入到编码器的每一层,但是这里我们先向下接着看前向传播,有点长,切一部分:
def forward(self, x: torch.Tensor, shared_ctx, compound_deeper_prompts):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat(
[self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
x], dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
这个部分是CLIP对ViT的实现,先前没有读过,所以这里也稍微讲下
首先是经过self.conv1(x)卷积层,将其patch化,形状是(2,768,14,14),其中2是batch,14是patch大小。具体来说,经过卷积层后图像会按照16步长,16卷积核的大小被切成14×14的patch,合计196。
所以接下来的形状转换就变成了(2,768,196)。
再经过一次变换适配ViT的格式后,进行了一次分类标记的拼接,x的形状从(2,196,768)变为了(2,197,768),多了一个分类标记。
接下来x = x + self.positional_embedding.to(x.dtype)给patch广播加上了位置编码。
然后是这次MaPLe引入的部分
# After positional embeddings, we will attach prompts with the model, remember only those
# are trainable parameters here in whole image encoder.
if self.VPT_shallow:
visual_ctx = shared_ctx.expand(x.shape[0], -1, -1).half()
x = torch.cat([x, visual_ctx], dim=1)
else:
assert self.prompt_till_layer_visual == 0
首先做了一个启用浅层提示的判断,当为启用的时候会将浅层提示拼接在x的后面。在此之前首先会用shared_ctx.expand(x.shape[0], -1, -1).half() 做一个尺寸对齐,这里两者都是batch等于2,没有什么变化。
然后使用torch.cat([x, visual_ctx], dim=1)拼接上x和浅层提示(2,2,768),x的形状变为了(2,199,768)
接下来就是与原版ViT一致的处理,先前在CoCoOp有所说明
# Normal code as before
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
# Again combine the inputs, so nn.sequential can work
outputs = self.transformer([x, compound_deeper_prompts, 0]) # third argument is counter
x = outputs[0]
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_post(x[:, 0, :])
if self.proj is not None:
x = x @ self.proj
return x
这里的outputs来自于Transformer,那么传入的三个参数在文本分支的时候我们有说过,这里也是一样的,区别是隐藏维变成了768
经过Transformer以后x的形状是(199,2,768),与输入前一致。
关于Transformer是如何处理输入的patch和深层提示的,我们可以进入ViT的Transformer部分一探究竟
class Transformer(nn.Module)
引入了修改的残差注意力层
首先看向初始化,值得注意的是,除了MaPLe,这里还有IVLP的实现,不过这就不是我们这次需要关注的部分了。
截取掉不需要的部分,只关注MaPLe的执行代码:
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, prompts_needed=0,
text_layer=False, design_details=None):
super().__init__()
self.width = width
self.layers = layers
# Implements respective encoder blocks for a given design choice
current_trainer = design_details['trainer']
elif current_trainer == 'MaPLe':
self.resblocks = nn.Sequential(
*[ResidualAttentionBlock_MaPLe(width, heads, attn_mask, design_details, text_layer, i)
for i in range(layers)])
def forward(self, x: torch.Tensor):
return self.resblocks(x)
这里传入的参数分别是:
width: 768
heads: 12
attn_mask: None
design_details:如之前
text_layer: False
i: 12
也就是ViT-16/b的规格,输入768维度,共计12个头,共计12层
看下来,这里主要是配置的定义初始化,即设定好宽度和配置文件等信息,而我们的图像patch会在前向传播的时候传入到残差注意力层当中。
带着这些参数和输入,进入残差注意力层看一下,另外深层提示我们还未看到是如何处理的
class ResidualAttentionBlock_MaPLe(nn.Module):
修改的残差注意力块
首先是初始化新增的部分:
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, design_details=None,
# For the first iteration i, we do not need to add the learnable parameters here
# as it will be added in the beginning, for both text and the vision branch
self.text_layer = text_layer
self.attn_mask = attn_mask
# This must be consistent with the config file prompt
self.compound_prompt_nctx = design_details['maple_length']
if i == 0:
self.first_layer = True
else:
self.first_layer = False
首先是新增的三个传入参数:
design_details=None
text_layer=False
i=0
design_details已经知道了,不做解释。
text_layer是层层递进一直传递下来的参数,用来区分视觉模态和文本模态,两者之间作特征提取的时候是不同的,视觉分支下进行特征提取采用卷积层,而文本模态的文本嵌入就不能这么干了,具体在CLIP调用的时候就会体现这个参数的用途,当CLIP用来提取文本嵌入的时候,此处就会从false改为true,不指定情况下则默认为false。
i用来标记目前所在的残差层数的计数器,默认0为第一层,每层构建时都会增加计数器,利用这个计数器就可以分辨当前是位于第一层(0)还是更深层。
接下来,到修改的前向传播部分,这段改动相当多,我们一点点切:
def forward(self, inputs):
# For the first layer, we do not need to add any duplicate, as it is already added
# as the shallow version
x = inputs[0]
compound_prompts_deeper = inputs[1]
counter = inputs[2]
首先是inputs当中的内容,很显然,就是我们从文本/视觉编码器传递下来的输入了,不论是文本还是视觉,都是三项:x,深层提示,计数器
首先这里是将三项取出来,定义好局部变量方便后续调用
接下来是做层数判断,区分目前位于浅层(输入层)还是深层
if not self.first_layer:
if len(compound_prompts_deeper) > 0:
这么做的原因很简单,不在浅层重复添加已经添加过了的浅层提示。
具体在哪里添加的呢?如果已经忘记了的话,以视觉分支为例,返回到之前的VisionTransformer_MaPLe类,其中在前向传播中有一段:
# After positional embeddings, we will attach prompts with the model, remember only those
# are trainable parameters here in whole image encoder.
if self.VPT_shallow:
visual_ctx = shared_ctx.expand(x.shape[0], -1, -1).half()
x = torch.cat([x, visual_ctx], dim=1)
else:
assert self.prompt_till_layer_visual == 0
注意看这行:
visual_ctx = shared_ctx.expand(x.shape[0], -1, -1).half()
此处插入了浅层提示,此时的图像张量x形状是(2,197, 768),也就是有2个batch,将x的batch形状取出,保持原本ctx的形状,复制到batch上,得到有两个相同ctx的visual_ctx,也就是最终的浅层提示。
将浅层提示拼接(torch.cat)到x的尾部,得到最终的x:
x = torch.cat([x, visual_ctx], dim=1)
此时形状是(2,199,768),作为输入,传入到transformer模块当中:
outputs = self.transformer([x, compound_deeper_prompts, 0]) # third argument is counter
同时一并传入的还有深层提示compound_deeper_prompts,此时的深层提示形状为(2,768),序列长度为8。回忆一下,配置文件当中的提示总插入深度为PROMPT_DEPTH: 9。本处长度为8是深层提示的插入数量,由总计9层插入,但减去了第一层输入层,也即浅层提示的部分而得到的,也即总计插入了1层浅层提示和8层深层提示,浅层由x = torch.cat([x, visual_ctx], dim=1)拼接传入,深层由compound_deeper_prompts传入。
深层提示形状(2,768)其中的2是上下文向量长度,768对齐隐藏层维度,如果是文本分支,那么这里应该是512。
继续回到残差注意力块,首先从视觉分支的深层提示处理部分看起:
if not self.text_layer:
# First check if the ith layer needs compound prompts or not
if not (counter > len(compound_prompts_deeper) - 1):
# Remove the outputs produced by learnable tokens of previous layer
prefix = x[0:x.shape[0] - self.compound_prompt_nctx, :, :]
# Create/configure learnable tokens of this layer
visual_context = compound_prompts_deeper[counter] # extract the correct index
visual_context = visual_context.expand(x.shape[1], -1, -1).permute(1, 0, 2).half()
# Add the learnable tokens of this layer with the input, by replacing previous
# layer learnable tokens
x = torch.cat([prefix, visual_context], dim=0)
# Once done, update the counter, so that the next time, it does not use same learnable tokens
counter += 1
首先x[0:x.shape[0] - self.compound_prompt_nctx, :, :]是将嵌入和先前的浅层提示分开,x的形状是(199,2,768),所以x.shape[0]就是199了,接下来减掉先前拼接上去的浅层提示,得到拼接前的197,其他不变,所以prefix的形状就是(197,2,768)
接下来是深层提示的声明,这里原版的注释也说明了用途:
# Create/configure learnable tokens of this layer
visual_context = compound_prompts_deeper[counter] # extract the correct index
这行代码的意味就是定义了深层提示,其中compound_prompts_deeper的长度就是8,每层形状是(2,768)也就是每一层的可学习深层提示,每一层的深层提示都不同,前面提到,counter是每一层的计数器,而compound_prompts_deeper实质上就是一个包含了8层深层提示的列表,所以这里会根据层数,选择定义用于插入的深度提示。
接下来是对齐batch_size
visual_context = visual_context.expand(x.shape[1], -1, -1).permute(1, 0, 2).half()
然后将深度提示拼接上输入后面
# Add the learnable tokens of this layer with the input, by replacing previous
# layer learnable tokens
x = torch.cat([prefix, visual_context], dim=0)
相当于我们再次做了一次浅层提示的插入行为,不过区别是这里插入的是深层提示,但是操作是一致的,后续还会根据层数再删除掉本次添加的深层提示,后拼接上新的深层提示,原因是每层我们初始化的可学习提示的向量都是不同的。
这样输入x又变回了带有可学习的提示的199维:(199,2,768)
最后触发计数器
# Once done, update the counter, so that the next time, it does not use same learnable tokens
counter += 1
另一条分支是文本分支:
# First check if the ith layer needs compound prompts or not
if not (counter > len(compound_prompts_deeper) - 1):
# Appending the learnable tokens in different way
# x -> [77, NCLS, DIM]
# First remove the learnable tokens from previous layer
prefix = x[:1, :, :]
suffix = x[1 + self.compound_prompt_nctx:, :, :]
# Create/configure learnable tokens of this layer
textual_context = compound_prompts_deeper[counter]
textual_context = textual_context.expand(x.shape[1], -1, -1).permute(1, 0, 2).half()
# Add the learnable tokens of this layer with the input, replaced by previous
# layer learnable tokens
x = torch.cat([prefix, textual_context, suffix], dim=0)
# Once done, update the counter, so that the next time, it does not use same learnable tokens
counter += 1
与视觉分支第一个不同的地方就是嵌入的处理方式,因为在插入浅层提示时的方式是不同的,所以处理起来也不同。在视觉分支时,方法是将提示拼接在嵌入之后,而文本则是有序列起始符,CTX和填充的序列之别,浅层提示CTX是拼接在起始符之后与填充内容之前的,如果忘记了的话可以回到MultiModalPromptLearner的文本分支部分再看下。
首先文本嵌入的形状是(77,50,512),其中序列的第一位是起始符,将起始符取出:
prefix = x[:1, :, :]
接下来再取出起始符和先前插入的可学习的浅层提示,当然,如果是第二次的话那替换掉的就是上一层的深层提示了,这与视觉分支处理的逻辑是一致的。
保留原本的序列,此处compound_prompt_nctx为2,故形状为(74,50,512):
suffix = x[1 + self.compound_prompt_nctx:, :, :]
然后相同的方法定义适用于文本分支的深层提示:
textual_context = compound_prompts_deeper[counter]
textual_context = textual_context.expand(x.shape[1], -1, -1).permute(1, 0, 2).half()
接下来插入到起始符和原本内容之间,实现深层提示插入:
# Add the learnable tokens of this layer with the input, replaced by previous
# layer learnable tokens
x = torch.cat([prefix, textual_context, suffix], dim=0)
然后是计数器的递增
最后,两个分支都会进行Transformer的残差自注意力实现,唯一改动是返回时变为列表的形式,多了深层提示列表和计数器,用于适配pytorch:
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return [x, compound_prompts_deeper, counter] # return again as a list, so that nn.seq can work
后言
对于数据经过每一层的变化和具体的形状含义,在阅读全部代码的时候或许会有些懵,我打算单独开一篇,将MaPLe的形状详细变化和输入输出的数据重新整理下,因为我认为在有些地方的解释尚有欠缺,尤其是提示的组成部分。
你好大佬! 你的文章非常非常清晰,让我这个小白也能看懂,但是其中关于方法中prefix 是[0, :1, …], suffix是[0, 1 + n_ctx:, :],那中间原有的文本嵌入的1~n_ctx是直接被可学习的文本给替代了嘛?原有的那部分就不会使用了嘛?
嗯嗯,原有的嵌入会被替代掉,不会使用了,每层的CTX是不传递的。具体原因可以参考论文的实验部分,我印象里是有针对这一问题的实验的,传递会导致性能下降