学习笔记 – MaPLe代码精读

CoCoOp给文本编码器带来了可变的prompt,而Maple给视觉编码器也带来了可变的prompt,不觉得这很酷吗,我觉得这真是太酷了,满足了我对CLIP最终形态的想象。

前言

上一篇读过的部分,这次会简单带过,主要看MaPLe引入的部分,另外就是阅读的顺序会有所变动,尽可能让阅读的体验更好一些。

Config

本次使用的config是vit_b16_c2_ep5_batch4_2ctx

class MaPLe(TrainerX):

MaPLe主训练代码的改动

总的来说,基本是没有什么变动的,只有一处有变化:

        print("Turning off gradients in both the image and the text encoder")
        name_to_update = "prompt_learner"

        for name, param in self.model.named_parameters():
            if name_to_update not in name:
                # Make sure that VPT prompts are updated
                if "VPT" in name:
                    param.requires_grad_(True)
                else:
                    param.requires_grad_(False)

该段用于限制参数更新,除了prompt_learner和包含了VPT的参数外,均冻结。

prompt_learner表示提示学习,VPT表示视觉提示。

class CustomCLIP(nn.Module)

自定义CLIP模块

MaPLe的CustomCLIP类初始化定义部分与CoCoOp是一样的:

    def __init__(self, cfg, classnames, clip_model):
        super().__init__()
        self.prompt_learner = MultiModalPromptLearner(cfg, classnames, clip_model)
        self.tokenized_prompts = self.prompt_learner.tokenized_prompts
        self.image_encoder = clip_model.visual
        self.text_encoder = TextEncoder(clip_model)
        self.logit_scale = clip_model.logit_scale
        self.dtype = clip_model.dtype

不过里面却大有不同,为了理解起来方便,我们先搁置定义内的结构,而是按照前向传播的顺序进行介绍:

    def forward(self, image, label=None):
        tokenized_prompts = self.tokenized_prompts
        logit_scale = self.logit_scale.exp()

        prompts, shared_ctx, deep_compound_prompts_text, deep_compound_prompts_vision = self.prompt_learner()
        text_features = self.text_encoder(prompts, tokenized_prompts, deep_compound_prompts_text)
        image_features = self.image_encoder(image.type(self.dtype), shared_ctx, deep_compound_prompts_vision)

        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        logits = logit_scale * image_features @ text_features.t()

        if self.prompt_learner.training:
            return F.cross_entropy(logits, label)

        return logits

看过CoCoOp的话,这里应该就比较眼熟了吧?

首先是tokenized_prompts,这里的形状是(50,77),其中50是类的数量,77是序列长度,首先是将序列化后的提示取出,具体是如何进行序列化的我们可以稍后在MultiModalPromptLearner当中看到。

接下来的是将提示,上下文向量和深层文本提示,深层视觉提示取出:

        prompts, shared_ctx, deep_compound_prompts_text, deep_compound_prompts_vision = self.prompt_learner()

带着输出内容的疑问进入MultiModalPromptLearner类,我们需要分别取出:

token化的提示tokenized_prompts,形状为(50,77)

原本的提示prompts,形状为(50,77,512)

共享的上下文向量shared_ctx,形状为(2,768)

文本分支的深层提示deep_compound_prompts_text,长度为8的列表,每层形状为(2,512)

视觉分支的深层提示deep_compound_prompts_vision,长度为8的列表,每层形状为(2,768)

根据512和768两种隐藏维,我们可以猜测到其属于视觉分支还是文本分支

class MultiModalPromptLearner(nn.Module)

从提示学习先开始看起,解答先前的问题

首先是初始化部分,切片着看

    def __init__(self, cfg, classnames, clip_model):
        super().__init__()
        n_cls = len(classnames)
        n_ctx = cfg.TRAINER.MAPLE.N_CTX
        ctx_init = cfg.TRAINER.MAPLE.CTX_INIT
        dtype = clip_model.dtype
        ctx_dim = clip_model.ln_final.weight.shape[0]
        clip_imsize = clip_model.visual.input_resolution
        cfg_imsize = cfg.INPUT.SIZE[0]

