题目
思路
部分分
对于m==2或者m==3的情况
我们可以直接定义状态
\(dp_{n,i,j,k}\)表示前n中烹饪方法第一种主要食材用了i次,第二种主要食材使用了j次,第三种主要食材用了k次的所有方案数
\(dp_{n,i,j,k}=dp_{n-1,i,j,k}+dp_{n-1,i-1,j,k}+dp_{n-1,i,j-1,k}+dp_{n-1,i,j,k-1}\)
初始化就是
\(dp_{0,0,0,0}=1\)
考虑到可能出现负数的情况
所以笔者选择从小向大转移
代码
#include<iostream>
using namespace std;
const long long mod=998244353;
long long n,m;
long long s=0;
long long a[45][45];
long long dp1[45][45][45];
long long dp2[45][45][45][45];
int main()
{
cin>>n>>m;
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
cin>>a[i][j];
if(m==3||m==2)
{
dp2[0][0][0][0]=1;
for(int z=0;z<n;z++)
{
for(int i=0;i<=n;i++)
{
for(int j=0;j<=n;j++)
{
for(int k=0;k<=n;k++)
{
dp2[z+1][i][j][k]=(dp2[z+1][i][j][k]+dp2[z][i][j][k])%mod;
dp2[z+1][i+1][j][k]=(dp2[z+1][i+1][j][k]+dp2[z][i][j][k]*a[z+1][1])%mod;
dp2[z+1][i][j+1][k]=(dp2[z+1][i][j+1][k]+dp2[z][i][j][k]*a[z+1][2])%mod;
dp2[z+1][i][j][k+1]=(dp2[z+1][i][j][k+1]+dp2[z][i][j][k]*a[z+1][3])%mod;
}
}
}
}
for(int i=0;i<=n;i++)
{
for(int j=0;j<=n;j++)
{
for(int k=0;k<=n;k++)
{
int t=(i+j+k)/2;
if(!t)
continue;
if(i<=t&&j<=t&&k<=t)
{
s=(s+dp2[n][i][j][k])%mod;
}
}
}
}
cout<<s;
return 0;
}
cout<<"fjdaklsfjldasfe";
return 0;
}
正解
要求的数量可以表示为所有的数量减去不符合条件的数量
我们将超过\(k\over2\)的食材称为特殊食材
将其他食材一起合在一起成为超级食材
用\(dp_{i,j}\)表示表示第i种特殊食材比超级食材多j道
\(dp_{i,j}\)可以转移到
\(\begin{cases}dp_{i+1,j}=dp_{i,j-1}*a[i][col]+dp_{i,j}+dp_{i,j+1}*(s[i]-a[i][col])\\dp_{i+1,j+1}=dp_{i,j}*a[i][col]+dp_{i,j+1}+dp_{i,j+2}*(s[i]-a[i][col])\\dp_{i+1,j-1}=dp_{i,j-2}*a[i][col]+dp_{i,j-1}+dp_{i,j}*(s[i]-a[i][col])\\\end{cases}\)
代码
#include<iostream>
#include<cstring>
using namespace std;
const long long mod=998244353;
long long n,m;
long long ans;
long long s[105];
long long a[105][2005];
long long dp[105][205];
void solve_DP(int col)
{
memset(dp,0,sizeof(dp));
dp[0][n]=1;
for(int i=0;i<=n;i++)
{
for(int j=0;j<=2*n;j++)
{
dp[i+1][j]=(dp[i+1][j]+dp[i][j])%mod;
dp[i+1][j+1]=(dp[i+1][j+1]+dp[i][j]*a[i+1][col])%mod;
if(j)
dp[i+1][j-1]=(dp[i+1][j-1]+dp[i][j]*(s[i+1]-a[i+1][col]))%mod;
}
}
for(int i=n+1;i<=2*n;i++)
ans=((ans-dp[n][i])%mod+mod)%mod;
}
int main()
{
cin>>n>>m;
ans=1;
for(int i=1;i<=n;i++)
{
for(int j=1;j<=m;j++)
{
cin>>a[i][j];
s[i]=(s[i]+a[i][j])%mod;
}
ans=(ans*(s[i]+1))%mod;
}
ans=((ans-1)%mod+mod)%mod;
for(int col=1;col<=m;col++)
solve_DP(col);
cout<<ans;
return 0;
}