Skip to content

Commit

Permalink
Prepare for release
Browse files Browse the repository at this point in the history
  • Loading branch information
SmilingWolf committed Oct 21, 2022
1 parent 588c5bf commit 1d3f947
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 10 deletions.
51 changes: 50 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Repo for my Tensorflow/Keras CV experiments. Mostly revolving around the Danboor

---

Framework: TF/Keras 2.7
Framework: TF/Keras 2.10

Training SQLite DB built using fire-egg's tools: https://github.com/fire-eggs/Danbooru2019

Expand All @@ -15,6 +15,55 @@ Anonymous, The Danbooru Community, & Gwern Branwen; “Danbooru2021: A Large-Sca
----

## Journal

**21/10/2022**:
New release today, after a lot of time, many failed experiments, some successes and a whopping final 200 epochs and 24 days of training, TPU time courtesy of TRC.
ViT in particular was a pain, turns out it is quite sensitive to learning rate. I'm still not 100% sure I nailed it, but right now it works(TM).
While it might look like ConvNext does better, the whole story is a bit more interesting.
Full results below:

All 5500 tags:
| run_name | definition_name | params_human | image_size | thres | F1 | F2 |
|:---------------------------------|:------------------|:---------------|-------------:|--------:|-------:|-------:|
| ensenble_october | - | - | 448 | 0.3611 | 0.7044 | 0.7044 |
| ConvNextBV1_09_25_2022_05h13m55s | B | 93.2M | 448 | 0.3673 | 0.6941 | 0.6941 |
| ViTB16_09_25_2022_04h53m38s | B16 | 90.5M | 448 | 0.3663 | 0.6918 | 0.6918 |

All tags, starting from #2380 and below (sorted by most to least popular):
| run_name | definition_name | params_human | image_size | thres | F1 | F2 |
|:---------------------------------|:------------------|:---------------|-------------:|--------:|-------:|-------:|
| ensenble_october | - | - | 448 | 0.3611 | 0.6107 | 0.5588 |
| ConvNextBV1_09_25_2022_05h13m55s | B | 93.2M | 448 | 0.3673 | 0.5932 | 0.5425 |
| ViTB16_09_25_2022_04h53m38s | B16 | 90.5M | 448 | 0.3663 | 0.5993 | 0.5529 |

General tags (category 0), no character or series tags:
| run_name | definition_name | params_human | image_size | thres | F1 | F2 |
|:---------------------------------|:------------------|:---------------|-------------:|--------:|-------:|-------:|
| ensenble_october | - | - | 448 | 0.3618 | 0.6878 | 0.6878 |
| ConvNextBV1_09_25_2022_05h13m55s | B | 93.2M | 448 | 0.3682 | 0.6774 | 0.6774 |
| ViTB16_09_25_2022_04h53m38s | B16 | 90.5M | 448 | 0.3672 | 0.6748 | 0.6748 |

General tags (category 0), no character or series tags, starting from #2000 and below (sorted by most to least popular):
| run_name | definition_name | params_human | image_size | thres | F1 | F2 |
|:---------------------------------|:------------------|:---------------|-------------:|--------:|-------:|-------:|
| ensenble_october | - | - | 448 | 0.3618 | 0.4515 | 0.3976 |
| ConvNextBV1_09_25_2022_05h13m55s | B | 93.2M | 448 | 0.3682 | 0.4320 | 0.3804 |
| ViTB16_09_25_2022_04h53m38s | B16 | 90.5M | 448 | 0.3672 | 0.4416 | 0.3936 |

The numbers are obtained using tools/analyze_metrics.py to first find the point where P ≈ R, then using that threshold to check what scores I get on the less popular tags.
ViT blazes past ConvNext when it comes to rarer tags, so you might want to consider that when choosing what model to use.
Personally, I ensemble them if I don't have time constraints. That would be ensenble_october in the tables above. Quite some gains.

Next, I'll be finetuning at least one of these models on the latest tags, and adding NSFW images and tags to the training set, so that it can be used in tandem with Waifu Diffusion.

