강얼쥐와 함께 즐겁게 읽는 AI

구글 데이터 분석하기

영웅*^%&$ 2023. 2. 9. 13:18
728x90

import numpy as np

import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from keras.models import Sequential
from keras.layers import Dense, LSTM
import matplotlib.pyplot as plt

 

# Load the Google stock data into a pandas dataframe
df = pd.read_csv('GOOGL.csv', date_parser=True)

 

# Normalize the data using MinMaxScaler
scaler = MinMaxScaler(feature_range=(0,1))
df = scaler.fit_transform(df[['Close']])

 

# Split the data into training and testing sets
train_data = df[0:int(df.shape[0]*0.8),:]
test_data = df[int(df.shape[0]*0.8):,:]

 

# Convert the data into a numpy array
x_train = []
y_train = []
for i in range(60, train_data.shape[0]):
    x_train.append(train_data[i-60:i,0])
    y_train.append(train_data[i,0])
x_train, y_train = np.array(x_train), np.array(y_train)

 

x_test = []
y_test = []
for i in range(60, test_data.shape[0]):
    x_test.append(test_data[i-60:i,0])
    y_test.append(test_data[i,0])
x_test, y_test = np.array(x_test), np.array(y_test)

 

# Reshape the data for use with the LSTM model
x_train = np.reshape(x_train, (x_train.shape[0], x_train.shape[1], 1))
x_test = np.reshape(x_test, (x_test.shape[0], x_test.shape[1], 1))

 

# Build the LSTM model
model = Sequential()
model.add(LSTM(units=50, return_sequences=True, input_shape=(x_train.shape[1], 1)))
model.add(LSTM(units=50))
model.add(Dense(1))

 

# Compile the model
model.compile(loss='mean_squared_error', optimizer='adam')

 

# Train the model
history = model.fit(x_train, y_train, epochs=30, batch_size=16)

 

# Evaluate the model on the test data
model.evaluate(x_test, y_test)

 

# Make predictions using the trained model
predictions = model.predict(x_test)

 

# Plot the predictions against the actual stock prices
plt.plot(y_test, color='red', label='Actual Google Stock Price')
plt.plot(predictions, color='blue', label='Predicted Google Stock Price')
plt.title('Google Stock Price Prediction')
plt.xlabel('Time')
plt.ylabel('Stock Price')
plt.legend()
plt.show()
728x90