Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG TSDS/TFT grouped multivariate error with large batch size #1730

Open
Loggy48 opened this issue Dec 17, 2024 · 1 comment
Open

[BUG TSDS/TFT grouped multivariate error with large batch size #1730

Loggy48 opened this issue Dec 17, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@Loggy48
Copy link

Loggy48 commented Dec 17, 2024

testGroup.csv
testGroup.log

Describe the bug

A large batch size of 3*7200=21600 leads to a mismatch between the original groups and the groups as mapped in x['groups']. This is reflected in the number of observations per group which is originally [7200,7200,7200] but in the x tensor it comes out as [7201,7201,7198].

To Reproduce
The code below with the attached file demonstrates the error. On inspection the first 7131 observations for each group appear to be correct but the top-up at the end of the groups tensor is wrong

# Test program which reads testGroup.csv to show TSDS/TFT groups are not set correctly for batch_size 21600
# This will place a copy of the source and of the log in a directory run/{processID}
# The program overwrites the built in print and stores a source copy and print output in the run directory
# TSDS and TFT parameters are automatically printed to validate and inspect() prints object structure (and data if requested)
# Features are already scaled so paseed thru and group-based scaling is enabled
# The test targets are x0 ~1.6, x1 ~ 10.8 and x2 ~ 150 so they can easily be distinguished
# The initial group table is printed at line 265 while the model group table is printed at line 304
# At input the groups are all 7200 long, after processing the group lengths are 7201, 7201 and 7198
import sys, os, time
import traceback
import logging
import threading as th # in case threads used in libraries!
import datetime as dt
import pandas as pd
import numpy as np
import torch
from torch.nn import MSELoss
import pytorch_forecasting as pf
from pytorch_forecasting import TimeSeriesDataSet as OrigTSDS
from pytorch_forecasting.models import TemporalFusionTransformer
from pytorch_forecasting.metrics import MultiLoss
from pytorch_forecasting.data.encoders import TorchNormalizer, GroupNormalizer, MultiNormalizer
import lightning.pytorch as pl # This is V2 DO NOT USE the old pytorch_lightning
threads = 32
os.environ['TZ'] = 'UTC'
os.environ['NUMEXPR_MAX_THREADS'] = str(threads)
os.environ['NUMEXPR_NUM_THREADS'] = str(threads)
os.environ['OMP_NUM_THREADS'] = str(threads)
os.environ['MKL_NUM_THREADS'] = str(threads)
os.environ['OMP_DYNAMIC'] = 'FALSE'
time.tzset()  # This line is only necessary for Unix-based systems
processID = str(os.getpid())
while os.path.exists("run/"+processID):
	processID = processID+"X" # just in case
runDIR = "run/"+processID
inheritDIR = ""
os.makedirs(runDIR)
log_file = False
pgm = sys.argv[0].split(".")
pgm = ".".join(pgm[0:len(pgm)-1])
log_file = runDIR+"/"+pgm+".log"
log_file_path = log_file
os.system("cp -p "+sys.argv[0]+" "+runDIR+"/"+pgm+".py")
log_file = open(log_file, "w") # encoding???
bprint = __builtins__.print
#------------------- function  print --------------------###
def	print(*args, sep=" ", nl="\n", raw=False): #, file=None): # , **kwargs):
# Construct the message
	message_parts = [str(arg) for arg in args]
	message = sep.join(message_parts)
	if len(args)>0 and not np.isscalar(args[0]):
		message = nl+message
# Add timestamp and line number
	if not raw:
		line_no = f"@ {traceback.extract_stack(None, 2)[0].lineno}: "
		timestamp = dt.datetime.now().strftime("%d/%m/%y %H:%M:%S") + ": "
		message = line_no + timestamp + message
	bprint(message)
	if log_file:
		log_file.write(message+"\n")
		log_file.flush()
__builtins__.print = print

sys.path.append('/home/jl/FX/API/fxcm/ForexConnect')
np.set_printoptions(edgeitems = 30, linewidth = 100000, formatter = {"float": lambda x: "%.6g" % x})
pd.set_option('display.width', 100000)
pd.set_option('display.min_rows', 20)
pd.options.display.float_format = '{:.6f}'.format
logging.basicConfig(filename=log_file_path,filemode='a',format='%(asctime)s - %(levelname)s - %(message)s',level=logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logging.getLogger().addHandler(console_handler)

def recursive_explore(object, attr_name = "", indent = "", printIt = False):
	if isinstance(object, (list, tuple)):
		print(f"{indent}{attr_name}: {type(object)}, Length: {len(object)}")
		for i,item in enumerate(object):
			recursive_explore(item, f"{attr_name}[{i}]", indent + "	", printIt = printIt)
	elif isinstance(object, dict):
		print(f"{indent}{attr_name}: {type(object)}, Keys: {list(object.keys())}")
		for key, value in object.items():
			recursive_explore(value, f"{attr_name}[{key}]", indent + "	", printIt = printIt)
	elif isinstance(object, pd.DataFrame):
		print(f"{indent}{attr_name}: pandas DataFrame, Shape: {object.shape}")
		if printIt:
			print(object)
	elif isinstance(object, torch.Tensor):
		print(f"{indent}{attr_name}: torch.Tensor, Shape: {object.shape}")
		if printIt:
			print(object)
	elif isinstance(object, pd.Series):
		print(f"{indent}{attr_name}: pandas Series, Length: {len(object)}")
		if printIt:
			print(object)
	elif isinstance(object, (str, int, float, bool)):
# Primitive types - just print them
		if printIt:
			print(f"{indent}{attr_name}: {type(object)}, Value: {object}")
		else:
			print(f"{indent}{attr_name}: {type(object)}")
	elif hasattr(object, '__dict__'):
# If the object is a class instance, explore its __dict__
		print(f"{indent}{attr_name}: {type(object)} (class instance)")
		for key, value in vars(object).items():
			recursive_explore(value, key, indent + "	", printIt = printIt)
	else:
		print(f"{indent}{attr_name}: {type(object)}")
		if printIt:
			print(object)

###-------------------- function inspect --------------------###
def inspect(object, printIt = False):
	line_no = f" @ line {traceback.extract_stack(None, 2)[0].lineno}"
	name = ""
	if hasattr(object, "__name__"):
		name = object.__name__
	print(f"Exploring object structure called{line_no}: {name}")
	recursive_explore(object, printIt = printIt)

####--------------------	 class PassThroughNormalizer --------------------###
class	PassThroughNormalizer(TorchNormalizer):
# Normalizer that does not change the input values.

###--------------------method __init__ --------------------###
	def	__init__(self):
		super().__init__()  # Call the __init__ method of TorchNormalizer
# Add a center_ and scale_ a	ributes
		self.center_ = 0.0
		self.scale_ = 1.0
		self.transformation = None

###--------------------method decode --------------------###
	def	decode(self, tensor: torch.Tensor) -> torch.Tensor:
# Directly return the tensor without modifications
		return tensor

###--------------------method encode --------------------###
	def	encode(self, values: pd.Series) -> torch.Tensor:
# Convert the values directly to tensor without modifications
		return torch.tensor(values.values, dtype = torch.float)

###--------------------method fit --------------------###
	def	fit(self, y):
		return self

###--------------------method transform --------------------###
	def	transform(self, y, return_norm = False, target_scale = None):
# Convert DataFrame to numpy array
		y_array = y.values if isinstance(y, pd.DataFrame) else y
		y_tensor = torch.tensor(y_array)
		if return_norm:
# Return a dummy scale tensor of ones with the same shape as y
  	  # Add an extra dimension to make it a 2D tensor
			return y_tensor, torch.ones_like(y_tensor)[None, :]
		return y_tensor

###--------------------method inverse_transform --------------------###
	def	inverse_transform(self, y, target_scale = None):
		return y

class CustomMultiNormalizer(MultiNormalizer):
	def is_fitted(self):
# Check if all internal normalizers are fitted based on presence of `center` and `scale`
		return all(hasattr(normalizer, 'center') and hasattr(normalizer, 'scale')
				   for normalizer in self.normalizers)

def debugmethod(func): # add to class
	def wrapper(self, *args, **kwargs):
		print(f"Calling function: {func.__name__}")
		if len(args) > 0:
			print("Positional arguments:")
			for i, arg in enumerate(args):
				print(f"  Arg {i+1}:\n{arg}",raw=True)
		if len(kwargs) > 0:
			print("Keyword arguments:")
			for key, value in kwargs.items():
				print(f"  {key}: {value}",raw=True)
		return func(self, *args, **kwargs)
	return wrapper

class TimeSeriesDataSet(OrigTSDS):
	@debugmethod
	def __init__(self, data, time_idx, target, group_ids, **kwargs):
		super().__init__(data, time_idx, target, group_ids, **kwargs)

class MyTFTModel(TemporalFusionTransformer):
	@debugmethod
	def __init__(self, *args, **kwargs):
		super().__init__(*args, **kwargs)  # Let the parent class handle most of the initialization
		self.loss_fn=MSELoss()
		self.learning_rate = learning_rate or 1e-3  # default learning rate if not provided
		self.weights=1

	@classmethod
	def from_predict_df(cls, learning_rate=None, **kwargs):
# Instantiate the loss function dynamically based on the predict_df/quantiles
		loss_fn = MSELoss()
		multi_loss = MultiLoss([loss_fn])
# Instantiate the model using the parent class's from_predict_df method
		model = cls(loss=multi_loss, **kwargs)
# Additional model-specific configuration
		model.loss_fn = loss_fn
		return model

	def training_step(self, batch, batch_idx):
		x, y = batch
		self.weights = y[1].detach().clone()[:,0]  # Extract weights
		if self.weights.sum() < 5: # Ignore (almost) empty sequences
			return None
		y = y[0][0].flatten()  # Flatten to 1D tensor
		y_scale = x['target_scale'][0][0][0]
		y_hat = self(x)[0][0].flatten()  # Flatten to 1D tensor
		losses = self.loss_fn(y_hat, y, y_scale)
		self.log("loss", losses, prog_bar=True, on_epoch=True)
		return {"loss": losses}

	def validation_step(self, batch, batch_idx):
		x, y = batch
		self.weights = y[1].detach().clone()[:,0]  # Extract weights
		if self.weights.sum() < 5:
			return None
		y = y[0][0].flatten()  # Flatten to 1D tensor
		y_scale = x['target_scale'][0][0][0]
		y_hat = self(x)[0][0].flatten()  # Flatten to 1D tensor
		losses = self.loss_fn(y_hat, y, y_scale)
		self.log("val_loss", losses, prog_bar=True, on_epoch=True)
		return {"val_loss":losses}

	def predict_step(self, batch, batch_idx=None): #, pred_samples=100): #, batch_idx):
		x, _ = batch
		with torch.no_grad():
			return self(x)[0][0].squeeze(2)
#			return self(x)[0][0].flatten()
# Define the ranges you want to check
def group_indices(lst):
	values = list(set(lst))
	result = [[] for _ in range(len(values))]
	for i, value in enumerate(lst):
		result[value]=result[value]+[i]
	return result

# Initial program code
executable = sys.executable.split("/")
executable = executable[len(executable)-1]
print(f'Running {pgm}.py dumping source in', runDIR+'/'+pgm+'.py', f'log file {log_file.name}')
print('python', sys.version.replace("\n", " "), f' pandas {pd.__version__}, numpy {np.__version__}')
print(f'torch {torch.__version__}, pytorch_forecasting {pf.__version__}, lightning.pytorch {pl.__version__}')
#print(subprocess.run(['pip3', 'list', '--outdated', '--format = columns'], capture_output = True, text = True))
print(f'********************************************************************** Active threads: {th.active_count()}')

# Read data which is 21600 long
data = pd.read_csv("testGroup.csv",index_col=False)
known_reals = ['scaled_mins', 'sin_mins', 'cos_mins'] #, 'hour']
unknown_reals = []
static_reals = ['encoder_length','R_center', 'R_scale']
target_normalizer = CustomMultiNormalizer(normalizers = [GroupNormalizer(groups = ["group"], scale_by_group = True)])
scalers = {col: PassThroughNormalizer() for col in unknown_reals}
batch_size = 21600
learning_rate=1e-3
embedding_sizes={}
epochs = 1
max_encoder_length=120
max_prediction_length=20
print(data['group'].value_counts())
prepare = TimeSeriesDataSet(
	data,
	time_idx = "time_idx",
	target = ['R'],
	group_ids = ["group"],
	max_encoder_length = max_encoder_length,
	max_prediction_length = max_prediction_length,
	min_encoder_length = max_encoder_length//2,
	min_prediction_length = max_prediction_length//2,
	static_categoricals = ["group"],
	static_reals = static_reals,
	time_varying_known_categoricals = [],
	time_varying_known_reals = known_reals,
	time_varying_unknown_reals = unknown_reals,
	add_relative_time_idx = True,
	add_target_scales = True,
	add_encoder_length = True, # make sure max == min
	target_normalizer = target_normalizer,
	scalers = scalers, # use PTN for all others
)
dataloader = prepare.to_dataloader(train = False, batch_size = batch_size, shuffle = False, sampler = None, num_workers = threads)
lstm_hidden_size = 10
lstm_layers = 2 # generatly 2, 3
model = MyTFTModel.from_predict_df(
	prepare,
	hidden_size = lstm_hidden_size,
	lstm_layers = lstm_layers,
	attention_head_size = 4, # Number of attention heads try 2, 4, 8 see also the hidden sizes
	dropout = 0.1, # Dropout rate
	hidden_continuous_size = 16, # Size of continuous variables embeddings again try 16, 32
	output_size = 1,
	embedding_sizes = embedding_sizes,
)
# Run the model
model.eval()
x,_ = next(iter(dataloader))
inspect(x)
groups=x['groups']
print(pd.DataFrame(x['groups']).value_counts())

Expected behavior

I expect the groups in x['groups'] with the associated data to match the data in the input dataframe so I can extract the appropriate results. I want to do this in a single batch call for efficiency and simplicity reasons - it will be far faster than doing the prediction for each minute. Results are in the attachment.

Additional context

The correction should work for any number of groups, for any batch_size and for a list of targets as well as a grouped target.

Versions

System: python: 3.8.10 (default, Nov 7 2024, 13:10:47) [GCC 9.4.0] executable: /usr/bin/python3 machine: Linux-5.4.0-193-generic-x86_64-with-glibc2.29

Python dependencies:
pip: 20.0.2
sktime: 0.29.1
sklearn: 1.3.2
skbase: 0.7.8
numpy: 1.24.3
scipy: 1.10.1
pandas: 2.0.3
matplotlib: 3.7.5
joblib: 1.4.2
numba: None
statsmodels: 0.14.1
pmdarima: None
statsforecast: None
tsfresh: None
tslearn: None
torch: 2.4.1+cu121
tensorflow: 2.13.1
tensorflow_probability: 0.21.0

The program output shows the actual pytorch_forecasting etc versions.

System:
python: 3.8.10 (default, Nov 7 2024, 13:10:47) [GCC 9.4.0]
executable: /usr/bin/python3
machine: Linux-5.4.0-193-generic-x86_64-with-glibc2.29

Python dependencies:
pip: 20.0.2
sktime: 0.29.1
sklearn: 1.3.2
skbase: 0.7.8
numpy: 1.24.3
scipy: 1.10.1
pandas: 2.0.3
matplotlib: 3.7.5
joblib: 1.4.2
numba: None
statsmodels: 0.14.1
pmdarima: None
statsforecast: None
tsfresh: None
tslearn: None
torch: 2.4.1+cu121
tensorflow: 2.13.1
tensorflow_probability: 0.21.0

The program output shows the actual pytorch_forecasting etc versions.

System: python: 3.8.10 (default, Nov 7 2024, 13:10:47) [GCC 9.4.0] executable: /usr/bin/python3 machine: Linux-5.4.0-193-generic-x86_64-with-glibc2.29

Python dependencies:
pip: 20.0.2
sktime: 0.29.1
sklearn: 1.3.2
skbase: 0.7.8
numpy: 1.24.3
scipy: 1.10.1
pandas: 2.0.3
matplotlib: 3.7.5
joblib: 1.4.2
numba: None
statsmodels: 0.14.1
pmdarima: None
statsforecast: None
tsfresh: None
tslearn: None
torch: 2.4.1+cu121
tensorflow: 2.13.1
tensorflow_probability: 0.21.0

The program output shows the actual pytorch_forecasting etc versions.

@Loggy48 Loggy48 added the bug Something isn't working label Dec 17, 2024
@github-project-automation github-project-automation bot moved this to Needs triage & validation in Bugfixing - pytorch-forecasting Dec 17, 2024
@Loggy48
Copy link
Author

Loggy48 commented Dec 18, 2024

The output from TimeSeriesDataSet shows the groups correctly in pd.Series(prepare.data['groups'].flatten()).value_counts() as having 7200 per group. The issue therefore must be in TemporalFusionTransformer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
Status: Needs triage & validation
Development

No branches or pull requests

1 participant