model.backbones package

Submodules

model.backbones.Attention module

class model.backbones.Attention.ATT

Bases: object

adaptive part cropping functions

used in other project

NMS_crop(attention_maps, input_image)
attention_crop(attention_maps, input_image)
attention_crop_drop(attention_maps, input_image)
attention_drop(attention_maps, input_image)
calc_iou(pred, gt)
calc_mask_iou(pred_mask, gt_mask)
feature_crop(attention_maps, input_image)
class model.backbones.Attention.GatedTensorBankbuilter(n_part=512)

Bases: torch.nn.modules.module.Module

build gated tensor for learning, normalization funcs can be further updated for better performance.

n_part : input part numbers default set as 512 for best performance

alpha: learnable spatial weight, adaptive to the feature size, pooled to 28 * 28

input: attention b,N_part,w,h

feature b,c,w,h

return:

tensorbank # b*c* N_part

forward(feature, attention)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class model.backbones.Attention.Tensorbuilter

Bases: torch.nn.modules.module.Module

Naive tensor bank builder, with slightly lower performance

input: attention b,N_part,w,h

feature b,c,w,h

return:

tensorbank # b*c* N_part

forward(feature, attention)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

model.backbones.resnet module

class model.backbones.resnet.BasicBlock(inplanes, planes, stride=1, downsample=None)

Bases: torch.nn.modules.module.Module

expansion = 1
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class model.backbones.resnet.Bottleneck(inplanes, planes, stride=1, downsample=None, use_cross=False)

Bases: torch.nn.modules.module.Module

expansion = 4
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class model.backbones.resnet.Bottleneck_old(inplanes, planes, stride=1, downsample=None)

Bases: torch.nn.modules.module.Module

expansion = 4
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class model.backbones.resnet.ResNet(last_stride=2, block=<class 'model.backbones.resnet.Bottleneck'>, use_cross=False, layers=[3, 4, 6, 3])

Bases: torch.nn.modules.module.Module

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_layers()
load_param(model_path)
random_init()
model.backbones.resnet.conv3x3(in_planes, out_planes, stride=1)

3x3 convolution with padding

model.backbones.resnet.resnet50_crosslevel(pretrained=True, **kwargs)

Module contents