到目前为止,基本都一致,没有什么差异,接下来引入了这次MaPLe新增的部分

        assert cfg.TRAINER.MAPLE.PROMPT_DEPTH >= 1, "For MaPLe, PROMPT_DEPTH should be >= 1"
        self.compound_prompts_depth = cfg.TRAINER.MAPLE.PROMPT_DEPTH  # max=12, but will create 11 such shared prompts
        assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"

第一行是检查配置文件当中定义的提示学习深度,设计上来说不允许低于1,因为MaPLe的设计当中,同时插入提示到Transformer的浅层和深层是核心的设计理念。论文根据消融实验,9为最佳值,所以配置文件中设定深度为9层,这里传入的深度就是9了。

        if ctx_init and (n_ctx) <= 4:
            # use given words to initialize context vectors
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = n_ctx
            prompt = clip.tokenize(ctx_init)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).type(dtype)
            ctx_vectors = embedding[0, 1: 1 + n_ctx, :]
            prompt_prefix = ctx_init
        else:
            # random initialization
            ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
            nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)
        print('MaPLe design: Multi-modal Prompt Learning')
        print(f'Initial context: "{prompt_prefix}"')
        print(f"Number of MaPLe context words (tokens): {n_ctx}")

这一段虽然相较CoCoOp略有修改,但是基本没什么区别,就不多说了。

接下来看这次加入的映射层

        # These below, related to the shallow prompts
        # Linear layer so that the tokens will project to 512 and will be initialized from 768
        self.proj = nn.Linear(ctx_dim, 768)
        self.proj.half()
        self.ctx = nn.Parameter(ctx_vectors)

之前说过,CLIP默认的文本编码器维度是512维,但是当需要处理视觉分支部分时就会遇到768维的输入,所以如果想同时将上下文向量输入到视觉编码器当中时,会出现维度不一致的问题。那么MaPLe就引入了一个MLP映射层,在需要的时候将其从512维映射到了768维。

不过这里声明的上下文向量仍旧是512维度。

接下来是引入了文本分支的深层提示学习,也就是输出之一deep_compound_prompts_text

        # These below parameters related to the shared prompts
        # Define the compound prompts for the deeper layers

        # Minimum can be 1, which defaults to shallow MaPLe
        # compound prompts
        self.compound_prompts_text = nn.ParameterList([nn.Parameter(torch.empty(n_ctx, 512))
                                                      for _ in range(self.compound_prompts_depth - 1)])
        for single_para in self.compound_prompts_text:
            nn.init.normal_(single_para, std=0.02)

CoCoOp当中的上下文向量只会应用在输入层,而MaPLe提出了一种将其插入到Transformer多层之间的思想,称为深层提示(Deep Prompting),过往仅用于输入层的则称为浅层提示(Shallow Prompting)。

首先是初始化深层提示的上下文向量,然后根据设定的深度,生成一个具有多组张量的列表,数量需减去浅层,也即第一层输入层(self.ctx)合计8组,此处n_ctx是2,512隐藏维代表这里是文本分支。

关于为什么此处上下文向量的长度只有2,而不是4~16,是因为MaPLe是每层插入2,而非仅在输入层应用提示,所以上下文向量数量是累加的形式,具体可以参考下论文的消融实验,token数量在2和3的时候有比较好的泛化表现:

源论文 Figure 4 提示学习超参数消融实验

当然,初始的空值在接下来的迭代中用nn.init.normal_(single_para, std=0.02)方法,对其进行正态分布初始化,标准差为0.02,此时文本分支每层的可学习深层提示就初始化完毕了。

接下来做视觉分支的映射,定义了一个512维度到768维度的MLP,和之前的声明一致,这里不使用之前的self.proj而单独初始化,应该是为了可读性?因为用了_get_clones的方法的话,应该是直接可以用之前定义好的MLP用于克隆的。

        # Also make corresponding projection layers, for each prompt
        single_layer = nn.Linear(ctx_dim, 768)
        self.compound_prompt_projections = _get_clones(single_layer, self.compound_prompts_depth - 1)

一共克隆了8层,对应深层提示的层数

接下来就是将提示进行嵌入化了,将类名,类数量,固定提示词传入,组合在一起,然后token化:

        classnames = [name.replace("_", " ") for name in classnames]
        name_lens = [len(_tokenizer.encode(name)) for name in classnames]
        prompts = [prompt_prefix + " " + name + "." for name in classnames]

        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])  # (n_cls, n_tkn)
        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)

