求助FFT
查看原帖
求助FFT
347839
Daniel_7216楼主2021/12/28 13:01

RT,只有很小的数相乘时才能得到正确答案,为什么答案算出来都那么离谱,是精度问题吗?

#include <cstdio>
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
using namespace std;
char A[1000001], B[1000001];
int len1, len2, len, ans[2000001];
double a[1000001], b[1000001];
const double PI = acos(-1.0);
struct a_bi{
	double x, y;
	a_bi (double x_ = 0.0, double y_ = 0.0){
		x = x_;
		y = y_;
	}
	a_bi operator -(const a_bi &a) const{
		return a_bi(x - a.x, y - a.y);
	}
	a_bi operator +(const a_bi &a) const{
		return a_bi(x + a.x, y + a.y);
	}
	a_bi operator *(const a_bi &a) const{
		return a_bi(x * a.x - y * a.y, x * a.y + y * a.x);
	}
};
//令A(x)=a0+a1x+a2x^2+a3x^3+.....+a4x^4,如果令x=10,则这个多项式可以表示一个大整数。
//那么,现在只用知道新的多项式的系数即可。 
a_bi x1[1000001], x2[1000001];
void reverse(a_bi F[]){
	for (int i = 1, j = len / 2, k; i < len - 1; i++){
		if (i < j) swap(F[i], F[j]);
		k= len / 2;
		while (j >= k){
			j-= k, k /= 2;
		}
		if (j < k) j += k;
	}
} 
void fft(a_bi F[], int dft){
	reverse(F);
	for (int i = 2; i <= len; i *= 2){
		a_bi omega_n(cos(2 * PI / i), sin(dft * 2 * PI / i));
		for (int j = 0; j < len; j += i){
			a_bi w(1, 0);
			for (int k = j; k < j + i / 2; k++){
				a_bi X = F[k];
				a_bi Y = w * F[i / 2 + k];
				F[k] = X + Y;
				F[i / 2 + k] = X - Y;
				w = w * omega_n;
			}
		}
	}
    if (dft == -1){
        for (int i = 0; i < len; i++){
            F[i].x /= len;
        }
    }
}
int main(){
	scanf("%s", A);
	scanf("%s", B);
	len1 = strlen(A);
	len2 = strlen(B);
	len = 2;
	for (int i = 1; (1 << i) < len1 + len2 - 1; i++) len *= 2;
	for (int i = 0; i < len1; i++){
		a[i] = (double)(A[len1 - 1 - i] - '0');
	}
	for (int i = 0; i < len2; i++){
		b[i] = (double)(B[len2 - 1 - i] - '0');
	}
	for (int i = 0; i < len; i++){
		x1[i] = a_bi(a[i], 0);
	}
	for (int i = 0; i < len; i++){
		x2[i] = a_bi(b[i], 0);
	}
	fft(x1, 1);
	fft(x2, 1);
	for (int i = 0; i < len; i++){
        x1[i] = x1[i] * x2[i];
    }
    fft(x1, -1);
	for (int i = 0; i < len; i++){
		ans[i] = (int)(x1[i].x + 0.5);
		ans[i + 1] += ans[i] / 10;
		ans[i] = ans[i] % 10;
	}
    int i;
    for (i = len1 + len2; i >= 0 && !ans[i]; i--);
	for (; i >= 0; i--){
		printf("%d", ans[i]);
	}
	return 0;
} 
2021/12/28 13:01
加载中...