Skip to content

Commit

Permalink
feat: 添加过滤周期后显示图标
Browse files Browse the repository at this point in the history
  • Loading branch information
StevenChen16 committed Dec 4, 2024
1 parent f2a8cde commit 31abd9e
Showing 1 changed file with 68 additions and 19 deletions.
87 changes: 68 additions & 19 deletions ftt/ftt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.fft import fft, fftfreq
from scipy.fft import fft, ifft, fftfreq
from scipy.signal import hilbert, find_peaks
import pywt
from datetime import datetime, timedelta
Expand Down Expand Up @@ -47,8 +47,8 @@ def compute_basic_metrics(self):
# 计算波动率
self.df['vol_21'] = self.df['log_return'].rolling(window=21).std() * np.sqrt(252)

def perform_fft(self):
"""执行傅里叶变换"""
def perform_fft(self, filter_threshold=None):
"""执行傅里叶变换,可选择性地过滤高频成分"""
# 准备数据
returns = self.df['log_return'].values
n = len(returns)
Expand All @@ -57,6 +57,19 @@ def perform_fft(self):
fft_result = fft(returns)
freqs = fftfreq(n, d=1)

# 如果指定了过滤阈值,过滤高频成分
if filter_threshold is not None:
# 创建低通滤波器
filter_mask = np.abs(freqs) < filter_threshold
fft_result_filtered = fft_result * filter_mask

# 执行逆傅里叶变换获取过滤后的收益率
filtered_returns = np.real(ifft(fft_result_filtered))
self.df['filtered_returns'] = filtered_returns

# 从过滤后的收益率重建价格序列
self.df['filtered_price'] = self.df['Close'].iloc[0] * np.exp(filtered_returns.cumsum())

# 计算功率谱
power_spectrum = np.abs(fft_result)**2

Expand All @@ -67,6 +80,16 @@ def perform_fft(self):

return self.periods, self.power_spectrum

def filter_high_frequency(self, cutoff_period=21):
"""
过滤高频成分
参数:
cutoff_period: 截止周期(天),高于此频率的成分将被过滤
"""
filter_threshold = 1/cutoff_period # 将周期转换为频率
self.perform_fft(filter_threshold=filter_threshold)
return self.df['filtered_price']

def find_significant_periods(self, n_peaks=5):
"""找出显著周期"""
peaks, _ = find_peaks(self.power_spectrum)
Expand Down Expand Up @@ -108,10 +131,14 @@ def detect_regime_changes(self):
change_dates = [self.df.index[i] for i in regime_changes]
return change_dates, energy, threshold

def plot_comprehensive_analysis(self):
"""绘制综合分析图"""
def plot_comprehensive_analysis(self, show_filtered=True):
"""绘制综合分析图,包括过滤后的价格"""
# 准备数据
self.perform_fft()
if show_filtered:
self.filter_high_frequency() # 默认使用21天作为截止周期
else:
self.perform_fft()

self.find_significant_periods()
self.wavelet_analysis()
self.hilbert_phase_analysis()
Expand All @@ -122,10 +149,13 @@ def plot_comprehensive_analysis(self):

# 1. 价格和移动平均线
ax1 = plt.subplot(511)
ax1.plot(self.df['Date'], self.df['Close'], label='价格')
ax1.plot(self.df['Date'], self.df['Close'], label='原始价格', alpha=0.7)
if show_filtered and 'filtered_price' in self.df.columns:
ax1.plot(self.df['Date'], self.df['filtered_price'],
label='过滤后价格', color='red', linewidth=2)
ax1.plot(self.df['Date'], self.df['MA21'], label='21日均线')
ax1.plot(self.df['Date'], self.df['MA63'], label='63日均线')
ax1.set_title('价格走势和移动平均')
ax1.set_title('价格走势比较')
ax1.legend()
ax1.grid(True)

Expand Down Expand Up @@ -158,6 +188,26 @@ def plot_comprehensive_analysis(self):

plt.tight_layout()
return plt

def analyze_with_different_filters(self, periods=[5, 21, 63]):
"""使用不同的过滤周期进行分析"""
plt.figure(figsize=(15, 8))

# 绘制原始价格
plt.plot(self.df['Date'], self.df['Close'],
label='原始价格', alpha=0.5, color='gray')

# 使用不同的过滤周期
colors = ['blue', 'green', 'red']
for period, color in zip(periods, colors):
filtered_prices = self.filter_high_frequency(cutoff_period=period)
plt.plot(self.df['Date'], filtered_prices,
label=f'过滤周期 {period}天', color=color)

plt.title('不同过滤周期的价格对比')
plt.legend()
plt.grid(True)
return plt

def get_trading_signals(self):
"""生成交易信号"""
Expand All @@ -182,6 +232,10 @@ def get_trading_signals(self):

def print_analysis_summary(self):
"""打印分析摘要"""
# 确保先执行FFT分析和寻找显著周期
self.perform_fft()
self.find_significant_periods()

print("\n=== 股票分析摘要 ===")

# 基本统计
Expand Down Expand Up @@ -244,20 +298,15 @@ def main():
# 使用示例
analyzer = StockSpectralAnalysis(data)

# 执行分析
analyzer.perform_fft()
analyzer.find_significant_periods()

# 打印分析摘要
# 基本分析
analyzer.print_analysis_summary()

# 生成交易信号
signals = analyzer.get_trading_signals()
print("\n最新交易信号:")
print(signals.iloc[-1])
# 使用不同的过滤周期进行分析
plt = analyzer.analyze_with_different_filters(periods=[5, 21, 63])
plt.show()

# 绘制分析图表
plt = analyzer.plot_comprehensive_analysis()
# 显示综合分析图
plt = analyzer.plot_comprehensive_analysis(show_filtered=True)
plt.show()

if __name__ == "__main__":
Expand Down

0 comments on commit 31abd9e

Please sign in to comment.