其中prompt_prefix是固定提示词'a photo of a',将类名拼接在后面,最后token化,这个在CoCoOp已经见过相同的处理方式了,最后就会得到之前我们想知道的tokenized_promptstokenized_prompts是不进行嵌入化这一步的。

接下来的初始化部分和提示的组成都是一致的了,跳过这个部分,直接到前向传播部分。

    def forward(self):
        ctx = self.ctx

        if ctx.dim() == 2:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)

        prefix = self.token_prefix
        suffix = self.token_suffix
        prompts = self.construct_prompts(ctx, prefix, suffix)

注意到, 在第五行,MaPLe对ctx对齐图像的处理方式与CoCoOp不同,回忆下,CoCoOp的ctx对齐是:

        ctx = self.ctx                     # (n_ctx, ctx_dim)
        bias = self.meta_net(im_features)  # (batch, ctx_dim)
        bias = bias.unsqueeze(1)           # (batch, 1, ctx_dim)
        ctx = ctx.unsqueeze(0)             # (1, n_ctx, ctx_dim)
        ctx_shifted = ctx + bias           # (batch, n_ctx, ctx_dim)
        
        # Use instance-conditioned context tokens for all classes
        prompts = []
        for ctx_shifted_i in ctx_shifted:
            ctx_i = ctx_shifted_i.unsqueeze(0).expand(self.n_cls, -1, -1)
            pts_i = self.construct_prompts(ctx_i, prefix, suffix)  # (n_cls, n_tkn, ctx_dim)
            prompts.append(pts_i)
        prompts = torch.stack(prompts)

同样是在(n_ctx, ctx_dim),CoCoOp采用了截然不同的方式去处理对齐,通过扩展一个维度,和声明一种基于图像的偏移量,对齐到与batch相等的维度,之后在循环当中对齐类别的数量。总的来说,CoCoOp是利用偏移量动态的对齐每个图像对应的文本提示。

而MaPLe在这个部分则是延续了以往CoOp的方式,为每个类别分配一个静态的上下文向量。

利用.unsqueeze(0).expand(self.n_cls, -1, -1),将原先形状为(n_ctx, ctx_dim)变换为了(n_cls ,n_ctx, ctx_dim),已知n_cls其实就是类别数量,所以这里就是让上下文向量与类别数量对齐,所有类别共享相同的提示。

最后的prompts就是将起始符[SOS],形状为(50,1,512),拼接上浅层提示的上下文学习向量[CTX],形状为(50,2,512)然后是长度为74的序列,形状为(50,74,512),最终就是(50,77,512)了。

后续我们还会见到替换这里的CTX为深层提示,但仍旧会保持这个格式,也就是[SOS]+[CTX]+....+[EOS]

继续向下读:

        # Before returning, need to transform
        # prompts to 768 for the visual side
        visual_deep_prompts = []
        for index, layer in enumerate(self.compound_prompt_projections):
            visual_deep_prompts.append(layer(self.compound_prompts_text[index]))
        # Now the other way around
        # We will project the textual prompts from 512 to 768
        return prompts, self.proj(self.ctx), self.compound_prompts_text, visual_deep_prompts   # pass here original, as for visual 768 is required

这里声明了视觉分支的深层提示,其中的compound_prompt_projections就是之前克隆的映射层了,共计8层的线性层。

本处迭代器访问了两个元素,一个是index,用于访问层数,另一个是layer方法,利用layer方法,我们可以将形状为(2,512)的文本分支深层提示向量self.compound_prompts_text[index]输入到线性层(512,768)当中,经过线性变换得到适用于视觉分支的输出,新的形状为(2,768)

那么到此,这个部分的代码就全部结束了,还差一个返回

        # Now the other way around
        # We will project the textual prompts from 512 to 768
        return prompts, self.proj(self.ctx), self.compound_prompts_text, visual_deep_prompts   # pass here original, as for visual 768 is required

第一个返回值prompts我们已经清楚了,形状是(50,77,512)

第二个返回值是用了一个映射层,将上下文向量从(2,512)映射到了(2,768),从变量名shared_ctx我们可以猜测这是文本分支的上下文向量,同时共享给了视觉分支,具体是否是这个用途,我们可以先按下不表,到ViT的时候再予以验证,只需要记住这里的原始内容是由self.ctx映射得到的就好了。

