import java.util.Scanner;
public class Main {
static long mod=998244353;
public static void main( String[] args ) {
Scanner scanner = new Scanner(System.in);
long a = scanner.nextLong();
long b = scanner.nextLong();
long res=a,x=a;
if (a==1){
System.out.println(0);
return;
}
for (int i = 2; ii <= x; i++) {
if(x%i==0){
while (x%i==0)x/=i;
res=res/i(i-1)%mod;
}
}
if(x>1)res=res/x(x-1);
System.out.println(resPowered_by(a,b-1)%mod);
}
private static long Powered_by( long a, long b ){
long res = 1;
while (b!=0)
{
if ((b&1)!=0) res = res * a % mod;
a = a * a % 998244353;
b >>= 1;
}
return res;
}
}