Skip to content

Commit c32c427

Browse files
committed
add fsa
1 parent 773b7ca commit c32c427

29 files changed

+4877
-6
lines changed

README.md

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,54 @@
1-
# InhibitoryAttention
2-
类脑组抑制性神经元与Attention
1+
# Win-Take-All Self-Attention
2+
3+
This folder contains the implementation of Win-Take-All Self-Attention based on DeiT and PVT models for image classification.
4+
5+
## Dependencies
6+
7+
- Python 3.9
8+
- PyTorch == 1.11.0
9+
- torchvision == 0.12.0
10+
- numpy
11+
- timm == 0.4.12
12+
- einops
13+
- yacs
14+
15+
```
16+
pipx install uv
17+
uv sync
18+
```
19+
20+
## Data preparation
21+
22+
The ImageNet dataset should be prepared as follows:
23+
24+
```
25+
$ tree data
26+
imagenet
27+
├── train
28+
│ ├── class1
29+
│ │ ├── img1.jpeg
30+
│ │ ├── img2.jpeg
31+
│ │ └── ...
32+
│ ├── class2
33+
│ │ ├── img3.jpeg
34+
│ │ └── ...
35+
│ └── ...
36+
└── val
37+
├── class1
38+
│ ├── img4.jpeg
39+
│ ├── img5.jpeg
40+
│ └── ...
41+
├── class2
42+
│ ├── img6.jpeg
43+
│ └── ...
44+
└── ...
45+
```
46+
47+
## Train Models from Scratch
48+
49+
- To train `FSA-DeiT/FSA-PVT` on ImageNet from scratch, run:
50+
51+
```shell
52+
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg <path-to-config-file> --data-path <imagenet-path> --output <output-path>
53+
```
54+

cfgs/fsa_deit_b.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
MODEL:
2+
TYPE: fsa_deit_base
3+
NAME: fsa_deit_base
4+
DROP_PATH_RATE: 0.3

cfgs/fsa_deit_s.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
MODEL:
2+
TYPE: fsa_deit_small
3+
NAME: fsa_deit_small
4+
DROP_PATH_RATE: 0.1

cfgs/fsa_deit_t.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
DATA:
2+
BATCH_SIZE: 512
3+
MODEL:
4+
TYPE: fsa_deit_tiny
5+
NAME: fsa_deit_tiny
6+
DROP_PATH_RATE: 0.0

cfgs/fsa_pvt_b.yaml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
DATA:
2+
IMG_SIZE: 224
3+
BATCH_SIZE: 64
4+
5+
TRAIN:
6+
WEIGHT_DECAY: 0.05
7+
EPOCHS: 300
8+
WARMUP_EPOCHS: 5
9+
COOLDOWN_EPOCHS: 10
10+
BASE_LR: 5e-4
11+
WARMUP_LR: 1e-6
12+
MIN_LR: 1e-5
13+
CLIP_GRAD: 1.0
14+
15+
MODEL:
16+
TYPE: fsa_pvt_large
17+
NAME: fsa_pvt_large
18+
DROP_PATH_RATE: 0.3
19+
FSA:
20+
ATTN_TYPE: FFFF

cfgs/fsa_pvt_m.yaml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
DATA:
2+
IMG_SIZE: 224
3+
BATCH_SIZE: 128
4+
5+
TRAIN:
6+
WEIGHT_DECAY: 0.05
7+
EPOCHS: 300
8+
WARMUP_EPOCHS: 5
9+
COOLDOWN_EPOCHS: 10
10+
BASE_LR: 5e-4
11+
WARMUP_LR: 1e-6
12+
MIN_LR: 1e-5
13+
CLIP_GRAD: 1.0
14+
15+
MODEL:
16+
TYPE: fsa_pvt_medium
17+
NAME: fsa_pvt_medium
18+
DROP_PATH_RATE: 0.3
19+
FSA:
20+
ATTN_TYPE: FFFF

cfgs/fsa_pvt_s.yaml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
DATA:
2+
IMG_SIZE: 224
3+
BATCH_SIZE: 128
4+
5+
TRAIN:
6+
WEIGHT_DECAY: 0.05
7+
EPOCHS: 300
8+
WARMUP_EPOCHS: 5
9+
COOLDOWN_EPOCHS: 10
10+
BASE_LR: 5e-4
11+
WARMUP_LR: 1e-6
12+
MIN_LR: 1e-5
13+
CLIP_GRAD: None
14+
15+
MODEL:
16+
TYPE: fsa_pvt_small
17+
NAME: fsa_pvt_small
18+
DROP_PATH_RATE: 0.1
19+
FSA:
20+
ATTN_TYPE: FFFF

cfgs/fsa_pvt_t.yaml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
DATA:
2+
IMG_SIZE: 224
3+
BATCH_SIZE: 128
4+
5+
TRAIN:
6+
WEIGHT_DECAY: 0.05
7+
EPOCHS: 300
8+
WARMUP_EPOCHS: 5
9+
COOLDOWN_EPOCHS: 10
10+
BASE_LR: 5e-4
11+
WARMUP_LR: 1e-6
12+
MIN_LR: 1e-5
13+
CLIP_GRAD: None
14+
15+
MODEL:
16+
TYPE: fsa_pvt_tiny
17+
NAME: fsa_pvt_tiny
18+
DROP_PATH_RATE: 0.1
19+
FSA:
20+
ATTN_TYPE: FFFF

0 commit comments

Comments
 (0)