第三个返回值是文本分支的深层提示,形状是(8,2,512),或者说长度为8的列表,每层为(2,512)

第四个是视觉分支的深层提示,是用文本分支的深层提示从512维映射到768维得到的,形状是(8,2,768),或者说长度为8的列表,每层为(2,768)

那么提示模块就看完了,输出的来源也得知了,唯一等待关注的就是shared_ctx

回到CustomCLIP类,继续向下看。

        text_features = self.text_encoder(prompts, tokenized_prompts, deep_compound_prompts_text)
        image_features = self.image_encoder(image.type(self.dtype), shared_ctx, deep_compound_prompts_vision)

短短两行将先前的输出都用上了,可以看到,512维的都进入了文本编码器,768维的都进入了视觉编码器,我们先按照顺序来看文本编码器的部分。

文本编码器传入了三个变量,分别是形状为(50,77,512)的prompts提示,token化的提示,文本分支的深层提示。

提示的内容已经很清楚了,50是类,77是SOS/CTX/…+EOS,512维度对齐文本编码器隐藏维。

tokenized_prompts的形状是(50,77),没有进行嵌入化也没有插入上下文向量的原始prompts,内容为预设的提示词和类名的token化。

deep_compound_prompts_text就是深层提示了,形状为(8,2,512)

进入文本编码器,看一下是怎么处理这三个输入并得到文本特征的。


class TextEncoder(nn.Module):

文本编码器部分

    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype

声明部分一笔带过,主要是看接下来的前向传播部分,关注下文本编码器是怎么从我们输入的三个变量中进行提取特征的:

    def forward(self, prompts, tokenized_prompts, compound_prompts_deeper_text):
        x = prompts + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        # Pass as the list, as nn.sequential cannot process multiple arguments in the forward pass
        combined = [x, compound_prompts_deeper_text, 0]  # third argument is the counter which denotes depth of prompt
        outputs = self.transformer(combined)
        x = outputs[0]  # extract the x back from here
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection

基本上与CoCoOp是一样的,首先是x = prompts + self.positional_embedding.type(self.dtype)对提示进行位置编码,接下来是维度变换为Transformer的格式,然后将提示组合。

MaPLe新加入部分的就是这里的提示组合:

        combined = [x, compound_prompts_deeper_text, 0]  # third argument is the counter which denotes depth of prompt

实际上就是一个列表,第一位是已经变换了形状的提示,第二位是传入的深层提示,第三位是计数器。

接下来直接输入到修改的transformer当中:

        outputs = self.transformer(combined)

先看下输出,和输入的格式是相同的:

第一位是对应的提示,形状为(77,50,512)

第二位是对应的深层提示,形状为(8,2,512)

第三位是计数器,这里是8,因为过了八层

最后是经过一个映射层,将全局信息提取出来

        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection

        return x

映射层会提取出长度为77的嵌入序列当中的[EOS],即实际的最后一位包含信息的嵌入,这个嵌入包含了整个序列的全局信息

最后提取得到的这个嵌入形状为(50,512),即50个分类,每个序列长度77,提取每个类当中的[EOS]


回到MultiModalPromptLearner类,我们接着看,紧接着是视觉分支

        image_features = self.image_encoder(image.type(self.dtype), shared_ctx, deep_compound_prompts_vision)

回忆一下,传入到视觉编码器的参数一共是三个,第一个是输入图像,第二个是共享的上下文向量,也就是映射到768维度的浅层提示,第三个是从文本深层提示克隆到768维的。

我们进入视觉编码器,看看是如何提取特征的


class VisionTransformer_MaPLe(nn.Module)

视觉编码器部分

首先看初始化部分

    def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int,
                 design_details):
        super().__init__()
        self.input_resolution = input_resolution
        self.output_dim = output_dim
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
        self.VPT_shallow = True
        scale = width ** -0.5
        self.class_embedding = nn.Parameter(scale * torch.randn(width))
        self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
        self.ln_pre = LayerNorm(width)
        # hyper-parameter if need to add prompt embeddings inside to the input
        # of transformer block or not:
        self.prompt_till_layer_visual = 0
        self.transformer = Transformer(width, layers, heads, design_details=design_details)

        self.ln_post = LayerNorm(width)
        self.proj = nn.Parameter(scale * torch.randn(width, output_dim))

