Skip to content

Commit

Permalink
feat: Add CSGO-2
Browse files Browse the repository at this point in the history
  • Loading branch information
Flowtter committed Sep 2, 2023
1 parent 81d35ae commit 92f2224
Show file tree
Hide file tree
Showing 18 changed files with 192 additions and 91 deletions.
62 changes: 61 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,63 @@ The following settings are adjustable:
- second-before: Seconds of gameplay included before the highlight.
- second-after: Seconds of gameplay included after the highlight.
- second-between-kills: Transition time between highlights. If the time between two highlights is less than this value, the both highlights will be merged.
- game: Chosen game (either "valorant" or "overwatch")
- game: Chosen game (either "valorant", "overwatch" or "csgo2")

### Recommended settings

I recommend you to use the trials and errors method to find the best settings for your videos.\
Here are some settings that I found to work well for me:

#### Valorant

```json
{
"neural-network": {
"confidence": 0.8
},
"clip": {
"framerate": 8,
"second-before": 4,
"second-after": 0.5,
"second-between-kills": 3
},
"game": "valorant"
}
```

#### Overwatch

```json
{
"neural-network": {
"confidence": 0.6
},
"clip": {
"framerate": 8,
"second-before": 4,
"second-after": 3,
"second-between-kills": 5
},
"game": "overwatch"
}
```

#### CSGO2

```json
{
"neural-network": {
"confidence": 0.7
},
"clip": {
"framerate": 8,
"second-before": 4,
"second-after": 1,
"second-between-kills": 3
},
"game": "csgo2"
}
```

## Run

Expand Down Expand Up @@ -139,3 +195,7 @@ Now `pre-commit` will run on every `git commit`.

- `cd crispy-frontend && yarn && yarn dev`
- `cd crispy-backend && pip install -Ir requirements-dev.txt && python -m api`

## Test

- `cd crispy-api && pytest`
12 changes: 8 additions & 4 deletions crispy-api/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,24 @@
from montydb import MontyClient, set_storage
from pydantic.json import ENCODERS_BY_TYPE

from api.config import DATABASE_PATH, DEBUG, GAME, MUSICS, VIDEOS
from api.config import DATABASE_PATH, DEBUG, FRAMERATE, GAME, MUSICS, VIDEOS
from api.tools.AI.network import NeuralNetwork
from api.tools.enums import SupportedGames
from api.tools.filters import apply_filters # noqa
from api.tools.setup import handle_highlights, handle_musics

ENCODERS_BY_TYPE[ObjectId] = str

neural_network = NeuralNetwork(GAME)

if GAME == SupportedGames.OVERWATCH:
neural_network = NeuralNetwork([10000, 120, 15, 2])
neural_network.load("./assets/overwatch.npy")
elif GAME == SupportedGames.VALORANT:
neural_network = NeuralNetwork([4000, 120, 15, 2], 0.01)
neural_network.load("./assets/valorant.npy")
elif GAME == SupportedGames.CSGO2:
neural_network.load("./assets/csgo2.npy")
else:
raise ValueError(f"game {GAME} not supported")


logging.getLogger("PIL").setLevel(logging.ERROR)
Expand Down Expand Up @@ -62,7 +66,7 @@ def is_tool_installed(ffmpeg_tool: str) -> None:
@app.on_event("startup")
async def setup_crispy() -> None:
await handle_musics(MUSICS)
await handle_highlights(VIDEOS, GAME, framerate=8)
await handle_highlights(VIDEOS, GAME, framerate=FRAMERATE)


@app.exception_handler(HTTPException)
Expand Down
53 changes: 48 additions & 5 deletions crispy-api/api/__main__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,35 @@
import argparse
import asyncio
import os
import sys

import uvicorn

from api import init_database
from api.config import DEBUG, HOST, PORT
from api.config import (
DATASET_CSV_PATH,
DATASET_CSV_TEST_PATH,
DEBUG,
HOST,
NETWORK_OUTPUTS_PATH,
PORT,
)
from api.tools.AI.trainer import Trainer, test, train
from api.tools.dataset import create_dataset
from api.tools.enums import SupportedGames

_parser = argparse.ArgumentParser()
# Dataset
_parser.add_argument("--dataset", action="store_true")

# Trainer
_parser.add_argument("--train", help="Train the network", action="store_true")
_parser.add_argument("--test", help="Test the network", action="store_true")
_parser.add_argument("--epoch", help="Number of epochs", type=int, default=1000)
_parser.add_argument("--load", help="Load a trained network", action="store_true")
_parser.add_argument("--path", help="Path to the network", type=str)

# Game
_parser.add_argument(
"--game", type=str, choices=[game.value for game in SupportedGames]
)
Expand All @@ -26,11 +46,34 @@ async def generate_dataset(game: SupportedGames) -> None:


if __name__ == "__main__":
if not _args.dataset:
if not _args.dataset and not _args.train and not _args.test:
uvicorn.run("api:app", host=HOST, port=PORT, reload=DEBUG, proxy_headers=True)
else:
game = SupportedGames(_args.game)
if not game:
raise ValueError("Game not supported")
if _args.dataset:
if not game:
raise ValueError("Game not supported")

asyncio.run(generate_dataset(game))
else:
trainer = Trainer(game, 0.01)

