From be43d75219622f1381352a49b94779fd5ef35bb7 Mon Sep 17 00:00:00 2001
From: Xiang Zhizheng <50263991+Xzzit@users.noreply.github.com>
Date: Tue, 13 Jun 2023 09:49:12 +0900
Subject: [PATCH] update readme and fix a bug for cpu user
---
README.md | 49 ++++++++++++++++++++++++------------------------
requirements.txt | 6 ++----
stylize_arg.py | 2 +-
3 files changed, 27 insertions(+), 30 deletions(-)
diff --git a/README.md b/README.md
index 913622b..23e14c8 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
# Fast Neural Style Transfer in Pytorch :art: :rocket:
-A Pytorch implementation of paper [**Perceptual Losses for Real-Time Style Transfer and Super-Resolution**](https://arxiv.org/abs/1603.08155) by *Justin Johnson, Alexandre Alahi, and Fei-Fei Li*. ***Note that*** the original paper proposes the algorithm to conduct 1) neural style transfer task and 2) image super-resolution task. This implementation can only be used to stylize images with arbitrary artistic style.
+A Pytorch implementation of paper [**Perceptual Losses for Real-Time Style Transfer and Super-Resolution**](https://arxiv.org/abs/1603.08155) by *Justin Johnson, Alexandre Alahi, and Fei-Fei Li*. ***Note that*** the original paper proposes the algorithm to conduct 1) neural style transfer task and 2) image super-resolution task. This implementation can only be used to 1) stylize images with arbitrary artistic style.
The idea 'neural style transfer' is proposed by *Leon A. Gatys, Alexander S. Ecker, Matthias Bethge* in paper [**Image Style Transfer Using Convolutional Neural Networks**](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Gatys_Image_Style_Transfer_CVPR_2016_paper.pdf), where the content features are represented as outputs of some selected VGG-16 layers and style features are represented as their Gram matrix.
@@ -9,58 +9,57 @@ This repo is based on the code [**fast-neural-style-pytorch**](https://github.co
## Dependencies
Tested With:
* Windows 10/11 || Mac M1 chip || Ubuntu 22.04 (Reconmended)
-* Python 3.7
-* Pytorch 1.10
+* Python 3.10
+* Pytorch 2.0.1
```
-conda create -n fst python=3.7
+conda create -n fst python=3.10
conda activate fst
pip install -r requirements.txt
```
+Then download the latest [**PyTorch**](https://pytorch.org/).
## Example Output
```
-python stylize_arg.py --content-image ./pretrained_models/bear.jpg --model ./pretrained_models/Fauvism_André-Derain_Pier.pth
+python stylize_arg.py --c ./pretrained_models/bear.jpg --m ./pretrained_models/Fauvism_André-Derain_Pier.pth
```
-
+
## Usage
***Train the model*** :hammer_and_wrench:
```
-python train_arg.py --dataset --style-image
+python train_arg.py --d --i
```
-- `--dataset`: path to training content images folder, I use Train images [118K/18GB] in [COCO 2017](https://cocodataset.org/#download).
-- `--style-image`: path to style-image.
-- `--save-model-dir`: path to folder where trained model will be saved.
-- `--model-name`: name of saved model.
-- `--content-weight`: weight for content-loss, default is 1e5.
-- `--style-weightt`: weight for style-loss, default is 1e10.
-- `--consistency-weight`: weight for consistency-loss, default is 1e1.
+- `--d`: path to training content images folder, I use Train images [118K/18GB] in [COCO 2017](https://cocodataset.org/#download).
+- `--i`: path to style-image.
- `--mps`: add it for running on macOS GPU
-- `--model-type`: architecture for stylization network. including: 1. ae: Autoencoder; 2. bo: bottleneck; 3. res: resNext.
+- `--save-model-dir`: path to folder where trained model will be saved.
+- `--c`: weight for content-loss, default is 1e5.
+- `--s`: weight for style-loss, default is 1e10.
+- `--cs`: weight for consistency-loss, default is 1e0.
+- `-tv`: weight for total variance-loss, default is 1e0.
-Refer to `train_arg.py` for other command line arguments. Refer to `models` folder for details of neural network architecture.
-For training new models you might have to tune the values of `--content-weight`, `--style-weight` and `--consistency-weight`.
+To learn about additional command line arguments, please refer to `train_arg.py`. For more information on the neural network architecture, please see the `models` folder.
+If you're training new models, you may need to adjust the values of `--c`, `--s`, `--cs`, and `--tv`.
***Stylize the image*** :paintbrush:
```
-python stylize_arg.py --content-image --model
+python stylize_arg.py --c --m
```
-- `--content-image`: path to content image you want to stylize.
-- `--model`: saved model to be used for stylizing the image (eg: `mosaic.pth`)
-- `--output-path`: path for saving the output image.
-- `--output-name`: name of output image.
-- `--content-scale`: factor for scaling down the content image if memory is an issue (eg: value of 2 will halve the height and width of content-image)
+- `--c`: path to content image you want to stylize.
+- `--m`: saved model to be used for stylizing the image (eg: `mosaic.pth`)
- `--mps`: add it for running on macOS GPU
-- `--model-type`: architecture for stylization network. including: 1. ae: Autoencoder; 2. bo: bottleneck; 3. res: resNext.
+- `--output-path`: path for saving the output image, default is current path.
+- `--output-name`: name of output image, default format is `stylized.jpg`
+- `--content-scale`: factor for scaling down the content image if memory is an issue (eg: value of 2 will halve the height and width of content-image)
Make sure that stylizaiton neural network has same `model-type` with pre-trained model.
diff --git a/requirements.txt b/requirements.txt
index 49cac6e..182dc58 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,2 @@
-numpy>=1.21.2
-Pillow>=8.4.0
-torch>=1.13+cu117
-torchvision>=0.14+cu117
+numpy>=1.24.3
+Pillow>=9.5.0
\ No newline at end of file
diff --git a/stylize_arg.py b/stylize_arg.py
index 00b1d85..245dd95 100644
--- a/stylize_arg.py
+++ b/stylize_arg.py
@@ -46,7 +46,7 @@ def stylize(args):
print('Error: invalid selected architecture')
sys.exit()
- state_dict = torch.load(args.model)
+ state_dict = torch.load(args.model, map_location=device)
# remove saved deprecated running_* keys in InstanceNorm from the checkpoint
for k in list(state_dict.keys()):
if re.search(r'in\d+\.running_(mean|var)$', k):