第六行self.conv1声明了一个卷积层,其卷积核为(16,16),步长为(16,16),输入通道为3。这个卷积层用于将图像嵌入化

源论文 Fig. 1 ViT Overview

第七行,声明VPT_shallow浅层提示为True

第九行self.class_embedding初始化了CLS标识符,按照VIT的官方说法,CLS主要是用于对齐Ttransformer的,本身没有重要含义[1]

由于图像嵌入本身是没有位置信息的,所以第十行引入了位置编码的实现,形状是(197,768),其中序列长度为(\frac{224}{16})^2+1=197,宽度为768,与ViT维度对齐。197就是序列的长度,包括图像patch的长度和CLS标签,一共197

第十四行定义了插入提示嵌入的设定,默认为0,即不插入。

该行在本处主要是声明作用,在VisionTransformer类中会影响VPT的启用,prompt_till_layer_visual将由字典design_details["vision_depth"]赋值,当为0时,浅层提示self.VPT_shallow = False

design_details字典则在第十五行出现,用于传入到Transformer类当中,该参数为:

'trainer' = 'MaPLe'
'vision_depth' = 0
'language_depth' = 0
'vision_ctx' = 0
'language_ctx' = 0
'maple_length' = 2

vision_depth是0,意味此处传入Transformer是关闭了浅层提示的,同时出了maple_length都是0,本处实际是不设固定值,而是选择以counter动态添加的形式进行。

这里的Transformer模块MaPLe有大幅修改,以使深层提示可以插入到编码器的每一层,但是这里我们先向下接着看前向传播,有点长,切一部分:

    def forward(self, x: torch.Tensor, shared_ctx, compound_deeper_prompts):
        x = self.conv1(x)  # shape = [*, width, grid, grid]
        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
        x = torch.cat(
            [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
             x], dim=1)  # shape = [*, grid ** 2 + 1, width]
        x = x + self.positional_embedding.to(x.dtype)

这个部分是CLIP对ViT的实现,先前没有读过,所以这里也稍微讲下

首先是经过self.conv1(x)卷积层,将其patch化,形状是(2,768,14,14),其中2是batch,14是patch大小。具体来说,经过卷积层后图像会按照16步长,16卷积核的大小被切成14×14的patch,合计196。

所以接下来的形状转换就变成了(2,768,196)

再经过一次变换适配ViT的格式后,进行了一次分类标记的拼接,x的形状从(2,196,768)变为了(2,197,768),多了一个分类标记。

接下来x = x + self.positional_embedding.to(x.dtype)给patch广播加上了位置编码。

然后是这次MaPLe引入的部分

        # After positional embeddings, we will attach prompts with the model, remember only those
        # are trainable parameters here in whole image encoder.
        if self.VPT_shallow:
            visual_ctx = shared_ctx.expand(x.shape[0], -1, -1).half()   
            x = torch.cat([x, visual_ctx], dim=1)
        else:
            assert self.prompt_till_layer_visual == 0

首先做了一个启用浅层提示的判断,当为启用的时候会将浅层提示拼接在x的后面。在此之前首先会用shared_ctx.expand(x.shape[0], -1, -1).half() 做一个尺寸对齐,这里两者都是batch等于2,没有什么变化。

然后使用torch.cat([x, visual_ctx], dim=1)拼接上x和浅层提示(2,2,768),x的形状变为了(2,199,768)

接下来就是与原版ViT一致的处理,先前在CoCoOp有所说明

        # Normal code as before
        x = self.ln_pre(x)

        x = x.permute(1, 0, 2)  # NLD -> LND
        # Again combine the inputs, so nn.sequential can work
        outputs = self.transformer([x, compound_deeper_prompts, 0])  # third argument is counter
        x = outputs[0]
        x = x.permute(1, 0, 2)  # LND -> NLD

        x = self.ln_post(x[:, 0, :])

        if self.proj is not None:
            x = x @ self.proj

        return x

这里的outputs来自于Transformer,那么传入的三个参数在文本分支的时候我们有说过,这里也是一样的,区别是隐藏维变成了768

经过Transformer以后x的形状是(199,2,768),与输入前一致。

