Skip to content

Commit

Permalink
make sure dtype is correct PCA
Browse files Browse the repository at this point in the history
  • Loading branch information
Intron7 committed Nov 25, 2024
1 parent fb8c825 commit c65585d
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def fit(self, x):
def transform(self, X):
def _transform(X_part, mean_, components_):
pre_mean = mean_ @ components_.T
mean_impact = cp.ones((X_part.shape[0], 1)) @ pre_mean.reshape(1, -1)
mean_impact = cp.ones(
(X_part.shape[0], 1), dtype=X_part.dtype
) @ pre_mean.reshape(1, -1)
X_transformed = X_part.dot(components_.T) - mean_impact
return X_transformed

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def fit(self, x):

def transform(self, X):
precomputed_mean_impact = self.mean_ @ self.components_.T
mean_impact = cp.ones((X.shape[0], 1)) @ precomputed_mean_impact.reshape(1, -1)
mean_impact = cp.ones(
(X.shape[0], 1), dtype=cp.float32
) @ precomputed_mean_impact.reshape(1, -1)
X_transformed = X.dot(self.components_.T) - mean_impact
# X = X - self.mean_
# X_transformed = X.dot(self.components_.T)
Expand Down

0 comments on commit c65585d

Please sign in to comment.