diesel_derives/
insertable.rs

1use crate::attrs::AttributeSpanWrapper;
2use crate::field::Field;
3use crate::model::Model;
4use crate::util::{inner_of_option_ty, is_option_ty, wrap_in_dummy_mod};
5use proc_macro2::TokenStream;
6use quote::quote;
7use quote::quote_spanned;
8use syn::parse_quote;
9use syn::spanned::Spanned as _;
10use syn::{DeriveInput, Expr, Path, Result, Type};
11
12pub fn derive(item: DeriveInput) -> Result<TokenStream> {
13    let model = Model::from_item(&item, false, true)?;
14
15    let tokens = model
16        .table_names()
17        .iter()
18        .map(|table_name| derive_into_single_table(&item, &model, table_name))
19        .collect::<Result<Vec<_>>>()?;
20
21    Ok(wrap_in_dummy_mod(quote! {
22        #(#tokens)*
23    }))
24}
25
26fn derive_into_single_table(
27    item: &DeriveInput,
28    model: &Model,
29    table_name: &Path,
30) -> Result<TokenStream> {
31    let treat_none_as_default_value = model.treat_none_as_default_value();
32    let struct_name = &item.ident;
33
34    let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl();
35
36    let mut generate_borrowed_insert = true;
37
38    let mut direct_field_ty = Vec::with_capacity(model.fields().len());
39    let mut direct_field_assign = Vec::with_capacity(model.fields().len());
40    let mut ref_field_ty = Vec::with_capacity(model.fields().len());
41    let mut ref_field_assign = Vec::with_capacity(model.fields().len());
42
43    for field in model.fields() {
44        // skip this field while generating the insertion
45        if field.skip_insertion() {
46            continue;
47        }
48        // Use field-level attr. with fallback to the struct-level one.
49        let treat_none_as_default_value = match &field.treat_none_as_default_value {
50            Some(attr) => {
51                if let Some(embed) = &field.embed {
52                    return Err(syn::Error::new(
53                        embed.attribute_span,
54                        "`embed` and `treat_none_as_default_value` are mutually exclusive",
55                    ));
56                }
57
58                if !is_option_ty(&field.ty) {
59                    return Err(syn::Error::new(
60                        field.ty.span(),
61                        "expected `treat_none_as_default_value` field to be of type `Option<_>`",
62                    ));
63                }
64
65                attr.item
66            }
67            None => treat_none_as_default_value,
68        };
69
70        match (field.serialize_as.as_ref(), field.embed()) {
71            (None, true) => {
72                direct_field_ty.push(field_ty_embed(field, None));
73                direct_field_assign.push(field_expr_embed(field, None));
74                ref_field_ty.push(field_ty_embed(field, Some(quote!(&'insert))));
75                ref_field_assign.push(field_expr_embed(field, Some(quote!(&))));
76            }
77            (None, false) => {
78                direct_field_ty.push(field_ty(
79                    field,
80                    table_name,
81                    None,
82                    treat_none_as_default_value,
83                )?);
84                direct_field_assign.push(field_expr(
85                    field,
86                    table_name,
87                    None,
88                    treat_none_as_default_value,
89                )?);
90                ref_field_ty.push(field_ty(
91                    field,
92                    table_name,
93                    Some(quote!(&'insert)),
94                    treat_none_as_default_value,
95                )?);
96                ref_field_assign.push(field_expr(
97                    field,
98                    table_name,
99                    Some(quote!(&)),
100                    treat_none_as_default_value,
101                )?);
102            }
103            (Some(AttributeSpanWrapper { item: ty, .. }), false) => {
104                direct_field_ty.push(field_ty_serialize_as(
105                    field,
106                    table_name,
107                    ty,
108                    treat_none_as_default_value,
109                )?);
110                direct_field_assign.push(field_expr_serialize_as(
111                    field,
112                    table_name,
113                    ty,
114                    treat_none_as_default_value,
115                )?);
116
117                generate_borrowed_insert = false; // as soon as we hit one field with #[diesel(serialize_as)] there is no point in generating the impl of Insertable for borrowed structs
118            }
119            (Some(AttributeSpanWrapper { attribute_span, .. }), true) => {
120                return Err(syn::Error::new(
121                    *attribute_span,
122                    "`#[diesel(embed)]` cannot be combined with `#[diesel(serialize_as)]`",
123                ));
124            }
125        }
126    }
127
128    let insert_owned = quote! {
129        impl #impl_generics diesel::insertable::Insertable<#table_name::table> for #struct_name #ty_generics
130            #where_clause
131        {
132            type Values = <(#(#direct_field_ty,)*) as diesel::insertable::Insertable<#table_name::table>>::Values;
133
134            fn values(self) -> <(#(#direct_field_ty,)*) as diesel::insertable::Insertable<#table_name::table>>::Values {
135                diesel::insertable::Insertable::<#table_name::table>::values((#(#direct_field_assign,)*))
136            }
137        }
138    };
139
140    let insert_borrowed = if generate_borrowed_insert {
141        let mut impl_generics = item.generics.clone();
142        impl_generics.params.push(parse_quote!('insert));
143        let (impl_generics, ..) = impl_generics.split_for_impl();
144
145        quote! {
146            impl #impl_generics diesel::insertable::Insertable<#table_name::table>
147                for &'insert #struct_name #ty_generics
148            #where_clause
149            {
150                type Values = <(#(#ref_field_ty,)*) as diesel::insertable::Insertable<#table_name::table>>::Values;
151
152                fn values(self) -> <(#(#ref_field_ty,)*) as diesel::insertable::Insertable<#table_name::table>>::Values {
153                    diesel::insertable::Insertable::<#table_name::table>::values((#(#ref_field_assign,)*))
154                }
155            }
156        }
157    } else {
158        quote! {}
159    };
160
161    Ok(quote! {
162        #[allow(unused_qualifications)]
163        #insert_owned
164
165        #[allow(unused_qualifications)]
166        #insert_borrowed
167
168        impl #impl_generics diesel::internal::derives::insertable::UndecoratedInsertRecord<#table_name::table>
169                for #struct_name #ty_generics
170            #where_clause
171        {
172        }
173    })
174}
175
176fn field_ty_embed(field: &Field, lifetime: Option<TokenStream>) -> TokenStream {
177    let field_ty = &field.ty;
178    let span = field.span;
179    quote_spanned!(span=> #lifetime #field_ty)
180}
181
182fn field_expr_embed(field: &Field, lifetime: Option<TokenStream>) -> TokenStream {
183    let field_name = &field.name;
184    quote!(#lifetime self.#field_name)
185}
186
187fn field_ty_serialize_as(
188    field: &Field,
189    table_name: &Path,
190    ty: &Type,
191    treat_none_as_default_value: bool,
192) -> Result<TokenStream> {
193    let column_name = field.column_name()?.to_ident()?;
194    let span = field.span;
195    if treat_none_as_default_value {
196        let inner_ty = inner_of_option_ty(ty);
197
198        Ok(quote_spanned! {span=>
199            std::option::Option<diesel::dsl::Eq<
200                #table_name::#column_name,
201                #inner_ty,
202            >>
203        })
204    } else {
205        Ok(quote_spanned! {span=>
206            diesel::dsl::Eq<
207                #table_name::#column_name,
208                #ty,
209            >
210        })
211    }
212}
213
214fn field_expr_serialize_as(
215    field: &Field,
216    table_name: &Path,
217    ty: &Type,
218    treat_none_as_default_value: bool,
219) -> Result<TokenStream> {
220    let field_name = &field.name;
221    let column_name = field.column_name()?.to_ident()?;
222    let column = quote!(#table_name::#column_name);
223    if treat_none_as_default_value {
224        if is_option_ty(ty) {
225            Ok(
226                quote!(::std::convert::Into::<#ty>::into(self.#field_name).map(|v| diesel::ExpressionMethods::eq(#column, v))),
227            )
228        } else {
229            Ok(
230                quote!(std::option::Option::Some(diesel::ExpressionMethods::eq(#column, ::std::convert::Into::<#ty>::into(self.#field_name)))),
231            )
232        }
233    } else {
234        Ok(
235            quote!(diesel::ExpressionMethods::eq(#column, ::std::convert::Into::<#ty>::into(self.#field_name))),
236        )
237    }
238}
239
240fn field_ty(
241    field: &Field,
242    table_name: &Path,
243    lifetime: Option<TokenStream>,
244    treat_none_as_default_value: bool,
245) -> Result<TokenStream> {
246    let column_name = field.column_name()?.to_ident()?;
247    let span = field.span;
248    if treat_none_as_default_value {
249        let inner_ty = inner_of_option_ty(&field.ty);
250
251        Ok(quote_spanned! {span=>
252            std::option::Option<diesel::dsl::Eq<
253                #table_name::#column_name,
254                #lifetime #inner_ty,
255            >>
256        })
257    } else {
258        let inner_ty = &field.ty;
259
260        Ok(quote_spanned! {span=>
261            diesel::dsl::Eq<
262                #table_name::#column_name,
263                #lifetime #inner_ty,
264            >
265        })
266    }
267}
268
269fn field_expr(
270    field: &Field,
271    table_name: &Path,
272    lifetime: Option<TokenStream>,
273    treat_none_as_default_value: bool,
274) -> Result<TokenStream> {
275    let field_name = &field.name;
276    let column_name = field.column_name()?.to_ident()?;
277
278    let column: Expr = parse_quote!(#table_name::#column_name);
279    if treat_none_as_default_value {
280        if is_option_ty(&field.ty) {
281            if lifetime.is_some() {
282                Ok(
283                    quote!(self.#field_name.as_ref().map(|x| diesel::ExpressionMethods::eq(#column, x))),
284                )
285            } else {
286                Ok(quote!(self.#field_name.map(|x| diesel::ExpressionMethods::eq(#column, x))))
287            }
288        } else {
289            Ok(
290                quote!(std::option::Option::Some(diesel::ExpressionMethods::eq(#column, #lifetime self.#field_name))),
291            )
292        }
293    } else {
294        Ok(quote!(diesel::ExpressionMethods::eq(#column, #lifetime self.#field_name)))
295    }
296}