关于Transformer是如何处理输入的patch和深层提示的,我们可以进入ViT的Transformer部分一探究竟


class Transformer(nn.Module)

引入了修改的残差注意力层

首先看向初始化,值得注意的是,除了MaPLe,这里还有IVLP的实现,不过这就不是我们这次需要关注的部分了。

截取掉不需要的部分,只关注MaPLe的执行代码:

    def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, prompts_needed=0,
                 text_layer=False, design_details=None):
        super().__init__()
        self.width = width
        self.layers = layers
        # Implements respective encoder blocks for a given design choice
        current_trainer = design_details['trainer']
        elif current_trainer == 'MaPLe':
            self.resblocks = nn.Sequential(
                *[ResidualAttentionBlock_MaPLe(width, heads, attn_mask, design_details, text_layer, i)
                  for i in range(layers)])
    def forward(self, x: torch.Tensor):
        return self.resblocks(x)

这里传入的参数分别是:

width: 768
heads: 12
attn_mask: None
design_details:如之前
text_layer: False
i: 12

也就是ViT-16/b的规格,输入768维度,共计12个头,共计12

看下来,这里主要是配置的定义初始化,即设定好宽度和配置文件等信息,而我们的图像patch会在前向传播的时候传入到残差注意力层当中。

带着这些参数和输入,进入残差注意力层看一下,另外深层提示我们还未看到是如何处理的


class ResidualAttentionBlock_MaPLe(nn.Module):

修改的残差注意力块

首先是初始化新增的部分:

    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, design_details=None,
        # For the first iteration i, we do not need to add the learnable parameters here
        # as it will be added in the beginning, for both text and the vision branch
        self.text_layer = text_layer
        self.attn_mask = attn_mask
        # This must be consistent with the config file prompt
        self.compound_prompt_nctx = design_details['maple_length']
        if i == 0:
            self.first_layer = True
        else:
            self.first_layer = False

首先是新增的三个传入参数:

design_details=None
text_layer=False
i=0

design_details已经知道了,不做解释。

text_layer是层层递进一直传递下来的参数,用来区分视觉模态和文本模态,两者之间作特征提取的时候是不同的,视觉分支下进行特征提取采用卷积层,而文本模态的文本嵌入就不能这么干了,具体在CLIP调用的时候就会体现这个参数的用途,当CLIP用来提取文本嵌入的时候,此处就会从false改为true,不指定情况下则默认为false

i用来标记目前所在的残差层数的计数器,默认0为第一层,每层构建时都会增加计数器,利用这个计数器就可以分辨当前是位于第一层(0)还是更深层。

接下来,到修改的前向传播部分,这段改动相当多,我们一点点切:

    def forward(self, inputs):
        # For the first layer, we do not need to add any duplicate, as it is already added
        # as the shallow version
        x = inputs[0]
        compound_prompts_deeper = inputs[1]
        counter = inputs[2]

首先是inputs当中的内容,很显然,就是我们从文本/视觉编码器传递下来的输入了,不论是文本还是视觉,都是三项:x,深层提示,计数器

首先这里是将三项取出来,定义好局部变量方便后续调用

接下来是做层数判断,区分目前位于浅层(输入层)还是深层

if not self.first_layer:
            if len(compound_prompts_deeper) > 0:

这么做的原因很简单,不在浅层重复添加已经添加过了的浅层提示。

具体在哪里添加的呢?如果已经忘记了的话,以视觉分支为例,返回到之前的VisionTransformer_MaPLe类,其中在前向传播中有一段:

        # After positional embeddings, we will attach prompts with the model, remember only those
        # are trainable parameters here in whole image encoder.
        if self.VPT_shallow:
            visual_ctx = shared_ctx.expand(x.shape[0], -1, -1).half()
            x = torch.cat([x, visual_ctx], dim=1)
        else:
            assert self.prompt_till_layer_visual == 0

注意看这行:

visual_ctx = shared_ctx.expand(x.shape[0], -1, -1).half()

此处插入了浅层提示,此时的图像张量x形状是(2,197, 768),也就是有2个batch,将xbatch形状取出,保持原本ctx的形状,复制到batch上,得到有两个相同ctxvisual_ctx,也就是最终的浅层提示。

