diff --git a/notebooks/simple_model_scripted.pt b/notebooks/simple_model_scripted.pt new file mode 100644 index 000000000..54f81b11c Binary files /dev/null and b/notebooks/simple_model_scripted.pt differ diff --git a/notebooks/torchscript_example.ipynb b/notebooks/torchscript_example.ipynb new file mode 100644 index 000000000..557204722 --- /dev/null +++ b/notebooks/torchscript_example.ipynb @@ -0,0 +1,92 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import torch.utils.data as data" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "class SimpleModel(nn.Module):\n", + " def __init__(self):\n", + " super(SimpleModel, self).__init__()\n", + " self.fc = nn.Linear(10, 5)\n", + "\n", + " def forward(self, x):\n", + " return self.fc(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "model = SimpleModel()\n", + "scripted_model = torch.jit.script(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "scripted_model.save(\"simple_model_scripted.pt\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[ 0.4575, 0.6755, 0.1485, -0.5884, -1.2903]],\n", + " grad_fn=)\n" + ] + } + ], + "source": [ + "loaded_model = torch.jit.load(\"simple_model_scripted.pt\")\n", + "x = torch.randn(1, 10)\n", + "output = loaded_model(x)\n", + "print(output)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "bot", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.19" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/torchscript_example_test.py b/tests/torchscript_example_test.py new file mode 100644 index 000000000..e5bf5adbd --- /dev/null +++ b/tests/torchscript_example_test.py @@ -0,0 +1,61 @@ +import torch +import unittest +import numpy as np + +class TestTorchScriptModel(unittest.TestCase): + + @classmethod + def setUpClass(cls): + # Load the TorchScript model + cls.model = torch.jit.load('notebooks/simple_model_scripted.pt') + cls.model.eval() # Set the model to evaluation mode + + def test_model_output_shape(self): + """Test if the model outputs the correct shape.""" + input_tensor = torch.randn(1, 5) # Adjust shape based on model input requirements + output_tensor = self.model(input_tensor) + self.assertEqual(output_tensor.shape, (1, 5), "Output shape mismatch") + + def test_model_output_values(self): + """Test if the model output values are within an expected range.""" + input_tensor = torch.randn(1, 5) + output_tensor = self.model(input_tensor) + # Example: Check if all output values are within the range -1 to 1 + self.assertTrue(torch.all(output_tensor >= -1) and torch.all(output_tensor <= 1), + "Output values out of expected range") + + def test_model_with_different_inputs(self): + """Test the model with various types of inputs to ensure robustness.""" + inputs = [ + torch.zeros(1, 5), + torch.ones(1, 5), + torch.randn(1, 5), + torch.full((1, 5), 0.5) + ] + for input_tensor in inputs: + output_tensor = self.model(input_tensor) + self.assertEqual(output_tensor.shape, (1, 5), "Output shape mismatch with different inputs") + + def test_model_gradients(self): + """Test if the model's gradients are computed correctly.""" + input_tensor = torch.randn(1, 5, requires_grad=True) + output_tensor = self.model(input_tensor) + output_tensor.sum().backward() + self.assertIsNotNone(input_tensor.grad, "Gradients were not computed") + + def test_scripted_model_serialization(self): + """Test if the scripted model can be reloaded and produce consistent outputs.""" + input_tensor = torch.randn(1, 5) + output_original = self.model(input_tensor) + + # Save and reload the scripted model + torch.jit.save(self.model, 'test_scripted_model.pt') + reloaded_model = torch.jit.load('test_scripted_model.pt') + reloaded_model.eval() + + output_reloaded = reloaded_model(input_tensor) + self.assertTrue(torch.allclose(output_original, output_reloaded), + "Outputs differ after reloading the scripted model") + +if __name__ == '__main__': + unittest.main()