1use crate::attrs::AttributeSpanWrapper;
2use crate::field::Field;
3use crate::model::Model;
4use crate::util::{inner_of_option_ty, is_option_ty, wrap_in_dummy_mod};
5use proc_macro2::TokenStream;
6use quote::quote;
7use quote::quote_spanned;
8use syn::parse_quote;
9use syn::spanned::Spanned as _;
10use syn::{DeriveInput, Expr, Path, Result, Type};
11
12pub fn derive(item: DeriveInput) -> Result<TokenStream> {
13 let model = Model::from_item(&item, false, true)?;
14
15 let tokens = model
16 .table_names()
17 .iter()
18 .map(|table_name| derive_into_single_table(&item, &model, table_name))
19 .collect::<Result<Vec<_>>>()?;
20
21 Ok(wrap_in_dummy_mod(quote! {
22 #(#tokens)*
23 }))
24}
25
26fn derive_into_single_table(
27 item: &DeriveInput,
28 model: &Model,
29 table_name: &Path,
30) -> Result<TokenStream> {
31 let treat_none_as_default_value = model.treat_none_as_default_value();
32 let struct_name = &item.ident;
33
34 let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl();
35
36 let mut generate_borrowed_insert = true;
37
38 let mut direct_field_ty = Vec::with_capacity(model.fields().len());
39 let mut direct_field_assign = Vec::with_capacity(model.fields().len());
40 let mut ref_field_ty = Vec::with_capacity(model.fields().len());
41 let mut ref_field_assign = Vec::with_capacity(model.fields().len());
42
43 for field in model.fields() {
44 if field.skip_insertion() {
46 continue;
47 }
48 let treat_none_as_default_value = match &field.treat_none_as_default_value {
50 Some(attr) => {
51 if let Some(embed) = &field.embed {
52 return Err(syn::Error::new(
53 embed.attribute_span,
54 "`embed` and `treat_none_as_default_value` are mutually exclusive",
55 ));
56 }
57
58 if !is_option_ty(&field.ty) {
59 return Err(syn::Error::new(
60 field.ty.span(),
61 "expected `treat_none_as_default_value` field to be of type `Option<_>`",
62 ));
63 }
64
65 attr.item
66 }
67 None => treat_none_as_default_value,
68 };
69
70 match (field.serialize_as.as_ref(), field.embed()) {
71 (None, true) => {
72 direct_field_ty.push(field_ty_embed(field, None));
73 direct_field_assign.push(field_expr_embed(field, None));
74 ref_field_ty.push(field_ty_embed(field, Some(quote!(&'insert))));
75 ref_field_assign.push(field_expr_embed(field, Some(quote!(&))));
76 }
77 (None, false) => {
78 direct_field_ty.push(field_ty(
79 field,
80 table_name,
81 None,
82 treat_none_as_default_value,
83 )?);
84 direct_field_assign.push(field_expr(
85 field,
86 table_name,
87 None,
88 treat_none_as_default_value,
89 )?);
90 ref_field_ty.push(field_ty(
91 field,
92 table_name,
93 Some(quote!(&'insert)),
94 treat_none_as_default_value,
95 )?);
96 ref_field_assign.push(field_expr(
97 field,
98 table_name,
99 Some(quote!(&)),
100 treat_none_as_default_value,
101 )?);
102 }
103 (Some(AttributeSpanWrapper { item: ty, .. }), false) => {
104 direct_field_ty.push(field_ty_serialize_as(
105 field,
106 table_name,
107 ty,
108 treat_none_as_default_value,
109 )?);
110 direct_field_assign.push(field_expr_serialize_as(
111 field,
112 table_name,
113 ty,
114 treat_none_as_default_value,
115 )?);
116
117 generate_borrowed_insert = false; }
119 (Some(AttributeSpanWrapper { attribute_span, .. }), true) => {
120 return Err(syn::Error::new(
121 *attribute_span,
122 "`#[diesel(embed)]` cannot be combined with `#[diesel(serialize_as)]`",
123 ));
124 }
125 }
126 }
127
128 let insert_owned = quote! {
129 impl #impl_generics diesel::insertable::Insertable<#table_name::table> for #struct_name #ty_generics
130 #where_clause
131 {
132 type Values = <(#(#direct_field_ty,)*) as diesel::insertable::Insertable<#table_name::table>>::Values;
133
134 fn values(self) -> <(#(#direct_field_ty,)*) as diesel::insertable::Insertable<#table_name::table>>::Values {
135 diesel::insertable::Insertable::<#table_name::table>::values((#(#direct_field_assign,)*))
136 }
137 }
138 };
139
140 let insert_borrowed = if generate_borrowed_insert {
141 let mut impl_generics = item.generics.clone();
142 impl_generics.params.push(parse_quote!('insert));
143 let (impl_generics, ..) = impl_generics.split_for_impl();
144
145 quote! {
146 impl #impl_generics diesel::insertable::Insertable<#table_name::table>
147 for &'insert #struct_name #ty_generics
148 #where_clause
149 {
150 type Values = <(#(#ref_field_ty,)*) as diesel::insertable::Insertable<#table_name::table>>::Values;
151
152 fn values(self) -> <(#(#ref_field_ty,)*) as diesel::insertable::Insertable<#table_name::table>>::Values {
153 diesel::insertable::Insertable::<#table_name::table>::values((#(#ref_field_assign,)*))
154 }
155 }
156 }
157 } else {
158 quote! {}
159 };
160
161 Ok(quote! {
162 #[allow(unused_qualifications)]
163 #insert_owned
164
165 #[allow(unused_qualifications)]
166 #insert_borrowed
167
168 impl #impl_generics diesel::internal::derives::insertable::UndecoratedInsertRecord<#table_name::table>
169 for #struct_name #ty_generics
170 #where_clause
171 {
172 }
173 })
174}
175
176fn field_ty_embed(field: &Field, lifetime: Option<TokenStream>) -> TokenStream {
177 let field_ty = &field.ty;
178 let span = field.span;
179 quote_spanned!(span=> #lifetime #field_ty)
180}
181
182fn field_expr_embed(field: &Field, lifetime: Option<TokenStream>) -> TokenStream {
183 let field_name = &field.name;
184 quote!(#lifetime self.#field_name)
185}
186
187fn field_ty_serialize_as(
188 field: &Field,
189 table_name: &Path,
190 ty: &Type,
191 treat_none_as_default_value: bool,
192) -> Result<TokenStream> {
193 let column_name = field.column_name()?.to_ident()?;
194 let span = field.span;
195 if treat_none_as_default_value {
196 let inner_ty = inner_of_option_ty(ty);
197
198 Ok(quote_spanned! {span=>
199 std::option::Option<diesel::dsl::Eq<
200 #table_name::#column_name,
201 #inner_ty,
202 >>
203 })
204 } else {
205 Ok(quote_spanned! {span=>
206 diesel::dsl::Eq<
207 #table_name::#column_name,
208 #ty,
209 >
210 })
211 }
212}
213
214fn field_expr_serialize_as(
215 field: &Field,
216 table_name: &Path,
217 ty: &Type,
218 treat_none_as_default_value: bool,
219) -> Result<TokenStream> {
220 let field_name = &field.name;
221 let column_name = field.column_name()?.to_ident()?;
222 let column = quote!(#table_name::#column_name);
223 if treat_none_as_default_value {
224 if is_option_ty(ty) {
225 Ok(
226 quote!(::std::convert::Into::<#ty>::into(self.#field_name).map(|v| diesel::ExpressionMethods::eq(#column, v))),
227 )
228 } else {
229 Ok(
230 quote!(std::option::Option::Some(diesel::ExpressionMethods::eq(#column, ::std::convert::Into::<#ty>::into(self.#field_name)))),
231 )
232 }
233 } else {
234 Ok(
235 quote!(diesel::ExpressionMethods::eq(#column, ::std::convert::Into::<#ty>::into(self.#field_name))),
236 )
237 }
238}
239
240fn field_ty(
241 field: &Field,
242 table_name: &Path,
243 lifetime: Option<TokenStream>,
244 treat_none_as_default_value: bool,
245) -> Result<TokenStream> {
246 let column_name = field.column_name()?.to_ident()?;
247 let span = field.span;
248 if treat_none_as_default_value {
249 let inner_ty = inner_of_option_ty(&field.ty);
250
251 Ok(quote_spanned! {span=>
252 std::option::Option<diesel::dsl::Eq<
253 #table_name::#column_name,
254 #lifetime #inner_ty,
255 >>
256 })
257 } else {
258 let inner_ty = &field.ty;
259
260 Ok(quote_spanned! {span=>
261 diesel::dsl::Eq<
262 #table_name::#column_name,
263 #lifetime #inner_ty,
264 >
265 })
266 }
267}
268
269fn field_expr(
270 field: &Field,
271 table_name: &Path,
272 lifetime: Option<TokenStream>,
273 treat_none_as_default_value: bool,
274) -> Result<TokenStream> {
275 let field_name = &field.name;
276 let column_name = field.column_name()?.to_ident()?;
277
278 let column: Expr = parse_quote!(#table_name::#column_name);
279 if treat_none_as_default_value {
280 if is_option_ty(&field.ty) {
281 if lifetime.is_some() {
282 Ok(
283 quote!(self.#field_name.as_ref().map(|x| diesel::ExpressionMethods::eq(#column, x))),
284 )
285 } else {
286 Ok(quote!(self.#field_name.map(|x| diesel::ExpressionMethods::eq(#column, x))))
287 }
288 } else {
289 Ok(
290 quote!(std::option::Option::Some(diesel::ExpressionMethods::eq(#column, #lifetime self.#field_name))),
291 )
292 }
293 } else {
294 Ok(quote!(diesel::ExpressionMethods::eq(#column, #lifetime self.#field_name)))
295 }
296}