Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

What watershed do in precompute_freq_cis function #73

Open
yjhong89 opened this issue Jun 17, 2024 · 1 comment
Open

What watershed do in precompute_freq_cis function #73

yjhong89 opened this issue Jun 17, 2024 · 1 comment

Comments

@yjhong89
Copy link

Hi. Thanks for sharing great works!

I wonder what is the role of scale_watershed in

scale_watershed: float = 1.0,
?

@ChrisLiu6
Copy link
Contributor

ChrisLiu6 commented Jun 18, 2024

In short, it is a watershed w.r.t. time step, before which position embedding is linearly scaled, and after which position embedding is NTK scaled.

More details: to make a model trained at 1k resolution to generate images at 1.5k or higher resolutions, an extrapolation on position embedding (i.e. RoPE in Lumina) is needed. We find that linear RoPE scaling leads to good global structure and composition, but the nearby pixels tend not to be harmonious; In contrast, NTK scaling makes good local texture, but global structure is usually unreasonable. Therefore, we use a combination of them two, applying linear scaling in the initial diffusion steps to define the global composition (intuitively like to draw a draft), and then switch to NTK for high-quality texture. It follows the same intuition as the method introduced in Sec 2.2 of the Lumina-Next paper but usually behaves more stably.

This method is very simple w.r.t. implementation

if timestep < scale_watershed:
linear_factor = scale_factor
ntk_factor = 1.0
else:
linear_factor = 1.0
ntk_factor = scale_factor
theta = theta * ntk_factor
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float().cuda() / dim)) / linear_factor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants