手写二分
用时一万ms,容易出错
#include <bits/stdc++.h>
using namespace std;
const int N=5e6;
const int M=7.5e6+10;
struct str{
int a,b,c;
}st[M];
int idx=0;
bool cmp(str x,str y)
{
if(x.a<y.a) return true;
else if(x.a>y.a) return false;
else if(x.b<y.b) return true;
else if(x.b>y.b) return false;
else return x.c<y.c;
}
int main()
{
int n;
cin >> n;
for(int i=0;i*i*2<=N;i++)
for(int j=i;j*j<=N;j++)
{
st[idx++]={i*i+j*j,i,j};
}
sort(st,st+idx,cmp);
for(int i=0;i*i*4<=N;i++)
for(int j=i;(i*i+j*j)*2<=N;j++)
{
int k=n-i*i-j*j;
// str temp={k,0,0};
// int t=lower_bound(st,st+idx,temp,cmp)-st;
int l=0,r=idx-1;
while(l<r)
{
int mid=(l+r+1)/2;
if(st[mid].a<=k) l=mid;
else r=mid-1;
}
if(st[l].a==k)
{
cout<<i<<' '<<j<<' '<<st[l].b<<' '<<st[l].c;
return 0;
}
}
}
利用STL二分函数
1万毫秒 代码短不容易出错
#include <bits/stdc++.h>
using namespace std;
const int N=5e6;
const int M=7.5e6+10;
struct str{
int a,b,c;
}st[M];
int idx=0;
bool cmp(str x,str y)
{
if(x.a<y.a) return true;
else if(x.a>y.a) return false;
else if(x.b<y.b) return true;
else if(x.b>y.b) return false;
else return x.c<y.c;
}
int main()
{
int n;
cin >> n;
for(int i=0;i*i*2<=N;i++)
for(int j=i;j*j<=N;j++)
{
st[idx++]={i*i+j*j,i,j};
}
sort(st,st+idx,cmp);
for(int i=0;i*i*4<=N;i++)
for(int j=i;(i*i+j*j)*2<=N;j++)
{
int k=n-i*i-j*j;
str temp={k,0,0};
int t=lower_bound(st,st+idx,temp,cmp)-st;
if(t==n) continue;
if(st[t].a==k)
{
cout<<i<<' '<<j<<' '<<st[t].b<<' '<<st[t].c;
return 0;
}
}
}
哈希写法
500ms 代码短思路简单用时最短
#include <bits/stdc++.h>
using namespace std;
const int N=5e6;
const int M=7.5e6+10;
int st[M];
int main()
{
int n;
cin >> n;
for(int i=0;i*i<=N;i++) st[i*i]=-1;
for(int i=1;i*i*2<=N;i++)
for(int j=i;j*j<=N;j++)
{
int t=i*i+j*j;
if(st[t]==0) st[t]=i;
}
for(int i=0;i*i*4<=N;i++)
for(int j=i;(i*i+j*j)*2<=N;j++)
{
int k=n-i*i-j*j;
if(st[k]!=0)
{
if(st[k]==-1) st[k]=0;
cout<<i<<' '<<j<<' '<<st[k]<<' '<<(int)sqrt(k-st[k]*st[k]);
return 0;
}
}
}
暴力写法
1400ms 暴力出奇迹,比哈希快麻了
#include <bits/stdc++.h>
using namespace std;
int n;
int main()
{
scanf("%d",&n);
for(int i = 0; i<n ; ++i)
for(int j=i;(i*i+j*j)*2<n;++j )
{
for (int k = j; k*k<=n-i*i-j*j-k*k ;++k)
{
int z = sqrt(n - i * i - j * j - k * k);
if (i * i + j * j + k * k + z * z == n)
{
printf("%d %d %d %d",i,j,k,z);
return 0;
}
}
}
}