**21/04/2022**:
Checkpointing sweeps and conclusions so far:
- NFNets: ECA and SiLU are well worth the extra computational cost. ECA is incredibly cheap on parameters side, too. Using both ECA and SiLU.
- NFNets: tested MixUp with alpha = 0.2 and alpha = 0.3, found no particular reason to use alpha 0.3. Using alpha = 0.2.
- ConvNext: focal loss: tested a few parameter combinations from the original paper, defaults are good. Using alpha = 0.25 gamma = 2.
- ConvNext: losses: tested a few parameter combinations. Best results achieved with ASL with gamma_neg = gamma_pos = clip = 0, which boils down to BCE with the sum of the per-class losses instead of the average. Using ASL with gamma_neg = gamma_pos = clip = 0.
- ConvNext: tested cutout_rate = 0.0, cutout_rate = 0.25, cutout_rate = 0.5. Even training for 300 epochs, neither cutout_rate > 0 run ever displayed any advantage against the cutout_rate = 0 run, both overall and at the single class level. Using cutout_rate = 0.0.

**05/04/2022**:
So, I trained a bunch of ConvNexts in the past month.
Learned a few things too:
Expand Down
34 changes: 25 additions & 9 deletions tools/analyze_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

import numpy as np
import pandas as pd

import re

def filter_old_new(img_tags, img_probs, is_new=False):
"""
is_new = True if model has been trained on danbooru2021
is_new = False if model has been trained on danbooru2020
"""
df2020 = pd.read_csv("2020_0000_0599/selected_tags.csv")
df2021 = pd.read_csv("2021_0000_0899/selected_tags.csv")
df2021 = pd.read_csv("2021_0000_0899_5500/selected_tags.csv")

names2020 = df2020["name"].tolist()
names2021 = df2021["name"].tolist()
Expand All @@ -32,8 +32,8 @@ def filter_old_new(img_tags, img_probs, is_new=False):


def calc_metrics(img_tags, img_probs, thresh):
yz = (img_tags > 0).astype(np.uint)
pos = (img_probs > thresh).astype(np.uint)
yz = (img_tags > 0).astype(np.uint8)
pos = (img_probs > thresh).astype(np.uint8)
pct = pos + 2 * yz

TN = np.sum(pct == 0).astype(np.float32)
Expand Down Expand Up @@ -62,6 +62,14 @@ def calc_metrics(img_tags, img_probs, thresh):
help="Slice files along axis=1 starting from this index",
)

parser.add_argument(
"-c",
"--category",
type=int,
default=-1,
help="Only analyze tags of this category (-1 = all)",
)

thresh_group = parser.add_mutually_exclusive_group()
thresh_group.add_argument(
"-a",
Expand All @@ -81,7 +89,13 @@ def calc_metrics(img_tags, img_probs, thresh):
args = parser.parse_args()

img_probs = np.load(args.dump)
img_tags = np.load("2021_0000_0899/encoded_tags_test.npy")
img_tags = np.load("2021_0000_0899_5500/encoded_tags_test.npy")

if args.category > -1:
df = pd.read_csv("2021_0000_0899_5500/selected_tags.csv")
indexes = np.where(df["category"] == args.category)[0]
img_probs = img_probs[:, indexes]
img_tags = img_tags[:, indexes]

# img_tags, img_probs = filter_old_new(img_tags, img_probs, True)

Expand All @@ -95,7 +109,7 @@ def calc_metrics(img_tags, img_probs, thresh):

recall = 0.0
precision = 1.0
while round(recall, 4) != round(precision, 4):
while not np.isclose(recall, precision):
threshold = (threshold_max + threshold_min) / 2
precision, recall = calc_metrics(img_tags, img_probs, threshold)
if precision > recall:
Expand All @@ -107,8 +121,8 @@ def calc_metrics(img_tags, img_probs, thresh):
else:
threshold = args.threshold

pos = (img_probs > threshold).astype(np.uint)
yz = (img_tags > 0).astype(np.uint)
pos = (img_probs > threshold).astype(np.uint8)
yz = (img_tags > 0).astype(np.uint8)
pct = pos + 2 * yz

TN = np.sum(pct == 0).astype(np.float32)
Expand All @@ -125,6 +139,8 @@ def calc_metrics(img_tags, img_probs, thresh):

MCC = ((TP * TN) - (FP * FN)) / np.sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN))

model_name = re.sub(".*tags_probs_", "", args.dump)
model_name = model_name.replace(".npy", "")
d = {
"thres": threshold,
"F1": round(F1, 4),
Expand All @@ -134,4 +150,4 @@ def calc_metrics(img_tags, img_probs, thresh):
"R": round(recall, 4),
"P": round(precision, 4),
}
print(d)
print(f"{model_name}: {str(d)}")

0 comments on commit 1d3f947

Please sign in to comment.