import numpy as np
from numpy.random import normal
from scipy.special import expit as sigmoid
from sklearn.preprocessing import StandardScaler
import time

class Fm(object):
    def __init__(self):
        self._w = None
        self._w_0 = None
        self._v = None
        self.scaler = StandardScaler()

    def fit(self, X, y, feature_potential=8, alpha=0.01, iter=100):
        # Start the timer
        start_time = time.time()
        
        # Standard data preparation steps
        dataMatrix = np.mat(self.scaler.fit_transform(X))
        classLabels = y.values.reshape(-1, 1)  # Convert the Series to a NumPy array and then reshape
        k = feature_potential
        m, n = np.shape(dataMatrix)
        w = np.zeros((n, 1))
        w_0 = 0.0
        v = normal(0, 0.2, size=(n, k))
        
        # Iterative training loop
        for it in range(iter):
            for x in range(32):
                inter_1 = dataMatrix[x] * v
                inter_2 = np.multiply(dataMatrix[x], dataMatrix[x]) * np.multiply(v, v)
                interaction = np.sum(np.multiply(inter_1, inter_1) - inter_2) / 2.0
                p = w_0 + dataMatrix[x] * w + interaction
                loss = sigmoid(classLabels[x] * p[0, 0]) - 1
                w_0 -= alpha * loss * classLabels[x]
                for i in range(n):
                    if dataMatrix[x, i] != 0:
                        w[i, 0] -= alpha * loss * classLabels[x] * dataMatrix[x, i]
                        for j in range(k):
                            v[i, j] -= alpha * loss * classLabels[x] * (
                                dataMatrix[x, i] * inter_1[0, j] - v[i, j] * dataMatrix[x, i] * dataMatrix[x, i])
        
        # Set the model parameters
        self._w_0, self._w, self._v = w_0, w, v
        
        # End the timer
        end_time = time.time()
        
        # Calculate elapsed time
        total_time = end_time - start_time

        # Print out the training duration
        print(f"Training completed in {total_time:.6f} seconds")

    def predict(self, X):
        if self._w_0 is None or np.any(self._w == None) or np.any(self._v == None):
            raise Exception("Estimator not fitted, call `fit` first")
        
        X = np.mat(self.scaler.transform(X))
        m, n = np.shape(X)
        result = []

        # Start timing
        start_time = time.time()
        
        for x in range(m):
            inter_1 = X[x] * self._v
            inter_2 = np.multiply(X[x], X[x]) * np.multiply(self._v, self._v)
            interaction = np.sum(np.multiply(inter_1, inter_1) - inter_2) / 2.
            p = self._w_0 + X[x] * self._w + interaction
            result.append(sigmoid(p[0, 0]))
        
        # End timing
        end_time = time.time()
        
        # Calculate time taken
        total_time = end_time - start_time
        avg_time_per_data = total_time / m

        # Display statistics
        print(f"Processed {m} data points")
        print(f"Total elapsed time: {total_time:.6f} seconds")
        print(f"Average time per data point: {avg_time_per_data:.6f} seconds")
        
        return result