将浅层提示拼接(torch.cat)到x的尾部,得到最终的x

x = torch.cat([x, visual_ctx], dim=1)

此时形状是(2,199,768),作为输入,传入到transformer模块当中:

        outputs = self.transformer([x, compound_deeper_prompts, 0])  # third argument is counter

同时一并传入的还有深层提示compound_deeper_prompts,此时的深层提示形状为(2,768),序列长度为8。回忆一下,配置文件当中的提示总插入深度为PROMPT_DEPTH: 9。本处长度为8是深层提示的插入数量,由总计9层插入,但减去了第一层输入层,也即浅层提示的部分而得到的,也即总计插入了1层浅层提示和8层深层提示,浅层由x = torch.cat([x, visual_ctx], dim=1)拼接传入,深层由compound_deeper_prompts传入。

深层提示形状(2,768)其中的2是上下文向量长度,768对齐隐藏层维度,如果是文本分支,那么这里应该是512

继续回到残差注意力块,首先从视觉分支的深层提示处理部分看起:

                if not self.text_layer:
                    # First check if the ith layer needs compound prompts or not
                    if not (counter > len(compound_prompts_deeper) - 1):
                        # Remove the outputs produced by learnable tokens of previous layer
                        prefix = x[0:x.shape[0] - self.compound_prompt_nctx, :, :]
                        # Create/configure learnable tokens of this layer
                        visual_context = compound_prompts_deeper[counter]  # extract the correct index
                        visual_context = visual_context.expand(x.shape[1], -1, -1).permute(1, 0, 2).half()
                        # Add the learnable tokens of this layer with the input, by replacing previous
                        # layer learnable tokens
                        x = torch.cat([prefix, visual_context], dim=0)

                        # Once done, update the counter, so that the next time, it does not use same learnable tokens
                        counter += 1

首先x[0:x.shape[0] - self.compound_prompt_nctx, :, :]是将嵌入和先前的浅层提示分开,x的形状是(199,2,768),所以x.shape[0]就是199了,接下来减掉先前拼接上去的浅层提示,得到拼接前的197,其他不变,所以prefix的形状就是(197,2,768)

接下来是深层提示的声明,这里原版的注释也说明了用途:

# Create/configure learnable tokens of this layer
visual_context = compound_prompts_deeper[counter]  # extract the correct index

这行代码的意味就是定义了深层提示,其中compound_prompts_deeper的长度就是8,每层形状是(2,768)也就是每一层的可学习深层提示,每一层的深层提示都不同,前面提到,counter是每一层的计数器,而compound_prompts_deeper实质上就是一个包含了8层深层提示的列表,所以这里会根据层数,选择定义用于插入的深度提示。

接下来是对齐batch_size

visual_context = visual_context.expand(x.shape[1], -1, -1).permute(1, 0, 2).half()

然后将深度提示拼接上输入后面

# Add the learnable tokens of this layer with the input, by replacing previous
# layer learnable tokens
x = torch.cat([prefix, visual_context], dim=0)

相当于我们再次做了一次浅层提示的插入行为,不过区别是这里插入的是深层提示,但是操作是一致的,后续还会根据层数再删除掉本次添加的深层提示,后拼接上新的深层提示,原因是每层我们初始化的可学习提示的向量都是不同的。

这样输入x又变回了带有可学习的提示的199维:(199,2,768)

最后触发计数器

# Once done, update the counter, so that the next time, it does not use same learnable tokens
counter += 1

另一条分支是文本分支:

                    # First check if the ith layer needs compound prompts or not
                    if not (counter > len(compound_prompts_deeper) - 1):
                        # Appending the learnable tokens in different way
                        # x -> [77, NCLS, DIM]
                        # First remove the learnable tokens from previous layer
                        prefix = x[:1, :, :]
                        suffix = x[1 + self.compound_prompt_nctx:, :, :]
                        # Create/configure learnable tokens of this layer
                        textual_context = compound_prompts_deeper[counter]
                        textual_context = textual_context.expand(x.shape[1], -1, -1).permute(1, 0, 2).half()
                        # Add the learnable tokens of this layer with the input, replaced by previous
                        # layer learnable tokens
                        x = torch.cat([prefix, textual_context, suffix], dim=0)
                        # Once done, update the counter, so that the next time, it does not use same learnable tokens
                        counter += 1

