Skip to content

Commit

Permalink
update readme and fix a bug for cpu user
Browse files Browse the repository at this point in the history
  • Loading branch information
Xzzit committed Jun 13, 2023
1 parent 9194eac commit be43d75
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 30 deletions.
49 changes: 24 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Fast Neural Style Transfer in Pytorch :art: :rocket:

A Pytorch implementation of paper [**Perceptual Losses for Real-Time Style Transfer and Super-Resolution**](https://arxiv.org/abs/1603.08155) by *Justin Johnson, Alexandre Alahi, and Fei-Fei Li*. ***Note that*** the original paper proposes the algorithm to conduct 1) neural style transfer task and 2) image super-resolution task. This implementation can only be used to stylize images with arbitrary artistic style.
A Pytorch implementation of paper [**Perceptual Losses for Real-Time Style Transfer and Super-Resolution**](https://arxiv.org/abs/1603.08155) by *Justin Johnson, Alexandre Alahi, and Fei-Fei Li*. ***Note that*** the original paper proposes the algorithm to conduct 1) neural style transfer task and 2) image super-resolution task. This implementation can only be used to 1) stylize images with arbitrary artistic style.

The idea 'neural style transfer' is proposed by *Leon A. Gatys, Alexander S. Ecker, Matthias Bethge* in paper [**Image Style Transfer Using Convolutional Neural Networks**](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Gatys_Image_Style_Transfer_CVPR_2016_paper.pdf), where the content features are represented as outputs of some selected VGG-16 layers and style features are represented as their Gram matrix.

Expand All @@ -9,58 +9,57 @@ This repo is based on the code [**fast-neural-style-pytorch**](https://github.co
## Dependencies
Tested With:
* Windows 10/11 || Mac M1 chip || Ubuntu 22.04 (Reconmended)
* Python 3.7
* Pytorch 1.10
* Python 3.10
* Pytorch 2.0.1

```
conda create -n fst python=3.7
conda create -n fst python=3.10
conda activate fst
pip install -r requirements.txt
```
Then download the latest [**PyTorch**](https://pytorch.org/).

## Example Output
```
python stylize_arg.py --content-image ./pretrained_models/bear.jpg --model ./pretrained_models/Fauvism_André-Derain_Pier.pth
python stylize_arg.py --c ./pretrained_models/bear.jpg --m ./pretrained_models/Fauvism_André-Derain_Pier.pth
```

<p align="center">
<div style="display: flex; justify-content: center;">
<img src="pretrained_models/bear.jpg" height="300px" title="content image">
<img src="pretrained_models/Fauvism_André-Derain_Pier.jpg" height="300px" title="style image">
<img src="pretrained_models/stylized.jpg" height="300px" title="generated image">
</p>
</div>

## Usage
***Train the model*** :hammer_and_wrench:

```
python train_arg.py --dataset <path/to/content/images/folder> --style-image <path/to/style/image/file>
python train_arg.py --d <path/to/content/images/folder> --i <path/to/style/image/file>
```

- `--dataset`: path to training content images folder, I use Train images [118K/18GB] in [COCO 2017](https://cocodataset.org/#download).
- `--style-image`: path to style-image.
- `--save-model-dir`: path to folder where trained model will be saved.
- `--model-name`: name of saved model.
- `--content-weight`: weight for content-loss, default is 1e5.
- `--style-weightt`: weight for style-loss, default is 1e10.
- `--consistency-weight`: weight for consistency-loss, default is 1e1.
- `--d`: path to training content images folder, I use Train images [118K/18GB] in [COCO 2017](https://cocodataset.org/#download).
- `--i`: path to style-image.
- `--mps`: add it for running on macOS GPU
- `--model-type`: architecture for stylization network. including: 1. ae: Autoencoder; 2. bo: bottleneck; 3. res: resNext.
- `--save-model-dir`: path to folder where trained model will be saved.
- `--c`: weight for content-loss, default is 1e5.
- `--s`: weight for style-loss, default is 1e10.
- `--cs`: weight for consistency-loss, default is 1e0.
- `-tv`: weight for total variance-loss, default is 1e0.

Refer to `train_arg.py` for other command line arguments. Refer to `models` folder for details of neural network architecture.
For training new models you might have to tune the values of `--content-weight`, `--style-weight` and `--consistency-weight`.
To learn about additional command line arguments, please refer to `train_arg.py`. For more information on the neural network architecture, please see the `models` folder.
If you're training new models, you may need to adjust the values of `--c`, `--s`, `--cs`, and `--tv`.

***Stylize the image*** :paintbrush:

```
python stylize_arg.py --content-image <path/to/content/image/file> --model <path/to/saved/model>
python stylize_arg.py --c <path/to/content/image/file> --m <path/to/saved/model>
```

- `--content-image`: path to content image you want to stylize.
- `--model`: saved model to be used for stylizing the image (eg: `mosaic.pth`)
- `--output-path`: path for saving the output image.
- `--output-name`: name of output image.
- `--content-scale`: factor for scaling down the content image if memory is an issue (eg: value of 2 will halve the height and width of content-image)
- `--c`: path to content image you want to stylize.
- `--m`: saved model to be used for stylizing the image (eg: `mosaic.pth`)
- `--mps`: add it for running on macOS GPU
- `--model-type`: architecture for stylization network. including: 1. ae: Autoencoder; 2. bo: bottleneck; 3. res: resNext.
- `--output-path`: path for saving the output image, default is current path.
- `--output-name`: name of output image, default format is `stylized.jpg`
- `--content-scale`: factor for scaling down the content image if memory is an issue (eg: value of 2 will halve the height and width of content-image)

Make sure that stylizaiton neural network has same `model-type` with pre-trained model.
6 changes: 2 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
numpy>=1.21.2
Pillow>=8.4.0
torch>=1.13+cu117
torchvision>=0.14+cu117
numpy>=1.24.3
Pillow>=9.5.0
2 changes: 1 addition & 1 deletion stylize_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def stylize(args):
print('Error: invalid selected architecture')
sys.exit()

state_dict = torch.load(args.model)
state_dict = torch.load(args.model, map_location=device)
# remove saved deprecated running_* keys in InstanceNorm from the checkpoint
for k in list(state_dict.keys()):
if re.search(r'in\d+\.running_(mean|var)$', k):
Expand Down

0 comments on commit be43d75

Please sign in to comment.