if _args.load:
trainer.load(_args.path)
else:
trainer.initialize_weights()

print(trainer)
if _args.train:
if not os.path.exists(NETWORK_OUTPUTS_PATH):
os.makedirs(NETWORK_OUTPUTS_PATH)
train(
_args.epoch, trainer, DATASET_CSV_PATH, True, NETWORK_OUTPUTS_PATH
)

if _args.test:
if not _args.load and not _args.train:
print("You need to load a trained network")
sys.exit(1)

asyncio.run(generate_dataset(game))
sys.exit(not test(trainer, DATASET_CSV_TEST_PATH))
8 changes: 6 additions & 2 deletions crispy-api/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

ASSETS = "assets"
SILENCE_PATH = os.path.join(ASSETS, "silence.mp3")
DOT_PATH = os.path.join(ASSETS, "dot.png")
VALORANT_MASK_PATH = os.path.join(ASSETS, "valorant-mask.png")
CSGO2_MASK_PATH = os.path.join(ASSETS, "csgo2-mask.png")

BACKUP = "backup"

Expand All @@ -24,7 +25,10 @@
MUSICS = os.path.join(RESOURCES, "musics")

DATASET_PATH = "dataset"
DATASET_VALUES_PATH = "dataset-values.json"
DATASET_VALUES_PATH = os.path.join(DATASET_PATH, "dataset-values.json")
DATASET_CSV_PATH = os.path.join(DATASET_PATH, "result.csv")
DATASET_CSV_TEST_PATH = os.path.join(DATASET_PATH, "test.csv")
NETWORK_OUTPUTS_PATH = "outputs"

DATABASE_PATH = ".data"

Expand Down
26 changes: 22 additions & 4 deletions crispy-api/api/models/highlight.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
from mongo_thingy import Thingy
from PIL import Image, ImageFilter, ImageOps

from api.config import BACKUP, DOT_PATH
from api.config import BACKUP, CSGO2_MASK_PATH, VALORANT_MASK_PATH
from api.models.filter import Filter
from api.models.segment import Segment
from api.tools.audio import silence_if_no_audio
from api.tools.enums import SupportedGames
from api.tools.ffmpeg import merge_videos

logger = logging.getLogger("uvicorn")
valorant_mask = Image.open(VALORANT_MASK_PATH)
csgo2_mask = Image.open(CSGO2_MASK_PATH)


class Highlight(Thingy):
Expand Down Expand Up @@ -130,9 +132,7 @@ def _apply_filter_and_do_operations(

image = image.crop((1, 1, image.width - 2, image.height - 2))

dot = Image.open(DOT_PATH)

image.paste(dot, (0, 0), dot)
image.paste(valorant_mask, (0, 0), valorant_mask)

left = image.crop((0, 0, 25, 60))
right = image.crop((95, 0, 120, 60))
Expand Down Expand Up @@ -162,13 +162,31 @@ def post_process(image: Image) -> Image:
post_process, (899, 801, 122, 62), framerate=framerate
)

async def extract_csgo2_images(self, framerate: int = 4) -> bool:
def post_process(image: Image) -> Image:
image = ImageOps.grayscale(
image.filter(ImageFilter.FIND_EDGES).filter(
ImageFilter.EDGE_ENHANCE_MORE
)
)
final = Image.new("RGB", (100, 100))
final.paste(image, (0, 0))
final.paste(csgo2_mask, (0, 0), csgo2_mask)
return final

return await self.extract_images(
post_process, (930, 925, 100, 100), framerate=framerate
)

async def extract_images_from_game(
self, game: SupportedGames, framerate: int = 4
) -> bool:
if game == SupportedGames.OVERWATCH:
return await self.extract_overwatch_images(framerate)
elif game == SupportedGames.VALORANT:
return await self.extract_valorant_images(framerate)
elif game == SupportedGames.CSGO2:
return await self.extract_csgo2_images(framerate)
else:
raise NotImplementedError

Expand Down
14 changes: 11 additions & 3 deletions crispy-api/api/tools/AI/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,22 @@
import numpy as np
import scipy.special

from api.tools.enums import SupportedGames

NetworkResolution = {
SupportedGames.VALORANT: [4000, 120, 15, 2],
SupportedGames.OVERWATCH: [10000, 120, 15, 2],
SupportedGames.CSGO2: [10000, 120, 15, 2],
}


class NeuralNetwork:
"""
Neural network to predict if a kill is on the image
"""

def __init__(self, nodes: List[int], learning_rate: float = 0.01) -> None:
self.nodes = nodes
def __init__(self, game: SupportedGames, learning_rate: float = 0.01) -> None:
self.nodes = NetworkResolution[game]
self.learning_rate = learning_rate
self.weights: List[Any] = []
self.activation_function = lambda x: scipy.special.expit(x)
Expand Down Expand Up @@ -55,7 +63,7 @@ def _train(self, inputs: List[float], targets: Any) -> Tuple[int, int, int]:
for i in range(len(self.nodes) - 1 - 1, 0, -1):
errors.insert(0, np.dot(self.weights[i].T, errors[0]))

# ten times more likely to be not be kill
# five times more likely to be not be kill
# so we mitigate the error
if expected == 0:
errors = [e / 5 for e in errors]
Expand Down
Loading

0 comments on commit 92f2224

Please sign in to comment.