This repository has been archived by the owner on Sep 18, 2020. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
plotting.py
96 lines (71 loc) · 2.81 KB
/
plotting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""Result plotting helpers
Copyright 2018 Spectre Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from abc import ABCMeta, abstractmethod
import json
from numbers import Number
from typing import Collection, Dict, List
import numpy as np
def _by_dict(object_):
return object_.__dict__
_SERIALIZER_OPTIONS = {"default": _by_dict, "sort_keys": True}
class Trace(metaclass=ABCMeta):
@abstractmethod
def __init__(self):
raise NotImplementedError(Trace.__name__ + ' is abstract')
def to_json(self):
return json.dumps(self, **_SERIALIZER_OPTIONS)
__str__ = to_json
__repr__ = to_json
__json__ = to_json
for_json = to_json
class Plot:
def __init__(self, data: List[Trace], layout: Dict=None):
self.data = data
self.layout = layout or {}
def to_json(self):
return json.dumps(self, **_SERIALIZER_OPTIONS)
__str__ = to_json
__repr__ = to_json
__json__ = to_json
for_json = to_json
ArrayLike = Collection[Number]
class Scatter2d(Trace):
def __init__(self, x: ArrayLike, y: ArrayLike):
if len(x) != len(y):
raise ValueError("len(x) != len(y); %i != %i" % (len(x), len(y)))
self.x = list(x if not isinstance(x, np.ndarray) else x.ravel())
self.y = list(y if not isinstance(y, np.ndarray) else y.ravel())
self.mode = 'markers'
self.type = 'scatter'
self.marker = {"size": 2}
class Scatter3d(Trace):
def __init__(self, x: ArrayLike, y: ArrayLike, z: ArrayLike):
if len(x) != len(y):
raise ValueError("len(x) != len(y); %i != %i" % (len(x), len(y)))
if len(x) != len(z):
raise ValueError("len(x) != len(z); %i != %i" % (len(x), len(z)))
self.x = list(x if not isinstance(x, np.ndarray) else x.ravel())
self.y = list(y if not isinstance(y, np.ndarray) else y.ravel())
self.z = list(z if not isinstance(z, np.ndarray) else z.ravel())
self.mode = 'markers'
self.type = 'scatter3d'
self.marker = {"size": 2}
def as_scatter_plot(observations: np.ndarray) -> Plot:
dimensions = observations.shape[1]
if dimensions == 2:
trace = Scatter2d
elif dimensions == 3:
trace = Scatter3d
else:
raise ValueError("Supports only 2D and 3D data. Was: %i" % dimensions)
return Plot([trace(*observations.T)])