与视觉分支第一个不同的地方就是嵌入的处理方式,因为在插入浅层提示时的方式是不同的,所以处理起来也不同。在视觉分支时,方法是将提示拼接在嵌入之后,而文本则是有序列起始符,CTX和填充的序列之别,浅层提示CTX是拼接在起始符之后与填充内容之前的,如果忘记了的话可以回到MultiModalPromptLearner的文本分支部分再看下。

首先文本嵌入的形状是(77,50,512),其中序列的第一位是起始符,将起始符取出:

prefix = x[:1, :, :]

接下来再取出起始符和先前插入的可学习的浅层提示,当然,如果是第二次的话那替换掉的就是上一层的深层提示了,这与视觉分支处理的逻辑是一致的。

保留原本的序列,此处compound_prompt_nctx为2,故形状为(74,50,512):

suffix = x[1 + self.compound_prompt_nctx:, :, :]

然后相同的方法定义适用于文本分支的深层提示:

textual_context = compound_prompts_deeper[counter]
textual_context = textual_context.expand(x.shape[1], -1, -1).permute(1, 0, 2).half()

接下来插入到起始符和原本内容之间,实现深层提示插入:

# Add the learnable tokens of this layer with the input, replaced by previous
# layer learnable tokens
x = torch.cat([prefix, textual_context, suffix], dim=0)

然后是计数器的递增

最后,两个分支都会进行Transformer的残差自注意力实现,唯一改动是返回时变为列表的形式,多了深层提示列表和计数器,用于适配pytorch:

        x = x + self.attention(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return [x, compound_prompts_deeper, counter]  # return again as a list, so that nn.seq can work

后言

对于数据经过每一层的变化和具体的形状含义,在阅读全部代码的时候或许会有些懵,我打算单独开一篇,将MaPLe的形状详细变化和输入输出的数据重新整理下,因为我认为在有些地方的解释尚有欠缺,尤其是提示的组成部分。

Ref

  1. https://github.com/google-research/vision_transformer/issues/61#issuecomment-802233921
  2. https://www.pinecone.io/learn/series/image-search/vision-transformers/

评论

  1. 头像
    magicrane
    1 天前
    2025-3-12 11:46:52

    你好大佬! 你的文章非常非常清晰,让我这个小白也能看懂,但是其中关于方法中prefix 是[0, :1, …], suffix是[0, 1 + n_ctx:, :],那中间原有的文本嵌入的1~n_ctx是直接被可学习的文本给替代了嘛?原有的那部分就不会使用了嘛?

    • xiaohuo
      博主
      magicrane
      16 小时前
      2025-3-13 6:33:01

      嗯嗯,原有的嵌入会被替代掉,不会使用了,每层的CTX是不传递的。具体原因可以参考论文的实验部分,我印象里是有针对这一问题的实验的,传递会导致性能下降

发送评论 编辑评论


				
|´・ω・)ノ
ヾ(≧∇≦*)ゝ
(☆ω☆)
(╯‵□′)╯︵┴─┴
 ̄﹃ ̄
(/ω\)
∠( ᐛ 」∠)_
(๑•̀ㅁ•́ฅ)
→_→
୧(๑•̀⌄•́๑)૭
٩(ˊᗜˋ*)و
(ノ°ο°)ノ
(´இ皿இ`)
⌇●﹏●⌇
(ฅ´ω`ฅ)
(╯°A°)╯︵○○○
φ( ̄∇ ̄o)
ヾ(´・ ・`。)ノ"
( ง ᵒ̌皿ᵒ̌)ง⁼³₌₃
(ó﹏ò。)
Σ(っ °Д °;)っ
( ,,´・ω・)ノ"(´っω・`。)
╮(╯▽╰)╭
o(*////▽////*)q
>﹏<
( ๑´•ω•) "(ㆆᴗㆆ)
😂
😀
😅
😊
🙂
🙃
😌
😍
😘
😜
😝
😏
😒
🙄
😳
😡
😔
😫
😱
😭
💩
👻
🙌
🖕
👍
👫
👬
👭
🌚
🌝
🙈
💊
😶
🙏
🍦
🍉
😣
Source: github.com/k4yt3x/flowerhd
颜文字
Emoji
小恐龙
花!
上一篇